Source code for nlsq.stability.guard

"""Numerical stability management for NLSQ optimization.

This module provides comprehensive numerical stability monitoring and correction
capabilities for the NLSQ package, ensuring robust optimization even with
ill-conditioned problems or extreme parameter values.

Stability Modes
===============

The stability parameter in curve_fit() controls numerical stability behavior:

- ``stability=False`` (default): No stability checks. Maximum performance.
- ``stability='check'``: Check for issues and warn, but don't modify data.
- ``stability='auto'``: Automatically detect and fix numerical issues.

Key Design Decisions (v0.3.0)
=============================

1. **Initialization-only Jacobian checks**: Stability checks on the Jacobian are
   performed only at optimization initialization, not per-iteration. Per-iteration
   Jacobian modification was found to cause optimization divergence due to
   accumulated numerical perturbations. (commit 8028a03)

2. **SVD skip for large Jacobians**: For Jacobians exceeding max_jacobian_elements_for_svd
   (default 10M elements), SVD computation is skipped to avoid O(min(m,n)^2 * max(m,n))
   overhead. Only O(n) NaN/Inf checking is performed.

3. **rescale_data parameter**: Applications requiring unit preservation can
   set rescale_data=False to preserve physical units. The default (True) rescales
   data to [0, 1] when ill-conditioning is detected.

Module Constants
================

MAX_JACOBIAN_ELEMENTS_FOR_SVD : int = 10_000_000
    Default maximum Jacobian elements for SVD computation. For Jacobians
    exceeding this threshold, SVD is skipped to avoid excessive computation.

    Performance characteristics:
    - 10M element Jacobian: ~1.5GB RAM, ~5 seconds (GPU)
    - SVD per iteration can exceed total optimization time
    - For large datasets (1M points x 3 params = 3M), SVD is computed
    - For large datasets (10M points x 3 params = 30M), SVD is skipped

CONDITION_THRESHOLD : float = 1e12
    Threshold for detecting ill-conditioned matrices. Condition numbers
    above this trigger regularization warnings or automatic fixes.

REGULARIZATION_FACTOR : float = 1e-10
    Default Tikhonov regularization factor for ill-conditioned problems.

See Also
--------
NumericalStabilityGuard : Main class for stability operations
apply_automatic_fixes : Function to apply stability fixes to data
check_problem_stability : Pre-flight stability check for optimization
"""

from __future__ import annotations

import warnings
from collections.abc import Callable

import numpy as np

from nlsq.config import JAXConfig

_jax_config = JAXConfig()

import jax
import jax.numpy as jnp
from jax import jit, lax

# Module-level cached constants (avoid repeated np.finfo lookups)
_FLOAT64_INFO = np.finfo(np.float64)
_EPS = _FLOAT64_INFO.eps
_MAX_FLOAT = _FLOAT64_INFO.max
_MIN_FLOAT = _FLOAT64_INFO.tiny

__all__ = [
    "NumericalStabilityGuard",
    "apply_automatic_fixes",
    "check_problem_stability",
    "detect_collinearity",
    "detect_parameter_scale_mismatch",
    "estimate_condition_number",
    "solve_with_cholesky_fallback",
    "stability_guard",
]


