Source code for nlsq.interfaces.data_source_protocol

"""Protocol definition for data sources.

This module defines the DataSourceProtocol that data providers should implement,
enabling support for different data backends (arrays, HDF5, streaming).
"""

from typing import Protocol, runtime_checkable

import numpy as np


[docs] @runtime_checkable class DataSourceProtocol(Protocol): """Protocol for data sources. This protocol defines the interface for data providers, allowing support for various backends like NumPy arrays, HDF5 files, or streaming data without coupling to specific implementations. Properties ---------- n_points : int Total number of data points. n_dims : int Number of dimensions in x data (1 for scalar x). dtype : np.dtype Data type of the arrays. Methods ------- get_chunk(start, end) Get a chunk of data from start to end indices. __len__() Return the total number of data points. """ @property def n_points(self) -> int: """Total number of data points.""" ... @property def n_dims(self) -> int: """Number of dimensions in x data.""" ... @property def dtype(self) -> np.dtype: """Data type of the arrays.""" ...
[docs] def get_chunk(self, start: int, end: int) -> tuple[np.ndarray, np.ndarray]: """Get a chunk of data. Parameters ---------- start : int Start index (inclusive). end : int End index (exclusive). Returns ------- tuple[np.ndarray, np.ndarray] (xdata, ydata) arrays for the requested chunk. """ ...
[docs] def __len__(self) -> int: """Return total number of data points.""" ...
[docs] @runtime_checkable class StreamingDataSourceProtocol(Protocol): """Protocol for streaming data sources. Extended protocol for data sources that support streaming iteration over batches, useful for datasets that don't fit in memory. """ @property def n_points(self) -> int: """Total number of data points.""" ... @property def batch_size(self) -> int: """Size of each batch.""" ...
[docs] def __iter__(self) -> "StreamingDataSourceProtocol": """Return iterator over batches.""" ...
[docs] def __next__(self) -> tuple[np.ndarray, np.ndarray]: """Get next batch of (xdata, ydata).""" ...
[docs] def reset(self) -> None: """Reset iterator to beginning.""" ...
[docs] class ArrayDataSource: """Concrete implementation of DataSourceProtocol for NumPy arrays. This is the default data source for in-memory arrays. Parameters ---------- xdata : np.ndarray Independent variable data. ydata : np.ndarray Dependent variable data. """ __slots__ = ("_xdata", "_ydata")
[docs] def __init__(self, xdata: np.ndarray, ydata: np.ndarray) -> None: self._xdata = np.asarray(xdata) self._ydata = np.asarray(ydata)
@property def n_points(self) -> int: """Total number of data points.""" return len(self._ydata) @property def n_dims(self) -> int: """Number of dimensions in x data.""" return 1 if self._xdata.ndim == 1 else self._xdata.shape[1] @property def dtype(self) -> np.dtype: """Data type of the arrays.""" return self._ydata.dtype
[docs] def get_chunk(self, start: int, end: int) -> tuple[np.ndarray, np.ndarray]: """Get a chunk of data.""" return self._xdata[start:end], self._ydata[start:end]
[docs] def __len__(self) -> int: """Return total number of data points.""" return self.n_points