Source code for nlsq.core.orchestration.data_preprocessor

"""DataPreprocessor component for CurveFit decomposition.

Handles input validation, array conversion, data masking, and padding
for curve fitting operations. This component is extracted from the
CurveFit class as part of the God class decomposition.

Reference: specs/017-curve-fit-decomposition/spec.md FR-001
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import jax.numpy as jnp
import numpy as np

from nlsq.interfaces.orchestration_protocol import PreprocessedData

if TYPE_CHECKING:
    from collections.abc import Callable

    from nlsq.types import ArrayLike


[docs] class DataPreprocessor: """Preprocessor for curve fitting input data. Handles: 1. Input validation (type checking, finiteness) 2. Array conversion (numpy/list to JAX) 3. Length consistency checking 4. Data masking for invalid points 5. NaN/Inf handling via nan_policy Example: >>> preprocessor = DataPreprocessor() >>> data = preprocessor.preprocess( ... f=my_model, ... xdata=x_values, ... ydata=y_values, ... sigma=uncertainties, ... check_finite=True, ... ) >>> print(f"Valid points: {data.n_points}") """
[docs] def preprocess( self, f: Callable[..., ArrayLike], xdata: ArrayLike, ydata: ArrayLike, *, sigma: ArrayLike | None = None, absolute_sigma: bool = False, check_finite: bool = True, nan_policy: str = "raise", stability_check: bool = False, ) -> PreprocessedData: """Validate and preprocess input data for curve fitting. Args: f: Model function to fit (used for parameter count detection) xdata: Independent variable data ydata: Dependent variable data (observations) sigma: Uncertainty/weights for observations absolute_sigma: If True, sigma is absolute; else relative check_finite: If True, raise on NaN/Inf values nan_policy: How to handle NaN: 'raise', 'omit', or 'propagate' stability_check: If True, run additional stability checks Returns: PreprocessedData with validated, converted arrays Raises: ValueError: If inputs are invalid (wrong shape, non-finite, etc.) TypeError: If inputs have wrong types """ # Step 1: Convert to arrays. # Defer finite check for 'omit' and 'propagate': 'omit' filters invalids in # step 4; 'propagate' intentionally keeps NaN/Inf without raising. effective_check_finite = check_finite and nan_policy not in ( "omit", "propagate", ) xdata_arr, ydata_arr = self._convert_to_arrays( xdata, ydata, effective_check_finite ) # Step 2: Validate data is not empty if ydata_arr.size == 0: msg = "`ydata` must not be empty!" raise ValueError(msg) # Step 3: Validate length consistency m, xdims = self._validate_lengths(xdata_arr, ydata_arr) # Step 4: Handle NaN values based on policy has_nans_removed = False has_infs_removed = False if nan_policy == "omit": xdata_arr, ydata_arr, sigma, mask, has_nans_removed, has_infs_removed = ( self._handle_nan_omit(xdata_arr, ydata_arr, sigma, xdims) ) m = len(ydata_arr) else: mask = np.ones(m, dtype=bool) # Step 5: Validate sigma if provided sigma_arr = ( self._validate_sigma(sigma, ydata_arr.shape) if sigma is not None else None ) # Step 6: Convert to JAX arrays jnp_xdata = jnp.asarray(xdata_arr) jnp_ydata = jnp.asarray(ydata_arr) jnp_sigma = jnp.asarray(sigma_arr) if sigma_arr is not None else None jnp_mask = jnp.asarray(mask) return PreprocessedData( xdata=jnp_xdata, ydata=jnp_ydata, sigma=jnp_sigma, mask=jnp_mask, n_points=m, is_padded=False, original_length=m, has_nans_removed=has_nans_removed, has_infs_removed=has_infs_removed, )
def _convert_to_arrays( self, xdata: ArrayLike, ydata: ArrayLike, check_finite: bool, ) -> tuple[np.ndarray, np.ndarray]: """Convert inputs to numpy arrays and validate finiteness. Args: xdata: Independent variable data ydata: Dependent variable data check_finite: Whether to check for finite values Returns: Tuple of (xdata_array, ydata_array) Raises: ValueError: If check_finite=True and data contains NaN/Inf """ # Convert ydata if check_finite: ydata_arr = np.asarray_chkfinite(ydata, float) else: ydata_arr = np.asarray(ydata, float) # Convert xdata if hasattr(xdata, "__array__") or isinstance(xdata, (list, tuple, np.ndarray)): if check_finite: xdata_arr = np.asarray_chkfinite(xdata, float) else: xdata_arr = np.asarray(xdata, float) else: msg = ( f"xdata must be array-like (list, tuple, ndarray, or JAX array), " f"got {type(xdata).__name__!r}" ) raise ValueError(msg) return xdata_arr, ydata_arr def _validate_lengths( self, xdata: np.ndarray, ydata: np.ndarray ) -> tuple[int, int]: """Validate that X and Y data lengths match. Args: xdata: X data array ydata: Y data array Returns: Tuple of (data_length, x_dimensions) Raises: ValueError: If X and Y lengths don't match """ m = len(ydata) xdims = xdata.ndim xlen = len(xdata) if xdims == 1 else len(xdata[0]) if xlen != m: msg = "X and Y data lengths dont match" raise ValueError(msg) return m, xdims def _handle_nan_omit( self, xdata: np.ndarray, ydata: np.ndarray, sigma: ArrayLike | None, xdims: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray, bool, bool]: """Handle NaN values by omitting them from data. Args: xdata: X data array ydata: Y data array sigma: Sigma array or None xdims: Dimensionality of xdata Returns: Tuple of (xdata, ydata, sigma, mask, has_nans_removed, has_infs_removed) """ # Find valid indices (not NaN or Inf) y_valid = np.isfinite(ydata) if xdims == 1: x_valid = np.isfinite(xdata) else: # For 2D xdata, check all rows x_valid = np.all(np.isfinite(xdata), axis=0) valid_mask = y_valid & x_valid # Track what was removed has_nans = bool(np.any(np.isnan(ydata)) or np.any(np.isnan(xdata))) has_infs = bool(np.any(np.isinf(ydata)) or np.any(np.isinf(xdata))) # Filter data ydata_clean = ydata[valid_mask] if xdims == 1: xdata_clean = xdata[valid_mask] else: xdata_clean = xdata[:, valid_mask] sigma_clean = None if sigma is not None: sigma_arr = np.asarray(sigma) if sigma_arr.ndim == 1: sigma_clean = sigma_arr[valid_mask] else: # 2D covariance matrix - need to extract submatrix sigma_clean = sigma_arr[np.ix_(valid_mask, valid_mask)] if not np.all(np.isfinite(sigma_clean)): raise ValueError( "Sigma covariance matrix contains non-finite values in " "valid data rows/columns after NaN filtering" ) # Create mask for clean data (all True since we filtered) mask = np.ones(len(ydata_clean), dtype=bool) return xdata_clean, ydata_clean, sigma_clean, mask, has_nans, has_infs
[docs] def validate_sigma( self, sigma: ArrayLike | None, ydata_shape: tuple[int, ...], ) -> np.ndarray | None: """Validate and convert sigma to appropriate format. Public interface matching DataPreprocessorProtocol. Args: sigma: Input sigma (1D for diagonal, 2D for full covariance) ydata_shape: Shape of ydata for compatibility check Returns: Validated numpy array or None Raises: ValueError: If sigma shape is incompatible with ydata """ return self._validate_sigma(sigma, ydata_shape)
def _validate_sigma( self, sigma: ArrayLike | None, ydata_shape: tuple[int, ...], ) -> np.ndarray | None: """Validate and convert sigma to appropriate format. Args: sigma: Input sigma (1D for diagonal, 2D for full covariance) ydata_shape: Shape of ydata for compatibility check Returns: Validated numpy array or None Raises: ValueError: If sigma shape is incompatible with ydata """ if sigma is None: return None sigma_arr = np.asarray(sigma, dtype=float) n = ydata_shape[0] if sigma_arr.ndim == 1: if len(sigma_arr) != n: msg = f"Sigma length ({len(sigma_arr)}) must match ydata length ({n})" raise ValueError(msg) if not np.all(np.isfinite(sigma_arr)): msg = "Sigma contains non-finite values (NaN or Inf)" raise ValueError(msg) if not np.all(sigma_arr > 0): msg = "Sigma values must be strictly positive (used as 1/sigma weights)" raise ValueError(msg) elif sigma_arr.ndim == 2: if sigma_arr.shape != (n, n): msg = f"Sigma shape {sigma_arr.shape} must be ({n}, {n})" raise ValueError(msg) if not np.all(np.isfinite(sigma_arr)): msg = "Sigma covariance matrix contains non-finite values (NaN or Inf)" raise ValueError(msg) eigenvalues = np.linalg.eigvalsh(sigma_arr) if not np.all(eigenvalues > 0): msg = "Sigma covariance matrix must be positive definite" raise ValueError(msg) else: msg = f"Sigma must be 1D or 2D, got {sigma_arr.ndim}D" raise ValueError(msg) return sigma_arr