Source code for nlsq.streaming.phases.gauss_newton

"""Phase 2: Streaming Gauss-Newton optimization.

This module contains the GaussNewtonPhase class that encapsulates
the streaming Gauss-Newton logic for the AdaptiveHybridStreamingOptimizer.

The Gauss-Newton phase streams over the full dataset in chunks,
accumulating JtJ and Jtr implicitly, and solving normal equations via CG.
"""

from __future__ import annotations

import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp

from nlsq.streaming.large_dataset import get_bucket_size
from nlsq.utils.logging import get_logger

if TYPE_CHECKING:
    from jax import Array

    from nlsq.precision.parameter_normalizer import NormalizedModelWrapper
    from nlsq.streaming.hybrid_config import HybridStreamingConfig

_logger = get_logger("gauss_newton_phase")


def _pad_chunk_to_bucket(
    x_chunk: jnp.ndarray, y_chunk: jnp.ndarray, actual_size: int
) -> tuple[jnp.ndarray, jnp.ndarray, int]:
    """Pad a chunk to the nearest power-of-2 bucket size.

    This prevents XLA recompilation when the last chunk has fewer points
    than the configured chunk_size. Padded elements are zeroed out.

    Parameters
    ----------
    x_chunk : array_like
        Independent variable chunk.
    y_chunk : array_like
        Dependent variable chunk.
    actual_size : int
        Number of valid (non-padded) points in the chunk.

    Returns
    -------
    x_padded : array_like
        Padded x chunk with shape (bucket_size,) or (bucket_size, features).
    y_padded : array_like
        Padded y chunk with shape (bucket_size,).
    actual_size : int
        Number of valid points (unchanged, for unpadding results).
    """
    bucket_size = get_bucket_size(actual_size)
    if actual_size == bucket_size:
        return x_chunk, y_chunk, actual_size

    # Pad with zeros to bucket size
    if x_chunk.ndim == 1:
        x_padded = jnp.zeros(bucket_size, dtype=x_chunk.dtype)
        x_padded = x_padded.at[:actual_size].set(x_chunk)
    else:
        x_padded = jnp.zeros((bucket_size, x_chunk.shape[1]), dtype=x_chunk.dtype)
        x_padded = x_padded.at[:actual_size].set(x_chunk)
    y_padded = jnp.zeros(bucket_size, dtype=y_chunk.dtype)
    y_padded = y_padded.at[:actual_size].set(y_chunk)

    return x_padded, y_padded, actual_size


