nlsq.async_logger module

JAX-aware asynchronous logging to prevent host-device synchronization.

This module provides logging functions that use jax.debug.callback to execute logging asynchronously on the host, preventing blocking device computation.

Examples

>>> from nlsq.utils.async_logger import log_iteration_async
>>> import jax.numpy as jnp
>>>
>>> # Inside optimization loop
>>> log_iteration_async(
...     iteration=nit,
...     cost=cost,
...     gradient_norm=jnp.linalg.norm(g),
...     message=f"step={step:.6e}",
...     verbose=2
... )
nlsq.utils.async_logger.is_jax_array(x)[source]

Check if value is a JAX array.

Parameters:

x (Any) – Value to check

Returns:

True if x is a JAX array or tracer

Return type:

bool

nlsq.utils.async_logger.log_iteration_async(iteration, cost, gradient_norm, message='', verbose=0)[source]

Log optimization iteration asynchronously without blocking device.

Uses jax.debug.callback to execute logging on the host thread while device computation continues. This prevents the performance penalty of host-device synchronization during optimization.

Parameters:
  • iteration (int or jax.Array) – Current iteration number

  • cost (float or jax.Array) – Current cost/loss value

  • gradient_norm (float or jax.Array) – Norm of gradient vector

  • message (str, optional) – Additional message to append to log

  • verbose (int, optional) – Verbosity level: - 0: No logging - 1: Log every 10 iterations - 2: Log every iteration

Notes

The callback executes asynchronously, so log messages may appear slightly out of order. However, each message includes the iteration number for proper ordering during analysis.

This function has minimal overhead (~1-2μs) and does not block device computation, making it suitable for use in tight optimization loops.

Examples

>>> import jax.numpy as jnp
>>> from nlsq.utils.async_logger import log_iteration_async
>>>
>>> # In optimization loop
>>> for nit in range(100):
...     cost = compute_cost()  # Returns JAX array
...     g = compute_gradient()  # Returns JAX array
...     log_iteration_async(nit, cost, jnp.linalg.norm(g), verbose=2)
nlsq.utils.async_logger.log_convergence_async(reason, iterations, final_cost, time_sec, final_gradient_norm, verbose=1)[source]

Log convergence result asynchronously.

Parameters:
  • reason (str) – Convergence termination reason

  • iterations (int or jax.Array) – Total number of iterations performed

  • final_cost (float or jax.Array) – Final cost value achieved

  • time_sec (float) – Total optimization time in seconds

  • final_gradient_norm (float or jax.Array) – Final gradient norm

  • verbose (int, optional) – Verbosity level (0 = no logging)

Examples

>>> log_convergence_async(
...     reason="`gtol` termination condition is satisfied.",
...     iterations=42,
...     final_cost=1.23e-10,
...     time_sec=2.456,
...     final_gradient_norm=5.67e-9,
...     verbose=1
... )

Overview

The async_logger module provides JAX-aware asynchronous logging infrastructure to prevent GPU-CPU synchronization during optimization. This eliminates host-device blocking that can degrade performance.

Key Features

  • Zero-overhead logging: Uses jax.debug.callback for non-blocking execution

  • JAX array detection: Automatically identifies JAX arrays to prevent blocking

  • Verbosity control: Configure logging frequency (0=off, 1=every 10th, 2=all)

  • Performance impact: <5% overhead for async logging

Examples

Basic usage with curve fitting:

from nlsq import curve_fit
import jax.numpy as jnp

def model(x, a, b):
    return a * jnp.exp(-b * x)

# Async logging enabled with verbose=2
popt, pcov = curve_fit(
    model, x, y,
    p0=[1.0, 0.5],
    verbose=2  # Logs all iterations asynchronously
)

Integration with TRF optimizer:

from nlsq.utils.async_logger import log_iteration_async

# Called automatically by TRF when verbose > 0
# No manual integration needed

Type Detection:

from nlsq.utils.async_logger import is_jax_array
import jax.numpy as jnp
import numpy as np

jax_arr = jnp.array([1, 2, 3])
numpy_arr = np.array([1, 2, 3])

print(is_jax_array(jax_arr))   # True
print(is_jax_array(numpy_arr))  # False

Performance Characteristics

Async Logging Performance

Operation

Blocking (Old)

Async (New)

Iteration logging

~10ms GPU sync

<0.5ms callback

Large array logging

Transfer to CPU

Reference only

Overall overhead

15-25%

<5%

See Also

  • nlsq.profiling - Performance profiling infrastructure

  • nlsq.diagnostics - Optimization diagnostics and monitoring

  • Performance Optimization Guide - Complete performance optimization guide

Version History

Added in version 0.3.0-beta.3: Initial async logging implementation with JAX integration