"""Trust Region Reflective algorithm for least-squares optimization.
The algorithm is based on ideas from paper [STIR]_. The main idea is to
account for the presence of the bounds by appropriate scaling of the variables (or,
equivalently, changing a trust-region shape). Let's introduce a vector v::
| ub[i] - x[i], if g[i] < 0 and ub[i] < np.inf
v[i] = | x[i] - lb[i], if g[i] > 0 and lb[i] > -np.inf
| 1, otherwise
where g is the gradient of a cost function and lb, ub are the bounds. Its
components are distances to the bounds at which the anti-gradient points (if
this distance is finite). Define a scaling matrix D = diag(v**0.5).
First-order optimality conditions can be stated as::
D^2 g(x) = 0.
Meaning that components of the gradient should be zero for strictly interior
variables, and components must point inside the feasible region for variables
on the bound.
Now consider this system of equations as a new optimization problem. If the
point x is strictly interior (not on the bound), then the left-hand side is
differentiable and the Newton step for it satisfies::
(D^2 H + diag(g) Jv) p = -D^2 g
where H is the Hessian matrix (or its J^T J approximation in least squares),
Jv is the Jacobian matrix of v with components -1, 1 or 0, such that all
elements of matrix C = diag(g) Jv are non-negative. Introduce the change
of the variables x = D x_h (_h would be "hat" in LaTeX). In the new variables,
we have a Newton step satisfying::
B_h p_h = -g_h,
where B_h = D H D + C, g_h = D g. In least squares B_h = J_h^T J_h, where
J_h = J D. Note that J_h and g_h are proper Jacobian and gradient with respect
to "hat" variables. To guarantee global convergence we formulate a
trust-region problem based on the Newton step in the new variables::
0.5 * p_h^T B_h p + g_h^T p_h -> min, ||p_h|| <= Delta
In the original space B = H + D^{-1} C D^{-1}, and the equivalent trust-region
problem is::
0.5 * p^T B p + g^T p -> min, ||D^{-1} p|| <= Delta
Here, the meaning of the matrix D becomes more clear: it alters the shape
of a trust-region, such that large steps towards the bounds are not allowed.
In the implementation, the trust-region problem is solved in "hat" space,
but handling of the bounds is done in the original space (see below and read
the code).
The introduction of the matrix D doesn't allow to ignore bounds, the algorithm
must keep iterates strictly feasible (to satisfy aforementioned
differentiability), the parameter theta controls step back from the boundary
(see the code for details).
The algorithm does another important trick. If the trust-region solution
doesn't fit into the bounds, then a reflected (from a firstly encountered
bound) search direction is considered. For motivation and analysis refer to
[STIR]_ paper (and other papers of the authors). In practice, it doesn't need
a lot of justifications, the algorithm simply chooses the best step among
three: a constrained trust-region step, a reflected step and a constrained
Cauchy step (a minimizer along -g_h in "hat" space, or -D^2 g in the original
space).
Another feature is that a trust-region radius control strategy is modified to
account for appearance of the diagonal C matrix (called diag_h in the code).
Note that all described peculiarities are completely gone as we consider
problems without bounds (the algorithm becomes a standard trust-region type
algorithm very similar to ones implemented in MINPACK).
The implementation supports two methods of solving the trust-region problem.
The first, called 'exact', applies SVD on Jacobian and then solves the problem
very accurately using the algorithm described in [JJMore]_. It is not
applicable to large problem. The second, called 'lsmr', uses the 2-D subspace
approach (sometimes called "indefinite dogleg"), where the problem is solved
in a subspace spanned by the gradient and the approximate Gauss-Newton step
found by ``scipy.sparse.linalg.lsmr``. A 2-D trust-region problem is
reformulated as a 4th order algebraic equation and solved very accurately by
``numpy.roots``. The subspace approach allows to solve very large problems
(up to couple of millions of residuals on a regular PC), provided the Jacobian
matrix is sufficiently sparse.
References
----------
.. [STIR] Branch, M.A., T.F. Coleman, and Y. Li, "A Subspace, Interior,
and Conjugate Gradient Method for Large-Scale Bound-Constrained
Minimization Problems," SIAM Journal on Scientific Computing,
Vol. 21, Number 1, pp 1-23, 1999.
.. [JJMore] More, J. J., "The Levenberg-Marquardt Algorithm: Implementation
and Theory," Numerical Analysis, ed. G. A. Watson, Lecture
"""
# mypy: disable-error-code="arg-type,assignment,attr-defined,operator,misc,index,var-annotated,override"
# Note: mypy errors are mostly arg-type/assignment mismatches where Optional values
# are passed to methods expecting non-Optional, plus operator type conflicts between
# JAX arrays and numpy arrays. These require deeper refactoring of the TRF API.
from __future__ import annotations
import math
import warnings
from collections.abc import Callable
import numpy as np
# Initialize JAX configuration through central config
from nlsq.config import JAXConfig
__jax_config = JAXConfig()
import jax.numpy as jnp
from jax import debug
from jax.numpy.linalg import norm as jnorm
# Import safe SVD with fallback (full deterministic SVD only)
from nlsq.stability.svd_fallback import (
initialize_gpu_safely,
)
# Setup logging
from nlsq.utils.logging import LogLevel, get_logger
# Cache the performance log level for fast guard checks in hot loops
PERFORMANCE_LEVEL = LogLevel.PERFORMANCE
logger = get_logger("trf")
# Initialize GPU settings safely
initialize_gpu_safely()
# Import dataclasses for SVDCache
from dataclasses import dataclass
from typing import Any, NamedTuple
from nlsq.caching.unified_cache import get_global_cache
from nlsq.callbacks import StopOptimization
from nlsq.common_jax import (
CL_scaling_vector_jax,
CommonJIT,
check_termination_jax,
in_bounds_jax,
intersect_trust_region_jax,
make_strictly_feasible_jax,
minimize_quadratic_1d_jax,
solve_lsq_trust_region_jax,
step_size_to_bound_jax,
update_tr_radius_jax,
)
from nlsq.common_scipy import (
find_active_constraints,
in_bounds,
print_header_nonlinear,
print_iteration_nonlinear,
)
from nlsq.constants import (
DEFAULT_MAX_NFEV_MULTIPLIER,
INITIAL_LEVENBERG_MARQUARDT_LAMBDA,
)
# Logging support
# Optimizer base class
from nlsq.core.optimizer_base import TrustRegionOptimizerBase
# Profiling support
from nlsq.core.profiler import NullProfiler, TRFProfiler
# JIT-compiled helper functions (extracted for modularity)
from nlsq.core.trf_jit import TrustRegionJITFunctions
from nlsq.result import OptimizeResult
from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.utils.diagnostics import OptimizationDiagnostics
[docs]
class SVDCache(NamedTuple):
"""Cache SVD decomposition across inner loop iterations when Jacobian unchanged.
This cache stores the SVD components (U, s, V) along with the scaled Jacobian
J_h to avoid redundant SVD computations during inner loop iterations where
the step is rejected and parameters remain unchanged.
Attributes
----------
U : jnp.ndarray
Left singular vectors (m x k), where m is residuals and k = min(m, n).
s : jnp.ndarray
Singular values (k,).
V : jnp.ndarray
Right singular vectors (n x k), where n is parameters.
J_h : jnp.ndarray
Scaled Jacobian in "hat" space (m x n).
x_hash : int
Hash of parameter vector for cache validation. Cache is valid only
when the current parameter hash matches this value.
Notes
-----
The cache is valid only when `x_hash` matches the current parameter vector's hash.
When a step is rejected (actual_reduction <= 0), the parameters don't change,
so the SVD can be reused. When a step is accepted, the cache must be invalidated.
The expected speedup from SVD caching is 20-40% on problems with frequent step
rejections, as SVD computation is O(mn^2) and dominates iteration time.
"""
U: jnp.ndarray
s: jnp.ndarray
V: jnp.ndarray
J_h: jnp.ndarray
x_hash: int
# =====================================================================
# TRF Configuration Dataclasses (US4 - Parameter Objects)
# =====================================================================
[docs]
@dataclass(frozen=True, slots=True)
class TRFConfig:
"""Immutable TRF algorithm configuration.
Groups algorithm configuration parameters passed to TRF optimizer functions.
This is an internal implementation detail - the public API remains unchanged.
Attributes
----------
ftol : float
Tolerance for termination by change of cost function.
xtol : float
Tolerance for termination by change of independent variables.
gtol : float
Tolerance for termination by norm of gradient.
max_nfev : int or None
Maximum number of function evaluations. None for unlimited.
x_scale : str
Characteristic scale of variables. 'jac' for automatic scaling.
loss : str
Loss function type ('linear', 'soft_l1', 'huber', 'cauchy', 'arctan').
tr_solver : str
Trust-region subproblem solver ('exact', 'lsmr', 'cg').
verbose : int
Verbosity level (0=silent, 1=termination, 2=iterations).
"""
ftol: float = 1e-8
xtol: float = 1e-8
gtol: float = 1e-8
max_nfev: int | None = None
x_scale: str = "jac"
loss: str = "linear"
tr_solver: str = "exact"
verbose: int = 0
[docs]
def __post_init__(self) -> None:
"""Validate configuration values."""
if self.ftol < 0 or not math.isfinite(self.ftol):
raise ValueError(
f"ftol must be a finite non-negative number, got {self.ftol}"
)
if self.xtol < 0 or not math.isfinite(self.xtol):
raise ValueError(
f"xtol must be a finite non-negative number, got {self.xtol}"
)
if self.gtol < 0 or not math.isfinite(self.gtol):
raise ValueError(
f"gtol must be a finite non-negative number, got {self.gtol}"
)
if self.max_nfev is not None and self.max_nfev <= 0:
raise ValueError(f"max_nfev must be positive, got {self.max_nfev}")
valid_losses = {"linear", "soft_l1", "huber", "cauchy", "arctan"}
if self.loss not in valid_losses:
raise ValueError(f"loss must be one of {valid_losses}, got {self.loss}")
valid_solvers = {"exact", "lsmr", "cg"}
if self.tr_solver not in valid_solvers:
raise ValueError(
f"tr_solver must be one of {valid_solvers}, got {self.tr_solver}"
)
[docs]
@dataclass(slots=True)
class StepContext:
"""Mutable state container for TRF step computation.
Groups the iteration state variables passed between TRF helper methods.
This reduces parameter count and improves code clarity.
Attributes
----------
x : jnp.ndarray
Current parameter values.
f : jnp.ndarray
Residual vector at x.
J : jnp.ndarray
Jacobian matrix at x.
cost : float
Current cost (0.5 * ||f||^2).
g : jnp.ndarray
Gradient vector (J^T @ f).
trust_radius : float
Current trust region radius (Delta).
iteration : int
Current iteration number.
scale : jnp.ndarray
Variable scaling factors.
scale_inv : jnp.ndarray
Inverse scaling factors.
alpha : float
Levenberg-Marquardt parameter.
"""
x: jnp.ndarray
f: jnp.ndarray
J: jnp.ndarray
cost: float
g: jnp.ndarray
trust_radius: float
iteration: int
scale: jnp.ndarray
scale_inv: jnp.ndarray
alpha: float = 0.0
[docs]
@dataclass(frozen=True, slots=True)
class BoundsContext:
"""Bound constraint data for TRF optimization.
Groups the bound-related arrays used in bounded optimization.
Attributes
----------
lb : jnp.ndarray
Lower bounds on parameters.
ub : jnp.ndarray
Upper bounds on parameters.
x_scale : jnp.ndarray
Scaling factors for bounded variables.
x_offset : jnp.ndarray
Offset for bounded variables (center of bounds).
lb_scaled : jnp.ndarray
Scaled lower bounds.
ub_scaled : jnp.ndarray
Scaled upper bounds.
"""
lb: jnp.ndarray
ub: jnp.ndarray
x_scale: jnp.ndarray
x_offset: jnp.ndarray
lb_scaled: jnp.ndarray
ub_scaled: jnp.ndarray
[docs]
@classmethod
def from_bounds(
cls,
lb: jnp.ndarray,
ub: jnp.ndarray,
x_scale: jnp.ndarray | None = None,
) -> BoundsContext:
"""Create BoundsContext from bounds arrays.
Parameters
----------
lb : array_like
Lower bounds.
ub : array_like
Upper bounds.
x_scale : array_like, optional
Scaling factors. If None, uses 1.0.
Returns
-------
BoundsContext
Initialized bounds context.
"""
lb = jnp.asarray(lb)
ub = jnp.asarray(ub)
if x_scale is None:
x_scale = jnp.ones_like(lb)
else:
x_scale = jnp.asarray(x_scale)
# Handle infinite bounds: when both are infinite the midpoint is 0;
# when only one is infinite the offset is the finite bound.
both_finite = jnp.isfinite(lb) & jnp.isfinite(ub)
lb_finite_only = jnp.isfinite(lb) & ~jnp.isfinite(ub)
ub_finite_only = ~jnp.isfinite(lb) & jnp.isfinite(ub)
x_offset = jnp.where(
both_finite,
(lb + ub) / 2.0,
jnp.where(lb_finite_only, lb, jnp.where(ub_finite_only, ub, 0.0)),
)
lb_scaled = (lb - x_offset) / x_scale
ub_scaled = (ub - x_offset) / x_scale
return cls(
lb=lb,
ub=ub,
x_scale=x_scale,
x_offset=x_offset,
lb_scaled=lb_scaled,
ub_scaled=ub_scaled,
)
# Algorithm constants
# Trust region parameters
TR_REDUCTION_FACTOR = 0.25 # Factor to reduce trust region when numerical issues occur
TR_BOUNDARY_THRESHOLD = 0.95 # Threshold for checking if step is close to boundary
SQRT_EXPONENT = 0.5 # Exponent for square root in scaling (v**0.5)
[docs]
class TrustRegionReflective(TrustRegionJITFunctions, TrustRegionOptimizerBase):
"""Trust Region Reflective algorithm for bounded least squares optimization.
Implements the TRF algorithm with variable scaling to handle parameter bounds.
Supports exact (SVD) and iterative (CG) solvers for trust region subproblems.
"""
[docs]
def __init__(self, enable_stability: bool = False):
"""Initialize the TrustRegionReflective optimizer.
Creates JIT-compiled functions and sets up logging infrastructure.
All optimization functions are compiled during initialization for
maximum performance during solve operations.
Parameters
----------
enable_stability : bool, default False
Enable numerical stability checks and fixes
"""
TrustRegionJITFunctions.__init__(self)
TrustRegionOptimizerBase.__init__(self, name="trf")
self.cJIT = CommonJIT()
# Initialize unified cache for JIT compilation tracking
self.cache = get_global_cache()
# Initialize stability system
self.enable_stability = enable_stability
if enable_stability:
self.stability_guard = NumericalStabilityGuard()
@staticmethod
def _log_iteration_callback(
iteration: Any,
nfev: Any,
cost: Any,
actual_reduction: Any,
step_norm: Any,
g_norm: Any,
) -> None:
"""Wrapper for logging callback that converts JAX arrays to Python scalars.
This function is called by jax.debug.callback and ensures all arguments
are converted from JAX arrays to Python scalars before logging.
Parameters
----------
iteration : int or jax.Array
Iteration number
nfev : int or jax.Array
Number of function evaluations
cost : float or jax.Array
Current cost
actual_reduction : float or jax.Array or None
Actual cost reduction
step_norm : float or jax.Array or None
Step norm
g_norm : float or jax.Array
Gradient norm
"""
# Convert JAX arrays to Python scalars
iteration = int(iteration) if hasattr(iteration, "item") else iteration
nfev = int(nfev) if hasattr(nfev, "item") else nfev
cost = float(cost) if hasattr(cost, "item") else cost
g_norm = float(g_norm) if hasattr(g_norm, "item") else g_norm
# Handle optional values
if actual_reduction is not None:
actual_reduction = (
float(actual_reduction)
if hasattr(actual_reduction, "item")
else actual_reduction
)
if step_norm is not None:
step_norm = float(step_norm) if hasattr(step_norm, "item") else step_norm
print_iteration_nonlinear(
iteration, nfev, cost, actual_reduction, step_norm, g_norm
)
[docs]
def trf(
self,
fun: Callable[..., Any],
xdata: jnp.ndarray | tuple[jnp.ndarray],
ydata: jnp.ndarray,
jac: Callable[..., Any],
data_mask: jnp.ndarray,
transform: jnp.ndarray,
x0: np.ndarray,
f0: jnp.ndarray,
J0: jnp.ndarray,
lb: np.ndarray,
ub: np.ndarray,
ftol: float,
xtol: float,
gtol: float,
max_nfev: int,
f_scale: float,
x_scale: np.ndarray,
loss_function: None | Callable[..., Any],
tr_options: dict[str, Any],
verbose: int,
timeit: bool = False,
solver: str = "exact",
diagnostics: OptimizationDiagnostics | None = None,
callback: Callable[..., Any] | None = None,
**kwargs: Any,
) -> OptimizeResult:
"""Minimize a scalar function of one or more variables using the
trust-region reflective algorithm. Although I think this is not good
coding style, I maintained the original code format from SciPy such
that the code is easier to compare with the original. See the note
from the algorithms original author below.
For efficiency, it makes sense to run
the simplified version of the algorithm when no bounds are imposed.
We decided to write the two separate functions. It violates the DRY
principle, but the individual functions are kept the most readable.
Parameters
----------
fun : callable
The residual function
xdata : array_like or tuple of array_like
The independent variable where the data is measured. If `xdata` is a
tuple, then the input arguments to `fun` are assumed to be
``(xdata[0], xdata[1], ...)``.
ydata : jnp.ndarray
The dependent data
jac : callable
The Jacobian of `fun`.
data_mask : jnp.ndarray
The mask for the data.
transform : jnp.ndarray
The uncertainty transform for the data.
x0 : jnp.ndarray
Initial guess. Array of real elements of size (n,), where 'n' is the
number of independent variables.
f0 : jnp.ndarray
Initial residuals. Array of real elements of size (m,), where 'm' is
the number of data points.
J0 : jnp.ndarray
Initial Jacobian. Array of real elements of size (m, n), where 'm' is
the number of data points and 'n' is the number of independent
variables.
lb : jnp.ndarray
Lower bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ub : jnp.ndarray
Upper bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ftol : float
Tolerance for termination by the change of the cost function.
xtol : float
Tolerance for termination by the change of the independent variables.
gtol : float
Tolerance for termination by the norm of the gradient.
max_nfev : int
Maximum number of function evaluations.
f_scale : float
Cost function scalar
x_scale : jnp.ndarray
Scaling factors for independent variables.
loss_function : callable, optional
Loss function. If None, the standard least-squares problem is
solved.
tr_options : dict
Options for the trust-region algorithm.
verbose : int
Level of algorithm's verbosity:
* 0 (default) : work silently.
* 1 : display a termination report.
timeit : bool, optional
If True, the time for each step is measured if the unbounded
version is being ran. Default is False.
"""
# bounded or unbounded version
if np.all(lb == -np.inf) and np.all(ub == np.inf):
# unbounded version as timed and untimed version
if not timeit:
return self.trf_no_bounds(
fun,
xdata,
ydata,
jac,
data_mask,
transform,
x0,
f0,
J0,
lb,
ub,
ftol,
xtol,
gtol,
max_nfev,
f_scale,
x_scale,
loss_function,
tr_options,
verbose,
solver,
callback,
**kwargs,
)
else:
return self.trf_no_bounds(
fun,
xdata,
ydata,
jac,
data_mask,
transform,
x0,
f0,
J0,
lb,
ub,
ftol,
xtol,
gtol,
max_nfev,
f_scale,
x_scale,
loss_function,
tr_options,
verbose,
solver,
callback,
profiler=TRFProfiler(),
)
else:
return self.trf_bounds(
fun,
xdata,
ydata,
jac,
data_mask,
transform,
x0,
f0,
J0,
lb,
ub,
ftol,
xtol,
gtol,
max_nfev,
f_scale,
x_scale,
loss_function,
tr_options,
verbose,
solver,
callback,
**kwargs,
)
def _initialize_trf_state(
self,
x0: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
loss_function: Callable[..., Any] | None,
x_scale: np.ndarray | str,
f_scale: float,
data_mask: jnp.ndarray,
) -> dict[str, Any]:
"""Initialize optimization state for TRF algorithm.
This helper extracts the initialization logic from trf_no_bounds,
reducing complexity and improving testability.
Parameters
----------
x0 : np.ndarray
Initial parameter guess
f : jnp.ndarray
Initial residuals
J : jnp.ndarray
Initial Jacobian matrix
loss_function : Callable or None
Loss function (None for standard least squares)
x_scale : np.ndarray or str
Parameter scaling factors or 'jac' for Jacobian-based scaling
f_scale : float
Residual scaling factor
data_mask : jnp.ndarray
Data masking array
Returns
-------
dict
Initial state containing x, f, J, cost, g, scale, Delta, etc.
"""
m, n = J.shape
state = {
"x": jnp.asarray(x0), # OPT-2: Use JAX array directly, no copy
"f": f,
"J": J,
"nfev": 1,
"njev": 1,
"m": m,
"n": n,
}
# Apply loss function if provided
if loss_function is not None:
rho = loss_function(f, f_scale)
state["cost"] = self.calculate_cost(rho, data_mask)
# Save original residuals before scaling (for res.fun)
state["f_true"] = f
state["J"], state["f"] = self.cJIT.scale_for_robust_loss_function(J, f, rho)
else:
state["cost"] = self.default_loss_func(f)
# No scaling applied, so f is already the true residuals
state["f_true"] = f
# Compute gradient
state["g"] = self.compute_grad(state["J"], state["f"])
# Compute scaling factors
jac_scale = isinstance(x_scale, str) and x_scale == "jac"
if jac_scale:
scale, scale_inv = self.cJIT.compute_jac_scale(J)
state["scale"], state["scale_inv"] = scale, scale_inv
state["jac_scale"] = True
else:
# T022: Convert scale to JAX arrays once at entry to avoid repeated conversion
scale_arr = jnp.asarray(x_scale)
# Guard against zero scale values to avoid Inf in scale_inv
safe_scale = jnp.where(scale_arr == 0, 1.0, scale_arr)
state["scale"], state["scale_inv"] = (
scale_arr,
jnp.asarray(1 / safe_scale),
)
state["jac_scale"] = False
# Initialize trust region radius
Delta = jnorm(x0 * state["scale_inv"]) # Use JAX norm
state["Delta"] = Delta if Delta > 0 else 1.0
return state
def _check_convergence_criteria(
self,
g: jnp.ndarray,
gtol: float,
) -> tuple[int | None, float]:
"""Check if gradient convergence criterion is met.
This helper extracts convergence checking logic from trf_no_bounds,
reducing complexity and improving readability.
Parameters
----------
g : jnp.ndarray
Current gradient vector
gtol : float
Gradient tolerance for convergence
Returns
-------
tuple[int | None, float]
Tuple of (termination_status, g_norm):
- termination_status: 1 if gradient tolerance satisfied, None otherwise
- g_norm: Computed gradient norm (OPT-8: returned to avoid redundant computation)
"""
# OPT-8: Compute g_norm once and return it to avoid redundant computation
# Keep as JAX scalar to defer GPU→CPU sync until actually needed
g_norm = jnorm(g, ord=jnp.inf)
if g_norm < gtol:
self.logger.debug(
"Convergence: gradient tolerance satisfied",
g_norm=float(g_norm),
gtol=gtol,
)
return 1, g_norm
return None, g_norm
def _solve_trust_region_subproblem(
self,
J: jnp.ndarray,
f: jnp.ndarray,
g: jnp.ndarray,
scale: jnp.ndarray,
Delta: float,
alpha: float,
solver: str,
) -> dict[str, Any]:
"""Solve the trust region subproblem.
This helper extracts the subproblem setup and solving logic,
reducing complexity and improving readability.
Parameters
----------
J : jnp.ndarray
Current Jacobian matrix
f : jnp.ndarray
Current residuals
g : jnp.ndarray
Current gradient
scale : jnp.ndarray
Parameter scaling factors
Delta : float
Current trust region radius
alpha : float
Levenberg-Marquardt parameter
solver : str
Solver type ('cg' or 'exact')
Returns
-------
dict
Subproblem solution containing:
- J_h: Scaled Jacobian
- g_h: Scaled gradient
- d: Scaling vector
- d_jnp: JAX scaling vector
- step_h: Step in scaled space (for CG solver)
- s, V, uf: SVD components (for exact solver)
"""
# Setup scaled variables
# T022: scale is already a JAX array (converted once in _initialize_trf_state)
d = scale # Already JAX array, no conversion needed
g_h = self.compute_grad_hat(g, d)
result = {
"d": d,
"d_jnp": d, # Same as d now (backwards compatibility)
"g_h": g_h,
}
# Solve trust region subproblem
if solver == "cg":
# Conjugate gradient solver
J_h = J * d
step_h = self.solve_tr_subproblem_cg(J, f, d, Delta, alpha)
result.update(
{
"J_h": J_h,
"step_h": step_h,
"s": None,
"V": None,
"uf": None,
}
)
elif solver == "sparse":
# Sparse solver path (Task 6.4: Sparse Activation)
# TODO: Implement sparse SVD using JAX sparse operations
# For now, fall back to dense exact solver to maintain correctness
# Full sparse implementation would use:
# - JAX sparse matrix operations for Jacobian
# - Sparse QR or sparse SVD decomposition
# - Iterative sparse linear solvers
# Target: 3-10x speed, 5-50x memory reduction on sparse problems
svd_output = self.svd_no_bounds(J, d, f)
J_h = svd_output[0]
s, V, uf = svd_output[2:]
result.update(
{
"J_h": J_h,
"step_h": None,
"s": s,
"V": V,
"uf": uf,
}
)
else:
# SVD-based exact solver (default dense)
svd_output = self.svd_no_bounds(J, d, f)
J_h = svd_output[0]
# PERFORMANCE FIX: Keep arrays as JAX to avoid conversion overhead (8-12% gain)
# JAX arrays work with NumPy operations through duck typing, eliminating
# explicit array conversion reduces memory allocations and copies
s, V, uf = svd_output[2:] # Keep as JAX arrays instead of converting
result.update(
{
"J_h": J_h,
"step_h": None, # Computed later in inner loop
"s": s,
"V": V,
"uf": uf,
}
)
return result
def _evaluate_step_acceptance(
self,
fun: Callable[..., Any],
jac: Callable[..., Any],
x: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
J_h: jnp.ndarray,
g_h_jnp: jnp.ndarray,
cost: float,
d: np.ndarray,
d_jnp: jnp.ndarray,
Delta: float,
alpha: float,
step_h: jnp.ndarray | None,
s: np.ndarray | None,
V: np.ndarray | None,
uf: np.ndarray | None,
xdata: np.ndarray,
ydata: np.ndarray,
data_mask: jnp.ndarray,
transform: Callable[..., Any] | None,
loss_function: Callable[..., Any] | None,
f_scale: float,
scale_inv: np.ndarray,
jac_scale: bool,
solver: str,
ftol: float,
xtol: float,
max_nfev: int,
nfev: int,
) -> dict[str, Any]:
"""Evaluate step acceptance through inner trust region loop.
This method implements the inner loop of the TRF algorithm, which
repeatedly solves the trust region subproblem and evaluates candidate
steps until an acceptable step is found.
Parameters
----------
fun : Callable
Function to evaluate residuals
jac : Callable
Function to evaluate Jacobian
x : np.ndarray
Current parameter values
f : jnp.ndarray
Current residuals (possibly scaled by loss function)
J : jnp.ndarray
Current Jacobian (possibly scaled by loss function)
J_h : jnp.ndarray
Scaled Jacobian for subproblem
g_h_jnp : jnp.ndarray
Scaled gradient for subproblem
cost : float
Current cost value
d : np.ndarray
Parameter scaling factors
d_jnp : jnp.ndarray
Parameter scaling factors (JAX array)
Delta : float
Trust region radius
alpha : float
Levenberg-Marquardt parameter
step_h : jnp.ndarray | None
Pre-computed step (for CG solver), None for exact solver
s : np.ndarray | None
SVD singular values (for exact solver), None for CG
V : np.ndarray | None
SVD V matrix (for exact solver), None for CG
uf : np.ndarray | None
SVD U^T @ f (for exact solver), None for CG
xdata : np.ndarray
Independent variable data
ydata : np.ndarray
Dependent variable data
data_mask : jnp.ndarray
Mask for valid data points
transform : Callable | None
Parameter transformation function
loss_function : Callable | None
Robust loss function
f_scale : float
Residual scale factor
scale_inv : np.ndarray
Inverse parameter scaling
jac_scale : bool
Whether using Jacobian-based scaling
solver : str
Trust region solver ('cg' or 'exact')
ftol : float
Cost function tolerance
xtol : float
Parameter tolerance
max_nfev : int
Maximum function evaluations
nfev : int
Current function evaluation count
Returns
-------
dict
Dictionary containing:
- accepted : bool - Whether a step was accepted
- x_new : np.ndarray - New parameter values (if accepted)
- f_new : jnp.ndarray - New residuals (if accepted)
- J_new : jnp.ndarray - New Jacobian (if accepted)
- cost_new : float - New cost value (if accepted)
- g_new : jnp.ndarray - New gradient (if accepted)
- scale : np.ndarray - Updated parameter scaling (if accepted)
- scale_inv : np.ndarray - Updated inverse scaling (if accepted)
- actual_reduction : float - Actual cost reduction
- step_norm : float - Step norm
- Delta : float - Updated trust region radius
- alpha : float - Updated Levenberg-Marquardt parameter
- termination_status : int | None - Termination status code
- nfev : int - Updated function evaluation count
- njev : int - Jacobian evaluation count (1 if accepted, 0 otherwise)
"""
n, m = len(x), len(f)
actual_reduction = -1.0 # Python float, avoids JAX sync in while condition
inner_loop_count = 0
max_inner_iterations = 100
termination_status = None
step_norm = 0
x_new = x
f_new = f
cost_new = cost
while (
actual_reduction <= 0
and nfev < max_nfev
and inner_loop_count < max_inner_iterations
):
inner_loop_count += 1
# Solve subproblem (reuse step or compute new one)
if solver == "cg":
if inner_loop_count > 1:
step_h = self.solve_tr_subproblem_cg(J, f, d_jnp, Delta, alpha)
_n_iter = 1 # Dummy value for compatibility
else:
step_h, alpha, _n_iter = solve_lsq_trust_region_jax(
n, m, uf, s, V, Delta, initial_alpha=alpha
)
# Compute predicted reduction
predicted_reduction_jnp = -self.cJIT.evaluate_quadratic(
J_h, g_h_jnp, step_h
)
predicted_reduction = predicted_reduction_jnp
# Transform step and evaluate objective
# OPT-18: Fused step computation to reduce intermediate allocations
step = d * step_h
x_new = x + step # Keep step for later use in convergence check
f_new = fun(x_new, xdata, ydata, data_mask, transform)
nfev += 1
step_h_norm = jnorm(step_h)
# Check for numerical issues
if not self.check_isfinite(f_new):
Delta = TR_REDUCTION_FACTOR * step_h_norm
continue
# Compute actual reduction
if loss_function is not None:
cost_new_jnp = loss_function(f_new, f_scale, data_mask, cost_only=True)
else:
cost_new_jnp = self.default_loss_func(f_new)
cost_new = cost_new_jnp
actual_reduction = cost - cost_new
# Update trust region radius (JIT-compiled, stays on device)
Delta_new, ratio = update_tr_radius_jax(
Delta,
actual_reduction,
predicted_reduction,
step_h_norm,
step_h_norm > TR_BOUNDARY_THRESHOLD * Delta,
)
# Check termination criteria (JIT-compiled, stays on device)
step_norm = jnorm(step)
term_code = check_termination_jax(
actual_reduction, cost, step_norm, jnorm(x), ratio, ftol, xtol
)
# Single GPU→CPU sync per inner iteration: materialize term_code
# This sync also implicitly materializes actual_reduction and Delta_new
term_code_int = int(term_code)
# Materialize actual_reduction as Python float to avoid repeated JAX syncs
# in the while-loop condition and the acceptance check below
actual_reduction = float(actual_reduction)
if term_code_int != 0:
termination_status = term_code_int
break
if Delta_new > 0:
raw_alpha = alpha * Delta / Delta_new
alpha = min(raw_alpha if math.isfinite(raw_alpha) else 1e30, 1e30)
Delta = Delta_new
# Exit inner loop if we have a successful step
if actual_reduction > 0:
break
# Check if inner loop hit iteration limit
if inner_loop_count >= max_inner_iterations:
self.logger.warning(
"Inner optimization loop hit iteration limit",
inner_iterations=inner_loop_count,
actual_reduction=actual_reduction,
)
termination_status = -3 # Inner loop limit exceeded
# Prepare result
result: dict[str, Any] = {
"accepted": actual_reduction > 0,
"actual_reduction": actual_reduction,
"step_norm": step_norm if actual_reduction > 0 else 0,
"Delta": Delta,
"alpha": alpha,
"termination_status": termination_status,
"nfev": nfev,
"njev": 0, # Will be set to 1 if step is accepted
}
# If step was accepted, compute new state
if actual_reduction > 0:
result.update(
{
"x_new": x_new,
"f_new": f_new,
"cost_new": cost_new,
"njev": 1,
}
)
# Compute new Jacobian
J_new = jac(x_new, xdata, ydata, data_mask, transform)
# Apply loss function if provided
if loss_function is not None:
rho = loss_function(f_new, f_scale)
J_new, f_new_scaled = self.cJIT.scale_for_robust_loss_function(
J_new, f_new, rho
)
result["f_new"] = f_new_scaled # Scaled residuals for optimization
result["f_true_new"] = f_new # Unscaled residuals for res.fun
else:
result["f_new"] = f_new
result["f_true_new"] = f_new # No scaling, so both are the same
result["J_new"] = J_new
# Compute new gradient
g_new = self.compute_grad(J_new, result["f_new"])
result["g_new"] = g_new
# Update scaling if using Jacobian-based scaling
if jac_scale:
scale_new, scale_inv_new = self.cJIT.compute_jac_scale(J_new, scale_inv)
result["scale"] = scale_new
result["scale_inv"] = scale_inv_new
return result
def _invoke_callback(
self,
callback: Callable[..., Any],
iteration: int,
cost: float,
x: np.ndarray,
g_norm: float,
nfev: int,
step_norm: float | None,
actual_reduction: float | None,
) -> int | None:
"""Invoke user callback with proper exception handling.
This helper extracts callback handling logic from trf_no_bounds,
reducing complexity.
Parameters
----------
callback : Callable
User-provided callback function
iteration : int
Current iteration number
cost : float
Current cost value
x : np.ndarray
Current parameter values
g_norm : float
Gradient norm
nfev : int
Number of function evaluations
step_norm : float | None
Step norm (None if not computed)
actual_reduction : float | None
Actual cost reduction (None if not computed)
Returns
-------
int | None
Termination status if callback requested stop, None otherwise.
"""
try:
callback(
iteration=iteration,
cost=float(cost), # JAX scalar -> Python float
params=np.array(x), # JAX array -> NumPy array
info={
"gradient_norm": float(g_norm),
"nfev": nfev,
"step_norm": float(step_norm) if step_norm is not None else None,
"actual_reduction": float(actual_reduction)
if actual_reduction is not None
else None,
},
)
except StopOptimization:
self.logger.info("Optimization stopped by callback (StopOptimization)")
return 2 # User-requested stop
except Exception as e:
warnings.warn(
f"Callback raised exception: {e}. Continuing optimization.",
RuntimeWarning,
)
return None
def _initialize_bounds_state(
self,
x0: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
g: jnp.ndarray,
lb_jnp: jnp.ndarray,
ub_jnp: jnp.ndarray,
x_scale: np.ndarray | str,
) -> dict[str, Any]:
"""Initialize bounds-specific state for TRF algorithm.
This helper extracts bounds initialization logic from trf_bounds,
reducing complexity.
Parameters
----------
x0 : np.ndarray
Initial parameter guess
f : jnp.ndarray
Initial residuals
J : jnp.ndarray
Initial Jacobian
g : jnp.ndarray
Initial gradient
lb_jnp : jnp.ndarray
Lower bounds (pre-converted to JAX)
ub_jnp : jnp.ndarray
Upper bounds (pre-converted to JAX)
x_scale : np.ndarray | str
Parameter scaling
Returns
-------
dict
State containing v, dv, scale, scale_inv, Delta, jac_scale
"""
jac_scale = isinstance(x_scale, str) and x_scale == "jac"
if jac_scale:
scale, scale_inv = self.cJIT.compute_jac_scale(J)
else:
safe_scale = jnp.where(jnp.asarray(x_scale) == 0, 1.0, jnp.asarray(x_scale))
scale, scale_inv = safe_scale, 1.0 / safe_scale
v, dv = CL_scaling_vector_jax(x0, g, lb_jnp, ub_jnp)
mask = dv != 0
v = v.at[mask].set(v[mask] * scale_inv[mask])
Delta = jnorm(x0 * scale_inv / v**SQRT_EXPONENT)
if Delta == 0:
Delta = 1.0
return {
"v": v,
"dv": dv,
"scale": scale,
"scale_inv": scale_inv,
"Delta": Delta,
"jac_scale": jac_scale,
}
def _solve_bounds_subproblem(
self,
J: jnp.ndarray,
f: jnp.ndarray,
g: jnp.ndarray,
v: jnp.ndarray,
dv: jnp.ndarray,
scale: np.ndarray,
scale_inv: np.ndarray,
Delta: float,
alpha: float,
solver: str,
n: int,
) -> dict[str, Any]:
"""Solve trust region subproblem with bounds.
This helper extracts bounds subproblem logic from trf_bounds,
reducing complexity.
Parameters
----------
J : jnp.ndarray
Current Jacobian
f : jnp.ndarray
Current residuals
g : jnp.ndarray
Current gradient
v : jnp.ndarray
Coleman-Li scaling vector
dv : jnp.ndarray
Derivative of v
scale : np.ndarray
Parameter scaling
scale_inv : np.ndarray
Inverse parameter scaling
Delta : float
Trust region radius
alpha : float
Levenberg-Marquardt parameter
solver : str
Solver type
n : int
Number of parameters
Returns
-------
dict
Subproblem solution containing d, g_h, J_h, diag_h, p_h, s, V, uf
"""
# Apply two types of scaling
d = v**SQRT_EXPONENT * scale
# C = diag(g * scale) Jv
diag_h = g * dv * scale
# "hat" gradient
g_h = d * g
J_diag = jnp.diag(jnp.sqrt(jnp.maximum(diag_h, 0.0)))
# OPT-2: Use jnp.asarray() to avoid copy if already JAX array
d_jnp = jnp.asarray(d)
f_zeros = jnp.zeros(n, dtype=jnp.float64)
if solver == "cg":
J_h = J * d_jnp
p_h = self.solve_tr_subproblem_cg_bounds(
J, f, d_jnp, J_diag, f_zeros, Delta, alpha
)
s, V, uf = None, None, None
elif solver == "sparse":
# Sparse solver path - fall back to dense for correctness
output = self.svd_bounds(f, J, d_jnp, J_diag, f_zeros)
J_h = output[0]
s, V, uf = output[2:]
p_h = None
else:
# Exact SVD solver (default)
output = self.svd_bounds(f, J, d_jnp, J_diag, f_zeros)
J_h = output[0]
s, V, uf = output[2:]
p_h = None
return {
"d": d,
"d_jnp": d_jnp,
"g_h": g_h,
"J_h": J_h,
"diag_h": diag_h,
"p_h": p_h,
"s": s,
"V": V,
"uf": uf,
"f_zeros": f_zeros,
"J_diag": J_diag,
}
def _evaluate_bounds_inner_loop(
self,
fun: Callable[..., Any],
x: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
J_h: jnp.ndarray,
g_h: jnp.ndarray,
diag_h: jnp.ndarray,
cost: float,
d: np.ndarray,
d_jnp: jnp.ndarray,
Delta: float,
alpha: float,
p_h: jnp.ndarray | None,
s: jnp.ndarray | None,
V: jnp.ndarray | None,
uf: jnp.ndarray | None,
f_zeros: jnp.ndarray,
J_diag: jnp.ndarray,
xdata: np.ndarray,
ydata: np.ndarray,
data_mask: jnp.ndarray,
transform: Callable[..., Any] | None,
loss_function: Callable[..., Any] | None,
f_scale: float,
lb: np.ndarray,
ub: np.ndarray,
lb_jnp: jnp.ndarray,
ub_jnp: jnp.ndarray,
theta: float,
solver: str,
ftol: float,
xtol: float,
max_nfev: int,
nfev: int,
n: int,
m: int,
) -> dict[str, Any]:
"""Evaluate inner loop for bounds optimization.
This helper extracts the inner loop logic from trf_bounds,
reducing complexity.
Parameters
----------
fun : Callable
Residual function
x : np.ndarray
Current parameters
f : jnp.ndarray
Current residuals
J : jnp.ndarray
Current Jacobian
J_h : jnp.ndarray
Scaled Jacobian
g_h : jnp.ndarray
Scaled gradient
diag_h : jnp.ndarray
Diagonal scaling
cost : float
Current cost
d : np.ndarray
Scaling vector
d_jnp : jnp.ndarray
JAX scaling vector
Delta : float
Trust region radius
alpha : float
LM parameter
p_h : jnp.ndarray | None
Pre-computed step (CG)
s, V, uf : SVD components
f_zeros : jnp.ndarray
Zero vector
J_diag : jnp.ndarray
Diagonal Jacobian
xdata, ydata : Data arrays
data_mask : Data mask
transform : Transform function
loss_function : Loss function
f_scale : Residual scale
lb, ub : Bounds
theta : Step back ratio
solver : Solver type
ftol, xtol : Tolerances
max_nfev : Max function evals
nfev : Current function evals
n, m : Problem dimensions
Returns
-------
dict
Inner loop result
"""
actual_reduction = -1.0 # Python float, avoids JAX sync in while condition
inner_loop_count = 0
max_inner_iterations = 100
termination_status = None
step_norm = 0
x_new = x
f_new = f
cost_new = cost
while (
actual_reduction <= 0
and nfev < max_nfev
and inner_loop_count < max_inner_iterations
):
inner_loop_count += 1
if solver == "cg":
if inner_loop_count > 1:
p_h = self.solve_tr_subproblem_cg_bounds(
J, f, d_jnp, J_diag, f_zeros, Delta, alpha
)
_n_iter = 1
else:
# Use m + n for the augmented system row count (J_augmented is (m+n) x n)
p_h, alpha, _n_iter = solve_lsq_trust_region_jax(
n, m + n, uf, s, V, Delta, initial_alpha=alpha
)
p = d * p_h
step, step_h, predicted_reduction = self.select_step(
x,
J_h,
diag_h,
g_h,
p,
p_h,
d,
Delta,
lb,
ub,
theta,
lb_jnp=lb_jnp,
ub_jnp=ub_jnp,
)
x_new = make_strictly_feasible_jax(x + step, lb_jnp, ub_jnp)
f_new = fun(x_new, xdata, ydata, data_mask, transform)
nfev += 1
step_h_norm = jnorm(step_h)
if not self.check_isfinite(f_new):
Delta = 0.25 * step_h_norm
continue
if loss_function is not None:
cost_new = loss_function(f_new, f_scale, data_mask, cost_only=True)
else:
cost_new = self.default_loss_func(f_new)
actual_reduction = cost - cost_new
Delta_new, ratio = update_tr_radius_jax(
Delta,
actual_reduction,
predicted_reduction,
step_h_norm,
step_h_norm > 0.95 * Delta,
)
step_norm = jnorm(step)
term_code = check_termination_jax(
actual_reduction, cost, step_norm, jnorm(x), ratio, ftol, xtol
)
# Single GPU→CPU sync per inner iteration: materialize term_code
term_code_int = int(term_code)
# Piggyback float materialization to avoid JAX sync in while/if checks
actual_reduction = float(actual_reduction)
if term_code_int != 0:
termination_status = term_code_int
break
if Delta_new > 0:
raw_alpha = alpha * Delta / Delta_new
alpha = min(raw_alpha if math.isfinite(raw_alpha) else 1e30, 1e30)
Delta = Delta_new
# Check inner loop limit
if inner_loop_count >= max_inner_iterations:
self.logger.warning(
"Inner optimization loop hit iteration limit",
inner_iterations=inner_loop_count,
actual_reduction=actual_reduction,
)
termination_status = -3
return {
"accepted": actual_reduction > 0,
"x_new": x_new if actual_reduction > 0 else x,
"f_new": f_new if actual_reduction > 0 else f,
"cost_new": cost_new if actual_reduction > 0 else cost,
"actual_reduction": actual_reduction,
"step_norm": step_norm if actual_reduction > 0 else 0,
"Delta": Delta,
"alpha": alpha,
"termination_status": termination_status,
"nfev": nfev,
}
def _apply_accepted_step(
self,
acceptance_result: dict[str, Any],
jac_scale: bool,
njev: int,
) -> dict[str, Any]:
"""Apply accepted step updates to optimization state.
This helper extracts the state update logic after step acceptance
from trf_no_bounds, reducing complexity.
Parameters
----------
acceptance_result : dict
Result from _evaluate_step_acceptance
jac_scale : bool
Whether Jacobian scaling is enabled
njev : int
Current Jacobian evaluation count
Returns
-------
dict
Updated state variables: x, f, f_true, J, cost, g, njev,
and optionally scale, scale_inv.
"""
result = {
"x": acceptance_result["x_new"],
"f": acceptance_result["f_new"],
"f_true": acceptance_result["f_true_new"],
"J": acceptance_result["J_new"],
"cost": acceptance_result["cost_new"],
"g": acceptance_result["g_new"],
"njev": njev + acceptance_result["njev"],
}
if jac_scale and "scale" in acceptance_result:
result["scale"] = acceptance_result["scale"]
result["scale_inv"] = acceptance_result["scale_inv"]
return result
def _build_optimize_result(
self,
x: jnp.ndarray,
cost: float,
f_true: jnp.ndarray,
J: jnp.ndarray,
g: jnp.ndarray,
g_norm: float,
nfev: int,
njev: int,
iteration: int,
termination_status: int,
) -> OptimizeResult:
"""Build OptimizeResult from optimization state.
This helper extracts result construction logic from trf_no_bounds
and trf_bounds, reducing complexity.
Parameters
----------
x : jnp.ndarray
Final parameter values
cost : float
Final cost value
f_true : jnp.ndarray
Final residuals (unscaled)
J : jnp.ndarray
Final Jacobian matrix
g : jnp.ndarray
Final gradient
g_norm : float
Final gradient norm
nfev : int
Total function evaluations
njev : int
Total Jacobian evaluations
iteration : int
Total iterations performed
termination_status : int
Termination status code
Returns
-------
OptimizeResult
The optimization result object.
"""
active_mask = jnp.zeros_like(x) # JAX zeros instead of NumPy
return OptimizeResult(
x=x,
cost=float(cost), # Convert JAX scalar to Python float
fun=f_true,
jac=J,
grad=np.array(g), # Convert JAX array to NumPy
optimality=float(g_norm), # Convert JAX scalar to Python float
active_mask=active_mask,
nfev=nfev,
njev=njev,
nit=iteration, # Number of iterations performed
status=termination_status,
success=termination_status > 0,
all_times={},
)
[docs]
def trf_no_bounds(
self,
fun: Callable[..., Any],
xdata: jnp.ndarray | tuple[jnp.ndarray],
ydata: jnp.ndarray,
jac: Callable[..., Any],
data_mask: jnp.ndarray,
transform: jnp.ndarray,
x0: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
lb: np.ndarray,
ub: np.ndarray,
ftol: float,
xtol: float,
gtol: float,
max_nfev: int,
f_scale: float,
x_scale: np.ndarray,
loss_function: None | Callable[..., Any],
tr_options: dict[str, Any],
verbose: int,
solver: str = "exact",
callback: Callable[..., Any] | None = None,
profiler: TRFProfiler | NullProfiler | None = None,
**kwargs: Any,
) -> OptimizeResult:
"""Unbounded version of the trust-region reflective algorithm.
Parameters
----------
fun : callable
The residual function
xdata : array_like or tuple of array_like
The independent variable where the data is measured. If `xdata` is a
tuple, then the input arguments to `fun` are assumed to be
``(xdata[0], xdata[1], ...)``.
ydata : jnp.ndarray
The dependent data
jac : callable
The Jacobian of `fun`.
data_mask : jnp.ndarray
The mask for the data.
transform : jnp.ndarray
The uncertainty transform for the data.
x0 : jnp.ndarray
Initial guess. Array of real elements of size (n,), where 'n' is the
number of independent variables.
f0 : jnp.ndarray
Initial residuals. Array of real elements of size (m,), where 'm' is
the number of data points.
J0 : jnp.ndarray
Initial Jacobian. Array of real elements of size (m, n), where 'm' is
the number of data points and 'n' is the number of independent
variables.
lb : jnp.ndarray
Lower bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ub : jnp.ndarray
Upper bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ftol : float
Tolerance for termination by the change of the cost function.
xtol : float
Tolerance for termination by the change of the independent variables.
gtol : float
Tolerance for termination by the norm of the gradient.
max_nfev : int
Maximum number of function evaluations.
f_scale : float
Cost function scalar
x_scale : jnp.ndarray
Scaling factors for independent variables.
loss_function : callable, optional
Loss function. If None, the standard least-squares problem is
solved.
tr_options : dict
Options for the trust-region algorithm.
verbose : int
Level of algorithm's verbosity:
* 0 (default) : work silently.
* 1 : display a termination report.
Returns
-------
result : OptimizeResult
The optimization result represented as a ``OptimizeResult`` object.
Important attributes are: ``x`` the solution array, ``success`` a
Boolean flag indicating if the optimizer exited successfully and
``message`` which describes the cause of the termination. See
`OptimizeResult` for a description of other attributes.
profiler : TRFProfiler, NullProfiler, or None, optional
Profiler for timing algorithm operations. If None, uses NullProfiler
(zero overhead). Use TRFProfiler() for detailed performance analysis.
Default is None.
Notes
-----
The algorithm is described in [13]_.
"""
# Initialize profiler (NullProfiler if not provided for zero overhead)
if profiler is None:
profiler = NullProfiler()
# Initialize optimization state using helper
state = self._initialize_trf_state(
x0=x0,
f=f,
J=J,
loss_function=loss_function,
x_scale=x_scale,
f_scale=f_scale,
data_mask=data_mask,
)
# Extract state variables
x = state["x"]
f = state["f"]
J = state["J"]
cost = state["cost"]
g = state["g"]
scale = state["scale"]
scale_inv = state["scale_inv"]
Delta = state["Delta"]
nfev = state["nfev"]
njev = state["njev"]
m = state["m"]
n = state["n"]
jac_scale = state["jac_scale"]
f_true = state["f_true"] # Original unscaled residuals (for res.fun)
# Log optimization start
self.logger.info(
"Starting TRF optimization (no bounds)",
n_params=n,
n_residuals=m,
max_nfev=max_nfev,
)
# Set max_nfev if not provided
if max_nfev is None:
max_nfev = x0.size * DEFAULT_MAX_NFEV_MULTIPLIER
alpha = INITIAL_LEVENBERG_MARQUARDT_LAMBDA # "Levenberg-Marquardt" parameter
termination_status = None
iteration = 0
step_norm = None
actual_reduction = None
if verbose == 2:
print_header_nonlinear()
# Trust region optimization loop
with self.logger.timer("optimization", log_result=False):
while True:
# Check gradient convergence using helper (only if not already terminated)
# OPT-8: Get g_norm from convergence check to avoid redundant computation
if termination_status is None:
termination_status, g_norm = self._check_convergence_criteria(
g, gtol
)
else:
g_norm = jnorm(g, ord=jnp.inf) # Only compute if already terminated
if verbose == 2:
# Use jax.debug.callback to avoid blocking host-device transfers
debug.callback(
self._log_iteration_callback,
iteration,
nfev,
cost,
actual_reduction,
step_norm,
g_norm,
)
if termination_status is not None or nfev == max_nfev:
if nfev == max_nfev:
self.logger.warning(
"Maximum number of function evaluations reached", nfev=nfev
)
break
# Log iteration details (call-site guard avoids kwargs construction)
if self.logger.logger.isEnabledFor(PERFORMANCE_LEVEL):
self.logger.optimization_step(
iteration=iteration,
cost=cost,
gradient_norm=g_norm,
step_size=Delta if iteration > 0 else None,
nfev=nfev,
)
# Solve trust region subproblem using helper
subproblem_result = self._solve_trust_region_subproblem(
J=J,
f=f,
g=g,
scale=scale,
Delta=Delta,
alpha=alpha,
solver=solver,
)
# Extract subproblem solution
d = subproblem_result["d"]
d_jnp = subproblem_result["d_jnp"]
g_h_jnp = subproblem_result["g_h"]
J_h = subproblem_result["J_h"]
step_h = subproblem_result["step_h"]
s = subproblem_result["s"]
V = subproblem_result["V"]
uf = subproblem_result["uf"]
# Evaluate and potentially accept step using helper
acceptance_result = self._evaluate_step_acceptance(
fun=fun,
jac=jac,
x=x,
f=f,
J=J,
J_h=J_h,
g_h_jnp=g_h_jnp,
cost=cost,
d=d,
d_jnp=d_jnp,
Delta=Delta,
alpha=alpha,
step_h=step_h,
s=s,
V=V,
uf=uf,
xdata=xdata,
ydata=ydata,
data_mask=data_mask,
transform=transform,
loss_function=loss_function,
f_scale=f_scale,
scale_inv=scale_inv,
jac_scale=jac_scale,
solver=solver,
ftol=ftol,
xtol=xtol,
max_nfev=max_nfev,
nfev=nfev,
)
# Update state from acceptance result using helper
if acceptance_result["accepted"]:
step_update = self._apply_accepted_step(
acceptance_result=acceptance_result,
jac_scale=jac_scale,
njev=njev,
)
x = step_update["x"]
f = step_update["f"]
f_true = step_update["f_true"]
J = step_update["J"]
cost = step_update["cost"]
g = step_update["g"]
njev = step_update["njev"]
if "scale" in step_update:
scale = step_update["scale"]
scale_inv = step_update["scale_inv"]
# Update common values regardless of acceptance
actual_reduction = acceptance_result["actual_reduction"]
step_norm = acceptance_result["step_norm"]
Delta = acceptance_result["Delta"]
alpha = acceptance_result["alpha"]
nfev = acceptance_result["nfev"]
if acceptance_result["termination_status"] is not None:
termination_status = acceptance_result["termination_status"]
iteration += 1
# Invoke user callback if provided using helper
if callback is not None:
callback_status = self._invoke_callback(
callback=callback,
iteration=iteration,
cost=cost,
x=x,
g_norm=g_norm,
nfev=nfev,
step_norm=step_norm,
actual_reduction=actual_reduction,
)
if callback_status is not None:
termination_status = callback_status
break
if termination_status is None:
termination_status = 0
# Build and return final result using helper
return self._build_optimize_result(
x=x,
cost=cost,
f_true=f_true,
J=J,
g=g,
g_norm=g_norm,
nfev=nfev,
njev=njev,
iteration=iteration,
termination_status=termination_status,
)
[docs]
def trf_bounds(
self,
fun: Callable[..., Any],
xdata: jnp.ndarray | tuple[jnp.ndarray],
ydata: jnp.ndarray,
jac: Callable[..., Any],
data_mask: jnp.ndarray,
transform: jnp.ndarray,
x0: np.ndarray,
f: jnp.ndarray,
J: jnp.ndarray,
lb: np.ndarray,
ub: np.ndarray,
ftol: float,
xtol: float,
gtol: float,
max_nfev: int,
f_scale: float,
x_scale: np.ndarray,
loss_function: None | Callable[..., Any],
tr_options: dict[str, Any],
verbose: int,
solver: str = "exact",
callback: Callable[..., Any] | None = None,
**kwargs: Any,
) -> OptimizeResult:
"""Bounded version of the trust-region reflective algorithm.
Parameters
----------
fun : callable
The residual function
xdata : array_like or tuple of array_like
The independent variable where the data is measured. If `xdata` is a
tuple, then the input arguments to `fun` are assumed to be
``(xdata[0], xdata[1], ...)``.
ydata : jnp.ndarray
The dependent data
jac : callable
The Jacobian of `fun`.
data_mask : jnp.ndarray
The mask for the data.
transform : jnp.ndarray
The uncertainty transform for the data.
x0 : jnp.ndarray
Initial guess. Array of real elements of size (n,), where 'n' is the
number of independent variables.
f0 : jnp.ndarray
Initial residuals. Array of real elements of size (m,), where 'm' is
the number of data points.
J0 : jnp.ndarray
Initial Jacobian. Array of real elements of size (m, n), where 'm' is
the number of data points and 'n' is the number of independent
variables.
lb : jnp.ndarray
Lower bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ub : jnp.ndarray
Upper bounds on independent variables. Array of real elements of size
(n,), where 'n' is the number of independent variables.
ftol : float
Tolerance for termination by the change of the cost function.
xtol : float
Tolerance for termination by the change of the independent variables.
gtol : float
Tolerance for termination by the norm of the gradient.
max_nfev : int
Maximum number of function evaluations.
f_scale : float
Cost function scalar
x_scale : jnp.ndarray
Scaling factors for independent variables.
loss_function : callable, optional
Loss function. If None, the standard least-squares problem is
solved.
tr_options : dict
Options for the trust-region algorithm.
verbose : int
Level of algorithm's verbosity:
* 0 (default) : work silently.
* 1 : display a termination report.
Returns
-------
result : OptimizeResult
The optimization result represented as a ``OptimizeResult`` object.
Important attributes are: ``x`` the solution array, ``success`` a
Boolean flag indicating if the optimizer exited successfully and
``message`` which describes the cause of the termination. See
`OptimizeResult` for a description of other attributes.
Notes
-----
The algorithm is described in [13]_.
References
----------
.. [13] J. J. More, "The Levenberg-Marquardt Algorithm: Implementation and
Theory," in Numerical Analysis, ed. G. A. Watson (1978), pp. 105-116.
DOI: 10.1017/CBO9780511819595.006
.. [2] T. F. Coleman and Y. Li, "An interior trust region approach for
nonlinear minimization subject to bounds," SIAM Journal on
Optimization, vol. 6, no. 2, pp. 418-445, 1996.
"""
x = x0
f_true = f
nfev = 1
njev = 1
m, n = J.shape
# Convert bounds to JAX once at entry to avoid repeated transfers
lb_jnp = jnp.asarray(lb)
ub_jnp = jnp.asarray(ub)
if loss_function is not None:
rho = loss_function(f, f_scale)
cost = self.calculate_cost(rho, data_mask)
J, f = self.cJIT.scale_for_robust_loss_function(J, f, rho)
else:
cost = self.default_loss_func(f)
g = self.compute_grad(J, f)
# Initialize bounds state using helper
bounds_state = self._initialize_bounds_state(
x0, f, J, g, lb_jnp, ub_jnp, x_scale
)
v = bounds_state["v"]
dv = bounds_state["dv"]
scale = bounds_state["scale"]
scale_inv = bounds_state["scale_inv"]
Delta = bounds_state["Delta"]
jac_scale = bounds_state["jac_scale"]
# Use JAX norm for gradient norm calculation
g_norm = jnorm(g * v, ord=jnp.inf)
if max_nfev is None:
max_nfev = x0.size * DEFAULT_MAX_NFEV_MULTIPLIER
alpha = INITIAL_LEVENBERG_MARQUARDT_LAMBDA
termination_status = None
iteration = 0
step_norm = None
actual_reduction = None
if verbose == 2:
print_header_nonlinear()
while True:
v, dv = CL_scaling_vector_jax(x, g, lb_jnp, ub_jnp)
g_norm = jnorm(g * v, ord=jnp.inf)
if termination_status is None and g_norm < gtol:
termination_status = 1
if verbose == 2:
debug.callback(
self._log_iteration_callback,
iteration,
nfev,
cost,
actual_reduction,
step_norm,
g_norm,
)
if termination_status is not None or nfev == max_nfev:
break
# Update v with scaling
mask = dv != 0
v = v.at[mask].set(v[mask] * scale_inv[mask])
# Solve bounds subproblem using helper
subproblem = self._solve_bounds_subproblem(
J=J,
f=f,
g=g,
v=v,
dv=dv,
scale=scale,
scale_inv=scale_inv,
Delta=Delta,
alpha=alpha,
solver=solver,
n=n,
)
# theta controls step back step ratio from the bounds; must be < 1
# to maintain strict interior feasibility required by TRF bounds mode
theta = jnp.minimum(jnp.maximum(0.995, 1.0 - g_norm), 1.0 - 1e-8)
# Evaluate inner loop using helper
inner_result = self._evaluate_bounds_inner_loop(
fun=fun,
x=x,
f=f,
J=J,
J_h=subproblem["J_h"],
g_h=subproblem["g_h"],
diag_h=subproblem["diag_h"],
cost=cost,
d=subproblem["d"],
d_jnp=subproblem["d_jnp"],
Delta=Delta,
alpha=alpha,
p_h=subproblem["p_h"],
s=subproblem["s"],
V=subproblem["V"],
uf=subproblem["uf"],
f_zeros=subproblem["f_zeros"],
J_diag=subproblem["J_diag"],
xdata=xdata,
ydata=ydata,
data_mask=data_mask,
transform=transform,
loss_function=loss_function,
f_scale=f_scale,
lb=lb,
ub=ub,
lb_jnp=lb_jnp,
ub_jnp=ub_jnp,
theta=theta,
solver=solver,
ftol=ftol,
xtol=xtol,
max_nfev=max_nfev,
nfev=nfev,
n=n,
m=m,
)
# Update from inner loop result
actual_reduction = inner_result["actual_reduction"]
step_norm = inner_result["step_norm"]
Delta = inner_result["Delta"]
alpha = inner_result["alpha"]
nfev = inner_result["nfev"]
if inner_result["termination_status"] is not None:
termination_status = inner_result["termination_status"]
if inner_result["accepted"]:
x = inner_result["x_new"]
f_unscaled = inner_result["f_new"]
f_true = f_unscaled
cost = inner_result["cost_new"]
J = jac(x, xdata, ydata, data_mask, transform)
njev += 1
if loss_function is not None:
rho = loss_function(f_unscaled, f_scale)
J, f = self.cJIT.scale_for_robust_loss_function(J, f_unscaled, rho)
else:
f = f_unscaled
g = self.compute_grad(J, f)
if jac_scale:
scale, scale_inv = self.cJIT.compute_jac_scale(J, scale_inv)
else:
step_norm = 0
# actual_reduction is already set from inner_result (negative value);
# retaining it lets callbacks report true optimization state
iteration += 1
# Invoke user callback using helper
if callback is not None:
callback_status = self._invoke_callback(
callback=callback,
iteration=iteration,
cost=cost,
x=x,
g_norm=g_norm,
nfev=nfev,
step_norm=step_norm,
actual_reduction=actual_reduction,
)
if callback_status is not None:
termination_status = callback_status
break
if termination_status is None:
termination_status = 0
x_out = np.asarray(x)
active_mask = find_active_constraints(x_out, lb, ub, rtol=xtol)
return OptimizeResult(
x=x_out,
cost=float(cost),
fun=f_true,
jac=J,
grad=np.array(g),
optimality=float(g_norm),
active_mask=active_mask,
nfev=nfev,
njev=njev,
nit=iteration,
status=termination_status,
success=termination_status > 0,
)
[docs]
def select_step(
self,
x: np.ndarray,
J_h: jnp.ndarray,
diag_h: jnp.ndarray,
g_h: jnp.ndarray,
p: np.ndarray,
p_h: np.ndarray,
d: np.ndarray,
Delta: float,
lb: np.ndarray,
ub: np.ndarray,
theta: float,
lb_jnp: jnp.ndarray | None = None,
ub_jnp: jnp.ndarray | None = None,
) -> tuple[np.ndarray, Any, Any]:
"""Select the best step according to Trust Region Reflective algorithm.
Parameters
----------
x : np.ndarray
Current set parameter vector.
J_h : jnp.ndarray
Jacobian matrix in the scaled 'hat' space.
diag_h : jnp.ndarray
Diagonal of the scaled matrix C = diag(g * scale) Jv?
g_h : jnp.ndarray
Gradient vector in the scaled 'hat' space.
p : np.ndarray
Trust-region step in the original space.
p_h : np.ndarray
Trust-region step in the scaled 'hat' space.
d : np.ndarray
Scaling vector.
Delta : float
Trust-region radius.
lb : np.ndarray
Lower bounds on variables.
ub : np.ndarray
Upper bounds on variables.
theta : float
Controls step back step ratio from the bounds.
lb_jnp : jnp.ndarray, optional
Pre-converted JAX lower bounds (avoids repeated conversion).
ub_jnp : jnp.ndarray, optional
Pre-converted JAX upper bounds (avoids repeated conversion).
Returns
-------
step : np.ndarray
Step in the original space.
step_h : np.ndarray
Step in the scaled 'hat' space.
predicted_reduction : float
Predicted reduction in the cost function.
"""
if lb_jnp is not None and ub_jnp is not None:
_in_bounds = in_bounds_jax(x + p, lb_jnp, ub_jnp)
else:
_in_bounds = in_bounds(x + p, lb, ub)
if _in_bounds:
p_value = self.cJIT.evaluate_quadratic(J_h, g_h, p_h, diag=diag_h)
return p, p_h, -p_value
# B004: Use JAX-compiled versions to avoid D2H array transfers.
# Convert inputs to JAX arrays for on-device computation.
x_jnp = jnp.asarray(x)
p_jnp = jnp.asarray(p)
d_jnp = jnp.asarray(d)
lb_j = lb_jnp if lb_jnp is not None else jnp.asarray(lb)
ub_j = ub_jnp if ub_jnp is not None else jnp.asarray(ub)
p_stride_jax, hits = step_size_to_bound_jax(x_jnp, p_jnp, lb_j, ub_j)
p_stride = float(p_stride_jax)
# Compute the reflected direction.
r_h = jnp.asarray(p_h)
# Negate components that hit bounds
r_h = jnp.where(hits != 0, -r_h, r_h)
r = d_jnp * r_h
# Restrict trust-region step, such that it hits the bound.
p_h_scaled = jnp.asarray(p_h) * p_stride
x_on_bound = x_jnp + p_jnp * p_stride
# Reflected direction will cross first either feasible region or trust
# region boundary.
_, to_tr_jax = intersect_trust_region_jax(p_h_scaled, r_h, Delta)
to_bound_jax, _ = step_size_to_bound_jax(x_on_bound, r, lb_j, ub_j)
# Materialize scalars for host-side branching
to_tr = float(to_tr_jax)
to_bound = float(to_bound_jax)
# Find lower and upper bounds on a step size along the reflected
# direction, considering the strict feasibility requirement. There is no
# single correct way to do that, the chosen approach seems to work best
# on test problems.
r_stride = min(to_bound, to_tr)
if r_stride > 0:
r_stride_l = (1 - theta) * p_stride / r_stride
r_stride_u = theta * to_bound if r_stride == to_bound else to_tr
else:
r_stride_l = 0
r_stride_u = -1
# Check if reflection step is available.
if r_stride_l <= r_stride_u:
a, b, c = self.cJIT.build_quadratic_1d(
J_h, g_h, r_h, s0=p_h_scaled, diag=diag_h
)
r_stride_jax, r_value = minimize_quadratic_1d_jax(
a, b, r_stride_l, r_stride_u, c=c
)
r_stride = float(r_stride_jax)
r_h = r_h * r_stride + p_h_scaled
r = r_h * d_jnp
else:
r_value = jnp.inf
# Now correct p_h to make it strictly interior.
p = np.asarray(p_jnp * p_stride * theta)
p_h = p_h_scaled * theta
p_value = self.cJIT.evaluate_quadratic(J_h, g_h, p_h, diag=diag_h)
ag_h = -g_h
ag = d_jnp * ag_h
ag_h_norm = jnorm(ag_h)
to_tr = float(Delta / ag_h_norm) if float(ag_h_norm) > 0 else 1.0
to_bound_jax, _ = step_size_to_bound_jax(x_jnp, ag, lb_j, ub_j)
to_bound = float(to_bound_jax)
ag_stride = theta * to_bound if to_bound < to_tr else to_tr
a, b = self.cJIT.build_quadratic_1d(J_h, g_h, ag_h, diag=diag_h)
ag_stride_jax, ag_value = minimize_quadratic_1d_jax(a, b, 0, ag_stride)
ag_h = ag_h * ag_stride_jax
ag = ag * ag_stride_jax
# Materialize reduction values for host-side comparison
p_value = float(p_value)
r_value = float(r_value)
ag_value = float(ag_value)
if p_value < r_value and p_value < ag_value:
return np.asarray(p), p_h, -p_value
elif r_value < p_value and r_value < ag_value:
return np.asarray(r), r_h, -r_value
else:
return np.asarray(ag), ag_h, -ag_value
[docs]
def optimize(
self,
fun: Callable[..., Any],
x0: np.ndarray,
jac: Callable[..., Any] | None = None,
bounds: tuple[np.ndarray, np.ndarray] | tuple[float, float] = (-np.inf, np.inf),
**kwargs: Any,
) -> OptimizeResult:
"""Perform optimization using trust region reflective algorithm.
This method provides a simplified interface to the TRF algorithm.
For full control and curve fitting applications, use the `trf` method directly.
Parameters
----------
fun : callable
The objective function to minimize. Should return residuals.
x0 : np.ndarray
Initial guess for parameters
jac : callable, optional
Jacobian function. If None, uses automatic differentiation.
bounds : tuple of arrays
Lower and upper bounds for parameters
**kwargs
Additional optimization parameters
Returns
-------
OptimizeResult
The optimization result
Raises
------
NotImplementedError
This simplified interface is not yet implemented.
Use the `trf` method for full curve fitting functionality.
"""
raise NotImplementedError(
"The simplified optimize() interface is not yet implemented for TrustRegionReflective. "
"This class is designed for curve fitting applications. "
"Use the `trf()` method directly, or use the higher-level interfaces in "
"`nlsq.curve_fit()` or `LeastSquares.least_squares()`."
)