Source code for nlsq.result.optimize_result
"""Optimization result container for NLSQ curve fitting operations.
This module provides the OptimizeResult class, which stores the complete
results from nonlinear least squares optimization performed using JAX-accelerated
algorithms.
Usage
-----
Access results using attribute syntax::
result.x # Optimized parameters
result.success # Convergence status
result.cost # Final cost value
For dictionary conversion, use::
result.to_dict() # Convert to dict
"""
from dataclasses import dataclass
from typing import Any
import jax.numpy as jnp
[docs]
@dataclass(frozen=True, slots=True)
class OptimizeResultV2:
"""Memory-efficient optimization result container (v2).
This class provides a memory-efficient alternative to OptimizeResult using
Python's frozen dataclass with slots. It offers:
- ~40% memory reduction per instance (no __dict__)
- ~2x faster attribute access (direct slot access vs dict lookup)
- Immutability for thread-safety and caching
Core Attributes
---------------
x : jnp.ndarray
Optimized parameter vector containing the final fitted parameters.
success : bool
Indicates whether the optimization terminated successfully.
cost : float
Final cost function value: 0.5 * ||f(x)||².
fun : jnp.ndarray
Final residual vector f(x) at the solution.
Optional Attributes
-------------------
jac : jnp.ndarray | None
Final Jacobian matrix J(x). None if not requested (saves ~400KB for 10k×50).
grad : jnp.ndarray | None
Final gradient vector g = J^T * f.
optimality : float
Final gradient norm ||g||_inf.
active_mask : jnp.ndarray | None
Boolean mask indicating which parameters hit bounds.
nfev : int
Total number of objective function evaluations.
njev : int
Total number of Jacobian evaluations.
nit : int
Number of optimization iterations completed.
status : int
Numerical termination status code.
message : str
Human-readable description of termination cause.
pcov : jnp.ndarray | None
Parameter covariance matrix.
all_times : dict | None
Detailed timing information for profiling.
Examples
--------
>>> result.x # Access optimized parameters
>>> result.success # Check convergence
>>> result.cost # Get final cost value
>>> result.to_dict() # Convert to dictionary
"""
x: jnp.ndarray
success: bool
cost: float
fun: jnp.ndarray
jac: jnp.ndarray | None = None
grad: jnp.ndarray | None = None
optimality: float = 0.0
active_mask: jnp.ndarray | None = None
nfev: int = 0
njev: int = 0
nit: int = 0
status: int = 0
message: str = ""
pcov: jnp.ndarray | None = None
all_times: dict[str, Any] | None = None
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary.
Returns
-------
dict
Dictionary containing all non-None fields.
"""
result = {}
for field_name in self.__slots__:
value = getattr(self, field_name)
if value is not None or field_name in ("x", "success", "cost", "fun"):
result[field_name] = value
return result
[docs]
def __repr__(self) -> str:
"""Compact representation showing key fields."""
return (
f"OptimizeResultV2(success={self.success}, cost={self.cost:.6e}, "
f"nfev={self.nfev}, status={self.status})"
)
[docs]
class OptimizeResult(dict):
"""Optimization result container for NLSQ curve fitting operations.
This class stores the complete results from nonlinear least squares optimization
performed using JAX-accelerated algorithms. It extends dict to provide both
dictionary-style and attribute-style access to optimization results.
Core Attributes
---------------
x : jax.numpy.ndarray or numpy.ndarray
Optimized parameter vector containing the final fitted parameters.
These represent the solution to the nonlinear least squares problem.
success : bool
Indicates whether the optimization terminated successfully. True means
convergence criteria were satisfied within tolerance limits.
status : int
Numerical termination status code indicating why optimization stopped:
- 1: Gradient convergence (||g||_inf < gtol)
- 2: Step size convergence (||dx||/||x|| < xtol)
- 3: Function value convergence (delta_f/f < ftol)
- 0: Maximum iterations reached
- -1: Evaluation limit exceeded
- -3: Inner loop iteration limit (algorithm-specific)
message : str
Human-readable description of termination cause. Provides detailed
information about convergence status or failure reasons.
Objective Function Results
---------------------------
fun : jax.numpy.ndarray
Final residual vector f(x) at the solution. For curve fitting, these
are the differences between model predictions and data points.
cost : float
Final cost function value: 0.5 * ||f(x)||² for standard least squares,
or 0.5 * sum(ρ(f_i²/σ²)) for robust loss functions.
jac : jax.numpy.ndarray
Final Jacobian matrix J(x) with shape (m, n) where m is number of
data points and n is number of parameters. Computed using JAX autodiff.
grad : jax.numpy.ndarray
Final gradient vector g = J^T * f with shape (n,). Used for
convergence checking and parameter uncertainty estimation.
Convergence Metrics
-------------------
optimality : float
Final gradient norm ||g||_inf used for convergence assessment.
Should be less than gtol for successful convergence.
active_mask : numpy.ndarray
Boolean mask indicating which parameters hit bounds (for bounded
optimization). Shape (n,) with True for parameters at constraints.
Iteration Statistics
--------------------
nfev : int
Total number of objective function evaluations during optimization.
Each evaluation computes residuals f(x) for given parameters.
njev : int
Total number of Jacobian evaluations. With JAX autodiff, this equals
the number of combined function+gradient evaluations.
nit : int
Number of optimization iterations completed. Not always available
for all algorithms.
Algorithm-Specific Results
---------------------------
pcov : jax.numpy.ndarray, optional
Parameter covariance matrix with shape (n, n). Provides parameter
uncertainty estimates. Available when uncertainty estimation is requested.
Computed as: pcov = inv(J^T * J) * residual_variance
active_mask : numpy.ndarray
For bounded optimization, indicates which parameters are at bounds.
all_times : dict, optional
Detailed timing information for algorithm profiling. Contains timing
data for different optimization phases (function evaluation, Jacobian
computation, linear algebra operations, etc.).
Usage Examples
--------------
Basic result access::
import nlsq
# Perform curve fitting
result = nlsq.curve_fit(model_func, x_data, y_data, p0=initial_guess)
# Access optimized parameters
fitted_params = result.x
# Check convergence
if result.success:
print(f"Optimization converged: {result.message}")
print(f"Final cost: {result.cost}")
print(f"Function evaluations: {result.nfev}")
else:
print(f"Optimization failed: {result.message}")
# Parameter uncertainties (if covariance computed)
if hasattr(result, 'pcov'):
param_errors = jnp.sqrt(jnp.diag(result.pcov))
print(f"Parameter uncertainties: {param_errors}")
Advanced result inspection::
# Examine residuals and fit quality
final_residuals = result.fun
rms_error = jnp.sqrt(jnp.mean(final_residuals**2))
# Check gradient convergence
gradient_norm = result.optimality
print(f"Final gradient norm: {gradient_norm}")
# Analyze Jacobian condition
jacobian = result.jac
condition_number = jnp.linalg.cond(jacobian)
print(f"Jacobian condition number: {condition_number}")
# For bounded problems, check active constraints
if hasattr(result, 'active_mask'):
constrained_params = jnp.where(result.active_mask)[0]
print(f"Parameters at bounds: {constrained_params}")
Integration with SciPy
----------------------
This class maintains compatibility with scipy.optimize.OptimizeResult
while adding JAX-specific features and NLSQ-specific results. It can
be used interchangeably with SciPy optimization results in most contexts.
Technical Notes
---------------
- All JAX arrays are automatically converted to NumPy arrays for compatibility
- Covariance matrices use double precision for numerical stability
- Large dataset results may include memory management statistics
- GPU timing results require explicit timing mode activation
- Progress monitoring data is stored in algorithm-specific attributes
"""
def __getattr__(self, name):
try:
return self[name]
except KeyError as e:
raise AttributeError(name) from e
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __repr__(self):
if self.keys():
m = max(map(len, list(self.keys()))) + 1
return "\n".join(
[k.rjust(m) + ": " + repr(v) for k, v in sorted(self.items())]
)
else:
return self.__class__.__name__ + "()"
def __dir__(self):
return list(self.keys())
# Legacy alias for explicit backward compatibility
# Users who want the dict-based behavior after v1.0.0 can use this
OptimizeResultLegacy = OptimizeResult