"""Type aliases and protocols for NLSQ.
This module provides type hints for the NLSQ public API to improve IDE support,
documentation, and static type checking with mypy.
Note: These types are primarily for documentation and tooling. Python's duck typing
means functions will work with any compatible objects at runtime.
Note: JAX imports are deferred using TYPE_CHECKING to avoid import-time errors
during documentation builds or in environments where JAX is not fully configured.
At runtime, JAX array types are represented as Any to allow the module to be
imported without JAX being available.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypedDict
import numpy as np
# =============================================================================
# Array-like types
# =============================================================================
# At runtime, we use Any for JAX arrays to avoid import-time JAX dependency.
# Static type checkers will see the full type via TYPE_CHECKING block.
if TYPE_CHECKING:
import jax.numpy as jnp
# Full type definitions for static analysis
ArrayLike: TypeAlias = np.ndarray | jnp.ndarray | list | tuple
JAXArray: TypeAlias = jnp.ndarray
else:
# Runtime definitions - avoid JAX import
# Use Any as a stand-in for jax.numpy.ndarray
ArrayLike: TypeAlias = np.ndarray | Any | list | tuple
JAXArray: TypeAlias = Any
# NumPy array of floating point numbers.
FloatArray: TypeAlias = np.ndarray
# =============================================================================
# Function types
# =============================================================================
# Model function f(x, *params) -> y_pred.
#
# The model function takes independent variable(s) x and fit parameters,
# returning predicted dependent variable values.
#
# Examples:
# - Linear: f(x, a, b) = a*x + b
# - Exponential: f(x, a, b) = a * exp(-b * x)
# - Multi-parameter: f(x, p1, p2, ..., pN) = ...
ModelFunction: TypeAlias = Callable[..., ArrayLike]
# Jacobian function jac(x, *params) -> J.
#
# The Jacobian function computes the matrix of partial derivatives:
# J[i, j] = ∂f[i]/∂params[j]
#
# Parameters:
# x: Independent variable(s)
# *params: Fit parameters
#
# Returns:
# J: Jacobian matrix of shape (m, n) where m = len(f(x)) and n = len(params)
JacobianFunction: TypeAlias = Callable[..., ArrayLike]
# Callback function for monitoring optimization progress.
#
# Parameters:
# params: Current parameter estimates
# residuals: Current residual values
#
# Returns:
# True to stop optimization, False/None to continue
CallbackFunction: TypeAlias = Callable[[FloatArray, FloatArray], bool | None]
# Loss function rho(z) for robust fitting.
#
# Robust loss functions reduce the influence of outliers by applying
# a non-linear transformation to residuals.
#
# Parameters:
# z: Squared residuals (z = residuals**2)
#
# Returns:
# rho: Transformed residuals for robust fitting
LossFunction: TypeAlias = Callable[[FloatArray], FloatArray]
# =============================================================================
# Bounds types
# =============================================================================
# Parameter bounds as (lower, upper) tuple.
#
# Examples:
# - Unbounded: (-np.inf, np.inf)
# - Lower only: (0, np.inf) for positive parameters
# - Both: ([-1, 0], [1, 10]) for constrained parameters
BoundsTuple: TypeAlias = tuple[ArrayLike, ArrayLike]
# =============================================================================
# Result types
# =============================================================================
# Optimization result dictionary with parameters and diagnostics.
#
# Common fields:
# - x: Optimized parameters
# - success: Whether optimization converged
# - message: Optimization status message
# - fun: Final residual values (optional)
# - jac: Final Jacobian (optional)
# - cost: Final cost function value
# - optimality: Final gradient norm
# - nfev: Number of function evaluations
# - njev: Number of Jacobian evaluations (optional)
OptimizeResultDict: TypeAlias = dict[str, Any]
# =============================================================================
# Configuration types
# =============================================================================
# Optimization method name.
#
# Options:
# - "trf": Trust Region Reflective (default, supports bounds)
# - "dogbox": Dogleg algorithm for box-constrained problems
# - "lm": Levenberg-Marquardt (unconstrained only, faster)
MethodLiteral: TypeAlias = str # "trf" | "dogbox" | "lm"
# Linear solver for trust region subproblems.
#
# Options:
# - "exact": Direct solver using SVD (default, more accurate)
# - "lsmr": Iterative solver (faster for large problems)
SolverLiteral: TypeAlias = str # "exact" | "lsmr"
# =============================================================================
# Streaming diagnostic types (Task 6.2)
# =============================================================================
[docs]
class CheckpointInfo(TypedDict, total=False):
"""Checkpoint information in streaming diagnostics."""
path: str | None # Path to checkpoint file
saved_at: str # Timestamp when checkpoint was saved
batch_idx: int # Batch index at checkpoint
iteration: int # Iteration number at checkpoint
file_size: int # Size of checkpoint file in bytes (optional)
[docs]
class CommonError(TypedDict):
"""Common error entry in diagnostics."""
type: str # Error type name
count: int # Number of occurrences
[docs]
class AggregateStats(TypedDict, total=False):
"""Aggregate statistics across batches."""
mean_loss: float # Mean batch loss
std_loss: float # Standard deviation of batch losses
mean_grad_norm: float # Mean gradient norm
min_loss: float # Minimum batch loss
max_loss: float # Maximum batch loss
[docs]
class StreamingDiagnostics(TypedDict, total=False):
"""Comprehensive diagnostics for streaming optimization.
This structure matches the format from chunked processing for consistency.
"""
failed_batches: list[int] # List of failed batch indices
retry_counts: dict[int, int] # Retry attempts per batch index
error_types: dict[str, int] # Count of each error type
batch_success_rate: float # Overall success rate (0.0 to 1.0)
checkpoint_info: CheckpointInfo # Checkpoint details
recent_batch_stats: list[dict[str, Any]] # Circular buffer of recent batch stats
aggregate_stats: AggregateStats # Aggregate metrics across all batches
common_errors: list[CommonError] # Top 3 most common errors
# Streaming-specific fields
total_batches_processed: int # Total number of batches attempted
total_retries: int # Total number of retry attempts
convergence_achieved: bool # Whether convergence criteria was met
final_epoch: int # Epoch at which optimization ended
# Timing information
total_time: float # Total optimization time in seconds
mean_batch_time: float # Average time per batch
checkpoint_save_time: float # Total time spent saving checkpoints
# =============================================================================
# Protocols for structural typing
# =============================================================================
[docs]
class HasShape(Protocol):
"""Protocol for objects with a shape attribute."""
@property
def shape(self) -> tuple[int, ...]:
"""Shape of the array."""
...
[docs]
class SupportsFloat(Protocol):
"""Protocol for objects that can be converted to float."""
[docs]
def __float__(self) -> float:
"""Convert to float."""
...
# Re-export commonly used types from dependencies
__all__ = [
"AggregateStats",
# Array types
"ArrayLike",
# Bounds and results
"BoundsTuple",
"CallbackFunction",
"CheckpointInfo",
"CommonError",
"FloatArray",
# Protocols
"HasShape",
"JAXArray",
"JacobianFunction",
"LossFunction",
# Method/solver literals
"MethodLiteral",
# Function types
"ModelFunction",
"OptimizeResultDict",
"SolverLiteral",
# Streaming diagnostics
"StreamingDiagnostics",
"SupportsFloat",
]