nlsq.profiling module¶
Performance profiling utilities for NLSQ.
Provides lightweight profiling infrastructure for measuring optimization performance and validating improvements without heavy dependencies.
Examples
>>> from nlsq.utils.profiling import profile_optimization
>>> from nlsq import least_squares
>>> import jax.numpy as jnp
>>>
>>> with profile_optimization() as metrics:
... result = least_squares(lambda x: x**2, x0=jnp.array([1.0]), max_nfev=100)
>>>
>>> print(f"Average iteration: {metrics.avg_iteration_time_ms:.2f}ms")
>>> print(f"Total time: {metrics.total_time_sec:.3f}s")
- class nlsq.utils.profiling.PerformanceMetrics(iteration_count=0, total_time_sec=0.0, iteration_times=<factory>)[source]
Bases:
objectPerformance metrics for optimization runs.
- iteration_count
Total number of iterations performed
- Type:
- total_time_sec
Total elapsed time in seconds
- Type:
- Properties
- ----------
- avg_iteration_time_ms
Average iteration time in milliseconds
- Type:
- min_iteration_time_ms
Minimum iteration time in milliseconds
- Type:
- max_iteration_time_ms
Maximum iteration time in milliseconds
- Type:
- iteration_count: int
- total_time_sec: float
- property avg_iteration_time_ms: float
Average iteration time in milliseconds.
- property min_iteration_time_ms: float
Minimum iteration time in milliseconds.
- property max_iteration_time_ms: float
Maximum iteration time in milliseconds.
- __init__(iteration_count=0, total_time_sec=0.0, iteration_times=<factory>)
- nlsq.utils.profiling.profile_optimization(enabled=True)[source]
Profile optimization performance.
Measures total runtime as a proxy for performance improvements. For beta.1, focuses on wall-clock time rather than detailed profiling.
- Parameters:
enabled (bool, default=True) – Whether to enable profiling
- Yields:
PerformanceMetrics – Performance statistics
Examples
>>> from nlsq import curve_fit >>> import jax.numpy as jnp >>> >>> with profile_optimization() as metrics: ... x = jnp.linspace(0, 10, 100) ... y = 2.0 * jnp.exp(-0.5 * x) ... popt, pcov = curve_fit(lambda x, a, b: a * jnp.exp(-b * x), x, y) >>> >>> print(f"Optimization took {metrics.total_time_sec:.3f}s")
- nlsq.utils.profiling.analyze_source_transfers(source_code)[source]
Analyze source code for host-device transfer patterns.
This is a static analysis tool for validating transfer reduction. Counts patterns that typically induce GPU-CPU transfers.
- Parameters:
source_code (str) – Source code to analyze
- Returns:
Analysis results with counts of transfer-inducing patterns: - ‘np_array_calls’: Number of np.array() calls - ‘np_asarray_calls’: Number of np.asarray() calls - ‘block_until_ready_calls’: Number of .block_until_ready() calls - ‘total_potential_transfers’: Sum of all transfer patterns
- Return type:
Notes
This is a heuristic analysis tool, not a comprehensive profiler. It provides a relative measure for before/after comparison. Only counts NumPy conversions (np.*), not JAX operations (jnp.*).
Examples
>>> from nlsq.utils.profiling import analyze_source_transfers >>> >>> code = ''' ... def my_function(x): ... y = np.array(x) # Transfer! ... return y ... ''' >>> >>> analysis = analyze_source_transfers(code) >>> print(f"Potential transfers: {analysis['total_potential_transfers']}") Potential transfers: 1
- nlsq.utils.profiling.compare_transfer_reduction(source_before, source_after, module_name='module')[source]
Compare transfer patterns before and after optimization.
- Parameters:
- Returns:
Comparison results with reduction percentages
- Return type:
Examples
>>> before = "x = np.array(y); z = np.array(w)" >>> after = "x = jnp.asarray(y); z = jnp.asarray(w)" >>> >>> comparison = compare_transfer_reduction(before, after, "mymodule") >>> print(f"Reduction: {comparison['reduction_percent']:.1f}%") Reduction: 100.0%
- nlsq.utils.profiling.profile_transfers_runtime(func, *args, trace_dir=None, **kwargs)[source]
Profile actual host-device transfers using JAX profiler.
This function provides runtime measurement of host-device transfers using JAX’s built-in profiler. Unlike static analysis, this captures actual transfer events during execution.
- Parameters:
func (callable) – Function to profile
*args – Positional arguments to pass to func
trace_dir (str or None, optional) – Directory to store profiler trace. If None, uses system temp directory with “jax-profiling” subdirectory.
**kwargs – Keyword arguments to pass to func
- Returns:
(result, transfer_stats) where result is func’s return value and transfer_stats contains profiling information
- Return type:
Notes
Requires JAX profiler support. On CPU, transfers are minimal. Most useful for GPU profiling.
Examples
>>> from nlsq.utils.profiling import profile_transfers_runtime >>> import jax.numpy as jnp >>> >>> def my_computation(x): ... return jnp.sum(x ** 2) >>> >>> result, stats = profile_transfers_runtime( ... my_computation, ... jnp.array([1.0, 2.0, 3.0]) ... ) >>> print(f"Result: {result}, Transfers: {stats['transfer_count']}")
Overview¶
The profiling module provides JAX profiler integration for monitoring host-device transfers in the TRF solver, enabling measurement of transfer bytes and counts per iteration.
Performance Targets (Task Group 2)¶
Host-device transfer bytes: 80% reduction (current ~80KB → <16KB per iteration)
Transfer count: Reduce from 24+ → <5 per iteration
GPU iteration time: 5-15% reduction through transfer optimization
Classes¶
Example Usage¶
from nlsq.utils.profiling import TransferProfiler
# Create profiler (requires JAX profiler)
profiler = TransferProfiler(enable=True)
# Profile TRF iteration
with profiler.profile_iteration(iteration=0):
# TRF solver iteration code
pass
# Get diagnostics
diagnostics = profiler.get_diagnostics()
print(f"Transfer bytes: {diagnostics['transfer_bytes']}")
print(f"Transfer count: {diagnostics['transfer_count']}")
print(f"Avg per iteration: {diagnostics['avg_bytes_per_iter']:.2f} bytes")
Chrome Trace Visualization¶
The profiler can generate Chrome trace files for visualization:
# Profile with trace output
profiler = TransferProfiler(enable=True)
with profiler.profile_iteration(iteration=0):
# Code to profile
pass
# View at chrome://tracing
profiler.save_trace("trace.json")
Requirements¶
The profiler requires the JAX profiler optional dependency:
pip install jax[profiler]
Notes¶
Profiling overhead is minimal (~1-2% when enabled)
Transfer bytes are estimated based on array sizes
Chrome traces provide detailed timeline visualization
Automatically disabled if JAX profiler not available
See Also¶
nlsq.trf module - Trust Region Reflective algorithm
nlsq.diagnostics package - General optimization diagnostics
Performance Optimization Guide - Performance optimization guide