Source code for nlsq.result.optimize_result

"""Optimization result container for NLSQ curve fitting operations.

This module provides the OptimizeResult class, which stores the complete
results from nonlinear least squares optimization performed using JAX-accelerated
algorithms.

Usage
-----
Access results using attribute syntax::

    result.x        # Optimized parameters
    result.success  # Convergence status
    result.cost     # Final cost value

For dictionary conversion, use::

    result.to_dict()  # Convert to dict
"""

from dataclasses import dataclass
from typing import Any

import jax.numpy as jnp


[docs] @dataclass(frozen=True, slots=True) class OptimizeResultV2: """Memory-efficient optimization result container (v2). This class provides a memory-efficient alternative to OptimizeResult using Python's frozen dataclass with slots. It offers: - ~40% memory reduction per instance (no __dict__) - ~2x faster attribute access (direct slot access vs dict lookup) - Immutability for thread-safety and caching Core Attributes --------------- x : jnp.ndarray Optimized parameter vector containing the final fitted parameters. success : bool Indicates whether the optimization terminated successfully. cost : float Final cost function value: 0.5 * ||f(x)||². fun : jnp.ndarray Final residual vector f(x) at the solution. Optional Attributes ------------------- jac : jnp.ndarray | None Final Jacobian matrix J(x). None if not requested (saves ~400KB for 10k×50). grad : jnp.ndarray | None Final gradient vector g = J^T * f. optimality : float Final gradient norm ||g||_inf. active_mask : jnp.ndarray | None Boolean mask indicating which parameters hit bounds. nfev : int Total number of objective function evaluations. njev : int Total number of Jacobian evaluations. nit : int Number of optimization iterations completed. status : int Numerical termination status code. message : str Human-readable description of termination cause. pcov : jnp.ndarray | None Parameter covariance matrix. all_times : dict | None Detailed timing information for profiling. Examples -------- >>> result.x # Access optimized parameters >>> result.success # Check convergence >>> result.cost # Get final cost value >>> result.to_dict() # Convert to dictionary """ x: jnp.ndarray success: bool cost: float fun: jnp.ndarray jac: jnp.ndarray | None = None grad: jnp.ndarray | None = None optimality: float = 0.0 active_mask: jnp.ndarray | None = None nfev: int = 0 njev: int = 0 nit: int = 0 status: int = 0 message: str = "" pcov: jnp.ndarray | None = None all_times: dict[str, Any] | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary. Returns ------- dict Dictionary containing all non-None fields. """ result = {} for field_name in self.__slots__: value = getattr(self, field_name) if value is not None or field_name in ("x", "success", "cost", "fun"): result[field_name] = value return result
[docs] def __repr__(self) -> str: """Compact representation showing key fields.""" return ( f"OptimizeResultV2(success={self.success}, cost={self.cost:.6e}, " f"nfev={self.nfev}, status={self.status})" )
[docs] class OptimizeResult(dict): """Optimization result container for NLSQ curve fitting operations. This class stores the complete results from nonlinear least squares optimization performed using JAX-accelerated algorithms. It extends dict to provide both dictionary-style and attribute-style access to optimization results. Core Attributes --------------- x : jax.numpy.ndarray or numpy.ndarray Optimized parameter vector containing the final fitted parameters. These represent the solution to the nonlinear least squares problem. success : bool Indicates whether the optimization terminated successfully. True means convergence criteria were satisfied within tolerance limits. status : int Numerical termination status code indicating why optimization stopped: - 1: Gradient convergence (||g||_inf < gtol) - 2: Step size convergence (||dx||/||x|| < xtol) - 3: Function value convergence (delta_f/f < ftol) - 0: Maximum iterations reached - -1: Evaluation limit exceeded - -3: Inner loop iteration limit (algorithm-specific) message : str Human-readable description of termination cause. Provides detailed information about convergence status or failure reasons. Objective Function Results --------------------------- fun : jax.numpy.ndarray Final residual vector f(x) at the solution. For curve fitting, these are the differences between model predictions and data points. cost : float Final cost function value: 0.5 * ||f(x)||² for standard least squares, or 0.5 * sum(ρ(f_i²/σ²)) for robust loss functions. jac : jax.numpy.ndarray Final Jacobian matrix J(x) with shape (m, n) where m is number of data points and n is number of parameters. Computed using JAX autodiff. grad : jax.numpy.ndarray Final gradient vector g = J^T * f with shape (n,). Used for convergence checking and parameter uncertainty estimation. Convergence Metrics ------------------- optimality : float Final gradient norm ||g||_inf used for convergence assessment. Should be less than gtol for successful convergence. active_mask : numpy.ndarray Boolean mask indicating which parameters hit bounds (for bounded optimization). Shape (n,) with True for parameters at constraints. Iteration Statistics -------------------- nfev : int Total number of objective function evaluations during optimization. Each evaluation computes residuals f(x) for given parameters. njev : int Total number of Jacobian evaluations. With JAX autodiff, this equals the number of combined function+gradient evaluations. nit : int Number of optimization iterations completed. Not always available for all algorithms. Algorithm-Specific Results --------------------------- pcov : jax.numpy.ndarray, optional Parameter covariance matrix with shape (n, n). Provides parameter uncertainty estimates. Available when uncertainty estimation is requested. Computed as: pcov = inv(J^T * J) * residual_variance active_mask : numpy.ndarray For bounded optimization, indicates which parameters are at bounds. all_times : dict, optional Detailed timing information for algorithm profiling. Contains timing data for different optimization phases (function evaluation, Jacobian computation, linear algebra operations, etc.). Usage Examples -------------- Basic result access:: import nlsq # Perform curve fitting result = nlsq.curve_fit(model_func, x_data, y_data, p0=initial_guess) # Access optimized parameters fitted_params = result.x # Check convergence if result.success: print(f"Optimization converged: {result.message}") print(f"Final cost: {result.cost}") print(f"Function evaluations: {result.nfev}") else: print(f"Optimization failed: {result.message}") # Parameter uncertainties (if covariance computed) if hasattr(result, 'pcov'): param_errors = jnp.sqrt(jnp.diag(result.pcov)) print(f"Parameter uncertainties: {param_errors}") Advanced result inspection:: # Examine residuals and fit quality final_residuals = result.fun rms_error = jnp.sqrt(jnp.mean(final_residuals**2)) # Check gradient convergence gradient_norm = result.optimality print(f"Final gradient norm: {gradient_norm}") # Analyze Jacobian condition jacobian = result.jac condition_number = jnp.linalg.cond(jacobian) print(f"Jacobian condition number: {condition_number}") # For bounded problems, check active constraints if hasattr(result, 'active_mask'): constrained_params = jnp.where(result.active_mask)[0] print(f"Parameters at bounds: {constrained_params}") Integration with SciPy ---------------------- This class maintains compatibility with scipy.optimize.OptimizeResult while adding JAX-specific features and NLSQ-specific results. It can be used interchangeably with SciPy optimization results in most contexts. Technical Notes --------------- - All JAX arrays are automatically converted to NumPy arrays for compatibility - Covariance matrices use double precision for numerical stability - Large dataset results may include memory management statistics - GPU timing results require explicit timing mode activation - Progress monitoring data is stored in algorithm-specific attributes """ def __getattr__(self, name): try: return self[name] except KeyError as e: raise AttributeError(name) from e __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def __repr__(self): if self.keys(): m = max(map(len, list(self.keys()))) + 1 return "\n".join( [k.rjust(m) + ": " + repr(v) for k, v in sorted(self.items())] ) else: return self.__class__.__name__ + "()" def __dir__(self): return list(self.keys())
# Legacy alias for explicit backward compatibility # Users who want the dict-based behavior after v1.0.0 can use this OptimizeResultLegacy = OptimizeResult