Source code for nlsq.core.least_squares

"""Generic interface for least-squares minimization."""

# mypy: disable-error-code="arg-type,assignment"
# Note: Remaining mypy errors are arg-type/assignment mismatches where Optional values
# are passed to methods expecting non-Optional, or Literal type narrowing issues.
# These require deeper refactoring. Fixed in this file: check_x_scale types,
# safe_clip→jnp.clip, logger.warning, None callable guards, return type fixes.

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any, Literal
from warnings import warn

import numpy as np

# Initialize JAX configuration through central config
from nlsq.config import JAXConfig

__jax_config = JAXConfig()
import jax.numpy as jnp
from jax import jacfwd, jacrev, jit
from jax.scipy.linalg import solve_triangular as jax_solve_triangular

from nlsq.caching.memory_manager import get_memory_manager
from nlsq.caching.unified_cache import get_global_cache
from nlsq.common_scipy import EPS, in_bounds, make_strictly_feasible
from nlsq.constants import DEFAULT_FTOL, DEFAULT_GTOL, DEFAULT_XTOL
from nlsq.core.loss_functions import LossFunctionsJIT
from nlsq.core.trf import TrustRegionReflective
from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.types import ArrayLike, BoundsTuple, CallbackFunction, MethodLiteral
from nlsq.utils.diagnostics import OptimizationDiagnostics
from nlsq.utils.logging import get_logger


