nlsq.common_jax module

JAX-specific utilities and functions.

These are functions that were initially in the common.py file, but are have large data operations and are therefore better suited to be compiled with JAX. They are compiled with JAX and then added to the CommonJIT class.

nlsq.common_jax.phi_and_derivative_jax(alpha, suf, s, Delta)[source]

JAX-compiled phi function for trust region subproblem.

This function computes the value and derivative of the secular equation used to find the optimal Levenberg-Marquardt parameter. It is defined as “norm of regularized (by alpha) least-squares solution minus Delta”.

The function is used iteratively to find the root, which gives the optimal regularization parameter for the trust region subproblem.

Parameters:
  • alpha (float) – Current regularization parameter (Levenberg-Marquardt lambda)

  • suf (jnp.ndarray) – Product of singular values and U^T @ f (s * uf)

  • s (jnp.ndarray) – Singular values of the Jacobian

  • Delta (float) – Trust region radius

Returns:

  • phi (jnp.ndarray) – Value of the secular equation: ||p|| - Delta

  • phi_prime (jnp.ndarray) – Derivative of phi with respect to alpha

Return type:

tuple[Array, Array]

Notes

This is a JAX-jitted version of the function in common_scipy.py. Using JAX enables GPU acceleration and avoids NumPy-JAX data transfers when used with other JAX operations.

The computation follows [12] Branch, M.A., Coleman, T.F., Li, Y., “A Subspace, Interior, and Conjugate Gradient Method for Large-Scale Bound-Constrained Minimization Problems”, SIAM Journal on Scientific Computing, Vol. 21, Number 1, pp 1-23, 1999.

nlsq.common_jax.solve_lsq_trust_region_jax(n, m, uf, s, V, Delta, initial_alpha=None, rtol=0.01, max_iter=10)[source]

JAX-compiled trust-region problem solver for least-squares minimization.

This function implements a method described by J. J. More [12] and used in MINPACK, but relies on a single SVD of Jacobian instead of series of Cholesky decompositions. Before running this function, compute: U, s, VT = svd(J, full_matrices=False).

This is a pure JAX implementation using lax.while_loop for JIT compilation, avoiding NumPy-JAX data transfers when used in the optimization hot path.

Parameters:
  • n (int) – Number of variables.

  • m (int) – Number of residuals.

  • uf (jnp.ndarray) – Computed as U.T @ f.

  • s (jnp.ndarray) – Singular values of J.

  • V (jnp.ndarray) – Transpose of VT (i.e., V = VT.T).

  • Delta (float) – Radius of a trust region.

  • initial_alpha (float, optional) – Initial guess for alpha. If None, determined automatically.

  • rtol (float, optional) – Stopping tolerance for the root-finding procedure.

  • max_iter (int, optional) – Maximum allowed number of iterations.

Returns:

  • p (jnp.ndarray, shape (n,)) – Found solution of a trust-region problem.

  • alpha (jnp.ndarray) – Levenberg-Marquardt parameter (scalar as 0-d array).

  • n_iter (jnp.ndarray) – Number of iterations made (scalar as 0-d array).

Return type:

tuple[Array, Array, Array]

References

class nlsq.common_jax.CommonJIT[source]

Bases: object

JIT-compiled common functions for nonlinear least squares optimization.

This class provides GPU/TPU-accelerated implementations of mathematical operations commonly used across different optimization algorithms. All functions are JIT-compiled for maximum performance and memory efficiency.

Core Functionality

  • Quadratic Function Operations: Build and evaluate quadratic forms

  • Matrix-Vector Products: Optimized Jacobian operations

  • Robust Loss Scaling: Jacobian and residual scaling for robust methods

  • Numerical Utilities: Condition-aware computations with overflow protection

JIT Compilation Benefits

  • GPU/TPU Acceleration: All operations optimized for parallel hardware

  • Memory Efficiency: Reduced allocations through compilation optimization

  • Automatic Fusion: Operations automatically fused for better performance

  • Type Specialization: Functions compiled for specific array shapes/types

Mathematical Operations

