"""Generic interface for least-squares minimization."""
# mypy: disable-error-code="arg-type,assignment"
# Note: Remaining mypy errors are arg-type/assignment mismatches where Optional values
# are passed to methods expecting non-Optional, or Literal type narrowing issues.
# These require deeper refactoring. Fixed in this file: check_x_scale types,
# safe_clip→jnp.clip, logger.warning, None callable guards, return type fixes.
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Any, Literal
from warnings import warn
import numpy as np
# Initialize JAX configuration through central config
from nlsq.config import JAXConfig
__jax_config = JAXConfig()
import jax.numpy as jnp
from jax import jacfwd, jacrev, jit
from jax.scipy.linalg import solve_triangular as jax_solve_triangular
from nlsq.caching.memory_manager import get_memory_manager
from nlsq.caching.unified_cache import get_global_cache
from nlsq.common_scipy import EPS, in_bounds, make_strictly_feasible
from nlsq.constants import DEFAULT_FTOL, DEFAULT_GTOL, DEFAULT_XTOL
from nlsq.core.loss_functions import LossFunctionsJIT
from nlsq.core.trf import TrustRegionReflective
from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.types import ArrayLike, BoundsTuple, CallbackFunction, MethodLiteral
from nlsq.utils.diagnostics import OptimizationDiagnostics
from nlsq.utils.logging import get_logger
[docs]
def jacobian_mode_selector(
n_params: int, n_residuals: int, mode: str = "auto"
) -> tuple[str, str]:
"""Select Jacobian automatic differentiation mode based on problem dimensions.
Automatically chooses between forward-mode (jacfwd) and reverse-mode (jacrev)
automatic differentiation based on the Jacobian shape to minimize computational cost.
Parameters
----------
n_params : int
Number of parameters (columns in Jacobian)
n_residuals : int
Number of residuals (rows in Jacobian)
mode : {'auto', 'fwd', 'rev'}, optional
Jacobian mode selection. Default is 'auto'.
- 'auto': Automatically select based on problem dimensions
- 'fwd': Force forward-mode AD (jacfwd)
- 'rev': Force reverse-mode AD (jacrev)
Returns
-------
selected_mode : str
Selected mode ('fwd' or 'rev')
rationale : str
Human-readable explanation of the selection
Raises
------
ValueError
If mode is not one of 'auto', 'fwd', 'rev'
Notes
-----
Selection heuristic for 'auto' mode:
- Use jacrev when n_params > n_residuals (tall Jacobian, more params than residuals)
- Reverse-mode is O(n_residuals) operations
- Forward-mode would be O(n_params) operations
- Use jacfwd when n_params <= n_residuals (wide Jacobian, more residuals than params)
- Forward-mode is O(n_params) operations
- Reverse-mode would be O(n_residuals) operations
For high-parameter problems (e.g., 1000 params, 100 residuals), jacrev can be
10-100x faster than jacfwd.
Examples
--------
>>> from nlsq.core.least_squares import jacobian_mode_selector
>>> # Tall Jacobian (many parameters, few residuals)
>>> mode, rationale = jacobian_mode_selector(1000, 100, mode='auto')
>>> print(mode, rationale)
rev jacrev (1000 params > 100 residuals)
>>> # Wide Jacobian (few parameters, many residuals)
>>> mode, rationale = jacobian_mode_selector(100, 1000, mode='auto')
>>> print(mode, rationale)
fwd jacfwd (100 params <= 1000 residuals)
>>> # Manual override
>>> mode, rationale = jacobian_mode_selector(1000, 100, mode='fwd')
>>> print(mode, rationale)
fwd explicit override: fwd
"""
if mode == "auto":
# Heuristic: use jacrev for tall Jacobians (n_params > n_residuals)
# because reverse-mode is O(n_residuals) vs forward-mode O(n_params)
if n_params > n_residuals:
return "rev", f"jacrev ({n_params} params > {n_residuals} residuals)"
else:
return "fwd", f"jacfwd ({n_params} params <= {n_residuals} residuals)"
elif mode in ("fwd", "rev"):
return mode, f"explicit override: {mode}"
else:
raise ValueError(
f"Invalid jacobian_mode: {mode}. Must be 'auto', 'fwd', or 'rev'"
)
TERMINATION_MESSAGES = {
-3: "Inner optimization loop exceeded maximum iterations.",
-2: "Maximum iterations reached.",
-1: "Improper input parameters status returned from `leastsq`",
0: "The maximum number of function evaluations is exceeded.",
1: "`gtol` termination condition is satisfied.",
2: "`ftol` termination condition is satisfied.",
3: "`xtol` termination condition is satisfied.",
4: "Both `ftol` and `xtol` termination conditions are satisfied.",
}
[docs]
def prepare_bounds(bounds, n) -> tuple[np.ndarray, np.ndarray]:
"""Prepare bounds for optimization.
This function prepares the bounds for the optimization by ensuring that
they are both 1-D arrays of length `n`. If either bound is a scalar, it is
resized to an array of length `n`.
Parameters
----------
bounds : Tuple[np.ndarray, np.ndarray]
The lower and upper bounds for the optimization.
n : int
The length of the bounds arrays.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The prepared lower and upper bounds arrays.
"""
lb, ub = (np.asarray(b, dtype=float) for b in bounds)
if lb.ndim == 0:
lb = np.resize(lb, n)
if ub.ndim == 0:
ub = np.resize(ub, n)
return lb, ub
[docs]
def check_tolerance(
ftol: float, xtol: float, gtol: float, method: str
) -> tuple[float, float, float]:
"""Check and prepare tolerance values for optimization.
This function checks the tolerance values for the optimization and
prepares them for use. If any of the tolerances is `None`, it is set to
0. If any of the tolerances is lower than the machine epsilon, a warning
is issued and the tolerance is set to the machine epsilon. If all
tolerances are lower than the machine epsilon, a `ValueError` is raised.
Parameters
----------
ftol : float
The tolerance for the optimization function value.
xtol : float
The tolerance for the optimization variable values.
gtol : float
The tolerance for the optimization gradient values.
method : str
The name of the optimization method.
Returns
-------
Tuple[float, float, float]
The prepared tolerance values.
"""
def check(tol: float, name: str) -> float:
if tol is None:
tol = 0
elif tol < EPS:
warn(
f"Setting `{name}` below the machine epsilon ({EPS:.2e}) effectively "
"disables the corresponding termination condition.",
stacklevel=2,
)
return tol
ftol = check(ftol, "ftol")
xtol = check(xtol, "xtol")
gtol = check(gtol, "gtol")
if ftol < EPS and xtol < EPS and gtol < EPS:
raise ValueError(
"At least one of the tolerances must be higher than "
f"machine epsilon ({EPS:.2e})."
)
return ftol, xtol, gtol
[docs]
def check_x_scale(
x_scale: str | Sequence[float] | np.ndarray, x0: np.ndarray
) -> str | np.ndarray:
"""Check and prepare the `x_scale` parameter for optimization.
This function checks and prepares the `x_scale` parameter for the
optimization. `x_scale` can either be 'jac' or an array_like with positive
numbers. If it's 'jac' the jacobian is used as the scaling.
Parameters
----------
x_scale : str | Sequence[float] | np.ndarray
The scaling for the optimization variables.
x0 : np.ndarray
The initial guess for the optimization variables.
Returns
-------
str | np.ndarray
The prepared `x_scale` parameter.
"""
if isinstance(x_scale, str) and x_scale == "jac":
return x_scale
try:
x_scale_arr = np.asarray(x_scale, dtype=float)
valid: bool = bool(np.all(np.isfinite(x_scale_arr)) and np.all(x_scale_arr > 0))
except (ValueError, TypeError):
valid = False
if not valid:
raise ValueError("`x_scale` must be 'jac' or array_like with positive numbers.")
if x_scale_arr.ndim == 0:
x_scale_arr = np.resize(x_scale_arr, x0.shape)
if x_scale_arr.shape != x0.shape:
raise ValueError("Inconsistent shapes between `x_scale` and `x0`.")
return x_scale_arr
[docs]
class AutoDiffJacobian:
"""Wraps the residual fit function such that automatic differentiation is performed.
Supports both forward-mode (jacfwd) and reverse-mode (jacrev) automatic differentiation.
This needs to be a class since we need to maintain in memory three different versions
of the Jacobian for different sigma/covariance cases.
"""
[docs]
def create_ad_jacobian(
self, func: Callable, num_args: int, masked: bool = True, mode: str = "fwd"
) -> Callable:
"""Creates a function that returns the autodiff jacobian of the
residual fit function. The Jacobian of the residual fit function is
equivalent to the Jacobian of the fit function.
Parameters
----------
func : Callable
The function to take the jacobian of.
num_args : int
The number of arguments the function takes.
masked : bool, optional
Whether to use a masked jacobian, by default True
mode : str, optional
Jacobian mode ('fwd' or 'rev'), by default 'fwd'
Returns
-------
Callable
The function that returns the autodiff jacobian of the given
function.
"""
# create a list of argument indices for the wrapped function which
# will correspond to the arguments of the residual fit function and
# will be passed to JAX's jacfwd/jacrev function.
arg_list = [4 + i for i in range(num_args)]
# Select the appropriate JAX differentiation function
jac_func_ad = jacfwd if mode == "fwd" else jacrev
# Note: Uses @jit (not cached_jit) because these closures capture 'func'
# which changes each call, so caching based on source wouldn't work
@jit
def wrap_func(*all_args) -> jnp.ndarray:
"""Wraps the residual fit function such that it can be passed to the
jacfwd/jacrev function. Both require the function to have a single list
of arguments.
"""
xdata, ydata, data_mask, atransform = all_args[:4]
args = jnp.array(all_args[4:])
return func(args, xdata, ydata, data_mask, atransform)
@jit
def jac_func(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""Returns the jacobian. Places all the residual fit function
arguments into a single list for the wrapped residual fit function.
Then calls the jacfwd or jacrev function on the wrapped function with
the arglist of the arguments to differentiate with respect to which
is only the arguments of the original fit function.
"""
fixed_args = [xdata, ydata, data_mask, atransform]
all_args = [*fixed_args, *args]
jac_result = jac_func_ad(wrap_func, argnums=arg_list)(*all_args)
return jnp.array(jac_result)
@jit
def masked_jac(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""Returns the masked jacobian."""
Jt = jac_func(args, xdata, ydata, data_mask, atransform)
J = jnp.where(data_mask, Jt, 0).T
return jnp.atleast_2d(J)
@jit
def no_mask_jac(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""Returns the unmasked jacobian."""
J = jac_func(args, xdata, ydata, data_mask, atransform).T
return jnp.atleast_2d(J)
if masked:
self.jac = masked_jac
else:
self.jac = no_mask_jac
return self.jac
[docs]
class LeastSquares:
"""Core least squares optimization engine with JAX acceleration.
This class implements the main optimization algorithms for nonlinear least squares
problems, including Trust Region Reflective (TRF) and Levenberg-Marquardt (LM).
It handles automatic differentiation, bound constraints, loss functions, and
uncertainty propagation.
The class maintains separate automatic differentiation instances for different
sigma configurations (no sigma, 1D sigma, 2D covariance matrix) to optimize
compilation and execution performance.
Attributes
----------
trf : TrustRegionReflective
Trust Region Reflective algorithm implementation
ls : LossFunctionsJIT
JIT-compiled loss function implementations
logger : Logger
Internal logger for debugging and performance tracking
f : callable
Current objective function being optimized
jac : callable or None
Current Jacobian function (None for automatic differentiation)
adjn : AutoDiffJacobian
Automatic differentiation instance for unweighted problems
adj1d : AutoDiffJacobian
Automatic differentiation instance for 1D sigma weighting
adj2d : AutoDiffJacobian
Automatic differentiation instance for 2D covariance matrix weighting
Methods
-------
least_squares : Main optimization method
"""
[docs]
def __init__(
self,
enable_stability: bool = False,
enable_diagnostics: bool = False,
max_jacobian_elements_for_svd: int = 10_000_000,
) -> None:
"""Initialize LeastSquares with optimization algorithms and autodiff instances.
Sets up the Trust Region Reflective solver, loss functions, and separate
automatic differentiation instances for different weighting schemes to
maximize JAX compilation efficiency.
Parameters
----------
enable_stability : bool, default False
Enable numerical stability checks and fixes
enable_diagnostics : bool, default False
Enable optimization diagnostics collection
max_jacobian_elements_for_svd : int, default 10_000_000
Maximum Jacobian size (m × n elements) for SVD computation during
stability checks. SVD is skipped for larger Jacobians.
"""
super().__init__() # not sure if this is needed
self.trf = TrustRegionReflective()
self.ls = LossFunctionsJIT()
self.logger = get_logger("least_squares")
# initialize jacobian to None and f to a dummy function
self.f = lambda x: None
self.jac: Callable[..., jnp.ndarray] | None = None
# need a separate instance of the autodiff class for each of the
# the different sigma/covariance cases
self.adjn = AutoDiffJacobian()
self.adj1d = AutoDiffJacobian()
self.adj2d = AutoDiffJacobian()
# Initialize unified cache for JIT compilation tracking
self.cache = get_global_cache()
# Stability and diagnostics systems
self.enable_stability = enable_stability
self.enable_diagnostics = enable_diagnostics
self.max_jacobian_elements_for_svd = max_jacobian_elements_for_svd
if enable_stability:
self.stability_guard = NumericalStabilityGuard(
max_jacobian_elements_for_svd=max_jacobian_elements_for_svd
)
self.memory_manager = get_memory_manager()
if enable_diagnostics:
self.diagnostics = OptimizationDiagnostics()
def _validate_least_squares_inputs(
self,
x0: np.ndarray,
bounds: tuple,
method: str,
jac,
loss: str,
verbose: int,
max_nfev: float | None,
ftol: float,
xtol: float,
gtol: float,
x_scale,
prepared_bounds: tuple[np.ndarray, np.ndarray] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, np.ndarray]:
"""Validate and prepare least squares inputs.
Parameters
----------
prepared_bounds : tuple[np.ndarray, np.ndarray] or None, optional
Pre-prepared bounds as (lb, ub). If provided, skips prepare_bounds call.
This optimization avoids redundant bounds preparation when caller has
already called prepare_bounds. Default is None.
Returns
-------
x0 : np.ndarray
Validated initial guess
lb : np.ndarray
Lower bounds
ub : np.ndarray
Upper bounds
ftol : float
Function tolerance
xtol : float
Parameter tolerance
gtol : float
Gradient tolerance
x_scale : np.ndarray
Parameter scaling
"""
# Validate loss function
if loss not in self.ls.IMPLEMENTED_LOSSES and not callable(loss):
raise ValueError(
f"`loss` must be one of {self.ls.IMPLEMENTED_LOSSES.keys()} or a callable."
)
# Validate method
if method != "trf":
raise ValueError("`method` must be 'trf'")
# Validate jac parameter
if jac is not None and not callable(jac):
raise ValueError("`jac` must be None or callable.")
# Validate verbose level
if verbose not in [0, 1, 2]:
raise ValueError("`verbose` must be in [0, 1, 2].")
# Validate bounds
if len(bounds) != 2:
raise ValueError("`bounds` must contain 2 elements.")
# Validate max_nfev
if max_nfev is not None and max_nfev <= 0:
raise ValueError("`max_nfev` must be None or positive integer.")
# Validate x0
if np.iscomplexobj(x0):
raise ValueError("`x0` must be real.")
x0 = np.atleast_1d(x0).astype(float)
if x0.ndim > 1:
raise ValueError("`x0` must have at most 1 dimension.")
# Use prepared bounds if provided, otherwise prepare them
if prepared_bounds is not None:
lb, ub = prepared_bounds
else:
lb, ub = prepare_bounds(bounds, x0.shape[0])
if lb.shape != x0.shape or ub.shape != x0.shape:
raise ValueError("Inconsistent shapes between bounds and `x0`.")
if np.any(lb >= ub):
raise ValueError(
"Each lower bound must be strictly less than each upper bound."
)
if not in_bounds(x0, lb, ub):
raise ValueError("`x0` is infeasible.")
# Check and prepare scaling/tolerances
x_scale = check_x_scale(x_scale, x0)
ftol, xtol, gtol = check_tolerance(ftol, xtol, gtol, method)
x0 = make_strictly_feasible(x0, lb, ub)
return x0, lb, ub, ftol, xtol, gtol, x_scale
def _setup_functions(
self,
fun: Callable,
jac: Callable | None,
xdata: jnp.ndarray | None,
ydata: jnp.ndarray | None,
transform: jnp.ndarray | None,
x0: np.ndarray,
args: tuple,
kwargs: dict,
jacobian_mode_selected: str = "fwd",
) -> tuple:
"""Setup residual and Jacobian functions.
Returns
-------
rfunc : callable
Residual function
jac_func : callable
Jacobian function
"""
if xdata is not None and ydata is not None:
# Check if fit function needs updating
func_update = False
try:
if hasattr(self.f, "__code__") and hasattr(fun, "__code__"):
func_update = self.f.__code__.co_code != fun.__code__.co_code
else:
func_update = self.f != fun
except Exception:
func_update = True
# Update function if needed
if func_update:
self.update_function(fun)
if jac is None:
self.autdiff_jac(jac, mode=jacobian_mode_selected)
# Handle analytical Jacobian
if jac is not None:
if (
self.jac is None
or self.jac.__code__.co_code != jac.__code__.co_code
):
self.wrap_jac(jac)
elif self.jac is not None and not func_update:
self.autdiff_jac(jac, mode=jacobian_mode_selected)
# Select appropriate residual function and Jacobian
if transform is None:
rfunc = self.func_none
jac_func = self.jac_none
elif transform.ndim == 1:
rfunc = self.func_1d
jac_func = self.jac_1d
else:
rfunc = self.func_2d
jac_func = self.jac_2d
else:
# SciPy compatibility mode
def wrap_func(fargs, xdata, ydata, data_mask, atransform):
return jnp.atleast_1d(fun(fargs, *args, **kwargs))
rfunc = wrap_func
if jac is None:
adj = AutoDiffJacobian()
jac_func = adj.create_ad_jacobian(
wrap_func, x0.size, masked=False, mode=jacobian_mode_selected
)
else:
# Capture jac in closure with proper type narrowing
jac_callable = jac
def wrap_jac(fargs, xdata, ydata, data_mask, atransform):
return jnp.atleast_2d(jac_callable(fargs, *args, **kwargs))
jac_func = wrap_jac
return rfunc, jac_func
def _evaluate_initial_residuals_and_jacobian(
self,
rfunc: Callable,
jac_func: Callable,
x0: np.ndarray,
xdata: jnp.ndarray | None,
ydata: jnp.ndarray | None,
data_mask: jnp.ndarray | None,
transform: jnp.ndarray | None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Evaluate initial residuals and Jacobian, with stability checks.
Parameters
----------
rfunc : Callable
Residual function
jac_func : Callable
Jacobian function
x0 : np.ndarray
Initial parameters
xdata : jnp.ndarray | None
X data
ydata : jnp.ndarray | None
Y data
data_mask : jnp.ndarray | None
Data mask
transform : jnp.ndarray | None
Transform matrix
Returns
-------
f0 : jnp.ndarray
Initial residuals
J0 : jnp.ndarray
Initial Jacobian
Raises
------
ValueError
If residuals are not 1-D or not finite
"""
f0 = rfunc(x0, xdata, ydata, data_mask, transform)
J0 = jac_func(x0, xdata, ydata, data_mask, transform)
if f0.ndim != 1:
raise ValueError(
f"`fun` must return at most 1-d array_like. f0.shape: {f0.shape}"
)
if not np.all(np.isfinite(f0)):
if self.enable_stability:
self.logger.warning("Non-finite residuals detected, attempting to fix")
f0 = jnp.clip(f0, -1e10, 1e10)
if not np.all(np.isfinite(f0)):
raise ValueError("Residuals are not finite after stabilization")
else:
raise ValueError("Residuals are not finite in the initial point.")
return f0, J0
def _check_and_fix_initial_jacobian(
self, J0: jnp.ndarray, m: int, n: int
) -> jnp.ndarray:
"""Check and fix initial Jacobian if stability is enabled.
Parameters
----------
J0 : jnp.ndarray
Initial Jacobian
m : int
Number of residuals
n : int
Number of parameters
Returns
-------
J0 : jnp.ndarray
Validated/fixed Jacobian
Raises
------
ValueError
If Jacobian has wrong shape
"""
# Check and fix Jacobian if stability is enabled
if self.enable_stability and J0 is not None:
J0_fixed, issues = self.stability_guard.check_and_fix_jacobian(J0)
if issues:
# Only warn if there's an actual problem, not just SVD skipped for performance
has_problem = (
issues.get("has_nan")
or issues.get("has_inf")
or issues.get("is_ill_conditioned")
or issues.get("regularized")
)
if has_problem:
self.logger.warning(
"Jacobian issues detected and fixed", issues=issues
)
elif issues.get("svd_skipped"):
self.logger.debug(
"SVD skipped for large Jacobian (expected for datasets > 10M points)",
issues=issues,
)
J0 = J0_fixed
if J0 is not None and J0.shape != (m, n):
raise ValueError(
f"The return value of `jac` has wrong shape: expected {(m, n)}, "
f"actual {J0.shape}."
)
return J0
def _compute_initial_cost(
self,
f0: jnp.ndarray,
loss: str | Callable,
loss_function: Callable | None,
f_scale: float,
data_mask: jnp.ndarray,
) -> float:
"""Compute initial cost from residuals and loss function.
Parameters
----------
f0 : jnp.ndarray
Initial residuals
loss : str | Callable
Loss function name or callable
loss_function : Callable | None
Loss function implementation
f_scale : float
Loss function scale parameter
data_mask : jnp.ndarray
Data mask
Returns
-------
initial_cost : float
Initial cost value
Raises
------
ValueError
If callable loss returns wrong shape
"""
m = f0.size
self.logger.debug("Computing initial cost", loss_type=loss, f_scale=f_scale)
if callable(loss):
if loss_function is None:
raise ValueError("loss_function must be provided when loss is callable")
rho = loss_function(f0, f_scale, data_mask=data_mask)
if rho.shape != (3, m):
raise ValueError("The return value of `loss` callable has wrong shape.")
initial_cost_jnp = self.trf.calculate_cost(rho, data_mask)
elif loss_function is not None:
initial_cost_jnp = loss_function(
f0, f_scale, data_mask=data_mask, cost_only=True
)
else:
initial_cost_jnp = self.trf.default_loss_func(f0)
return float(initial_cost_jnp)
def _check_memory_and_adjust_solver(
self, m: int, n: int, method: str, tr_solver: str | None
) -> str | None:
"""Check memory requirements and adjust solver if needed.
Parameters
----------
m : int
Number of residuals
n : int
Number of parameters
method : str
Optimization method
tr_solver : str | None
Current trust region solver
Returns
-------
tr_solver : str | None
Adjusted trust region solver (or original if no adjustment needed)
"""
if self.enable_stability:
memory_required = self.memory_manager.predict_memory_requirement(
m, n, method
)
is_available, msg = self.memory_manager.check_memory_availability(
memory_required
)
if not is_available:
self.logger.warning("Memory constraint detected", details=msg)
# Switch to memory-efficient solver
tr_solver = "lsmr"
return tr_solver
def _create_stable_wrappers(
self, rfunc: Callable, jac_func: Callable
) -> tuple[Callable, Callable]:
"""Create stability wrapper functions for residuals and Jacobian.
NOTE: Stability checks are only performed at initialization, not per-iteration.
Per-iteration Jacobian modification was found to cause optimization divergence
due to accumulated numerical perturbations and expensive SVD computations.
The residual wrapper still checks for NaN/Inf at each evaluation since this
is a cheap O(n) check that can catch numerical explosions early.
Parameters
----------
rfunc : Callable
Original residual function
jac_func : Callable
Original Jacobian function
Returns
-------
rfunc : Callable
Wrapped residual function (with NaN/Inf checking)
jac_func : Callable
Original Jacobian function (NOT wrapped - stability checked at init only)
"""
if self.enable_stability:
original_rfunc = rfunc
def stable_rfunc(x, xd, yd, dm, tf):
result = original_rfunc(x, xd, yd, dm, tf)
result = jnp.where(
jnp.isfinite(result), result, jnp.clip(result, -1e10, 1e10)
)
return result
# NOTE: Jacobian is NOT wrapped - stability checked only at initialization
# via _check_and_fix_initial_jacobian(). Per-iteration Jacobian modification
# causes optimization divergence due to SVD overhead and accumulated perturbations.
return stable_rfunc, jac_func
return rfunc, jac_func
def _run_trf_optimization(
self,
rfunc: Callable,
jac_func: Callable,
xdata: jnp.ndarray | None,
ydata: jnp.ndarray | None,
data_mask: jnp.ndarray,
transform: jnp.ndarray | None,
x0: np.ndarray,
f0: jnp.ndarray,
J0: jnp.ndarray,
lb: np.ndarray,
ub: np.ndarray,
ftol: float,
xtol: float,
gtol: float,
max_nfev: float | None,
f_scale: float,
x_scale: np.ndarray,
loss_function: Callable | None,
tr_options: dict,
verbose: int,
timeit: bool,
tr_solver: str | None,
method: str,
loss: str,
n: int,
m: int,
initial_cost: float,
timeout_kwargs: dict,
callback: Callable | None,
):
"""Run TRF optimization with diagnostics and logging.
Returns
-------
result : OptimizeResult
Optimization result
"""
with self.logger.timer("optimization"):
self.logger.debug("Calling TRF optimizer", initial_cost=initial_cost)
# Initialize diagnostics if enabled
if self.enable_diagnostics:
self.diagnostics.start_optimization(
n_params=n, n_data=m, method=method, loss=loss
)
result = self.trf.trf(
rfunc,
xdata,
ydata,
jac_func,
data_mask,
transform,
x0,
f0,
J0,
lb,
ub,
ftol,
xtol,
gtol,
max_nfev,
f_scale,
x_scale,
loss_function,
tr_options.copy(),
verbose,
timeit,
solver=tr_solver or "exact",
diagnostics=self.diagnostics if self.enable_diagnostics else None,
callback=callback,
**timeout_kwargs,
)
return result
def _process_optimization_result(self, result, initial_cost: float, verbose: int):
"""Process optimization result and log convergence.
Parameters
----------
result : OptimizeResult
Optimization result
initial_cost : float
Initial cost value
verbose : int
Verbosity level
Returns
-------
result : OptimizeResult
Processed result with message and success flag
"""
result.message = TERMINATION_MESSAGES[result.status]
result.success = result.status > 0
# Log convergence
self.logger.convergence(
reason=result.message,
iterations=getattr(result, "nit", None),
final_cost=result.cost,
time_elapsed=self.logger.timers.get("optimization", 0),
final_gradient_norm=getattr(result, "optimality", None),
)
if verbose >= 1:
self.logger.info(result.message)
self.logger.info(
f"Function evaluations {result.nfev}, initial cost {initial_cost:.4e}, final cost "
f"{result.cost:.4e}, first-order optimality {result.optimality:.2e}."
)
return result
[docs]
def least_squares(
self,
fun: Callable,
x0: ArrayLike,
jac: Callable | None = None,
bounds: BoundsTuple | tuple[float, float] = (-np.inf, np.inf),
method: MethodLiteral = "trf",
ftol: float = DEFAULT_FTOL,
xtol: float = DEFAULT_XTOL,
gtol: float = DEFAULT_GTOL,
x_scale: Literal["jac"] | ArrayLike | float = 1.0,
loss: str = "linear",
f_scale: float = 1.0,
diff_step: ArrayLike | None = None,
tr_solver: Literal["exact", "lsmr"] | None = None,
tr_options: dict[str, Any] | None = None,
jac_sparsity: ArrayLike | None = None,
max_nfev: float | None = None,
verbose: int = 0,
jacobian_mode: Literal["auto", "fwd", "rev"] | None = None,
xdata: ArrayLike | None = None,
ydata: ArrayLike | None = None,
data_mask: ArrayLike | None = None,
transform: ArrayLike | None = None,
timeit: bool = False,
callback: CallbackFunction | None = None,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
prepared_bounds: tuple[np.ndarray, np.ndarray] | None = None,
**timeout_kwargs: Any,
) -> dict[str, Any]:
"""Solve nonlinear least squares problem using JAX-accelerated algorithms.
This method orchestrates the optimization process by calling focused
helper methods for each major step: validation, function setup,
initial evaluation, stability checks, and optimization execution.
Parameters
----------
fun : callable
Residual function. Must use jax.numpy operations.
x0 : array_like
Initial parameter guess.
jac : callable or None, optional
Jacobian function. If None, uses JAX autodiff.
bounds : 2-tuple, optional
Parameter bounds as (lower, upper).
method : str, optional
Optimization algorithm ('trf').
ftol, xtol, gtol : float, optional
Convergence tolerances for function, parameters, and gradient.
x_scale : str or array_like, optional
Parameter scaling ('jac' for automatic).
loss : str or callable, optional
Robust loss function ('linear', 'huber', 'soft_l1', etc.).
f_scale : float, optional
Scale parameter for robust loss functions.
max_nfev : int, optional
Maximum function evaluations.
verbose : int, optional
Verbosity level (0, 1, or 2).
jacobian_mode : {'auto', 'fwd', 'rev'}, optional
Jacobian automatic differentiation mode. If None, uses configuration
from environment variable, config file, or auto-default. Default is None.
- 'auto': Automatically select based on problem dimensions
- 'fwd': Force forward-mode AD (jacfwd)
- 'rev': Force reverse-mode AD (jacrev)
xdata, ydata : array_like, optional
Data for curve fitting applications.
data_mask : array_like, optional
Boolean mask for data exclusion.
transform : array_like, optional
Transformation matrix for weighted fitting.
timeit : bool, optional
Enable detailed timing analysis.
callback : callable or None, optional
Callback function called after each optimization iteration with signature
``callback(iteration, cost, params, info)``. Useful for monitoring
optimization progress, logging, or implementing custom stopping criteria.
If None (default), no callback is invoked.
args : tuple, optional
Additional arguments for objective function.
kwargs : dict, optional
Additional optimization parameters.
prepared_bounds : tuple[np.ndarray, np.ndarray] or None, optional
Pre-prepared bounds as (lb, ub). If provided, skips prepare_bounds call.
This optimization avoids redundant bounds preparation when caller has
already called prepare_bounds. Default is None.
Returns
-------
result : OptimizeResult
Optimization result with solution, convergence info, and statistics.
"""
# Step 1: Initialize parameters and validate options
if kwargs is None:
kwargs = {}
if tr_options is None:
tr_options = {}
if "options" in timeout_kwargs:
raise TypeError("'options' is not a supported keyword argument")
if data_mask is None and ydata is not None:
data_mask = jnp.ones(len(ydata), dtype=bool)
# Step 2: Validate inputs (pass prepared_bounds if available)
x0, lb, ub, ftol, xtol, gtol, x_scale = self._validate_least_squares_inputs(
x0,
bounds,
method,
jac,
loss,
verbose,
max_nfev,
ftol,
xtol,
gtol,
x_scale,
prepared_bounds=prepared_bounds,
)
self.n = len(x0)
n = x0.size
# Step 2.5: Determine Jacobian mode with configuration precedence
# Precedence: function parameter > env var > config file > auto-default
if jacobian_mode is not None:
# Function parameter has highest priority
jacobian_mode_config = jacobian_mode
jacobian_mode_source = "function parameter"
else:
# Get from environment/config/default
from nlsq.config import get_jacobian_mode
jacobian_mode_config, jacobian_mode_source = get_jacobian_mode()
# Step 3: Log optimization setup
self.logger.info(
"Starting least squares optimization",
method=method,
n_params=self.n,
loss=loss,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
# Step 4: Setup residual and Jacobian functions
# Use ydata length as proxy for n_residuals to avoid redundant forward pass.
# This is correct because for standard NLSQ, n_residuals == len(ydata).
# The actual m is verified after initial evaluation at Step 5.
if jac is None and xdata is not None and ydata is not None:
m_estimate = (
len(ydata) if hasattr(ydata, "__len__") else np.asarray(ydata).size
)
# Select Jacobian mode based on problem dimensions
jacobian_mode_selected, jacobian_rationale = jacobian_mode_selector(
n, m_estimate, mode=jacobian_mode_config
)
# Log Jacobian mode selection in debug mode
self.logger.debug(
f"Jacobian mode: '{jacobian_mode_selected}' (from {jacobian_mode_source}). Rationale: {jacobian_rationale}"
)
else:
# Analytical Jacobian or SciPy mode - use default forward mode
jacobian_mode_selected = "fwd"
jacobian_rationale = "analytical Jacobian or SciPy compatibility mode"
self.logger.debug(
f"Jacobian mode: '{jacobian_mode_selected}'. Rationale: {jacobian_rationale}"
)
rfunc, jac_func = self._setup_functions(
fun, jac, xdata, ydata, transform, x0, args, kwargs, jacobian_mode_selected
)
# Step 5: Evaluate initial residuals and Jacobian
f0, J0 = self._evaluate_initial_residuals_and_jacobian(
rfunc, jac_func, x0, xdata, ydata, data_mask, transform
)
m = f0.size
# Step 6: Check and fix initial Jacobian
J0 = self._check_and_fix_initial_jacobian(J0, m, n)
# Step 7: Setup data mask and loss function
if data_mask is None:
data_mask = jnp.ones(m)
loss_function = self.ls.get_loss_function(loss)
# Step 8: Compute initial cost
initial_cost = self._compute_initial_cost(
f0, loss, loss_function, f_scale, data_mask
)
# Step 8.5: Detect sparsity and auto-select sparse solver if beneficial
# This happens AFTER initial Jacobian computation (so we have J0 available)
# Auto-selection triggers when: sparsity >50% AND n_residuals >10K
sparsity_ratio = 0.0
is_sparse_problem = False
sparse_solver_selected = False
# Skip sparsity detection when tr_solver is explicitly set (the caller
# already knows what solver to use) or for small problems where the
# overhead exceeds any benefit (sparse solver requires m > 10K).
should_detect_sparsity = (
tr_solver is None
and m > 10000
and xdata is not None
and ydata is not None
and fun is not None
)
if should_detect_sparsity:
# Import sparsity detection function
from nlsq.core.sparse_jacobian import detect_sparsity_at_p0
# Detect sparsity at p0 (uses sampling for efficiency)
try:
sparsity_ratio, is_sparse_problem = detect_sparsity_at_p0(
func=fun,
p0=x0,
xdata=xdata,
n_residuals=m,
threshold=0.01,
sample_size=min(100, m),
)
# Auto-select sparse solver only if user hasn't specified one
if is_sparse_problem and tr_solver is None:
tr_solver = "sparse"
sparse_solver_selected = True
self.logger.info(
f"Sparse solver activated: sparsity={sparsity_ratio:.1%}, "
f"n_residuals={m}, n_params={n}"
)
elif tr_solver is None:
tr_solver = "exact"
self.logger.debug(
f"Dense solver selected: low sparsity ({sparsity_ratio:.1%})"
)
except Exception as e:
# If sparsity detection fails, fall back to dense solver
self.logger.warning(
f"Sparsity detection failed: {e}. Using dense solver."
)
if tr_solver is None:
tr_solver = "exact"
elif tr_solver is None:
tr_solver = "exact"
# Step 9: Check memory and adjust solver if needed
tr_solver = self._check_memory_and_adjust_solver(m, n, method, tr_solver)
# Step 10: Create stable wrappers for residual and Jacobian functions
rfunc, jac_func = self._create_stable_wrappers(rfunc, jac_func)
# Step 11: Run TRF optimization
result = self._run_trf_optimization(
rfunc,
jac_func,
xdata,
ydata,
data_mask,
transform,
x0,
f0,
J0,
lb,
ub,
ftol,
xtol,
gtol,
max_nfev,
f_scale,
x_scale,
loss_function,
tr_options,
verbose,
timeit,
tr_solver,
method,
loss,
n,
m,
initial_cost,
timeout_kwargs,
callback,
)
# Step 12: Process optimization result
result = self._process_optimization_result(result, initial_cost, verbose)
# Step 13: Add sparsity diagnostics to result (Task 6.5)
# This provides transparency about whether sparse solver was used
result.sparsity_detected = {
"detected": is_sparse_problem,
"ratio": float(sparsity_ratio),
"solver": "sparse" if sparse_solver_selected else "dense",
"n_residuals": m,
"n_params": n,
}
# Log sparsity info in debug mode
self.logger.debug(
f"Sparsity diagnostics: detected={is_sparse_problem}, "
f"ratio={sparsity_ratio:.1%}, solver={'sparse' if sparse_solver_selected else 'dense'}"
)
return result
[docs]
def autdiff_jac(self, jac: None, mode: str = "fwd") -> None:
"""We do this for all three sigma transformed functions such
that if sigma is changed from none to 1D to covariance sigma then no
retracing is needed.
Parameters
----------
jac : None
Passed in to maintain compatibility with the user defined Jacobian
function.
mode : str, optional
Jacobian mode ('fwd' or 'rev'), by default 'fwd'
"""
self.jac_none = self.adjn.create_ad_jacobian(self.func_none, self.n, mode=mode)
self.jac_1d = self.adj1d.create_ad_jacobian(self.func_1d, self.n, mode=mode)
self.jac_2d = self.adj2d.create_ad_jacobian(self.func_2d, self.n, mode=mode)
# jac is
self.jac = jac
[docs]
def update_function(self, func: Callable) -> None:
"""Wraps the given fit function to be a residual function using the
data. The wrapped function is in a JAX JIT compatible format which
is purely functional. This requires that both the data mask and the
uncertainty transform are passed to the function. Even for the case
where the data mask is all True and the uncertainty transform is None
we still need to pass these arguments to the function due JAX's
functional nature.
Parameters
----------
func : Callable
The fit function to wrap.
Returns
-------
None
"""
# Note: Uses @jit (not cached_jit) because this closure captures 'func'
# which changes each call, so caching based on source wouldn't work
@jit
def masked_residual_func(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
) -> jnp.ndarray:
"""Compute the residual of the function evaluated at `args` with
respect to the data.
This function computes the residual of the user fit function
evaluated at `args` with respect to the data `(xdata, ydata)`,
masked by `data_mask`. The residual is defined as the difference
between the function evaluation and the data. The masked residual
is obtained by setting the residual to 0 wherever the corresponding
element of `data_mask` is 0.
Parameters
----------
args : jnp.ndarray
The parameters of the function.
xdata : jnp.ndarray
The independent variable data.
ydata : jnp.ndarray
The dependent variable data.
data_mask : jnp.ndarray
The mask for the data.
Returns
-------
jnp.ndarray
The masked residual of the function evaluated at `args` with respect to the data.
"""
# JAX 0.8.0+ handles tuple unpacking efficiently without TracerArrayConversionError
# This replaces the previous 100-line if-elif chain (Optimization #2)
# See: OPTIMIZATION_QUICK_REFERENCE.md for performance analysis
func_eval = func(xdata, *args) - ydata
return jnp.where(data_mask, func_eval, 0)
# need to define a separate function for each of the different
# sigma/covariance cases as the uncertainty transform is different
# for each case. In future could remove the no transfore bit by setting
# the uncertainty transform to all ones in the case where there is no
# uncertainty transform.
@jit
def func_no_transform(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The residual function when there is no uncertainty transform.
The atranform argument is not used in this case, but is included
for consistency with the other cases."""
return masked_residual_func(args, xdata, ydata, data_mask)
@jit
def func_1d_transform(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The residual function when there is a 1D uncertainty transform,
that is when only the diagonal elements of the inverse covariance
matrix are used."""
# OPT-11: Inlined masked_residual_func for XLA fusion optimization
# XLA can better fuse operations when they're in the same JIT scope
func_eval = func(xdata, *args) - ydata
masked_residual = jnp.where(data_mask, func_eval, 0)
return atransform * masked_residual
@jit
def func_2d_transform(
args: jnp.ndarray,
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The residual function when there is a 2D uncertainty transform,
that is when the full covariance matrix is given."""
# OPT-11: Inlined masked_residual_func for XLA fusion optimization
func_eval = func(xdata, *args) - ydata
masked_residual = jnp.where(data_mask, func_eval, 0)
return jax_solve_triangular(atransform, masked_residual, lower=True)
self.func_none = func_no_transform
self.func_1d = func_1d_transform
self.func_2d = func_2d_transform
self.f = func
[docs]
def wrap_jac(self, jac: Callable) -> None:
"""Wraps an user defined Jacobian function to allow for data masking
and uncertainty transforms. The wrapped function is in a JAX JIT
compatible format which is purely functional. This requires that both
the data mask and the uncertainty transform are passed to the function.
Using an analytical Jacobian of the fit function is equivalent to
the Jacobian of the residual function.
Also note that the analytical Jacobian doesn't require the independent
ydata, but we still need to pass it to the function to maintain
compatibility with autdiff version which does require the ydata.
Parameters
----------
jac : Callable
The Jacobian function to wrap.
Returns
-------
jnp.ndarray
The masked Jacobian of the function evaluated at `args` with respect to the data.
"""
# Note: Uses @jit (not cached_jit) because these closures capture 'jac'
# which changes each call, so caching based on source wouldn't work
@jit
def jac_func(coords: jnp.ndarray, args: jnp.ndarray) -> jnp.ndarray:
# Create individual arguments from the array for JAX compatibility
# This avoids the TracerArrayConversionError with dynamic unpacking
if args.size == 1:
jac_fwd = jac(coords, args[0])
elif args.size == 2:
jac_fwd = jac(coords, args[0], args[1])
elif args.size == 3:
jac_fwd = jac(coords, args[0], args[1], args[2])
elif args.size == 4:
jac_fwd = jac(coords, args[0], args[1], args[2], args[3])
elif args.size == 5:
jac_fwd = jac(coords, args[0], args[1], args[2], args[3], args[4])
elif args.size == 6:
jac_fwd = jac(
coords, args[0], args[1], args[2], args[3], args[4], args[5]
)
else:
# For more parameters, use a more generic approach
args_list = [args[i] for i in range(args.size)]
jac_fwd = jac(coords, *args_list)
return jnp.array(jac_fwd)
@jit
def masked_jac(
coords: jnp.ndarray, args: jnp.ndarray, data_mask: jnp.ndarray
) -> jnp.ndarray:
"""Compute the wrapped Jacobian but masks out the padded elements
with 0s"""
Jt = jac_func(coords, args)
return jnp.where(data_mask, Jt, 0).T
@jit
def jac_no_transform(
args: jnp.ndarray,
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is no
uncertainty transform."""
return jnp.atleast_2d(masked_jac(coords, args, data_mask))
@jit
def jac_1d_transform(
args: jnp.ndarray,
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is a 1D uncertainty
transform, that is when only the diagonal elements of the inverse
covariance matrix are used."""
J = masked_jac(coords, args, data_mask)
return jnp.atleast_2d(atransform[:, jnp.newaxis] * jnp.asarray(J))
@jit
def jac_2d_transform(
args: jnp.ndarray,
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray,
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is a 2D uncertainty
transform, that is when the full covariance matrix is given."""
J = masked_jac(coords, args, data_mask)
return jnp.atleast_2d(
jax_solve_triangular(atransform, jnp.asarray(J), lower=True)
)
# we need all three versions of the Jacobian function to allow for
# changing the sigma transform from none to 1D to 2D without having
# to retrace the function
self.jac_none = jac_no_transform
self.jac_1d = jac_1d_transform
self.jac_2d = jac_2d_transform
self.jac = jac