[docs] def jacobian_mode_selector( n_params: int, n_residuals: int, mode: str = "auto" ) -> tuple[str, str]: """Select Jacobian automatic differentiation mode based on problem dimensions. Automatically chooses between forward-mode (jacfwd) and reverse-mode (jacrev) automatic differentiation based on the Jacobian shape to minimize computational cost. Parameters ---------- n_params : int Number of parameters (columns in Jacobian) n_residuals : int Number of residuals (rows in Jacobian) mode : {'auto', 'fwd', 'rev'}, optional Jacobian mode selection. Default is 'auto'. - 'auto': Automatically select based on problem dimensions - 'fwd': Force forward-mode AD (jacfwd) - 'rev': Force reverse-mode AD (jacrev) Returns ------- selected_mode : str Selected mode ('fwd' or 'rev') rationale : str Human-readable explanation of the selection Raises ------ ValueError If mode is not one of 'auto', 'fwd', 'rev' Notes ----- Selection heuristic for 'auto' mode: - Use jacrev when n_params > n_residuals (tall Jacobian, more params than residuals) - Reverse-mode is O(n_residuals) operations - Forward-mode would be O(n_params) operations - Use jacfwd when n_params <= n_residuals (wide Jacobian, more residuals than params) - Forward-mode is O(n_params) operations - Reverse-mode would be O(n_residuals) operations For high-parameter problems (e.g., 1000 params, 100 residuals), jacrev can be 10-100x faster than jacfwd. Examples -------- >>> from nlsq.core.least_squares import jacobian_mode_selector >>> # Tall Jacobian (many parameters, few residuals) >>> mode, rationale = jacobian_mode_selector(1000, 100, mode='auto') >>> print(mode, rationale) rev jacrev (1000 params > 100 residuals) >>> # Wide Jacobian (few parameters, many residuals) >>> mode, rationale = jacobian_mode_selector(100, 1000, mode='auto') >>> print(mode, rationale) fwd jacfwd (100 params <= 1000 residuals) >>> # Manual override >>> mode, rationale = jacobian_mode_selector(1000, 100, mode='fwd') >>> print(mode, rationale) fwd explicit override: fwd """ if mode == "auto": # Heuristic: use jacrev for tall Jacobians (n_params > n_residuals) # because reverse-mode is O(n_residuals) vs forward-mode O(n_params) if n_params > n_residuals: return "rev", f"jacrev ({n_params} params > {n_residuals} residuals)" else: return "fwd", f"jacfwd ({n_params} params <= {n_residuals} residuals)" elif mode in ("fwd", "rev"): return mode, f"explicit override: {mode}" else: raise ValueError( f"Invalid jacobian_mode: {mode}. Must be 'auto', 'fwd', or 'rev'" )
TERMINATION_MESSAGES = { -3: "Inner optimization loop exceeded maximum iterations.", -2: "Maximum iterations reached.", -1: "Improper input parameters status returned from `leastsq`", 0: "The maximum number of function evaluations is exceeded.", 1: "`gtol` termination condition is satisfied.", 2: "`ftol` termination condition is satisfied.", 3: "`xtol` termination condition is satisfied.", 4: "Both `ftol` and `xtol` termination conditions are satisfied.", }
[docs] def prepare_bounds(bounds, n) -> tuple[np.ndarray, np.ndarray]: """Prepare bounds for optimization. This function prepares the bounds for the optimization by ensuring that they are both 1-D arrays of length `n`. If either bound is a scalar, it is resized to an array of length `n`. Parameters ---------- bounds : Tuple[np.ndarray, np.ndarray] The lower and upper bounds for the optimization. n : int The length of the bounds arrays. Returns ------- Tuple[np.ndarray, np.ndarray] The prepared lower and upper bounds arrays. """ lb, ub = (np.asarray(b, dtype=float) for b in bounds) if lb.ndim == 0: lb = np.resize(lb, n) if ub.ndim == 0: ub = np.resize(ub, n) return lb, ub
[docs] def check_tolerance( ftol: float, xtol: float, gtol: float, method: str ) -> tuple[float, float, float]: """Check and prepare tolerance values for optimization. This function checks the tolerance values for the optimization and prepares them for use. If any of the tolerances is `None`, it is set to 0. If any of the tolerances is lower than the machine epsilon, a warning is issued and the tolerance is set to the machine epsilon. If all tolerances are lower than the machine epsilon, a `ValueError` is raised. Parameters ---------- ftol : float The tolerance for the optimization function value. xtol : float The tolerance for the optimization variable values. gtol : float The tolerance for the optimization gradient values. method : str The name of the optimization method. Returns ------- Tuple[float, float, float] The prepared tolerance values. """ def check(tol: float, name: str) -> float: if tol is None: tol = 0 elif tol < EPS: warn( f"Setting `{name}` below the machine epsilon ({EPS:.2e}) effectively " "disables the corresponding termination condition.", stacklevel=2, ) return tol ftol = check(ftol, "ftol") xtol = check(xtol, "xtol") gtol = check(gtol, "gtol") if ftol < EPS and xtol < EPS and gtol < EPS: raise ValueError( "At least one of the tolerances must be higher than " f"machine epsilon ({EPS:.2e})." ) return ftol, xtol, gtol
[docs] def check_x_scale( x_scale: str | Sequence[float] | np.ndarray, x0: np.ndarray ) -> str | np.ndarray: """Check and prepare the `x_scale` parameter for optimization. This function checks and prepares the `x_scale` parameter for the optimization. `x_scale` can either be 'jac' or an array_like with positive numbers. If it's 'jac' the jacobian is used as the scaling. Parameters ---------- x_scale : str | Sequence[float] | np.ndarray The scaling for the optimization variables. x0 : np.ndarray The initial guess for the optimization variables. Returns ------- str | np.ndarray The prepared `x_scale` parameter. """ if isinstance(x_scale, str) and x_scale == "jac": return x_scale try: x_scale_arr = np.asarray(x_scale, dtype=float) valid: bool = bool(np.all(np.isfinite(x_scale_arr)) and np.all(x_scale_arr > 0)) except (ValueError, TypeError): valid = False if not valid: raise ValueError("`x_scale` must be 'jac' or array_like with positive numbers.") if x_scale_arr.ndim == 0: x_scale_arr = np.resize(x_scale_arr, x0.shape) if x_scale_arr.shape != x0.shape: raise ValueError("Inconsistent shapes between `x_scale` and `x0`.") return x_scale_arr
[docs] class AutoDiffJacobian: """Wraps the residual fit function such that automatic differentiation is performed. Supports both forward-mode (jacfwd) and reverse-mode (jacrev) automatic differentiation. This needs to be a class since we need to maintain in memory three different versions of the Jacobian for different sigma/covariance cases. """
[docs] def create_ad_jacobian( self, func: Callable, num_args: int, masked: bool = True, mode: str = "fwd" ) -> Callable: """Creates a function that returns the autodiff jacobian of the residual fit function. The Jacobian of the residual fit function is equivalent to the Jacobian of the fit function. Parameters ---------- func : Callable The function to take the jacobian of. num_args : int The number of arguments the function takes. masked : bool, optional Whether to use a masked jacobian, by default True mode : str, optional Jacobian mode ('fwd' or 'rev'), by default 'fwd' Returns ------- Callable The function that returns the autodiff jacobian of the given function. """ # create a list of argument indices for the wrapped function which # will correspond to the arguments of the residual fit function and # will be passed to JAX's jacfwd/jacrev function. arg_list = [4 + i for i in range(num_args)] # Select the appropriate JAX differentiation function jac_func_ad = jacfwd if mode == "fwd" else jacrev # Note: Uses @jit (not cached_jit) because these closures capture 'func' # which changes each call, so caching based on source wouldn't work @jit def wrap_func(*all_args) -> jnp.ndarray: """Wraps the residual fit function such that it can be passed to the jacfwd/jacrev function. Both require the function to have a single list of arguments. """ xdata, ydata, data_mask, atransform = all_args[:4] args = jnp.array(all_args[4:]) return func(args, xdata, ydata, data_mask, atransform) @jit def jac_func( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """Returns the jacobian. Places all the residual fit function arguments into a single list for the wrapped residual fit function. Then calls the jacfwd or jacrev function on the wrapped function with the arglist of the arguments to differentiate with respect to which is only the arguments of the original fit function. """ fixed_args = [xdata, ydata, data_mask, atransform] all_args = [*fixed_args, *args] jac_result = jac_func_ad(wrap_func, argnums=arg_list)(*all_args) return jnp.array(jac_result) @jit def masked_jac( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """Returns the masked jacobian.""" Jt = jac_func(args, xdata, ydata, data_mask, atransform) J = jnp.where(data_mask, Jt, 0).T return jnp.atleast_2d(J) @jit def no_mask_jac( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """Returns the unmasked jacobian.""" J = jac_func(args, xdata, ydata, data_mask, atransform).T return jnp.atleast_2d(J) if masked: self.jac = masked_jac else: self.jac = no_mask_jac return self.jac
[docs] class LeastSquares: """Core least squares optimization engine with JAX acceleration. This class implements the main optimization algorithms for nonlinear least squares problems, including Trust Region Reflective (TRF) and Levenberg-Marquardt (LM). It handles automatic differentiation, bound constraints, loss functions, and uncertainty propagation. The class maintains separate automatic differentiation instances for different sigma configurations (no sigma, 1D sigma, 2D covariance matrix) to optimize compilation and execution performance. Attributes ---------- trf : TrustRegionReflective Trust Region Reflective algorithm implementation ls : LossFunctionsJIT JIT-compiled loss function implementations logger : Logger Internal logger for debugging and performance tracking f : callable Current objective function being optimized jac : callable or None Current Jacobian function (None for automatic differentiation) adjn : AutoDiffJacobian Automatic differentiation instance for unweighted problems adj1d : AutoDiffJacobian Automatic differentiation instance for 1D sigma weighting adj2d : AutoDiffJacobian Automatic differentiation instance for 2D covariance matrix weighting Methods ------- least_squares : Main optimization method """
[docs] def __init__( self, enable_stability: bool = False, enable_diagnostics: bool = False, max_jacobian_elements_for_svd: int = 10_000_000, ) -> None: """Initialize LeastSquares with optimization algorithms and autodiff instances. Sets up the Trust Region Reflective solver, loss functions, and separate automatic differentiation instances for different weighting schemes to maximize JAX compilation efficiency. Parameters ---------- enable_stability : bool, default False Enable numerical stability checks and fixes enable_diagnostics : bool, default False Enable optimization diagnostics collection max_jacobian_elements_for_svd : int, default 10_000_000 Maximum Jacobian size (m × n elements) for SVD computation during stability checks. SVD is skipped for larger Jacobians. """ super().__init__() # not sure if this is needed self.trf = TrustRegionReflective() self.ls = LossFunctionsJIT() self.logger = get_logger("least_squares") # initialize jacobian to None and f to a dummy function self.f = lambda x: None self.jac: Callable[..., jnp.ndarray] | None = None # need a separate instance of the autodiff class for each of the # the different sigma/covariance cases self.adjn = AutoDiffJacobian() self.adj1d = AutoDiffJacobian() self.adj2d = AutoDiffJacobian() # Initialize unified cache for JIT compilation tracking self.cache = get_global_cache() # Stability and diagnostics systems self.enable_stability = enable_stability self.enable_diagnostics = enable_diagnostics self.max_jacobian_elements_for_svd = max_jacobian_elements_for_svd if enable_stability: self.stability_guard = NumericalStabilityGuard( max_jacobian_elements_for_svd=max_jacobian_elements_for_svd ) self.memory_manager = get_memory_manager() if enable_diagnostics: self.diagnostics = OptimizationDiagnostics()
def _validate_least_squares_inputs( self, x0: np.ndarray, bounds: tuple, method: str, jac, loss: str, verbose: int, max_nfev: float | None, ftol: float, xtol: float, gtol: float, x_scale, prepared_bounds: tuple[np.ndarray, np.ndarray] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, np.ndarray]: """Validate and prepare least squares inputs. Parameters ---------- prepared_bounds : tuple[np.ndarray, np.ndarray] or None, optional Pre-prepared bounds as (lb, ub). If provided, skips prepare_bounds call. This optimization avoids redundant bounds preparation when caller has already called prepare_bounds. Default is None. Returns ------- x0 : np.ndarray Validated initial guess lb : np.ndarray Lower bounds ub : np.ndarray Upper bounds ftol : float Function tolerance xtol : float Parameter tolerance gtol : float Gradient tolerance x_scale : np.ndarray Parameter scaling """ # Validate loss function if loss not in self.ls.IMPLEMENTED_LOSSES and not callable(loss): raise ValueError( f"`loss` must be one of {self.ls.IMPLEMENTED_LOSSES.keys()} or a callable." ) # Validate method if method != "trf": raise ValueError("`method` must be 'trf'") # Validate jac parameter if jac is not None and not callable(jac): raise ValueError("`jac` must be None or callable.") # Validate verbose level if verbose not in [0, 1, 2]: raise ValueError("`verbose` must be in [0, 1, 2].") # Validate bounds if len(bounds) != 2: raise ValueError("`bounds` must contain 2 elements.") # Validate max_nfev if max_nfev is not None and max_nfev <= 0: raise ValueError("`max_nfev` must be None or positive integer.") # Validate x0 if np.iscomplexobj(x0): raise ValueError("`x0` must be real.") x0 = np.atleast_1d(x0).astype(float) if x0.ndim > 1: raise ValueError("`x0` must have at most 1 dimension.") # Use prepared bounds if provided, otherwise prepare them if prepared_bounds is not None: lb, ub = prepared_bounds else: lb, ub = prepare_bounds(bounds, x0.shape[0]) if lb.shape != x0.shape or ub.shape != x0.shape: raise ValueError("Inconsistent shapes between bounds and `x0`.") if np.any(lb >= ub): raise ValueError( "Each lower bound must be strictly less than each upper bound." ) if not in_bounds(x0, lb, ub): raise ValueError("`x0` is infeasible.") # Check and prepare scaling/tolerances x_scale = check_x_scale(x_scale, x0) ftol, xtol, gtol = check_tolerance(ftol, xtol, gtol, method) x0 = make_strictly_feasible(x0, lb, ub) return x0, lb, ub, ftol, xtol, gtol, x_scale def _setup_functions( self, fun: Callable, jac: Callable | None, xdata: jnp.ndarray | None, ydata: jnp.ndarray | None, transform: jnp.ndarray | None, x0: np.ndarray, args: tuple, kwargs: dict, jacobian_mode_selected: str = "fwd", ) -> tuple: """Setup residual and Jacobian functions. Returns ------- rfunc : callable Residual function jac_func : callable Jacobian function """ if xdata is not None and ydata is not None: # Check if fit function needs updating func_update = False try: if hasattr(self.f, "__code__") and hasattr(fun, "__code__"): func_update = self.f.__code__.co_code != fun.__code__.co_code else: func_update = self.f != fun except Exception: func_update = True # Update function if needed if func_update: self.update_function(fun) if jac is None: self.autdiff_jac(jac, mode=jacobian_mode_selected) # Handle analytical Jacobian if jac is not None: if ( self.jac is None or self.jac.__code__.co_code != jac.__code__.co_code ): self.wrap_jac(jac) elif self.jac is not None and not func_update: self.autdiff_jac(jac, mode=jacobian_mode_selected) # Select appropriate residual function and Jacobian if transform is None: rfunc = self.func_none jac_func = self.jac_none elif transform.ndim == 1: rfunc = self.func_1d jac_func = self.jac_1d else: rfunc = self.func_2d jac_func = self.jac_2d else: # SciPy compatibility mode def wrap_func(fargs, xdata, ydata, data_mask, atransform): return jnp.atleast_1d(fun(fargs, *args, **kwargs)) rfunc = wrap_func if jac is None: adj = AutoDiffJacobian() jac_func = adj.create_ad_jacobian( wrap_func, x0.size, masked=False, mode=jacobian_mode_selected ) else: # Capture jac in closure with proper type narrowing jac_callable = jac def wrap_jac(fargs, xdata, ydata, data_mask, atransform): return jnp.atleast_2d(jac_callable(fargs, *args, **kwargs)) jac_func = wrap_jac return rfunc, jac_func def _evaluate_initial_residuals_and_jacobian( self, rfunc: Callable, jac_func: Callable, x0: np.ndarray, xdata: jnp.ndarray | None, ydata: jnp.ndarray | None, data_mask: jnp.ndarray | None, transform: jnp.ndarray | None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Evaluate initial residuals and Jacobian, with stability checks. Parameters ---------- rfunc : Callable Residual function jac_func : Callable Jacobian function x0 : np.ndarray Initial parameters xdata : jnp.ndarray | None X data ydata : jnp.ndarray | None Y data data_mask : jnp.ndarray | None Data mask transform : jnp.ndarray | None Transform matrix Returns ------- f0 : jnp.ndarray Initial residuals J0 : jnp.ndarray Initial Jacobian Raises ------ ValueError If residuals are not 1-D or not finite """ f0 = rfunc(x0, xdata, ydata, data_mask, transform) J0 = jac_func(x0, xdata, ydata, data_mask, transform) if f0.ndim != 1: raise ValueError( f"`fun` must return at most 1-d array_like. f0.shape: {f0.shape}" ) if not np.all(np.isfinite(f0)): if self.enable_stability: self.logger.warning("Non-finite residuals detected, attempting to fix") f0 = jnp.clip(f0, -1e10, 1e10) if not np.all(np.isfinite(f0)): raise ValueError("Residuals are not finite after stabilization") else: raise ValueError("Residuals are not finite in the initial point.") return f0, J0 def _check_and_fix_initial_jacobian( self, J0: jnp.ndarray, m: int, n: int ) -> jnp.ndarray: """Check and fix initial Jacobian if stability is enabled. Parameters ---------- J0 : jnp.ndarray Initial Jacobian m : int Number of residuals n : int Number of parameters Returns ------- J0 : jnp.ndarray Validated/fixed Jacobian Raises ------ ValueError If Jacobian has wrong shape """ # Check and fix Jacobian if stability is enabled if self.enable_stability and J0 is not None: J0_fixed, issues = self.stability_guard.check_and_fix_jacobian(J0) if issues: # Only warn if there's an actual problem, not just SVD skipped for performance has_problem = ( issues.get("has_nan") or issues.get("has_inf") or issues.get("is_ill_conditioned") or issues.get("regularized") ) if has_problem: self.logger.warning( "Jacobian issues detected and fixed", issues=issues ) elif issues.get("svd_skipped"): self.logger.debug( "SVD skipped for large Jacobian (expected for datasets > 10M points)", issues=issues, ) J0 = J0_fixed if J0 is not None and J0.shape != (m, n): raise ValueError( f"The return value of `jac` has wrong shape: expected {(m, n)}, " f"actual {J0.shape}." ) return J0 def _compute_initial_cost( self, f0: jnp.ndarray, loss: str | Callable, loss_function: Callable | None, f_scale: float, data_mask: jnp.ndarray, ) -> float: """Compute initial cost from residuals and loss function. Parameters ---------- f0 : jnp.ndarray Initial residuals loss : str | Callable Loss function name or callable loss_function : Callable | None Loss function implementation f_scale : float Loss function scale parameter data_mask : jnp.ndarray Data mask Returns ------- initial_cost : float Initial cost value Raises ------ ValueError If callable loss returns wrong shape """ m = f0.size self.logger.debug("Computing initial cost", loss_type=loss, f_scale=f_scale) if callable(loss): if loss_function is None: raise ValueError("loss_function must be provided when loss is callable") rho = loss_function(f0, f_scale, data_mask=data_mask) if rho.shape != (3, m): raise ValueError("The return value of `loss` callable has wrong shape.") initial_cost_jnp = self.trf.calculate_cost(rho, data_mask) elif loss_function is not None: initial_cost_jnp = loss_function( f0, f_scale, data_mask=data_mask, cost_only=True ) else: initial_cost_jnp = self.trf.default_loss_func(f0) return float(initial_cost_jnp) def _check_memory_and_adjust_solver( self, m: int, n: int, method: str, tr_solver: str | None ) -> str | None: """Check memory requirements and adjust solver if needed. Parameters ---------- m : int Number of residuals n : int Number of parameters method : str Optimization method tr_solver : str | None Current trust region solver Returns ------- tr_solver : str | None Adjusted trust region solver (or original if no adjustment needed) """ if self.enable_stability: memory_required = self.memory_manager.predict_memory_requirement( m, n, method ) is_available, msg = self.memory_manager.check_memory_availability( memory_required ) if not is_available: self.logger.warning("Memory constraint detected", details=msg) # Switch to memory-efficient solver tr_solver = "lsmr" return tr_solver def _create_stable_wrappers( self, rfunc: Callable, jac_func: Callable ) -> tuple[Callable, Callable]: """Create stability wrapper functions for residuals and Jacobian. NOTE: Stability checks are only performed at initialization, not per-iteration. Per-iteration Jacobian modification was found to cause optimization divergence due to accumulated numerical perturbations and expensive SVD computations. The residual wrapper still checks for NaN/Inf at each evaluation since this is a cheap O(n) check that can catch numerical explosions early. Parameters ---------- rfunc : Callable Original residual function jac_func : Callable Original Jacobian function Returns ------- rfunc : Callable Wrapped residual function (with NaN/Inf checking) jac_func : Callable Original Jacobian function (NOT wrapped - stability checked at init only) """ if self.enable_stability: original_rfunc = rfunc def stable_rfunc(x, xd, yd, dm, tf): result = original_rfunc(x, xd, yd, dm, tf) result = jnp.where( jnp.isfinite(result), result, jnp.clip(result, -1e10, 1e10) ) return result # NOTE: Jacobian is NOT wrapped - stability checked only at initialization # via _check_and_fix_initial_jacobian(). Per-iteration Jacobian modification # causes optimization divergence due to SVD overhead and accumulated perturbations. return stable_rfunc, jac_func return rfunc, jac_func def _run_trf_optimization( self, rfunc: Callable, jac_func: Callable, xdata: jnp.ndarray | None, ydata: jnp.ndarray | None, data_mask: jnp.ndarray, transform: jnp.ndarray | None, x0: np.ndarray, f0: jnp.ndarray, J0: jnp.ndarray, lb: np.ndarray, ub: np.ndarray, ftol: float, xtol: float, gtol: float, max_nfev: float | None, f_scale: float, x_scale: np.ndarray, loss_function: Callable | None, tr_options: dict, verbose: int, timeit: bool, tr_solver: str | None, method: str, loss: str, n: int, m: int, initial_cost: float, timeout_kwargs: dict, callback: Callable | None, ): """Run TRF optimization with diagnostics and logging. Returns ------- result : OptimizeResult Optimization result """ with self.logger.timer("optimization"): self.logger.debug("Calling TRF optimizer", initial_cost=initial_cost) # Initialize diagnostics if enabled if self.enable_diagnostics: self.diagnostics.start_optimization( n_params=n, n_data=m, method=method, loss=loss ) result = self.trf.trf( rfunc, xdata, ydata, jac_func, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options.copy(), verbose, timeit, solver=tr_solver or "exact", diagnostics=self.diagnostics if self.enable_diagnostics else None, callback=callback, **timeout_kwargs, ) return result def _process_optimization_result(self, result, initial_cost: float, verbose: int): """Process optimization result and log convergence. Parameters ---------- result : OptimizeResult Optimization result initial_cost : float Initial cost value verbose : int Verbosity level Returns ------- result : OptimizeResult Processed result with message and success flag """ result.message = TERMINATION_MESSAGES[result.status] result.success = result.status > 0 # Log convergence self.logger.convergence( reason=result.message, iterations=getattr(result, "nit", None), final_cost=result.cost, time_elapsed=self.logger.timers.get("optimization", 0), final_gradient_norm=getattr(result, "optimality", None), ) if verbose >= 1: self.logger.info(result.message) self.logger.info( f"Function evaluations {result.nfev}, initial cost {initial_cost:.4e}, final cost " f"{result.cost:.4e}, first-order optimality {result.optimality:.2e}." ) return result
[docs] def least_squares( self, fun: Callable, x0: ArrayLike, jac: Callable | None = None, bounds: BoundsTuple | tuple[float, float] = (-np.inf, np.inf), method: MethodLiteral = "trf", ftol: float = DEFAULT_FTOL, xtol: float = DEFAULT_XTOL, gtol: float = DEFAULT_GTOL, x_scale: Literal["jac"] | ArrayLike | float = 1.0, loss: str = "linear", f_scale: float = 1.0, diff_step: ArrayLike | None = None, tr_solver: Literal["exact", "lsmr"] | None = None, tr_options: dict[str, Any] | None = None, jac_sparsity: ArrayLike | None = None, max_nfev: float | None = None, verbose: int = 0, jacobian_mode: Literal["auto", "fwd", "rev"] | None = None, xdata: ArrayLike | None = None, ydata: ArrayLike | None = None, data_mask: ArrayLike | None = None, transform: ArrayLike | None = None, timeit: bool = False, callback: CallbackFunction | None = None, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, prepared_bounds: tuple[np.ndarray, np.ndarray] | None = None, **timeout_kwargs: Any, ) -> dict[str, Any]: """Solve nonlinear least squares problem using JAX-accelerated algorithms. This method orchestrates the optimization process by calling focused helper methods for each major step: validation, function setup, initial evaluation, stability checks, and optimization execution. Parameters ---------- fun : callable Residual function. Must use jax.numpy operations. x0 : array_like Initial parameter guess. jac : callable or None, optional Jacobian function. If None, uses JAX autodiff. bounds : 2-tuple, optional Parameter bounds as (lower, upper). method : str, optional Optimization algorithm ('trf'). ftol, xtol, gtol : float, optional Convergence tolerances for function, parameters, and gradient. x_scale : str or array_like, optional Parameter scaling ('jac' for automatic). loss : str or callable, optional Robust loss function ('linear', 'huber', 'soft_l1', etc.). f_scale : float, optional Scale parameter for robust loss functions. max_nfev : int, optional Maximum function evaluations. verbose : int, optional Verbosity level (0, 1, or 2). jacobian_mode : {'auto', 'fwd', 'rev'}, optional Jacobian automatic differentiation mode. If None, uses configuration from environment variable, config file, or auto-default. Default is None. - 'auto': Automatically select based on problem dimensions - 'fwd': Force forward-mode AD (jacfwd) - 'rev': Force reverse-mode AD (jacrev) xdata, ydata : array_like, optional Data for curve fitting applications. data_mask : array_like, optional Boolean mask for data exclusion. transform : array_like, optional Transformation matrix for weighted fitting. timeit : bool, optional Enable detailed timing analysis. callback : callable or None, optional Callback function called after each optimization iteration with signature ``callback(iteration, cost, params, info)``. Useful for monitoring optimization progress, logging, or implementing custom stopping criteria. If None (default), no callback is invoked. args : tuple, optional Additional arguments for objective function. kwargs : dict, optional Additional optimization parameters. prepared_bounds : tuple[np.ndarray, np.ndarray] or None, optional Pre-prepared bounds as (lb, ub). If provided, skips prepare_bounds call. This optimization avoids redundant bounds preparation when caller has already called prepare_bounds. Default is None. Returns ------- result : OptimizeResult Optimization result with solution, convergence info, and statistics. """ # Step 1: Initialize parameters and validate options if kwargs is None: kwargs = {} if tr_options is None: tr_options = {} if "options" in timeout_kwargs: raise TypeError("'options' is not a supported keyword argument") if data_mask is None and ydata is not None: data_mask = jnp.ones(len(ydata), dtype=bool) # Step 2: Validate inputs (pass prepared_bounds if available) x0, lb, ub, ftol, xtol, gtol, x_scale = self._validate_least_squares_inputs( x0, bounds, method, jac, loss, verbose, max_nfev, ftol, xtol, gtol, x_scale, prepared_bounds=prepared_bounds, ) self.n = len(x0) n = x0.size # Step 2.5: Determine Jacobian mode with configuration precedence # Precedence: function parameter > env var > config file > auto-default if jacobian_mode is not None: # Function parameter has highest priority jacobian_mode_config = jacobian_mode jacobian_mode_source = "function parameter" else: # Get from environment/config/default from nlsq.config import get_jacobian_mode jacobian_mode_config, jacobian_mode_source = get_jacobian_mode() # Step 3: Log optimization setup self.logger.info( "Starting least squares optimization", method=method, n_params=self.n, loss=loss, ftol=ftol, xtol=xtol, gtol=gtol, ) # Step 4: Setup residual and Jacobian functions # Use ydata length as proxy for n_residuals to avoid redundant forward pass. # This is correct because for standard NLSQ, n_residuals == len(ydata). # The actual m is verified after initial evaluation at Step 5. if jac is None and xdata is not None and ydata is not None: m_estimate = ( len(ydata) if hasattr(ydata, "__len__") else np.asarray(ydata).size ) # Select Jacobian mode based on problem dimensions jacobian_mode_selected, jacobian_rationale = jacobian_mode_selector( n, m_estimate, mode=jacobian_mode_config ) # Log Jacobian mode selection in debug mode self.logger.debug( f"Jacobian mode: '{jacobian_mode_selected}' (from {jacobian_mode_source}). Rationale: {jacobian_rationale}" ) else: # Analytical Jacobian or SciPy mode - use default forward mode jacobian_mode_selected = "fwd" jacobian_rationale = "analytical Jacobian or SciPy compatibility mode" self.logger.debug( f"Jacobian mode: '{jacobian_mode_selected}'. Rationale: {jacobian_rationale}" ) rfunc, jac_func = self._setup_functions( fun, jac, xdata, ydata, transform, x0, args, kwargs, jacobian_mode_selected ) # Step 5: Evaluate initial residuals and Jacobian f0, J0 = self._evaluate_initial_residuals_and_jacobian( rfunc, jac_func, x0, xdata, ydata, data_mask, transform ) m = f0.size # Step 6: Check and fix initial Jacobian J0 = self._check_and_fix_initial_jacobian(J0, m, n) # Step 7: Setup data mask and loss function if data_mask is None: data_mask = jnp.ones(m) loss_function = self.ls.get_loss_function(loss) # Step 8: Compute initial cost initial_cost = self._compute_initial_cost( f0, loss, loss_function, f_scale, data_mask ) # Step 8.5: Detect sparsity and auto-select sparse solver if beneficial # This happens AFTER initial Jacobian computation (so we have J0 available) # Auto-selection triggers when: sparsity >50% AND n_residuals >10K sparsity_ratio = 0.0 is_sparse_problem = False sparse_solver_selected = False # Skip sparsity detection when tr_solver is explicitly set (the caller # already knows what solver to use) or for small problems where the # overhead exceeds any benefit (sparse solver requires m > 10K). should_detect_sparsity = ( tr_solver is None and m > 10000 and xdata is not None and ydata is not None and fun is not None ) if should_detect_sparsity: # Import sparsity detection function from nlsq.core.sparse_jacobian import detect_sparsity_at_p0 # Detect sparsity at p0 (uses sampling for efficiency) try: sparsity_ratio, is_sparse_problem = detect_sparsity_at_p0( func=fun, p0=x0, xdata=xdata, n_residuals=m, threshold=0.01, sample_size=min(100, m), ) # Auto-select sparse solver only if user hasn't specified one if is_sparse_problem and tr_solver is None: tr_solver = "sparse" sparse_solver_selected = True self.logger.info( f"Sparse solver activated: sparsity={sparsity_ratio:.1%}, " f"n_residuals={m}, n_params={n}" ) elif tr_solver is None: tr_solver = "exact" self.logger.debug( f"Dense solver selected: low sparsity ({sparsity_ratio:.1%})" ) except Exception as e: # If sparsity detection fails, fall back to dense solver self.logger.warning( f"Sparsity detection failed: {e}. Using dense solver." ) if tr_solver is None: tr_solver = "exact" elif tr_solver is None: tr_solver = "exact" # Step 9: Check memory and adjust solver if needed tr_solver = self._check_memory_and_adjust_solver(m, n, method, tr_solver) # Step 10: Create stable wrappers for residual and Jacobian functions rfunc, jac_func = self._create_stable_wrappers(rfunc, jac_func) # Step 11: Run TRF optimization result = self._run_trf_optimization( rfunc, jac_func, xdata, ydata, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose, timeit, tr_solver, method, loss, n, m, initial_cost, timeout_kwargs, callback, ) # Step 12: Process optimization result result = self._process_optimization_result(result, initial_cost, verbose) # Step 13: Add sparsity diagnostics to result (Task 6.5) # This provides transparency about whether sparse solver was used result.sparsity_detected = { "detected": is_sparse_problem, "ratio": float(sparsity_ratio), "solver": "sparse" if sparse_solver_selected else "dense", "n_residuals": m, "n_params": n, } # Log sparsity info in debug mode self.logger.debug( f"Sparsity diagnostics: detected={is_sparse_problem}, " f"ratio={sparsity_ratio:.1%}, solver={'sparse' if sparse_solver_selected else 'dense'}" ) return result
[docs] def autdiff_jac(self, jac: None, mode: str = "fwd") -> None: """We do this for all three sigma transformed functions such that if sigma is changed from none to 1D to covariance sigma then no retracing is needed. Parameters ---------- jac : None Passed in to maintain compatibility with the user defined Jacobian function. mode : str, optional Jacobian mode ('fwd' or 'rev'), by default 'fwd' """ self.jac_none = self.adjn.create_ad_jacobian(self.func_none, self.n, mode=mode) self.jac_1d = self.adj1d.create_ad_jacobian(self.func_1d, self.n, mode=mode) self.jac_2d = self.adj2d.create_ad_jacobian(self.func_2d, self.n, mode=mode) # jac is self.jac = jac
[docs] def update_function(self, func: Callable) -> None: """Wraps the given fit function to be a residual function using the data. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function. Even for the case where the data mask is all True and the uncertainty transform is None we still need to pass these arguments to the function due JAX's functional nature. Parameters ---------- func : Callable The fit function to wrap. Returns ------- None """ # Note: Uses @jit (not cached_jit) because this closure captures 'func' # which changes each call, so caching based on source wouldn't work @jit def masked_residual_func( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, ) -> jnp.ndarray: """Compute the residual of the function evaluated at `args` with respect to the data. This function computes the residual of the user fit function evaluated at `args` with respect to the data `(xdata, ydata)`, masked by `data_mask`. The residual is defined as the difference between the function evaluation and the data. The masked residual is obtained by setting the residual to 0 wherever the corresponding element of `data_mask` is 0. Parameters ---------- args : jnp.ndarray The parameters of the function. xdata : jnp.ndarray The independent variable data. ydata : jnp.ndarray The dependent variable data. data_mask : jnp.ndarray The mask for the data. Returns ------- jnp.ndarray The masked residual of the function evaluated at `args` with respect to the data. """ # JAX 0.8.0+ handles tuple unpacking efficiently without TracerArrayConversionError # This replaces the previous 100-line if-elif chain (Optimization #2) # See: OPTIMIZATION_QUICK_REFERENCE.md for performance analysis func_eval = func(xdata, *args) - ydata return jnp.where(data_mask, func_eval, 0) # need to define a separate function for each of the different # sigma/covariance cases as the uncertainty transform is different # for each case. In future could remove the no transfore bit by setting # the uncertainty transform to all ones in the case where there is no # uncertainty transform. @jit def func_no_transform( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The residual function when there is no uncertainty transform. The atranform argument is not used in this case, but is included for consistency with the other cases.""" return masked_residual_func(args, xdata, ydata, data_mask) @jit def func_1d_transform( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The residual function when there is a 1D uncertainty transform, that is when only the diagonal elements of the inverse covariance matrix are used.""" # OPT-11: Inlined masked_residual_func for XLA fusion optimization # XLA can better fuse operations when they're in the same JIT scope func_eval = func(xdata, *args) - ydata masked_residual = jnp.where(data_mask, func_eval, 0) return atransform * masked_residual @jit def func_2d_transform( args: jnp.ndarray, xdata: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The residual function when there is a 2D uncertainty transform, that is when the full covariance matrix is given.""" # OPT-11: Inlined masked_residual_func for XLA fusion optimization func_eval = func(xdata, *args) - ydata masked_residual = jnp.where(data_mask, func_eval, 0) return jax_solve_triangular(atransform, masked_residual, lower=True) self.func_none = func_no_transform self.func_1d = func_1d_transform self.func_2d = func_2d_transform self.f = func
[docs] def wrap_jac(self, jac: Callable) -> None: """Wraps an user defined Jacobian function to allow for data masking and uncertainty transforms. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function. Using an analytical Jacobian of the fit function is equivalent to the Jacobian of the residual function. Also note that the analytical Jacobian doesn't require the independent ydata, but we still need to pass it to the function to maintain compatibility with autdiff version which does require the ydata. Parameters ---------- jac : Callable The Jacobian function to wrap. Returns ------- jnp.ndarray The masked Jacobian of the function evaluated at `args` with respect to the data. """ # Note: Uses @jit (not cached_jit) because these closures capture 'jac' # which changes each call, so caching based on source wouldn't work @jit def jac_func(coords: jnp.ndarray, args: jnp.ndarray) -> jnp.ndarray: # Create individual arguments from the array for JAX compatibility # This avoids the TracerArrayConversionError with dynamic unpacking if args.size == 1: jac_fwd = jac(coords, args[0]) elif args.size == 2: jac_fwd = jac(coords, args[0], args[1]) elif args.size == 3: jac_fwd = jac(coords, args[0], args[1], args[2]) elif args.size == 4: jac_fwd = jac(coords, args[0], args[1], args[2], args[3]) elif args.size == 5: jac_fwd = jac(coords, args[0], args[1], args[2], args[3], args[4]) elif args.size == 6: jac_fwd = jac( coords, args[0], args[1], args[2], args[3], args[4], args[5] ) else: # For more parameters, use a more generic approach args_list = [args[i] for i in range(args.size)] jac_fwd = jac(coords, *args_list) return jnp.array(jac_fwd) @jit def masked_jac( coords: jnp.ndarray, args: jnp.ndarray, data_mask: jnp.ndarray ) -> jnp.ndarray: """Compute the wrapped Jacobian but masks out the padded elements with 0s""" Jt = jac_func(coords, args) return jnp.where(data_mask, Jt, 0).T @jit def jac_no_transform( args: jnp.ndarray, coords: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The wrapped Jacobian function when there is no uncertainty transform.""" return jnp.atleast_2d(masked_jac(coords, args, data_mask)) @jit def jac_1d_transform( args: jnp.ndarray, coords: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The wrapped Jacobian function when there is a 1D uncertainty transform, that is when only the diagonal elements of the inverse covariance matrix are used.""" J = masked_jac(coords, args, data_mask) return jnp.atleast_2d(atransform[:, jnp.newaxis] * jnp.asarray(J)) @jit def jac_2d_transform( args: jnp.ndarray, coords: jnp.ndarray, ydata: jnp.ndarray, data_mask: jnp.ndarray, atransform: jnp.ndarray, ) -> jnp.ndarray: """The wrapped Jacobian function when there is a 2D uncertainty transform, that is when the full covariance matrix is given.""" J = masked_jac(coords, args, data_mask) return jnp.atleast_2d( jax_solve_triangular(atransform, jnp.asarray(J), lower=True) ) # we need all three versions of the Jacobian function to allow for # changing the sigma transform from none to 1D to 2D without having # to retrace the function self.jac_none = jac_no_transform self.jac_1d = jac_1d_transform self.jac_2d = jac_2d_transform self.jac = jac