The class implements several categories of operations:

  1. Quadratic Functions: For trust region subproblems - 1D quadratic parameterization along search directions - Quadratic model evaluation for step selection - Hessian approximation using J^T * J structure

  2. Matrix Operations: Optimized linear algebra - Jacobian-vector products with broadcasting - Scaling operations for robust loss functions - Condition-aware computations with numerical stability

  3. Robust Loss Support: For outlier-resistant fitting - Jacobian scaling using loss function derivatives - Residual weighting based on robustness weights - Numerical stability for extreme scaling factors

Performance Characteristics

  • Memory Usage: O(1) additional memory overhead

  • Compilation Time: One-time cost during initialization

  • Execution Speed: 10-100x faster than pure NumPy on GPU

  • Numerical Precision: Full double precision support

Technical Implementation

All functions use JAX JIT compilation with the following features: - Static argument handling for shape polymorphism - Automatic differentiation compatibility - XLA optimization for target hardware - Memory layout optimization for cache efficiency

Integration with Optimization Algorithms

This class is used by: - Trust Region Reflective (TRF) algorithm for quadratic models - Levenberg-Marquardt algorithm for step computation - Robust loss functions for residual scaling - Large dataset processing for memory-efficient operations

Usage Example

from nlsq.common_jax import CommonJIT
import jax.numpy as jnp

# Initialize JIT-compiled functions
cjit = CommonJIT()

# Example: Scale Jacobian for robust loss
jacobian = jnp.array([[1.0, 2.0], [3.0, 4.0]])
residuals = jnp.array([0.1, 2.5])  # Contains outlier
rho = jnp.array([loss_val, loss_deriv1, loss_deriv2])

# Apply robust scaling
scaled_J, scaled_f = cjit.scale_for_robust_loss_function(jacobian, residuals, rho)

# Example: Build quadratic model
gradient = jnp.array([0.1, -0.3])
direction = jnp.array([1.0, 0.5])
a, b, c = cjit.build_quadratic_1d(jacobian, gradient, direction)

Numerical Considerations

  • Overflow Protection: Automatic handling of extreme scaling factors

  • Underflow Prevention: Minimum threshold enforcement (EPS)

  • Condition Monitoring: Numerical stability checks for ill-conditioned operations

  • Precision Control: Double precision arithmetic throughout

__init__()[source]

Initialize CommonJIT with all compiled functions.

This creates and compiles all JIT functions during initialization for optimal runtime performance. Functions are compiled once and reused across multiple optimization runs.

Compiled Functions Created

  • Quadratic function builders and evaluators

  • Matrix-vector dot products with broadcasting

  • Jacobian sum operations for constraint handling

  • Robust loss function scaling operations

create_scale_for_robust_loss_function()[source]

Create the scaling function for the loss functions

build_quadratic_1d(J, g, s, diag=None, s0=None)[source]

Parameterize a multivariate quadratic function along a line.

The resulting univariate quadratic function is given as follows:

