Source code for nlsq.common_jax

"""
These are functions that were initially in the common.py file, but are have
large data operations and are therefore better suited to be compiled with
JAX.  They are compiled with JAX and then added to the CommonJIT class.
"""

import numpy as np

# Initialize JAX configuration through central config
from nlsq.config import JAXConfig

_jax_config = JAXConfig()

import jax.numpy as jnp
from jax import jit, lax

EPS = np.finfo(float).eps


[docs] @jit def phi_and_derivative_jax( alpha: float, suf: jnp.ndarray, s: jnp.ndarray, Delta: float ) -> tuple[jnp.ndarray, jnp.ndarray]: """JAX-compiled phi function for trust region subproblem. This function computes the value and derivative of the secular equation used to find the optimal Levenberg-Marquardt parameter. It is defined as "norm of regularized (by alpha) least-squares solution minus Delta". The function is used iteratively to find the root, which gives the optimal regularization parameter for the trust region subproblem. Parameters ---------- alpha : float Current regularization parameter (Levenberg-Marquardt lambda) suf : jnp.ndarray Product of singular values and U^T @ f (s * uf) s : jnp.ndarray Singular values of the Jacobian Delta : float Trust region radius Returns ------- phi : jnp.ndarray Value of the secular equation: ||p|| - Delta phi_prime : jnp.ndarray Derivative of phi with respect to alpha Notes ----- This is a JAX-jitted version of the function in common_scipy.py. Using JAX enables GPU acceleration and avoids NumPy-JAX data transfers when used with other JAX operations. The computation follows [12] Branch, M.A., Coleman, T.F., Li, Y., "A Subspace, Interior, and Conjugate Gradient Method for Large-Scale Bound-Constrained Minimization Problems", SIAM Journal on Scientific Computing, Vol. 21, Number 1, pp 1-23, 1999. """ denom = s**2 + alpha p_norm = jnp.linalg.norm(suf / denom) phi = p_norm - Delta safe_p_norm = jnp.where(p_norm == 0.0, jnp.finfo(p_norm.dtype).tiny, p_norm) phi_prime = -jnp.sum(suf**2 / denom**3) / safe_p_norm return phi, phi_prime
def _solve_lsq_trust_region_jax_impl( n: int, m: int, uf: jnp.ndarray, s: jnp.ndarray, V: jnp.ndarray, Delta: float, initial_alpha: float, has_initial_alpha: bool, rtol: float, max_iter: int, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Internal implementation for solve_lsq_trust_region_jax.""" suf = s * uf # Check if J has full rank threshold = EPS * m * s[0] full_rank = lax.cond( m >= n, lambda: s[-1] > threshold, lambda: False, ) # Gauss-Newton step (if full rank) p_gn = -V @ (uf / s) p_gn_norm = jnp.linalg.norm(p_gn) use_gauss_newton = full_rank & (p_gn_norm <= Delta) # Compute alpha bounds safe_Delta = jnp.maximum(Delta, jnp.finfo(jnp.result_type(Delta)).tiny) alpha_upper = jnp.linalg.norm(suf) / safe_Delta # Compute alpha_lower based on full_rank def compute_alpha_lower_full_rank(): phi_val, phi_prime_val = phi_and_derivative_jax(0.0, suf, s, Delta) safe_phi_prime = jnp.where( phi_prime_val == 0.0, jnp.finfo(phi_prime_val.dtype).tiny, phi_prime_val ) return -phi_val / safe_phi_prime alpha_lower = lax.cond( full_rank, compute_alpha_lower_full_rank, lambda: 0.0, ) # Compute default alpha default_alpha = jnp.maximum( 0.001 * alpha_upper, jnp.sqrt(alpha_lower * alpha_upper) ) # Use provided initial_alpha only if valid # Invalid if: no initial_alpha provided, OR (not full_rank AND initial_alpha == 0) use_provided = has_initial_alpha & ~(~full_rank & (initial_alpha == 0.0)) alpha_start = lax.cond( use_provided, lambda: initial_alpha, lambda: default_alpha, ) # State for while_loop: (alpha, alpha_lower, alpha_upper, iteration, converged) def loop_cond(state): _alpha, _alpha_lower, _alpha_upper, iteration, converged = state return (iteration < max_iter) & ~converged def loop_body(state): alpha, alpha_lower, alpha_upper, iteration, _ = state # Reset alpha if out of bounds alpha = lax.cond( (alpha < alpha_lower) | (alpha > alpha_upper), lambda: jnp.maximum( 0.001 * alpha_upper, jnp.sqrt(alpha_lower * alpha_upper) ), lambda: alpha, ) phi_val, phi_prime_val = phi_and_derivative_jax(alpha, suf, s, Delta) # Update alpha_upper if phi < 0 alpha_upper_new = lax.cond( phi_val < 0, lambda: alpha, lambda: alpha_upper, ) # Update alpha using Newton step (guard phi_prime_val == 0) safe_phi_prime = jnp.where( phi_prime_val == 0.0, jnp.finfo(phi_prime_val.dtype).tiny, phi_prime_val ) ratio = phi_val / safe_phi_prime alpha_lower_new = jnp.maximum(alpha_lower, alpha - ratio) alpha_new = alpha - (phi_val + Delta) * ratio / Delta # Check convergence converged = jnp.abs(phi_val) < rtol * Delta return (alpha_new, alpha_lower_new, alpha_upper_new, iteration + 1, converged) # Run the while loop init_state = (alpha_start, alpha_lower, alpha_upper, jnp.array(0), jnp.array(False)) final_alpha, _, _, n_iter_final, _ = lax.while_loop( loop_cond, loop_body, init_state ) # Compute final solution p p_iterative = -V @ (suf / (s**2 + final_alpha)) # Normalize p to exactly Delta to prevent numerical drift p_iterative_norm = jnp.linalg.norm(p_iterative) safe_norm = jnp.maximum(p_iterative_norm, jnp.finfo(p_iterative.dtype).tiny) p_iterative_normalized = p_iterative * (Delta / safe_norm) # Select between Gauss-Newton and iterative solution p_final = lax.cond( use_gauss_newton, lambda: p_gn, lambda: p_iterative_normalized, ) alpha_final = lax.cond( use_gauss_newton, lambda: jnp.array(0.0), lambda: final_alpha, ) n_iter_out = lax.cond( use_gauss_newton, lambda: jnp.array(0), lambda: n_iter_final, ) return p_final, alpha_final, n_iter_out # Create a JIT-compiled version of the implementation _solve_lsq_trust_region_jax_jit = jit( _solve_lsq_trust_region_jax_impl, static_argnums=(0, 1, 8, 9) )
[docs] def solve_lsq_trust_region_jax( n: int, m: int, uf: jnp.ndarray, s: jnp.ndarray, V: jnp.ndarray, Delta: float, initial_alpha: float | None = None, rtol: float = 0.01, max_iter: int = 10, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """JAX-compiled trust-region problem solver for least-squares minimization. This function implements a method described by J. J. More [12] and used in MINPACK, but relies on a single SVD of Jacobian instead of series of Cholesky decompositions. Before running this function, compute: ``U, s, VT = svd(J, full_matrices=False)``. This is a pure JAX implementation using lax.while_loop for JIT compilation, avoiding NumPy-JAX data transfers when used in the optimization hot path. Parameters ---------- n : int Number of variables. m : int Number of residuals. uf : jnp.ndarray Computed as U.T @ f. s : jnp.ndarray Singular values of J. V : jnp.ndarray Transpose of VT (i.e., V = VT.T). Delta : float Radius of a trust region. initial_alpha : float, optional Initial guess for alpha. If None, determined automatically. rtol : float, optional Stopping tolerance for the root-finding procedure. max_iter : int, optional Maximum allowed number of iterations. Returns ------- p : jnp.ndarray, shape (n,) Found solution of a trust-region problem. alpha : jnp.ndarray Levenberg-Marquardt parameter (scalar as 0-d array). n_iter : jnp.ndarray Number of iterations made (scalar as 0-d array). References ---------- .. [12] More, J. J., "The Levenberg-Marquardt Algorithm: Implementation and Theory," Numerical Analysis, ed. G. A. Watson, Lecture Notes in Mathematics 630, Springer Verlag, pp. 105-116, 1977. """ # Handle None initial_alpha if initial_alpha is None: init_alpha_val = 0.0 has_initial_alpha = False else: init_alpha_val = initial_alpha has_initial_alpha = True return _solve_lsq_trust_region_jax_jit( n, m, uf, s, V, Delta, init_alpha_val, has_initial_alpha, rtol, max_iter )
[docs] class CommonJIT: """JIT-compiled common functions for nonlinear least squares optimization. This class provides GPU/TPU-accelerated implementations of mathematical operations commonly used across different optimization algorithms. All functions are JIT-compiled for maximum performance and memory efficiency. Core Functionality ------------------ - **Quadratic Function Operations**: Build and evaluate quadratic forms - **Matrix-Vector Products**: Optimized Jacobian operations - **Robust Loss Scaling**: Jacobian and residual scaling for robust methods - **Numerical Utilities**: Condition-aware computations with overflow protection JIT Compilation Benefits ------------------------- - **GPU/TPU Acceleration**: All operations optimized for parallel hardware - **Memory Efficiency**: Reduced allocations through compilation optimization - **Automatic Fusion**: Operations automatically fused for better performance - **Type Specialization**: Functions compiled for specific array shapes/types Mathematical Operations ------------------------ The class implements several categories of operations: 1. **Quadratic Functions**: For trust region subproblems - 1D quadratic parameterization along search directions - Quadratic model evaluation for step selection - Hessian approximation using J^T * J structure 2. **Matrix Operations**: Optimized linear algebra - Jacobian-vector products with broadcasting - Scaling operations for robust loss functions - Condition-aware computations with numerical stability 3. **Robust Loss Support**: For outlier-resistant fitting - Jacobian scaling using loss function derivatives - Residual weighting based on robustness weights - Numerical stability for extreme scaling factors Performance Characteristics ---------------------------- - **Memory Usage**: O(1) additional memory overhead - **Compilation Time**: One-time cost during initialization - **Execution Speed**: 10-100x faster than pure NumPy on GPU - **Numerical Precision**: Full double precision support Technical Implementation ------------------------ All functions use JAX JIT compilation with the following features: - Static argument handling for shape polymorphism - Automatic differentiation compatibility - XLA optimization for target hardware - Memory layout optimization for cache efficiency Integration with Optimization Algorithms ----------------------------------------- This class is used by: - Trust Region Reflective (TRF) algorithm for quadratic models - Levenberg-Marquardt algorithm for step computation - Robust loss functions for residual scaling - Large dataset processing for memory-efficient operations Usage Example ------------- :: from nlsq.common_jax import CommonJIT import jax.numpy as jnp # Initialize JIT-compiled functions cjit = CommonJIT() # Example: Scale Jacobian for robust loss jacobian = jnp.array([[1.0, 2.0], [3.0, 4.0]]) residuals = jnp.array([0.1, 2.5]) # Contains outlier rho = jnp.array([loss_val, loss_deriv1, loss_deriv2]) # Apply robust scaling scaled_J, scaled_f = cjit.scale_for_robust_loss_function(jacobian, residuals, rho) # Example: Build quadratic model gradient = jnp.array([0.1, -0.3]) direction = jnp.array([1.0, 0.5]) a, b, c = cjit.build_quadratic_1d(jacobian, gradient, direction) Numerical Considerations ------------------------- - **Overflow Protection**: Automatic handling of extreme scaling factors - **Underflow Prevention**: Minimum threshold enforcement (EPS) - **Condition Monitoring**: Numerical stability checks for ill-conditioned operations - **Precision Control**: Double precision arithmetic throughout """
[docs] def __init__(self): """Initialize CommonJIT with all compiled functions. This creates and compiles all JIT functions during initialization for optimal runtime performance. Functions are compiled once and reused across multiple optimization runs. Compiled Functions Created --------------------------- - Quadratic function builders and evaluators - Matrix-vector dot products with broadcasting - Jacobian sum operations for constraint handling - Robust loss function scaling operations """ self.create_quadratic_funcs() self.create_js_dot() self.create_jac_sum() self.create_scale_for_robust_loss_function()
[docs] def create_scale_for_robust_loss_function(self): """Create the scaling function for the loss functions""" @jit def scale_for_robust_loss_function( J: jnp.ndarray, f: jnp.ndarray, rho: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: """Scale Jacobian and residuals for a robust loss function. Arrays are modified in place. Parameters ---------- J : jnp.ndarray Jacobian matrix. f : jnp.ndarray Residuals. rho : jnp.ndarray Cost function evaluation. """ # Scale Jacobian and residuals for robust loss function J_scale = rho[1] + 2 * rho[2] * f**2 # Prevent division by zero mask = J_scale < EPS J_scale = jnp.where(mask, EPS, J_scale) J_scale = J_scale**0.5 # Compute scaling factors fscale = rho[1] / J_scale # Apply scaling f = f * fscale J = J * J_scale[:, jnp.newaxis] return J, f self.scale_for_robust_loss_function = scale_for_robust_loss_function
[docs] def build_quadratic_1d( self, J: jnp.ndarray, g: jnp.ndarray, s: jnp.ndarray, diag: jnp.ndarray | None = None, s0: jnp.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray]: """Parameterize a multivariate quadratic function along a line. The resulting univariate quadratic function is given as follows: f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t) Parameters ---------- J : ndarray, sparse matrix or LinearOperator, shape (m, n) Jacobian matrix, affects the quadratic term. g : ndarray, shape (n,) Gradient, defines the linear term. s : ndarray, shape (n,) Direction vector of a line. diag : None or ndarray with shape (n,), optional Addition diagonal part, affects the quadratic term. If None, assumed to be 0. s0 : None or ndarray with shape (n,), optional Initial point. If None, assumed to be 0. Returns ------- a : float Coefficient for t**2. b : float Coefficient for t. c : float Free term. Returned only if `s0` is provided. """ # OPT-2: Use jnp.asarray() to avoid copy if already JAX array s_jnp = jnp.asarray(s) v = self.js_dot(J, s_jnp) a = np.dot(v, v) if diag is not None: a += np.dot(s * diag, s) a *= 0.5 b = np.dot(g, s) if s0 is not None: # OPT-2: Use jnp.asarray() to avoid copy if already JAX array s0_jnp = jnp.asarray(s0) u = self.js0_dot(J, s0_jnp) b += np.dot(u, v) c = 0.5 * np.dot(u, u) + np.dot(g, s0) if diag is not None: b += np.dot(s0 * diag, s) c += 0.5 * np.dot(s0 * diag, s0) return a, b, c else: return a, b
[docs] def compute_jac_scale( self, J: jnp.ndarray, scale_inv_old: jnp.ndarray | np.ndarray | None = None ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute variables scale based on the Jacobian matrix. Returns JAX arrays to keep results on-device and avoid GPU→CPU sync. Parameters ---------- J : jnp.ndarray Jacobian matrix. scale_inv_old : jnp.ndarray | np.ndarray | None, optional Previous scale, by default None Returns ------- scale : jnp.ndarray Scale for the variables. scale_inv : jnp.ndarray Inverse of the scale for the variables. """ scale_inv = self.jac_sum_func(J) if scale_inv_old is None: scale_inv = jnp.where(scale_inv == 0, 1.0, scale_inv) else: scale_inv = jnp.maximum(scale_inv, scale_inv_old) return 1 / scale_inv, scale_inv
[docs] def create_js_dot(self): """Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both. """ @jit def js_dot(J: jnp.ndarray, s: jnp.ndarray) -> jnp.ndarray: return J.dot(s) @jit def js0_dot(J: jnp.ndarray, s0: jnp.ndarray) -> jnp.ndarray: return J.dot(s0) self.js_dot = js_dot self.js0_dot = js0_dot
[docs] def evaluate_quadratic( self, J: jnp.ndarray, g: jnp.ndarray, s_np: np.ndarray, diag: np.ndarray | None = None, ) -> jnp.ndarray: """Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s. Parameters ---------- J : ndarray, sparse matrix or LinearOperator, shape (m, n) Jacobian matrix, affects the quadratic term. g : ndarray, shape (n,) Gradient, defines the linear term. s : ndarray, shape (k, n) or (n,) Array containing steps as rows. diag : ndarray, shape (n,), optional Addition diagonal part, affects the quadratic term. If None, assumed to be 0. Returns ------- values : ndarray with shape (k,) or float Values of the function. If `s` was 2-D, then ndarray is returned, otherwise, float is returned. """ # OPT-2: Use jnp.asarray() to avoid copy if already JAX array s = jnp.asarray(s_np) if s.ndim == 1: if diag is None: return self.evaluate_quadratic1(J, g, s) else: return self.evaluate_quadratic_diagonal1(J, g, s, diag) elif diag is None: return self.evaluate_quadratic2(J, g, s) else: return self.evaluate_quadratic_diagonal2(J, g, s, diag)
[docs] def create_quadratic_funcs(self): @jit def evaluate_quadratic1(J, g, s): Js = J.dot(s) q = jnp.dot(Js, Js) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic_diagonal1(J, g, s, diag): Js = J.dot(s) q = jnp.dot(Js, Js) + jnp.dot(s * diag, s) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic2(J, g, s): Js = J.dot(s.T) q = jnp.sum(Js**2, axis=0) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic_diagonal2(J, g, s, diag): Js = J.dot(s.T) q = jnp.sum(Js**2, axis=0) + jnp.sum(diag * s**2, axis=1) l = jnp.dot(s, g) return 0.5 * q + l self.evaluate_quadratic1 = evaluate_quadratic1 self.evaluate_quadratic_diagonal1 = evaluate_quadratic_diagonal1 self.evaluate_quadratic2 = evaluate_quadratic2 self.evaluate_quadratic_diagonal2 = evaluate_quadratic_diagonal2
[docs] def create_jac_sum(self): """Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this. """ @jit def jac_sum_func(J): return jnp.sum(J**2, axis=0) ** 0.5 self.jac_sum_func = jac_sum_func
# --------------------------------------------------------------------------- # JIT-compiled versions of trust-region helper functions (B002, B003 fixes) # These eliminate GPU→CPU sync overhead by keeping all operations on device. # ---------------------------------------------------------------------------
[docs] @jit def update_tr_radius_jax( Delta: float, actual_reduction: float, predicted_reduction: float, step_norm: float, bound_hit: bool, ) -> tuple[jnp.ndarray, jnp.ndarray]: """JIT-compiled trust-region radius update (replaces common_scipy.update_tr_radius). Uses branchless jnp.where to avoid Python control flow and GPU→CPU sync. Parameters ---------- Delta : float Current trust region radius. actual_reduction : float Actual cost reduction. predicted_reduction : float Predicted cost reduction. step_norm : float Norm of the step. bound_hit : bool Whether the step hit the trust region boundary. Returns ------- Delta : jnp.ndarray Updated trust region radius. ratio : jnp.ndarray Ratio of actual to predicted reduction. """ # Compute ratio with branchless logic safe_pred = jnp.where(predicted_reduction != 0, predicted_reduction, 1.0) ratio = jnp.where( predicted_reduction > 0, actual_reduction / safe_pred, jnp.where( (predicted_reduction == 0) & (actual_reduction == 0), 1.0, 0.0, ), ) # Update Delta with branchless logic Delta_shrunk = 0.25 * step_norm Delta_grown = Delta * 2.0 Delta_new = jnp.where( ratio < 0.25, Delta_shrunk, jnp.where((ratio > 0.75) & bound_hit, Delta_grown, Delta), ) return Delta_new, ratio
[docs] @jit def check_termination_jax( dF: float, F: float, dx_norm: float, x_norm: float, ratio: float, ftol: float, xtol: float, ) -> jnp.ndarray: """JIT-compiled termination check (replaces common_scipy.check_termination). Returns integer status code as a JAX scalar (no GPU→CPU sync needed until the Python-level loop checks it). Returns 0 for "not terminated". Parameters ---------- dF : float Change in cost function. F : float Current cost function value. dx_norm : float Norm of the step. x_norm : float Norm of the current parameters. ratio : float Ratio of actual to predicted reduction. ftol : float Cost function tolerance. xtol : float Parameter tolerance. Returns ------- status : jnp.ndarray 0=not terminated, 2=ftol, 3=xtol, 4=both. """ ftol_satisfied = (dF < ftol * F) & (ratio > 0.25) xtol_satisfied = dx_norm < xtol * (xtol + x_norm) status = jnp.where( ftol_satisfied & xtol_satisfied, 4, jnp.where(ftol_satisfied, 2, jnp.where(xtol_satisfied, 3, 0)), ) return status
# --------------------------------------------------------------------------- # JIT-compiled bounds helper functions (S07-S10) # These keep bounds operations on-device, avoiding NumPy→JAX transfers. # The NumPy originals in common_scipy.py remain for functions that need them. # ---------------------------------------------------------------------------
[docs] @jit def CL_scaling_vector_jax( x: jnp.ndarray, g: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: """JIT-compiled Coleman-Li scaling vector (replaces common_scipy.CL_scaling_vector). Keeps result on-device to avoid GPU→CPU sync at call sites. """ v = jnp.ones_like(x) dv = jnp.zeros_like(x) mask_neg = (g < 0) & jnp.isfinite(ub) mask_pos = (g > 0) & jnp.isfinite(lb) v = jnp.where(mask_neg, ub - x, jnp.where(mask_pos, x - lb, v)) dv = jnp.where(mask_neg, -1.0, jnp.where(mask_pos, 1.0, dv)) return v, dv
[docs] @jit def in_bounds_jax(x: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray) -> jnp.ndarray: """JIT-compiled bounds check (replaces common_scipy.in_bounds). Returns a JAX scalar boolean, avoiding GPU→CPU sync until Python checks it. """ return jnp.all((x >= lb) & (x <= ub))
[docs] @jit def make_strictly_feasible_jax( x: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray ) -> jnp.ndarray: """JIT-compiled strict feasibility projection for rstep=0 case. Shifts boundary points to the interior using jnp.nextafter, with a guard against XLA's DAZ (Denormals-Are-Zero) behavior on CPU where nextafter(0, x) produces a denormal that compares equal to zero. Only supports rstep=0 (the only value used in trf.py hot path). """ # find_active_constraints with rtol=0 inlined lower_mask = x <= lb upper_mask = x >= ub lower_nudge = jnp.nextafter(lb, ub) upper_nudge = jnp.nextafter(ub, lb) # Guard against DAZ: if nextafter produced a denormal that XLA treats as # equal to the bound, use the smallest normal float as a minimum offset. tiny = jnp.finfo(x.dtype).tiny lower_nudge = jnp.where(lower_nudge <= lb, lb + tiny, lower_nudge) upper_nudge = jnp.where(upper_nudge >= ub, ub - tiny, upper_nudge) x_new = jnp.where( lower_mask, lower_nudge, jnp.where(upper_mask, upper_nudge, x), ) # Handle tight bounds where nextafter still violates tight = (x_new < lb) | (x_new > ub) return jnp.where(tight, 0.5 * (lb + ub), x_new)
# ===================================================================== # JAX-compiled bounded-path helpers (B004 optimization) # These replace NumPy versions from common_scipy.py in the hot path # to avoid implicit device-to-host transfers. # =====================================================================
[docs] @jit def step_size_to_bound_jax( x: jnp.ndarray, s: jnp.ndarray, lb: jnp.ndarray, ub: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute min step size to reach a bound (JAX version). Parameters ---------- x : jnp.ndarray Current parameter vector. s : jnp.ndarray Step direction. lb : jnp.ndarray Lower bounds. ub : jnp.ndarray Upper bounds. Returns ------- min_step : jnp.ndarray Scalar minimum step to reach any bound. hits : jnp.ndarray Array indicating which bounds are hit (-1=lower, 0=none, 1=upper). """ # Compute step to each bound, handling division by zero non_zero = s != 0 steps_lb = jnp.where(non_zero, (lb - x) / jnp.where(non_zero, s, 1.0), jnp.inf) steps_ub = jnp.where(non_zero, (ub - x) / jnp.where(non_zero, s, 1.0), jnp.inf) # For each component, the relevant bound step is the positive one steps = jnp.maximum(steps_lb, steps_ub) min_step = jnp.min(steps) hits = jnp.where(steps == min_step, jnp.sign(s).astype(jnp.int32), 0) return min_step, hits
[docs] @jit def intersect_trust_region_jax( x: jnp.ndarray, s: jnp.ndarray, Delta: float ) -> tuple[jnp.ndarray, jnp.ndarray]: """Find intersection of a line with trust region boundary (JAX version). Solves ||(x + s*t)||^2 = Delta^2 for t. Parameters ---------- x : jnp.ndarray Starting point. s : jnp.ndarray Direction. Delta : float Trust region radius. Returns ------- t_neg : jnp.ndarray Negative root. t_pos : jnp.ndarray Positive root. """ a = jnp.dot(s, s) b = jnp.dot(x, s) c = jnp.dot(x, x) - Delta**2 d = jnp.sqrt(jnp.maximum(b * b - a * c, 0.0)) # Avoid loss of significance (Numerical Recipes) q = -(b + jnp.where(b != 0, jnp.sign(b), 1.0) * d) # Guard against a == 0 or q == 0 t1 = jnp.where(a != 0, q / jnp.where(a != 0, a, 1.0), 0.0) t2 = jnp.where(q != 0, c / jnp.where(q != 0, q, 1.0), 0.0) t_neg = jnp.minimum(t1, t2) t_pos = jnp.maximum(t1, t2) return t_neg, t_pos
[docs] @jit def minimize_quadratic_1d_jax( a: jnp.ndarray, b: jnp.ndarray, lb: float, ub: float, c: float | jnp.ndarray = 0.0 ) -> tuple[jnp.ndarray, jnp.ndarray]: """Minimize 1-D quadratic function subject to bounds (JAX version). Minimizes f(t) = a*t^2 + b*t + c on [lb, ub]. Parameters ---------- a : jnp.ndarray Quadratic coefficient. b : jnp.ndarray Linear coefficient. lb : float Lower bound. ub : float Upper bound. c : jnp.ndarray Constant term (default 0). Returns ------- t_min : jnp.ndarray Minimizing point. y_min : jnp.ndarray Minimum value. """ # Evaluate at endpoints y_lb = lb * (a * lb + b) + c y_ub = ub * (a * ub + b) + c # Check if extremum is in bounds extremum = -0.5 * b / jnp.where(a != 0, a, 1.0) y_ext = extremum * (a * extremum + b) + c ext_valid = (a != 0) & (lb < extremum) & (extremum < ub) # Select minimum among valid candidates # Start with lb vs ub t_min = jnp.where(y_lb <= y_ub, lb, ub) y_min = jnp.minimum(y_lb, y_ub) # Consider extremum if valid t_min = jnp.where(ext_valid & (y_ext < y_min), extremum, t_min) y_min = jnp.where(ext_valid & (y_ext < y_min), y_ext, y_min) return t_min, y_min