Source code for nlsq.core.trf

"""Trust Region Reflective algorithm for least-squares optimization.
The algorithm is based on ideas from paper [STIR]_. The main idea is to
account for the presence of the bounds by appropriate scaling of the variables (or,
equivalently, changing a trust-region shape). Let's introduce a vector v::

           | ub[i] - x[i], if g[i] < 0 and ub[i] < np.inf
    v[i] = | x[i] - lb[i], if g[i] > 0 and lb[i] > -np.inf
           | 1,           otherwise

where g is the gradient of a cost function and lb, ub are the bounds. Its
components are distances to the bounds at which the anti-gradient points (if
this distance is finite). Define a scaling matrix D = diag(v**0.5).
First-order optimality conditions can be stated as::

    D^2 g(x) = 0.

Meaning that components of the gradient should be zero for strictly interior
variables, and components must point inside the feasible region for variables
on the bound.
Now consider this system of equations as a new optimization problem. If the
point x is strictly interior (not on the bound), then the left-hand side is
differentiable and the Newton step for it satisfies::

    (D^2 H + diag(g) Jv) p = -D^2 g

where H is the Hessian matrix (or its J^T J approximation in least squares),
Jv is the Jacobian matrix of v with components -1, 1 or 0, such that all
elements of matrix C = diag(g) Jv are non-negative. Introduce the change
of the variables x = D x_h (_h would be "hat" in LaTeX). In the new variables,
we have a Newton step satisfying::

    B_h p_h = -g_h,

where B_h = D H D + C, g_h = D g. In least squares B_h = J_h^T J_h, where
J_h = J D. Note that J_h and g_h are proper Jacobian and gradient with respect
to "hat" variables. To guarantee global convergence we formulate a
trust-region problem based on the Newton step in the new variables::

    0.5 * p_h^T B_h p + g_h^T p_h -> min, ||p_h|| <= Delta

In the original space B = H + D^{-1} C D^{-1}, and the equivalent trust-region
problem is::

    0.5 * p^T B p + g^T p -> min, ||D^{-1} p|| <= Delta

Here, the meaning of the matrix D becomes more clear: it alters the shape
of a trust-region, such that large steps towards the bounds are not allowed.
In the implementation, the trust-region problem is solved in "hat" space,
but handling of the bounds is done in the original space (see below and read
the code).
The introduction of the matrix D doesn't allow to ignore bounds, the algorithm
must keep iterates strictly feasible (to satisfy aforementioned
differentiability), the parameter theta controls step back from the boundary
(see the code for details).
The algorithm does another important trick. If the trust-region solution
doesn't fit into the bounds, then a reflected (from a firstly encountered
bound) search direction is considered. For motivation and analysis refer to
[STIR]_ paper (and other papers of the authors). In practice, it doesn't need
a lot of justifications, the algorithm simply chooses the best step among
three: a constrained trust-region step, a reflected step and a constrained
Cauchy step (a minimizer along -g_h in "hat" space, or -D^2 g in the original
space).
Another feature is that a trust-region radius control strategy is modified to
account for appearance of the diagonal C matrix (called diag_h in the code).
Note that all described peculiarities are completely gone as we consider
problems without bounds (the algorithm becomes a standard trust-region type
algorithm very similar to ones implemented in MINPACK).
The implementation supports two methods of solving the trust-region problem.
The first, called 'exact', applies SVD on Jacobian and then solves the problem
very accurately using the algorithm described in [JJMore]_. It is not
applicable to large problem. The second, called 'lsmr', uses the 2-D subspace
approach (sometimes called "indefinite dogleg"), where the problem is solved
in a subspace spanned by the gradient and the approximate Gauss-Newton step
found by ``scipy.sparse.linalg.lsmr``. A 2-D trust-region problem is
reformulated as a 4th order algebraic equation and solved very accurately by
``numpy.roots``. The subspace approach allows to solve very large problems
(up to couple of millions of residuals on a regular PC), provided the Jacobian
matrix is sufficiently sparse.
References
----------
.. [STIR] Branch, M.A., T.F. Coleman, and Y. Li, "A Subspace, Interior,
      and Conjugate Gradient Method for Large-Scale Bound-Constrained
      Minimization Problems," SIAM Journal on Scientific Computing,
      Vol. 21, Number 1, pp 1-23, 1999.
.. [JJMore] More, J. J., "The Levenberg-Marquardt Algorithm: Implementation
    and Theory," Numerical Analysis, ed. G. A. Watson, Lecture
"""

# mypy: disable-error-code="arg-type,assignment,attr-defined,operator,misc,index,var-annotated,override"
# Note: mypy errors are mostly arg-type/assignment mismatches where Optional values
# are passed to methods expecting non-Optional, plus operator type conflicts between
# JAX arrays and numpy arrays. These require deeper refactoring of the TRF API.

from __future__ import annotations

import math
import warnings
from collections.abc import Callable

import numpy as np

# Initialize JAX configuration through central config
from nlsq.config import JAXConfig

__jax_config = JAXConfig()
import jax.numpy as jnp
from jax import debug
from jax.numpy.linalg import norm as jnorm

# Import safe SVD with fallback (full deterministic SVD only)
from nlsq.stability.svd_fallback import (
    initialize_gpu_safely,
)

# Setup logging
from nlsq.utils.logging import LogLevel, get_logger

# Cache the performance log level for fast guard checks in hot loops
PERFORMANCE_LEVEL = LogLevel.PERFORMANCE

logger = get_logger("trf")

# Initialize GPU settings safely
initialize_gpu_safely()

# Import dataclasses for SVDCache
from dataclasses import dataclass
from typing import Any, NamedTuple

from nlsq.caching.unified_cache import get_global_cache
from nlsq.callbacks import StopOptimization
from nlsq.common_jax import (
    CL_scaling_vector_jax,
    CommonJIT,
    check_termination_jax,
    in_bounds_jax,
    intersect_trust_region_jax,
    make_strictly_feasible_jax,
    minimize_quadratic_1d_jax,
    solve_lsq_trust_region_jax,
    step_size_to_bound_jax,
    update_tr_radius_jax,
)
from nlsq.common_scipy import (
    find_active_constraints,
    in_bounds,
    print_header_nonlinear,
    print_iteration_nonlinear,
)
from nlsq.constants import (
    DEFAULT_MAX_NFEV_MULTIPLIER,
    INITIAL_LEVENBERG_MARQUARDT_LAMBDA,
)

# Logging support
# Optimizer base class
from nlsq.core.optimizer_base import TrustRegionOptimizerBase

# Profiling support
from nlsq.core.profiler import NullProfiler, TRFProfiler

# JIT-compiled helper functions (extracted for modularity)
from nlsq.core.trf_jit import TrustRegionJITFunctions
from nlsq.result import OptimizeResult
from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.utils.diagnostics import OptimizationDiagnostics


