"""Adaptive Hybrid Streaming Optimizer with Parameter Normalization.
This module implements a four-phase hybrid optimizer that solves three fundamental
issues in streaming optimization:
1. Weak gradient signals from parameter scale imbalance (via normalization)
2. Slow convergence near optimum (via Gauss-Newton)
3. Crude covariance estimation (via exact J^T J accumulation)
The optimizer operates in four phases:
- **Phase 0**: Parameter normalization setup
- **Phase 1**: L-BFGS warmup with adaptive switching
- **Phase 2**: Streaming Gauss-Newton with exact J^T J accumulation
- **Phase 3**: Denormalization and covariance transform
This implementation focuses on Phase 0 setup logic and phase tracking infrastructure.
"""
# mypy: disable-error-code="assignment,misc,return-value,valid-type,union-attr,operator,import-untyped,attr-defined,arg-type"
# Note: mypy errors are mostly assignment/return-value/union-attr issues from
# complex streaming state management. These require deeper refactoring.
from __future__ import annotations
import time
from collections import deque
from collections.abc import Callable
from pathlib import Path
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import optax # type: ignore[import-not-found]
from nlsq.global_optimization.config import GlobalOptimizationConfig
from nlsq.global_optimization.sampling import (
get_sampler,
)
from nlsq.global_optimization.tournament import TournamentSelector
from nlsq.precision.parameter_normalizer import (
NormalizedModelWrapper,
ParameterNormalizer,
)
from nlsq.stability.guard import NumericalStabilityGuard
from nlsq.streaming.hybrid_config import HybridStreamingConfig
from nlsq.streaming.telemetry import (
DefenseLayerTelemetry,
get_defense_telemetry,
reset_defense_telemetry,
)
from nlsq.utils.logging import get_logger
# Lazy import cache for CheckpointManager to avoid circular imports
# Import happens at first checkpoint operation
_lazy_imports: dict = {}
# Module-level logger for warmup defense diagnostics
_logger = get_logger("adaptive_hybrid_streaming")
# Re-export telemetry classes for backwards compatibility
__all__ = [
"AdaptiveHybridStreamingOptimizer",
"DefenseLayerTelemetry",
"get_defense_telemetry",
"reset_defense_telemetry",
]
[docs]
class AdaptiveHybridStreamingOptimizer:
"""Adaptive hybrid streaming optimizer with four-phase optimization.
This optimizer combines parameter normalization, L-BFGS warmup, streaming
Gauss-Newton, and exact covariance computation to provide:
- Fast convergence for parameters with different scales
- Accurate uncertainty estimates on large datasets
- Memory-efficient streaming for unlimited dataset sizes
- Production-ready fault tolerance
The optimization proceeds through four phases:
- **Phase 0**: Setup parameter normalization and bounds transformation
- **Phase 1**: L-BFGS warmup with adaptive switching to Phase 2
- **Phase 2**: Streaming Gauss-Newton with exact J^T J accumulation
- **Phase 3**: Denormalize parameters and transform covariance matrix
Parameters
----------
config : HybridStreamingConfig, optional
Configuration for all phases of optimization. If None, uses default
configuration. See HybridStreamingConfig for details.
Attributes
----------
config : HybridStreamingConfig
Configuration object controlling all phases
current_phase : int
Current optimization phase (0, 1, 2, or 3)
phase_history : list
History of phase transitions with timing information
phase_start_time : float or None
Start time of current phase (seconds since epoch)
normalized_params : jax.Array or None
Current parameters in normalized space
normalizer : ParameterNormalizer or None
Parameter normalizer instance (created in Phase 0)
normalized_model : NormalizedModelWrapper or None
Wrapped model function operating in normalized space
normalized_bounds : tuple of jax.Array or None
Bounds transformed to normalized space
normalization_jacobian : jax.Array or None
Denormalization Jacobian for covariance transform
Examples
--------
Basic usage with default configuration:
>>> from nlsq import AdaptiveHybridStreamingOptimizer, HybridStreamingConfig
>>> import jax.numpy as jnp
>>> config = HybridStreamingConfig()
>>> optimizer = AdaptiveHybridStreamingOptimizer(config)
With bounds-based normalization:
>>> config = HybridStreamingConfig(
... normalize=True,
... normalization_strategy='bounds'
... )
>>> optimizer = AdaptiveHybridStreamingOptimizer(config)
With custom warmup settings:
>>> config = HybridStreamingConfig(
... warmup_iterations=300,
... lbfgs_initial_step_size=0.5,
... gauss_newton_tol=1e-10
... )
>>> optimizer = AdaptiveHybridStreamingOptimizer(config)
See Also
--------
HybridStreamingConfig : Configuration for all phases
ParameterNormalizer : Parameter normalization implementation
curve_fit : High-level interface with method='hybrid_streaming'
Notes
-----
Based on Adaptive Hybrid Streaming Optimizer specification:
``agent-os/specs/2025-12-18-adaptive-hybrid-streaming-optimizer/spec.md``
"""
[docs]
def __init__(self, config: HybridStreamingConfig | None = None):
"""Initialize adaptive hybrid streaming optimizer.
Parameters
----------
config : HybridStreamingConfig, optional
Configuration for all phases. If None, uses default configuration.
"""
# Store configuration
self.config = config if config is not None else HybridStreamingConfig()
# Phase tracking infrastructure
self.current_phase: int = 0
self.phase_history: deque[dict[str, Any]] = deque(maxlen=100)
self.phase_start_time: float | None = None
self.normalized_params: jnp.ndarray | None = None
# Phase 0: Normalization components (created during setup)
self.normalizer: ParameterNormalizer | None = None
self.normalized_model: NormalizedModelWrapper | None = None
self.normalized_bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None
self.normalization_jacobian: jnp.ndarray | None = None
# Store original model and parameters (for reference)
self.original_model: Callable | None = None
self.original_p0: jnp.ndarray | None = None
self.original_bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None
# Fault tolerance components
self.stability_guard = NumericalStabilityGuard()
self.best_params_global: jnp.ndarray | None = None
self.best_cost_global: float = float("inf")
self.checkpoint_counter: int = 0
self.retry_count: int = 0
# Phase-specific state for checkpointing
self.phase1_optimizer_state: optax.OptState | None = None
self.phase2_JTJ_accumulator: jnp.ndarray | None = None
self.phase2_JTr_accumulator: jnp.ndarray | None = None
# Checkpoint manager (lazy initialized on first use)
self._checkpoint_manager = None
# Multi-device support
self.device_info: dict[str, Any] | None = None
self.multi_device_config: dict[str, Any] | None = None
# Multi-start optimization with tournament selection
self.multistart_candidates: jnp.ndarray | None = None
self.tournament_selector = None
self.multistart_best_candidate: jnp.ndarray | None = None
self.multistart_diagnostics: dict[str, Any] | None = None
# 4-Layer Defense Strategy state for warmup divergence prevention
self._warmup_initial_loss: float | None = None
self._warmup_relative_loss: float | None = None
self._warmup_lr_mode: str | None = None
self._warmup_clip_count: int = 0
# Residual weighting state for weighted least squares
# Allows domain-specific weighting of residuals during optimization
self._residual_weights_jax: jnp.ndarray | None = None # Per-group weights
# Pre-compiled Jacobian function (set up in _setup_jacobian_fn)
# This avoids recompilation overhead in _compute_jacobian_chunk
self._jacobian_fn_compiled: Callable | None = None
# Pre-compiled cost-only function (set up in _setup_cost_fn)
# Used for efficient new_cost evaluation in Gauss-Newton iterations
self._cost_fn_compiled: Callable | None = None
# Pre-compiled scan functions (set up in _setup_scan_functions)
# Used for efficient chunk-based accumulation with JAX lax.scan
self._jtj_scan_body_compiled: bool = False
self._cost_scan_body_compiled: bool = False
def _setup_normalization(
self,
model: Callable,
p0: jnp.ndarray,
bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
) -> None:
"""Setup parameter normalization (Phase 0).
This method determines the normalization strategy based on config and
inputs, creates the ParameterNormalizer instance, wraps the model
function, transforms bounds to normalized space, and stores the
normalization Jacobian for Phase 3 covariance transform.
Parameters
----------
model : Callable
User model function with signature: ``model(x, *params) -> predictions``
p0 : array_like
Initial parameter guess of shape (n_params,)
bounds : tuple of array_like, optional
Parameter bounds as (lb, ub) where lb and ub are arrays of shape
(n_params,). If None, no bounds are applied.
Notes
-----
This method sets up the following attributes:
- self.normalizer: ParameterNormalizer instance
- self.normalized_model: NormalizedModelWrapper instance
- self.normalized_bounds: Transformed bounds in normalized space
- self.normalization_jacobian: Jacobian for covariance transform
- self.normalized_params: Initial parameters in normalized space
- self.original_model, self.original_p0, self.original_bounds: References
The normalization strategy is determined as follows:
- If config.normalize is False: Use 'none' strategy (identity)
- Otherwise: Use config.normalization_strategy ('auto', 'bounds', 'p0', 'none')
- 'auto' selects 'bounds' if bounds provided, else 'p0'
"""
# Store original inputs
self.original_model = model
self.original_p0 = jnp.asarray(p0, dtype=jnp.float64)
self.original_bounds = bounds
# Determine normalization strategy
if not self.config.normalize:
# Normalization disabled: use identity transform
strategy = "none"
else:
# Use configured strategy
strategy = self.config.normalization_strategy
# Create ParameterNormalizer
self.normalizer = ParameterNormalizer(
p0=self.original_p0, bounds=bounds, strategy=strategy
)
# Create NormalizedModelWrapper
self.normalized_model = NormalizedModelWrapper(
model_fn=model, normalizer=self.normalizer
)
# Transform bounds to normalized space
if bounds is not None:
lb_normalized, ub_normalized = self.normalizer.transform_bounds()
self.normalized_bounds = (lb_normalized, ub_normalized)
else:
self.normalized_bounds = None
# Store normalization Jacobian for Phase 3 covariance transform
self.normalization_jacobian = self.normalizer.normalization_jacobian
# Initialize normalized parameters
self.normalized_params = self.normalizer.normalize(self.original_p0)
# Generate multi-start candidates if enabled
if self.config.enable_multistart:
self._generate_multistart_candidates(bounds)
# Record Phase 0 completion in history
phase_record = {
"phase": 0,
"name": "normalization_setup",
"strategy": self.normalizer.strategy,
"timestamp": time.time(),
"normalized_params_shape": self.normalized_params.shape,
"has_bounds": bounds is not None,
}
self.phase_history.append(phase_record)
# Pre-compile Jacobian function for Phase 2 performance
self._setup_jacobian_fn()
# Initialize residual weights from config if enabled
self._setup_residual_weights()
def _setup_residual_weights(self) -> None:
"""Initialize residual weights from config.
This method sets up per-group weights for weighted least squares
optimization. When enabled, residuals are weighted during loss
computation, allowing users to assign different importance to
different groups of data points.
The weights are stored as JAX arrays for efficient lookup during
loss computation.
Notes
-----
Residual weighting is useful for:
- Heteroscedastic data (varying noise levels)
- Emphasizing certain regions of the data
- Domain-specific weighting schemes (e.g., XPCS shear-sensitivity)
"""
if not self.config.enable_residual_weighting:
return
if self.config.residual_weights is None:
_logger.warning(
"Residual weighting enabled but no weights provided. "
"Residual weighting will be disabled."
)
return
# Convert to JAX arrays for efficient computation
self._residual_weights_jax = jnp.asarray(
self.config.residual_weights, dtype=jnp.float64
)
_logger.info(
f"Residual weighting enabled: n_weights={len(self._residual_weights_jax)}, "
f"weight_range=[{float(self._residual_weights_jax.min()):.3f}, "
f"{float(self._residual_weights_jax.max()):.3f}]"
)
[docs]
def set_residual_weights(self, weights: np.ndarray) -> None:
"""Set residual weights for weighted least squares optimization.
This method allows updating weights during optimization, for example
when weights need to be recomputed based on current parameter estimates.
Parameters
----------
weights : np.ndarray
Per-group weights of shape (n_groups,). Higher weights give more
importance to residuals in that group. The group index for each
data point is determined by the first column of x_data.
Notes
-----
Weights must be positive. The weighted MSE is computed as:
wMSE = sum(w[group_idx] * residuals^2) / sum(w[group_idx])
"""
self._residual_weights_jax = jnp.asarray(weights, dtype=jnp.float64)
_logger.debug(
f"Updated residual weights: range=[{float(self._residual_weights_jax.min()):.3f}, "
f"{float(self._residual_weights_jax.max()):.3f}]"
)
def _setup_jacobian_fn(self) -> None:
"""Pre-compile the Jacobian function for efficient Phase 2 computation.
This method creates a JIT-compiled Jacobian function that avoids
recompilation overhead in _compute_jacobian_chunk. The function is
stored in self._jacobian_fn_compiled and reuses the normalized_model.
The Jacobian is computed using reverse-mode AD (jacrev) vectorized
over data points (vmap). Pre-compilation provides 15-25% speedup
in Phase 2 by avoiding repeated function tracing.
Notes
-----
Must be called after _setup_normalization sets self.normalized_model.
The compiled function has signature: (params, x_chunk) -> J_chunk
where J_chunk has shape (n_points, n_params).
"""
if self.normalized_model is None:
raise RuntimeError(
"_setup_jacobian_fn must be called after _setup_normalization"
)
# Capture normalized_model in closure at setup time
normalized_model = self.normalized_model
def compute_jacobian_core(
params: jnp.ndarray, x_chunk: jnp.ndarray
) -> jnp.ndarray:
"""Core Jacobian computation using vmap + jacrev.
Parameters
----------
params : array_like
Parameters in normalized space of shape (n_params,)
x_chunk : array_like
Data chunk of shape (n_points,) or (n_points, n_features)
Returns
-------
J_chunk : array_like
Jacobian matrix of shape (n_points, n_params)
"""
def model_at_point(p, x_single):
return normalized_model(x_single, *p)
# jacrev computes gradient w.r.t. first argument (params)
# vmap over x_chunk to get Jacobian row for each point
return jax.vmap(lambda x: jax.jacrev(model_at_point, argnums=0)(params, x))(
x_chunk
)
# JIT compile the Jacobian function
# Note: We don't use static_argnums here since params shape is fixed
# but values change. The function traces once per unique shape combination.
self._jacobian_fn_compiled = jax.jit(compute_jacobian_core)
# Optionally warm up the compiled function with a small test
# This triggers compilation eagerly rather than on first use
if self.config.verbose >= 2:
_logger.debug("Pre-compiled Jacobian function for Phase 2")
# Also set up the cost-only function
self._setup_cost_fn()
def _setup_cost_fn(self) -> None:
"""Pre-compile the cost-only function for efficient Gauss-Newton iterations.
This method creates a JIT-compiled function that computes only the sum of
squared residuals (cost) without computing the Jacobian. This is used in
`_gauss_newton_iteration` to evaluate the cost at new parameters after
taking a step.
The cost function is significantly faster than re-computing the full
Jacobian, providing 20-30% speedup in Phase 2 by avoiding redundant
model evaluations.
Notes
-----
Must be called after _setup_normalization sets self.normalized_model.
The compiled function has signature:
(params, x_data, y_data, chunk_size) -> total_cost
"""
if self.normalized_model is None:
raise RuntimeError(
"_setup_cost_fn must be called after _setup_normalization"
)
# Capture normalized_model in closure at setup time
normalized_model = self.normalized_model
def compute_chunk_cost(
params: jnp.ndarray, x_chunk: jnp.ndarray, y_chunk: jnp.ndarray
) -> float:
"""Compute cost for a single chunk.
Parameters
----------
params : array_like
Parameters in normalized space of shape (n_params,)
x_chunk : array_like
Data chunk of shape (n_points,) or (n_points, n_features)
y_chunk : array_like
Target chunk of shape (n_points,)
Returns
-------
cost : float
Sum of squared residuals for this chunk
"""
predictions = normalized_model(x_chunk, *params)
residuals = y_chunk - predictions
return jnp.sum(residuals**2)
# JIT compile the chunk cost function
self._cost_fn_compiled = jax.jit(compute_chunk_cost)
if self.config.verbose >= 2:
_logger.debug("Pre-compiled cost function for Phase 2")
# Also set up scan functions for efficient loop-free accumulation
self._setup_scan_functions()
def _setup_scan_functions(self) -> None:
"""Initialize scan function infrastructure for efficient chunk accumulation.
This method prepares the optimizer to use JAX lax.scan for chunk-based
operations, eliminating Python loop overhead and enabling XLA fusion.
The scan approach provides 10-15% speedup by:
- Eliminating Python interpreter overhead in hot loops
- Enabling XLA to fuse operations across chunks
- Reducing memory allocation overhead
Notes
-----
The actual scan body functions are created inline in _accumulate_jtj_jtr_scan
and _compute_cost_scan, as they need to capture the current params in a closure.
This setup method primarily validates that prerequisites are met.
"""
if self.normalized_model is None:
raise RuntimeError(
"_setup_scan_functions must be called after _setup_normalization"
)
# Mark scan functions as available (bodies created inline in each method)
self._jtj_scan_body_compiled = True # Flag indicating scan is available
self._cost_scan_body_compiled = True
if self.config.verbose >= 2:
_logger.debug("Scan-based accumulation enabled for Phase 2")
def _use_scan_for_accumulation(self) -> bool:
"""Determine whether to use JAX scan or Python loops for chunk accumulation.
Returns True if JAX lax.scan should be used, False for Python loops.
The decision is based on:
- config.loop_strategy: 'auto', 'scan', or 'loop'
- For 'auto': GPU/TPU use scan (better XLA fusion), CPU uses loops
(lower tracing overhead)
Returns
-------
use_scan : bool
True to use jax.lax.scan, False to use Python for loops
Notes
-----
Benchmarks show:
- CPU: Python loops are ~10x faster due to scan tracing overhead
- GPU: Scan is faster due to kernel launch overhead in Python loops
"""
strategy = self.config.loop_strategy
if strategy == "scan":
return True
elif strategy == "loop":
return False
else: # 'auto'
# Detect backend from first device
devices = jax.devices()
if devices:
platform = devices[0].platform
# Use scan on GPU/TPU, loops on CPU
return platform in ("gpu", "cuda", "rocm", "tpu")
return False # Default to loops if no devices detected
def _prepare_chunked_data(
self,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
chunk_size: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:
"""Prepare data for scan by padding and reshaping into fixed-size chunks.
Parameters
----------
x_data : array_like
Full x data of shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data of shape (n_points,)
chunk_size : int
Size of each chunk
Returns
-------
x_chunks : array_like
Reshaped x data of shape (n_chunks, chunk_size, ...)
y_chunks : array_like
Reshaped y data of shape (n_chunks, chunk_size)
mask_chunks : array_like
Validity mask of shape (n_chunks, chunk_size), 1.0 for valid points
n_valid_points : int
Number of valid (non-padded) points
Notes
-----
Pads data with zeros to make n_points evenly divisible by chunk_size.
Uses a mask to ensure padded points don't contribute to cost or gradients.
"""
n_points = x_data.shape[0]
n_chunks = (n_points + chunk_size - 1) // chunk_size
padded_size = n_chunks * chunk_size
pad_size = padded_size - n_points
# Create validity mask (1.0 for valid, 0.0 for padded)
mask = jnp.ones(n_points, dtype=jnp.float64)
if pad_size > 0:
# Pad x_data with zeros
if x_data.ndim == 1:
x_padded = jnp.pad(
x_data, (0, pad_size), mode="constant", constant_values=0
)
else:
x_padded = jnp.pad(
x_data, ((0, pad_size), (0, 0)), mode="constant", constant_values=0
)
# Pad y_data with zeros (mask will exclude these)
y_padded = jnp.pad(
y_data, (0, pad_size), mode="constant", constant_values=0
)
# Pad mask with zeros (invalid)
mask = jnp.pad(mask, (0, pad_size), mode="constant", constant_values=0)
else:
x_padded = x_data
y_padded = y_data
# Reshape into chunks
if x_padded.ndim == 1:
x_chunks = x_padded.reshape(n_chunks, chunk_size)
else:
n_features = x_padded.shape[1]
x_chunks = x_padded.reshape(n_chunks, chunk_size, n_features)
y_chunks = y_padded.reshape(n_chunks, chunk_size)
mask_chunks = mask.reshape(n_chunks, chunk_size)
return x_chunks, y_chunks, mask_chunks, n_points
def _get_padded_data(
self,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
chunk_size: int,
) -> tuple[jnp.ndarray, jnp.ndarray, int, int]:
"""Get padded flat arrays for dynamic_slice-based scan, with caching.
B006: Avoids pre-stacking data into (n_chunks, chunk_size, ...) arrays.
Instead, pads flat arrays minimally and caches them for reuse across
iterations (x/y data is constant, only params change).
Returns
-------
x_padded : array_like
Padded x data, flat (padded_size,) or (padded_size, n_features)
y_padded : array_like
Padded y data, flat (padded_size,)
n_points : int
Number of valid (non-padded) points
n_chunks : int
Number of chunks
"""
# Check cache validity using data content hash (not id() which can be recycled by GC)
cache = getattr(self, "_padded_cache", None)
if cache is not None:
c_hash_x, c_hash_y, c_cs, c_xp, c_yp, c_np, c_nc = cache
# Use shape + first/last/sum as a fast content fingerprint
x_flat = x_data.ravel()
x_hash = (
x_data.shape,
float(x_flat[0]),
float(x_flat[-1]),
float(jnp.sum(x_data)),
)
y_hash = (
y_data.shape,
float(y_data.ravel()[0]),
float(y_data.ravel()[-1]),
float(jnp.sum(y_data)),
)
if c_hash_x == x_hash and c_hash_y == y_hash and c_cs == chunk_size:
return c_xp, c_yp, c_np, c_nc
n_points = x_data.shape[0]
n_chunks = (n_points + chunk_size - 1) // chunk_size
pad_size = n_chunks * chunk_size - n_points
if pad_size > 0:
if x_data.ndim == 1:
x_padded = jnp.pad(x_data, (0, pad_size))
else:
x_padded = jnp.pad(x_data, ((0, pad_size), (0, 0)))
y_padded = jnp.pad(y_data, (0, pad_size))
else:
x_padded = x_data
y_padded = y_data
x_flat_w = x_data.ravel()
x_hash = (
x_data.shape,
float(x_flat_w[0]),
float(x_flat_w[-1]),
float(jnp.sum(x_data)),
)
y_flat_w = y_data.ravel()
y_hash = (
y_data.shape,
float(y_flat_w[0]),
float(y_flat_w[-1]),
float(jnp.sum(y_data)),
)
self._padded_cache = (
x_hash,
y_hash,
chunk_size,
x_padded,
y_padded,
n_points,
n_chunks,
)
return x_padded, y_padded, n_points, n_chunks
[docs]
def clear_cache(self) -> None:
"""Release cached padded arrays to free memory.
Call this after optimization completes or when reusing the optimizer
with different data. The cache is automatically invalidated when
data identity changes, but explicit clearing frees memory sooner.
"""
self._padded_cache = None
def _accumulate_jtj_jtr_scan(
self,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, float]:
"""Accumulate J^T J and J^T r using JAX lax.scan for efficiency.
This is the scan-based version of the chunk accumulation loop,
providing 10-15% speedup by eliminating Python loop overhead.
Parameters
----------
x_data : array_like
Full x data of shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data of shape (n_points,)
params : array_like
Current parameters in normalized space of shape (n_params,)
Returns
-------
JTJ : array_like
Accumulated J^T J of shape (n_params, n_params)
JTr : array_like
Accumulated J^T r of shape (n_params,)
total_cost : float
Total sum of squared residuals
Notes
-----
Uses jax.lax.scan instead of Python for loop, enabling XLA fusion
across chunks and eliminating interpreter overhead.
Uses masking to handle padding correctly for non-divisible data sizes.
"""
chunk_size = self.config.chunk_size
n_params = len(params)
# B006: Use cached padded flat arrays + dynamic_slice instead of
# pre-stacking into (n_chunks, chunk_size, ...) arrays.
x_padded, y_padded, n_points, n_chunks = self._get_padded_data(
x_data, y_data, chunk_size
)
# Capture functions and data for scan body closure
normalized_model = self.normalized_model
jacobian_fn = self._jacobian_fn_compiled
x_ndim = x_padded.ndim
chunk_indices = jnp.arange(n_chunks)
def scan_body(carry, chunk_idx):
"""Scan body using dynamic_slice from flat padded arrays."""
JTJ, JTr, total_cost = carry
start = chunk_idx * chunk_size
# Dynamic slice from flat padded arrays (no pre-stacking needed)
if x_ndim == 1:
x_chunk = jax.lax.dynamic_slice(x_padded, (start,), (chunk_size,))
else:
x_chunk = jax.lax.dynamic_slice(
x_padded, (start, 0), (chunk_size, x_padded.shape[1])
)
y_chunk = jax.lax.dynamic_slice(y_padded, (start,), (chunk_size,))
# Compute mask on-the-fly (tiny chunk_size array, not full n_points)
mask = jnp.where(start + jnp.arange(chunk_size) < n_points, 1.0, 0.0)
# Compute predictions and residuals
predictions = normalized_model(x_chunk, *params)
residuals = y_chunk - predictions
# Apply mask to residuals (zero out padded points)
masked_residuals = residuals * mask
# Compute Jacobian
if jacobian_fn is not None:
J_chunk = jacobian_fn(params, x_chunk)
else:
def model_at_x(p, x_single):
return normalized_model(x_single, *p)
J_chunk = jax.vmap(
lambda x: jax.jacrev(model_at_x, argnums=0)(params, x)
)(x_chunk)
# Apply mask to Jacobian rows (zero out padded point gradients)
masked_J = J_chunk * mask[:, None]
# Accumulate with masked values
JTJ_new = JTJ + masked_J.T @ masked_J
JTr_new = JTr + masked_J.T @ masked_residuals
cost_new = total_cost + jnp.sum(masked_residuals**2)
return (JTJ_new, JTr_new, cost_new), None
# Initialize carry
init_carry = (
jnp.zeros((n_params, n_params)),
jnp.zeros(n_params),
jnp.array(0.0),
)
# Run scan over chunk indices (not pre-stacked data arrays)
(JTJ, JTr, total_cost), _ = jax.lax.scan(
scan_body,
init_carry,
chunk_indices,
)
# Store accumulators for checkpointing
self.phase2_JTJ_accumulator = JTJ
self.phase2_JTr_accumulator = JTr
return JTJ, JTr, float(total_cost)
def _compute_cost_scan(
self,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> float:
"""Compute total cost using JAX lax.scan for efficiency.
This is the scan-based version of cost computation, providing
10-15% speedup by eliminating Python loop overhead.
Parameters
----------
params : array_like
Parameters in normalized space of shape (n_params,)
x_data : array_like
Full x data of shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data of shape (n_points,)
Returns
-------
total_cost : float
Total sum of squared residuals
Notes
-----
Uses masking to handle padding correctly for non-divisible data sizes.
"""
chunk_size = self.config.chunk_size
# B006: Use cached padded flat arrays + dynamic_slice
x_padded, y_padded, n_points, n_chunks = self._get_padded_data(
x_data, y_data, chunk_size
)
# Capture model and data for scan body closure
normalized_model = self.normalized_model
x_ndim = x_padded.ndim
chunk_indices = jnp.arange(n_chunks)
def scan_body(carry, chunk_idx):
"""Scan body for masked cost computation using dynamic_slice."""
start = chunk_idx * chunk_size
if x_ndim == 1:
x_chunk = jax.lax.dynamic_slice(x_padded, (start,), (chunk_size,))
else:
x_chunk = jax.lax.dynamic_slice(
x_padded, (start, 0), (chunk_size, x_padded.shape[1])
)
y_chunk = jax.lax.dynamic_slice(y_padded, (start,), (chunk_size,))
mask = jnp.where(start + jnp.arange(chunk_size) < n_points, 1.0, 0.0)
predictions = normalized_model(x_chunk, *params)
residuals = y_chunk - predictions
masked_residuals = residuals * mask
return carry + jnp.sum(masked_residuals**2), None
# Initialize carry
init_carry = jnp.array(0.0)
# Run scan over chunk indices
total_cost, _ = jax.lax.scan(
scan_body,
init_carry,
chunk_indices,
)
return float(total_cost)
def _compute_cost_only(
self,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> float:
"""Compute total cost (sum of squared residuals) without Jacobian.
This method efficiently computes only the cost at the given parameters,
without computing the Jacobian. Used in Gauss-Newton iterations to
evaluate the cost at new parameters after taking a step.
Parameters
----------
params : array_like
Parameters in normalized space of shape (n_params,)
x_data : array_like
Full x data of shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data of shape (n_points,)
Returns
-------
total_cost : float
Total sum of squared residuals
Performance
-----------
Dispatches between JAX scan (GPU/TPU) and Python loops (CPU) based on
config.loop_strategy for optimal performance on each backend.
"""
# Dispatch based on backend: scan for GPU/TPU, loops for CPU
if self._use_scan_for_accumulation():
# Use JAX scan for GPU/TPU (better XLA fusion, reduced kernel launches)
return self._compute_cost_scan(params, x_data, y_data)
# Use Python loops for CPU (lower tracing overhead)
chunk_size = self.config.chunk_size
n_points = len(x_data)
total_cost = 0.0
for i in range(0, n_points, chunk_size):
x_chunk = x_data[i : i + chunk_size]
y_chunk = y_data[i : i + chunk_size]
# Use pre-compiled cost function if available
if self._cost_fn_compiled is not None:
chunk_cost = self._cost_fn_compiled(params, x_chunk, y_chunk)
else:
# Fallback to inline computation
predictions = self.normalized_model(x_chunk, *params)
residuals = y_chunk - predictions
chunk_cost = float(jnp.sum(residuals**2))
total_cost += chunk_cost
return total_cost
def _compute_cost_with_variance_regularization(
self,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> float:
"""Compute total cost including group variance regularization.
This is a convenience method that computes the MSE cost plus any
group variance regularization penalty.
Parameters
----------
params : array_like
Parameters in normalized space of shape (n_params,)
x_data : array_like
Full x data of shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data of shape (n_points,)
Returns
-------
total_cost : float
Total cost including MSE and variance regularization
"""
# Base MSE cost
total_cost = self._compute_cost_only(params, x_data, y_data)
# Add group variance regularization if enabled
if (
self.config.enable_group_variance_regularization
and self.config.group_variance_indices
):
n_points = len(x_data)
var_lambda = self.config.group_variance_lambda
for start, end in self.config.group_variance_indices:
group_params = params[start:end]
group_var = jnp.var(group_params)
total_cost += (
var_lambda
* float(jnp.where(jnp.isfinite(group_var), group_var, 0.0))
* n_points
)
return total_cost
def _generate_multistart_candidates(
self,
bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
) -> None:
"""Generate multi-start candidates using LHS or other sampling methods.
Parameters
----------
bounds : tuple of array_like, optional
Parameter bounds as (lb, ub).
Notes
-----
Generates n_starts candidates using the configured sampler (LHS, Sobol, Halton).
If center_on_p0 is True, centers samples around the initial guess p0.
Stores candidates in self.multistart_candidates.
"""
import jax
import numpy as np
n_params = len(self.original_p0)
n_starts = self.config.n_starts
# Get sampler function
sampler = get_sampler(self.config.multistart_sampler)
# Generate samples in [0, 1] hypercube
rng_key = jax.random.PRNGKey(42) # Fixed seed for reproducibility
samples = sampler(n_starts, n_params, rng_key=rng_key)
samples = np.asarray(samples)
# Scale samples to bounds or around p0
if bounds is not None and self.config.center_on_p0:
# Center samples around p0 within bounds
lb, ub = np.asarray(bounds[0]), np.asarray(bounds[1])
p0 = np.asarray(self.original_p0)
# Scale factor controls how much of the bounds to explore
scale = self.config.scale_factor
# Center around p0 with scaled range
range_half = (ub - lb) * scale / 2
center_lb = np.maximum(lb, p0 - range_half)
center_ub = np.minimum(ub, p0 + range_half)
# Scale samples to centered bounds
candidates = center_lb + samples * (center_ub - center_lb)
elif bounds is not None:
# Scale samples to full bounds
lb, ub = np.asarray(bounds[0]), np.asarray(bounds[1])
candidates = lb + samples * (ub - lb)
else:
# No bounds: scale around p0 with heuristic range
p0 = np.asarray(self.original_p0)
scale_factor = self.config.scale_factor
# Use p0 magnitude as scale (avoid zero)
p0_scale = np.abs(p0) + 0.1
range_half = p0_scale * scale_factor
# Center samples around p0
candidates = p0 + (samples - 0.5) * 2 * range_half
self.multistart_candidates = jnp.asarray(candidates)
def _run_tournament_selection(
self,
data_source: tuple[jnp.ndarray, jnp.ndarray],
model: Callable,
) -> jnp.ndarray:
"""Run tournament selection to find the best starting candidate.
Parameters
----------
data_source : tuple
Data as (x_data, y_data).
model : Callable
Model function.
Returns
-------
best_params : array_like
Best starting parameters in normalized space.
"""
import numpy as np
# Create GlobalOptimizationConfig from HybridStreamingConfig
global_config = GlobalOptimizationConfig(
n_starts=self.config.n_starts,
elimination_rounds=self.config.elimination_rounds,
elimination_fraction=self.config.elimination_fraction,
batches_per_round=self.config.batches_per_round,
)
# Convert candidates to normalized space
normalized_candidates = np.array(
[
np.asarray(self.normalizer.normalize(jnp.asarray(c)))
for c in self.multistart_candidates
]
)
# Create tournament selector
self.tournament_selector = TournamentSelector(
candidates=normalized_candidates,
config=global_config,
)
# Create data batch generator
x_data, y_data = data_source
x_data = np.asarray(x_data)
y_data = np.asarray(y_data)
chunk_size = self.config.chunk_size
n_points = len(x_data)
def data_batch_generator():
# Shuffle indices for each epoch
indices = np.arange(n_points)
np.random.shuffle(indices)
# Yield chunks
for i in range(0, n_points, chunk_size):
batch_idx = indices[i : i + chunk_size]
yield x_data[batch_idx], y_data[batch_idx]
# Repeat if needed for more rounds
for _ in range(
self.config.elimination_rounds * self.config.batches_per_round
):
np.random.shuffle(indices)
for i in range(0, n_points, chunk_size):
batch_idx = indices[i : i + chunk_size]
yield x_data[batch_idx], y_data[batch_idx]
# Run tournament with normalized model
try:
best_candidates = self.tournament_selector.run_tournament(
data_batch_iterator=data_batch_generator(),
model=self.normalized_model,
top_m=1,
)
# Store diagnostics
self.multistart_diagnostics = self.tournament_selector.get_diagnostics()
# Return best candidate
best_normalized = best_candidates[0]
self.multistart_best_candidate = best_normalized
return jnp.asarray(best_normalized)
except Exception as e:
# Fallback: use original p0
import warnings
warnings.warn(
f"Tournament selection failed: {e}. Using p0 as starting point."
)
self.multistart_diagnostics = {"error": str(e), "fallback": True}
return self.normalized_params
def _create_lbfgs_optimizer(
self,
params: jnp.ndarray,
initial_step_size: float | None = None,
) -> tuple[optax.GradientTransformationExtraArgs, optax.OptState]:
"""Create L-BFGS optimizer with optax for Phase 1 warmup.
L-BFGS provides 5-10x faster convergence to the basin of attraction
compared to first-order warmup by using approximate Hessian information.
Parameters
----------
params : array_like
Initial parameters in normalized space
initial_step_size : float, optional
Override initial step size for L-BFGS line search.
If None, uses config.lbfgs_initial_step_size.
Returns
-------
optimizer : optax.GradientTransformationExtraArgs
L-BFGS optimizer instance with line search
opt_state : optax.OptState
Initial optimizer state
Notes
-----
Uses optax.lbfgs with backtracking line search for step acceptance.
The history size is configured via config.lbfgs_history_size (default 10).
Cold start scaffolding: During the first m iterations (where m is the
history size), the Hessian approximation is poor (starts as identity).
The initial_step_size parameter controls how conservative the first
steps are before the history buffer fills.
Examples
--------
Basic L-BFGS with default settings:
>>> optimizer, state = self._create_lbfgs_optimizer(params)
L-BFGS with small initial step (exploration mode):
>>> optimizer, state = self._create_lbfgs_optimizer(params, initial_step_size=0.1)
L-BFGS with large initial step (refinement mode):
>>> optimizer, state = self._create_lbfgs_optimizer(params, initial_step_size=1.0)
"""
# Determine initial step size (learning rate) for line search
step_size = (
initial_step_size
if initial_step_size is not None
else self.config.lbfgs_initial_step_size
)
# Configure line search based on config
line_search_type = self.config.lbfgs_line_search
if line_search_type == "backtracking":
# Backtracking line search with Armijo condition
linesearch = optax.scale_by_backtracking_linesearch(
max_backtracking_steps=20,
slope_rtol=1e-4, # Armijo condition parameter
decrease_factor=0.8,
increase_factor=1.5,
max_learning_rate=step_size,
)
else:
# Default to zoom linesearch (Wolfe conditions)
# This is the default optax.lbfgs linesearch
linesearch = optax.scale_by_zoom_linesearch(
max_linesearch_steps=20,
initial_guess_strategy="one",
)
# Create L-BFGS optimizer
# Note: optax.lbfgs uses the learning_rate to scale the initial step guess
optimizer = optax.lbfgs(
learning_rate=step_size,
memory_size=self.config.lbfgs_history_size,
scale_init_precond=True, # Use scaled identity for cold start
linesearch=linesearch,
)
# Chain with gradient clipping if configured
if self.config.gradient_clip_value is not None:
optimizer = optax.chain(
optax.clip_by_global_norm(self.config.gradient_clip_value),
optimizer,
)
# Initialize optimizer state
opt_state = optimizer.init(params)
return optimizer, opt_state
def _lbfgs_step(
self,
params: jnp.ndarray,
opt_state: optax.OptState,
optimizer: optax.GradientTransformationExtraArgs,
loss_fn: Callable,
x_batch: jnp.ndarray,
y_batch: jnp.ndarray,
iteration: int,
) -> tuple[jnp.ndarray, float, float, optax.OptState, bool]:
"""Perform single L-BFGS optimization step with cold start scaffolding.
Parameters
----------
params : array_like
Current parameters in normalized space
opt_state : optax.OptState
Current optimizer state
optimizer : optax.GradientTransformationExtraArgs
L-BFGS optimizer instance
loss_fn : Callable
Loss function
x_batch : array_like
Independent variable batch
y_batch : array_like
Dependent variable batch
iteration : int
Current iteration number (for cold start detection)
Returns
-------
new_params : array_like
Updated parameters in normalized space
loss : float
Loss value before update
grad_norm : float
L2 norm of gradient
new_opt_state : optax.OptState
Updated optimizer state
line_search_failed : bool
True if line search failed to find acceptable step
Notes
-----
Uses jax.value_and_grad for efficient loss and gradient computation.
Includes NaN/Inf validation if enabled in config.
Cold start scaffolding: During the first m iterations (history_size),
the step is scaled by lbfgs_initial_step_size to prevent overshooting
when the Hessian approximation is poor.
"""
# Validate input parameters
self._validate_numerics(params, context="at L-BFGS step input")
# Compute loss and gradient
loss_value, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
# Validate loss and gradients
if not self._validate_numerics(
params, loss=float(loss_value), gradients=grads, context="in L-BFGS step"
):
# Handle numerical issues
if (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
# Return current params unchanged (fallback)
return params, float("inf"), float("inf"), opt_state, True
else:
raise ValueError("Numerical issues detected in L-BFGS step")
# Compute gradient norm
grad_norm = jnp.linalg.norm(grads)
# L-BFGS requires a value_fn for line search
def value_fn(p):
return loss_fn(p, x_batch, y_batch)
# Apply optimizer updates (L-BFGS with line search)
try:
updates, new_opt_state = optimizer.update(
grads,
opt_state,
params,
value=loss_value,
grad=grads,
value_fn=value_fn,
)
line_search_failed = False
except Exception as e:
# Line search can fail in some cases
if self.config.verbose >= 2:
_logger.warning(f"L-BFGS line search failed: {e}")
# Fall back to gradient descent step
updates = -self.config.lbfgs_initial_step_size * grads
new_opt_state = opt_state
line_search_failed = True
# Record line search failure in telemetry
telemetry = get_defense_telemetry()
telemetry.record_lbfgs_line_search_failure(iteration, str(e))
# Layer 4: Trust Region Constraint - clip update magnitude (JIT-compatible)
if self.config.enable_step_clipping:
# Track original norm for telemetry before clipping
original_update_norm = float(jnp.linalg.norm(updates))
max_norm = self.config.max_warmup_step_size
updates = self._clip_update_norm(updates, max_norm)
# Record Layer 4 telemetry if clipping occurred
if original_update_norm > max_norm:
telemetry = get_defense_telemetry()
telemetry.record_layer4_clip(
original_norm=original_update_norm, max_norm=max_norm
)
new_params = optax.apply_updates(params, updates)
# Validate updated parameters
if not self._validate_numerics(new_params, context="after L-BFGS update"):
# Fallback: keep old parameters
if (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
return params, float(loss_value), float(grad_norm), opt_state, True
else:
raise ValueError("NaN/Inf in parameters after L-BFGS update")
# Track best parameters globally
if float(loss_value) < self.best_cost_global:
self.best_cost_global = float(loss_value)
self.best_params_global = new_params
# Store optimizer state for checkpointing
self.phase1_optimizer_state = new_opt_state
# Record history buffer fill event (once when history is fully populated)
if iteration == self.config.lbfgs_history_size:
telemetry = get_defense_telemetry()
telemetry.record_lbfgs_history_fill(iteration)
return (
new_params,
float(loss_value),
float(grad_norm),
new_opt_state,
line_search_failed,
)
def _create_warmup_loss_fn(self) -> Callable:
"""Create loss function for warmup phase.
Returns
-------
loss_fn : Callable
Loss function with signature: loss_fn(params, x_batch, y_batch) -> scalar_loss
Operates in normalized parameter space and returns mean squared residuals.
Notes
-----
The loss function is JIT-compiled for performance.
When `enable_group_variance_regularization=True`, the loss becomes:
L = MSE + group_variance_lambda * sum(Var(group_i))
where each group_i is defined by `group_variance_indices`.
This prevents per-angle parameters from absorbing angle-dependent physical signals.
When `enable_residual_weighting=True`, the MSE becomes weighted MSE:
wMSE = sum(w[group_idx] * residuals^2) / sum(w[group_idx])
where w[group_idx] are per-group weights. The group index is determined
by the first column of x_data.
"""
# Use normalized model wrapper
normalized_model = self.normalized_model
# Capture config values for closure
enable_var_reg = self.config.enable_group_variance_regularization
var_lambda = self.config.group_variance_lambda
var_indices = self.config.group_variance_indices
# Capture residual weighting state for closure
# Note: We capture the current state; if weights are updated via set_residual_weights,
# a new loss function should be created by calling _create_warmup_loss_fn again
enable_weighting = (
self.config.enable_residual_weighting
and self._residual_weights_jax is not None
)
residual_weights = self._residual_weights_jax
# Pre-convert group indices to JAX arrays for lax.fori_loop.
# Each group is extracted via dynamic_slice with a fixed max_group_size,
# then masked to the actual group size for correct variance computation.
if enable_var_reg and var_indices:
group_starts = jnp.array([s for s, _e in var_indices], dtype=jnp.int32)
group_sizes = jnp.array([e - s for s, e in var_indices], dtype=jnp.int32)
n_groups = len(var_indices)
max_group_size = max(e - s for s, e in var_indices)
else:
group_starts = jnp.zeros(0, dtype=jnp.int32)
group_sizes = jnp.zeros(0, dtype=jnp.int32)
n_groups = 0
max_group_size = 0
@jax.jit
def loss_fn(
params: jnp.ndarray, x_batch: jnp.ndarray, y_batch: jnp.ndarray
) -> jnp.ndarray:
predictions = normalized_model(x_batch, *params)
residuals = y_batch - predictions
# Base loss: weighted or unweighted MSE
if enable_weighting:
group_idx = x_batch[:, 0].astype(jnp.int32)
assert residual_weights is not None
weights = residual_weights[group_idx]
base_loss = jnp.sum(weights * residuals**2) / jnp.sum(weights)
else:
base_loss = jnp.mean(residuals**2)
# Variance regularization via lax.fori_loop (fixed XLA trace
# regardless of number of groups)
if enable_var_reg and var_indices:
def var_body(i, penalty):
start = group_starts[i]
size = group_sizes[i]
# Extract with fixed max_group_size, mask to actual size
group_params = jax.lax.dynamic_slice(
params, (start,), (max_group_size,)
)
mask = jnp.arange(max_group_size) < size
# Masked variance: Var = E[x^2] - E[x]^2 over valid elements
n = jnp.maximum(size, 1) # avoid division by zero
mean_val = jnp.sum(jnp.where(mask, group_params, 0.0)) / n
sq_diff = jnp.where(mask, (group_params - mean_val) ** 2, 0.0)
group_var = jnp.sum(sq_diff) / n
return penalty + group_var
variance_penalty = jax.lax.fori_loop(
0, n_groups, var_body, jnp.array(0.0)
)
return base_loss + var_lambda * variance_penalty
return base_loss
return loss_fn
@staticmethod
def _clip_update_norm(updates: jnp.ndarray, max_norm: float) -> jnp.ndarray:
"""Clip parameter update vector to maximum L2 norm (JIT-compatible).
This is Layer 4 of the 4-layer defense strategy for warmup divergence
prevention. It limits the magnitude of warmup updates to prevent
large steps that could destabilize optimization when near an optimum.
Parameters
----------
updates : array_like
Parameter updates from optimizer
max_norm : float
Maximum allowed L2 norm for the update vector
Returns
-------
clipped_updates : array_like
Updates with L2 norm <= max_norm. If original norm <= max_norm,
returns updates unchanged. Otherwise scales updates to have
exactly max_norm.
Notes
-----
Uses jnp.minimum for JIT compatibility - no Python conditionals.
Small epsilon (1e-10) added to denominator to prevent division by zero.
"""
update_norm = jnp.linalg.norm(updates)
scale = jnp.minimum(1.0, max_norm / (update_norm + 1e-10))
return updates * scale
def _check_phase1_switch_criteria(
self,
iteration: int,
current_loss: float,
prev_loss: float,
grad_norm: float,
) -> tuple[bool, str]:
"""Check if Phase 1 should switch to Phase 2.
Parameters
----------
iteration : int
Current iteration number
current_loss : float
Current loss value
prev_loss : float
Previous loss value
grad_norm : float
Current gradient norm
Returns
-------
should_switch : bool
Whether to switch to Phase 2
reason : str
Reason for switching (or empty if not switching)
Notes
-----
Checks criteria specified in config.active_switching_criteria:
- 'plateau': Loss plateau detection
- 'gradient': Gradient norm below threshold
- 'max_iter': Maximum iterations reached
"""
active_criteria = self.config.active_switching_criteria
# Check max iterations criterion
if "max_iter" in active_criteria:
if iteration >= self.config.max_warmup_iterations:
return True, "Maximum warmup iterations reached"
# Check gradient norm criterion
if "gradient" in active_criteria:
if grad_norm < self.config.gradient_norm_threshold:
return True, "Gradient norm below threshold"
# Check loss plateau criterion
if "plateau" in active_criteria:
# Compute relative loss change
eps = jnp.finfo(jnp.float64).eps
relative_change = jnp.abs(current_loss - prev_loss) / (
jnp.abs(prev_loss) + eps
)
if relative_change < self.config.loss_plateau_threshold:
return True, "Loss plateau detected"
# No switch criteria met
return False, ""
# =========================================================================
# Phase 1 Helper Methods (extracted for complexity reduction)
# =========================================================================
def _check_warm_start_and_record(
self,
current_params: jnp.ndarray,
loss_fn: Callable,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
telemetry: Any,
) -> tuple[bool, dict[str, Any] | None, float, float]:
"""Check warm start detection (Layer 1 defense).
Computes initial loss and relative loss, checks if we're already
near optimal and can skip L-BFGS warmup.
Parameters
----------
current_params : jnp.ndarray
Current parameters in normalized space.
loss_fn : Callable
Loss function.
x_data : jnp.ndarray
Input data.
y_data : jnp.ndarray
Output data.
telemetry : DefenseLayerTelemetry
Telemetry recorder.
Returns
-------
should_exit : bool
True if warm start detected and should skip warmup.
result : dict or None
If should_exit, the result dict to return. Otherwise None.
initial_loss : float
Computed initial loss value.
relative_loss : float
Computed relative loss (loss / y_variance).
"""
initial_loss = float(loss_fn(current_params, x_data, y_data))
y_variance = float(jnp.var(y_data))
relative_loss = initial_loss / (y_variance + 1e-10)
# Store for Layer 3 cost-increase guard
self._warmup_initial_loss = initial_loss
self._warmup_relative_loss = relative_loss
# Log diagnostic info
if self.config.verbose >= 2:
_logger.debug(
f"Phase 1 initial assessment: loss={initial_loss:.6e}, "
f"y_var={y_variance:.6e}, relative_loss={relative_loss:.6e}"
)
# Check warm start threshold
if (
self.config.enable_warm_start_detection
and relative_loss < self.config.warm_start_threshold
):
# Record Layer 1 telemetry
telemetry.record_layer1_trigger(
relative_loss=relative_loss, threshold=self.config.warm_start_threshold
)
phase_record = {
"phase": 1,
"name": "lbfgs_warmup",
"iterations": 0,
"final_loss": initial_loss,
"best_loss": initial_loss,
"switch_reason": (
f"Warm start detected (relative_loss={relative_loss:.4e} "
f"< {self.config.warm_start_threshold})"
),
"timestamp": time.time(),
"skipped": True,
"warm_start": True,
"relative_loss": relative_loss,
}
self.phase_history.append(phase_record)
if self.config.verbose >= 1:
_logger.info(
f"Phase 1: Skipping L-BFGS warmup - warm start detected "
f"(relative_loss={relative_loss:.4e})"
)
result = {
"final_params": current_params,
"best_params": current_params,
"best_loss": initial_loss,
"final_loss": initial_loss,
"iterations": 0,
"switch_reason": "Warm start detected - skipping L-BFGS warmup",
"warm_start": True,
"relative_loss": relative_loss,
}
return True, result, initial_loss, relative_loss
return False, None, initial_loss, relative_loss
def _select_adaptive_step_size(
self,
relative_loss: float,
telemetry: Any,
) -> tuple[float, str]:
"""Select adaptive initial step size for L-BFGS (Layer 2 defense).
Parameters
----------
relative_loss : float
Relative loss (loss / y_variance).
telemetry : DefenseLayerTelemetry
Telemetry recorder.
Returns
-------
initial_step : float
Selected initial step size.
lr_mode : str
Mode name ('refinement', 'careful', 'exploration', 'fixed').
"""
if self.config.enable_adaptive_warmup_lr:
if relative_loss < 0.1:
# Refinement mode: near optimal, use large step for Newton-like speed
initial_step = self.config.lbfgs_refinement_step_size
lr_mode = "refinement"
elif relative_loss < 1.0:
# Careful mode: reasonable starting point
initial_step = 0.5 # Intermediate step size
lr_mode = "careful"
else:
# Exploration mode: far from optimal, use small step to prevent overshoot
initial_step = self.config.lbfgs_exploration_step_size
lr_mode = "exploration"
self._warmup_lr_mode = lr_mode
# Record Layer 2 telemetry
telemetry.record_layer2_lr_mode(mode=lr_mode, relative_loss=relative_loss)
if self.config.verbose >= 2:
_logger.debug(
f"Phase 1 L-BFGS adaptive step: mode={lr_mode}, step={initial_step:.2f}, "
f"relative_loss={relative_loss:.4e}"
)
else:
initial_step = self.config.lbfgs_initial_step_size
lr_mode = "fixed"
self._warmup_lr_mode = lr_mode
# Record fixed mode telemetry
telemetry.record_layer2_lr_mode(mode=lr_mode, relative_loss=relative_loss)
return initial_step, lr_mode
def _check_cost_increase_guard(
self,
iteration: int,
loss_value: float,
best_loss: float,
best_params: jnp.ndarray,
lr_mode: str,
relative_loss: float,
telemetry: Any,
) -> tuple[bool, dict[str, Any] | None]:
"""Check cost increase guard (Layer 3 defense).
Parameters
----------
iteration : int
Current iteration number.
loss_value : float
Current loss value.
best_loss : float
Best loss seen so far.
best_params : jnp.ndarray
Best parameters seen so far.
lr_mode : str
Learning rate mode.
relative_loss : float
Initial relative loss.
telemetry : DefenseLayerTelemetry
Telemetry recorder.
Returns
-------
should_exit : bool
True if cost guard triggered.
result : dict or None
If should_exit, the result dict to return. Otherwise None.
"""
if not self.config.enable_cost_guard or iteration <= 0:
return False, None
if self._warmup_initial_loss == 0:
return False, None
cost_increase_ratio = loss_value / self._warmup_initial_loss
cost_threshold = 1.0 + self.config.cost_increase_tolerance
if cost_increase_ratio > cost_threshold:
# Record Layer 3 telemetry
telemetry.record_layer3_trigger(
cost_ratio=cost_increase_ratio,
tolerance=self.config.cost_increase_tolerance,
iteration=iteration,
)
# Loss increased beyond tolerance - abort and return best
if self.config.verbose >= 1:
_logger.warning(
f"Phase 1: Cost increase guard triggered at iteration "
f"{iteration + 1}. Loss {loss_value:.6e} > "
f"{self._warmup_initial_loss:.6e} * {cost_threshold:.2f}. "
f"Reverting to best params (loss={best_loss:.6e})."
)
phase_record = {
"phase": 1,
"name": "lbfgs_warmup",
"iterations": iteration + 1,
"final_loss": loss_value,
"best_loss": best_loss,
"switch_reason": (
f"Cost increase guard triggered (ratio={cost_increase_ratio:.4f})"
),
"timestamp": time.time(),
"cost_guard_triggered": True,
"lr_mode": lr_mode,
"relative_loss": relative_loss,
}
self.phase_history.append(phase_record)
result = {
"final_params": best_params, # Return BEST, not current
"best_params": best_params,
"best_loss": best_loss,
"final_loss": loss_value,
"iterations": iteration + 1,
"switch_reason": "Cost increase guard triggered",
"cost_guard_triggered": True,
"cost_increase_ratio": cost_increase_ratio,
"lr_mode": lr_mode,
"relative_loss": relative_loss,
}
return True, result
return False, None
def _build_phase1_result(
self,
final_params: jnp.ndarray,
best_params: jnp.ndarray,
best_loss: float,
final_loss: float,
iterations: int,
switch_reason: str,
lr_mode: str,
relative_loss: float,
record_history: bool = True,
**extra_fields: Any,
) -> dict[str, Any]:
"""Build Phase 1 result dict and optionally record to phase history.
Parameters
----------
final_params : jnp.ndarray
Final parameters.
best_params : jnp.ndarray
Best parameters found.
best_loss : float
Best loss value.
final_loss : float
Final loss value.
iterations : int
Number of iterations performed.
switch_reason : str
Reason for switching/ending.
lr_mode : str
Learning rate mode used.
relative_loss : float
Initial relative loss.
record_history : bool, optional
Whether to append to phase_history.
**extra_fields : Any
Additional fields to include in result.
Returns
-------
result : dict
Phase 1 result dictionary.
"""
if record_history:
phase_record = {
"phase": 1,
"name": "lbfgs_warmup",
"iterations": iterations,
"final_loss": final_loss,
"best_loss": best_loss,
"switch_reason": switch_reason,
"timestamp": time.time(),
"lr_mode": lr_mode,
"relative_loss": relative_loss,
}
self.phase_history.append(phase_record)
result = {
"final_params": final_params,
"best_params": best_params,
"best_loss": best_loss,
"final_loss": final_loss,
"iterations": iterations,
"switch_reason": switch_reason,
"lr_mode": lr_mode,
"relative_loss": relative_loss,
}
result.update(extra_fields)
return result
def _run_phase1_warmup(
self,
data_source: Any,
model: Callable,
p0: jnp.ndarray,
bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
) -> dict[str, Any]:
"""Run Phase 1 L-BFGS warmup.
L-BFGS provides 5-10x faster convergence to the basin of attraction
compared to first-order warmup by using approximate second-order
(Hessian) information.
Parameters
----------
data_source : various types
Data source (tuple of arrays for now)
model : Callable
User model function
p0 : array_like
Initial parameter guess
bounds : tuple of array_like, optional
Parameter bounds
Returns
-------
result : dict
Phase 1 result with keys:
- 'final_params': Final parameters in normalized space
- 'best_params': Best parameters found during warmup
- 'best_loss': Best loss value
- 'final_loss': Final loss value
- 'iterations': Number of iterations performed
- 'switch_reason': Reason for switching to Phase 2
Notes
-----
Operates entirely in normalized parameter space.
Tracks best parameters throughout warmup.
"""
# Setup normalization if not already done
if self.normalizer is None:
self._setup_normalization(model, p0, bounds)
# Extract data from source (simple tuple for now)
if isinstance(data_source, tuple) and len(data_source) == 2:
x_data, y_data = data_source
x_data = jnp.asarray(x_data, dtype=jnp.float64)
y_data = jnp.asarray(y_data, dtype=jnp.float64)
else:
raise NotImplementedError(
"Only tuple data sources (x_data, y_data) supported in Phase 1 warmup"
)
# Initialize parameters in normalized space
# Run tournament selection if multi-start enabled
if self.config.enable_multistart and self.multistart_candidates is not None:
current_params = self._run_tournament_selection(data_source, model)
else:
current_params = self.normalized_params
# Create loss function FIRST (needed for warm start detection)
loss_fn = self._create_warmup_loss_fn()
# Record telemetry for warmup start
telemetry = get_defense_telemetry()
telemetry.record_warmup_start()
# LAYER 1: Warm Start Detection
should_exit, result, initial_loss, relative_loss = (
self._check_warm_start_and_record(
current_params=current_params,
loss_fn=loss_fn,
x_data=x_data,
y_data=y_data,
telemetry=telemetry,
)
)
if should_exit:
return result
# LAYER 2: Adaptive Initial Step Size Selection for L-BFGS
initial_step, lr_mode = self._select_adaptive_step_size(
relative_loss=relative_loss,
telemetry=telemetry,
)
# Create L-BFGS optimizer with selected initial step size
optimizer, opt_state = self._create_lbfgs_optimizer(
current_params,
initial_step_size=initial_step,
)
# Best parameter tracking
best_params = current_params
best_loss = initial_loss # Initialize with computed initial loss
# Initialize previous loss
prev_loss = initial_loss
# Reset clip counter for diagnostics
self._warmup_clip_count = 0
# Warmup loop using L-BFGS
for iteration in range(self.config.max_warmup_iterations):
# Perform L-BFGS step (using full data for now)
current_params, loss_value, grad_norm, opt_state, _line_search_failed = (
self._lbfgs_step(
params=current_params,
opt_state=opt_state,
optimizer=optimizer,
loss_fn=loss_fn,
x_batch=x_data,
y_batch=y_data,
iteration=iteration,
)
)
# If line search failed, we may want to be more conservative
# The _lbfgs_step already handles fallback behavior
# Track best parameters
if loss_value < best_loss:
best_loss = loss_value
best_params = current_params
# LAYER 3: Cost-Increase Guard
cost_guard_exit, cost_guard_result = self._check_cost_increase_guard(
iteration=iteration,
loss_value=loss_value,
best_loss=best_loss,
best_params=best_params,
lr_mode=lr_mode,
relative_loss=relative_loss,
telemetry=telemetry,
)
if cost_guard_exit:
return cost_guard_result
# Save checkpoint periodically if enabled
if (
hasattr(self.config, "enable_checkpoints")
and self.config.enable_checkpoints
and hasattr(self.config, "checkpoint_frequency")
and (iteration + 1) % self.config.checkpoint_frequency == 0
):
if (
hasattr(self.config, "checkpoint_dir")
and self.config.checkpoint_dir
):
checkpoint_path = (
Path(self.config.checkpoint_dir)
/ f"checkpoint_phase1_iter{iteration + 1}.h5"
)
self.current_phase = 1
self.normalized_params = current_params
self._save_checkpoint(checkpoint_path)
# Check switch criteria after minimum warmup iterations
if iteration >= self.config.warmup_iterations:
should_switch, reason = self._check_phase1_switch_criteria(
iteration=iteration,
current_loss=loss_value,
prev_loss=prev_loss,
grad_norm=grad_norm,
)
if should_switch:
return self._build_phase1_result(
final_params=current_params,
best_params=best_params,
best_loss=best_loss,
final_loss=loss_value,
iterations=iteration + 1,
switch_reason=reason,
lr_mode=lr_mode,
relative_loss=relative_loss,
)
# Update previous loss
prev_loss = loss_value
# Maximum iterations reached (this shouldn't happen if max_iter criterion active)
return self._build_phase1_result(
final_params=current_params,
best_params=best_params,
best_loss=best_loss,
final_loss=loss_value,
iterations=self.config.max_warmup_iterations,
switch_reason="Maximum iterations reached",
lr_mode=lr_mode,
relative_loss=relative_loss,
)
def _compute_jacobian_chunk(
self,
x_chunk: jnp.ndarray,
params: jnp.ndarray,
) -> jnp.ndarray:
"""Compute exact Jacobian for a data chunk using vmap+grad.
This method computes the Jacobian matrix J where J[i, j] = ∂f_i/∂p_j
for all data points in the chunk. Uses JAX automatic differentiation
for exact gradients (no finite differences).
Parameters
----------
x_chunk : array_like
Independent variable chunk of shape (n_points,) or (n_points, n_features)
params : array_like
Parameters in normalized space of shape (n_params,)
Returns
-------
J_chunk : array_like
Jacobian matrix of shape (n_points, n_params)
Notes
-----
Uses jax.jacrev with vmap for efficient per-point gradient computation.
The normalized model wrapper automatically handles parameter denormalization.
Performance
-----------
Uses pre-compiled Jacobian function (self._jacobian_fn_compiled) when
available, providing 15-25% speedup by avoiding repeated JIT tracing.
Falls back to inline compilation for backwards compatibility.
"""
# Use pre-compiled function if available (set up in _setup_jacobian_fn)
if self._jacobian_fn_compiled is not None:
return self._jacobian_fn_compiled(params, x_chunk)
# Fallback: inline compilation (for backwards compatibility)
# This path is slower due to repeated function tracing
def model_at_x(p, x_single):
return self.normalized_model(x_single, *p)
jac_fn = jax.vmap(lambda x: jax.jacrev(model_at_x, argnums=0)(params, x))
return jac_fn(x_chunk)
def _accumulate_jtj_jtr(
self,
x_chunk: jnp.ndarray,
y_chunk: jnp.ndarray,
params: jnp.ndarray,
JTJ_prev: jnp.ndarray,
JTr_prev: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, float]:
"""Accumulate J^T J and J^T r across chunks for memory-efficient Gauss-Newton.
This is the key method that enables streaming optimization. Instead of storing
the full Jacobian (n_points × n_params), we only accumulate the products:
- J^T J: (n_params × n_params) - Gauss-Newton Hessian approximation
- J^T r: (n_params,) - Gradient vector
Memory: O(p^2) instead of O(np) where n >> p
Parameters
----------
x_chunk : array_like
Independent variable chunk of shape (n_points_chunk,)
y_chunk : array_like
Observed data chunk of shape (n_points_chunk,)
params : array_like
Current parameters in normalized space of shape (n_params,)
JTJ_prev : array_like
Previous accumulated J^T J of shape (n_params, n_params)
JTr_prev : array_like
Previous accumulated J^T r of shape (n_params,)
Returns
-------
JTJ_new : array_like
Updated J^T J accumulation of shape (n_params, n_params)
JTr_new : array_like
Updated J^T r accumulation of shape (n_params,)
residual_sum_sq : float
Sum of squared residuals for this chunk
Notes
-----
Accumulation formulas:
- J^T J_new = J^T J_prev + J_chunk^T @ J_chunk
- J^T r_new = J^T r_prev + J_chunk^T @ r_chunk
where r_chunk = y_chunk - f(x_chunk, params)
"""
# Compute predictions for chunk
predictions = self.normalized_model(x_chunk, *params)
# Compute residuals
residuals = y_chunk - predictions
# Compute Jacobian for chunk
J_chunk = self._compute_jacobian_chunk(x_chunk, params)
# Accumulate J^T J: (n_params, n_params)
JTJ_new = JTJ_prev + J_chunk.T @ J_chunk
# Accumulate J^T r: (n_params,)
JTr_new = JTr_prev + J_chunk.T @ residuals
# Compute residual sum of squares for this chunk
residual_sum_sq = float(jnp.sum(residuals**2))
# Store accumulators for checkpointing
self.phase2_JTJ_accumulator = JTJ_new
self.phase2_JTr_accumulator = JTr_new
return JTJ_new, JTr_new, residual_sum_sq
def _solve_gauss_newton_step(
self,
JTJ: jnp.ndarray,
JTr: jnp.ndarray,
trust_radius: float,
regularization: float = 1e-10,
) -> tuple[jnp.ndarray, float]:
"""Solve Gauss-Newton step using SVD following trf.py patterns.
Solves the trust region subproblem:
minimize: 0.5 * p^T (J^T J) p + (J^T r)^T p
subject to: ||p|| <= trust_radius
Uses SVD decomposition for numerical stability and handles rank-deficient
matrices through regularization.
Parameters
----------
JTJ : array_like
Accumulated J^T J matrix of shape (n_params, n_params)
JTr : array_like
Accumulated J^T r vector of shape (n_params,)
trust_radius : float
Trust region radius (maximum allowed step norm)
regularization : float, default=1e-10
Tikhonov regularization parameter for numerical stability
Returns
-------
step : array_like
Gauss-Newton step of shape (n_params,)
predicted_reduction : float
Predicted reduction in cost function: -g^T p - 0.5 p^T H p
Notes
-----
Follows the SVD-based trust region solver pattern from trf.py.
Uses regularization to handle rank-deficient or ill-conditioned J^T J.
Algorithm:
1. Compute SVD: J^T J = U S V^T
2. Solve regularized system: (J^T J + λI)^-1 (-J^T r)
3. Scale step if norm exceeds trust radius
"""
n_params = JTJ.shape[0]
# Add Tikhonov regularization for numerical stability
JTJ_reg = JTJ + regularization * jnp.eye(n_params, dtype=jnp.float64)
# Compute SVD of regularized J^T J
U, s, Vt = jnp.linalg.svd(JTJ_reg, full_matrices=False)
# Solve for Gauss-Newton step using SVD
# The Gauss-Newton step minimizes the linearized problem:
# Cost: C = 0.5 * ||r||^2 where r = y - f(x, p)
# Gradient: g = -J^T r (since ∂r/∂p = -J)
# The GN step solves: J^T J δ = J^T r
# Therefore: δ = (J^T J)^{-1} J^T r [NO negative sign!]
#
# Using SVD: J^T J = U S V^T, we have:
# δ = V S^{-1} U^T (J^T r)
# Compute U^T @ JTr (positive, not negative!)
UTb = U.T @ JTr
# Solve diagonal system with regularization
# Filter out small singular values
s_threshold = jnp.max(s) * 1e-10
s_safe = jnp.where(s > s_threshold, s, s_threshold)
step_hat = UTb / s_safe
# Transform back to parameter space
step = Vt.T @ step_hat
# Apply trust region constraint
step_norm = jnp.linalg.norm(step)
step = step * jnp.where(step_norm > trust_radius, trust_radius / step_norm, 1.0)
# Compute predicted reduction: -g^T δ - 0.5 δ^T H δ
# where g = -J^T r (gradient) and H = J^T J (Hessian approx)
# Since g = -JTr, we have: -g^T δ = JTr^T δ
# predicted_reduction = JTr^T δ - 0.5 δ^T (J^T J) δ
predicted_reduction = jnp.maximum(
jnp.dot(JTr, step) - 0.5 * jnp.dot(step, JTJ @ step), 0.0
)
return step, predicted_reduction
# =========================================================================
# CG-based Gauss-Newton Solver Methods (Task Group 3)
# =========================================================================
def _select_gn_solver(self, n_params: int) -> str:
"""Select Gauss-Newton solver based on parameter count.
Auto-selects between materialized SVD-based solve and CG-based
implicit solve based on the parameter count threshold.
Parameters
----------
n_params : int
Number of parameters in the optimization problem
Returns
-------
solver_type : str
Either 'materialized' (for small p) or 'cg' (for large p)
Notes
-----
Threshold logic:
- p < cg_param_threshold: Materialize J^T J, use SVD solve
- p >= cg_param_threshold: Use CG with implicit matvec
For p < 2000, O(p^3) SVD solve is fast and SVD provides better
conditioning information. For p >= 2000, CG avoids O(p^2) memory
for J^T J storage.
"""
threshold = self.config.cg_param_threshold
if n_params < threshold:
solver_type = "materialized"
else:
solver_type = "cg"
if self.config.verbose >= 2:
_logger.debug(
f"GN solver selection: p={n_params}, threshold={threshold}, "
f"selected={solver_type}"
)
return solver_type
def _implicit_jtj_matvec(
self,
v: jnp.ndarray,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> jnp.ndarray:
"""Compute (J^T J) @ v without materializing J^T J.
Implements implicit matrix-vector product for CG solver:
result = J^T @ (J @ v)
This avoids O(p^2) storage for J^T J, enabling optimization with
large parameter counts.
Parameters
----------
v : array_like
Vector to multiply, shape (n_params,)
params : array_like
Current parameters in normalized space, shape (n_params,)
x_data : array_like
Full x data, shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data, shape (n_points,) (unused, kept for API consistency)
Returns
-------
result : array_like
Result of (J^T J) @ v, shape (n_params,)
Notes
-----
Memory complexity: O(n_chunk * p) per chunk instead of O(p^2).
Operates on chunks to avoid memory explosion for large datasets.
The computation is:
1. For each chunk: compute J_chunk @ v (forward pass)
2. For each chunk: compute J_chunk^T @ (J_chunk @ v) (backward pass)
3. Sum across chunks
"""
chunk_size = self.config.chunk_size
n_points = len(x_data)
n_params = len(v)
# Initialize result accumulator
result = jnp.zeros(n_params)
# Process in chunks to limit memory
for i in range(0, n_points, chunk_size):
x_chunk = x_data[i : i + chunk_size]
# Compute Jacobian for this chunk: (chunk_size, n_params)
J_chunk = self._compute_jacobian_chunk(x_chunk, params)
# Forward: J @ v -> (chunk_size,)
Jv = J_chunk @ v
# Backward: J^T @ (J @ v) -> (n_params,)
JTJv_chunk = J_chunk.T @ Jv
# Accumulate
result = result + JTJv_chunk
return result
def _compute_jacobi_preconditioner(
self,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> jnp.ndarray:
"""Compute Jacobi (diagonal) preconditioner for CG solver.
Computes the diagonal of J^T J efficiently via chunked accumulation:
diag_JTJ[j] = sum_i (J[i, j])^2
Parameters
----------
params : array_like
Current parameters in normalized space, shape (n_params,)
x_data : array_like
Full x data, shape (n_points,) or (n_points, n_features)
y_data : array_like
Full y data, shape (n_points,) (unused, kept for API consistency)
Returns
-------
diag_JTJ : array_like
Diagonal of J^T J, shape (n_params,). Values are always positive.
Notes
-----
The Jacobi preconditioner M = diag(J^T J) is applied as M^{-1} r
in preconditioned CG. This scales the residual by inverse column
norms, improving convergence for poorly-scaled problems.
"""
chunk_size = self.config.chunk_size
n_points = len(x_data)
n_params = len(params)
# Initialize diagonal accumulator
diag_JTJ = jnp.zeros(n_params)
# Accumulate column norms squared across chunks
for i in range(0, n_points, chunk_size):
x_chunk = x_data[i : i + chunk_size]
# Compute Jacobian for this chunk: (chunk_size, n_params)
J_chunk = self._compute_jacobian_chunk(x_chunk, params)
# Accumulate squared column values: diag_JTJ[j] += sum_i J[i,j]^2
diag_JTJ = diag_JTJ + jnp.sum(J_chunk**2, axis=0)
# Ensure positive (add small regularization to avoid division by zero)
diag_JTJ = jnp.maximum(diag_JTJ, self.config.regularization_factor)
return diag_JTJ
def _cg_solve_implicit(
self,
JTr: jnp.ndarray,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
trust_radius: float,
) -> tuple[jnp.ndarray, int, bool]:
"""Solve (J^T J) @ step = J^T r using Conjugate Gradient with implicit matvec.
Uses CG iteration with implicit J^T J matvec to avoid O(p^2) storage.
Follows the pattern from trf.py:conjugate_gradient_solve().
Parameters
----------
JTr : array_like
Right-hand side vector J^T r, shape (n_params,)
params : array_like
Current parameters in normalized space, shape (n_params,)
x_data : array_like
Full x data
y_data : array_like
Full y data
trust_radius : float
Trust region radius (for step constraint)
Returns
-------
step : array_like
Approximate solution, shape (n_params,)
iterations : int
Number of CG iterations performed
converged : bool
True if CG converged within tolerance
Notes
-----
Implements Inexact Newton via CG with tolerance ||r_k|| < rtol * ||r_0||.
On non-convergence, returns incomplete solution (still a descent direction).
Uses jax.lax.while_loop for GPU acceleration.
"""
n_params = len(JTr)
max_iter = self.config.cg_max_iterations
rtol = self.config.cg_relative_tolerance
atol = self.config.cg_absolute_tolerance
# Initial guess: zero
step = jnp.zeros(n_params)
# Initial residual: r = b - A @ x = JTr - (J^T J) @ 0 = JTr
r = JTr
r_norm_initial = jnp.linalg.norm(r)
# Convergence threshold (Inexact Newton strategy)
tol = jnp.maximum(rtol * r_norm_initial, atol)
# Initialize CG vectors
p = r # Search direction
r_dot_r = jnp.dot(r, r)
# CG iteration state: (step, r, p, r_dot_r, iteration, converged)
def cg_body(state):
step, r, p, r_dot_r, iteration, _ = state
# Compute A @ p = (J^T J) @ p implicitly
Ap = self._implicit_jtj_matvec(p, params, x_data, y_data)
# Compute step length
pAp = jnp.dot(p, Ap)
# Safeguard against negative curvature or zero
pAp_safe = jnp.maximum(pAp, 1e-15)
alpha = r_dot_r / pAp_safe
# Update solution
step_new = step + alpha * p
# Update residual
r_new = r - alpha * Ap
# Compute new residual norm squared
r_dot_r_new = jnp.dot(r_new, r_new)
# Check convergence
r_norm = jnp.sqrt(r_dot_r_new)
converged = r_norm < tol
# Compute beta for next search direction
beta = r_dot_r_new / (r_dot_r + 1e-15)
# Update search direction
p_new = r_new + beta * p
return (step_new, r_new, p_new, r_dot_r_new, iteration + 1, converged)
def cg_cond(state):
_, _, _, _, iteration, converged = state
return jnp.logical_and(iteration < max_iter, jnp.logical_not(converged))
# Run CG iterations using while_loop for GPU efficiency.
# Seeding converged=True when r_norm_initial < atol makes cg_cond return
# False on the first check, so the loop body never executes — avoids a
# Python `if` on a traced value which would raise ConcretizationTypeError.
init_state = (step, r, p, r_dot_r, jnp.array(0), r_norm_initial < atol)
final_state = jax.lax.while_loop(cg_cond, cg_body, init_state)
step_final, _r_final, _, _, iterations, converged = final_state
# Apply trust region constraint
step_norm = jnp.linalg.norm(step_final)
if step_norm > trust_radius:
step_final = step_final * (trust_radius / step_norm)
return step_final, int(iterations), bool(converged)
def _solve_gauss_newton_step_cg(
self,
JTJ: jnp.ndarray,
JTr: jnp.ndarray,
trust_radius: float,
params: jnp.ndarray,
x_data: jnp.ndarray,
y_data: jnp.ndarray,
) -> tuple[jnp.ndarray, float]:
"""Solve Gauss-Newton step using CG with implicit J^T J matvec.
Alternative to SVD-based _solve_gauss_newton_step() for large parameter
counts. Uses CG iteration with implicit matvec to avoid O(p^2) storage.
Parameters
----------
JTJ : array_like
Accumulated J^T J matrix (may not be used if using implicit matvec)
JTr : array_like
Accumulated J^T r vector, shape (n_params,)
trust_radius : float
Trust region radius
params : array_like
Current parameters in normalized space
x_data : array_like
Full x data
y_data : array_like
Full y data
Returns
-------
step : array_like
Gauss-Newton step, shape (n_params,)
predicted_reduction : float
Predicted reduction in cost function
Notes
-----
On CG non-convergence, returns incomplete solution which is typically
still a descent direction (useful for trust region methods).
"""
# Solve using CG with implicit matvec
step, cg_iterations, converged = self._cg_solve_implicit(
JTr, params, x_data, y_data, trust_radius
)
# Log CG diagnostics
if self.config.verbose >= 2:
status = "converged" if converged else "incomplete"
_logger.debug(f"CG solver: {cg_iterations} iterations, {status}")
# If CG didn't converge, the incomplete solution is still usable
# as a descent direction. Optionally apply Jacobi preconditioner
# and re-solve if configured.
if not converged and cg_iterations >= self.config.cg_max_iterations * 0.9:
# CG struggled - this is logged for diagnostics but we use the
# incomplete solution which is typically still a descent direction
if self.config.verbose >= 1:
_logger.warning(
f"CG solver hit iteration limit ({cg_iterations}). "
"Using incomplete solution as descent direction."
)
# Compute predicted reduction: JTr^T @ step - 0.5 * step^T @ (J^T J) @ step
# Use implicit matvec for the JTJ @ step term
JTJ_step = self._implicit_jtj_matvec(step, params, x_data, y_data)
predicted_reduction = jnp.dot(JTr, step) - 0.5 * jnp.dot(step, JTJ_step)
predicted_reduction = float(jnp.maximum(predicted_reduction, 0.0))
return step, predicted_reduction
def _apply_trust_region(
self,
step: jnp.ndarray,
trust_radius: float,
) -> jnp.ndarray:
"""Apply trust region constraint by scaling step if necessary.
If the step norm exceeds the trust radius, scale it to lie on the
trust region boundary.
Parameters
----------
step : array_like
Proposed parameter step of shape (n_params,)
trust_radius : float
Trust region radius
Returns
-------
scaled_step : array_like
Step scaled to satisfy ||step|| <= trust_radius
Notes
-----
This is a simple trust region implementation. More sophisticated
approaches (e.g., dogleg, 2D subspace) are used in trf.py but
require the full Jacobian.
"""
step_norm = jnp.linalg.norm(step)
if step_norm <= trust_radius:
# Step is within trust region
return step
else:
# Scale step to trust region boundary
return step * (trust_radius / step_norm)
def _gauss_newton_iteration(
self,
data_source: tuple[jnp.ndarray, jnp.ndarray],
current_params: jnp.ndarray,
trust_radius: float,
) -> dict[str, Any]:
"""Perform one complete Gauss-Newton iteration with J^T J accumulation.
This method:
1. Accumulates J^T J and J^T r across all data chunks
2. Solves the trust region subproblem for the step
3. Evaluates the new parameters and cost
4. Updates trust region based on actual vs predicted reduction
Parameters
----------
data_source : tuple of array_like
Full dataset as (x_data, y_data)
current_params : array_like
Current parameters in normalized space of shape (n_params,)
trust_radius : float
Current trust region radius
Returns
-------
result : dict
Iteration result with keys:
- 'new_params': Updated parameters
- 'new_cost': New cost value
- 'step': Parameter step taken
- 'actual_reduction': Actual cost reduction
- 'predicted_reduction': Predicted cost reduction
- 'trust_radius': Updated trust region radius
- 'gradient_norm': Gradient norm ||J^T r||
Notes
-----
Uses chunk-based accumulation for memory efficiency.
Dispatches between JAX scan (GPU/TPU) and Python loops (CPU) based on
config.loop_strategy for optimal performance on each backend.
"""
x_data, y_data = data_source
n_params = len(current_params)
n_points = len(x_data)
chunk_size = self.config.chunk_size
# Dispatch based on backend: scan for GPU/TPU, loops for CPU
if self._use_scan_for_accumulation():
# Use JAX scan for GPU/TPU (better XLA fusion, reduced kernel launches)
JTJ, JTr, total_cost = self._accumulate_jtj_jtr_scan(
x_data, y_data, current_params
)
else:
# Use Python loops for CPU (lower tracing overhead)
JTJ = jnp.zeros((n_params, n_params))
JTr = jnp.zeros(n_params)
total_cost = 0.0
for i in range(0, n_points, chunk_size):
x_chunk = x_data[i : i + chunk_size]
y_chunk = y_data[i : i + chunk_size]
JTJ, JTr, chunk_cost = self._accumulate_jtj_jtr(
x_chunk, y_chunk, current_params, JTJ, JTr
)
total_cost += chunk_cost
# Add group variance regularization if enabled
# This prevents per-angle parameters from absorbing angle-dependent signals
if (
self.config.enable_group_variance_regularization
and self.config.group_variance_indices
):
var_lambda = self.config.group_variance_lambda
for start, end in self.config.group_variance_indices:
group_params = current_params[start:end]
n_group = end - start
group_mean = jnp.mean(group_params)
# Gradient of variance: ∂Var/∂p_i = (2/n) * (p_i - mean)
grad_var = (2.0 / n_group) * (group_params - group_mean)
# Add to JTr (negative gradient direction)
# Note: JTr represents -∇f, so we subtract the regularization gradient
JTr = JTr.at[start:end].add(-var_lambda * grad_var)
# Hessian of variance: H = (2/n) * (I - (1/n)*11^T)
# This is a dense (n_group x n_group) matrix
diag_term = (2.0 / n_group) * jnp.eye(n_group, dtype=jnp.float64)
off_diag_term = (2.0 / (n_group * n_group)) * jnp.ones(
(n_group, n_group), dtype=jnp.float64
)
H_var = diag_term - off_diag_term
# Add to JTJ for the group block
JTJ = JTJ.at[start:end, start:end].add(var_lambda * H_var)
# Add variance cost to total
group_var = jnp.var(group_params)
total_cost += (
var_lambda
* float(jnp.where(jnp.isfinite(group_var), group_var, 0.0))
* n_points
)
# Compute gradient norm for convergence check
gradient_norm = float(jnp.linalg.norm(JTr))
# Solve for Gauss-Newton step
step, predicted_reduction = self._solve_gauss_newton_step(
JTJ, JTr, trust_radius
)
# Apply step to get new parameters
new_params = current_params + step
# Clip to bounds if available (important for constrained optimization)
if self.normalized_bounds is not None:
lb, ub = self.normalized_bounds
new_params = jnp.clip(new_params, lb, ub)
# Evaluate cost at new parameters using optimized pre-compiled function
# This provides 20-30% speedup compared to inline computation
new_cost = self._compute_cost_with_variance_regularization(
new_params, x_data, y_data
)
# Compute actual reduction
actual_reduction = total_cost - new_cost
# Update trust region based on reduction ratio
if predicted_reduction > 0:
reduction_ratio = actual_reduction / predicted_reduction
else:
reduction_ratio = 0.0
# Trust region update logic with recovery mechanism
# Minimum and maximum trust radius bounds
min_trust_radius = getattr(self.config, "min_trust_radius", 1e-8)
max_trust_radius = getattr(self.config, "max_trust_radius", 1000.0)
step_norm = float(jnp.linalg.norm(step))
if reduction_ratio < 0.25:
# Poor agreement: shrink trust region (use 0.5 instead of 0.25 for
# less aggressive shrinkage)
new_trust_radius = trust_radius * 0.5
# Recovery mechanism: if trust radius is very small but gradient is
# large, the optimizer may be stuck. Reset to allow exploration.
if new_trust_radius < min_trust_radius and gradient_norm > 1e-4:
# Reset to gradient-scaled value for recovery
new_trust_radius = min(
0.1 * gradient_norm / max(1.0, gradient_norm), 1.0
)
elif reduction_ratio > 0.75 and step_norm >= 0.9 * trust_radius:
# Good agreement and step at boundary: expand trust region
new_trust_radius = min(trust_radius * 2.0, max_trust_radius)
else:
# Acceptable agreement: keep trust region
new_trust_radius = trust_radius
# Enforce minimum trust radius to prevent complete collapse
new_trust_radius = max(new_trust_radius, min_trust_radius)
return {
"new_params": new_params,
"new_cost": new_cost,
"step": step,
"actual_reduction": actual_reduction,
"predicted_reduction": predicted_reduction,
"trust_radius": new_trust_radius,
"gradient_norm": gradient_norm,
}
def _gn_iteration_with_retry(
self,
data_source: tuple[jnp.ndarray, jnp.ndarray],
current_params: jnp.ndarray,
trust_radius: float,
best_params: jnp.ndarray,
best_cost: float,
) -> tuple[dict[str, Any], float]:
"""Execute Gauss-Newton iteration with retry logic.
Parameters
----------
data_source : tuple of array_like
Full dataset as (x_data, y_data).
current_params : array_like
Current parameters.
trust_radius : float
Current trust region radius.
best_params : array_like
Best parameters found so far (fallback).
best_cost : float
Best cost found so far (fallback).
Returns
-------
iter_result : dict
Iteration result dictionary.
trust_radius : float
Updated trust radius.
"""
max_retries = getattr(self.config, "max_retries_per_batch", 0)
for retry_attempt in range(max_retries + 1):
try:
iter_result = self._gauss_newton_iteration(
data_source, current_params, trust_radius
)
new_params = iter_result["new_params"]
new_cost = iter_result["new_cost"]
# Validate results - if finite, return success
if jnp.all(jnp.isfinite(new_params)) and jnp.isfinite(new_cost):
return iter_result, trust_radius
# Non-finite results: retry or use fallback
if retry_attempt < max_retries:
trust_radius *= 0.5
continue
# Max retries exhausted: use best known params
iter_result["new_params"] = best_params
iter_result["new_cost"] = best_cost
iter_result["gradient_norm"] = 0.0
iter_result["actual_reduction"] = 0.0
return iter_result, trust_radius
except Exception:
if retry_attempt < max_retries:
trust_radius *= 0.5
continue
# Max retries exhausted: check fault tolerance
if not (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
raise
# Use best parameters as fallback
iter_result = {
"new_params": best_params,
"new_cost": best_cost,
"gradient_norm": 0.0,
"actual_reduction": 0.0,
"trust_radius": trust_radius,
}
return iter_result, trust_radius
# Should not reach here, but provide fallback
iter_result = {
"new_params": best_params,
"new_cost": best_cost,
"gradient_norm": 0.0,
"actual_reduction": 0.0,
"trust_radius": trust_radius,
}
return iter_result, trust_radius
def _should_save_checkpoint(self, iteration: int) -> bool:
"""Check if checkpoint should be saved at this iteration.
Parameters
----------
iteration : int
Current iteration number (0-indexed).
Returns
-------
should_save : bool
True if checkpoint should be saved.
"""
if not getattr(self.config, "enable_checkpoints", False):
return False
if not getattr(self.config, "checkpoint_dir", None):
return False
frequency = getattr(self.config, "checkpoint_frequency", 0)
if frequency <= 0:
return False
return (iteration + 1) % frequency == 0
def _run_phase2_gauss_newton(
self,
data_source: tuple[jnp.ndarray, jnp.ndarray],
initial_params: jnp.ndarray,
) -> dict[str, Any]:
"""Run Phase 2 streaming Gauss-Newton optimization.
This is the main Phase 2 loop that iterates Gauss-Newton steps until
convergence or maximum iterations.
Parameters
----------
data_source : tuple of array_like
Full dataset as (x_data, y_data)
initial_params : array_like
Starting parameters in normalized space (from Phase 1)
Returns
-------
result : dict
Phase 2 optimization result with keys:
- 'final_params': Final parameters in normalized space
- 'best_params': Best parameters found
- 'best_cost': Best cost achieved
- 'final_cost': Final cost value
- 'iterations': Number of Gauss-Newton iterations
- 'convergence_reason': Why optimization stopped
- 'gradient_norm': Final gradient norm
- 'JTJ_final': Final accumulated J^T J matrix (for Phase 3)
- 'residual_sum_sq': Total residual sum of squares (for Phase 3)
Notes
-----
Convergence criteria:
- Gradient norm < gauss_newton_tol
- Cost change < gauss_newton_tol
- Maximum iterations reached
"""
current_params = initial_params
trust_radius = self.config.trust_region_initial
# Track best parameters
best_params = current_params
best_cost = jnp.inf
prev_cost = jnp.inf
# Store final J^T J and residual sum for Phase 3
# Initialize with JTJ at initial params (in case no steps accepted)
x_data, y_data = data_source
n_params = len(current_params)
chunk_size = self.config.chunk_size
n_points = len(x_data)
final_JTJ = jnp.zeros((n_params, n_params))
final_JTr = jnp.zeros(n_params)
final_residual_sum_sq = 0.0
# Get verbosity from config or default to 1 for progress output
verbose = getattr(self.config, "verbose", 1)
log_frequency = getattr(
self.config, "log_frequency", 1
) # Log every N iterations
# Compute initial JTJ with progress reporting
n_chunks = (n_points + chunk_size - 1) // chunk_size
init_start_time = time.time()
if verbose >= 1:
print(
f" Computing initial JTJ ({n_chunks} chunks, {n_points:,} points)..."
)
for chunk_idx, i in enumerate(range(0, n_points, chunk_size)):
x_chunk = x_data[i : i + chunk_size]
y_chunk = y_data[i : i + chunk_size]
final_JTJ, final_JTr, res_sq = self._accumulate_jtj_jtr(
x_chunk, y_chunk, current_params, final_JTJ, final_JTr
)
final_residual_sum_sq += res_sq
# Progress for initial JTJ computation (every 10% or every 50 chunks)
if verbose >= 1 and (
(chunk_idx + 1) % max(1, n_chunks // 10) == 0
or (chunk_idx + 1) == n_chunks
):
elapsed = time.time() - init_start_time
pct = (chunk_idx + 1) / n_chunks * 100
print(
f" Initial JTJ: {chunk_idx + 1}/{n_chunks} chunks "
f"({pct:.0f}%), elapsed={elapsed:.1f}s"
)
if verbose >= 1:
init_elapsed = time.time() - init_start_time
print(
f" Initial JTJ complete: cost={final_residual_sum_sq:.6e}, time={init_elapsed:.1f}s"
)
# Gauss-Newton loop
# Initialize stall detection counter
self._consecutive_rejections = 0
for iteration in range(self.config.gauss_newton_max_iterations):
iter_start_time = time.time()
# Perform one Gauss-Newton iteration with retry logic
iter_result, trust_radius = self._gn_iteration_with_retry(
data_source, current_params, trust_radius, best_params, best_cost
)
# Extract results
new_params = iter_result["new_params"]
new_cost = iter_result["new_cost"]
gradient_norm = iter_result["gradient_norm"]
actual_reduction = iter_result["actual_reduction"]
trust_radius = iter_result["trust_radius"]
iter_time = time.time() - iter_start_time
# Progress logging for Phase 2 iterations
if verbose >= 1 and (iteration + 1) % log_frequency == 0:
print(
f" GN iter {iteration + 1}/{self.config.gauss_newton_max_iterations}: "
f"cost={new_cost:.6e}, grad_norm={gradient_norm:.6e}, "
f"reduction={actual_reduction:.6e}, Δ={trust_radius:.4f}, "
f"time={iter_time:.1f}s"
)
# Update best parameters
if new_cost < best_cost:
best_cost = new_cost
best_params = new_params
# Track global best
if new_cost < self.best_cost_global:
self.best_cost_global = new_cost
self.best_params_global = new_params
# Save checkpoint periodically if enabled
if self._should_save_checkpoint(iteration):
checkpoint_path = (
Path(self.config.checkpoint_dir)
/ f"checkpoint_phase2_iter{iteration + 1}.h5"
)
self.current_phase = 2
self.normalized_params = current_params
self._save_checkpoint(checkpoint_path)
# Accept step if cost decreased
# Save cost before step for convergence check
cost_before_step = prev_cost if jnp.isfinite(prev_cost) else new_cost
if actual_reduction > 0:
current_params = new_params
# Update prev_cost AFTER saving for convergence check
cost_before_step = (
prev_cost
if jnp.isfinite(prev_cost)
else new_cost + actual_reduction
)
prev_cost = new_cost
self._consecutive_rejections = 0 # Reset rejection counter on success
# Recompute J^T J at new params for Phase 3
# This ensures we have J^T J at the final parameters
x_data, y_data = data_source
n_params = len(current_params)
chunk_size = self.config.chunk_size
JTJ = jnp.zeros((n_params, n_params))
JTr = jnp.zeros(n_params)
residual_sum_sq = 0.0
n_points = len(x_data)
for i in range(0, n_points, chunk_size):
x_chunk = x_data[i : i + chunk_size]
y_chunk = y_data[i : i + chunk_size]
JTJ, JTr, res_sq = self._accumulate_jtj_jtr(
x_chunk, y_chunk, current_params, JTJ, JTr
)
residual_sum_sq += res_sq
final_JTJ = JTJ
final_residual_sum_sq = residual_sum_sq
else:
# Step rejected - trust radius is already updated in
# _gauss_newton_iteration, no need to shrink again here.
# Track consecutive rejections for stall detection.
consecutive_rejections = getattr(self, "_consecutive_rejections", 0) + 1
self._consecutive_rejections = consecutive_rejections
# Stall detection: if many consecutive rejections with large
# gradient, the optimizer is stuck. Reset trust radius.
if consecutive_rejections >= 10 and gradient_norm > 1e-4:
trust_radius = self.config.trust_region_initial
self._consecutive_rejections = 0
if verbose >= 1:
print(
f" Stall detected: resetting trust radius to "
f"{trust_radius:.4f}"
)
# Check convergence: gradient norm
if gradient_norm < self.config.gauss_newton_tol:
# Record Phase 2 completion
phase_record = {
"phase": 2,
"name": "gauss_newton",
"iterations": iteration + 1,
"final_cost": new_cost,
"best_cost": best_cost,
"convergence_reason": "Gradient norm below tolerance",
"gradient_norm": gradient_norm,
"timestamp": time.time(),
}
self.phase_history.append(phase_record)
return {
"final_params": new_params,
"best_params": best_params,
"best_cost": best_cost,
"final_cost": new_cost,
"iterations": iteration + 1,
"convergence_reason": "Gradient norm below tolerance",
"gradient_norm": gradient_norm,
"JTJ_final": final_JTJ,
"residual_sum_sq": final_residual_sum_sq,
}
# Check convergence: cost change (compare to cost before this step)
cost_change = abs(cost_before_step - new_cost)
relative_change = cost_change / (abs(cost_before_step) + 1e-10)
if relative_change < self.config.gauss_newton_tol:
phase_record = {
"phase": 2,
"name": "gauss_newton",
"iterations": iteration + 1,
"final_cost": new_cost,
"best_cost": best_cost,
"convergence_reason": "Cost change below tolerance",
"gradient_norm": gradient_norm,
"timestamp": time.time(),
}
self.phase_history.append(phase_record)
return {
"final_params": new_params,
"best_params": best_params,
"best_cost": best_cost,
"final_cost": new_cost,
"iterations": iteration + 1,
"convergence_reason": "Cost change below tolerance",
"gradient_norm": gradient_norm,
"JTJ_final": final_JTJ,
"residual_sum_sq": final_residual_sum_sq,
}
# Maximum iterations reached
phase_record = {
"phase": 2,
"name": "gauss_newton",
"iterations": self.config.gauss_newton_max_iterations,
"final_cost": prev_cost,
"best_cost": best_cost,
"convergence_reason": "Maximum iterations reached",
"gradient_norm": gradient_norm,
"timestamp": time.time(),
}
self.phase_history.append(phase_record)
return {
"final_params": best_params, # Use best, not current
"best_params": best_params,
"best_cost": best_cost,
"final_cost": prev_cost,
"iterations": self.config.gauss_newton_max_iterations,
"convergence_reason": "Maximum iterations reached",
"gradient_norm": gradient_norm,
"JTJ_final": final_JTJ,
"residual_sum_sq": final_residual_sum_sq,
}
def _denormalize_params(self, normalized_params: jnp.ndarray) -> jnp.ndarray:
"""Denormalize parameters back to original space.
This method uses the stored normalizer to transform parameters from
normalized space (used during optimization) back to the original
parameter space for final results.
Parameters
----------
normalized_params : array_like
Parameters in normalized space of shape (n_params,)
Returns
-------
params_original : array_like
Parameters in original space of shape (n_params,)
Notes
-----
Uses the normalizer.denormalize() method which implements the exact
inverse of the normalization transform applied in Phase 0.
Examples
--------
>>> # After optimization completes
>>> normalized_result = jnp.array([0.6, 0.7])
>>> original_result = optimizer._denormalize_params(normalized_result)
"""
if self.normalizer is None:
raise RuntimeError(
"Normalizer not initialized. Call _setup_normalization first."
)
return self.normalizer.denormalize(normalized_params)
def _compute_normalized_covariance(self, JTJ: jnp.ndarray) -> jnp.ndarray:
"""Compute covariance matrix in normalized space from J^T J.
The covariance is the inverse of the Hessian approximation:
Cov_norm = (J^T J)^(-1)
Uses pseudo-inverse for numerical stability and to handle rank-deficient
or ill-conditioned matrices gracefully.
Parameters
----------
JTJ : array_like
Accumulated J^T J matrix from Phase 2 of shape (n_params, n_params)
Returns
-------
cov_norm : array_like
Covariance matrix in normalized space of shape (n_params, n_params)
Notes
-----
Uses jnp.linalg.pinv (pseudo-inverse) which is numerically stable and
handles singular matrices via SVD. The result is guaranteed to be
symmetric positive semi-definite.
The pseudo-inverse uses SVD: J^T J = U S V^T
Then: (J^T J)^(-1) = V S^(-1) U^T where small singular values are zeroed.
"""
# Compute pseudo-inverse for numerical stability
# This handles rank-deficient and ill-conditioned matrices
cov_norm = jnp.linalg.pinv(JTJ)
# Ensure symmetry (should already be symmetric, but enforce for numerical reasons)
cov_norm = 0.5 * (cov_norm + cov_norm.T)
return cov_norm
def _transform_covariance(self, cov_norm: jnp.ndarray) -> jnp.ndarray:
"""Transform covariance from normalized to original parameter space.
Uses the chain rule for covariance transformation:
Cov_orig = D @ Cov_norm @ D^T
where D is the denormalization Jacobian (stored in Phase 0).
For our diagonal normalization (element-wise scaling), D is diagonal:
D = diag(scales)
This simplifies to:
Cov_orig[i,j] = scale_i * scale_j * Cov_norm[i,j]
Parameters
----------
cov_norm : array_like
Covariance matrix in normalized space of shape (n_params, n_params)
Returns
-------
cov_orig : array_like
Covariance matrix in original space of shape (n_params, n_params)
Notes
-----
The transformation preserves:
- Symmetry: Cov_orig is symmetric if Cov_norm is symmetric
- Positive semi-definiteness: Eigenvalues remain non-negative
- Shape: (n_params, n_params)
Mathematical derivation:
Let p_orig = denormalize(p_norm) = D @ p_norm + offset
Then Jacobian of denormalization is D (constant, diagonal)
By chain rule: Cov(p_orig) = D @ Cov(p_norm) @ D^T
"""
if self.normalization_jacobian is None:
raise RuntimeError(
"Normalization Jacobian not available. Call _setup_normalization first."
)
# Get denormalization Jacobian (diagonal matrix with scales)
D = self.normalization_jacobian
# Transform covariance: Cov_orig = D @ Cov_norm @ D^T
cov_orig = D @ cov_norm @ D.T
# Ensure symmetry (should be preserved, but enforce for numerical stability)
cov_orig = 0.5 * (cov_orig + cov_orig.T)
return cov_orig
def _apply_residual_variance(
self,
cov_orig: jnp.ndarray,
residual_sum_sq: float,
n_points: int,
) -> tuple[jnp.ndarray, float]:
"""Apply residual variance scaling to covariance matrix.
Scales the covariance by the residual variance estimate:
sigma^2 = residual_sum_sq / (n - p)
Cov_final = sigma^2 * Cov_orig
This provides the final covariance estimate accounting for the
goodness-of-fit of the model to the data.
Parameters
----------
cov_orig : array_like
Covariance matrix in original space (before variance scaling)
of shape (n_params, n_params)
residual_sum_sq : float
Total sum of squared residuals from Phase 2
n_points : int
Total number of data points
Returns
-------
cov_final : array_like
Final scaled covariance matrix of shape (n_params, n_params)
sigma_sq : float
Residual variance estimate (sigma^2)
Notes
-----
The residual variance is computed as:
sigma^2 = sum(residuals^2) / (n - p)
where n is the number of data points and p is the number of parameters.
This is an unbiased estimator of the error variance.
Special cases:
- If n <= p: sigma^2 set to infinity (covariance undefined)
- If residual_sum_sq = 0: Perfect fit, sigma^2 = 0
This matches scipy.optimize.curve_fit behavior when absolute_sigma=False.
"""
n_params = cov_orig.shape[0]
# Compute degrees of freedom
dof = n_points - n_params
if dof <= 0:
# Not enough data points to estimate variance
# Set to infinity (covariance undefined)
sigma_sq = jnp.inf
cov_final = jnp.full_like(cov_orig, jnp.inf)
else:
# Compute residual variance
sigma_sq = residual_sum_sq / dof
# Scale covariance by residual variance
cov_final = sigma_sq * cov_orig
return cov_final, float(sigma_sq)
def _compute_standard_errors(self, pcov: jnp.ndarray) -> jnp.ndarray:
"""Compute standard errors from covariance matrix diagonal.
Standard errors are the square root of the diagonal elements of
the covariance matrix:
perr[i] = sqrt(pcov[i, i])
Parameters
----------
pcov : array_like
Covariance matrix of shape (n_params, n_params)
Returns
-------
perr : array_like
Standard errors of shape (n_params,)
Notes
-----
Standard errors represent the uncertainty in each parameter estimate
(1-sigma confidence interval).
Special cases:
- If pcov[i,i] < 0: Sets perr[i] = NaN (indicates numerical issues)
- If pcov[i,i] = inf: Sets perr[i] = inf (undefined uncertainty)
The standard errors can be used to compute confidence intervals:
95% CI: param ± 1.96 * perr
"""
# Extract diagonal elements
variances = jnp.diag(pcov)
# Compute standard errors (sqrt of variances)
# Handle negative variances by setting to NaN
perr = jnp.where(variances >= 0, jnp.sqrt(variances), jnp.nan)
return perr
def _run_phase3_finalize(
self,
optimized_params_normalized: jnp.ndarray,
JTJ_final: jnp.ndarray,
residual_sum_sq: float,
n_points: int,
) -> dict[str, Any]:
"""Run Phase 3: Denormalization and covariance transform.
This is the final phase that:
1. Denormalizes optimized parameters to original space
2. Computes covariance in normalized space from J^T J
3. Transforms covariance to original space
4. Applies residual variance scaling
5. Computes standard errors
Parameters
----------
optimized_params_normalized : array_like
Optimized parameters in normalized space from Phase 2
JTJ_final : array_like
Final accumulated J^T J matrix from Phase 2
residual_sum_sq : float
Total sum of squared residuals
n_points : int
Total number of data points
Returns
-------
result : dict
Phase 3 result with keys:
- 'popt': Optimized parameters in original space
- 'pcov': Full covariance matrix in original space
- 'perr': Standard errors (1-sigma)
- 'sigma_sq': Residual variance estimate
- 'diagnostics': Phase 3 diagnostics
Notes
-----
This method orchestrates all Phase 3 operations to produce
final results compatible with scipy.optimize.curve_fit:
- popt: Optimized parameters
- pcov: Covariance matrix for uncertainty analysis
- perr: Standard errors (not returned by scipy but useful)
The covariance transformation follows:
1. Cov_norm = (J^T J)^(-1) [in normalized space]
2. Cov_orig = D @ Cov_norm @ D^T [transform to original space]
3. Cov_final = sigma^2 * Cov_orig [scale by residual variance]
"""
phase_start = time.time()
# Step 1: Denormalize parameters to original space
popt = self._denormalize_params(optimized_params_normalized)
# Step 2: Compute covariance in normalized space
cov_norm = self._compute_normalized_covariance(JTJ_final)
# Step 3: Transform covariance to original space
cov_orig = self._transform_covariance(cov_norm)
# Step 4: Apply residual variance scaling
pcov, sigma_sq = self._apply_residual_variance(
cov_orig, residual_sum_sq, n_points
)
# Step 5: Compute standard errors
perr = self._compute_standard_errors(pcov)
# Record Phase 3 completion
phase_duration = time.time() - phase_start
phase_record = {
"phase": 3,
"name": "denormalization_covariance",
"duration": phase_duration,
"sigma_sq": sigma_sq,
"cov_condition": float(jnp.linalg.cond(pcov))
if jnp.isfinite(pcov).all()
else jnp.inf,
"timestamp": time.time(),
}
self.phase_history.append(phase_record)
return {
"popt": popt,
"pcov": pcov,
"perr": perr,
"sigma_sq": sigma_sq,
"diagnostics": phase_record,
}
def _validate_numerics(
self,
params: jnp.ndarray,
loss: float | None = None,
gradients: jnp.ndarray | None = None,
context: str = "",
) -> bool:
"""Validate numerical stability of parameters, loss, and gradients.
Parameters
----------
params : array_like
Parameters to validate
loss : float, optional
Loss value to validate
gradients : array_like, optional
Gradient values to validate
context : str, optional
Context string for error messages
Returns
-------
is_valid : bool
True if all values are finite, False otherwise
Raises
------
ValueError
If validate_numerics is enabled and non-finite values detected
"""
if (
not hasattr(self.config, "validate_numerics")
or not self.config.validate_numerics
):
return True
# Check parameters
if not jnp.all(jnp.isfinite(params)):
if (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
# Log warning but continue
return False
else:
raise ValueError(f"NaN/Inf detected in parameters {context}")
# Check loss
if loss is not None and not jnp.isfinite(loss):
if (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
return False
else:
raise ValueError(f"NaN/Inf detected in loss {context}")
# Check gradients
if gradients is not None and not jnp.all(jnp.isfinite(gradients)):
if (
hasattr(self.config, "enable_fault_tolerance")
and self.config.enable_fault_tolerance
):
return False
else:
raise ValueError(f"NaN/Inf detected in gradients {context}")
return True
def _get_checkpoint_manager(self):
"""Get or create the checkpoint manager (lazy initialization).
Returns
-------
CheckpointManager
The checkpoint manager instance.
"""
if self._checkpoint_manager is None:
if "CheckpointManager" not in _lazy_imports:
from nlsq.streaming.phases.checkpoint import CheckpointManager
_lazy_imports["CheckpointManager"] = CheckpointManager
self._checkpoint_manager = _lazy_imports["CheckpointManager"](self.config)
return self._checkpoint_manager
def _create_checkpoint_state(self):
"""Create a CheckpointState from current optimizer state.
Returns
-------
CheckpointState
State container with all optimizer state for checkpointing.
"""
from nlsq.streaming.phases.checkpoint import CheckpointState
return CheckpointState(
current_phase=self.current_phase,
normalized_params=self.normalized_params,
phase1_optimizer_state=self.phase1_optimizer_state,
phase2_JTJ_accumulator=self.phase2_JTJ_accumulator,
phase2_JTr_accumulator=self.phase2_JTr_accumulator,
best_params_global=self.best_params_global,
best_cost_global=self.best_cost_global,
phase_history=self.phase_history,
normalizer=self.normalizer,
tournament_selector=self.tournament_selector,
multistart_candidates=self.multistart_candidates,
)
def _restore_from_checkpoint_state(self, state) -> None:
"""Restore optimizer state from a CheckpointState.
Parameters
----------
state : CheckpointState
State container to restore from.
"""
self.current_phase = state.current_phase
self.normalized_params = state.normalized_params
self.phase1_optimizer_state = state.phase1_optimizer_state
self.phase2_JTJ_accumulator = state.phase2_JTJ_accumulator
self.phase2_JTr_accumulator = state.phase2_JTr_accumulator
self.best_params_global = state.best_params_global
self.best_cost_global = state.best_cost_global
self.phase_history = state.phase_history
# Note: normalizer is NOT restored from checkpoint - recreated in _setup_normalization
if state.tournament_selector is not None:
self.tournament_selector = state.tournament_selector
if state.multistart_candidates is not None:
self.multistart_candidates = state.multistart_candidates
def _save_checkpoint(self, checkpoint_path: str | Path) -> None:
"""Save checkpoint with phase-specific state to HDF5 file.
Parameters
----------
checkpoint_path : str or Path
Path to checkpoint file (.h5)
Notes
-----
Checkpoint format version 3.0 includes:
- current_phase: Current phase number
- normalized_params: Parameters in normalized space
- phase1_optimizer_state: Optax L-BFGS state (history + params)
- phase2_jtj_accumulator: Accumulated J^T J matrix
- phase2_jtr_accumulator: Accumulated J^T r vector
- best_params_global: Best parameters found globally
- best_cost_global: Best cost value globally
- phase_history: Complete phase history
"""
manager = self._get_checkpoint_manager()
state = self._create_checkpoint_state()
manager.save(checkpoint_path, state)
def _load_checkpoint(self, checkpoint_path: str | Path) -> None:
"""Load checkpoint and restore phase-specific state.
Parameters
----------
checkpoint_path : str or Path
Path to checkpoint file (.h5)
Raises
------
FileNotFoundError
If checkpoint file does not exist
ValueError
If checkpoint version is incompatible
"""
manager = self._get_checkpoint_manager()
# Create GlobalOptimizationConfig for tournament reconstruction if needed
global_config = None
if self.config.enable_multistart:
global_config = GlobalOptimizationConfig(
n_starts=self.config.n_starts,
elimination_rounds=self.config.elimination_rounds,
elimination_fraction=self.config.elimination_fraction,
batches_per_round=self.config.batches_per_round,
)
state = manager.load(checkpoint_path, global_config)
self._restore_from_checkpoint_state(state)
def _detect_available_devices(self) -> dict[str, Any]:
"""Detect available GPU/TPU devices using JAX.
Returns
-------
device_info : dict
Dictionary with keys:
- 'device_count': Number of available devices (int)
- 'device_type': Type of devices ('cpu', 'gpu', 'tpu')
- 'devices': List of JAX device objects
Notes
-----
Uses jax.devices() to detect all available devices.
Device type is determined from the first device's platform.
Examples
--------
>>> optimizer = AdaptiveHybridStreamingOptimizer()
>>> info = optimizer._detect_available_devices()
>>> print(info['device_count'])
1
>>> print(info['device_type'])
'cpu'
"""
# Get all available devices
devices = jax.devices()
device_count = len(devices)
# Determine device type from first device
if device_count > 0:
platform = devices[0].platform
# Map JAX platform names to our device types
if platform == "cpu":
device_type = "cpu"
elif platform in ["gpu", "cuda", "rocm"]:
device_type = "gpu"
elif platform == "tpu":
device_type = "tpu"
else:
device_type = "cpu" # Default fallback
else:
device_type = "cpu"
device_info = {
"device_count": device_count,
"device_type": device_type,
"devices": devices,
}
# Store for later use
self.device_info = device_info
return device_info
def _should_use_multi_device(self, device_info: dict[str, Any]) -> bool:
"""Determine if multi-device should be used based on config and availability.
Parameters
----------
device_info : dict
Device information from _detect_available_devices()
Returns
-------
should_use : bool
True if multi-device should be used, False otherwise
Notes
-----
Multi-device is used only if:
1. config.enable_multi_device is True
2. More than 1 device is available
3. Devices are not CPU (CPU multi-device not beneficial)
"""
# Check if enabled in config
if not self.config.enable_multi_device:
return False
# Check if multiple devices available
if device_info["device_count"] <= 1:
return False
# Don't use multi-device for CPU (no benefit)
return device_info["device_type"] != "cpu"
def _setup_multi_device(self, device_info: dict[str, Any]) -> dict[str, Any]:
"""Setup multi-device configuration for data-parallel computation.
Parameters
----------
device_info : dict
Device information from _detect_available_devices()
Returns
-------
multi_device_config : dict
Configuration with keys:
- 'use_multi_device': Whether multi-device is enabled (bool)
- 'device_count': Number of devices to use (int)
- 'devices': List of JAX device objects
- 'axis_name': Axis name for pmap/psum ('devices')
Notes
-----
If multi-device cannot be used, returns configuration for single device.
Sets up axis_name='devices' for pmap and psum coordination.
Examples
--------
>>> optimizer = AdaptiveHybridStreamingOptimizer()
>>> device_info = optimizer._detect_available_devices()
>>> config = optimizer._setup_multi_device(device_info)
>>> print(config['use_multi_device'])
False # On single-device system
"""
should_use = self._should_use_multi_device(device_info)
if should_use:
# Multi-device configuration
multi_device_config = {
"use_multi_device": True,
"device_count": device_info["device_count"],
"devices": device_info["devices"],
"axis_name": "devices",
}
else:
# Single-device fallback
multi_device_config = {
"use_multi_device": False,
"device_count": 1,
"devices": device_info["devices"][:1] if device_info["devices"] else [],
"axis_name": None,
}
# Store for later use
self.multi_device_config = multi_device_config
return multi_device_config
def _pmap_jacobian_computation(
self,
x_chunks: list[jnp.ndarray],
params: jnp.ndarray,
) -> list[jnp.ndarray]:
"""Compute Jacobian across multiple devices using pmap (NOT IMPLEMENTED).
This method would use jax.pmap for data-parallel Jacobian computation
across multiple GPUs/TPUs. However, due to the complexity of properly
sharding data and handling device communication, we defer this to
future work.
Parameters
----------
x_chunks : list of array_like
Data chunks to distribute across devices
params : array_like
Parameters in normalized space
Returns
-------
J_chunks : list of array_like
Jacobian matrices computed on each device
Notes
-----
This is a placeholder for future multi-device support.
Currently falls back to single-device computation.
For proper pmap implementation, we would need:
1. Data sharding strategy matching device count
2. Replicated parameters across devices
3. pmap-compatible Jacobian function
4. Device mesh configuration
Raises
------
NotImplementedError
Always raised - pmap not yet implemented
"""
raise NotImplementedError(
"pmap Jacobian computation not yet implemented. "
"Use single-device computation via _compute_jacobian_chunk()."
)
def _aggregate_jtj_across_devices(self, JTJ_local: jnp.ndarray) -> jnp.ndarray:
"""Aggregate J^T J matrix across devices using psum.
On single device, this is a no-op and returns the input matrix unchanged.
On multi-device systems, this would use jax.lax.psum to sum matrices
across all devices.
Parameters
----------
JTJ_local : array_like
Local J^T J matrix of shape (n_params, n_params)
Returns
-------
JTJ_global : array_like
Aggregated J^T J matrix of shape (n_params, n_params)
Notes
-----
Single-device case: Returns input unchanged
Multi-device case (future): Would use jax.lax.psum with axis_name='devices'
Mathematical operation:
JTJ_global = sum_over_devices(JTJ_local)
For proper multi-device aggregation:
```python
if self.multi_device_config["use_multi_device"]:
JTJ_global = jax.lax.psum(JTJ_local, axis_name="devices")
else:
JTJ_global = JTJ_local
```
Examples
--------
>>> optimizer = AdaptiveHybridStreamingOptimizer()
>>> JTJ = jnp.array([[10.0, 2.0], [2.0, 8.0]])
>>> JTJ_agg = optimizer._aggregate_jtj_across_devices(JTJ)
>>> assert jnp.allclose(JTJ_agg, JTJ) # Same on single device
"""
# Check if multi-device is configured and enabled
if self.multi_device_config is not None and self.multi_device_config.get(
"use_multi_device", False
):
# Multi-device aggregation would use psum
# For now, we just return the local matrix (single-device fallback)
# Future implementation:
# return jax.lax.psum(JTJ_local, axis_name='devices')
# Log warning about fallback
import warnings
warnings.warn(
"Multi-device aggregation not yet fully implemented. "
"Falling back to single-device computation.",
UserWarning,
)
# Single-device case or fallback: return unchanged
return JTJ_local
[docs]
def fit(
self,
data_source: Any,
func: Callable,
p0: jnp.ndarray,
bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
sigma: jnp.ndarray | None = None,
absolute_sigma: bool = False,
callback: Callable | None = None,
verbose: int = 1,
) -> dict[str, Any]:
"""Fit model parameters using four-phase hybrid optimization.
This method orchestrates all four phases:
- Phase 0: Setup normalization
- Phase 1: L-BFGS warmup
- Phase 2: Streaming Gauss-Newton
- Phase 3: Denormalization and covariance
Parameters
----------
data_source : various types
Data source for optimization. Can be:
- Tuple of arrays: (x_data, y_data)
- Generator yielding (x_batch, y_batch)
- HDF5 file path with datasets
func : Callable
Model function with signature: ``func(x, *params) -> predictions``
p0 : array_like
Initial parameter guess of shape (n_params,)
bounds : tuple of array_like, optional
Parameter bounds as (lb, ub)
sigma : array_like, optional
Uncertainties in y_data for weighted least squares
absolute_sigma : bool, default=False
If True, sigma is used in absolute sense (pcov not scaled)
callback : Callable, optional
Callback with signature callback(params, loss, iteration)
Called every config.callback_frequency iterations
verbose : int, default=1
Verbosity level (0=silent, 1=progress, 2=debug)
Returns
-------
result : dict
Optimization result dictionary with keys:
- 'x': Optimized parameters in original space
- 'success': Boolean indicating success
- 'message': Status message
- 'fun': Final residuals
- 'pcov': Covariance matrix (Phase 3)
- 'perr': Standard errors (Phase 3)
- 'streaming_diagnostics': Phase information, timing, etc.
Notes
-----
The result dictionary is compatible with scipy.optimize.curve_fit
and can be used interchangeably.
"""
# Track total optimization time
total_start_time = time.time()
# Phase timing storage
phase_timings = {}
phase_iterations = {}
# Convert p0 to JAX array
p0_array = jnp.asarray(p0, dtype=jnp.float64)
# Extract data from source (currently only supports tuple)
if isinstance(data_source, tuple) and len(data_source) == 2:
x_data, y_data = data_source
x_data = jnp.asarray(x_data, dtype=jnp.float64)
y_data = jnp.asarray(y_data, dtype=jnp.float64)
n_points = len(x_data)
else:
raise NotImplementedError(
"Only tuple data sources (x_data, y_data) currently supported"
)
# ============================================================
# Phase 0: Setup Normalization
# ============================================================
if verbose >= 1:
print("=" * 60)
print("Adaptive Hybrid Streaming Optimizer")
print("=" * 60)
print(f"Dataset size: {n_points:,} points")
print(f"Parameters: {len(p0_array)}")
print(f"Normalization: {self.config.normalization_strategy}")
print()
phase0_start = time.time()
self._setup_normalization(func, p0_array, bounds)
phase0_duration = time.time() - phase0_start
phase_timings["phase0_normalization"] = phase0_duration
if verbose >= 1:
print(f"Phase 0: Normalization setup complete ({phase0_duration:.3f}s)")
print(f" Strategy: {self.normalizer.strategy}")
print()
# ============================================================
# Phase 1: L-BFGS warmup
# ============================================================
if verbose >= 1:
print("Phase 1: L-BFGS warmup...")
phase1_start = time.time()
phase1_result = self._run_phase1_warmup(
data_source=(x_data, y_data),
model=func,
p0=p0_array,
bounds=bounds,
)
phase1_duration = time.time() - phase1_start
phase_timings["phase1_warmup"] = phase1_duration
phase_iterations["phase1"] = phase1_result["iterations"]
if verbose >= 1:
print(
f"Phase 1 complete: {phase1_result['iterations']} iterations ({phase1_duration:.3f}s)"
)
print(f" Best loss: {phase1_result['best_loss']:.6e}")
print(f" Switch reason: {phase1_result['switch_reason']}")
print()
# Use best parameters from Phase 1 as starting point for Phase 2
warmup_params = phase1_result["best_params"]
# ============================================================
# Phase 2: Streaming Gauss-Newton
# ============================================================
if verbose >= 1:
print("Phase 2: Streaming Gauss-Newton...")
phase2_start = time.time()
phase2_result = self._run_phase2_gauss_newton(
data_source=(x_data, y_data),
initial_params=warmup_params,
)
phase2_duration = time.time() - phase2_start
phase_timings["phase2_gauss_newton"] = phase2_duration
phase_iterations["phase2"] = phase2_result["iterations"]
if verbose >= 1:
print(
f"Phase 2 complete: {phase2_result['iterations']} iterations ({phase2_duration:.3f}s)"
)
print(f" Final cost: {phase2_result['final_cost']:.6e}")
print(f" Convergence: {phase2_result['convergence_reason']}")
print(f" Gradient norm: {phase2_result['gradient_norm']:.6e}")
print()
# ============================================================
# Phase 3: Denormalization and Covariance
# ============================================================
if verbose >= 1:
print("Phase 3: Computing covariance...")
phase3_start = time.time()
phase3_result = self._run_phase3_finalize(
optimized_params_normalized=phase2_result["final_params"],
JTJ_final=phase2_result["JTJ_final"],
residual_sum_sq=phase2_result["residual_sum_sq"],
n_points=n_points,
)
phase3_duration = time.time() - phase3_start
phase_timings["phase3_finalize"] = phase3_duration
if verbose >= 1:
print(f"Phase 3 complete ({phase3_duration:.3f}s)")
print(f" Residual variance (σ²): {phase3_result['sigma_sq']:.6e}")
print()
# ============================================================
# Assemble Final Result
# ============================================================
total_duration = time.time() - total_start_time
# Compute final residuals for 'fun' field
final_predictions = func(x_data, *phase3_result["popt"])
final_residuals = y_data - final_predictions
# Build streaming diagnostics
streaming_diagnostics = {
"phase_timings": phase_timings,
"phase_iterations": phase_iterations,
"total_time": total_duration,
"warmup_diagnostics": {
"best_loss": phase1_result["best_loss"],
"final_loss": phase1_result["final_loss"],
"switch_reason": phase1_result["switch_reason"],
},
"gauss_newton_diagnostics": {
"best_cost": phase2_result["best_cost"],
"final_cost": phase2_result["final_cost"],
"gradient_norm": phase2_result["gradient_norm"],
"convergence_reason": phase2_result["convergence_reason"],
},
"phase_history": self.phase_history,
}
# Format result dictionary (scipy-compatible + NLSQ extensions)
result = {
"x": phase3_result["popt"], # Optimized parameters
"success": True, # Always True if we reach here
"message": phase2_result["convergence_reason"],
"fun": final_residuals, # Final residuals
"pcov": phase3_result["pcov"], # Covariance matrix
"perr": phase3_result["perr"], # Standard errors (NLSQ extension)
"streaming_diagnostics": streaming_diagnostics, # NLSQ extension
}
if verbose >= 1:
print("=" * 60)
print("Optimization Complete")
print("=" * 60)
print(f"Total time: {total_duration:.3f}s")
print(f"Final parameters: {result['x']}")
print(f"Parameter std errors: {result['perr']}")
print("=" * 60)
return result
@property
def phase_status(self) -> dict[str, Any]:
"""Get current phase status and history.
Returns
-------
status : dict
Phase status dictionary with keys:
- 'current_phase': Current phase number
- 'phase_name': Name of current phase
- 'phase_history': List of completed phases with timing
- 'total_phases': Total number of phases (4)
Examples
--------
>>> config = HybridStreamingConfig()
>>> optimizer = AdaptiveHybridStreamingOptimizer(config)
>>> status = optimizer.phase_status
>>> print(status['current_phase'])
0
>>> print(status['phase_name'])
Phase 0: Normalization Setup
"""
phase_names = {
0: "Phase 0: Normalization Setup",
1: "Phase 1: L-BFGS Warmup",
2: "Phase 2: Streaming Gauss-Newton",
3: "Phase 3: Denormalization and Covariance",
}
return {
"current_phase": self.current_phase,
"phase_name": phase_names.get(self.current_phase, "Unknown"),
"phase_history": self.phase_history,
"total_phases": 4,
}