nlsq.types module¶
Type aliases and protocols for NLSQ.
This module provides type hints for the NLSQ public API to improve IDE support, documentation, and static type checking with mypy.
Note: These types are primarily for documentation and tooling. Python’s duck typing means functions will work with any compatible objects at runtime.
Note: JAX imports are deferred using TYPE_CHECKING to avoid import-time errors during documentation builds or in environments where JAX is not fully configured. At runtime, JAX array types are represented as Any to allow the module to be imported without JAX being available.
- class nlsq.types.AggregateStats[source]
Bases:
TypedDictAggregate statistics across batches.
- mean_loss: float
- std_loss: float
- mean_grad_norm: float
- min_loss: float
- max_loss: float
- class nlsq.types.CheckpointInfo[source]
Bases:
TypedDictCheckpoint information in streaming diagnostics.
- saved_at: str
- batch_idx: int
- iteration: int
- file_size: int
- class nlsq.types.CommonError[source]
Bases:
TypedDictCommon error entry in diagnostics.
- type: str
- count: int
- nlsq.types.FloatArray
alias of
ndarray
- class nlsq.types.HasShape(*args, **kwargs)[source]
Bases:
ProtocolProtocol for objects with a shape attribute.
- __init__(*args, **kwargs)
- nlsq.types.JAXArray
alias of
Any
- nlsq.types.MethodLiteral
alias of
str
- nlsq.types.SolverLiteral
alias of
str
- class nlsq.types.StreamingDiagnostics[source]
Bases:
TypedDictComprehensive diagnostics for streaming optimization.
This structure matches the format from chunked processing for consistency.
- batch_success_rate: float
- checkpoint_info: CheckpointInfo
- aggregate_stats: AggregateStats
- common_errors: list[CommonError]
- total_batches_processed: int
- total_retries: int
- convergence_achieved: bool
- final_epoch: int
- total_time: float
- mean_batch_time: float
- checkpoint_save_time: float
- class nlsq.types.SupportsFloat(*args, **kwargs)[source]
Bases:
ProtocolProtocol for objects that can be converted to float.
- __float__()[source]
Convert to float.
- __init__(*args, **kwargs)
Overview¶
The types module defines type annotations and type aliases used throughout NLSQ.
Type Aliases¶
This module provides type aliases for:
Array types:
Array,ArrayLikeFunction types:
ModelFunction,ResidualFunctionResult types:
OptimizeResultConfiguration types:
Config,MemoryConfig
Example Usage¶
from nlsq.types import Array, ModelFunction
import jax.numpy as jnp
# Type-annotated function
def my_model(x: Array, a: float, b: float) -> Array:
return a * jnp.exp(-b * x)
# Using ModelFunction type alias
model: ModelFunction = my_model
Benefits¶
Using type annotations from this module:
Improved IDE support with autocomplete
Type checking with mypy/pyright
Better documentation for function signatures
Clearer code with explicit types
Type Definitions¶
from typing import Callable
import jax.numpy as jnp
# Common type aliases
Array = jnp.ndarray
ArrayLike = Union[Array, np.ndarray, Sequence[float]]
ModelFunction = Callable[[Array, ...], Array]
ResidualFunction = Callable[[Array], Array]
See Also¶
nlsq.result module - Result containers
nlsq.config module - Configuration types