nlsq.core.trf_jit

JIT-compiled Trust Region Reflective helper functions.

Added in version 1.2.0: Extracted from nlsq.core.trf for better code organization.

This module contains the JIT-compiled helper functions used by the Trust Region Reflective optimizer. These functions are performance-critical and benefit from JAX’s just-in-time compilation.

Key Components

TrustRegionJITFunctions

A dataclass containing JIT-compiled functions for:

  • Gradient computation with automatic differentiation

  • SVD-based trust region subproblem solving

  • Conjugate gradient (CG) solver for large problems

  • Trust region step computation

Module Contents

JIT-compiled functions for Trust Region Reflective optimization.

This module contains JAX JIT-compiled helper functions for the TRF algorithm, providing GPU/TPU-accelerated implementations of core mathematical operations.

XLA Memory Optimization

All JIT functions are defined at module level as singletons. This ensures each function is compiled only once per unique input shape, regardless of how many TrustRegionJITFunctions instances are created. Previously, each instance created 10+ new @jit closures, causing unbounded XLA compilation cache growth.

class nlsq.core.trf_jit.TrustRegionJITFunctions[source]

Bases: object

JIT-compiled functions for Trust Region Reflective optimization algorithm.

All JIT functions are module-level singletons to prevent XLA compilation cache bloat. Each function is compiled once per unique input shape, shared across all TrustRegionJITFunctions instances.

Core Operations

  • Gradient Computation: JAX-accelerated gradient calculation using J^T * f

  • SVD Decomposition: Singular value decomposition for trust region subproblems

  • Conjugate Gradient: Iterative solver for large-scale problems

  • Cost Function Evaluation: Loss function computation with masking support

  • Hat Space Transformation: Scaled variable transformations for bounds handling

Performance Characteristics

  • Small Problems: Direct SVD solution O(mn^2 + n^3)

  • Large Problems: CG iteration O(k*mn) where k is iteration count

  • GPU Memory: Module-level singletons prevent per-instance recompilation

  • Numerical Stability: Double precision arithmetic with condition monitoring

__init__()[source]

Bind module-level JIT singletons to instance attributes.

See Also