Source code for nlsq.stability.recovery

"""Recovery strategies for optimization failures.

This module provides automatic recovery mechanisms for handling
optimization failures with multiple retry strategies.
"""

import warnings
from collections.abc import Callable

import numpy as np

from nlsq.config import JAXConfig

_jax_config = JAXConfig()

import jax.numpy as jnp

from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.utils.diagnostics import OptimizationDiagnostics


[docs] class OptimizationRecovery: """Automatic recovery from optimization failures. This class provides multiple recovery strategies for handling optimization failures including: - Parameter perturbation - Algorithm switching - Regularization adjustment - Problem reformulation - Multi-start optimization Attributes ---------- max_retries : int Maximum number of recovery attempts strategies : list List of recovery strategies to try diagnostics : nlsq.diagnostics.OptimizationDiagnostics Diagnostics collector for monitoring stability_guard : nlsq.stability.NumericalStabilityGuard Numerical stability checker """
[docs] def __init__(self, max_retries: int = 3, enable_diagnostics: bool = True): """Initialize recovery system. Parameters ---------- max_retries : int Maximum recovery attempts enable_diagnostics : bool Enable diagnostic collection """ self.max_retries = max_retries self.enable_diagnostics = enable_diagnostics self.diagnostics = None if enable_diagnostics: self.diagnostics = OptimizationDiagnostics() self.stability_guard = NumericalStabilityGuard() # Recovery strategies in order of preference self.strategies = [ self._perturb_parameters, self._switch_algorithm, self._adjust_regularization, self._reformulate_problem, self._multi_start, ] # Track recovery history self.recovery_history: list[dict] = []
[docs] def recover_from_failure( self, failure_type: str, optimization_state: dict, optimization_func: Callable, **kwargs, ) -> tuple[bool, dict]: """Attempt recovery from optimization failure. Parameters ---------- failure_type : str Type of failure ('convergence', 'numerical', 'memory', etc.) optimization_state : dict Current state of optimization optimization_func : callable Function to retry optimization **kwargs Additional arguments for optimization function Returns ------- success : bool Whether recovery succeeded result : dict Recovered optimization result or error info """ self.recovery_history.append( { "failure_type": failure_type, "iteration": optimization_state.get("iteration", 0), "cost": optimization_state.get("cost", np.inf), } ) for retry in range(self.max_retries): for strategy in self.strategies: try: # Apply recovery strategy modified_state = strategy(failure_type, optimization_state, retry) # Retry optimization with modified state result = optimization_func(**modified_state, **kwargs) # Check if recovery succeeded if self._check_recovery_success(result): if self.enable_diagnostics and self.diagnostics is not None: self.diagnostics.record_event( "recovery_success", {"strategy": strategy.__name__, "retry": retry}, ) return True, result except Exception as e: warnings.warn(f"Recovery strategy {strategy.__name__} failed: {e}") continue # All recovery attempts failed if self.enable_diagnostics and self.diagnostics is not None: self.diagnostics.record_event( "recovery_failed", {"attempts": self.max_retries * len(self.strategies)} ) return False, {"error": f"Recovery failed for {failure_type}"}
def _perturb_parameters(self, failure_type: str, state: dict, retry: int) -> dict: """Perturb parameters to escape local minima. Parameters ---------- failure_type : str Type of failure state : dict Current optimization state retry : int Retry attempt number Returns ------- modified_state : dict State with perturbed parameters """ modified_state = state.copy() params = state.get("params", state.get("x")) if params is not None: # Increase perturbation with each retry noise_scale = 0.01 * (2**retry) # Add Gaussian noise using JAX import jax.random as jr key = jr.PRNGKey(42 + retry * 2) # Different seed for noise perturbation noise = jr.normal(key, shape=params.shape) * noise_scale # Scale noise by parameter magnitude param_scale = np.abs(params) + 1e-10 scaled_noise = noise * param_scale modified_state["params"] = params + scaled_noise # Also try different initial guess if available if "p0" in state: modified_state["p0"] = params + scaled_noise return modified_state def _switch_algorithm(self, failure_type: str, state: dict, retry: int) -> dict: """Switch to different optimization algorithm. Parameters ---------- failure_type : str Type of failure state : dict Current optimization state retry : int Retry attempt number Returns ------- modified_state : dict State with different algorithm """ modified_state = state.copy() current_method = state.get("method", "trf") # Algorithm switching strategy algorithm_chain = { "trf": ["lm", "dogbox"], "lm": ["trf", "dogbox"], "dogbox": ["trf", "lm"], } alternatives = algorithm_chain.get(current_method, ["trf"]) if retry < len(alternatives): modified_state["method"] = alternatives[retry] # Adjust tolerances for new algorithm if alternatives[retry] == "lm": # LM is less robust but more accurate modified_state["ftol"] = 1e-10 modified_state["xtol"] = 1e-10 elif alternatives[retry] == "dogbox": # Dogbox for bounded problems modified_state["ftol"] = 1e-8 modified_state["xtol"] = 1e-8 return modified_state def _adjust_regularization( self, failure_type: str, state: dict, retry: int ) -> dict: """Adjust regularization parameters. Parameters ---------- failure_type : str Type of failure state : dict Current optimization state retry : int Retry attempt number Returns ------- modified_state : dict State with adjusted regularization """ modified_state = state.copy() if failure_type in ["numerical", "ill_conditioned"]: # Increase regularization for numerical issues current_reg = state.get("regularization", 0) new_reg = max(1e-8, current_reg) * (10 ** (retry + 1)) modified_state["regularization"] = new_reg # Also adjust trust region parameters if applicable if state.get("method") == "trf": modified_state["tr_solver"] = "lsmr" # More stable modified_state["x_scale"] = "jac" # Jacobian scaling # Adjust loss function for outliers if failure_type == "outliers" or state.get("has_outliers", False): loss_progression = ["linear", "soft_l1", "huber", "cauchy"] current_loss = state.get("loss", "linear") try: current_idx = loss_progression.index(current_loss) if current_idx < len(loss_progression) - 1: modified_state["loss"] = loss_progression[current_idx + 1] except ValueError: modified_state["loss"] = "huber" return modified_state def _reformulate_problem(self, failure_type: str, state: dict, retry: int) -> dict: """Reformulate the optimization problem. Parameters ---------- failure_type : str Type of failure state : dict Current optimization state retry : int Retry attempt number Returns ------- modified_state : dict State with reformulated problem """ modified_state = state.copy() # Scale variables for better conditioning if "xdata" in state and "ydata" in state: xdata = np.asarray(state["xdata"]) ydata = np.asarray(state["ydata"]) # Normalize data x_mean = np.mean(xdata, axis=0) x_std = np.std(xdata, axis=0) + 1e-10 y_mean = np.mean(ydata) y_std = np.std(ydata) + 1e-10 modified_state["xdata"] = (xdata - x_mean) / x_std modified_state["ydata"] = (ydata - y_mean) / y_std # Store transformation for later modified_state["data_transform"] = { "x_mean": x_mean, "x_std": x_std, "y_mean": y_mean, "y_std": y_std, } # Adjust bounds if present if "bounds" in state and state["bounds"] is not None: bounds = state["bounds"] if isinstance(bounds, tuple) and len(bounds) == 2: lower, upper = bounds # Relax bounds slightly bound_range = upper - lower relaxation = 0.01 * (2**retry) * bound_range modified_state["bounds"] = (lower - relaxation, upper + relaxation) return modified_state def _multi_start(self, failure_type: str, state: dict, retry: int) -> dict: """Multi-start optimization from different initial points. Parameters ---------- failure_type : str Type of failure state : dict Current optimization state retry : int Retry attempt number Returns ------- modified_state : dict State with new starting point """ modified_state = state.copy() # Generate new starting point if "bounds" in state and state["bounds"] is not None: bounds = state["bounds"] if isinstance(bounds, tuple) and len(bounds) == 2: lower, upper = bounds n_params = len(state.get("p0", state.get("params", []))) if n_params > 0: # Use JAX-compatible random sampling for better coverage import jax.random as jr # Create deterministic key based on retry count key = jr.PRNGKey(42 + retry) # Generate uniform random sample sample = jr.uniform(key, shape=(n_params,)) # Scale to bounds new_p0 = lower + sample * (upper - lower) modified_state["p0"] = new_p0 else: # Random initialization around current point current = state.get("p0", state.get("params")) if current is not None: # Use JAX random for diversity import jax.random as jr key = jr.PRNGKey(42 + retry * 3) # Different seed for unbounded init scale = jnp.abs(current) + 1 new_p0 = current + jr.normal(key, shape=current.shape) * scale modified_state["p0"] = new_p0 return modified_state def _check_recovery_success(self, result: dict) -> bool: """Check if recovery was successful. Parameters ---------- result : dict Optimization result (dict or object with attributes) Returns ------- success : bool Whether recovery succeeded """ def _get_value(obj, key, default=None): """Get value from dict or object attribute.""" if isinstance(obj, dict): return obj.get(key, default) return getattr(obj, key, default) # Check for explicit success flag success = _get_value(result, "success") if success is not None: return success # Check for valid parameters params = _get_value(result, "x") if params is None: params = _get_value(result, "params") if params is not None: # Check for NaN/Inf if not np.all(np.isfinite(params)): return False # Check cost if available cost = _get_value(result, "cost") if cost is not None: if not np.isfinite(cost) or cost > 1e10: return False return True return False