f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t)

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (n,)) – Direction vector of a line.

  • diag (None or ndarray with shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

  • s0 (None or ndarray with shape (n,), optional) – Initial point. If None, assumed to be 0.

Returns:

  • a (float) – Coefficient for t**2.

  • b (float) – Coefficient for t.

  • c (float) – Free term. Returned only if s0 is provided.

Return type:

tuple[ndarray, ndarray, ndarray] | tuple[ndarray, ndarray]

compute_jac_scale(J, scale_inv_old=None)[source]

Compute variables scale based on the Jacobian matrix.

Returns JAX arrays to keep results on-device and avoid GPU→CPU sync.

Parameters:
  • J (jnp.ndarray) – Jacobian matrix.

  • scale_inv_old (jnp.ndarray | np.ndarray | None, optional) – Previous scale, by default None

Returns:

  • scale (jnp.ndarray) – Scale for the variables.

  • scale_inv (jnp.ndarray) – Inverse of the scale for the variables.

Return type:

tuple[Array, Array]

create_js_dot()[source]

Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both.

evaluate_quadratic(J, g, s_np, diag=None)[source]

Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (k, n) or (n,)) – Array containing steps as rows.

  • diag (ndarray, shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

Returns:

values – Values of the function. If s was 2-D, then ndarray is returned, otherwise, float is returned.

Return type:

ndarray with shape (k,) or float

create_quadratic_funcs()[source]
create_jac_sum()[source]

Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this.

nlsq.common_jax.update_tr_radius_jax(Delta, actual_reduction, predicted_reduction, step_norm, bound_hit)[source]

JIT-compiled trust-region radius update (replaces common_scipy.update_tr_radius).

Uses branchless jnp.where to avoid Python control flow and GPU→CPU sync.

Parameters:
  • Delta (float) – Current trust region radius.

  • actual_reduction (float) – Actual cost reduction.

  • predicted_reduction (float) – Predicted cost reduction.

  • step_norm (float) – Norm of the step.

  • bound_hit (bool) – Whether the step hit the trust region boundary.

Returns:

  • Delta (jnp.ndarray) – Updated trust region radius.

  • ratio (jnp.ndarray) – Ratio of actual to predicted reduction.

Return type:

tuple[Array, Array]

nlsq.common_jax.check_termination_jax(dF, F, dx_norm, x_norm, ratio, ftol, xtol)[source]

JIT-compiled termination check (replaces common_scipy.check_termination).

Returns integer status code as a JAX scalar (no GPU→CPU sync needed until the Python-level loop checks it). Returns 0 for “not terminated”.

Parameters:
  • dF (float) – Change in cost function.

  • F (float) – Current cost function value.

  • dx_norm (float) – Norm of the step.

  • x_norm (float) – Norm of the current parameters.

  • ratio (float) – Ratio of actual to predicted reduction.

  • ftol (float) – Cost function tolerance.

  • xtol (float) – Parameter tolerance.

Returns:

status – 0=not terminated, 2=ftol, 3=xtol, 4=both.

Return type:

jnp.ndarray

nlsq.common_jax.CL_scaling_vector_jax(x, g, lb, ub)[source]

JIT-compiled Coleman-Li scaling vector (replaces common_scipy.CL_scaling_vector).

Keeps result on-device to avoid GPU→CPU sync at call sites.

nlsq.common_jax.in_bounds_jax(x, lb, ub)[source]

JIT-compiled bounds check (replaces common_scipy.in_bounds).

Returns a JAX scalar boolean, avoiding GPU→CPU sync until Python checks it.

nlsq.common_jax.make_strictly_feasible_jax(x, lb, ub)[source]

JIT-compiled strict feasibility projection for rstep=0 case.

Shifts boundary points to the interior using jnp.nextafter, with a guard against XLA’s DAZ (Denormals-Are-Zero) behavior on CPU where nextafter(0, x) produces a denormal that compares equal to zero. Only supports rstep=0 (the only value used in trf.py hot path).

nlsq.common_jax.step_size_to_bound_jax(x, s, lb, ub)[source]

Compute min step size to reach a bound (JAX version).

Parameters:
  • x (jnp.ndarray) – Current parameter vector.

  • s (jnp.ndarray) – Step direction.

  • lb (jnp.ndarray) – Lower bounds.

  • ub (jnp.ndarray) – Upper bounds.

Returns:

  • min_step (jnp.ndarray) – Scalar minimum step to reach any bound.

  • hits (jnp.ndarray) – Array indicating which bounds are hit (-1=lower, 0=none, 1=upper).

Return type:

tuple[Array, Array]

nlsq.common_jax.intersect_trust_region_jax(x, s, Delta)[source]

Find intersection of a line with trust region boundary (JAX version).

Solves ||(x + s*t)||^2 = Delta^2 for t.

Parameters:
  • x (jnp.ndarray) – Starting point.

  • s (jnp.ndarray) – Direction.

  • Delta (float) – Trust region radius.

Returns:

  • t_neg (jnp.ndarray) – Negative root.

  • t_pos (jnp.ndarray) – Positive root.

Return type:

tuple[Array, Array]

nlsq.common_jax.minimize_quadratic_1d_jax(a, b, lb, ub, c=0.0)[source]

Minimize 1-D quadratic function subject to bounds (JAX version).

Minimizes f(t) = a*t^2 + b*t + c on [lb, ub].

Parameters:
  • a (jnp.ndarray) – Quadratic coefficient.

  • b (jnp.ndarray) – Linear coefficient.

  • lb (float) – Lower bound.

  • ub (float) – Upper bound.

  • c (jnp.ndarray) – Constant term (default 0).

Returns:

  • t_min (jnp.ndarray) – Minimizing point.

  • y_min (jnp.ndarray) – Minimum value.

Return type:

tuple[Array, Array]