nlsq.async_logger module¶
JAX-aware asynchronous logging to prevent host-device synchronization.
This module provides logging functions that use jax.debug.callback to execute logging asynchronously on the host, preventing blocking device computation.
Examples
>>> from nlsq.utils.async_logger import log_iteration_async
>>> import jax.numpy as jnp
>>>
>>> # Inside optimization loop
>>> log_iteration_async(
... iteration=nit,
... cost=cost,
... gradient_norm=jnp.linalg.norm(g),
... message=f"step={step:.6e}",
... verbose=2
... )
- nlsq.utils.async_logger.is_jax_array(x)[source]¶
Check if value is a JAX array.
- Parameters:
x (Any) – Value to check
- Returns:
True if x is a JAX array or tracer
- Return type:
- nlsq.utils.async_logger.log_iteration_async(iteration, cost, gradient_norm, message='', verbose=0)[source]¶
Log optimization iteration asynchronously without blocking device.
Uses jax.debug.callback to execute logging on the host thread while device computation continues. This prevents the performance penalty of host-device synchronization during optimization.
- Parameters:
Notes
The callback executes asynchronously, so log messages may appear slightly out of order. However, each message includes the iteration number for proper ordering during analysis.
This function has minimal overhead (~1-2μs) and does not block device computation, making it suitable for use in tight optimization loops.
Examples
>>> import jax.numpy as jnp >>> from nlsq.utils.async_logger import log_iteration_async >>> >>> # In optimization loop >>> for nit in range(100): ... cost = compute_cost() # Returns JAX array ... g = compute_gradient() # Returns JAX array ... log_iteration_async(nit, cost, jnp.linalg.norm(g), verbose=2)
- nlsq.utils.async_logger.log_convergence_async(reason, iterations, final_cost, time_sec, final_gradient_norm, verbose=1)[source]¶
Log convergence result asynchronously.
- Parameters:
Examples
>>> log_convergence_async( ... reason="`gtol` termination condition is satisfied.", ... iterations=42, ... final_cost=1.23e-10, ... time_sec=2.456, ... final_gradient_norm=5.67e-9, ... verbose=1 ... )
Overview¶
The async_logger module provides JAX-aware asynchronous logging infrastructure
to prevent GPU-CPU synchronization during optimization. This eliminates host-device
blocking that can degrade performance.
Key Features¶
Zero-overhead logging: Uses
jax.debug.callbackfor non-blocking executionJAX array detection: Automatically identifies JAX arrays to prevent blocking
Verbosity control: Configure logging frequency (0=off, 1=every 10th, 2=all)
Performance impact: <5% overhead for async logging
Examples¶
Basic usage with curve fitting:
from nlsq import curve_fit
import jax.numpy as jnp
def model(x, a, b):
return a * jnp.exp(-b * x)
# Async logging enabled with verbose=2
popt, pcov = curve_fit(
model, x, y,
p0=[1.0, 0.5],
verbose=2 # Logs all iterations asynchronously
)
Integration with TRF optimizer:
from nlsq.utils.async_logger import log_iteration_async
# Called automatically by TRF when verbose > 0
# No manual integration needed
Type Detection:
from nlsq.utils.async_logger import is_jax_array
import jax.numpy as jnp
import numpy as np
jax_arr = jnp.array([1, 2, 3])
numpy_arr = np.array([1, 2, 3])
print(is_jax_array(jax_arr)) # True
print(is_jax_array(numpy_arr)) # False
Performance Characteristics¶
Operation |
Blocking (Old) |
Async (New) |
|---|---|---|
Iteration logging |
~10ms GPU sync |
<0.5ms callback |
Large array logging |
Transfer to CPU |
Reference only |
Overall overhead |
15-25% |
<5% |
See Also¶
nlsq.profiling- Performance profiling infrastructurenlsq.diagnostics- Optimization diagnostics and monitoringPerformance Optimization Guide - Complete performance optimization guide
Version History¶
Added in version 0.3.0-beta.3: Initial async logging implementation with JAX integration