Source code for nlsq.utils.validators

"""Input validation for NLSQ optimization functions.

This module provides comprehensive input validation to catch errors early
and provide helpful error messages to users.
"""

import logging
import warnings
import weakref
from collections.abc import Callable
from contextlib import suppress
from functools import wraps
from inspect import signature
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

from nlsq.constants import DEFAULT_FTOL, DEFAULT_GTOL, DEFAULT_XTOL

logger = logging.getLogger(__name__)


[docs] class InputValidator: """Comprehensive input validation for curve fitting functions."""
[docs] def __init__(self, fast_mode: bool = True) -> None: """Initialize the input validator. Parameters ---------- fast_mode : bool, default True If True, skip expensive validation checks for better performance. If False, perform all validation checks. """ self.fast_mode = fast_mode # WeakKeyDictionary prevents stale id() reuse when functions are GC'd self._function_cache: weakref.WeakKeyDictionary[Callable, bool] = ( weakref.WeakKeyDictionary() )
def _validate_and_convert_arrays( self, xdata: Any, ydata: Any ) -> tuple[list[str], list[str], Any, Any, int]: """Validate and convert xdata/ydata to arrays. Parameters ---------- xdata : Any Independent variable data ydata : Any Dependent variable data Returns ------- errors : list List of error messages warnings : list List of warning messages xdata_converted : Any Converted xdata (array or tuple) ydata_converted : np.ndarray Converted ydata array n_points : int Number of data points """ errors: list[str] = [] warnings_list: list[str] = [] # Handle tuple xdata (for multi-dimensional fitting) if isinstance(xdata, tuple): try: n_points = len(xdata[0]) if len(xdata) > 0 else 0 # Check all arrays in tuple have same length for i, x_arr in enumerate(xdata): if len(x_arr) != n_points: errors.append("All arrays in xdata tuple must have same length") break warnings_list.append(f"xdata is tuple with {len(xdata)} arrays") except Exception as e: errors.append(f"Invalid xdata tuple: {e}") return errors, warnings_list, xdata, ydata, 0 else: # Convert to numpy arrays and check types try: if not isinstance(xdata, (np.ndarray, jax.Array)): xdata = np.asarray(xdata) warnings_list.append("xdata converted to numpy array") except Exception as e: errors.append(f"Cannot convert xdata to array: {e}") return errors, warnings_list, xdata, ydata, 0 # Check dimensions if xdata.ndim == 0: errors.append("xdata must be at least 1-dimensional") return errors, warnings_list, xdata, ydata, 0 # Handle 2D xdata (multiple independent variables) if xdata.ndim == 2: n_points = xdata.shape[0] n_vars = xdata.shape[1] warnings_list.append(f"xdata has {n_vars} independent variables") else: n_points = len(xdata) if hasattr(xdata, "__len__") else 1 # Convert and validate ydata try: if not isinstance(ydata, (np.ndarray, jax.Array)): ydata = np.asarray(ydata) warnings_list.append("ydata converted to numpy array") except Exception as e: errors.append(f"Cannot convert ydata to array: {e}") return errors, warnings_list, xdata, ydata, n_points if ydata.ndim == 0: errors.append("ydata must be at least 1-dimensional") return errors, warnings_list, xdata, ydata, n_points def _estimate_n_params(self, f: Callable, p0: Any | None) -> int: """Estimate number of parameters from function signature or p0. Parameters ---------- f : Callable Model function p0 : Any | None Initial parameter guess Returns ------- n_params : int Estimated number of parameters """ n_params = 2 # Default estimate try: sig = signature(f) # Count parameters excluding x params = list(sig.parameters.keys()) if params: n_params = len(params) - 1 except Exception: if p0 is not None: with suppress(Exception): n_params = len(p0) return n_params def _validate_data_shapes( self, n_points: int, ydata: np.ndarray, n_params: int ) -> tuple[list[str], list[str]]: """Validate data shapes and minimum requirements. Parameters ---------- n_points : int Number of data points ydata : np.ndarray Dependent variable data n_params : int Number of parameters Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] # Check shapes match if len(ydata) != n_points: errors.append( f"xdata ({n_points} points) and ydata ({len(ydata)} points) must have same length" ) # Check for minimum data points if n_points < 2: errors.append("Need at least 2 data points for fitting") if n_points <= n_params: errors.append( f"Need more data points ({n_points}) than parameters ({n_params}) for fitting" ) return errors, warnings_list def _validate_finite_values( self, xdata: Any, ydata: np.ndarray ) -> tuple[list[str], list[str]]: """Validate that arrays contain only finite values (no NaN/Inf). Parameters ---------- xdata : Any Independent variable data (array or tuple) ydata : np.ndarray Dependent variable data Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] # Check xdata for finite values if isinstance(xdata, tuple): # Check each array in the tuple for i, x_arr in enumerate(xdata): if not np.all(np.isfinite(x_arr)): n_bad = np.sum(~np.isfinite(x_arr)) errors.append(f"xdata[{i}] contains {n_bad} NaN or Inf values") elif not np.all(np.isfinite(xdata)): n_bad = np.sum(~np.isfinite(xdata)) errors.append(f"xdata contains {n_bad} NaN or Inf values") # Check ydata for finite values if not np.all(np.isfinite(ydata)): n_bad = np.sum(~np.isfinite(ydata)) errors.append(f"ydata contains {n_bad} NaN or Inf values") return errors, warnings_list def _validate_initial_guess( self, p0: Any | None, n_params: int ) -> tuple[list[str], list[str]]: """Validate initial parameter guess. Parameters ---------- p0 : Any | None Initial parameter guess n_params : int Expected number of parameters Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] if p0 is None: return errors, warnings_list try: p0 = np.asarray(p0) if len(p0) != n_params: errors.append( f"Initial guess p0 has {len(p0)} parameters, " f"but function expects {n_params}" ) if not np.all(np.isfinite(p0)): errors.append("Initial parameter guess p0 contains NaN or Inf values") except Exception as e: errors.append(f"Invalid initial parameter guess p0: {e}") return errors, warnings_list def _validate_bounds( self, bounds: tuple | None, n_params: int, p0: Any | None ) -> tuple[list[str], list[str]]: """Validate parameter bounds. Parameters ---------- bounds : tuple | None Parameter bounds (lower, upper) n_params : int Number of parameters p0 : Any | None Initial parameter guess (to check if within bounds) Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] if bounds is None: return errors, warnings_list try: if len(bounds) != 2: errors.append("bounds must be a 2-tuple of (lower, upper)") else: lb, ub = bounds if lb is not None and ub is not None: lb = np.asarray(lb) ub = np.asarray(ub) if len(lb) != n_params or len(ub) != n_params: errors.append( f"bounds must have length {n_params} to match parameters" ) if np.any(lb >= ub): errors.append("Lower bounds must be less than upper bounds") # Check if p0 is within bounds if p0 is not None: p0_array = np.asarray(p0) if np.any(p0_array < lb) or np.any(p0_array > ub): warnings_list.append("Initial guess p0 is outside bounds") except Exception as e: errors.append(f"Invalid bounds: {e}") return errors, warnings_list def _validate_sigma( self, sigma: Any | None, ydata: np.ndarray ) -> tuple[list[str], list[str]]: """Validate uncertainty (sigma) parameters. Parameters ---------- sigma : Any | None Uncertainties in ydata ydata : np.ndarray Dependent variable data Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] if sigma is None: return errors, warnings_list try: sigma = np.asarray(sigma) if sigma.shape != ydata.shape: errors.append("sigma must have same shape as ydata") if np.any(sigma <= 0): errors.append("sigma values must be positive") if not np.all(np.isfinite(sigma)): errors.append("sigma contains NaN or Inf values") except Exception as e: errors.append(f"Invalid sigma: {e}") return errors, warnings_list def _check_degenerate_x_values(self, xdata: Any) -> tuple[list[str], list[str]]: """Check for degenerate x data (all identical, very small/large range). Parameters ---------- xdata : array_like Independent variable data Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings: list[str] = [] # Only check for non-tuple xdata if isinstance(xdata, tuple): return errors, warnings if not (hasattr(xdata, "ndim") and xdata.ndim == 1 and len(xdata) > 0): return errors, warnings # Check for all identical values (handle JAX arrays) try: xdata_first = ( xdata.flatten()[0] if hasattr(xdata, "flatten") else xdata.flat[0] ) if np.all(xdata == xdata_first): errors.append("All x values are identical - cannot fit") except (AttributeError, NotImplementedError): # Skip if array type doesn't support .flat or .flatten() pass # Check for very small range x_range = np.ptp(xdata) if x_range < 1e-10 and x_range > 0: warnings.append( f"x data range is very small ({x_range:.2e}) - consider rescaling" ) # Check for very large range if x_range > 1e10: warnings.append( f"x data range is very large ({x_range:.2e}) - consider rescaling" ) return errors, warnings def _check_degenerate_y_values(self, ydata: np.ndarray) -> list[str]: """Check for degenerate y data (all identical, very small range). Parameters ---------- ydata : np.ndarray Dependent variable data Returns ------- warnings : list List of warning messages """ warnings: list[str] = [] # Check if all y values are identical try: ydata_first = ( ydata.flatten()[0] if hasattr(ydata, "flatten") else ydata.flat[0] ) if np.all(ydata == ydata_first): warnings.append("All y values are identical - trivial fit") except Exception as e: # Skip this check if it fails - log for debugging logger.debug(f"Y-value uniformity check failed (non-critical): {e}") # Check for very small range y_range = np.ptp(ydata) if y_range < 1e-10 and y_range > 0: warnings.append(f"y data range is very small ({y_range:.2e})") return warnings def _check_function_callable( self, f: Callable, xdata: Any, ydata: np.ndarray, p0: Any, n_params: int ) -> tuple[list[str], list[str]]: """Check if function can be called with test data. Parameters ---------- f : Callable Model function xdata : array_like Independent variable data ydata : np.ndarray Dependent variable data p0 : array_like Initial parameters n_params : int Number of parameters Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings: list[str] = [] try: # Cache function test results to avoid repeated calls. # Use the function object directly (WeakKeyDictionary auto-expires on GC). # Fall back to always-test if f is not weakly referenceable (e.g., builtins). try: already_tested = f in self._function_cache except TypeError: already_tested = False if not already_tested: if isinstance(xdata, tuple): # For tuple xdata, sample from each array test_x = tuple(arr[: min(10, len(arr))] for arr in xdata) expected_len = min(10, len(xdata[0])) else: if hasattr(xdata, "ndim") and xdata.ndim > 1: test_x = xdata[: min(10, len(xdata))] else: test_x = xdata[: min(10, len(xdata))] expected_len = min(10, len(xdata)) if p0 is not None: test_result = f(test_x, *p0) else: # Try with dummy parameters dummy_params = np.ones(n_params) test_result = f(test_x, *dummy_params) # Cache the result (suppress TypeError for non-referenceable callables) with suppress(TypeError): self._function_cache[f] = True # Check output shape/length if hasattr(test_result, "__len__"): if len(test_result) != expected_len: warnings.append( f"Function output length {len(test_result)} doesn't match " f"expected length {expected_len}" ) except Exception as e: errors.append(f"Cannot evaluate function: {e}") return errors, warnings def _check_data_quality(self, xdata: Any, ydata: np.ndarray) -> list[str]: """Check data quality (duplicates, outliers). Parameters ---------- xdata : array_like Independent variable data ydata : np.ndarray Dependent variable data Returns ------- warnings : list List of warning messages """ warnings: list[str] = [] # Check for duplicates in x if not isinstance(xdata, tuple) and hasattr(xdata, "ndim") and xdata.ndim == 1: unique_x = np.unique(xdata) if len(unique_x) < len(xdata): n_dup = len(xdata) - len(unique_x) warnings.append(f"xdata contains {n_dup} duplicate values") # Check for outliers in y if len(ydata) > 10: q1, q3 = np.percentile(ydata, [25, 75]) iqr = q3 - q1 lower = q1 - 3 * iqr upper = q3 + 3 * iqr n_outliers = np.sum((ydata < lower) | (ydata > upper)) if n_outliers > 0: warnings.append( f"ydata may contain {n_outliers} outliers - " "consider using robust loss function" ) return warnings
[docs] def validate_curve_fit_inputs( self, f: Callable, xdata: Any, ydata: Any, p0: Any | None = None, bounds: tuple | None = None, sigma: Any | None = None, absolute_sigma: bool = True, check_finite: bool = True, ) -> tuple[list[str], list[str], np.ndarray, np.ndarray]: """Validate inputs for curve_fit function. This method orchestrates the validation pipeline by calling focused helper methods for each validation step. Parameters ---------- f : callable Model function to fit xdata : array_like Independent variable data ydata : array_like Dependent variable data p0 : array_like, optional Initial parameter guess bounds : tuple, optional Parameter bounds sigma : array_like, optional Uncertainties in ydata absolute_sigma : bool Whether sigma is absolute or relative check_finite : bool Whether to check for finite values Returns ------- errors : list List of error messages (empty if no errors) warnings : list List of warning messages xdata_clean : np.ndarray Cleaned and validated xdata ydata_clean : np.ndarray Cleaned and validated ydata """ errors: list[str] = [] warnings_list: list[str] = [] # Step 1: Validate and convert arrays arr_errors, arr_warnings, xdata, ydata, n_points = ( self._validate_and_convert_arrays(xdata, ydata) ) errors.extend(arr_errors) warnings_list.extend(arr_warnings) if errors: return errors, warnings_list, xdata, ydata # Step 2: Estimate number of parameters n_params = self._estimate_n_params(f, p0) # Step 2.5: Security validation (array size limits, bounds ranges) security_errors, security_warnings = self.validate_security_constraints( n_points, n_params, bounds, p0 ) errors.extend(security_errors) warnings_list.extend(security_warnings) if security_errors: # Return early on critical security errors return errors, warnings_list, xdata, ydata # Step 3: Validate data shapes shape_errors, shape_warnings = self._validate_data_shapes( n_points, ydata, n_params ) errors.extend(shape_errors) warnings_list.extend(shape_warnings) # Step 4: Check for degenerate x values deg_x_errors, deg_x_warnings = self._check_degenerate_x_values(xdata) errors.extend(deg_x_errors) warnings_list.extend(deg_x_warnings) # Step 5: Check for degenerate y values deg_y_warnings = self._check_degenerate_y_values(ydata) warnings_list.extend(deg_y_warnings) # Step 6: Validate finite values if requested if check_finite: finite_errors, finite_warnings = self._validate_finite_values(xdata, ydata) errors.extend(finite_errors) warnings_list.extend(finite_warnings) # Step 7: Validate initial parameters p0_errors, p0_warnings = self._validate_initial_guess(p0, n_params) errors.extend(p0_errors) warnings_list.extend(p0_warnings) # Step 8: Validate bounds if provided bounds_errors, bounds_warnings = self._validate_bounds(bounds, n_params, p0) errors.extend(bounds_errors) warnings_list.extend(bounds_warnings) # Step 9: Validate sigma if provided sigma_errors, sigma_warnings = self._validate_sigma(sigma, ydata) errors.extend(sigma_errors) warnings_list.extend(sigma_warnings) # Step 10: Check function can be called (skip in fast mode) if not self.fast_mode: func_errors, func_warnings = self._check_function_callable( f, xdata, ydata, p0, n_params ) errors.extend(func_errors) warnings_list.extend(func_warnings) # Step 11: Data quality checks (skip in fast mode) if not self.fast_mode: quality_warnings = self._check_data_quality(xdata, ydata) warnings_list.extend(quality_warnings) # Return cleaned data # Keep tuples as tuples, convert arrays to numpy if not isinstance(xdata, tuple): xdata = np.asarray(xdata) ydata = np.asarray(ydata) return errors, warnings_list, xdata, ydata
def _validate_x0_array(self, x0: Any) -> tuple[list[str], np.ndarray]: """Validate and convert x0 to array. Parameters ---------- x0 : array_like Initial parameter guess Returns ------- errors : list List of error messages x0 : np.ndarray Converted x0 array """ errors: list[str] = [] # Convert x0 try: x0 = np.asarray(x0) except Exception as e: errors.append(f"Cannot convert x0 to array: {e}") return errors, x0 # Check x0 dimensions and values if x0.ndim != 1: errors.append("x0 must be 1-dimensional") if len(x0) == 0: errors.append("x0 cannot be empty") if not np.all(np.isfinite(x0)): errors.append("x0 contains NaN or Inf values") return errors, x0 def _validate_method(self, method: str) -> list[str]: """Validate optimization method. Parameters ---------- method : str Optimization method Returns ------- errors : list List of error messages """ errors: list[str] = [] valid_methods = ["trf", "dogbox", "lm"] if method not in valid_methods: errors.append(f"method must be one of {valid_methods}, got {method}") return errors def _validate_tolerances( self, ftol: float, xtol: float, gtol: float ) -> tuple[list[str], list[str]]: """Validate convergence tolerances. Parameters ---------- ftol : float Function tolerance xtol : float Parameter tolerance gtol : float Gradient tolerance Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings: list[str] = [] # Check positive if ftol <= 0: errors.append(f"ftol must be positive, got {ftol}") if xtol <= 0: errors.append(f"xtol must be positive, got {xtol}") if gtol <= 0: errors.append(f"gtol must be positive, got {gtol}") # Check very small values if ftol < 1e-15: warnings.append(f"ftol={ftol} is very small, may not converge") if xtol < 1e-15: warnings.append(f"xtol={xtol} is very small, may not converge") return errors, warnings def _validate_max_nfev( self, max_nfev: int | None, n_params: int ) -> tuple[list[str], list[str]]: """Validate maximum function evaluations. Parameters ---------- max_nfev : int or None Maximum function evaluations n_params : int Number of parameters Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings: list[str] = [] if max_nfev is not None: if max_nfev <= 0: errors.append(f"max_nfev must be positive, got {max_nfev}") elif max_nfev < n_params: warnings.append( f"max_nfev={max_nfev} is less than number of parameters {n_params}" ) return errors, warnings def _validate_bounds_and_x0( self, bounds: tuple | None, x0: np.ndarray, method: str ) -> list[str]: """Validate bounds and check x0 within bounds. Parameters ---------- bounds : tuple or None Parameter bounds as (lb, ub) x0 : np.ndarray Initial parameter guess method : str Optimization method Returns ------- errors : list List of error messages """ errors: list[str] = [] if bounds is None: return errors # Check method compatibility if method == "lm": errors.append("Levenberg-Marquardt method does not support bounds") return errors # Validate bounds structure try: lb, ub = bounds lb = np.asarray(lb) ub = np.asarray(ub) if len(lb) != len(x0) or len(ub) != len(x0): errors.append("bounds must have same length as x0") if np.any(lb >= ub): errors.append("Lower bounds must be less than upper bounds") # Check x0 within bounds if np.any(x0 < lb) or np.any(x0 > ub): errors.append("Initial guess x0 is outside bounds") except Exception as e: errors.append(f"Invalid bounds: {e}") return errors def _validate_function_at_x0( self, fun: Callable, x0: np.ndarray ) -> tuple[list[str], list[str]]: """Validate function can be evaluated at x0. Parameters ---------- fun : callable Residual function x0 : np.ndarray Initial parameter guess Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings: list[str] = [] try: result = fun(x0) result = np.asarray(result) if result.ndim != 1: errors.append("Function must return 1-dimensional residuals") if not np.all(np.isfinite(result)): warnings.append("Function returns NaN or Inf at initial guess") except Exception as e: errors.append(f"Cannot evaluate function at x0: {e}") return errors, warnings # ========================================================================= # Security-focused validation methods # ========================================================================= def _validate_array_size_limits( self, n_points: int, n_params: int, max_data_points: int = 10_000_000_000, # 10 billion max_jacobian_elements: int = 100_000_000_000, # 100 billion ) -> tuple[list[str], list[str]]: """Validate array sizes to prevent memory exhaustion and integer overflow. This is a security-focused check to prevent denial-of-service via excessive memory allocation or integer overflow in Jacobian computation. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters max_data_points : int, default 10_000_000_000 Maximum allowed data points (10 billion) max_jacobian_elements : int, default 100_000_000_000 Maximum allowed Jacobian elements (100 billion) Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] # Check data point limit if n_points > max_data_points: errors.append( f"Dataset size ({n_points:,} points) exceeds maximum allowed " f"({max_data_points:,} points). This limit prevents memory exhaustion." ) if n_points < 0: errors.append(f"Invalid negative data point count: {n_points}") if n_params < 0: errors.append(f"Invalid negative parameter count: {n_params}") # Check for potential integer overflow in Jacobian size calculation # Jacobian has shape (n_points, n_params) try: jacobian_elements = n_points * n_params if jacobian_elements > max_jacobian_elements: errors.append( f"Jacobian size ({n_points:,} x {n_params:,} = {jacobian_elements:,} elements) " f"exceeds maximum allowed ({max_jacobian_elements:,} elements). " "Consider using streaming optimization or reducing dataset size." ) except OverflowError: errors.append( f"Integer overflow computing Jacobian size: {n_points} x {n_params}. " "Dataset is too large." ) # Memory estimation warning (assuming float64) estimated_memory_gb = (n_points * n_params * 8) / (1024**3) if estimated_memory_gb > 100: warnings_list.append( f"Estimated Jacobian memory usage: {estimated_memory_gb:.1f} GB. " "Consider using streaming optimization." ) elif estimated_memory_gb > 10: warnings_list.append( f"Large Jacobian estimated at {estimated_memory_gb:.1f} GB memory." ) return errors, warnings_list def _validate_bounds_numeric_range( self, bounds: tuple | None, max_bound_magnitude: float = 1e100, ) -> tuple[list[str], list[str]]: """Validate that bounds are within reasonable numeric ranges. This prevents numerical issues from extreme bound values that could cause overflow or underflow during optimization. Parameters ---------- bounds : tuple | None Parameter bounds as (lb, ub) max_bound_magnitude : float, default 1e100 Maximum allowed absolute value for bounds Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] if bounds is None: return errors, warnings_list try: lb, ub = bounds # Convert to arrays for checking if lb is not None: lb_arr = np.asarray(lb) # Check for extreme values (excluding inf which is valid) finite_lb = lb_arr[np.isfinite(lb_arr)] if len(finite_lb) > 0 and np.any( np.abs(finite_lb) > max_bound_magnitude ): warnings_list.append( f"Lower bounds contain very large values (|lb| > {max_bound_magnitude:.0e}). " "This may cause numerical issues." ) if ub is not None: ub_arr = np.asarray(ub) finite_ub = ub_arr[np.isfinite(ub_arr)] if len(finite_ub) > 0 and np.any( np.abs(finite_ub) > max_bound_magnitude ): warnings_list.append( f"Upper bounds contain very large values (|ub| > {max_bound_magnitude:.0e}). " "This may cause numerical issues." ) # Check for NaN in bounds (invalid) if lb is not None and np.any(np.isnan(np.asarray(lb))): errors.append("Lower bounds contain NaN values") if ub is not None and np.any(np.isnan(np.asarray(ub))): errors.append("Upper bounds contain NaN values") except Exception as e: errors.append(f"Error validating bounds numeric range: {e}") return errors, warnings_list def _validate_parameter_values( self, p0: Any | None, max_param_magnitude: float = 1e50, ) -> tuple[list[str], list[str]]: """Validate that initial parameters are within reasonable numeric ranges. This prevents numerical issues from extreme parameter values that could cause overflow during function evaluation. Parameters ---------- p0 : array_like | None Initial parameter guess max_param_magnitude : float, default 1e50 Maximum allowed absolute value for parameters Returns ------- errors : list List of error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] if p0 is None: return errors, warnings_list try: p0_arr = np.asarray(p0) # Check for extreme values if np.any(np.abs(p0_arr) > max_param_magnitude): max_val = np.max(np.abs(p0_arr)) warnings_list.append( f"Initial parameters contain very large values (max |p0| = {max_val:.2e}). " "This may cause numerical overflow." ) # Check for subnormal values that might cause underflow finite_p0 = p0_arr[np.isfinite(p0_arr) & (p0_arr != 0)] if len(finite_p0) > 0: min_abs = np.min(np.abs(finite_p0)) if min_abs < 1e-300: warnings_list.append( f"Initial parameters contain very small values (min |p0| = {min_abs:.2e}). " "This may cause numerical underflow." ) except Exception as e: errors.append(f"Error validating parameter values: {e}") return errors, warnings_list
[docs] def validate_security_constraints( self, n_points: int, n_params: int, bounds: tuple | None = None, p0: Any | None = None, ) -> tuple[list[str], list[str]]: """Validate security constraints to prevent DoS and numerical issues. This method combines all security-focused validation checks. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters bounds : tuple | None, optional Parameter bounds p0 : array_like | None, optional Initial parameter guess Returns ------- errors : list List of critical error messages warnings : list List of warning messages """ errors: list[str] = [] warnings_list: list[str] = [] # Check array size limits size_errors, size_warnings = self._validate_array_size_limits( n_points, n_params ) errors.extend(size_errors) warnings_list.extend(size_warnings) # Check bounds numeric range bounds_errors, bounds_warnings = self._validate_bounds_numeric_range(bounds) errors.extend(bounds_errors) warnings_list.extend(bounds_warnings) # Check parameter value range param_errors, param_warnings = self._validate_parameter_values(p0) errors.extend(param_errors) warnings_list.extend(param_warnings) return errors, warnings_list
[docs] def validate_least_squares_inputs( self, fun: Callable, x0: Any, bounds: tuple | None = None, method: str = "trf", ftol: float = DEFAULT_FTOL, xtol: float = DEFAULT_XTOL, gtol: float = DEFAULT_GTOL, max_nfev: int | None = None, ) -> tuple[list[str], list[str], np.ndarray]: """Validate inputs for least_squares function. This method orchestrates the validation pipeline by calling focused helper methods for each validation step. Parameters ---------- fun : callable Residual function x0 : array_like Initial parameter guess bounds : tuple, optional Parameter bounds method : str Optimization method ftol : float Function tolerance xtol : float Parameter tolerance gtol : float Gradient tolerance max_nfev : int, optional Maximum function evaluations Returns ------- errors : list List of error messages warnings : list List of warning messages x0_clean : np.ndarray Cleaned initial guess """ errors: list[str] = [] warnings_list: list[str] = [] # Step 1: Validate and convert x0 x0_errors, x0 = self._validate_x0_array(x0) errors.extend(x0_errors) if x0_errors: return errors, warnings_list, x0 # Step 2: Validate method method_errors = self._validate_method(method) errors.extend(method_errors) # Step 3: Validate tolerances tol_errors, tol_warnings = self._validate_tolerances(ftol, xtol, gtol) errors.extend(tol_errors) warnings_list.extend(tol_warnings) # Step 4: Validate max_nfev nfev_errors, nfev_warnings = self._validate_max_nfev(max_nfev, len(x0)) errors.extend(nfev_errors) warnings_list.extend(nfev_warnings) # Step 5: Validate bounds and check x0 within bounds bounds_errors = self._validate_bounds_and_x0(bounds, x0, method) errors.extend(bounds_errors) # Step 6: Validate function can be called at x0 func_errors, func_warnings = self._validate_function_at_x0(fun, x0) errors.extend(func_errors) warnings_list.extend(func_warnings) return errors, warnings_list, x0
[docs] def validate_inputs(validation_type: str = "curve_fit") -> Callable: """Decorator for automatic input validation. Parameters ---------- validation_type : str Type of validation to perform ('curve_fit' or 'least_squares') Returns ------- decorator : function Decorator function that validates inputs """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: validator = InputValidator() if validation_type == "curve_fit": # Extract arguments if len(args) < 3: raise ValueError( "curve_fit requires at least 3 arguments (f, xdata, ydata)" ) f, xdata, ydata = args[:3] remaining_args = args[3:] # Get optional arguments p0 = kwargs.get("p0") bounds = kwargs.get("bounds") sigma = kwargs.get("sigma") absolute_sigma = kwargs.get("absolute_sigma", True) check_finite = kwargs.get("check_finite", True) # Validate errors, warnings_list, xdata_clean, ydata_clean = ( validator.validate_curve_fit_inputs( f, xdata, ydata, p0, bounds, sigma, absolute_sigma, check_finite ) ) # Handle errors and warnings if errors: raise ValueError(f"Input validation failed: {'; '.join(errors)}") for warning in warnings_list: warnings.warn(warning, UserWarning, stacklevel=2) # Replace with cleaned data args = (f, xdata_clean, ydata_clean, *remaining_args) elif validation_type == "least_squares": # Extract arguments if len(args) < 2: raise ValueError( "least_squares requires at least 2 arguments (fun, x0)" ) fun, x0 = args[:2] remaining_args = args[2:] # Get optional arguments bounds = kwargs.get("bounds") method = kwargs.get("method", "trf") ftol = kwargs.get("ftol", 1e-8) xtol = kwargs.get("xtol", 1e-8) gtol = kwargs.get("gtol", 1e-8) max_nfev = kwargs.get("max_nfev") # Validate errors, warnings_list, x0_clean = ( validator.validate_least_squares_inputs( fun, x0, bounds, method, ftol, xtol, gtol, max_nfev ) ) # Handle errors and warnings if errors: raise ValueError(f"Input validation failed: {'; '.join(errors)}") for warning in warnings_list: warnings.warn(warning, UserWarning, stacklevel=2) # Replace with cleaned data args = (fun, x0_clean, *remaining_args) else: raise ValueError(f"Unknown validation type: {validation_type}") # Call original function return func(*args, **kwargs) return wrapper return decorator