[docs] @dataclass(frozen=True, slots=True) class GNResult: """Result from streaming Gauss-Newton phase. Attributes: params: Final optimized parameters. cost: Final cost value. iterations: Number of GN iterations. converged: Whether GN converged. jacobian: Final Jacobian matrix (optional). cov: Parameter covariance matrix (optional). """ params: Array cost: float iterations: int converged: bool jacobian: Array | None = None cov: Array | None = None
[docs] class GaussNewtonPhase: """Phase 2: Streaming Gauss-Newton with implicit JtJ. This class encapsulates the streaming Gauss-Newton logic for large dataset optimization. It streams over the full dataset in chunks, accumulating J^T J and J^T r, then solving the normal equations. Parameters ---------- config : HybridStreamingConfig Configuration for streaming optimization. normalized_model : NormalizedModelWrapper Model wrapper operating in normalized parameter space. normalized_bounds : tuple of array_like or None Parameter bounds in normalized space. Attributes ---------- config : HybridStreamingConfig Configuration object. normalized_model : NormalizedModelWrapper Normalized model wrapper. normalized_bounds : tuple of array_like or None Bounds in normalized space. Notes ----- The Gauss-Newton method iteratively solves:: (J^T J) @ step = J^T r where J is the Jacobian and r is the residual vector. """
[docs] def __init__( self, config: HybridStreamingConfig, normalized_model: NormalizedModelWrapper, normalized_bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None, ) -> None: """Initialize GaussNewtonPhase. Parameters ---------- config : HybridStreamingConfig Configuration for streaming optimization. normalized_model : NormalizedModelWrapper Model wrapper operating in normalized parameter space. normalized_bounds : tuple of array_like or None Parameter bounds in normalized space. """ self.config = config self.normalized_model = normalized_model self.normalized_bounds = normalized_bounds # Pre-compiled functions (set externally or lazily initialized) self._jacobian_fn_compiled: Callable | None = None self._cost_fn_compiled: Callable | None = None # Accumulators for checkpointing self.phase2_JTJ_accumulator: jnp.ndarray | None = None self.phase2_JTr_accumulator: jnp.ndarray | None = None # Stall detection self._consecutive_rejections: int = 0
[docs] def set_jacobian_fn(self, fn: Callable | None) -> None: """Set pre-compiled Jacobian function.""" self._jacobian_fn_compiled = fn
[docs] def set_cost_fn(self, fn: Callable | None) -> None: """Set pre-compiled cost function.""" self._cost_fn_compiled = fn
[docs] def run( self, data_source: tuple[jnp.ndarray, jnp.ndarray], initial_params: jnp.ndarray, phase_history: list[dict[str, Any]], best_tracker: dict[str, Any], ) -> dict[str, Any]: """Run Phase 2 streaming Gauss-Newton optimization. Parameters ---------- data_source : tuple of array_like Full dataset as (x_data, y_data). initial_params : array_like Starting parameters in normalized space (from Phase 1). phase_history : list Phase history list to append records to. best_tracker : dict Dictionary tracking best_params_global and best_cost_global. Returns ------- result : dict Phase 2 optimization result with keys: - 'final_params': Final parameters in normalized space - 'best_params': Best parameters found - 'best_cost': Best cost achieved - 'final_cost': Final cost value - 'iterations': Number of Gauss-Newton iterations - 'convergence_reason': Why optimization stopped - 'gradient_norm': Final gradient norm - 'JTJ_final': Final accumulated J^T J matrix (for Phase 3) - 'residual_sum_sq': Total residual sum of squares (for Phase 3) - 'gn_result': GNResult dataclass instance """ current_params = initial_params trust_radius = self.config.trust_region_initial # Track best parameters best_params = current_params best_cost = jnp.inf prev_cost = jnp.inf x_data, y_data = data_source n_params = len(current_params) chunk_size = self.config.chunk_size n_points = len(x_data) final_JTJ = jnp.zeros((n_params, n_params)) final_JTr = jnp.zeros(n_params) final_residual_sum_sq = 0.0 verbose = getattr(self.config, "verbose", 1) log_frequency = getattr(self.config, "log_frequency", 1) # Compute initial JTJ with progress reporting n_chunks = (n_points + chunk_size - 1) // chunk_size init_start_time = time.time() if verbose >= 1: _logger.info( f" Computing initial JTJ ({n_chunks} chunks, {n_points:,} points)..." ) for chunk_idx, i in enumerate(range(0, n_points, chunk_size)): x_chunk = x_data[i : i + chunk_size] y_chunk = y_data[i : i + chunk_size] actual_size = len(x_chunk) # Pad last chunk to bucket size to prevent XLA recompilation x_chunk, y_chunk, actual_size = _pad_chunk_to_bucket( x_chunk, y_chunk, actual_size ) final_JTJ, final_JTr, res_sq = self._accumulate_jtj_jtr( x_chunk, y_chunk, current_params, final_JTJ, final_JTr, valid_size=actual_size, ) final_residual_sum_sq += res_sq if verbose >= 1 and ( (chunk_idx + 1) % max(1, n_chunks // 10) == 0 or (chunk_idx + 1) == n_chunks ): elapsed = time.time() - init_start_time pct = (chunk_idx + 1) / n_chunks * 100 _logger.info( f" Initial JTJ: {chunk_idx + 1}/{n_chunks} chunks " f"({pct:.0f}%), elapsed={elapsed:.1f}s" ) if verbose >= 1: init_elapsed = time.time() - init_start_time _logger.info( f" Initial JTJ complete: cost={final_residual_sum_sq:.6e}, " f"time={init_elapsed:.1f}s" ) # Initialize stall detection self._consecutive_rejections = 0 gradient_norm = 0.0 new_cost = float( final_residual_sum_sq ) # guard: used after loop; set to initial cost # Gauss-Newton loop for iteration in range(self.config.gauss_newton_max_iterations): iter_start_time = time.time() # Perform one Gauss-Newton iteration iter_result = self._gauss_newton_iteration( data_source, current_params, trust_radius ) new_params = iter_result["new_params"] new_cost = iter_result["new_cost"] gradient_norm = iter_result["gradient_norm"] actual_reduction = iter_result["actual_reduction"] trust_radius = iter_result["trust_radius"] iter_time = time.time() - iter_start_time # Progress logging if verbose >= 1 and (iteration + 1) % log_frequency == 0: max_iter = self.config.gauss_newton_max_iterations _logger.info( f" GN iter {iteration + 1}/{max_iter}: " f"cost={new_cost:.6e}, grad={gradient_norm:.6e}, " f"red={actual_reduction:.6e}, Δ={trust_radius:.4f}, " f"time={iter_time:.1f}s" ) # Update best parameters if new_cost < best_cost: best_cost = new_cost best_params = new_params # Track global best if new_cost < best_tracker.get("best_cost_global", float("inf")): best_tracker["best_cost_global"] = new_cost best_tracker["best_params_global"] = new_params # Accept step if cost decreased cost_before_step = prev_cost if jnp.isfinite(prev_cost) else new_cost if actual_reduction > 0: current_params = new_params cost_before_step = ( prev_cost if jnp.isfinite(prev_cost) else new_cost + actual_reduction ) prev_cost = new_cost self._consecutive_rejections = 0 # Recompute J^T J at new params for Phase 3 JTJ = jnp.zeros((n_params, n_params)) JTr = jnp.zeros(n_params) residual_sum_sq = 0.0 for i in range(0, n_points, chunk_size): x_chunk = x_data[i : i + chunk_size] y_chunk = y_data[i : i + chunk_size] actual_size = len(x_chunk) x_chunk, y_chunk, actual_size = _pad_chunk_to_bucket( x_chunk, y_chunk, actual_size ) JTJ, JTr, res_sq = self._accumulate_jtj_jtr( x_chunk, y_chunk, current_params, JTJ, JTr, valid_size=actual_size, ) residual_sum_sq += res_sq final_JTJ = JTJ final_residual_sum_sq = residual_sum_sq else: self._consecutive_rejections += 1 # Stall detection if self._consecutive_rejections >= 10 and gradient_norm > 1e-4: trust_radius = self.config.trust_region_initial self._consecutive_rejections = 0 if verbose >= 1: _logger.info( f" Stall detected: resetting trust radius to " f"{trust_radius:.4f}" ) # Check convergence: gradient norm if gradient_norm < self.config.gauss_newton_tol: phase_record = { "phase": 2, "name": "gauss_newton", "iterations": iteration + 1, "final_cost": new_cost, "best_cost": best_cost, "convergence_reason": "Gradient norm below tolerance", "gradient_norm": gradient_norm, "timestamp": time.time(), } phase_history.append(phase_record) gn_result = GNResult( params=new_params, cost=new_cost, iterations=iteration + 1, converged=True, ) return { "final_params": new_params, "best_params": best_params, "best_cost": best_cost, "final_cost": new_cost, "iterations": iteration + 1, "convergence_reason": "Gradient norm below tolerance", "gradient_norm": gradient_norm, "JTJ_final": final_JTJ, "residual_sum_sq": final_residual_sum_sq, "gn_result": gn_result, } # Check convergence: cost change cost_change = abs(cost_before_step - new_cost) relative_change = cost_change / (abs(cost_before_step) + 1e-10) if relative_change < self.config.gauss_newton_tol: phase_record = { "phase": 2, "name": "gauss_newton", "iterations": iteration + 1, "final_cost": new_cost, "best_cost": best_cost, "convergence_reason": "Cost change below tolerance", "gradient_norm": gradient_norm, "timestamp": time.time(), } phase_history.append(phase_record) gn_result = GNResult( params=new_params, cost=new_cost, iterations=iteration + 1, converged=True, ) return { "final_params": new_params, "best_params": best_params, "best_cost": best_cost, "final_cost": new_cost, "iterations": iteration + 1, "convergence_reason": "Cost change below tolerance", "gradient_norm": gradient_norm, "JTJ_final": final_JTJ, "residual_sum_sq": final_residual_sum_sq, "gn_result": gn_result, } # Maximum iterations reached. # Use prev_cost when steps were accepted, otherwise fall back to new_cost # (prev_cost stays jnp.inf when every step was rejected). reported_final_cost = prev_cost if jnp.isfinite(prev_cost) else new_cost phase_record = { "phase": 2, "name": "gauss_newton", "iterations": self.config.gauss_newton_max_iterations, "final_cost": reported_final_cost, "best_cost": best_cost, "convergence_reason": "Maximum iterations reached", "gradient_norm": gradient_norm, "timestamp": time.time(), } phase_history.append(phase_record) gn_result = GNResult( params=best_params, cost=best_cost, iterations=self.config.gauss_newton_max_iterations, converged=False, ) return { "final_params": best_params, "best_params": best_params, "best_cost": best_cost, "final_cost": reported_final_cost, "iterations": self.config.gauss_newton_max_iterations, "convergence_reason": "Maximum iterations reached", "gradient_norm": gradient_norm, "JTJ_final": final_JTJ, "residual_sum_sq": final_residual_sum_sq, "gn_result": gn_result, }
def _gauss_newton_iteration( self, data_source: tuple[jnp.ndarray, jnp.ndarray], current_params: jnp.ndarray, trust_radius: float, ) -> dict[str, Any]: """Perform one complete Gauss-Newton iteration. Parameters ---------- data_source : tuple of array_like Full dataset as (x_data, y_data). current_params : array_like Current parameters in normalized space. trust_radius : float Current trust region radius. Returns ------- result : dict Iteration result with keys: new_params, new_cost, step, actual_reduction, predicted_reduction, trust_radius, gradient_norm. """ x_data, y_data = data_source n_params = len(current_params) chunk_size = self.config.chunk_size n_points = len(x_data) # Accumulate J^T J and J^T r JTJ = jnp.zeros((n_params, n_params)) JTr = jnp.zeros(n_params) total_cost = 0.0 for i in range(0, n_points, chunk_size): x_chunk = x_data[i : i + chunk_size] y_chunk = y_data[i : i + chunk_size] actual_size = len(x_chunk) x_chunk, y_chunk, actual_size = _pad_chunk_to_bucket( x_chunk, y_chunk, actual_size ) JTJ, JTr, chunk_cost = self._accumulate_jtj_jtr( x_chunk, y_chunk, current_params, JTJ, JTr, valid_size=actual_size, ) total_cost += chunk_cost # Add group variance regularization if enabled if ( self.config.enable_group_variance_regularization and self.config.group_variance_indices ): var_lambda = self.config.group_variance_lambda for start, end in self.config.group_variance_indices: group_params = current_params[start:end] n_group = end - start group_mean = jnp.mean(group_params) grad_var = (2.0 / n_group) * (group_params - group_mean) JTr = JTr.at[start:end].add(-var_lambda * grad_var) diag_term = (2.0 / n_group) * jnp.eye(n_group) off_diag_term = (2.0 / (n_group * n_group)) * jnp.ones( (n_group, n_group) ) H_var = diag_term - off_diag_term JTJ = JTJ.at[start:end, start:end].add(var_lambda * H_var) group_var = jnp.var(group_params) total_cost += var_lambda * float(group_var) * n_points gradient_norm = float(jnp.linalg.norm(JTr)) # Solve for Gauss-Newton step step, predicted_reduction = self._solve_gauss_newton_step( JTJ, JTr, trust_radius ) # Apply step new_params = current_params + step # Clip to bounds if available if self.normalized_bounds is not None: lb, ub = self.normalized_bounds new_params = jnp.clip(new_params, lb, ub) # Evaluate cost at new parameters new_cost = self._compute_cost(new_params, x_data, y_data) # Compute actual reduction actual_reduction = total_cost - new_cost # Update trust region if predicted_reduction > 0: reduction_ratio = actual_reduction / predicted_reduction else: reduction_ratio = 0.0 min_trust_radius = getattr(self.config, "min_trust_radius", 1e-8) max_trust_radius = getattr(self.config, "max_trust_radius", 1000.0) step_norm = float(jnp.linalg.norm(step)) if reduction_ratio < 0.25: new_trust_radius = trust_radius * 0.5 if new_trust_radius < min_trust_radius and gradient_norm > 1e-4: scaled_grad = 0.1 * gradient_norm / max(1.0, gradient_norm) new_trust_radius = min(scaled_grad, 1.0) elif reduction_ratio > 0.75 and step_norm >= 0.9 * trust_radius: new_trust_radius = min(trust_radius * 2.0, max_trust_radius) else: new_trust_radius = trust_radius new_trust_radius = max(new_trust_radius, min_trust_radius) return { "new_params": new_params, "new_cost": new_cost, "step": step, "actual_reduction": actual_reduction, "predicted_reduction": predicted_reduction, "trust_radius": new_trust_radius, "gradient_norm": gradient_norm, } def _accumulate_jtj_jtr( self, x_chunk: jnp.ndarray, y_chunk: jnp.ndarray, params: jnp.ndarray, JTJ_prev: jnp.ndarray, JTr_prev: jnp.ndarray, valid_size: int | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray, float]: """Accumulate J^T J and J^T r for a data chunk. Parameters ---------- x_chunk : array_like Independent variable chunk (may be padded to bucket size). y_chunk : array_like Dependent variable chunk (may be padded to bucket size). params : array_like Current parameters in normalized space. JTJ_prev : array_like Previous accumulated J^T J. JTr_prev : array_like Previous accumulated J^T r. valid_size : int, optional Number of valid (non-padded) points. If None, all points are valid. Returns ------- JTJ_new : array_like Updated J^T J accumulation. JTr_new : array_like Updated J^T r accumulation. chunk_cost : float Sum of squared residuals for this chunk. """ # Compute predictions and residuals predictions = self.normalized_model(x_chunk, *params) residuals = y_chunk - predictions # Compute Jacobian for this chunk J_chunk = self._compute_jacobian_chunk(x_chunk, params) # Mask out padded elements if chunk was padded to bucket size if valid_size is not None and valid_size < len(x_chunk): mask = jnp.arange(len(x_chunk)) < valid_size residuals = jnp.where(mask, residuals, 0.0) J_chunk = jnp.where(mask[:, None], J_chunk, 0.0) # Accumulate JTJ_new = JTJ_prev + J_chunk.T @ J_chunk JTr_new = JTr_prev + J_chunk.T @ residuals chunk_cost = float(jnp.sum(residuals**2)) # Store accumulators for checkpointing self.phase2_JTJ_accumulator = JTJ_new self.phase2_JTr_accumulator = JTr_new return JTJ_new, JTr_new, chunk_cost def _compute_jacobian_chunk( self, x_chunk: jnp.ndarray, params: jnp.ndarray, ) -> jnp.ndarray: """Compute exact Jacobian for a data chunk. Parameters ---------- x_chunk : array_like Independent variable chunk. params : array_like Parameters in normalized space. Returns ------- J_chunk : array_like Jacobian matrix of shape (n_points, n_params). """ if self._jacobian_fn_compiled is not None: return self._jacobian_fn_compiled(params, x_chunk) normalized_model = self.normalized_model def model_at_x(p, x_single): return normalized_model(x_single, *p) jac_fn = jax.vmap(lambda x: jax.jacrev(model_at_x, argnums=0)(params, x)) return jac_fn(x_chunk) def _solve_gauss_newton_step( self, JTJ: jnp.ndarray, JTr: jnp.ndarray, trust_radius: float, regularization: float = 1e-10, ) -> tuple[jnp.ndarray, float]: """Solve Gauss-Newton step using SVD. Parameters ---------- JTJ : array_like Accumulated J^T J matrix. JTr : array_like Accumulated J^T r vector. trust_radius : float Trust region radius. regularization : float Tikhonov regularization parameter. Returns ------- step : array_like Gauss-Newton step. predicted_reduction : float Predicted reduction in cost function. """ n_params = JTJ.shape[0] # Add Tikhonov regularization JTJ_reg = JTJ + regularization * jnp.eye(n_params) # Compute SVD U, s, Vt = jnp.linalg.svd(JTJ_reg, full_matrices=False) # Solve using SVD UTb = U.T @ JTr s_threshold = jnp.max(s) * 1e-10 s_safe = jnp.where(s > s_threshold, s, s_threshold) step_hat = UTb / s_safe step = Vt.T @ step_hat # Apply trust region constraint step_norm = jnp.linalg.norm(step) if step_norm > 0 and step_norm > trust_radius: step = step * (trust_radius / step_norm) # Compute predicted reduction pred_red = jnp.dot(JTr, step) - 0.5 * jnp.dot(step, JTJ @ step) predicted_reduction = float(max(float(pred_red), 0.0)) return step, predicted_reduction def _compute_cost( self, params: jnp.ndarray, x_data: jnp.ndarray, y_data: jnp.ndarray, ) -> float: """Compute total cost (sum of squared residuals). Parameters ---------- params : array_like Parameters in normalized space. x_data : array_like Full x data. y_data : array_like Full y data. Returns ------- total_cost : float Total sum of squared residuals. """ chunk_size = self.config.chunk_size n_points = len(x_data) total_cost = 0.0 for i in range(0, n_points, chunk_size): x_chunk = x_data[i : i + chunk_size] y_chunk = y_data[i : i + chunk_size] actual_size = len(x_chunk) x_chunk, y_chunk, actual_size = _pad_chunk_to_bucket( x_chunk, y_chunk, actual_size ) if self._cost_fn_compiled is not None: if actual_size < len(x_chunk): # Mask padded elements for correct cost computation mask = jnp.arange(len(x_chunk)) < actual_size predictions = self.normalized_model(x_chunk, *params) residuals = jnp.where(mask, y_chunk - predictions, 0.0) chunk_cost = float(jnp.sum(residuals**2)) else: chunk_cost = self._cost_fn_compiled(params, x_chunk, y_chunk) else: predictions = self.normalized_model(x_chunk, *params) residuals = y_chunk - predictions if actual_size < len(x_chunk): mask = jnp.arange(len(x_chunk)) < actual_size residuals = jnp.where(mask, residuals, 0.0) chunk_cost = float(jnp.sum(residuals**2)) total_cost += chunk_cost # Add group variance regularization if enabled if ( self.config.enable_group_variance_regularization and self.config.group_variance_indices ): var_lambda = self.config.group_variance_lambda for start, end in self.config.group_variance_indices: group_params = params[start:end] group_var = jnp.var(group_params) total_cost += var_lambda * float(group_var) * n_points return total_cost
__all__ = ["GNResult", "GaussNewtonPhase"]