[docs] def solve_with_cholesky_fallback( A: jnp.ndarray, b: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: """Solve linear system Ax = b using Cholesky with eigenvalue fallback. This function attempts Cholesky decomposition first (O(n³/3)) and falls back to eigenvalue decomposition (O(n³)) only if Cholesky fails. The fallback is detected via NaN check, making this function JAX JIT-compatible. Parameters ---------- A : jnp.ndarray Symmetric matrix (n × n). Should be positive definite for Cholesky to succeed. b : jnp.ndarray Right-hand side vector (n,). Returns ------- x : jnp.ndarray Solution vector (n,). used_cholesky : jnp.ndarray Boolean mask indicating whether Cholesky was used (True) or the eigenvalue fallback was needed (False). Notes ----- For positive definite matrices (e.g., J^T J + λI with λ > 0), Cholesky is approximately 3x faster than eigenvalue decomposition. The JAX-compatible fallback pattern uses NaN detection instead of Python try/except. Examples -------- >>> import jax.numpy as jnp >>> A = jnp.array([[4.0, 2.0], [2.0, 3.0]]) # Positive definite >>> b = jnp.array([1.0, 2.0]) >>> x, used_cholesky = solve_with_cholesky_fallback(A, b) >>> used_cholesky True """ # Ensure symmetry A = 0.5 * (A + A.T) # Try Cholesky decomposition L = jnp.linalg.cholesky(A) x_chol = jax.scipy.linalg.cho_solve((L, True), b) # Check if Cholesky succeeded (no NaNs) cholesky_valid = jnp.all(jnp.isfinite(x_chol)) # Eigenvalue fallback def eigenvalue_solve(): eigvals, eigvecs = jnp.linalg.eigh(A) # Regularize small/negative eigenvalues eps = jnp.finfo(A.dtype).eps * jnp.max(jnp.abs(eigvals)) eigvals_safe = jnp.maximum(eigvals, eps) return eigvecs @ (eigvecs.T @ b / eigvals_safe) # Use lax.cond for JIT-compatible conditional x = lax.cond(cholesky_valid, lambda: x_chol, eigenvalue_solve) return x, cholesky_valid
[docs] class NumericalStabilityGuard: """Comprehensive numerical stability monitoring and correction. This class provides methods to detect and correct numerical issues that can arise during optimization, including: - NaN/Inf detection and correction - Ill-conditioning detection and regularization - Overflow/underflow protection - Safe mathematical operations Attributes ---------- eps : float Machine epsilon for float64 max_float : float Maximum representable float64 value min_float : float Minimum positive float64 value condition_threshold : float Threshold for detecting ill-conditioned matrices regularization_factor : float Default regularization factor for ill-conditioned problems """ __slots__ = ( "_check_gradient_jit", "_check_jacobian_fast_jit", "_safe_divide_jit", "_safe_exp_jit", "_safe_log_jit", "_safe_norm_jit", "_safe_sqrt_jit", "condition_threshold", "eps", "max_exp_arg", "max_float", "max_jacobian_elements_for_svd", "min_exp_arg", "min_float", "regularization_factor", )
[docs] def __init__(self, max_jacobian_elements_for_svd: int = 10_000_000): """Initialize stability guard with numerical constants. Parameters ---------- max_jacobian_elements_for_svd : int, default 10_000_000 Maximum number of elements in Jacobian (m × n) for which SVD will be computed. For larger Jacobians, SVD is skipped to avoid excessive computation time. Default is 10M elements, which corresponds to e.g., a (1M × 10) or (100K × 100) matrix. """ # Use cached module-level constants for fast initialization self.eps = _EPS self.max_float = _MAX_FLOAT self.min_float = _MIN_FLOAT self.condition_threshold = 1e12 self.regularization_factor = 1e-10 self.max_exp_arg = 700 # log(max_float) ≈ 709 self.min_exp_arg = -700 self.max_jacobian_elements_for_svd = max_jacobian_elements_for_svd # Create JIT-compiled versions of key functions self._create_jit_functions()
def _create_jit_functions(self): """Create JIT-compiled versions of numerical operations.""" @jit def _safe_exp_jit(x): """JIT-compiled safe exponential.""" x_clipped = jnp.clip(x, self.min_exp_arg, self.max_exp_arg) return jnp.exp(x_clipped) @jit def _safe_log_jit(x): """JIT-compiled safe logarithm.""" x_safe = jnp.maximum(x, self.min_float) return jnp.log(x_safe) @jit def _safe_divide_jit(numerator, denominator): """JIT-compiled safe division.""" safe_denom = jnp.where( jnp.abs(denominator) < self.eps, self.eps, denominator ) return numerator / safe_denom @jit def _safe_sqrt_jit(x): """JIT-compiled safe square root.""" x_safe = jnp.maximum(x, 0.0) return jnp.sqrt(x_safe) self._safe_exp_jit = _safe_exp_jit self._safe_log_jit = _safe_log_jit self._safe_divide_jit = _safe_divide_jit self._safe_sqrt_jit = _safe_sqrt_jit # JIT-compiled fast Jacobian check (for large Jacobians - NaN/Inf only) @jit def _check_jacobian_fast_jit(J): """Fast NaN/Inf check for large Jacobians.""" has_invalid = jnp.any(~jnp.isfinite(J)) J_fixed = jnp.where(jnp.isfinite(J), J, 0.0) return J_fixed, has_invalid # JIT-compiled gradient checking @jit def _check_gradient_jit(gradient, max_grad_norm): """JIT-compiled gradient checking.""" is_finite_mask = jnp.isfinite(gradient) gradient_clean = jnp.where(is_finite_mask, gradient, 0.0) grad_norm = jnp.linalg.norm(gradient_clean) needs_clipping = grad_norm > max_grad_norm gradient_fixed = jnp.where( needs_clipping, gradient_clean * (max_grad_norm / grad_norm), gradient_clean, ) has_invalid = jnp.any(~is_finite_mask) return gradient_fixed, has_invalid, needs_clipping, grad_norm # JIT-compiled safe norm (L2 only — ord is not a valid traced arg for jnp.linalg.norm) @jit def _safe_norm_jit(x, scale_factor): """JIT-compiled safe norm with pre-computed scale (L2 norm).""" x_scaled = x / scale_factor return jnp.linalg.norm(x_scaled) * scale_factor self._check_jacobian_fast_jit = _check_jacobian_fast_jit self._check_gradient_jit = _check_gradient_jit self._safe_norm_jit = _safe_norm_jit
[docs] def check_and_fix_jacobian(self, J: jnp.ndarray) -> tuple[jnp.ndarray, dict]: """Check Jacobian for numerical issues and fix them. This method performs several checks and corrections: 1. Detects and replaces NaN/Inf values 2. Computes condition number (skipped for large Jacobians) 3. Applies regularization if ill-conditioned 4. Checks for near-zero singular values For large Jacobians (> max_jacobian_elements_for_svd elements), SVD computation is skipped to avoid excessive computation time. Only NaN/Inf checking is performed. Parameters ---------- J : jnp.ndarray Jacobian matrix to check and fix Returns ------- J_fixed : jnp.ndarray Fixed Jacobian matrix issues : dict Dictionary with 'has_nan', 'has_inf', 'condition_number', etc. """ m, n = J.shape n_elements = m * n # Skip SVD for large Jacobians to avoid excessive computation time # SVD of (m, n) matrix is O(min(m,n)^2 * max(m,n)) which is very expensive # for large matrices (e.g., 10^6 x 7 = 7M elements) if n_elements > self.max_jacobian_elements_for_svd: # Fast path: JIT-compiled NaN/Inf check only (10-50x faster) J_fixed, has_invalid = self._check_jacobian_fast_jit(J) if has_invalid: warnings.warn( "Jacobian contains NaN or Inf values, replacing with zeros" ) issues: dict[str, object] = { "has_nan": bool(has_invalid), "has_inf": bool(has_invalid), "is_ill_conditioned": False, "condition_number": None, "regularized": False, "svd_skipped": True, "reason": f"Jacobian too large ({n_elements:,} > {self.max_jacobian_elements_for_svd:,} elements)", } return J_fixed, issues # Standard path: Check for NaN/Inf has_invalid = jnp.any(~jnp.isfinite(J)) if has_invalid: warnings.warn("Jacobian contains NaN or Inf values, replacing with zeros") J = jnp.where(jnp.isfinite(J), J, 0.0) # Check if matrix is all zeros - use broadcast add (more memory efficient) if jnp.allclose(J, 0.0): warnings.warn("Jacobian is all zeros, adding small perturbation") J = J + self.eps # Broadcasting is more memory efficient than jnp.ones return J, { "has_nan": False, "has_inf": False, "is_ill_conditioned": True, "condition_number": np.inf, "regularized": True, "svd_skipped": False, } # Compute singular values for condition number condition_number = np.inf min_sv_cached = None # Cache min singular value (avoid recomputation) try: svd_vals = jnp.linalg.svdvals(J) # Handle empty or invalid SVD if len(svd_vals) == 0: return J, { "has_nan": False, "has_inf": False, "condition_number": np.inf, } # svdvals returns sorted descending, use direct indexing O(1) max_sv = svd_vals[0] min_sv_cached = svd_vals[-1] # Compute condition number safely if min_sv_cached < self.eps * max_sv: condition_number = np.inf else: condition_number = float(max_sv / min_sv_cached) except Exception as e: warnings.warn(f"Could not compute SVD for condition number: {e}") condition_number = np.inf # Apply fixes based on condition number regularized = False k = min(m, n) # Diagonal size diag_indices = jnp.arange(k) if condition_number > self.condition_threshold: warnings.warn( f"Ill-conditioned Jacobian (condition number: {condition_number:.2e})" ) # Efficient diagonal regularization - 99% memory reduction vs jnp.eye(m, n) J = J.at[diag_indices, diag_indices].add(self.regularization_factor) regularized = True # Check for near-zero singular values (use cached value) if min_sv_cached is not None and min_sv_cached < self.eps * 10: # Efficient diagonal regularization J = J.at[diag_indices, diag_indices].add(self.eps * 10) regularized = True issues = { "has_nan": bool(has_invalid), "has_inf": bool(has_invalid), "is_ill_conditioned": condition_number > self.condition_threshold, "condition_number": condition_number, "regularized": regularized, "svd_skipped": False, } return J, issues
[docs] def check_parameters(self, params: jnp.ndarray) -> jnp.ndarray: """Check and fix parameter values. Parameters ---------- params : jnp.ndarray Parameter vector to check Returns ------- params_fixed : jnp.ndarray Fixed parameter vector """ # Check for NaN/Inf has_invalid = jnp.any(~jnp.isfinite(params)) if has_invalid: warnings.warn("Parameters contain NaN or Inf values") # Replace with reasonable defaults params = jnp.where(jnp.isfinite(params), params, 1.0) # Check for extreme values max_param = jnp.max(jnp.abs(params)) if max_param > 1e10: warnings.warn(f"Parameters have extreme values (max: {max_param:.2e})") # Scale down if needed params = params / (max_param / 1e10) return params
[docs] def safe_exp(self, x: jnp.ndarray) -> jnp.ndarray: """Exponential with overflow/underflow protection. Parameters ---------- x : jnp.ndarray Input array Returns ------- result : jnp.ndarray exp(x) with values clipped to prevent overflow """ return self._safe_exp_jit(x)
[docs] def safe_log(self, x: jnp.ndarray) -> jnp.ndarray: """Logarithm with domain protection. Parameters ---------- x : jnp.ndarray Input array (must be positive) Returns ------- result : jnp.ndarray log(x) with values clipped to ensure positive domain """ return self._safe_log_jit(x)
[docs] def safe_divide( self, numerator: jnp.ndarray, denominator: jnp.ndarray ) -> jnp.ndarray: """Division with zero-protection. Parameters ---------- numerator : jnp.ndarray Numerator array denominator : jnp.ndarray Denominator array Returns ------- result : jnp.ndarray numerator/denominator with small values in denominator replaced """ return self._safe_divide_jit(numerator, denominator)
[docs] def safe_sqrt(self, x: jnp.ndarray) -> jnp.ndarray: """Square root with domain protection. Parameters ---------- x : jnp.ndarray Input array Returns ------- result : jnp.ndarray sqrt(x) with negative values set to 0 """ return self._safe_sqrt_jit(x)
[docs] def safe_power(self, base: jnp.ndarray, exponent: float) -> jnp.ndarray: """Safe power operation. Parameters ---------- base : jnp.ndarray Base array exponent : float Power exponent Returns ------- result : jnp.ndarray base^exponent with numerical safety """ # Handle negative base with fractional exponent if not float(exponent).is_integer(): base = jnp.abs(base) # Prevent overflow max_base = ( jnp.power(self.max_float, 1.0 / abs(exponent)) if exponent != 0 else np.inf ) base_clipped = jnp.clip(base, -max_base, max_base) return jnp.power(base_clipped, exponent)
[docs] def check_gradient(self, gradient: jnp.ndarray) -> jnp.ndarray: """Check and fix gradient values. Parameters ---------- gradient : jnp.ndarray Gradient vector Returns ------- gradient_fixed : jnp.ndarray Fixed gradient with clipping applied if needed """ max_grad_norm = 1e6 # Use JIT-compiled gradient checking (5-15x faster) gradient_fixed, has_invalid, needs_clipping, grad_norm = ( self._check_gradient_jit(gradient, max_grad_norm) ) # Warnings outside JIT (side effects not allowed in JIT) if has_invalid: warnings.warn("Gradient contains NaN or Inf values") if needs_clipping: warnings.warn(f"Gradient norm too large ({float(grad_norm):.2e}), clipping") return gradient_fixed
[docs] def regularize_hessian( self, H: jnp.ndarray, min_eigenvalue: float = 1e-8 ) -> jnp.ndarray: """Regularize Hessian to ensure positive definiteness. Parameters ---------- H : jnp.ndarray Hessian or Hessian approximation matrix min_eigenvalue : float Minimum eigenvalue to ensure Returns ------- H_reg : jnp.ndarray Regularized Hessian """ n = H.shape[0] # Ensure symmetry H = 0.5 * (H + H.T) try: # Check minimum eigenvalue eigenvalues = jnp.linalg.eigvalsh(H) min_eig = jnp.min(eigenvalues) if min_eig < min_eigenvalue: # Add diagonal to ensure positive definiteness shift = min_eigenvalue - min_eig + self.eps H = H + shift * jnp.eye(n) except Exception as e: warnings.warn( f"regularize_hessian: eigendecomposition failed ({e}), applying fallback diagonal shift" ) H = H + min_eigenvalue * jnp.eye(n) return H
[docs] def check_residuals(self, residuals: jnp.ndarray) -> tuple[jnp.ndarray, bool]: """Check residuals for numerical issues and outliers. Parameters ---------- residuals : jnp.ndarray Residual vector Returns ------- residuals_fixed : jnp.ndarray Fixed residuals has_outliers : bool Whether outliers were detected """ # Check for NaN/Inf if jnp.any(~jnp.isfinite(residuals)): warnings.warn("Residuals contain NaN or Inf values") residuals = jnp.where(jnp.isfinite(residuals), residuals, 0.0) # Detect outliers using MAD (Median Absolute Deviation) median_res = jnp.median(residuals) mad = jnp.median(jnp.abs(residuals - median_res)) # Robust standard deviation estimate robust_std = 1.4826 * mad # Guard: if MAD is zero (constant residuals), no outliers possible if robust_std < jnp.finfo(residuals.dtype).eps: return residuals, False # Detect outliers (more than 5 robust std from median) outlier_mask = jnp.abs(residuals - median_res) > 5 * robust_std has_outliers = jnp.any(outlier_mask) if has_outliers: n_outliers = jnp.sum(outlier_mask) warnings.warn(f"Detected {n_outliers} outliers in residuals") return residuals, bool(has_outliers)
[docs] def safe_norm(self, x: jnp.ndarray, ord: float = 2) -> float: """Compute norm with overflow protection. Parameters ---------- x : jnp.ndarray Input vector or matrix ord : float Order of the norm Returns ------- norm_value : float Norm of x with overflow protection """ # Scale if needed to prevent overflow max_val = jnp.max(jnp.abs(x)) # Use JIT-compiled safe norm when scaling is needed (2-5x faster) if max_val > 1e100 or (max_val < 1e-100 and max_val > 0): if ord == 2: return float(self._safe_norm_jit(x, max_val)) # Non-L2 norms: scale manually without JIT (rare path) x_scaled = x / max_val return float(jnp.linalg.norm(x_scaled, ord=ord) * max_val) else: return float(jnp.linalg.norm(x, ord=ord))
[docs] def detect_numerical_issues(self, x: jnp.ndarray) -> dict: """Detect numerical issues in array. Parameters ---------- x : jnp.ndarray Array to check Returns ------- issues : dict Dictionary with keys 'has_nan', 'has_inf', 'has_negative' """ return { "has_nan": bool(jnp.any(jnp.isnan(x))), "has_inf": bool(jnp.any(jnp.isinf(x))), "has_negative": bool(jnp.any(x < 0)) if x.size > 0 else False, }
# Create a global instance for convenience stability_guard = NumericalStabilityGuard() # ============================================================================== # Pre-flight Stability Checks (Day 18 - Phase 3) # ==============================================================================
[docs] def estimate_condition_number(xdata: np.ndarray) -> float: """ Estimate the condition number of the data matrix. This checks if the independent variable data is well-conditioned for least squares fitting. High condition numbers indicate numerical instability. Parameters ---------- xdata : array_like Independent variable data (can be 1D or 2D) Returns ------- condition_number : float Estimated condition number. Values > 1e10 indicate potential problems. Notes ----- For 1D data, constructs a Vandermonde-like matrix with [1, x, x^2]. For multidimensional data, computes the condition number directly. """ xdata = np.asarray(xdata) # Handle 1D data if xdata.ndim == 1: # Create a simple design matrix [1, x, x^2] X = np.column_stack([np.ones_like(xdata), xdata, xdata**2]) else: # Use data directly for multidimensional case X = xdata # Compute condition number try: cond = np.linalg.cond(X) return float(cond) except (np.linalg.LinAlgError, ValueError): # If computation fails, return infinity return np.inf
[docs] def detect_parameter_scale_mismatch( p0: np.ndarray, threshold: float = 1e6 ) -> tuple[bool, float]: """ Detect if parameter scales differ by too many orders of magnitude. Large scale differences can cause numerical issues and slow convergence. Parameters ---------- p0 : array_like Initial parameter guess threshold : float, optional Ratio threshold for detecting mismatch. Default: 1e6 Returns ------- has_mismatch : bool True if parameter scales differ by more than threshold scale_ratio : float Ratio of largest to smallest parameter magnitude Examples -------- >>> p0 = np.array([1e-3, 1e3, 1.0]) >>> has_mismatch, ratio = detect_parameter_scale_mismatch(p0) >>> print(f"Mismatch: {has_mismatch}, Ratio: {ratio:.2e}") Mismatch: True, Ratio: 1.00e+06 """ p0 = np.asarray(p0) # Get absolute values (ignore sign) abs_p0 = np.abs(p0) # Handle zero parameters nonzero_p0 = abs_p0[abs_p0 > 0] if len(nonzero_p0) == 0: return False, 1.0 max_val = np.max(nonzero_p0) min_val = np.min(nonzero_p0) scale_ratio = max_val / min_val has_mismatch = scale_ratio > threshold return bool(has_mismatch), float(scale_ratio)
[docs] def detect_collinearity( xdata: np.ndarray, threshold: float = 0.95 ) -> tuple[bool, list]: """ Detect collinearity in multidimensional input data. Collinearity occurs when predictors are highly correlated, leading to unstable parameter estimates. Parameters ---------- xdata : array_like Independent variable data (multidimensional) threshold : float, optional Correlation threshold for detecting collinearity. Default: 0.95 Returns ------- has_collinearity : bool True if any pair of variables is highly correlated collinear_pairs : list of tuple List of (i, j, correlation) for collinear variable pairs Examples -------- >>> x1 = np.linspace(0, 10, 100) >>> x2 = 2 * x1 + 0.01 * np.random.randn(100) # Nearly collinear >>> xdata = np.column_stack([x1, x2]) >>> has_coll, pairs = detect_collinearity(xdata) >>> print(f"Collinear: {has_coll}") Collinear: True """ xdata = np.asarray(xdata) # Only makes sense for multidimensional data if xdata.ndim != 2 or xdata.shape[1] < 2: return False, [] # Compute correlation matrix try: corr_matrix = np.corrcoef(xdata, rowvar=False) except (ValueError, np.linalg.LinAlgError): return False, [] # Find highly correlated pairs - vectorized (10-50x faster than nested loop) n_vars = corr_matrix.shape[0] # Get upper triangular indices (excluding diagonal) i_idx, j_idx = np.triu_indices(n_vars, k=1) # Extract upper triangular correlations (vectorized) upper_corr = np.abs(corr_matrix[i_idx, j_idx]) # Find indices where correlation exceeds threshold mask = upper_corr > threshold high_corr_i = i_idx[mask] high_corr_j = j_idx[mask] high_corr_vals = upper_corr[mask] # Build result list (only for pairs exceeding threshold) collinear_pairs = [ (int(i), int(j), float(c)) for i, j, c in zip(high_corr_i, high_corr_j, high_corr_vals, strict=False) ] has_collinearity = len(collinear_pairs) > 0 return bool(has_collinearity), collinear_pairs
[docs] def check_problem_stability( xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | None = None, f: Callable | None = None, ) -> dict: """ Comprehensive pre-flight stability check for optimization problem. Identifies potential numerical issues before optimization begins, providing warnings and recommendations for fixes. Parameters ---------- xdata : array_like Independent variable data ydata : array_like Dependent variable data p0 : array_like, optional Initial parameter guess f : callable, optional Model function (currently unused, reserved for future checks) Returns ------- report : dict Stability report with keys: - 'issues': list of (issue_type, message, severity) tuples - 'condition_number': float - 'parameter_scale_ratio': float or None - 'has_collinearity': bool - 'recommendations': list of str - 'severity': str ('ok', 'warning', 'critical') Examples -------- >>> x = np.linspace(0, 1e6, 100) >>> y = 2.0 * x + 1.0 >>> p0 = [2.0, 1.0] >>> report = check_problem_stability(x, y, p0) >>> print(f"Severity: {report['severity']}") >>> for issue_type, message, severity in report['issues']: ... print(f"{severity}: {message}") """ xdata = np.asarray(xdata) ydata = np.asarray(ydata) if p0 is not None: p0 = np.asarray(p0) issues = [] recommendations = [] # Check 1: Data validity if np.any(~np.isfinite(xdata)): issues.append(("invalid_xdata", "xdata contains NaN or Inf values", "critical")) recommendations.append("Remove or interpolate NaN/Inf values in xdata") if np.any(~np.isfinite(ydata)): issues.append(("invalid_ydata", "ydata contains NaN or Inf values", "critical")) recommendations.append("Remove or interpolate NaN/Inf values in ydata") # Check 2: Condition number cond = estimate_condition_number(xdata) if cond > 1e12: issues.append( ( "ill_conditioned_data", f"xdata is ill-conditioned (cond={cond:.2e})", "critical", ) ) recommendations.append( "Rescale xdata to a smaller range (e.g., [0, 1] or [-1, 1])" ) elif cond > 1e10: issues.append( ( "poor_conditioning", f"xdata has poor conditioning (cond={cond:.2e})", "warning", ) ) recommendations.append("Consider rescaling xdata") # Check 3: Data range issues x_range = np.ptp(xdata) y_range = np.ptp(ydata) if x_range == 0: issues.append(("constant_xdata", "xdata has zero range", "critical")) recommendations.append("xdata must vary to fit a model") if y_range == 0: issues.append(("constant_ydata", "ydata has zero range", "warning")) recommendations.append("ydata is constant - model fit may be trivial") # Extreme ranges if x_range > 1e6: issues.append( ("large_x_range", f"xdata spans large range ({x_range:.2e})", "warning") ) recommendations.append("Consider normalizing xdata") if y_range > 1e6: issues.append( ("large_y_range", f"ydata spans large range ({y_range:.2e})", "warning") ) recommendations.append("Consider normalizing ydata") # Check 4: Parameter scale mismatch param_scale_ratio = None if p0 is not None and len(p0) > 1: has_mismatch, param_scale_ratio = detect_parameter_scale_mismatch(p0) if has_mismatch: issues.append( ( "parameter_scale_mismatch", f"Parameter scales differ by {param_scale_ratio:.2e}", "warning", ) ) recommendations.append( "Use x_scale parameter or rescale p0 to similar magnitudes" ) # Check 5: Collinearity (for multidimensional data) has_collinearity = False collinear_pairs: list[tuple[int, int, float]] = [] if xdata.ndim == 2 and xdata.shape[1] > 1: has_collinearity, collinear_pairs = detect_collinearity(xdata) if has_collinearity: pair_info = ", ".join( [f"({i},{j}): {corr:.3f}" for i, j, corr in collinear_pairs[:3]] ) issues.append( ( "collinear_data", f"Collinear predictors detected: {pair_info}", "warning", ) ) recommendations.append( "Remove or combine highly correlated predictors, or use regularization" ) # Determine overall severity if any(sev == "critical" for _, _, sev in issues): severity = "critical" elif any(sev == "warning" for _, _, sev in issues): severity = "warning" else: severity = "ok" return { "issues": issues, "condition_number": cond, "parameter_scale_ratio": param_scale_ratio, "has_collinearity": has_collinearity, "collinear_pairs": collinear_pairs, "recommendations": recommendations, "severity": severity, }
[docs] def apply_automatic_fixes( xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | None = None, stability_report: dict | None = None, rescale_data: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, dict]: """ Automatically apply fixes for detected stability issues. This function applies common fixes such as rescaling data and parameters based on the stability report. Parameters ---------- xdata : array_like Independent variable data ydata : array_like Dependent variable data p0 : array_like, optional Initial parameter guess stability_report : dict, optional Report from check_problem_stability(). If None, will be computed. rescale_data : bool, optional If True (default), rescale xdata/ydata to [0, 1] when ill-conditioned or large range is detected. Set to False for applications requiring unit preservation where data must maintain physical units (e.g., time in seconds, frequency in Hz). NaN/Inf handling and parameter normalization are still applied when stability='auto'. Default: True. Returns ------- xdata_fixed : ndarray Fixed xdata ydata_fixed : ndarray Fixed ydata p0_fixed : ndarray or None Fixed p0 fix_info : dict Information about applied fixes with keys: - 'applied_fixes': list of str - 'x_scale': float - 'y_scale': float - 'x_offset': float - 'y_offset': float Examples -------- >>> x = np.linspace(0, 1e6, 100) >>> y = 2.0 * x + 1.0 >>> x_fixed, y_fixed, p0_fixed, info = apply_automatic_fixes(x, y, [2.0, 1.0]) >>> print(f"Applied fixes: {info['applied_fixes']}") >>> # For applications requiring unit preservation, disable rescaling >>> x_fixed, y_fixed, p0_fixed, info = apply_automatic_fixes( ... x, y, [2.0, 1.0], rescale_data=False ... ) """ xdata = np.asarray(xdata, dtype=np.float64) ydata = np.asarray(ydata, dtype=np.float64) if p0 is not None: p0 = np.asarray(p0, dtype=np.float64) applied_fixes: list[str] = [] fix_info: dict[str, float | list[str]] = { "x_scale": 1.0, "y_scale": 1.0, "x_offset": 0.0, "y_offset": 0.0, } # Get stability report if not provided if stability_report is None: stability_report = check_problem_stability(xdata, ydata, p0) # Fix 1: Rescale xdata if ill-conditioned or large range (only if rescale_data=True) if rescale_data: cond = stability_report["condition_number"] x_range = np.ptp(xdata) if cond > 1e10 or x_range > 1e4: # Normalize to [0, 1] x_min = np.min(xdata) x_max = np.max(xdata) # Check for finite range to avoid division warnings with inf data if x_range > 0 and np.isfinite(x_range): xdata = (xdata - x_min) / x_range fix_info["x_scale"] = x_range fix_info["x_offset"] = x_min applied_fixes.append( f"Rescaled xdata from [{x_min:.2e}, {x_max:.2e}] to [0, 1]" ) # Fix 2: Rescale ydata if large range y_range = np.ptp(ydata) if y_range > 1e4 and np.isfinite(y_range): # Normalize to similar scale as x y_min = np.min(ydata) y_max = np.max(ydata) ydata = (ydata - y_min) / y_range fix_info["y_scale"] = y_range fix_info["y_offset"] = y_min applied_fixes.append( f"Rescaled ydata from [{y_min:.2e}, {y_max:.2e}] to [0, 1]" ) # Fix 3: Replace NaN/Inf in data if np.any(~np.isfinite(xdata)): # Use mean of finite values only (nanmean doesn't ignore inf) finite_mean = ( np.mean(xdata[np.isfinite(xdata)]) if np.any(np.isfinite(xdata)) else 0.0 ) xdata = np.where(np.isfinite(xdata), xdata, finite_mean) applied_fixes.append("Replaced NaN/Inf in xdata with mean") if np.any(~np.isfinite(ydata)): # Use mean of finite values only (nanmean doesn't ignore inf) finite_mean = ( np.mean(ydata[np.isfinite(ydata)]) if np.any(np.isfinite(ydata)) else 0.0 ) ydata = np.where(np.isfinite(ydata), ydata, finite_mean) applied_fixes.append("Replaced NaN/Inf in ydata with mean") # Fix 4: Adjust p0 scales if needed p0_fixed = p0 if ( p0 is not None and stability_report["parameter_scale_ratio"] and stability_report["parameter_scale_ratio"] > 1e6 ): # Normalize each parameter independently to order of magnitude 1 p0_fixed = np.copy(p0) for i in range(len(p0_fixed)): if abs(p0_fixed[i]) > 0: # Get order of magnitude order = 10 ** np.floor(np.log10(abs(p0_fixed[i]))) p0_fixed[i] = p0_fixed[i] / order applied_fixes.append("Normalized p0 to unit order of magnitude") fix_info["applied_fixes"] = applied_fixes return xdata, ydata, p0_fixed, fix_info