Source code for nlsq.core.profiler

"""Profiling utilities for Trust Region Reflective optimization.

This module provides profiling classes for timing TRF algorithm operations,
enabling performance analysis and optimization tuning.
"""

from __future__ import annotations

import time

__all__ = [
    "NullProfiler",
    "TRFProfiler",
]


[docs] class TRFProfiler: """Profiler for timing TRF algorithm operations. Records detailed timing information for each operation in the TRF algorithm, including GPU synchronization via block_until_ready() for accurate timings. This enables performance analysis without duplicating the entire algorithm. Attributes ---------- ftimes : list[float] Function evaluation times. jtimes : list[float] Jacobian evaluation times. svd_times : list[float] SVD computation times. ctimes : list[float] Cost computation times (JAX). gtimes : list[float] Gradient computation times (JAX). gtimes2 : list[float] Gradient norm computation times. ptimes : list[float] Parameter update times. svd_ctimes : list[float] SVD conversion times (JAX → NumPy). g_ctimes : list[float] Gradient conversion times (JAX → NumPy). c_ctimes : list[float] Cost conversion times (JAX → NumPy). p_ctimes : list[float] Parameter conversion times (JAX → NumPy). """ __slots__ = ( "c_ctimes", "ctimes", "ftimes", "g_ctimes", "gtimes", "gtimes2", "jtimes", "p_ctimes", "ptimes", "svd_ctimes", "svd_times", )
[docs] def __init__(self) -> None: """Initialize profiler with empty timing arrays.""" self.ftimes: list[float] = [] self.jtimes: list[float] = [] self.svd_times: list[float] = [] self.ctimes: list[float] = [] self.gtimes: list[float] = [] self.gtimes2: list[float] = [] self.ptimes: list[float] = [] # Conversion times (JAX → NumPy) self.svd_ctimes: list[float] = [] self.g_ctimes: list[float] = [] self.c_ctimes: list[float] = [] self.p_ctimes: list[float] = []
[docs] def time_operation(self, operation: str, jax_result): """Time a JAX operation with GPU synchronization. Parameters ---------- operation : str Operation name ('fun', 'jac', 'svd', 'cost', 'grad', etc.) jax_result : JAX array result to synchronize Returns ------- result The synchronized result (same as input) """ st = time.time() result = jax_result.block_until_ready() elapsed = time.time() - st # Record timing if operation == "fun": self.ftimes.append(elapsed) elif operation == "jac": self.jtimes.append(elapsed) elif operation == "svd": self.svd_times.append(elapsed) elif operation == "cost": self.ctimes.append(elapsed) elif operation == "grad": self.gtimes.append(elapsed) elif operation == "grad_norm": self.gtimes2.append(elapsed) elif operation == "param_update": self.ptimes.append(elapsed) return result
[docs] def time_conversion(self, operation: str, start_time: float) -> None: """Record timing for JAX → NumPy conversion. Parameters ---------- operation : str Conversion operation ('svd_convert', 'grad_convert', 'cost_convert', 'param_convert') start_time : float Start time from time.time() """ elapsed = time.time() - start_time if operation == "svd_convert": self.svd_ctimes.append(elapsed) elif operation == "grad_convert": self.g_ctimes.append(elapsed) elif operation == "cost_convert": self.c_ctimes.append(elapsed) elif operation == "param_convert": self.p_ctimes.append(elapsed)
[docs] def get_timing_data(self) -> dict[str, list[float]]: """Get all recorded timing data. Returns ------- dict[str, list[float]] Dictionary containing all timing arrays """ return { "ftimes": self.ftimes, "jtimes": self.jtimes, "svd_times": self.svd_times, "ctimes": self.ctimes, "gtimes": self.gtimes, "gtimes2": self.gtimes2, "ptimes": self.ptimes, "svd_ctimes": self.svd_ctimes, "g_ctimes": self.g_ctimes, "c_ctimes": self.c_ctimes, "p_ctimes": self.p_ctimes, }
[docs] class NullProfiler: """Null object profiler with zero overhead. Provides same interface as TRFProfiler but does nothing, enabling profiling to be toggled with no performance impact. """ __slots__ = ()
[docs] def time_operation(self, operation: str, jax_result): """No-op timing - returns result unchanged.""" return jax_result
[docs] def time_conversion(self, operation: str, start_time: float) -> None: """No-op conversion timing."""
[docs] def get_timing_data(self) -> dict[str, list[float]]: """Returns empty timing data.""" return { "ftimes": [], "jtimes": [], "svd_times": [], "ctimes": [], "gtimes": [], "gtimes2": [], "ptimes": [], "svd_ctimes": [], "g_ctimes": [], "c_ctimes": [], "p_ctimes": [], }