"""
These are functions that were initially in the common.py file, but are have
large data operations and are therefore better suited to be compiled with
JAX. They are compiled with JAX and then added to the CommonJIT class.
"""
import numpy as np
# Initialize JAX configuration through central config
from nlsq.config import JAXConfig
_jax_config = JAXConfig()
import jax.numpy as jnp
from jax import jit, lax
EPS = np.finfo(float).eps
[docs]
@jit
def phi_and_derivative_jax(
alpha: float, suf: jnp.ndarray, s: jnp.ndarray, Delta: float
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""JAX-compiled phi function for trust region subproblem.
This function computes the value and derivative of the secular equation
used to find the optimal Levenberg-Marquardt parameter. It is defined as
"norm of regularized (by alpha) least-squares solution minus Delta".
The function is used iteratively to find the root, which gives the optimal
regularization parameter for the trust region subproblem.
Parameters
----------
alpha : float
Current regularization parameter (Levenberg-Marquardt lambda)
suf : jnp.ndarray
Product of singular values and U^T @ f (s * uf)
s : jnp.ndarray
Singular values of the Jacobian
Delta : float
Trust region radius
Returns
-------
phi : jnp.ndarray
Value of the secular equation: ||p|| - Delta
phi_prime : jnp.ndarray
Derivative of phi with respect to alpha
Notes
-----
This is a JAX-jitted version of the function in common_scipy.py.
Using JAX enables GPU acceleration and avoids NumPy-JAX data transfers
when used with other JAX operations.
The computation follows [12] Branch, M.A., Coleman, T.F., Li, Y.,
"A Subspace, Interior, and Conjugate Gradient Method for Large-Scale
Bound-Constrained Minimization Problems", SIAM Journal on Scientific
Computing, Vol. 21, Number 1, pp 1-23, 1999.
"""
denom = s**2 + alpha
p_norm = jnp.linalg.norm(suf / denom)
phi = p_norm - Delta
safe_p_norm = jnp.where(p_norm == 0.0, jnp.finfo(p_norm.dtype).tiny, p_norm)
phi_prime = -jnp.sum(suf**2 / denom**3) / safe_p_norm
return phi, phi_prime
def _solve_lsq_trust_region_jax_impl(
n: int,
m: int,
uf: jnp.ndarray,
s: jnp.ndarray,
V: jnp.ndarray,
Delta: float,
initial_alpha: float,
has_initial_alpha: bool,
rtol: float,
max_iter: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Internal implementation for solve_lsq_trust_region_jax."""
suf = s * uf
# Check if J has full rank
threshold = EPS * m * s[0]
full_rank = lax.cond(
m >= n,
lambda: s[-1] > threshold,
lambda: False,
)
# Gauss-Newton step (if full rank)
p_gn = -V @ (uf / s)
p_gn_norm = jnp.linalg.norm(p_gn)
use_gauss_newton = full_rank & (p_gn_norm <= Delta)
# Compute alpha bounds
safe_Delta = jnp.maximum(Delta, jnp.finfo(jnp.result_type(Delta)).tiny)
alpha_upper = jnp.linalg.norm(suf) / safe_Delta
# Compute alpha_lower based on full_rank
def compute_alpha_lower_full_rank():
phi_val, phi_prime_val = phi_and_derivative_jax(0.0, suf, s, Delta)
safe_phi_prime = jnp.where(
phi_prime_val == 0.0, jnp.finfo(phi_prime_val.dtype).tiny, phi_prime_val
)
return -phi_val / safe_phi_prime
alpha_lower = lax.cond(
full_rank,
compute_alpha_lower_full_rank,
lambda: 0.0,
)
# Compute default alpha
default_alpha = jnp.maximum(
0.001 * alpha_upper, jnp.sqrt(alpha_lower * alpha_upper)
)
# Use provided initial_alpha only if valid
# Invalid if: no initial_alpha provided, OR (not full_rank AND initial_alpha == 0)
use_provided = has_initial_alpha & ~(~full_rank & (initial_alpha == 0.0))
alpha_start = lax.cond(
use_provided,
lambda: initial_alpha,
lambda: default_alpha,
)
# State for while_loop: (alpha, alpha_lower, alpha_upper, iteration, converged)
def loop_cond(state):
_alpha, _alpha_lower, _alpha_upper, iteration, converged = state
return (iteration < max_iter) & ~converged
def loop_body(state):
alpha, alpha_lower, alpha_upper, iteration, _ = state
# Reset alpha if out of bounds
alpha = lax.cond(
(alpha < alpha_lower) | (alpha > alpha_upper),
lambda: jnp.maximum(
0.001 * alpha_upper, jnp.sqrt(alpha_lower * alpha_upper)
),
lambda: alpha,
)
phi_val, phi_prime_val = phi_and_derivative_jax(alpha, suf, s, Delta)
# Update alpha_upper if phi < 0
alpha_upper_new = lax.cond(
phi_val < 0,
lambda: alpha,
lambda: alpha_upper,
)
# Update alpha using Newton step (guard phi_prime_val == 0)
safe_phi_prime = jnp.where(
phi_prime_val == 0.0, jnp.finfo(phi_prime_val.dtype).tiny, phi_prime_val
)
ratio = phi_val / safe_phi_prime
alpha_lower_new = jnp.maximum(alpha_lower, alpha - ratio)
alpha_new = alpha - (phi_val + Delta) * ratio / Delta
# Check convergence
converged = jnp.abs(phi_val) < rtol * Delta
return (alpha_new, alpha_lower_new, alpha_upper_new, iteration + 1, converged)
# Run the while loop
init_state = (alpha_start, alpha_lower, alpha_upper, jnp.array(0), jnp.array(False))
final_alpha, _, _, n_iter_final, _ = lax.while_loop(
loop_cond, loop_body, init_state
)
# Compute final solution p
p_iterative = -V @ (suf / (s**2 + final_alpha))
# Normalize p to exactly Delta to prevent numerical drift
p_iterative_norm = jnp.linalg.norm(p_iterative)
safe_norm = jnp.maximum(p_iterative_norm, jnp.finfo(p_iterative.dtype).tiny)
p_iterative_normalized = p_iterative * (Delta / safe_norm)
# Select between Gauss-Newton and iterative solution
p_final = lax.cond(
use_gauss_newton,
lambda: p_gn,
lambda: p_iterative_normalized,
)
alpha_final = lax.cond(
use_gauss_newton,
lambda: jnp.array(0.0),
lambda: final_alpha,
)
n_iter_out = lax.cond(
use_gauss_newton,
lambda: jnp.array(0),
lambda: n_iter_final,
)
return p_final, alpha_final, n_iter_out
# Create a JIT-compiled version of the implementation
_solve_lsq_trust_region_jax_jit = jit(
_solve_lsq_trust_region_jax_impl, static_argnums=(0, 1, 8, 9)
)
[docs]
def solve_lsq_trust_region_jax(
n: int,
m: int,
uf: jnp.ndarray,
s: jnp.ndarray,
V: jnp.ndarray,
Delta: float,
initial_alpha: float | None = None,
rtol: float = 0.01,
max_iter: int = 10,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""JAX-compiled trust-region problem solver for least-squares minimization.
This function implements a method described by J. J. More [12] and used
in MINPACK, but relies on a single SVD of Jacobian instead of series
of Cholesky decompositions. Before running this function, compute:
``U, s, VT = svd(J, full_matrices=False)``.
This is a pure JAX implementation using lax.while_loop for JIT compilation,
avoiding NumPy-JAX data transfers when used in the optimization hot path.
Parameters
----------
n : int
Number of variables.
m : int
Number of residuals.
uf : jnp.ndarray
Computed as U.T @ f.
s : jnp.ndarray
Singular values of J.
V : jnp.ndarray
Transpose of VT (i.e., V = VT.T).
Delta : float
Radius of a trust region.
initial_alpha : float, optional
Initial guess for alpha. If None, determined automatically.
rtol : float, optional
Stopping tolerance for the root-finding procedure.
max_iter : int, optional
Maximum allowed number of iterations.
Returns
-------
p : jnp.ndarray, shape (n,)
Found solution of a trust-region problem.
alpha : jnp.ndarray
Levenberg-Marquardt parameter (scalar as 0-d array).
n_iter : jnp.ndarray
Number of iterations made (scalar as 0-d array).
References
----------
.. [12] More, J. J., "The Levenberg-Marquardt Algorithm: Implementation
and Theory," Numerical Analysis, ed. G. A. Watson, Lecture Notes
in Mathematics 630, Springer Verlag, pp. 105-116, 1977.
"""
# Handle None initial_alpha
if initial_alpha is None:
init_alpha_val = 0.0
has_initial_alpha = False
else:
init_alpha_val = initial_alpha
has_initial_alpha = True
return _solve_lsq_trust_region_jax_jit(
n, m, uf, s, V, Delta, init_alpha_val, has_initial_alpha, rtol, max_iter
)
[docs]
class CommonJIT:
"""JIT-compiled common functions for nonlinear least squares optimization.
This class provides GPU/TPU-accelerated implementations of mathematical
operations commonly used across different optimization algorithms. All
functions are JIT-compiled for maximum performance and memory efficiency.
Core Functionality
------------------
- **Quadratic Function Operations**: Build and evaluate quadratic forms
- **Matrix-Vector Products**: Optimized Jacobian operations
- **Robust Loss Scaling**: Jacobian and residual scaling for robust methods
- **Numerical Utilities**: Condition-aware computations with overflow protection
JIT Compilation Benefits
-------------------------
- **GPU/TPU Acceleration**: All operations optimized for parallel hardware
- **Memory Efficiency**: Reduced allocations through compilation optimization
- **Automatic Fusion**: Operations automatically fused for better performance
- **Type Specialization**: Functions compiled for specific array shapes/types
Mathematical Operations
------------------------
The class implements several categories of operations:
1. **Quadratic Functions**: For trust region subproblems
- 1D quadratic parameterization along search directions
- Quadratic model evaluation for step selection
- Hessian approximation using J^T * J structure
2. **Matrix Operations**: Optimized linear algebra
- Jacobian-vector products with broadcasting
- Scaling operations for robust loss functions
- Condition-aware computations with numerical stability
3. **Robust Loss Support**: For outlier-resistant fitting
- Jacobian scaling using loss function derivatives
- Residual weighting based on robustness weights
- Numerical stability for extreme scaling factors
Performance Characteristics
----------------------------
- **Memory Usage**: O(1) additional memory overhead
- **Compilation Time**: One-time cost during initialization
- **Execution Speed**: 10-100x faster than pure NumPy on GPU
- **Numerical Precision**: Full double precision support
Technical Implementation
------------------------
All functions use JAX JIT compilation with the following features:
- Static argument handling for shape polymorphism
- Automatic differentiation compatibility
- XLA optimization for target hardware
- Memory layout optimization for cache efficiency
Integration with Optimization Algorithms
-----------------------------------------
This class is used by:
- Trust Region Reflective (TRF) algorithm for quadratic models
- Levenberg-Marquardt algorithm for step computation
- Robust loss functions for residual scaling
- Large dataset processing for memory-efficient operations
Usage Example
-------------
::
from nlsq.common_jax import CommonJIT
import jax.numpy as jnp
# Initialize JIT-compiled functions
cjit = CommonJIT()
# Example: Scale Jacobian for robust loss
jacobian = jnp.array([[1.0, 2.0], [3.0, 4.0]])
residuals = jnp.array([0.1, 2.5]) # Contains outlier
rho = jnp.array([loss_val, loss_deriv1, loss_deriv2])
# Apply robust scaling
scaled_J, scaled_f = cjit.scale_for_robust_loss_function(jacobian, residuals, rho)
# Example: Build quadratic model
gradient = jnp.array([0.1, -0.3])
direction = jnp.array([1.0, 0.5])
a, b, c = cjit.build_quadratic_1d(jacobian, gradient, direction)
Numerical Considerations
-------------------------
- **Overflow Protection**: Automatic handling of extreme scaling factors
- **Underflow Prevention**: Minimum threshold enforcement (EPS)
- **Condition Monitoring**: Numerical stability checks for ill-conditioned operations
- **Precision Control**: Double precision arithmetic throughout
"""
[docs]
def __init__(self):
"""Initialize CommonJIT with all compiled functions.
This creates and compiles all JIT functions during initialization
for optimal runtime performance. Functions are compiled once and
reused across multiple optimization runs.
Compiled Functions Created
---------------------------
- Quadratic function builders and evaluators
- Matrix-vector dot products with broadcasting
- Jacobian sum operations for constraint handling
- Robust loss function scaling operations
"""
self.create_quadratic_funcs()
self.create_js_dot()
self.create_jac_sum()
self.create_scale_for_robust_loss_function()
[docs]
def create_scale_for_robust_loss_function(self):
"""Create the scaling function for the loss functions"""
@jit
def scale_for_robust_loss_function(
J: jnp.ndarray, f: jnp.ndarray, rho: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Scale Jacobian and residuals for a robust loss function.
Arrays are modified in place.
Parameters
----------
J : jnp.ndarray
Jacobian matrix.
f : jnp.ndarray
Residuals.
rho : jnp.ndarray
Cost function evaluation.
"""
# Scale Jacobian and residuals for robust loss function
J_scale = rho[1] + 2 * rho[2] * f**2
# Prevent division by zero
mask = J_scale < EPS
J_scale = jnp.where(mask, EPS, J_scale)
J_scale = J_scale**0.5
# Compute scaling factors
fscale = rho[1] / J_scale
# Apply scaling
f = f * fscale
J = J * J_scale[:, jnp.newaxis]
return J, f
self.scale_for_robust_loss_function = scale_for_robust_loss_function
[docs]
def build_quadratic_1d(
self,
J: jnp.ndarray,
g: jnp.ndarray,
s: jnp.ndarray,
diag: jnp.ndarray | None = None,
s0: jnp.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray]:
"""Parameterize a multivariate quadratic function along a line.
The resulting univariate quadratic function is given as follows:
f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t)
Parameters
----------
J : ndarray, sparse matrix or LinearOperator, shape (m, n)
Jacobian matrix, affects the quadratic term.
g : ndarray, shape (n,)
Gradient, defines the linear term.
s : ndarray, shape (n,)
Direction vector of a line.
diag : None or ndarray with shape (n,), optional
Addition diagonal part, affects the quadratic term.
If None, assumed to be 0.
s0 : None or ndarray with shape (n,), optional
Initial point. If None, assumed to be 0.
Returns
-------
a : float
Coefficient for t**2.
b : float
Coefficient for t.
c : float
Free term. Returned only if `s0` is provided.
"""
# OPT-2: Use jnp.asarray() to avoid copy if already JAX array
s_jnp = jnp.asarray(s)
v = self.js_dot(J, s_jnp)
a = np.dot(v, v)
if diag is not None:
a += np.dot(s * diag, s)
a *= 0.5
b = np.dot(g, s)
if s0 is not None:
# OPT-2: Use jnp.asarray() to avoid copy if already JAX array
s0_jnp = jnp.asarray(s0)
u = self.js0_dot(J, s0_jnp)
b += np.dot(u, v)
c = 0.5 * np.dot(u, u) + np.dot(g, s0)
if diag is not None:
b += np.dot(s0 * diag, s)
c += 0.5 * np.dot(s0 * diag, s0)
return a, b, c
else:
return a, b
[docs]
def compute_jac_scale(
self, J: jnp.ndarray, scale_inv_old: jnp.ndarray | np.ndarray | None = None
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Compute variables scale based on the Jacobian matrix.
Returns JAX arrays to keep results on-device and avoid GPU→CPU sync.
Parameters
----------
J : jnp.ndarray
Jacobian matrix.
scale_inv_old : jnp.ndarray | np.ndarray | None, optional
Previous scale, by default None
Returns
-------
scale : jnp.ndarray
Scale for the variables.
scale_inv : jnp.ndarray
Inverse of the scale for the variables.
"""
scale_inv = self.jac_sum_func(J)
if scale_inv_old is None:
scale_inv = jnp.where(scale_inv == 0, 1.0, scale_inv)
else:
scale_inv = jnp.maximum(scale_inv, scale_inv_old)
return 1 / scale_inv, scale_inv
[docs]
def create_js_dot(self):
"""Create the functions for the dot product of the Jacobian and the
search direction. We need two functions because s and s0 are different
shapes which causes retracing of the function if we use the same
function for both.
"""
@jit
def js_dot(J: jnp.ndarray, s: jnp.ndarray) -> jnp.ndarray:
return J.dot(s)
@jit
def js0_dot(J: jnp.ndarray, s0: jnp.ndarray) -> jnp.ndarray:
return J.dot(s0)
self.js_dot = js_dot
self.js0_dot = js0_dot
[docs]
def evaluate_quadratic(
self,
J: jnp.ndarray,
g: jnp.ndarray,
s_np: np.ndarray,
diag: np.ndarray | None = None,
) -> jnp.ndarray:
"""Compute values of a quadratic function arising in least squares.
The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.
Parameters
----------
J : ndarray, sparse matrix or LinearOperator, shape (m, n)
Jacobian matrix, affects the quadratic term.
g : ndarray, shape (n,)
Gradient, defines the linear term.
s : ndarray, shape (k, n) or (n,)
Array containing steps as rows.
diag : ndarray, shape (n,), optional
Addition diagonal part, affects the quadratic term.
If None, assumed to be 0.
Returns
-------
values : ndarray with shape (k,) or float
Values of the function. If `s` was 2-D, then ndarray is
returned, otherwise, float is returned.
"""
# OPT-2: Use jnp.asarray() to avoid copy if already JAX array
s = jnp.asarray(s_np)
if s.ndim == 1:
if diag is None:
return self.evaluate_quadratic1(J, g, s)
else:
return self.evaluate_quadratic_diagonal1(J, g, s, diag)
elif diag is None:
return self.evaluate_quadratic2(J, g, s)
else:
return self.evaluate_quadratic_diagonal2(J, g, s, diag)
[docs]
def create_quadratic_funcs(self):
@jit
def evaluate_quadratic1(J, g, s):
Js = J.dot(s)
q = jnp.dot(Js, Js)
l = jnp.dot(s, g)
return 0.5 * q + l
@jit
def evaluate_quadratic_diagonal1(J, g, s, diag):
Js = J.dot(s)
q = jnp.dot(Js, Js) + jnp.dot(s * diag, s)
l = jnp.dot(s, g)
return 0.5 * q + l
@jit
def evaluate_quadratic2(J, g, s):
Js = J.dot(s.T)
q = jnp.sum(Js**2, axis=0)
l = jnp.dot(s, g)
return 0.5 * q + l
@jit
def evaluate_quadratic_diagonal2(J, g, s, diag):
Js = J.dot(s.T)
q = jnp.sum(Js**2, axis=0) + jnp.sum(diag * s**2, axis=1)
l = jnp.dot(s, g)
return 0.5 * q + l
self.evaluate_quadratic1 = evaluate_quadratic1
self.evaluate_quadratic_diagonal1 = evaluate_quadratic_diagonal1
self.evaluate_quadratic2 = evaluate_quadratic2
self.evaluate_quadratic_diagonal2 = evaluate_quadratic_diagonal2
[docs]
def create_jac_sum(self):
"""Create the function for the sum of the Jacobian squared and then
taking the square root. This is used to compute the scale for the
variables. Can potentially remove this.
"""
@jit
def jac_sum_func(J):
return jnp.sum(J**2, axis=0) ** 0.5
self.jac_sum_func = jac_sum_func
# ---------------------------------------------------------------------------
# JIT-compiled versions of trust-region helper functions (B002, B003 fixes)
# These eliminate GPU→CPU sync overhead by keeping all operations on device.
# ---------------------------------------------------------------------------
[docs]
@jit
def update_tr_radius_jax(
Delta: float,
actual_reduction: float,
predicted_reduction: float,
step_norm: float,
bound_hit: bool,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""JIT-compiled trust-region radius update (replaces common_scipy.update_tr_radius).
Uses branchless jnp.where to avoid Python control flow and GPU→CPU sync.
Parameters
----------
Delta : float
Current trust region radius.
actual_reduction : float
Actual cost reduction.
predicted_reduction : float
Predicted cost reduction.
step_norm : float
Norm of the step.
bound_hit : bool
Whether the step hit the trust region boundary.
Returns
-------
Delta : jnp.ndarray
Updated trust region radius.
ratio : jnp.ndarray
Ratio of actual to predicted reduction.
"""
# Compute ratio with branchless logic
safe_pred = jnp.where(predicted_reduction != 0, predicted_reduction, 1.0)
ratio = jnp.where(
predicted_reduction > 0,
actual_reduction / safe_pred,
jnp.where(
(predicted_reduction == 0) & (actual_reduction == 0),
1.0,
0.0,
),
)
# Update Delta with branchless logic
Delta_shrunk = 0.25 * step_norm
Delta_grown = Delta * 2.0
Delta_new = jnp.where(
ratio < 0.25,
Delta_shrunk,
jnp.where((ratio > 0.75) & bound_hit, Delta_grown, Delta),
)
return Delta_new, ratio
[docs]
@jit
def check_termination_jax(
dF: float,
F: float,
dx_norm: float,
x_norm: float,
ratio: float,
ftol: float,
xtol: float,
) -> jnp.ndarray:
"""JIT-compiled termination check (replaces common_scipy.check_termination).
Returns integer status code as a JAX scalar (no GPU→CPU sync needed until
the Python-level loop checks it). Returns 0 for "not terminated".
Parameters
----------
dF : float
Change in cost function.
F : float
Current cost function value.
dx_norm : float
Norm of the step.
x_norm : float
Norm of the current parameters.
ratio : float
Ratio of actual to predicted reduction.
ftol : float
Cost function tolerance.
xtol : float
Parameter tolerance.
Returns
-------
status : jnp.ndarray
0=not terminated, 2=ftol, 3=xtol, 4=both.
"""
ftol_satisfied = (dF < ftol * F) & (ratio > 0.25)
xtol_satisfied = dx_norm < xtol * (xtol + x_norm)
status = jnp.where(
ftol_satisfied & xtol_satisfied,
4,
jnp.where(ftol_satisfied, 2, jnp.where(xtol_satisfied, 3, 0)),
)
return status
# ---------------------------------------------------------------------------
# JIT-compiled bounds helper functions (S07-S10)
# These keep bounds operations on-device, avoiding NumPy→JAX transfers.
# The NumPy originals in common_scipy.py remain for functions that need them.
# ---------------------------------------------------------------------------
[docs]
@jit
def CL_scaling_vector_jax(
x: jnp.ndarray, g: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""JIT-compiled Coleman-Li scaling vector (replaces common_scipy.CL_scaling_vector).
Keeps result on-device to avoid GPU→CPU sync at call sites.
"""
v = jnp.ones_like(x)
dv = jnp.zeros_like(x)
mask_neg = (g < 0) & jnp.isfinite(ub)
mask_pos = (g > 0) & jnp.isfinite(lb)
v = jnp.where(mask_neg, ub - x, jnp.where(mask_pos, x - lb, v))
dv = jnp.where(mask_neg, -1.0, jnp.where(mask_pos, 1.0, dv))
return v, dv
[docs]
@jit
def in_bounds_jax(x: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray) -> jnp.ndarray:
"""JIT-compiled bounds check (replaces common_scipy.in_bounds).
Returns a JAX scalar boolean, avoiding GPU→CPU sync until Python checks it.
"""
return jnp.all((x >= lb) & (x <= ub))
[docs]
@jit
def make_strictly_feasible_jax(
x: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray
) -> jnp.ndarray:
"""JIT-compiled strict feasibility projection for rstep=0 case.
Shifts boundary points to the interior using jnp.nextafter, with a
guard against XLA's DAZ (Denormals-Are-Zero) behavior on CPU where
nextafter(0, x) produces a denormal that compares equal to zero.
Only supports rstep=0 (the only value used in trf.py hot path).
"""
# find_active_constraints with rtol=0 inlined
lower_mask = x <= lb
upper_mask = x >= ub
lower_nudge = jnp.nextafter(lb, ub)
upper_nudge = jnp.nextafter(ub, lb)
# Guard against DAZ: if nextafter produced a denormal that XLA treats as
# equal to the bound, use the smallest normal float as a minimum offset.
tiny = jnp.finfo(x.dtype).tiny
lower_nudge = jnp.where(lower_nudge <= lb, lb + tiny, lower_nudge)
upper_nudge = jnp.where(upper_nudge >= ub, ub - tiny, upper_nudge)
x_new = jnp.where(
lower_mask,
lower_nudge,
jnp.where(upper_mask, upper_nudge, x),
)
# Handle tight bounds where nextafter still violates
tight = (x_new < lb) | (x_new > ub)
return jnp.where(tight, 0.5 * (lb + ub), x_new)
# =====================================================================
# JAX-compiled bounded-path helpers (B004 optimization)
# These replace NumPy versions from common_scipy.py in the hot path
# to avoid implicit device-to-host transfers.
# =====================================================================
[docs]
@jit
def step_size_to_bound_jax(
x: jnp.ndarray, s: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Compute min step size to reach a bound (JAX version).
Parameters
----------
x : jnp.ndarray
Current parameter vector.
s : jnp.ndarray
Step direction.
lb : jnp.ndarray
Lower bounds.
ub : jnp.ndarray
Upper bounds.
Returns
-------
min_step : jnp.ndarray
Scalar minimum step to reach any bound.
hits : jnp.ndarray
Array indicating which bounds are hit (-1=lower, 0=none, 1=upper).
"""
# Compute step to each bound, handling division by zero
non_zero = s != 0
steps_lb = jnp.where(non_zero, (lb - x) / jnp.where(non_zero, s, 1.0), jnp.inf)
steps_ub = jnp.where(non_zero, (ub - x) / jnp.where(non_zero, s, 1.0), jnp.inf)
# For each component, the relevant bound step is the positive one
steps = jnp.maximum(steps_lb, steps_ub)
min_step = jnp.min(steps)
hits = jnp.where(steps == min_step, jnp.sign(s).astype(jnp.int32), 0)
return min_step, hits
[docs]
@jit
def intersect_trust_region_jax(
x: jnp.ndarray, s: jnp.ndarray, Delta: float
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Find intersection of a line with trust region boundary (JAX version).
Solves ||(x + s*t)||^2 = Delta^2 for t.
Parameters
----------
x : jnp.ndarray
Starting point.
s : jnp.ndarray
Direction.
Delta : float
Trust region radius.
Returns
-------
t_neg : jnp.ndarray
Negative root.
t_pos : jnp.ndarray
Positive root.
"""
a = jnp.dot(s, s)
b = jnp.dot(x, s)
c = jnp.dot(x, x) - Delta**2
d = jnp.sqrt(jnp.maximum(b * b - a * c, 0.0))
# Avoid loss of significance (Numerical Recipes)
q = -(b + jnp.where(b != 0, jnp.sign(b), 1.0) * d)
# Guard against a == 0 or q == 0
t1 = jnp.where(a != 0, q / jnp.where(a != 0, a, 1.0), 0.0)
t2 = jnp.where(q != 0, c / jnp.where(q != 0, q, 1.0), 0.0)
t_neg = jnp.minimum(t1, t2)
t_pos = jnp.maximum(t1, t2)
return t_neg, t_pos
[docs]
@jit
def minimize_quadratic_1d_jax(
a: jnp.ndarray, b: jnp.ndarray, lb: float, ub: float, c: float | jnp.ndarray = 0.0
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Minimize 1-D quadratic function subject to bounds (JAX version).
Minimizes f(t) = a*t^2 + b*t + c on [lb, ub].
Parameters
----------
a : jnp.ndarray
Quadratic coefficient.
b : jnp.ndarray
Linear coefficient.
lb : float
Lower bound.
ub : float
Upper bound.
c : jnp.ndarray
Constant term (default 0).
Returns
-------
t_min : jnp.ndarray
Minimizing point.
y_min : jnp.ndarray
Minimum value.
"""
# Evaluate at endpoints
y_lb = lb * (a * lb + b) + c
y_ub = ub * (a * ub + b) + c
# Check if extremum is in bounds
extremum = -0.5 * b / jnp.where(a != 0, a, 1.0)
y_ext = extremum * (a * extremum + b) + c
ext_valid = (a != 0) & (lb < extremum) & (extremum < ub)
# Select minimum among valid candidates
# Start with lb vs ub
t_min = jnp.where(y_lb <= y_ub, lb, ub)
y_min = jnp.minimum(y_lb, y_ub)
# Consider extremum if valid
t_min = jnp.where(ext_valid & (y_ext < y_min), extremum, t_min)
y_min = jnp.where(ext_valid & (y_ext < y_min), y_ext, y_min)
return t_min, y_min