[docs] class SVDCache(NamedTuple): """Cache SVD decomposition across inner loop iterations when Jacobian unchanged. This cache stores the SVD components (U, s, V) along with the scaled Jacobian J_h to avoid redundant SVD computations during inner loop iterations where the step is rejected and parameters remain unchanged. Attributes ---------- U : jnp.ndarray Left singular vectors (m x k), where m is residuals and k = min(m, n). s : jnp.ndarray Singular values (k,). V : jnp.ndarray Right singular vectors (n x k), where n is parameters. J_h : jnp.ndarray Scaled Jacobian in "hat" space (m x n). x_hash : int Hash of parameter vector for cache validation. Cache is valid only when the current parameter hash matches this value. Notes ----- The cache is valid only when `x_hash` matches the current parameter vector's hash. When a step is rejected (actual_reduction <= 0), the parameters don't change, so the SVD can be reused. When a step is accepted, the cache must be invalidated. The expected speedup from SVD caching is 20-40% on problems with frequent step rejections, as SVD computation is O(mn^2) and dominates iteration time. """ U: jnp.ndarray s: jnp.ndarray V: jnp.ndarray J_h: jnp.ndarray x_hash: int
# ===================================================================== # TRF Configuration Dataclasses (US4 - Parameter Objects) # =====================================================================
[docs] @dataclass(frozen=True, slots=True) class TRFConfig: """Immutable TRF algorithm configuration. Groups algorithm configuration parameters passed to TRF optimizer functions. This is an internal implementation detail - the public API remains unchanged. Attributes ---------- ftol : float Tolerance for termination by change of cost function. xtol : float Tolerance for termination by change of independent variables. gtol : float Tolerance for termination by norm of gradient. max_nfev : int or None Maximum number of function evaluations. None for unlimited. x_scale : str Characteristic scale of variables. 'jac' for automatic scaling. loss : str Loss function type ('linear', 'soft_l1', 'huber', 'cauchy', 'arctan'). tr_solver : str Trust-region subproblem solver ('exact', 'lsmr', 'cg'). verbose : int Verbosity level (0=silent, 1=termination, 2=iterations). """ ftol: float = 1e-8 xtol: float = 1e-8 gtol: float = 1e-8 max_nfev: int | None = None x_scale: str = "jac" loss: str = "linear" tr_solver: str = "exact" verbose: int = 0
[docs] def __post_init__(self) -> None: """Validate configuration values.""" if self.ftol < 0 or not math.isfinite(self.ftol): raise ValueError( f"ftol must be a finite non-negative number, got {self.ftol}" ) if self.xtol < 0 or not math.isfinite(self.xtol): raise ValueError( f"xtol must be a finite non-negative number, got {self.xtol}" ) if self.gtol < 0 or not math.isfinite(self.gtol): raise ValueError( f"gtol must be a finite non-negative number, got {self.gtol}" ) if self.max_nfev is not None and self.max_nfev <= 0: raise ValueError(f"max_nfev must be positive, got {self.max_nfev}") valid_losses = {"linear", "soft_l1", "huber", "cauchy", "arctan"} if self.loss not in valid_losses: raise ValueError(f"loss must be one of {valid_losses}, got {self.loss}") valid_solvers = {"exact", "lsmr", "cg"} if self.tr_solver not in valid_solvers: raise ValueError( f"tr_solver must be one of {valid_solvers}, got {self.tr_solver}" )
[docs] @dataclass(slots=True) class StepContext: """Mutable state container for TRF step computation. Groups the iteration state variables passed between TRF helper methods. This reduces parameter count and improves code clarity. Attributes ---------- x : jnp.ndarray Current parameter values. f : jnp.ndarray Residual vector at x. J : jnp.ndarray Jacobian matrix at x. cost : float Current cost (0.5 * ||f||^2). g : jnp.ndarray Gradient vector (J^T @ f). trust_radius : float Current trust region radius (Delta). iteration : int Current iteration number. scale : jnp.ndarray Variable scaling factors. scale_inv : jnp.ndarray Inverse scaling factors. alpha : float Levenberg-Marquardt parameter. """ x: jnp.ndarray f: jnp.ndarray J: jnp.ndarray cost: float g: jnp.ndarray trust_radius: float iteration: int scale: jnp.ndarray scale_inv: jnp.ndarray alpha: float = 0.0
[docs] @dataclass(frozen=True, slots=True) class BoundsContext: """Bound constraint data for TRF optimization. Groups the bound-related arrays used in bounded optimization. Attributes ---------- lb : jnp.ndarray Lower bounds on parameters. ub : jnp.ndarray Upper bounds on parameters. x_scale : jnp.ndarray Scaling factors for bounded variables. x_offset : jnp.ndarray Offset for bounded variables (center of bounds). lb_scaled : jnp.ndarray Scaled lower bounds. ub_scaled : jnp.ndarray Scaled upper bounds. """ lb: jnp.ndarray ub: jnp.ndarray x_scale: jnp.ndarray x_offset: jnp.ndarray lb_scaled: jnp.ndarray ub_scaled: jnp.ndarray
[docs] @classmethod def from_bounds( cls, lb: jnp.ndarray, ub: jnp.ndarray, x_scale: jnp.ndarray | None = None, ) -> BoundsContext: """Create BoundsContext from bounds arrays. Parameters ---------- lb : array_like Lower bounds. ub : array_like Upper bounds. x_scale : array_like, optional Scaling factors. If None, uses 1.0. Returns ------- BoundsContext Initialized bounds context. """ lb = jnp.asarray(lb) ub = jnp.asarray(ub) if x_scale is None: x_scale = jnp.ones_like(lb) else: x_scale = jnp.asarray(x_scale) # Handle infinite bounds: when both are infinite the midpoint is 0; # when only one is infinite the offset is the finite bound. both_finite = jnp.isfinite(lb) & jnp.isfinite(ub) lb_finite_only = jnp.isfinite(lb) & ~jnp.isfinite(ub) ub_finite_only = ~jnp.isfinite(lb) & jnp.isfinite(ub) x_offset = jnp.where( both_finite, (lb + ub) / 2.0, jnp.where(lb_finite_only, lb, jnp.where(ub_finite_only, ub, 0.0)), ) lb_scaled = (lb - x_offset) / x_scale ub_scaled = (ub - x_offset) / x_scale return cls( lb=lb, ub=ub, x_scale=x_scale, x_offset=x_offset, lb_scaled=lb_scaled, ub_scaled=ub_scaled, )
# Algorithm constants # Trust region parameters TR_REDUCTION_FACTOR = 0.25 # Factor to reduce trust region when numerical issues occur TR_BOUNDARY_THRESHOLD = 0.95 # Threshold for checking if step is close to boundary SQRT_EXPONENT = 0.5 # Exponent for square root in scaling (v**0.5)
[docs] class TrustRegionReflective(TrustRegionJITFunctions, TrustRegionOptimizerBase): """Trust Region Reflective algorithm for bounded least squares optimization. Implements the TRF algorithm with variable scaling to handle parameter bounds. Supports exact (SVD) and iterative (CG) solvers for trust region subproblems. """
[docs] def __init__(self, enable_stability: bool = False): """Initialize the TrustRegionReflective optimizer. Creates JIT-compiled functions and sets up logging infrastructure. All optimization functions are compiled during initialization for maximum performance during solve operations. Parameters ---------- enable_stability : bool, default False Enable numerical stability checks and fixes """ TrustRegionJITFunctions.__init__(self) TrustRegionOptimizerBase.__init__(self, name="trf") self.cJIT = CommonJIT() # Initialize unified cache for JIT compilation tracking self.cache = get_global_cache() # Initialize stability system self.enable_stability = enable_stability if enable_stability: self.stability_guard = NumericalStabilityGuard()
@staticmethod def _log_iteration_callback( iteration: Any, nfev: Any, cost: Any, actual_reduction: Any, step_norm: Any, g_norm: Any, ) -> None: """Wrapper for logging callback that converts JAX arrays to Python scalars. This function is called by jax.debug.callback and ensures all arguments are converted from JAX arrays to Python scalars before logging. Parameters ---------- iteration : int or jax.Array Iteration number nfev : int or jax.Array Number of function evaluations cost : float or jax.Array Current cost actual_reduction : float or jax.Array or None Actual cost reduction step_norm : float or jax.Array or None Step norm g_norm : float or jax.Array Gradient norm """ # Convert JAX arrays to Python scalars iteration = int(iteration) if hasattr(iteration, "item") else iteration nfev = int(nfev) if hasattr(nfev, "item") else nfev cost = float(cost) if hasattr(cost, "item") else cost g_norm = float(g_norm) if hasattr(g_norm, "item") else g_norm # Handle optional values if actual_reduction is not None: actual_reduction = ( float(actual_reduction) if hasattr(actual_reduction, "item") else actual_reduction ) if step_norm is not None: step_norm = float(step_norm) if hasattr(step_norm, "item") else step_norm print_iteration_nonlinear( iteration, nfev, cost, actual_reduction, step_norm, g_norm )
[docs] def trf( self, fun: Callable[..., Any], xdata: jnp.ndarray | tuple[jnp.ndarray], ydata: jnp.ndarray, jac: Callable[..., Any], data_mask: jnp.ndarray, transform: jnp.ndarray, x0: np.ndarray, f0: jnp.ndarray, J0: jnp.ndarray, lb: np.ndarray, ub: np.ndarray, ftol: float, xtol: float, gtol: float, max_nfev: int, f_scale: float, x_scale: np.ndarray, loss_function: None | Callable[..., Any], tr_options: dict[str, Any], verbose: int, timeit: bool = False, solver: str = "exact", diagnostics: OptimizationDiagnostics | None = None, callback: Callable[..., Any] | None = None, **kwargs: Any, ) -> OptimizeResult: """Minimize a scalar function of one or more variables using the trust-region reflective algorithm. Although I think this is not good coding style, I maintained the original code format from SciPy such that the code is easier to compare with the original. See the note from the algorithms original author below. For efficiency, it makes sense to run the simplified version of the algorithm when no bounds are imposed. We decided to write the two separate functions. It violates the DRY principle, but the individual functions are kept the most readable. Parameters ---------- fun : callable The residual function xdata : array_like or tuple of array_like The independent variable where the data is measured. If `xdata` is a tuple, then the input arguments to `fun` are assumed to be ``(xdata[0], xdata[1], ...)``. ydata : jnp.ndarray The dependent data jac : callable The Jacobian of `fun`. data_mask : jnp.ndarray The mask for the data. transform : jnp.ndarray The uncertainty transform for the data. x0 : jnp.ndarray Initial guess. Array of real elements of size (n,), where 'n' is the number of independent variables. f0 : jnp.ndarray Initial residuals. Array of real elements of size (m,), where 'm' is the number of data points. J0 : jnp.ndarray Initial Jacobian. Array of real elements of size (m, n), where 'm' is the number of data points and 'n' is the number of independent variables. lb : jnp.ndarray Lower bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ub : jnp.ndarray Upper bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ftol : float Tolerance for termination by the change of the cost function. xtol : float Tolerance for termination by the change of the independent variables. gtol : float Tolerance for termination by the norm of the gradient. max_nfev : int Maximum number of function evaluations. f_scale : float Cost function scalar x_scale : jnp.ndarray Scaling factors for independent variables. loss_function : callable, optional Loss function. If None, the standard least-squares problem is solved. tr_options : dict Options for the trust-region algorithm. verbose : int Level of algorithm's verbosity: * 0 (default) : work silently. * 1 : display a termination report. timeit : bool, optional If True, the time for each step is measured if the unbounded version is being ran. Default is False. """ # bounded or unbounded version if np.all(lb == -np.inf) and np.all(ub == np.inf): # unbounded version as timed and untimed version if not timeit: return self.trf_no_bounds( fun, xdata, ydata, jac, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose, solver, callback, **kwargs, ) else: return self.trf_no_bounds( fun, xdata, ydata, jac, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose, solver, callback, profiler=TRFProfiler(), ) else: return self.trf_bounds( fun, xdata, ydata, jac, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose, solver, callback, **kwargs, )
def _initialize_trf_state( self, x0: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, loss_function: Callable[..., Any] | None, x_scale: np.ndarray | str, f_scale: float, data_mask: jnp.ndarray, ) -> dict[str, Any]: """Initialize optimization state for TRF algorithm. This helper extracts the initialization logic from trf_no_bounds, reducing complexity and improving testability. Parameters ---------- x0 : np.ndarray Initial parameter guess f : jnp.ndarray Initial residuals J : jnp.ndarray Initial Jacobian matrix loss_function : Callable or None Loss function (None for standard least squares) x_scale : np.ndarray or str Parameter scaling factors or 'jac' for Jacobian-based scaling f_scale : float Residual scaling factor data_mask : jnp.ndarray Data masking array Returns ------- dict Initial state containing x, f, J, cost, g, scale, Delta, etc. """ m, n = J.shape state = { "x": jnp.asarray(x0), # OPT-2: Use JAX array directly, no copy "f": f, "J": J, "nfev": 1, "njev": 1, "m": m, "n": n, } # Apply loss function if provided if loss_function is not None: rho = loss_function(f, f_scale) state["cost"] = self.calculate_cost(rho, data_mask) # Save original residuals before scaling (for res.fun) state["f_true"] = f state["J"], state["f"] = self.cJIT.scale_for_robust_loss_function(J, f, rho) else: state["cost"] = self.default_loss_func(f) # No scaling applied, so f is already the true residuals state["f_true"] = f # Compute gradient state["g"] = self.compute_grad(state["J"], state["f"]) # Compute scaling factors jac_scale = isinstance(x_scale, str) and x_scale == "jac" if jac_scale: scale, scale_inv = self.cJIT.compute_jac_scale(J) state["scale"], state["scale_inv"] = scale, scale_inv state["jac_scale"] = True else: # T022: Convert scale to JAX arrays once at entry to avoid repeated conversion scale_arr = jnp.asarray(x_scale) # Guard against zero scale values to avoid Inf in scale_inv safe_scale = jnp.where(scale_arr == 0, 1.0, scale_arr) state["scale"], state["scale_inv"] = ( scale_arr, jnp.asarray(1 / safe_scale), ) state["jac_scale"] = False # Initialize trust region radius Delta = jnorm(x0 * state["scale_inv"]) # Use JAX norm state["Delta"] = Delta if Delta > 0 else 1.0 return state def _check_convergence_criteria( self, g: jnp.ndarray, gtol: float, ) -> tuple[int | None, float]: """Check if gradient convergence criterion is met. This helper extracts convergence checking logic from trf_no_bounds, reducing complexity and improving readability. Parameters ---------- g : jnp.ndarray Current gradient vector gtol : float Gradient tolerance for convergence Returns ------- tuple[int | None, float] Tuple of (termination_status, g_norm): - termination_status: 1 if gradient tolerance satisfied, None otherwise - g_norm: Computed gradient norm (OPT-8: returned to avoid redundant computation) """ # OPT-8: Compute g_norm once and return it to avoid redundant computation # Keep as JAX scalar to defer GPU→CPU sync until actually needed g_norm = jnorm(g, ord=jnp.inf) if g_norm < gtol: self.logger.debug( "Convergence: gradient tolerance satisfied", g_norm=float(g_norm), gtol=gtol, ) return 1, g_norm return None, g_norm def _solve_trust_region_subproblem( self, J: jnp.ndarray, f: jnp.ndarray, g: jnp.ndarray, scale: jnp.ndarray, Delta: float, alpha: float, solver: str, ) -> dict[str, Any]: """Solve the trust region subproblem. This helper extracts the subproblem setup and solving logic, reducing complexity and improving readability. Parameters ---------- J : jnp.ndarray Current Jacobian matrix f : jnp.ndarray Current residuals g : jnp.ndarray Current gradient scale : jnp.ndarray Parameter scaling factors Delta : float Current trust region radius alpha : float Levenberg-Marquardt parameter solver : str Solver type ('cg' or 'exact') Returns ------- dict Subproblem solution containing: - J_h: Scaled Jacobian - g_h: Scaled gradient - d: Scaling vector - d_jnp: JAX scaling vector - step_h: Step in scaled space (for CG solver) - s, V, uf: SVD components (for exact solver) """ # Setup scaled variables # T022: scale is already a JAX array (converted once in _initialize_trf_state) d = scale # Already JAX array, no conversion needed g_h = self.compute_grad_hat(g, d) result = { "d": d, "d_jnp": d, # Same as d now (backwards compatibility) "g_h": g_h, } # Solve trust region subproblem if solver == "cg": # Conjugate gradient solver J_h = J * d step_h = self.solve_tr_subproblem_cg(J, f, d, Delta, alpha) result.update( { "J_h": J_h, "step_h": step_h, "s": None, "V": None, "uf": None, } ) elif solver == "sparse": # Sparse solver path (Task 6.4: Sparse Activation) # TODO: Implement sparse SVD using JAX sparse operations # For now, fall back to dense exact solver to maintain correctness # Full sparse implementation would use: # - JAX sparse matrix operations for Jacobian # - Sparse QR or sparse SVD decomposition # - Iterative sparse linear solvers # Target: 3-10x speed, 5-50x memory reduction on sparse problems svd_output = self.svd_no_bounds(J, d, f) J_h = svd_output[0] s, V, uf = svd_output[2:] result.update( { "J_h": J_h, "step_h": None, "s": s, "V": V, "uf": uf, } ) else: # SVD-based exact solver (default dense) svd_output = self.svd_no_bounds(J, d, f) J_h = svd_output[0] # PERFORMANCE FIX: Keep arrays as JAX to avoid conversion overhead (8-12% gain) # JAX arrays work with NumPy operations through duck typing, eliminating # explicit array conversion reduces memory allocations and copies s, V, uf = svd_output[2:] # Keep as JAX arrays instead of converting result.update( { "J_h": J_h, "step_h": None, # Computed later in inner loop "s": s, "V": V, "uf": uf, } ) return result def _evaluate_step_acceptance( self, fun: Callable[..., Any], jac: Callable[..., Any], x: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, J_h: jnp.ndarray, g_h_jnp: jnp.ndarray, cost: float, d: np.ndarray, d_jnp: jnp.ndarray, Delta: float, alpha: float, step_h: jnp.ndarray | None, s: np.ndarray | None, V: np.ndarray | None, uf: np.ndarray | None, xdata: np.ndarray, ydata: np.ndarray, data_mask: jnp.ndarray, transform: Callable[..., Any] | None, loss_function: Callable[..., Any] | None, f_scale: float, scale_inv: np.ndarray, jac_scale: bool, solver: str, ftol: float, xtol: float, max_nfev: int, nfev: int, ) -> dict[str, Any]: """Evaluate step acceptance through inner trust region loop. This method implements the inner loop of the TRF algorithm, which repeatedly solves the trust region subproblem and evaluates candidate steps until an acceptable step is found. Parameters ---------- fun : Callable Function to evaluate residuals jac : Callable Function to evaluate Jacobian x : np.ndarray Current parameter values f : jnp.ndarray Current residuals (possibly scaled by loss function) J : jnp.ndarray Current Jacobian (possibly scaled by loss function) J_h : jnp.ndarray Scaled Jacobian for subproblem g_h_jnp : jnp.ndarray Scaled gradient for subproblem cost : float Current cost value d : np.ndarray Parameter scaling factors d_jnp : jnp.ndarray Parameter scaling factors (JAX array) Delta : float Trust region radius alpha : float Levenberg-Marquardt parameter step_h : jnp.ndarray | None Pre-computed step (for CG solver), None for exact solver s : np.ndarray | None SVD singular values (for exact solver), None for CG V : np.ndarray | None SVD V matrix (for exact solver), None for CG uf : np.ndarray | None SVD U^T @ f (for exact solver), None for CG xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data data_mask : jnp.ndarray Mask for valid data points transform : Callable | None Parameter transformation function loss_function : Callable | None Robust loss function f_scale : float Residual scale factor scale_inv : np.ndarray Inverse parameter scaling jac_scale : bool Whether using Jacobian-based scaling solver : str Trust region solver ('cg' or 'exact') ftol : float Cost function tolerance xtol : float Parameter tolerance max_nfev : int Maximum function evaluations nfev : int Current function evaluation count Returns ------- dict Dictionary containing: - accepted : bool - Whether a step was accepted - x_new : np.ndarray - New parameter values (if accepted) - f_new : jnp.ndarray - New residuals (if accepted) - J_new : jnp.ndarray - New Jacobian (if accepted) - cost_new : float - New cost value (if accepted) - g_new : jnp.ndarray - New gradient (if accepted) - scale : np.ndarray - Updated parameter scaling (if accepted) - scale_inv : np.ndarray - Updated inverse scaling (if accepted) - actual_reduction : float - Actual cost reduction - step_norm : float - Step norm - Delta : float - Updated trust region radius - alpha : float - Updated Levenberg-Marquardt parameter - termination_status : int | None - Termination status code - nfev : int - Updated function evaluation count - njev : int - Jacobian evaluation count (1 if accepted, 0 otherwise) """ n, m = len(x), len(f) actual_reduction = -1.0 # Python float, avoids JAX sync in while condition inner_loop_count = 0 max_inner_iterations = 100 termination_status = None step_norm = 0 x_new = x f_new = f cost_new = cost while ( actual_reduction <= 0 and nfev < max_nfev and inner_loop_count < max_inner_iterations ): inner_loop_count += 1 # Solve subproblem (reuse step or compute new one) if solver == "cg": if inner_loop_count > 1: step_h = self.solve_tr_subproblem_cg(J, f, d_jnp, Delta, alpha) _n_iter = 1 # Dummy value for compatibility else: step_h, alpha, _n_iter = solve_lsq_trust_region_jax( n, m, uf, s, V, Delta, initial_alpha=alpha ) # Compute predicted reduction predicted_reduction_jnp = -self.cJIT.evaluate_quadratic( J_h, g_h_jnp, step_h ) predicted_reduction = predicted_reduction_jnp # Transform step and evaluate objective # OPT-18: Fused step computation to reduce intermediate allocations step = d * step_h x_new = x + step # Keep step for later use in convergence check f_new = fun(x_new, xdata, ydata, data_mask, transform) nfev += 1 step_h_norm = jnorm(step_h) # Check for numerical issues if not self.check_isfinite(f_new): Delta = TR_REDUCTION_FACTOR * step_h_norm continue # Compute actual reduction if loss_function is not None: cost_new_jnp = loss_function(f_new, f_scale, data_mask, cost_only=True) else: cost_new_jnp = self.default_loss_func(f_new) cost_new = cost_new_jnp actual_reduction = cost - cost_new # Update trust region radius (JIT-compiled, stays on device) Delta_new, ratio = update_tr_radius_jax( Delta, actual_reduction, predicted_reduction, step_h_norm, step_h_norm > TR_BOUNDARY_THRESHOLD * Delta, ) # Check termination criteria (JIT-compiled, stays on device) step_norm = jnorm(step) term_code = check_termination_jax( actual_reduction, cost, step_norm, jnorm(x), ratio, ftol, xtol ) # Single GPU→CPU sync per inner iteration: materialize term_code # This sync also implicitly materializes actual_reduction and Delta_new term_code_int = int(term_code) # Materialize actual_reduction as Python float to avoid repeated JAX syncs # in the while-loop condition and the acceptance check below actual_reduction = float(actual_reduction) if term_code_int != 0: termination_status = term_code_int break if Delta_new > 0: raw_alpha = alpha * Delta / Delta_new alpha = min(raw_alpha if math.isfinite(raw_alpha) else 1e30, 1e30) Delta = Delta_new # Exit inner loop if we have a successful step if actual_reduction > 0: break # Check if inner loop hit iteration limit if inner_loop_count >= max_inner_iterations: self.logger.warning( "Inner optimization loop hit iteration limit", inner_iterations=inner_loop_count, actual_reduction=actual_reduction, ) termination_status = -3 # Inner loop limit exceeded # Prepare result result: dict[str, Any] = { "accepted": actual_reduction > 0, "actual_reduction": actual_reduction, "step_norm": step_norm if actual_reduction > 0 else 0, "Delta": Delta, "alpha": alpha, "termination_status": termination_status, "nfev": nfev, "njev": 0, # Will be set to 1 if step is accepted } # If step was accepted, compute new state if actual_reduction > 0: result.update( { "x_new": x_new, "f_new": f_new, "cost_new": cost_new, "njev": 1, } ) # Compute new Jacobian J_new = jac(x_new, xdata, ydata, data_mask, transform) # Apply loss function if provided if loss_function is not None: rho = loss_function(f_new, f_scale) J_new, f_new_scaled = self.cJIT.scale_for_robust_loss_function( J_new, f_new, rho ) result["f_new"] = f_new_scaled # Scaled residuals for optimization result["f_true_new"] = f_new # Unscaled residuals for res.fun else: result["f_new"] = f_new result["f_true_new"] = f_new # No scaling, so both are the same result["J_new"] = J_new # Compute new gradient g_new = self.compute_grad(J_new, result["f_new"]) result["g_new"] = g_new # Update scaling if using Jacobian-based scaling if jac_scale: scale_new, scale_inv_new = self.cJIT.compute_jac_scale(J_new, scale_inv) result["scale"] = scale_new result["scale_inv"] = scale_inv_new return result def _invoke_callback( self, callback: Callable[..., Any], iteration: int, cost: float, x: np.ndarray, g_norm: float, nfev: int, step_norm: float | None, actual_reduction: float | None, ) -> int | None: """Invoke user callback with proper exception handling. This helper extracts callback handling logic from trf_no_bounds, reducing complexity. Parameters ---------- callback : Callable User-provided callback function iteration : int Current iteration number cost : float Current cost value x : np.ndarray Current parameter values g_norm : float Gradient norm nfev : int Number of function evaluations step_norm : float | None Step norm (None if not computed) actual_reduction : float | None Actual cost reduction (None if not computed) Returns ------- int | None Termination status if callback requested stop, None otherwise. """ try: callback( iteration=iteration, cost=float(cost), # JAX scalar -> Python float params=np.array(x), # JAX array -> NumPy array info={ "gradient_norm": float(g_norm), "nfev": nfev, "step_norm": float(step_norm) if step_norm is not None else None, "actual_reduction": float(actual_reduction) if actual_reduction is not None else None, }, ) except StopOptimization: self.logger.info("Optimization stopped by callback (StopOptimization)") return 2 # User-requested stop except Exception as e: warnings.warn( f"Callback raised exception: {e}. Continuing optimization.", RuntimeWarning, ) return None def _initialize_bounds_state( self, x0: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, g: jnp.ndarray, lb_jnp: jnp.ndarray, ub_jnp: jnp.ndarray, x_scale: np.ndarray | str, ) -> dict[str, Any]: """Initialize bounds-specific state for TRF algorithm. This helper extracts bounds initialization logic from trf_bounds, reducing complexity. Parameters ---------- x0 : np.ndarray Initial parameter guess f : jnp.ndarray Initial residuals J : jnp.ndarray Initial Jacobian g : jnp.ndarray Initial gradient lb_jnp : jnp.ndarray Lower bounds (pre-converted to JAX) ub_jnp : jnp.ndarray Upper bounds (pre-converted to JAX) x_scale : np.ndarray | str Parameter scaling Returns ------- dict State containing v, dv, scale, scale_inv, Delta, jac_scale """ jac_scale = isinstance(x_scale, str) and x_scale == "jac" if jac_scale: scale, scale_inv = self.cJIT.compute_jac_scale(J) else: safe_scale = jnp.where(jnp.asarray(x_scale) == 0, 1.0, jnp.asarray(x_scale)) scale, scale_inv = safe_scale, 1.0 / safe_scale v, dv = CL_scaling_vector_jax(x0, g, lb_jnp, ub_jnp) mask = dv != 0 v = v.at[mask].set(v[mask] * scale_inv[mask]) Delta = jnorm(x0 * scale_inv / v**SQRT_EXPONENT) if Delta == 0: Delta = 1.0 return { "v": v, "dv": dv, "scale": scale, "scale_inv": scale_inv, "Delta": Delta, "jac_scale": jac_scale, } def _solve_bounds_subproblem( self, J: jnp.ndarray, f: jnp.ndarray, g: jnp.ndarray, v: jnp.ndarray, dv: jnp.ndarray, scale: np.ndarray, scale_inv: np.ndarray, Delta: float, alpha: float, solver: str, n: int, ) -> dict[str, Any]: """Solve trust region subproblem with bounds. This helper extracts bounds subproblem logic from trf_bounds, reducing complexity. Parameters ---------- J : jnp.ndarray Current Jacobian f : jnp.ndarray Current residuals g : jnp.ndarray Current gradient v : jnp.ndarray Coleman-Li scaling vector dv : jnp.ndarray Derivative of v scale : np.ndarray Parameter scaling scale_inv : np.ndarray Inverse parameter scaling Delta : float Trust region radius alpha : float Levenberg-Marquardt parameter solver : str Solver type n : int Number of parameters Returns ------- dict Subproblem solution containing d, g_h, J_h, diag_h, p_h, s, V, uf """ # Apply two types of scaling d = v**SQRT_EXPONENT * scale # C = diag(g * scale) Jv diag_h = g * dv * scale # "hat" gradient g_h = d * g J_diag = jnp.diag(jnp.sqrt(jnp.maximum(diag_h, 0.0))) # OPT-2: Use jnp.asarray() to avoid copy if already JAX array d_jnp = jnp.asarray(d) f_zeros = jnp.zeros(n, dtype=jnp.float64) if solver == "cg": J_h = J * d_jnp p_h = self.solve_tr_subproblem_cg_bounds( J, f, d_jnp, J_diag, f_zeros, Delta, alpha ) s, V, uf = None, None, None elif solver == "sparse": # Sparse solver path - fall back to dense for correctness output = self.svd_bounds(f, J, d_jnp, J_diag, f_zeros) J_h = output[0] s, V, uf = output[2:] p_h = None else: # Exact SVD solver (default) output = self.svd_bounds(f, J, d_jnp, J_diag, f_zeros) J_h = output[0] s, V, uf = output[2:] p_h = None return { "d": d, "d_jnp": d_jnp, "g_h": g_h, "J_h": J_h, "diag_h": diag_h, "p_h": p_h, "s": s, "V": V, "uf": uf, "f_zeros": f_zeros, "J_diag": J_diag, } def _evaluate_bounds_inner_loop( self, fun: Callable[..., Any], x: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, J_h: jnp.ndarray, g_h: jnp.ndarray, diag_h: jnp.ndarray, cost: float, d: np.ndarray, d_jnp: jnp.ndarray, Delta: float, alpha: float, p_h: jnp.ndarray | None, s: jnp.ndarray | None, V: jnp.ndarray | None, uf: jnp.ndarray | None, f_zeros: jnp.ndarray, J_diag: jnp.ndarray, xdata: np.ndarray, ydata: np.ndarray, data_mask: jnp.ndarray, transform: Callable[..., Any] | None, loss_function: Callable[..., Any] | None, f_scale: float, lb: np.ndarray, ub: np.ndarray, lb_jnp: jnp.ndarray, ub_jnp: jnp.ndarray, theta: float, solver: str, ftol: float, xtol: float, max_nfev: int, nfev: int, n: int, m: int, ) -> dict[str, Any]: """Evaluate inner loop for bounds optimization. This helper extracts the inner loop logic from trf_bounds, reducing complexity. Parameters ---------- fun : Callable Residual function x : np.ndarray Current parameters f : jnp.ndarray Current residuals J : jnp.ndarray Current Jacobian J_h : jnp.ndarray Scaled Jacobian g_h : jnp.ndarray Scaled gradient diag_h : jnp.ndarray Diagonal scaling cost : float Current cost d : np.ndarray Scaling vector d_jnp : jnp.ndarray JAX scaling vector Delta : float Trust region radius alpha : float LM parameter p_h : jnp.ndarray | None Pre-computed step (CG) s, V, uf : SVD components f_zeros : jnp.ndarray Zero vector J_diag : jnp.ndarray Diagonal Jacobian xdata, ydata : Data arrays data_mask : Data mask transform : Transform function loss_function : Loss function f_scale : Residual scale lb, ub : Bounds theta : Step back ratio solver : Solver type ftol, xtol : Tolerances max_nfev : Max function evals nfev : Current function evals n, m : Problem dimensions Returns ------- dict Inner loop result """ actual_reduction = -1.0 # Python float, avoids JAX sync in while condition inner_loop_count = 0 max_inner_iterations = 100 termination_status = None step_norm = 0 x_new = x f_new = f cost_new = cost while ( actual_reduction <= 0 and nfev < max_nfev and inner_loop_count < max_inner_iterations ): inner_loop_count += 1 if solver == "cg": if inner_loop_count > 1: p_h = self.solve_tr_subproblem_cg_bounds( J, f, d_jnp, J_diag, f_zeros, Delta, alpha ) _n_iter = 1 else: # Use m + n for the augmented system row count (J_augmented is (m+n) x n) p_h, alpha, _n_iter = solve_lsq_trust_region_jax( n, m + n, uf, s, V, Delta, initial_alpha=alpha ) p = d * p_h step, step_h, predicted_reduction = self.select_step( x, J_h, diag_h, g_h, p, p_h, d, Delta, lb, ub, theta, lb_jnp=lb_jnp, ub_jnp=ub_jnp, ) x_new = make_strictly_feasible_jax(x + step, lb_jnp, ub_jnp) f_new = fun(x_new, xdata, ydata, data_mask, transform) nfev += 1 step_h_norm = jnorm(step_h) if not self.check_isfinite(f_new): Delta = 0.25 * step_h_norm continue if loss_function is not None: cost_new = loss_function(f_new, f_scale, data_mask, cost_only=True) else: cost_new = self.default_loss_func(f_new) actual_reduction = cost - cost_new Delta_new, ratio = update_tr_radius_jax( Delta, actual_reduction, predicted_reduction, step_h_norm, step_h_norm > 0.95 * Delta, ) step_norm = jnorm(step) term_code = check_termination_jax( actual_reduction, cost, step_norm, jnorm(x), ratio, ftol, xtol ) # Single GPU→CPU sync per inner iteration: materialize term_code term_code_int = int(term_code) # Piggyback float materialization to avoid JAX sync in while/if checks actual_reduction = float(actual_reduction) if term_code_int != 0: termination_status = term_code_int break if Delta_new > 0: raw_alpha = alpha * Delta / Delta_new alpha = min(raw_alpha if math.isfinite(raw_alpha) else 1e30, 1e30) Delta = Delta_new # Check inner loop limit if inner_loop_count >= max_inner_iterations: self.logger.warning( "Inner optimization loop hit iteration limit", inner_iterations=inner_loop_count, actual_reduction=actual_reduction, ) termination_status = -3 return { "accepted": actual_reduction > 0, "x_new": x_new if actual_reduction > 0 else x, "f_new": f_new if actual_reduction > 0 else f, "cost_new": cost_new if actual_reduction > 0 else cost, "actual_reduction": actual_reduction, "step_norm": step_norm if actual_reduction > 0 else 0, "Delta": Delta, "alpha": alpha, "termination_status": termination_status, "nfev": nfev, } def _apply_accepted_step( self, acceptance_result: dict[str, Any], jac_scale: bool, njev: int, ) -> dict[str, Any]: """Apply accepted step updates to optimization state. This helper extracts the state update logic after step acceptance from trf_no_bounds, reducing complexity. Parameters ---------- acceptance_result : dict Result from _evaluate_step_acceptance jac_scale : bool Whether Jacobian scaling is enabled njev : int Current Jacobian evaluation count Returns ------- dict Updated state variables: x, f, f_true, J, cost, g, njev, and optionally scale, scale_inv. """ result = { "x": acceptance_result["x_new"], "f": acceptance_result["f_new"], "f_true": acceptance_result["f_true_new"], "J": acceptance_result["J_new"], "cost": acceptance_result["cost_new"], "g": acceptance_result["g_new"], "njev": njev + acceptance_result["njev"], } if jac_scale and "scale" in acceptance_result: result["scale"] = acceptance_result["scale"] result["scale_inv"] = acceptance_result["scale_inv"] return result def _build_optimize_result( self, x: jnp.ndarray, cost: float, f_true: jnp.ndarray, J: jnp.ndarray, g: jnp.ndarray, g_norm: float, nfev: int, njev: int, iteration: int, termination_status: int, ) -> OptimizeResult: """Build OptimizeResult from optimization state. This helper extracts result construction logic from trf_no_bounds and trf_bounds, reducing complexity. Parameters ---------- x : jnp.ndarray Final parameter values cost : float Final cost value f_true : jnp.ndarray Final residuals (unscaled) J : jnp.ndarray Final Jacobian matrix g : jnp.ndarray Final gradient g_norm : float Final gradient norm nfev : int Total function evaluations njev : int Total Jacobian evaluations iteration : int Total iterations performed termination_status : int Termination status code Returns ------- OptimizeResult The optimization result object. """ active_mask = jnp.zeros_like(x) # JAX zeros instead of NumPy return OptimizeResult( x=x, cost=float(cost), # Convert JAX scalar to Python float fun=f_true, jac=J, grad=np.array(g), # Convert JAX array to NumPy optimality=float(g_norm), # Convert JAX scalar to Python float active_mask=active_mask, nfev=nfev, njev=njev, nit=iteration, # Number of iterations performed status=termination_status, success=termination_status > 0, all_times={}, )
[docs] def trf_no_bounds( self, fun: Callable[..., Any], xdata: jnp.ndarray | tuple[jnp.ndarray], ydata: jnp.ndarray, jac: Callable[..., Any], data_mask: jnp.ndarray, transform: jnp.ndarray, x0: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, lb: np.ndarray, ub: np.ndarray, ftol: float, xtol: float, gtol: float, max_nfev: int, f_scale: float, x_scale: np.ndarray, loss_function: None | Callable[..., Any], tr_options: dict[str, Any], verbose: int, solver: str = "exact", callback: Callable[..., Any] | None = None, profiler: TRFProfiler | NullProfiler | None = None, **kwargs: Any, ) -> OptimizeResult: """Unbounded version of the trust-region reflective algorithm. Parameters ---------- fun : callable The residual function xdata : array_like or tuple of array_like The independent variable where the data is measured. If `xdata` is a tuple, then the input arguments to `fun` are assumed to be ``(xdata[0], xdata[1], ...)``. ydata : jnp.ndarray The dependent data jac : callable The Jacobian of `fun`. data_mask : jnp.ndarray The mask for the data. transform : jnp.ndarray The uncertainty transform for the data. x0 : jnp.ndarray Initial guess. Array of real elements of size (n,), where 'n' is the number of independent variables. f0 : jnp.ndarray Initial residuals. Array of real elements of size (m,), where 'm' is the number of data points. J0 : jnp.ndarray Initial Jacobian. Array of real elements of size (m, n), where 'm' is the number of data points and 'n' is the number of independent variables. lb : jnp.ndarray Lower bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ub : jnp.ndarray Upper bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ftol : float Tolerance for termination by the change of the cost function. xtol : float Tolerance for termination by the change of the independent variables. gtol : float Tolerance for termination by the norm of the gradient. max_nfev : int Maximum number of function evaluations. f_scale : float Cost function scalar x_scale : jnp.ndarray Scaling factors for independent variables. loss_function : callable, optional Loss function. If None, the standard least-squares problem is solved. tr_options : dict Options for the trust-region algorithm. verbose : int Level of algorithm's verbosity: * 0 (default) : work silently. * 1 : display a termination report. Returns ------- result : OptimizeResult The optimization result represented as a ``OptimizeResult`` object. Important attributes are: ``x`` the solution array, ``success`` a Boolean flag indicating if the optimizer exited successfully and ``message`` which describes the cause of the termination. See `OptimizeResult` for a description of other attributes. profiler : TRFProfiler, NullProfiler, or None, optional Profiler for timing algorithm operations. If None, uses NullProfiler (zero overhead). Use TRFProfiler() for detailed performance analysis. Default is None. Notes ----- The algorithm is described in [13]_. """ # Initialize profiler (NullProfiler if not provided for zero overhead) if profiler is None: profiler = NullProfiler() # Initialize optimization state using helper state = self._initialize_trf_state( x0=x0, f=f, J=J, loss_function=loss_function, x_scale=x_scale, f_scale=f_scale, data_mask=data_mask, ) # Extract state variables x = state["x"] f = state["f"] J = state["J"] cost = state["cost"] g = state["g"] scale = state["scale"] scale_inv = state["scale_inv"] Delta = state["Delta"] nfev = state["nfev"] njev = state["njev"] m = state["m"] n = state["n"] jac_scale = state["jac_scale"] f_true = state["f_true"] # Original unscaled residuals (for res.fun) # Log optimization start self.logger.info( "Starting TRF optimization (no bounds)", n_params=n, n_residuals=m, max_nfev=max_nfev, ) # Set max_nfev if not provided if max_nfev is None: max_nfev = x0.size * DEFAULT_MAX_NFEV_MULTIPLIER alpha = INITIAL_LEVENBERG_MARQUARDT_LAMBDA # "Levenberg-Marquardt" parameter termination_status = None iteration = 0 step_norm = None actual_reduction = None if verbose == 2: print_header_nonlinear() # Trust region optimization loop with self.logger.timer("optimization", log_result=False): while True: # Check gradient convergence using helper (only if not already terminated) # OPT-8: Get g_norm from convergence check to avoid redundant computation if termination_status is None: termination_status, g_norm = self._check_convergence_criteria( g, gtol ) else: g_norm = jnorm(g, ord=jnp.inf) # Only compute if already terminated if verbose == 2: # Use jax.debug.callback to avoid blocking host-device transfers debug.callback( self._log_iteration_callback, iteration, nfev, cost, actual_reduction, step_norm, g_norm, ) if termination_status is not None or nfev == max_nfev: if nfev == max_nfev: self.logger.warning( "Maximum number of function evaluations reached", nfev=nfev ) break # Log iteration details (call-site guard avoids kwargs construction) if self.logger.logger.isEnabledFor(PERFORMANCE_LEVEL): self.logger.optimization_step( iteration=iteration, cost=cost, gradient_norm=g_norm, step_size=Delta if iteration > 0 else None, nfev=nfev, ) # Solve trust region subproblem using helper subproblem_result = self._solve_trust_region_subproblem( J=J, f=f, g=g, scale=scale, Delta=Delta, alpha=alpha, solver=solver, ) # Extract subproblem solution d = subproblem_result["d"] d_jnp = subproblem_result["d_jnp"] g_h_jnp = subproblem_result["g_h"] J_h = subproblem_result["J_h"] step_h = subproblem_result["step_h"] s = subproblem_result["s"] V = subproblem_result["V"] uf = subproblem_result["uf"] # Evaluate and potentially accept step using helper acceptance_result = self._evaluate_step_acceptance( fun=fun, jac=jac, x=x, f=f, J=J, J_h=J_h, g_h_jnp=g_h_jnp, cost=cost, d=d, d_jnp=d_jnp, Delta=Delta, alpha=alpha, step_h=step_h, s=s, V=V, uf=uf, xdata=xdata, ydata=ydata, data_mask=data_mask, transform=transform, loss_function=loss_function, f_scale=f_scale, scale_inv=scale_inv, jac_scale=jac_scale, solver=solver, ftol=ftol, xtol=xtol, max_nfev=max_nfev, nfev=nfev, ) # Update state from acceptance result using helper if acceptance_result["accepted"]: step_update = self._apply_accepted_step( acceptance_result=acceptance_result, jac_scale=jac_scale, njev=njev, ) x = step_update["x"] f = step_update["f"] f_true = step_update["f_true"] J = step_update["J"] cost = step_update["cost"] g = step_update["g"] njev = step_update["njev"] if "scale" in step_update: scale = step_update["scale"] scale_inv = step_update["scale_inv"] # Update common values regardless of acceptance actual_reduction = acceptance_result["actual_reduction"] step_norm = acceptance_result["step_norm"] Delta = acceptance_result["Delta"] alpha = acceptance_result["alpha"] nfev = acceptance_result["nfev"] if acceptance_result["termination_status"] is not None: termination_status = acceptance_result["termination_status"] iteration += 1 # Invoke user callback if provided using helper if callback is not None: callback_status = self._invoke_callback( callback=callback, iteration=iteration, cost=cost, x=x, g_norm=g_norm, nfev=nfev, step_norm=step_norm, actual_reduction=actual_reduction, ) if callback_status is not None: termination_status = callback_status break if termination_status is None: termination_status = 0 # Build and return final result using helper return self._build_optimize_result( x=x, cost=cost, f_true=f_true, J=J, g=g, g_norm=g_norm, nfev=nfev, njev=njev, iteration=iteration, termination_status=termination_status, )
[docs] def trf_bounds( self, fun: Callable[..., Any], xdata: jnp.ndarray | tuple[jnp.ndarray], ydata: jnp.ndarray, jac: Callable[..., Any], data_mask: jnp.ndarray, transform: jnp.ndarray, x0: np.ndarray, f: jnp.ndarray, J: jnp.ndarray, lb: np.ndarray, ub: np.ndarray, ftol: float, xtol: float, gtol: float, max_nfev: int, f_scale: float, x_scale: np.ndarray, loss_function: None | Callable[..., Any], tr_options: dict[str, Any], verbose: int, solver: str = "exact", callback: Callable[..., Any] | None = None, **kwargs: Any, ) -> OptimizeResult: """Bounded version of the trust-region reflective algorithm. Parameters ---------- fun : callable The residual function xdata : array_like or tuple of array_like The independent variable where the data is measured. If `xdata` is a tuple, then the input arguments to `fun` are assumed to be ``(xdata[0], xdata[1], ...)``. ydata : jnp.ndarray The dependent data jac : callable The Jacobian of `fun`. data_mask : jnp.ndarray The mask for the data. transform : jnp.ndarray The uncertainty transform for the data. x0 : jnp.ndarray Initial guess. Array of real elements of size (n,), where 'n' is the number of independent variables. f0 : jnp.ndarray Initial residuals. Array of real elements of size (m,), where 'm' is the number of data points. J0 : jnp.ndarray Initial Jacobian. Array of real elements of size (m, n), where 'm' is the number of data points and 'n' is the number of independent variables. lb : jnp.ndarray Lower bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ub : jnp.ndarray Upper bounds on independent variables. Array of real elements of size (n,), where 'n' is the number of independent variables. ftol : float Tolerance for termination by the change of the cost function. xtol : float Tolerance for termination by the change of the independent variables. gtol : float Tolerance for termination by the norm of the gradient. max_nfev : int Maximum number of function evaluations. f_scale : float Cost function scalar x_scale : jnp.ndarray Scaling factors for independent variables. loss_function : callable, optional Loss function. If None, the standard least-squares problem is solved. tr_options : dict Options for the trust-region algorithm. verbose : int Level of algorithm's verbosity: * 0 (default) : work silently. * 1 : display a termination report. Returns ------- result : OptimizeResult The optimization result represented as a ``OptimizeResult`` object. Important attributes are: ``x`` the solution array, ``success`` a Boolean flag indicating if the optimizer exited successfully and ``message`` which describes the cause of the termination. See `OptimizeResult` for a description of other attributes. Notes ----- The algorithm is described in [13]_. References ---------- .. [13] J. J. More, "The Levenberg-Marquardt Algorithm: Implementation and Theory," in Numerical Analysis, ed. G. A. Watson (1978), pp. 105-116. DOI: 10.1017/CBO9780511819595.006 .. [2] T. F. Coleman and Y. Li, "An interior trust region approach for nonlinear minimization subject to bounds," SIAM Journal on Optimization, vol. 6, no. 2, pp. 418-445, 1996. """ x = x0 f_true = f nfev = 1 njev = 1 m, n = J.shape # Convert bounds to JAX once at entry to avoid repeated transfers lb_jnp = jnp.asarray(lb) ub_jnp = jnp.asarray(ub) if loss_function is not None: rho = loss_function(f, f_scale) cost = self.calculate_cost(rho, data_mask) J, f = self.cJIT.scale_for_robust_loss_function(J, f, rho) else: cost = self.default_loss_func(f) g = self.compute_grad(J, f) # Initialize bounds state using helper bounds_state = self._initialize_bounds_state( x0, f, J, g, lb_jnp, ub_jnp, x_scale ) v = bounds_state["v"] dv = bounds_state["dv"] scale = bounds_state["scale"] scale_inv = bounds_state["scale_inv"] Delta = bounds_state["Delta"] jac_scale = bounds_state["jac_scale"] # Use JAX norm for gradient norm calculation g_norm = jnorm(g * v, ord=jnp.inf) if max_nfev is None: max_nfev = x0.size * DEFAULT_MAX_NFEV_MULTIPLIER alpha = INITIAL_LEVENBERG_MARQUARDT_LAMBDA termination_status = None iteration = 0 step_norm = None actual_reduction = None if verbose == 2: print_header_nonlinear() while True: v, dv = CL_scaling_vector_jax(x, g, lb_jnp, ub_jnp) g_norm = jnorm(g * v, ord=jnp.inf) if termination_status is None and g_norm < gtol: termination_status = 1 if verbose == 2: debug.callback( self._log_iteration_callback, iteration, nfev, cost, actual_reduction, step_norm, g_norm, ) if termination_status is not None or nfev == max_nfev: break # Update v with scaling mask = dv != 0 v = v.at[mask].set(v[mask] * scale_inv[mask]) # Solve bounds subproblem using helper subproblem = self._solve_bounds_subproblem( J=J, f=f, g=g, v=v, dv=dv, scale=scale, scale_inv=scale_inv, Delta=Delta, alpha=alpha, solver=solver, n=n, ) # theta controls step back step ratio from the bounds; must be < 1 # to maintain strict interior feasibility required by TRF bounds mode theta = jnp.minimum(jnp.maximum(0.995, 1.0 - g_norm), 1.0 - 1e-8) # Evaluate inner loop using helper inner_result = self._evaluate_bounds_inner_loop( fun=fun, x=x, f=f, J=J, J_h=subproblem["J_h"], g_h=subproblem["g_h"], diag_h=subproblem["diag_h"], cost=cost, d=subproblem["d"], d_jnp=subproblem["d_jnp"], Delta=Delta, alpha=alpha, p_h=subproblem["p_h"], s=subproblem["s"], V=subproblem["V"], uf=subproblem["uf"], f_zeros=subproblem["f_zeros"], J_diag=subproblem["J_diag"], xdata=xdata, ydata=ydata, data_mask=data_mask, transform=transform, loss_function=loss_function, f_scale=f_scale, lb=lb, ub=ub, lb_jnp=lb_jnp, ub_jnp=ub_jnp, theta=theta, solver=solver, ftol=ftol, xtol=xtol, max_nfev=max_nfev, nfev=nfev, n=n, m=m, ) # Update from inner loop result actual_reduction = inner_result["actual_reduction"] step_norm = inner_result["step_norm"] Delta = inner_result["Delta"] alpha = inner_result["alpha"] nfev = inner_result["nfev"] if inner_result["termination_status"] is not None: termination_status = inner_result["termination_status"] if inner_result["accepted"]: x = inner_result["x_new"] f_unscaled = inner_result["f_new"] f_true = f_unscaled cost = inner_result["cost_new"] J = jac(x, xdata, ydata, data_mask, transform) njev += 1 if loss_function is not None: rho = loss_function(f_unscaled, f_scale) J, f = self.cJIT.scale_for_robust_loss_function(J, f_unscaled, rho) else: f = f_unscaled g = self.compute_grad(J, f) if jac_scale: scale, scale_inv = self.cJIT.compute_jac_scale(J, scale_inv) else: step_norm = 0 # actual_reduction is already set from inner_result (negative value); # retaining it lets callbacks report true optimization state iteration += 1 # Invoke user callback using helper if callback is not None: callback_status = self._invoke_callback( callback=callback, iteration=iteration, cost=cost, x=x, g_norm=g_norm, nfev=nfev, step_norm=step_norm, actual_reduction=actual_reduction, ) if callback_status is not None: termination_status = callback_status break if termination_status is None: termination_status = 0 x_out = np.asarray(x) active_mask = find_active_constraints(x_out, lb, ub, rtol=xtol) return OptimizeResult( x=x_out, cost=float(cost), fun=f_true, jac=J, grad=np.array(g), optimality=float(g_norm), active_mask=active_mask, nfev=nfev, njev=njev, nit=iteration, status=termination_status, success=termination_status > 0, )
[docs] def select_step( self, x: np.ndarray, J_h: jnp.ndarray, diag_h: jnp.ndarray, g_h: jnp.ndarray, p: np.ndarray, p_h: np.ndarray, d: np.ndarray, Delta: float, lb: np.ndarray, ub: np.ndarray, theta: float, lb_jnp: jnp.ndarray | None = None, ub_jnp: jnp.ndarray | None = None, ) -> tuple[np.ndarray, Any, Any]: """Select the best step according to Trust Region Reflective algorithm. Parameters ---------- x : np.ndarray Current set parameter vector. J_h : jnp.ndarray Jacobian matrix in the scaled 'hat' space. diag_h : jnp.ndarray Diagonal of the scaled matrix C = diag(g * scale) Jv? g_h : jnp.ndarray Gradient vector in the scaled 'hat' space. p : np.ndarray Trust-region step in the original space. p_h : np.ndarray Trust-region step in the scaled 'hat' space. d : np.ndarray Scaling vector. Delta : float Trust-region radius. lb : np.ndarray Lower bounds on variables. ub : np.ndarray Upper bounds on variables. theta : float Controls step back step ratio from the bounds. lb_jnp : jnp.ndarray, optional Pre-converted JAX lower bounds (avoids repeated conversion). ub_jnp : jnp.ndarray, optional Pre-converted JAX upper bounds (avoids repeated conversion). Returns ------- step : np.ndarray Step in the original space. step_h : np.ndarray Step in the scaled 'hat' space. predicted_reduction : float Predicted reduction in the cost function. """ if lb_jnp is not None and ub_jnp is not None: _in_bounds = in_bounds_jax(x + p, lb_jnp, ub_jnp) else: _in_bounds = in_bounds(x + p, lb, ub) if _in_bounds: p_value = self.cJIT.evaluate_quadratic(J_h, g_h, p_h, diag=diag_h) return p, p_h, -p_value # B004: Use JAX-compiled versions to avoid D2H array transfers. # Convert inputs to JAX arrays for on-device computation. x_jnp = jnp.asarray(x) p_jnp = jnp.asarray(p) d_jnp = jnp.asarray(d) lb_j = lb_jnp if lb_jnp is not None else jnp.asarray(lb) ub_j = ub_jnp if ub_jnp is not None else jnp.asarray(ub) p_stride_jax, hits = step_size_to_bound_jax(x_jnp, p_jnp, lb_j, ub_j) p_stride = float(p_stride_jax) # Compute the reflected direction. r_h = jnp.asarray(p_h) # Negate components that hit bounds r_h = jnp.where(hits != 0, -r_h, r_h) r = d_jnp * r_h # Restrict trust-region step, such that it hits the bound. p_h_scaled = jnp.asarray(p_h) * p_stride x_on_bound = x_jnp + p_jnp * p_stride # Reflected direction will cross first either feasible region or trust # region boundary. _, to_tr_jax = intersect_trust_region_jax(p_h_scaled, r_h, Delta) to_bound_jax, _ = step_size_to_bound_jax(x_on_bound, r, lb_j, ub_j) # Materialize scalars for host-side branching to_tr = float(to_tr_jax) to_bound = float(to_bound_jax) # Find lower and upper bounds on a step size along the reflected # direction, considering the strict feasibility requirement. There is no # single correct way to do that, the chosen approach seems to work best # on test problems. r_stride = min(to_bound, to_tr) if r_stride > 0: r_stride_l = (1 - theta) * p_stride / r_stride r_stride_u = theta * to_bound if r_stride == to_bound else to_tr else: r_stride_l = 0 r_stride_u = -1 # Check if reflection step is available. if r_stride_l <= r_stride_u: a, b, c = self.cJIT.build_quadratic_1d( J_h, g_h, r_h, s0=p_h_scaled, diag=diag_h ) r_stride_jax, r_value = minimize_quadratic_1d_jax( a, b, r_stride_l, r_stride_u, c=c ) r_stride = float(r_stride_jax) r_h = r_h * r_stride + p_h_scaled r = r_h * d_jnp else: r_value = jnp.inf # Now correct p_h to make it strictly interior. p = np.asarray(p_jnp * p_stride * theta) p_h = p_h_scaled * theta p_value = self.cJIT.evaluate_quadratic(J_h, g_h, p_h, diag=diag_h) ag_h = -g_h ag = d_jnp * ag_h ag_h_norm = jnorm(ag_h) to_tr = float(Delta / ag_h_norm) if float(ag_h_norm) > 0 else 1.0 to_bound_jax, _ = step_size_to_bound_jax(x_jnp, ag, lb_j, ub_j) to_bound = float(to_bound_jax) ag_stride = theta * to_bound if to_bound < to_tr else to_tr a, b = self.cJIT.build_quadratic_1d(J_h, g_h, ag_h, diag=diag_h) ag_stride_jax, ag_value = minimize_quadratic_1d_jax(a, b, 0, ag_stride) ag_h = ag_h * ag_stride_jax ag = ag * ag_stride_jax # Materialize reduction values for host-side comparison p_value = float(p_value) r_value = float(r_value) ag_value = float(ag_value) if p_value < r_value and p_value < ag_value: return np.asarray(p), p_h, -p_value elif r_value < p_value and r_value < ag_value: return np.asarray(r), r_h, -r_value else: return np.asarray(ag), ag_h, -ag_value
[docs] def optimize( self, fun: Callable[..., Any], x0: np.ndarray, jac: Callable[..., Any] | None = None, bounds: tuple[np.ndarray, np.ndarray] | tuple[float, float] = (-np.inf, np.inf), **kwargs: Any, ) -> OptimizeResult: """Perform optimization using trust region reflective algorithm. This method provides a simplified interface to the TRF algorithm. For full control and curve fitting applications, use the `trf` method directly. Parameters ---------- fun : callable The objective function to minimize. Should return residuals. x0 : np.ndarray Initial guess for parameters jac : callable, optional Jacobian function. If None, uses automatic differentiation. bounds : tuple of arrays Lower and upper bounds for parameters **kwargs Additional optimization parameters Returns ------- OptimizeResult The optimization result Raises ------ NotImplementedError This simplified interface is not yet implemented. Use the `trf` method for full curve fitting functionality. """ raise NotImplementedError( "The simplified optimize() interface is not yet implemented for TrustRegionReflective. " "This class is designed for curve fitting applications. " "Use the `trf()` method directly, or use the higher-level interfaces in " "`nlsq.curve_fit()` or `LeastSquares.least_squares()`." )