nlsq.profiling module

Performance profiling utilities for NLSQ.

Provides lightweight profiling infrastructure for measuring optimization performance and validating improvements without heavy dependencies.

Examples

>>> from nlsq.utils.profiling import profile_optimization
>>> from nlsq import least_squares
>>> import jax.numpy as jnp
>>>
>>> with profile_optimization() as metrics:
...     result = least_squares(lambda x: x**2, x0=jnp.array([1.0]), max_nfev=100)
>>>
>>> print(f"Average iteration: {metrics.avg_iteration_time_ms:.2f}ms")
>>> print(f"Total time: {metrics.total_time_sec:.3f}s")
class nlsq.utils.profiling.PerformanceMetrics(iteration_count=0, total_time_sec=0.0, iteration_times=<factory>)[source]

Bases: object

Performance metrics for optimization runs.

iteration_count

Total number of iterations performed

Type:

int

total_time_sec

Total elapsed time in seconds

Type:

float

iteration_times

Individual iteration times (if tracked)

Type:

list of float

Properties
----------
avg_iteration_time_ms

Average iteration time in milliseconds

Type:

float

min_iteration_time_ms

Minimum iteration time in milliseconds

Type:

float

max_iteration_time_ms

Maximum iteration time in milliseconds

Type:

float

iteration_count: int
total_time_sec: float
iteration_times: list[float]
property avg_iteration_time_ms: float

Average iteration time in milliseconds.

property min_iteration_time_ms: float

Minimum iteration time in milliseconds.

property max_iteration_time_ms: float

Maximum iteration time in milliseconds.

__init__(iteration_count=0, total_time_sec=0.0, iteration_times=<factory>)
nlsq.utils.profiling.profile_optimization(enabled=True)[source]

Profile optimization performance.

Measures total runtime as a proxy for performance improvements. For beta.1, focuses on wall-clock time rather than detailed profiling.

Parameters:

enabled (bool, default=True) – Whether to enable profiling

Yields:

PerformanceMetrics – Performance statistics

Examples

>>> from nlsq import curve_fit
>>> import jax.numpy as jnp
>>>
>>> with profile_optimization() as metrics:
...     x = jnp.linspace(0, 10, 100)
...     y = 2.0 * jnp.exp(-0.5 * x)
...     popt, pcov = curve_fit(lambda x, a, b: a * jnp.exp(-b * x), x, y)
>>>
>>> print(f"Optimization took {metrics.total_time_sec:.3f}s")
nlsq.utils.profiling.analyze_source_transfers(source_code)[source]

Analyze source code for host-device transfer patterns.

This is a static analysis tool for validating transfer reduction. Counts patterns that typically induce GPU-CPU transfers.

Parameters:

source_code (str) – Source code to analyze

Returns:

Analysis results with counts of transfer-inducing patterns: - ‘np_array_calls’: Number of np.array() calls - ‘np_asarray_calls’: Number of np.asarray() calls - ‘block_until_ready_calls’: Number of .block_until_ready() calls - ‘total_potential_transfers’: Sum of all transfer patterns

Return type:

dict

Notes

This is a heuristic analysis tool, not a comprehensive profiler. It provides a relative measure for before/after comparison. Only counts NumPy conversions (np.*), not JAX operations (jnp.*).

Examples

>>> from nlsq.utils.profiling import analyze_source_transfers
>>>
>>> code = '''
... def my_function(x):
...     y = np.array(x)  # Transfer!
...     return y
... '''
>>>
>>> analysis = analyze_source_transfers(code)
>>> print(f"Potential transfers: {analysis['total_potential_transfers']}")
Potential transfers: 1
nlsq.utils.profiling.compare_transfer_reduction(source_before, source_after, module_name='module')[source]

Compare transfer patterns before and after optimization.

Parameters:
  • source_before (str) – Source code before optimization

  • source_after (str) – Source code after optimization

  • module_name (str, optional) – Name of module being analyzed (for reporting)

Returns:

Comparison results with reduction percentages

Return type:

dict

Examples

>>> before = "x = np.array(y); z = np.array(w)"
>>> after = "x = jnp.asarray(y); z = jnp.asarray(w)"
>>>
>>> comparison = compare_transfer_reduction(before, after, "mymodule")
>>> print(f"Reduction: {comparison['reduction_percent']:.1f}%")
Reduction: 100.0%
nlsq.utils.profiling.profile_transfers_runtime(func, *args, trace_dir=None, **kwargs)[source]

Profile actual host-device transfers using JAX profiler.

This function provides runtime measurement of host-device transfers using JAX’s built-in profiler. Unlike static analysis, this captures actual transfer events during execution.

Parameters:
  • func (callable) – Function to profile

  • *args – Positional arguments to pass to func

  • trace_dir (str or None, optional) – Directory to store profiler trace. If None, uses system temp directory with “jax-profiling” subdirectory.

  • **kwargs – Keyword arguments to pass to func

Returns:

(result, transfer_stats) where result is func’s return value and transfer_stats contains profiling information

Return type:

tuple

Notes

Requires JAX profiler support. On CPU, transfers are minimal. Most useful for GPU profiling.

Examples

>>> from nlsq.utils.profiling import profile_transfers_runtime
>>> import jax.numpy as jnp
>>>
>>> def my_computation(x):
...     return jnp.sum(x ** 2)
>>>
>>> result, stats = profile_transfers_runtime(
...     my_computation,
...     jnp.array([1.0, 2.0, 3.0])
... )
>>> print(f"Result: {result}, Transfers: {stats['transfer_count']}")

Overview

The profiling module provides JAX profiler integration for monitoring host-device transfers in the TRF solver, enabling measurement of transfer bytes and counts per iteration.

Performance Targets (Task Group 2)

  • Host-device transfer bytes: 80% reduction (current ~80KB → <16KB per iteration)

  • Transfer count: Reduce from 24+ → <5 per iteration

  • GPU iteration time: 5-15% reduction through transfer optimization

Classes

Example Usage

from nlsq.utils.profiling import TransferProfiler

# Create profiler (requires JAX profiler)
profiler = TransferProfiler(enable=True)

# Profile TRF iteration
with profiler.profile_iteration(iteration=0):
    # TRF solver iteration code
    pass

# Get diagnostics
diagnostics = profiler.get_diagnostics()
print(f"Transfer bytes: {diagnostics['transfer_bytes']}")
print(f"Transfer count: {diagnostics['transfer_count']}")
print(f"Avg per iteration: {diagnostics['avg_bytes_per_iter']:.2f} bytes")

Chrome Trace Visualization

The profiler can generate Chrome trace files for visualization:

# Profile with trace output
profiler = TransferProfiler(enable=True)

with profiler.profile_iteration(iteration=0):
    # Code to profile
    pass

# View at chrome://tracing
profiler.save_trace("trace.json")

Requirements

The profiler requires the JAX profiler optional dependency:

pip install jax[profiler]

Notes

  • Profiling overhead is minimal (~1-2% when enabled)

  • Transfer bytes are estimated based on array sizes

  • Chrome traces provide detailed timeline visualization

  • Automatically disabled if JAX profiler not available

See Also