Source code for nlsq.diagnostics.gradient_health

"""Gradient health monitoring for nonlinear least squares optimization.

This module provides the GradientMonitor class for tracking gradient behavior
during optimization iterations. It detects:

- Vanishing gradients (GRAD-001): Gradient magnitude becomes very small
  while cost remains significant
- Gradient imbalance (GRAD-002): Large disparity in gradient magnitudes
  across parameters
- Gradient stagnation (GRAD-003): Gradient norm remains nearly constant
  for multiple iterations

Memory usage is bounded at <1KB regardless of iteration count using:
- Sliding window for gradient norm history (configurable, default 100)
- Welford's online algorithm for running mean/variance per parameter

Issue codes follow the pattern GRAD-NNN for consistency with the
Model Health Diagnostics System.

Integration with TRF Optimizer
------------------------------
The GradientMonitor can be integrated with the TRF optimizer via callbacks:

>>> from nlsq import curve_fit
>>> from nlsq.diagnostics import DiagnosticsConfig, GradientMonitor
>>>
>>> config = DiagnosticsConfig()
>>> monitor = GradientMonitor(config)
>>> callback = monitor.create_callback()
>>>
>>> # Fit with gradient monitoring
>>> result = curve_fit(model, x, y, p0=p0, callback=callback)
>>>
>>> # Get gradient health report
>>> report = monitor.get_report()
>>> print(report)
"""

from __future__ import annotations

import time
from collections import deque
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np

from nlsq.diagnostics.recommendations import get_recommendation
from nlsq.diagnostics.types import (
    DiagnosticsConfig,
    GradientHealthReport,
    HealthStatus,
    IssueCategory,
    IssueSeverity,
    ModelHealthIssue,
)

if TYPE_CHECKING:
    from collections.abc import Sequence


[docs] class GradientMonitor: """Monitor gradient health during optimization iterations. This class tracks gradient behavior to detect potential optimization issues such as vanishing gradients, gradient imbalance between parameters, and gradient stagnation. It uses memory-efficient algorithms to ensure memory usage stays below 1KB regardless of iteration count. Parameters ---------- config : DiagnosticsConfig Configuration containing thresholds and settings for gradient monitoring. Attributes ---------- config : DiagnosticsConfig Configuration for the monitor. iteration_count : int Total number of iterations recorded. Examples -------- >>> from nlsq.diagnostics import DiagnosticsConfig >>> from nlsq.diagnostics.gradient_health import GradientMonitor >>> import numpy as np >>> >>> config = DiagnosticsConfig() >>> monitor = GradientMonitor(config) >>> >>> # Record gradients during optimization >>> for i in range(100): ... gradient = np.array([0.1, 0.08, 0.12]) / (i + 1) ... monitor.record_gradient(gradient, cost=1.0 / (i + 1)) >>> >>> report = monitor.get_report() >>> print(report.health_status) HealthStatus.HEALTHY Integration with curve_fit callback: >>> from nlsq import curve_fit >>> from nlsq.diagnostics import DiagnosticsConfig, GradientMonitor >>> >>> config = DiagnosticsConfig() >>> monitor = GradientMonitor(config) >>> callback = monitor.create_callback() >>> >>> # Use in curve_fit (gradient is estimated from Jacobian) >>> # result = curve_fit(model, x, y, p0=p0, callback=callback) >>> # report = monitor.get_report() Notes ----- Memory efficiency is achieved through: 1. **Sliding window**: Stores only the last N gradient norms (default 100), using a deque with maxlen for O(1) append/pop. 2. **Welford's algorithm**: Computes running mean and variance in O(1) space per parameter, without storing individual values. The total memory footprint is approximately: - Sliding window: window_size * 8 bytes (floats) - Per-parameter stats: 3 * n_params * 8 bytes (mean, M2, count) - Overhead: ~100 bytes for scalars and bookkeeping For 100 window size and 10 parameters: ~900 bytes < 1KB """ __slots__ = ( "_cost_history", "_gradient_norm_history", "_has_numerical_issues", "_initial_cost", "_initial_gradient_norm", "_last_gradient_norm", "_last_params", "_max_imbalance_ratio", "_n_params", "_param_count", "_param_m2", "_param_means", "config", "iteration_count", )
[docs] def __init__(self, config: DiagnosticsConfig) -> None: """Initialize the gradient monitor. Parameters ---------- config : DiagnosticsConfig Configuration containing monitoring thresholds. """ self.config = config self.iteration_count: int = 0 # Sliding window for gradient norms (bounded memory) self._gradient_norm_history: deque[float] = deque( maxlen=config.gradient_window_size ) self._cost_history: deque[float] = deque(maxlen=config.gradient_window_size) # Welford's algorithm state for per-parameter running statistics self._param_means: np.ndarray = np.array([]) self._param_m2: np.ndarray = np.array([]) # Sum of squared differences self._param_count: int = 0 # Tracking variables self._last_gradient_norm: float = 0.0 self._has_numerical_issues: bool = False self._n_params: int = 0 self._max_imbalance_ratio: float = 1.0 self._initial_gradient_norm: float = 0.0 self._initial_cost: float = 0.0 self._last_params: np.ndarray | None = None
[docs] def record_gradient( self, gradient: np.ndarray | Sequence[float], cost: float, ) -> None: """Record a gradient observation from an optimization iteration. Parameters ---------- gradient : np.ndarray or Sequence[float] The gradient vector (partial derivatives w.r.t. each parameter). cost : float The current cost/loss value at this iteration. Raises ------ ValueError If gradient is empty. Notes ----- This method uses Welford's online algorithm to update running statistics for per-parameter gradient magnitudes. This allows computing mean and variance without storing individual values, achieving O(1) memory per parameter. The algorithm maintains: - mean: Running mean of absolute gradient values - M2: Sum of squared differences from the mean Variance is computed as M2 / (n - 1) when needed. """ gradient = np.asarray(gradient) if gradient.size == 0: raise ValueError("Gradient array cannot be empty") # Check for numerical issues if np.any(np.isnan(gradient)) or np.any(np.isinf(gradient)): self._has_numerical_issues = True # Replace non-finite values with safe large values (avoid overflow in Welford's) gradient = np.nan_to_num(gradient, nan=0.0, posinf=1e100, neginf=-1e100) self.iteration_count += 1 n_params = len(gradient) # Initialize per-parameter stats on first call if self._n_params == 0: self._n_params = n_params self._param_means = np.zeros(n_params) self._param_m2 = np.zeros(n_params) self._initial_cost = cost # Compute gradient norm gradient_norm = float(np.linalg.norm(gradient)) self._gradient_norm_history.append(gradient_norm) self._cost_history.append(cost) self._last_gradient_norm = gradient_norm # Store initial gradient norm if self.iteration_count == 1: self._initial_gradient_norm = gradient_norm # Update per-parameter running statistics (Welford's algorithm) abs_gradient = np.abs(gradient) self._param_count += 1 delta = abs_gradient - self._param_means self._param_means += delta / self._param_count delta2 = abs_gradient - self._param_means self._param_m2 += delta * delta2 # Track max imbalance ratio (with overflow guard) min_grad = ( np.min(abs_gradient[abs_gradient > 0]) if np.any(abs_gradient > 0) else 1.0 ) max_grad = np.max(abs_gradient) if min_grad > 0 and np.isfinite(max_grad) and np.isfinite(min_grad): imbalance = max_grad / min_grad if np.isfinite(imbalance): self._max_imbalance_ratio = max(self._max_imbalance_ratio, imbalance)
[docs] def create_callback( self, user_callback: Callable[..., None] | None = None, ) -> Callable[..., None]: """Create a callback function for integration with curve_fit/TRF. This method creates a callback compatible with NLSQ's optimization callbacks. The callback extracts gradient information from the optimization state and records it in the monitor. Parameters ---------- user_callback : callable, optional An optional user callback to chain with the gradient monitor. Will be called after gradient recording with the same arguments. Returns ------- callable A callback function compatible with curve_fit's callback parameter. Examples -------- >>> from nlsq import curve_fit >>> from nlsq.diagnostics import DiagnosticsConfig, GradientMonitor >>> >>> monitor = GradientMonitor(DiagnosticsConfig()) >>> callback = monitor.create_callback() >>> >>> # result = curve_fit(model, x, y, p0=p0, callback=callback) >>> # report = monitor.get_report() Notes ----- The callback receives iteration information including: - iteration: Current iteration number - cost: Current cost value - params: Current parameter values - info: Dictionary with gradient_norm, nfev, step_norm, etc. When the gradient is not directly available, we estimate it from changes in parameters and cost, or use gradient_norm from info. """ def gradient_monitor_callback( iteration: int, cost: float, params: np.ndarray, info: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Callback function for gradient monitoring. Parameters ---------- iteration : int Current iteration number. cost : float Current cost value. params : np.ndarray Current parameter values. info : dict, optional Additional information including gradient_norm. **kwargs : Any Additional keyword arguments (ignored). """ # Extract gradient information if info is not None and "gradient" in info: # Direct gradient available gradient = np.asarray(info["gradient"]) elif info is not None and "gradient_norm" in info: # Only gradient norm available — do not estimate per-component # direction from parameter steps (step direction != gradient # direction for trust-region methods). Use a uniform proxy # that preserves the norm without fabricating component ratios. gradient_norm = info["gradient_norm"] n_params = len(params) gradient = ( np.ones(n_params) * (gradient_norm / np.sqrt(n_params)) if gradient_norm > 0 else np.ones(n_params) ) # No gradient info - estimate from parameters elif self._last_params is not None: gradient = -(params - self._last_params) if np.linalg.norm(gradient) == 0: gradient = np.ones_like(params) * 1e-10 else: gradient = np.ones_like(params) # Store current params for next iteration self._last_params = params.copy() # Record in monitor self.record_gradient(gradient, cost) # Call user callback if provided if user_callback is not None: user_callback( iteration=iteration, cost=cost, params=params, info=info, **kwargs ) return gradient_monitor_callback
[docs] def get_report(self) -> GradientHealthReport: """Generate a gradient health report from recorded observations. Returns ------- GradientHealthReport Report containing gradient health metrics and any detected issues. Notes ----- The report includes: - Overall health score (0-1, higher is healthier) - Mean and final gradient norms - Per-parameter mean and variance of gradient magnitudes - Detection of vanishing gradients, imbalance, and stagnation - List of ModelHealthIssue objects for any detected problems """ start_time = time.perf_counter() if self.iteration_count == 0: return GradientHealthReport( available=True, n_iterations=0, health_score=1.0, issues=[], health_status=HealthStatus.HEALTHY, computation_time_ms=0.0, ) # Compute statistics from sliding window norm_history = list(self._gradient_norm_history) cost_history = list(self._cost_history) mean_gradient_norm = float(np.mean(norm_history)) if norm_history else 0.0 final_gradient_norm = self._last_gradient_norm # Compute per-parameter variance from Welford's M2 if self._param_count > 1: variance = self._param_m2 / (self._param_count - 1) else: variance = np.zeros_like(self._param_means) # Detect issues issues: list[ModelHealthIssue] = [] # Check for vanishing gradients (GRAD-001) vanishing_detected = self._detect_vanishing_gradients( norm_history, cost_history ) if vanishing_detected: issues.append(self._create_grad_001_issue(norm_history, cost_history)) # Check for gradient imbalance (GRAD-002) imbalance_detected = self._detect_gradient_imbalance() if imbalance_detected: issues.append(self._create_grad_002_issue()) # Check for gradient stagnation (GRAD-003) stagnation_detected = self._detect_gradient_stagnation(norm_history) if stagnation_detected: issues.append(self._create_grad_003_issue(norm_history)) # Compute health score health_score = self._compute_health_score( vanishing_detected, imbalance_detected, stagnation_detected ) # Determine overall health status health_status = self._determine_health_status(issues) computation_time = (time.perf_counter() - start_time) * 1000 return GradientHealthReport( available=True, n_iterations=self.iteration_count, health_score=health_score, mean_gradient_norm=mean_gradient_norm, final_gradient_norm=final_gradient_norm, mean_gradient_magnitudes=self._param_means.copy(), variance_gradient_magnitudes=variance, max_imbalance_ratio=self._max_imbalance_ratio, has_numerical_issues=self._has_numerical_issues, vanishing_detected=vanishing_detected, imbalance_detected=imbalance_detected, stagnation_detected=stagnation_detected, issues=issues, health_status=health_status, computation_time_ms=computation_time, )
[docs] def reset(self) -> None: """Reset the monitor to its initial state. Clears all recorded gradients and statistics. Useful when starting a new optimization run. """ self.iteration_count = 0 self._gradient_norm_history.clear() self._cost_history.clear() self._param_means = np.array([]) self._param_m2 = np.array([]) self._param_count = 0 self._last_gradient_norm = 0.0 self._has_numerical_issues = False self._n_params = 0 self._max_imbalance_ratio = 1.0 self._initial_gradient_norm = 0.0 self._initial_cost = 0.0 self._last_params = None
def _detect_vanishing_gradients( self, norm_history: list[float], cost_history: list[float], ) -> bool: """Detect if gradients are vanishing while cost remains significant. Vanishing gradients occur when the gradient magnitude becomes very small relative to the vanishing_threshold, but the cost function is still significant (not converged). Parameters ---------- norm_history : list[float] Recent gradient norm history. cost_history : list[float] Recent cost history. Returns ------- bool True if vanishing gradients detected. """ if len(norm_history) < 5: return False # Check recent gradient norms recent_norms = norm_history[-10:] if len(norm_history) >= 10 else norm_history recent_costs = cost_history[-10:] if len(cost_history) >= 10 else cost_history avg_recent_norm = np.mean(recent_norms) avg_recent_cost = np.mean(recent_costs) # Gradient is vanishing if: # 1. Average norm is below threshold # 2. Cost is still significant (not effectively zero) # 3. Initial gradient was not already vanishing (we had signal to start) threshold = self.config.vanishing_threshold if self._initial_gradient_norm > 0: # Use relative threshold based on initial gradient relative_threshold = self._initial_gradient_norm * threshold else: relative_threshold = threshold # Cost is significant if it's above a small epsilon cost_significant = avg_recent_cost > 1e-10 # Also check if gradient has dropped significantly from initial gradient_dropped = ( self._initial_gradient_norm > 0 and avg_recent_norm < relative_threshold ) return gradient_dropped and cost_significant def _detect_gradient_imbalance(self) -> bool: """Detect if gradient magnitudes are severely imbalanced across parameters. Imbalance occurs when the ratio between the largest and smallest gradient components exceeds the imbalance_threshold. Returns ------- bool True if gradient imbalance detected. """ if self._n_params < 2: return False return self._max_imbalance_ratio > self.config.imbalance_threshold def _detect_gradient_stagnation(self, norm_history: list[float]) -> bool: """Detect if gradient norm has stagnated (no significant change). Stagnation occurs when the gradient norm remains nearly constant for multiple consecutive iterations, which may indicate the optimizer is stuck or has reached a saddle point. Parameters ---------- norm_history : list[float] Recent gradient norm history. Returns ------- bool True if gradient stagnation detected. """ window = self.config.stagnation_window tolerance = self.config.stagnation_tolerance if len(norm_history) < window: return False recent = norm_history[-window:] mean_norm = np.mean(recent) if mean_norm < 1e-15: # Effectively zero gradient - this is convergence, not stagnation return False # Check if relative variation is below tolerance std_norm = np.std(recent) relative_variation = std_norm / mean_norm return relative_variation < tolerance def _compute_health_score( self, vanishing: bool, imbalance: bool, stagnation: bool, ) -> float: """Compute overall gradient health score. The score ranges from 0 (poor) to 1 (healthy). Parameters ---------- vanishing : bool Whether vanishing gradients were detected. imbalance : bool Whether gradient imbalance was detected. stagnation : bool Whether gradient stagnation was detected. Returns ------- float Health score in [0, 1]. """ score = 1.0 # Deduct for each detected issue if vanishing: score -= 0.3 if imbalance: score -= 0.3 if stagnation: score -= 0.2 if self._has_numerical_issues: score -= 0.2 return max(0.0, score) def _determine_health_status(self, issues: list[ModelHealthIssue]) -> HealthStatus: """Determine overall health status from detected issues. Parameters ---------- issues : list[ModelHealthIssue] List of detected issues. Returns ------- HealthStatus Overall health status. """ if not issues: return HealthStatus.HEALTHY has_critical = any(issue.severity == IssueSeverity.CRITICAL for issue in issues) if has_critical: return HealthStatus.CRITICAL has_warning = any(issue.severity == IssueSeverity.WARNING for issue in issues) if has_warning: return HealthStatus.WARNING return HealthStatus.HEALTHY def _create_grad_001_issue( self, norm_history: list[float], cost_history: list[float], ) -> ModelHealthIssue: """Create GRAD-001 issue for vanishing gradients. Parameters ---------- norm_history : list[float] Recent gradient norm history. cost_history : list[float] Recent cost history. Returns ------- ModelHealthIssue Issue describing vanishing gradients. """ recent_norm = ( np.mean(norm_history[-10:]) if len(norm_history) >= 10 else np.mean(norm_history) ) recent_cost = ( np.mean(cost_history[-10:]) if len(cost_history) >= 10 else np.mean(cost_history) ) return ModelHealthIssue( category=IssueCategory.GRADIENT, severity=IssueSeverity.WARNING, code="GRAD-001", message=( f"Vanishing gradients detected: gradient norm ({recent_norm:.2e}) " f"is very small while cost ({recent_cost:.2e}) remains significant." ), affected_parameters=None, details={ "recent_gradient_norm": recent_norm, "recent_cost": recent_cost, "threshold": self.config.vanishing_threshold, "initial_gradient_norm": self._initial_gradient_norm, }, recommendation=get_recommendation("GRAD-001"), ) def _create_grad_002_issue(self) -> ModelHealthIssue: """Create GRAD-002 issue for gradient imbalance. Returns ------- ModelHealthIssue Issue describing gradient imbalance. """ # Identify which parameters have extreme gradients if len(self._param_means) > 0: min_idx = int(np.argmin(self._param_means)) max_idx = int(np.argmax(self._param_means)) affected = (min_idx, max_idx) if min_idx != max_idx else (min_idx,) else: affected = None return ModelHealthIssue( category=IssueCategory.GRADIENT, severity=IssueSeverity.WARNING, code="GRAD-002", message=( f"Gradient imbalance detected: ratio between largest and smallest " f"gradient components is {self._max_imbalance_ratio:.2e}, " f"exceeding threshold {self.config.imbalance_threshold:.2e}." ), affected_parameters=affected, details={ "imbalance_ratio": self._max_imbalance_ratio, "threshold": self.config.imbalance_threshold, "mean_gradient_magnitudes": self._param_means.tolist(), }, recommendation=get_recommendation("GRAD-002"), ) def _create_grad_003_issue(self, norm_history: list[float]) -> ModelHealthIssue: """Create GRAD-003 issue for gradient stagnation. Parameters ---------- norm_history : list[float] Recent gradient norm history. Returns ------- ModelHealthIssue Issue describing gradient stagnation. """ window = self.config.stagnation_window recent = norm_history[-window:] if len(norm_history) >= window else norm_history mean_norm = np.mean(recent) std_norm = np.std(recent) return ModelHealthIssue( category=IssueCategory.GRADIENT, severity=IssueSeverity.WARNING, code="GRAD-003", message=( f"Gradient stagnation detected: gradient norm has remained " f"nearly constant ({mean_norm:.2e} +/- {std_norm:.2e}) " f"over the last {len(recent)} iterations." ), affected_parameters=None, details={ "mean_gradient_norm": mean_norm, "std_gradient_norm": std_norm, "relative_variation": std_norm / mean_norm if mean_norm > 0 else 0, "stagnation_window": window, "tolerance": self.config.stagnation_tolerance, }, recommendation=get_recommendation("GRAD-003"), )