nlsq.common_jax module¶
JAX-specific utilities and functions.
These are functions that were initially in the common.py file, but are have large data operations and are therefore better suited to be compiled with JAX. They are compiled with JAX and then added to the CommonJIT class.
- nlsq.common_jax.phi_and_derivative_jax(alpha, suf, s, Delta)[source]
JAX-compiled phi function for trust region subproblem.
This function computes the value and derivative of the secular equation used to find the optimal Levenberg-Marquardt parameter. It is defined as “norm of regularized (by alpha) least-squares solution minus Delta”.
The function is used iteratively to find the root, which gives the optimal regularization parameter for the trust region subproblem.
- Parameters:
- Returns:
phi (jnp.ndarray) – Value of the secular equation: ||p|| - Delta
phi_prime (jnp.ndarray) – Derivative of phi with respect to alpha
- Return type:
Notes
This is a JAX-jitted version of the function in common_scipy.py. Using JAX enables GPU acceleration and avoids NumPy-JAX data transfers when used with other JAX operations.
The computation follows [12] Branch, M.A., Coleman, T.F., Li, Y., “A Subspace, Interior, and Conjugate Gradient Method for Large-Scale Bound-Constrained Minimization Problems”, SIAM Journal on Scientific Computing, Vol. 21, Number 1, pp 1-23, 1999.
- nlsq.common_jax.solve_lsq_trust_region_jax(n, m, uf, s, V, Delta, initial_alpha=None, rtol=0.01, max_iter=10)[source]
JAX-compiled trust-region problem solver for least-squares minimization.
This function implements a method described by J. J. More [12] and used in MINPACK, but relies on a single SVD of Jacobian instead of series of Cholesky decompositions. Before running this function, compute:
U, s, VT = svd(J, full_matrices=False).This is a pure JAX implementation using lax.while_loop for JIT compilation, avoiding NumPy-JAX data transfers when used in the optimization hot path.
- Parameters:
n (int) – Number of variables.
m (int) – Number of residuals.
uf (jnp.ndarray) – Computed as U.T @ f.
s (jnp.ndarray) – Singular values of J.
V (jnp.ndarray) – Transpose of VT (i.e., V = VT.T).
Delta (float) – Radius of a trust region.
initial_alpha (float, optional) – Initial guess for alpha. If None, determined automatically.
rtol (float, optional) – Stopping tolerance for the root-finding procedure.
max_iter (int, optional) – Maximum allowed number of iterations.
- Returns:
p (jnp.ndarray, shape (n,)) – Found solution of a trust-region problem.
alpha (jnp.ndarray) – Levenberg-Marquardt parameter (scalar as 0-d array).
n_iter (jnp.ndarray) – Number of iterations made (scalar as 0-d array).
- Return type:
References
- class nlsq.common_jax.CommonJIT[source]
Bases:
objectJIT-compiled common functions for nonlinear least squares optimization.
This class provides GPU/TPU-accelerated implementations of mathematical operations commonly used across different optimization algorithms. All functions are JIT-compiled for maximum performance and memory efficiency.
Core Functionality¶
Quadratic Function Operations: Build and evaluate quadratic forms
Matrix-Vector Products: Optimized Jacobian operations
Robust Loss Scaling: Jacobian and residual scaling for robust methods
Numerical Utilities: Condition-aware computations with overflow protection
JIT Compilation Benefits¶
GPU/TPU Acceleration: All operations optimized for parallel hardware
Memory Efficiency: Reduced allocations through compilation optimization
Automatic Fusion: Operations automatically fused for better performance
Type Specialization: Functions compiled for specific array shapes/types
Mathematical Operations¶
The class implements several categories of operations:
Quadratic Functions: For trust region subproblems - 1D quadratic parameterization along search directions - Quadratic model evaluation for step selection - Hessian approximation using J^T * J structure
Matrix Operations: Optimized linear algebra - Jacobian-vector products with broadcasting - Scaling operations for robust loss functions - Condition-aware computations with numerical stability
Robust Loss Support: For outlier-resistant fitting - Jacobian scaling using loss function derivatives - Residual weighting based on robustness weights - Numerical stability for extreme scaling factors
Performance Characteristics¶
Memory Usage: O(1) additional memory overhead
Compilation Time: One-time cost during initialization
Execution Speed: 10-100x faster than pure NumPy on GPU
Numerical Precision: Full double precision support
Technical Implementation¶
All functions use JAX JIT compilation with the following features: - Static argument handling for shape polymorphism - Automatic differentiation compatibility - XLA optimization for target hardware - Memory layout optimization for cache efficiency
Integration with Optimization Algorithms¶
This class is used by: - Trust Region Reflective (TRF) algorithm for quadratic models - Levenberg-Marquardt algorithm for step computation - Robust loss functions for residual scaling - Large dataset processing for memory-efficient operations
Usage Example¶
from nlsq.common_jax import CommonJIT import jax.numpy as jnp # Initialize JIT-compiled functions cjit = CommonJIT() # Example: Scale Jacobian for robust loss jacobian = jnp.array([[1.0, 2.0], [3.0, 4.0]]) residuals = jnp.array([0.1, 2.5]) # Contains outlier rho = jnp.array([loss_val, loss_deriv1, loss_deriv2]) # Apply robust scaling scaled_J, scaled_f = cjit.scale_for_robust_loss_function(jacobian, residuals, rho) # Example: Build quadratic model gradient = jnp.array([0.1, -0.3]) direction = jnp.array([1.0, 0.5]) a, b, c = cjit.build_quadratic_1d(jacobian, gradient, direction)
Numerical Considerations¶
Overflow Protection: Automatic handling of extreme scaling factors
Underflow Prevention: Minimum threshold enforcement (EPS)
Condition Monitoring: Numerical stability checks for ill-conditioned operations
Precision Control: Double precision arithmetic throughout
- __init__()[source]
Initialize CommonJIT with all compiled functions.
This creates and compiles all JIT functions during initialization for optimal runtime performance. Functions are compiled once and reused across multiple optimization runs.
Compiled Functions Created¶
Quadratic function builders and evaluators
Matrix-vector dot products with broadcasting
Jacobian sum operations for constraint handling
Robust loss function scaling operations
- create_scale_for_robust_loss_function()[source]
Create the scaling function for the loss functions
- build_quadratic_1d(J, g, s, diag=None, s0=None)[source]
Parameterize a multivariate quadratic function along a line.
The resulting univariate quadratic function is given as follows:
f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t)
- Parameters:
J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.
g (ndarray, shape (n,)) – Gradient, defines the linear term.
s (ndarray, shape (n,)) – Direction vector of a line.
diag (None or ndarray with shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.
s0 (None or ndarray with shape (n,), optional) – Initial point. If None, assumed to be 0.
- Returns:
a (float) – Coefficient for t**2.
b (float) – Coefficient for t.
c (float) – Free term. Returned only if s0 is provided.
- Return type:
- compute_jac_scale(J, scale_inv_old=None)[source]
Compute variables scale based on the Jacobian matrix.
Returns JAX arrays to keep results on-device and avoid GPU→CPU sync.
- create_js_dot()[source]
Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both.
- evaluate_quadratic(J, g, s_np, diag=None)[source]
Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.
- Parameters:
J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.
g (ndarray, shape (n,)) – Gradient, defines the linear term.
s (ndarray, shape (k, n) or (n,)) – Array containing steps as rows.
diag (ndarray, shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.
- Returns:
values – Values of the function. If s was 2-D, then ndarray is returned, otherwise, float is returned.
- Return type:
ndarray with shape (k,) or float
- create_quadratic_funcs()[source]
- create_jac_sum()[source]
Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this.
- nlsq.common_jax.update_tr_radius_jax(Delta, actual_reduction, predicted_reduction, step_norm, bound_hit)[source]
JIT-compiled trust-region radius update (replaces common_scipy.update_tr_radius).
Uses branchless jnp.where to avoid Python control flow and GPU→CPU sync.
- Parameters:
- Returns:
Delta (jnp.ndarray) – Updated trust region radius.
ratio (jnp.ndarray) – Ratio of actual to predicted reduction.
- Return type:
- nlsq.common_jax.check_termination_jax(dF, F, dx_norm, x_norm, ratio, ftol, xtol)[source]
JIT-compiled termination check (replaces common_scipy.check_termination).
Returns integer status code as a JAX scalar (no GPU→CPU sync needed until the Python-level loop checks it). Returns 0 for “not terminated”.
- Parameters:
- Returns:
status – 0=not terminated, 2=ftol, 3=xtol, 4=both.
- Return type:
jnp.ndarray
- nlsq.common_jax.CL_scaling_vector_jax(x, g, lb, ub)[source]
JIT-compiled Coleman-Li scaling vector (replaces common_scipy.CL_scaling_vector).
Keeps result on-device to avoid GPU→CPU sync at call sites.
- nlsq.common_jax.in_bounds_jax(x, lb, ub)[source]
JIT-compiled bounds check (replaces common_scipy.in_bounds).
Returns a JAX scalar boolean, avoiding GPU→CPU sync until Python checks it.
- nlsq.common_jax.make_strictly_feasible_jax(x, lb, ub)[source]
JIT-compiled strict feasibility projection for rstep=0 case.
Shifts boundary points to the interior using jnp.nextafter, with a guard against XLA’s DAZ (Denormals-Are-Zero) behavior on CPU where nextafter(0, x) produces a denormal that compares equal to zero. Only supports rstep=0 (the only value used in trf.py hot path).
- nlsq.common_jax.step_size_to_bound_jax(x, s, lb, ub)[source]
Compute min step size to reach a bound (JAX version).
- Parameters:
x (jnp.ndarray) – Current parameter vector.
s (jnp.ndarray) – Step direction.
lb (jnp.ndarray) – Lower bounds.
ub (jnp.ndarray) – Upper bounds.
- Returns:
min_step (jnp.ndarray) – Scalar minimum step to reach any bound.
hits (jnp.ndarray) – Array indicating which bounds are hit (-1=lower, 0=none, 1=upper).
- Return type:
- nlsq.common_jax.intersect_trust_region_jax(x, s, Delta)[source]
Find intersection of a line with trust region boundary (JAX version).
Solves ||(x + s*t)||^2 = Delta^2 for t.
- nlsq.common_jax.minimize_quadratic_1d_jax(a, b, lb, ub, c=0.0)[source]
Minimize 1-D quadratic function subject to bounds (JAX version).
Minimizes f(t) = a*t^2 + b*t + c on [lb, ub].
- Parameters:
- Returns:
t_min (jnp.ndarray) – Minimizing point.
y_min (jnp.ndarray) – Minimum value.
- Return type: