Source code for nlsq.streaming.phases.warmup

"""Phase 1: L-BFGS Warmup optimization.

This module contains the WarmupPhase class that encapsulates
the L-BFGS warmup logic for the AdaptiveHybridStreamingOptimizer.

The warmup phase runs L-BFGS on sampled data chunks to provide
warm-started parameters for the streaming Gauss-Newton phase.
"""

from __future__ import annotations

import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import optax  # type: ignore[import-not-found,import-untyped]

from nlsq.streaming.telemetry import get_defense_telemetry
from nlsq.utils.logging import get_logger

if TYPE_CHECKING:
    from jax import Array

    from nlsq.precision.parameter_normalizer import NormalizedModelWrapper
    from nlsq.streaming.hybrid_config import HybridStreamingConfig

_logger = get_logger("warmup_phase")


[docs] @dataclass(frozen=True, slots=True) class WarmupResult: """Result from L-BFGS warmup phase. Attributes: params: Optimized parameters after warmup. cost: Final cost after warmup. iterations: Number of warmup iterations performed. converged: Whether warmup converged. cost_history: Cost at each iteration. """ params: Array cost: float iterations: int converged: bool cost_history: list[float]
[docs] class WarmupPhase: """Phase 1: L-BFGS warmup for initial convergence. This class encapsulates the L-BFGS warmup logic that provides warm-started parameters for the streaming Gauss-Newton phase. L-BFGS provides 5-10x faster convergence to the basin of attraction compared to first-order warmup by using approximate second-order (Hessian) information. Parameters ---------- config : HybridStreamingConfig Configuration for streaming optimization. normalized_model : NormalizedModelWrapper Model wrapper operating in normalized parameter space. Attributes ---------- config : HybridStreamingConfig Configuration object. normalized_model : NormalizedModelWrapper Normalized model wrapper. Notes ----- The 4-Layer Defense Strategy is implemented to prevent warmup divergence: - Layer 1: Warm start detection (skip if already near optimum) - Layer 2: Adaptive initial step size based on relative loss - Layer 3: Cost-increase guard (abort if loss increases beyond tolerance) - Layer 4: Trust region constraint (clip update magnitude) """
[docs] def __init__( self, config: HybridStreamingConfig, normalized_model: NormalizedModelWrapper, ) -> None: """Initialize WarmupPhase. Parameters ---------- config : HybridStreamingConfig Configuration for streaming optimization. normalized_model : NormalizedModelWrapper Model wrapper operating in normalized parameter space. """ self.config = config self.normalized_model = normalized_model # State for 4-layer defense strategy self._initial_loss: float | None = None self._relative_loss: float | None = None self._lr_mode: str | None = None self._clip_count: int = 0 # Residual weighting state (optional) self._residual_weights: jnp.ndarray | None = None
[docs] def set_residual_weights(self, weights: jnp.ndarray | None) -> None: """Set per-group residual weights for weighted least squares. Parameters ---------- weights : array_like or None Per-group weights for weighted MSE computation. """ self._residual_weights = weights
[docs] def run( self, data_source: tuple[jnp.ndarray, jnp.ndarray], initial_params: jnp.ndarray, phase_history: list[dict[str, Any]], best_tracker: dict[str, Any], ) -> dict[str, Any]: """Run Phase 1 L-BFGS warmup. Parameters ---------- data_source : tuple of array_like Data source as (x_data, y_data). initial_params : array_like Initial parameters in normalized space. phase_history : list Phase history list to append records to. best_tracker : dict Dictionary tracking best_params_global and best_cost_global. Returns ------- result : dict Phase 1 result with keys: - 'final_params': Final parameters in normalized space - 'best_params': Best parameters found during warmup - 'best_loss': Best loss value - 'final_loss': Final loss value - 'iterations': Number of iterations performed - 'switch_reason': Reason for switching to Phase 2 - 'warmup_result': WarmupResult dataclass instance """ # Extract data x_data, y_data = data_source x_data = jnp.asarray(x_data, dtype=jnp.float64) y_data = jnp.asarray(y_data, dtype=jnp.float64) current_params = initial_params # Create loss function loss_fn = self._create_loss_fn() # Record telemetry for warmup start telemetry = get_defense_telemetry() telemetry.record_warmup_start() # ===================================================== # LAYER 1: Warm Start Detection # ===================================================== initial_loss = float(loss_fn(current_params, x_data, y_data)) y_variance = float(jnp.var(y_data)) relative_loss = initial_loss / (y_variance + 1e-10) self._initial_loss = initial_loss self._relative_loss = relative_loss if self.config.verbose >= 2: _logger.debug( f"Phase 1 initial assessment: loss={initial_loss:.6e}, " f"y_var={y_variance:.6e}, relative_loss={relative_loss:.6e}" ) # Check warm start threshold if ( self.config.enable_warm_start_detection and relative_loss < self.config.warm_start_threshold ): telemetry.record_layer1_trigger( relative_loss=relative_loss, threshold=self.config.warm_start_threshold ) phase_record = { "phase": 1, "name": "lbfgs_warmup", "iterations": 0, "final_loss": initial_loss, "best_loss": initial_loss, "switch_reason": ( f"Warm start detected (relative_loss={relative_loss:.4e} " f"< {self.config.warm_start_threshold})" ), "timestamp": time.time(), "skipped": True, "warm_start": True, "relative_loss": relative_loss, } phase_history.append(phase_record) if self.config.verbose >= 1: _logger.info( f"Phase 1: Skipping L-BFGS warmup - warm start detected " f"(relative_loss={relative_loss:.4e})" ) warmup_result = WarmupResult( params=current_params, cost=initial_loss, iterations=0, converged=True, cost_history=[initial_loss], ) return { "final_params": current_params, "best_params": current_params, "best_loss": initial_loss, "final_loss": initial_loss, "iterations": 0, "switch_reason": "Warm start detected - skipping L-BFGS warmup", "warm_start": True, "relative_loss": relative_loss, "warmup_result": warmup_result, } # ===================================================== # LAYER 2: Adaptive Initial Step Size Selection # ===================================================== initial_step, lr_mode = self._select_initial_step_size(relative_loss, telemetry) self._lr_mode = lr_mode # Create L-BFGS optimizer optimizer, opt_state = self._create_lbfgs_optimizer( current_params, initial_step ) # Best parameter tracking best_params = current_params best_loss = initial_loss cost_history = [initial_loss] prev_loss = initial_loss loss_value = ( initial_loss # guard: used after loop (in case max_warmup_iterations==0) ) self._clip_count = 0 # Warmup loop for iteration in range(self.config.max_warmup_iterations): current_params, loss_value, grad_norm, opt_state, _line_search_failed = ( self._lbfgs_step( params=current_params, opt_state=opt_state, optimizer=optimizer, loss_fn=loss_fn, x_batch=x_data, y_batch=y_data, iteration=iteration, best_tracker=best_tracker, ) ) cost_history.append(loss_value) # Track best parameters if loss_value < best_loss: best_loss = loss_value best_params = current_params # ===================================================== # LAYER 3: Cost-Increase Guard # ===================================================== if self.config.enable_cost_guard and iteration > 0: result = self._check_cost_guard( iteration=iteration, loss_value=loss_value, best_params=best_params, best_loss=best_loss, lr_mode=lr_mode, relative_loss=relative_loss, phase_history=phase_history, telemetry=telemetry, cost_history=cost_history, ) if result is not None: return result # Check switch criteria after minimum warmup iterations if iteration >= self.config.warmup_iterations: should_switch, reason = self._check_switch_criteria( iteration=iteration, current_loss=loss_value, prev_loss=prev_loss, grad_norm=grad_norm, ) if should_switch: phase_record = { "phase": 1, "name": "lbfgs_warmup", "iterations": iteration + 1, "final_loss": loss_value, "best_loss": best_loss, "switch_reason": reason, "timestamp": time.time(), "lr_mode": lr_mode, "relative_loss": relative_loss, } phase_history.append(phase_record) warmup_result = WarmupResult( params=current_params, cost=loss_value, iterations=iteration + 1, converged=True, cost_history=cost_history, ) return { "final_params": current_params, "best_params": best_params, "best_loss": best_loss, "final_loss": loss_value, "iterations": iteration + 1, "switch_reason": reason, "lr_mode": lr_mode, "relative_loss": relative_loss, "warmup_result": warmup_result, } prev_loss = loss_value # Maximum iterations reached phase_record = { "phase": 1, "name": "lbfgs_warmup", "iterations": self.config.max_warmup_iterations, "final_loss": loss_value, "best_loss": best_loss, "switch_reason": "Maximum iterations reached", "timestamp": time.time(), "lr_mode": lr_mode, "relative_loss": relative_loss, } phase_history.append(phase_record) warmup_result = WarmupResult( params=current_params, cost=loss_value, iterations=self.config.max_warmup_iterations, converged=False, cost_history=cost_history, ) return { "final_params": current_params, "best_params": best_params, "best_loss": best_loss, "final_loss": loss_value, "iterations": self.config.max_warmup_iterations, "switch_reason": "Maximum iterations reached", "lr_mode": lr_mode, "relative_loss": relative_loss, "warmup_result": warmup_result, }
def _create_loss_fn(self) -> Callable: """Create loss function for warmup phase. Returns a single JIT-compiled closure that handles all combinations of variance regularization and residual weighting. Python-level branching selects the code path at trace time; the variance penalty uses ``lax.fori_loop`` with padded group extraction instead of a Python for-loop to prevent re-tracing when the number of groups changes. Returns ------- loss_fn : callable Loss function: loss_fn(params, x_batch, y_batch) -> scalar. """ normalized_model = self.normalized_model enable_var_reg = ( self.config.enable_group_variance_regularization and self.config.group_variance_indices ) var_lambda = self.config.group_variance_lambda enable_weighting = ( self.config.enable_residual_weighting and self._residual_weights is not None ) residual_weights = self._residual_weights # Pre-convert group indices to JAX arrays for lax.fori_loop. # Each group is extracted via dynamic_slice with a fixed max_group_size, # then masked to the actual group size for correct variance computation. if enable_var_reg and self.config.group_variance_indices: group_starts = jnp.array( [s for s, _e in self.config.group_variance_indices], dtype=jnp.int32 ) group_sizes = jnp.array( [e - s for s, e in self.config.group_variance_indices], dtype=jnp.int32 ) n_groups = len(self.config.group_variance_indices) max_group_size = max(e - s for s, e in self.config.group_variance_indices) else: group_starts = jnp.zeros(0, dtype=jnp.int32) group_sizes = jnp.zeros(0, dtype=jnp.int32) n_groups = 0 max_group_size = 0 @jax.jit def loss_fn( params: jnp.ndarray, x_batch: jnp.ndarray, y_batch: jnp.ndarray ) -> jnp.ndarray: predictions = normalized_model(x_batch, *params) residuals = y_batch - predictions # Base loss: weighted or unweighted MSE if enable_weighting: group_idx = x_batch[:, 0].astype(jnp.int32) assert residual_weights is not None weights = residual_weights[group_idx] base_loss = jnp.sum(weights * residuals**2) / jnp.sum(weights) else: base_loss = jnp.mean(residuals**2) # Variance regularization via lax.fori_loop (fixed XLA trace # regardless of number of groups) if enable_var_reg: def var_body(i, penalty): start = group_starts[i] size = group_sizes[i] # Extract with fixed max_group_size, mask to actual size group_params = jax.lax.dynamic_slice( params, (start,), (max_group_size,) ) mask = jnp.arange(max_group_size) < size # Masked variance: Var = E[x^2] - E[x]^2 over valid elements n = jnp.maximum(size, 1) # avoid division by zero mean_val = jnp.sum(jnp.where(mask, group_params, 0.0)) / n sq_diff = jnp.where(mask, (group_params - mean_val) ** 2, 0.0) group_var = jnp.sum(sq_diff) / n return penalty + group_var variance_penalty = jax.lax.fori_loop( 0, n_groups, var_body, jnp.array(0.0) ) return base_loss + var_lambda * variance_penalty return base_loss return loss_fn def _select_initial_step_size( self, relative_loss: float, telemetry: Any ) -> tuple[float, str]: """Select initial step size based on relative loss (Layer 2). Parameters ---------- relative_loss : float Initial loss relative to data variance. telemetry : DefenseLayerTelemetry Telemetry collector. Returns ------- initial_step : float Selected initial step size. lr_mode : str Mode name for logging. """ if self.config.enable_adaptive_warmup_lr: if relative_loss < 0.1: initial_step = self.config.lbfgs_refinement_step_size lr_mode = "refinement" elif relative_loss < 1.0: initial_step = 0.5 lr_mode = "careful" else: initial_step = self.config.lbfgs_exploration_step_size lr_mode = "exploration" telemetry.record_layer2_lr_mode(mode=lr_mode, relative_loss=relative_loss) if self.config.verbose >= 2: _logger.debug( f"L-BFGS adaptive: mode={lr_mode}, step={initial_step:.2f}, " f"rel_loss={relative_loss:.4e}" ) else: initial_step = self.config.lbfgs_initial_step_size lr_mode = "fixed" telemetry.record_layer2_lr_mode(mode=lr_mode, relative_loss=relative_loss) return initial_step, lr_mode def _create_lbfgs_optimizer( self, params: jnp.ndarray, initial_step_size: float ) -> tuple[optax.GradientTransformationExtraArgs, optax.OptState]: """Create L-BFGS optimizer with optax. Parameters ---------- params : array_like Initial parameters in normalized space. initial_step_size : float Initial step size for L-BFGS line search. Returns ------- optimizer : optax.GradientTransformationExtraArgs L-BFGS optimizer instance. opt_state : optax.OptState Initial optimizer state. """ line_search_type = self.config.lbfgs_line_search if line_search_type == "backtracking": linesearch = optax.scale_by_backtracking_linesearch( max_backtracking_steps=20, slope_rtol=1e-4, decrease_factor=0.8, increase_factor=1.5, max_learning_rate=initial_step_size, ) else: linesearch = optax.scale_by_zoom_linesearch( max_linesearch_steps=20, initial_guess_strategy="one", ) optimizer = optax.lbfgs( learning_rate=initial_step_size, memory_size=self.config.lbfgs_history_size, scale_init_precond=True, linesearch=linesearch, ) if self.config.gradient_clip_value is not None: optimizer = optax.chain( optax.clip_by_global_norm(self.config.gradient_clip_value), optimizer, ) opt_state = optimizer.init(params) return optimizer, opt_state def _lbfgs_step( self, params: jnp.ndarray, opt_state: optax.OptState, optimizer: optax.GradientTransformationExtraArgs, loss_fn: Callable, x_batch: jnp.ndarray, y_batch: jnp.ndarray, iteration: int, best_tracker: dict[str, Any], ) -> tuple[jnp.ndarray, float, float, optax.OptState, bool]: """Perform single L-BFGS optimization step. Parameters ---------- params : array_like Current parameters in normalized space. opt_state : optax.OptState Current optimizer state. optimizer : optax.GradientTransformationExtraArgs L-BFGS optimizer instance. loss_fn : callable Loss function. x_batch : array_like Independent variable batch. y_batch : array_like Dependent variable batch. iteration : int Current iteration number. best_tracker : dict Dictionary tracking best_params_global and best_cost_global. Returns ------- new_params : array_like Updated parameters. loss : float Loss value before update. grad_norm : float L2 norm of gradient. new_opt_state : optax.OptState Updated optimizer state. line_search_failed : bool True if line search failed. """ # Validate input parameters if not self._validate_numerics(params, context="at L-BFGS step input"): if self.config.enable_fault_tolerance: return params, float("inf"), float("inf"), opt_state, True raise ValueError("Numerical issues detected in L-BFGS step input") # Compute loss and gradient loss_value, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch) # Validate loss and gradients if not self._validate_numerics( params, loss=float(loss_value), gradients=grads, context="in L-BFGS step" ): if self.config.enable_fault_tolerance: return params, float("inf"), float("inf"), opt_state, True raise ValueError("Numerical issues detected in L-BFGS step") grad_norm = jnp.linalg.norm(grads) def value_fn(p): return loss_fn(p, x_batch, y_batch) try: updates, new_opt_state = optimizer.update( grads, opt_state, params, value=loss_value, grad=grads, value_fn=value_fn, ) line_search_failed = False except Exception as e: if self.config.verbose >= 2: _logger.warning(f"L-BFGS line search failed: {e}") updates = -self.config.lbfgs_initial_step_size * grads new_opt_state = opt_state line_search_failed = True telemetry = get_defense_telemetry() telemetry.record_lbfgs_line_search_failure(iteration, str(e)) # Layer 4: Trust Region Constraint if self.config.enable_step_clipping: original_update_norm = float(jnp.linalg.norm(updates)) max_norm = self.config.max_warmup_step_size updates = self._clip_update_norm(updates, max_norm) if original_update_norm > max_norm: self._clip_count += 1 telemetry = get_defense_telemetry() telemetry.record_layer4_clip( original_norm=original_update_norm, max_norm=max_norm ) new_params = optax.apply_updates(params, updates) # Validate updated parameters if not self._validate_numerics(new_params, context="after L-BFGS update"): if self.config.enable_fault_tolerance: return params, float(loss_value), float(grad_norm), opt_state, True raise ValueError("NaN/Inf in parameters after L-BFGS update") # Track best parameters globally if float(loss_value) < best_tracker.get("best_cost_global", float("inf")): best_tracker["best_cost_global"] = float(loss_value) best_tracker["best_params_global"] = new_params # Record history buffer fill event if iteration == self.config.lbfgs_history_size: telemetry = get_defense_telemetry() telemetry.record_lbfgs_history_fill(iteration) return ( new_params, float(loss_value), float(grad_norm), new_opt_state, line_search_failed, ) @staticmethod def _clip_update_norm(updates: jnp.ndarray, max_norm: float) -> jnp.ndarray: """Clip parameter update vector to maximum L2 norm (JIT-compatible). Parameters ---------- updates : array_like Parameter updates from optimizer. max_norm : float Maximum allowed L2 norm. Returns ------- clipped_updates : array_like Updates with L2 norm <= max_norm. """ update_norm = jnp.linalg.norm(updates) scale = jnp.minimum(1.0, max_norm / (update_norm + 1e-10)) return updates * scale def _validate_numerics( self, params: jnp.ndarray, loss: float | None = None, gradients: jnp.ndarray | None = None, context: str = "", ) -> bool: """Validate numerical stability. Parameters ---------- params : array_like Parameters to validate. loss : float, optional Loss value to validate. gradients : array_like, optional Gradient values to validate. context : str, optional Context string for error messages. Returns ------- is_valid : bool True if all values are finite. """ if not getattr(self.config, "validate_numerics", False): return True if not jnp.all(jnp.isfinite(params)): return False if loss is not None and not jnp.isfinite(loss): return False return gradients is None or bool(jnp.all(jnp.isfinite(gradients))) def _check_switch_criteria( self, iteration: int, current_loss: float, prev_loss: float, grad_norm: float, ) -> tuple[bool, str]: """Check if Phase 1 should switch to Phase 2. Parameters ---------- iteration : int Current iteration number. current_loss : float Current loss value. prev_loss : float Previous loss value. grad_norm : float Current gradient norm. Returns ------- should_switch : bool Whether to switch to Phase 2. reason : str Reason for switching. """ active_criteria = self.config.active_switching_criteria or [] if "max_iter" in active_criteria: if iteration >= self.config.max_warmup_iterations: return True, "Maximum warmup iterations reached" if "gradient" in active_criteria: if grad_norm < self.config.gradient_norm_threshold: return True, "Gradient norm below threshold" if "plateau" in active_criteria: eps = jnp.finfo(jnp.float64).eps relative_change = jnp.abs(current_loss - prev_loss) / ( jnp.abs(prev_loss) + eps ) if relative_change < self.config.loss_plateau_threshold: return True, "Loss plateau detected" return False, "" def _check_cost_guard( self, iteration: int, loss_value: float, best_params: jnp.ndarray, best_loss: float, lr_mode: str, relative_loss: float, phase_history: list[dict[str, Any]], telemetry: Any, cost_history: list[float], ) -> dict[str, Any] | None: """Check cost-increase guard (Layer 3). Parameters ---------- iteration : int Current iteration. loss_value : float Current loss. best_params : array_like Best parameters found. best_loss : float Best loss found. lr_mode : str Learning rate mode. relative_loss : float Initial relative loss. phase_history : list Phase history list. telemetry : DefenseLayerTelemetry Telemetry collector. cost_history : list Cost history. Returns ------- result : dict or None Result dict if guard triggered, None otherwise. """ assert self._initial_loss is not None cost_increase_ratio = loss_value / self._initial_loss cost_threshold = 1.0 + self.config.cost_increase_tolerance if cost_increase_ratio > cost_threshold: telemetry.record_layer3_trigger( cost_ratio=cost_increase_ratio, tolerance=self.config.cost_increase_tolerance, iteration=iteration, ) if self.config.verbose >= 1: _logger.warning( f"Phase 1: Cost increase guard triggered at iteration " f"{iteration + 1}. Loss {loss_value:.6e} > " f"{self._initial_loss:.6e} * {cost_threshold:.2f}. " f"Reverting to best params (loss={best_loss:.6e})." ) phase_record = { "phase": 1, "name": "lbfgs_warmup", "iterations": iteration + 1, "final_loss": loss_value, "best_loss": best_loss, "switch_reason": ( f"Cost increase guard triggered (ratio={cost_increase_ratio:.4f})" ), "timestamp": time.time(), "cost_guard_triggered": True, "lr_mode": lr_mode, "relative_loss": relative_loss, } phase_history.append(phase_record) warmup_result = WarmupResult( params=best_params, cost=best_loss, iterations=iteration + 1, converged=False, cost_history=cost_history, ) return { "final_params": best_params, "best_params": best_params, "best_loss": best_loss, "final_loss": loss_value, "iterations": iteration + 1, "switch_reason": "Cost increase guard triggered", "cost_guard_triggered": True, "cost_increase_ratio": cost_increase_ratio, "lr_mode": lr_mode, "relative_loss": relative_loss, "warmup_result": warmup_result, } return None
__all__ = ["WarmupPhase", "WarmupResult"]