nlsq.types module

Type aliases and protocols for NLSQ.

This module provides type hints for the NLSQ public API to improve IDE support, documentation, and static type checking with mypy.

Note: These types are primarily for documentation and tooling. Python’s duck typing means functions will work with any compatible objects at runtime.

Note: JAX imports are deferred using TYPE_CHECKING to avoid import-time errors during documentation builds or in environments where JAX is not fully configured. At runtime, JAX array types are represented as Any to allow the module to be imported without JAX being available.

class nlsq.types.AggregateStats[source]

Bases: TypedDict

Aggregate statistics across batches.

mean_loss: float
std_loss: float
mean_grad_norm: float
min_loss: float
max_loss: float
class nlsq.types.CheckpointInfo[source]

Bases: TypedDict

Checkpoint information in streaming diagnostics.

path: str | None
saved_at: str
batch_idx: int
iteration: int
file_size: int
class nlsq.types.CommonError[source]

Bases: TypedDict

Common error entry in diagnostics.

type: str
count: int
nlsq.types.FloatArray

alias of ndarray

class nlsq.types.HasShape(*args, **kwargs)[source]

Bases: Protocol

Protocol for objects with a shape attribute.

property shape: tuple[int, ...]

Shape of the array.

__init__(*args, **kwargs)
nlsq.types.JAXArray

alias of Any

nlsq.types.MethodLiteral

alias of str

nlsq.types.SolverLiteral

alias of str

class nlsq.types.StreamingDiagnostics[source]

Bases: TypedDict

Comprehensive diagnostics for streaming optimization.

This structure matches the format from chunked processing for consistency.

failed_batches: list[int]
retry_counts: dict[int, int]
error_types: dict[str, int]
batch_success_rate: float
checkpoint_info: CheckpointInfo
recent_batch_stats: list[dict[str, Any]]
aggregate_stats: AggregateStats
common_errors: list[CommonError]
total_batches_processed: int
total_retries: int
convergence_achieved: bool
final_epoch: int
total_time: float
mean_batch_time: float
checkpoint_save_time: float
class nlsq.types.SupportsFloat(*args, **kwargs)[source]

Bases: Protocol

Protocol for objects that can be converted to float.

__float__()[source]

Convert to float.

__init__(*args, **kwargs)

Overview

The types module defines type annotations and type aliases used throughout NLSQ.

Type Aliases

This module provides type aliases for:

  • Array types: Array, ArrayLike

  • Function types: ModelFunction, ResidualFunction

  • Result types: OptimizeResult

  • Configuration types: Config, MemoryConfig

Example Usage

from nlsq.types import Array, ModelFunction
import jax.numpy as jnp


# Type-annotated function
def my_model(x: Array, a: float, b: float) -> Array:
    return a * jnp.exp(-b * x)


# Using ModelFunction type alias
model: ModelFunction = my_model

Benefits

Using type annotations from this module:

  • Improved IDE support with autocomplete

  • Type checking with mypy/pyright

  • Better documentation for function signatures

  • Clearer code with explicit types

Type Definitions

from typing import Callable
import jax.numpy as jnp

# Common type aliases
Array = jnp.ndarray
ArrayLike = Union[Array, np.ndarray, Sequence[float]]

ModelFunction = Callable[[Array, ...], Array]
ResidualFunction = Callable[[Array], Array]

See Also