nlsq.core.trf_jit¶
JIT-compiled Trust Region Reflective helper functions.
Added in version 1.2.0: Extracted from nlsq.core.trf for better code organization.
This module contains the JIT-compiled helper functions used by the Trust Region Reflective optimizer. These functions are performance-critical and benefit from JAX’s just-in-time compilation.
Key Components¶
TrustRegionJITFunctions¶
A dataclass containing JIT-compiled functions for:
Gradient computation with automatic differentiation
SVD-based trust region subproblem solving
Conjugate gradient (CG) solver for large problems
Trust region step computation
Module Contents¶
JIT-compiled functions for Trust Region Reflective optimization.
This module contains JAX JIT-compiled helper functions for the TRF algorithm, providing GPU/TPU-accelerated implementations of core mathematical operations.
XLA Memory Optimization¶
All JIT functions are defined at module level as singletons. This ensures each function is compiled only once per unique input shape, regardless of how many TrustRegionJITFunctions instances are created. Previously, each instance created 10+ new @jit closures, causing unbounded XLA compilation cache growth.
- class nlsq.core.trf_jit.TrustRegionJITFunctions[source]¶
Bases:
objectJIT-compiled functions for Trust Region Reflective optimization algorithm.
All JIT functions are module-level singletons to prevent XLA compilation cache bloat. Each function is compiled once per unique input shape, shared across all TrustRegionJITFunctions instances.
Core Operations¶
Gradient Computation: JAX-accelerated gradient calculation using J^T * f
SVD Decomposition: Singular value decomposition for trust region subproblems
Conjugate Gradient: Iterative solver for large-scale problems
Cost Function Evaluation: Loss function computation with masking support
Hat Space Transformation: Scaled variable transformations for bounds handling
Performance Characteristics¶
Small Problems: Direct SVD solution O(mn^2 + n^3)
Large Problems: CG iteration O(k*mn) where k is iteration count
GPU Memory: Module-level singletons prevent per-instance recompilation
Numerical Stability: Double precision arithmetic with condition monitoring
See Also¶
nlsq.trf module - Main Trust Region Reflective optimizer
nlsq.core.profiler - Performance profiling for TRF