"""Phase orchestrator for streaming optimization workflow.
This module provides the PhaseOrchestrator class that coordinates
the multi-phase optimization workflow (setup, warmup, GN, finalize).
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import jax.numpy as jnp
from nlsq.utils.logging import get_logger
if TYPE_CHECKING:
from jax import Array
from nlsq.precision.parameter_normalizer import NormalizedModelWrapper
from nlsq.streaming.hybrid_config import HybridStreamingConfig
from nlsq.streaming.phases.gauss_newton import GaussNewtonPhase, GNResult
from nlsq.streaming.phases.warmup import WarmupPhase, WarmupResult
_logger = get_logger("phase_orchestrator")
[docs]
@dataclass(frozen=True, slots=True)
class PhaseOrchestratorResult:
"""Complete result from phase orchestration.
Attributes:
params: Final optimized parameters in original space.
normalized_params: Final parameters in normalized space.
cost: Final cost value.
warmup_result: Result from Phase 1 warmup.
gn_result: Result from Phase 2 Gauss-Newton.
phase_history: List of phase transition records.
total_time: Total optimization time.
"""
params: Array
normalized_params: Array
cost: float
warmup_result: WarmupResult | None
gn_result: GNResult | None
phase_history: list[dict[str, Any]]
total_time: float
[docs]
class PhaseOrchestrator:
"""Orchestrates the multi-phase streaming optimization workflow.
The orchestrator coordinates:
- Phase 0: Setup (normalization, validation)
- Phase 1: L-BFGS warmup (WarmupPhase)
- Phase 2: Streaming Gauss-Newton (GaussNewtonPhase)
- Phase 3: Finalization (denormalization, covariance)
Parameters
----------
config : HybridStreamingConfig
Configuration for streaming optimization.
Attributes
----------
config : HybridStreamingConfig
Configuration object.
warmup_phase : WarmupPhase or None
The warmup phase handler (lazy initialized).
gn_phase : GaussNewtonPhase or None
The Gauss-Newton phase handler (lazy initialized).
phase_history : list
Records of phase transitions.
"""
[docs]
def __init__(self, config: HybridStreamingConfig) -> None:
"""Initialize PhaseOrchestrator.
Parameters
----------
config : HybridStreamingConfig
Configuration for streaming optimization.
"""
self.config = config
self.warmup_phase: WarmupPhase | None = None
self.gn_phase: GaussNewtonPhase | None = None
self.phase_history: list[dict[str, Any]] = []
# Best parameter tracking (shared across phases)
self._best_tracker: dict[str, Any] = {
"best_params_global": None,
"best_cost_global": float("inf"),
}
# State
self.current_phase: int = 0
self._start_time: float | None = None
[docs]
def initialize_phases(
self,
normalized_model: NormalizedModelWrapper,
normalized_bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
) -> None:
"""Initialize phase handlers.
Parameters
----------
normalized_model : NormalizedModelWrapper
Model wrapper operating in normalized parameter space.
normalized_bounds : tuple of array_like or None
Parameter bounds in normalized space.
"""
from nlsq.streaming.phases.gauss_newton import GaussNewtonPhase
from nlsq.streaming.phases.warmup import WarmupPhase
self.warmup_phase = WarmupPhase(self.config, normalized_model)
self.gn_phase = GaussNewtonPhase(
self.config, normalized_model, normalized_bounds
)
[docs]
def run(
self,
data_source: tuple[jnp.ndarray, jnp.ndarray],
initial_params: jnp.ndarray,
normalizer: Any | None = None,
) -> dict[str, Any]:
"""Run the full multi-phase optimization workflow.
Parameters
----------
data_source : tuple of array_like
Data as (x_data, y_data).
initial_params : array_like
Initial parameters in normalized space.
normalizer : ParameterNormalizer or None
For denormalization in Phase 3.
Returns
-------
result : dict
Complete optimization result with keys:
- 'final_params': Final parameters in original space
- 'normalized_params': Final parameters in normalized space
- 'best_cost': Best cost achieved
- 'warmup_result': WarmupResult from Phase 1
- 'gn_result': GNResult from Phase 2
- 'phase_history': List of phase records
- 'JTJ_final': Final J^T J matrix
- 'residual_sum_sq': Final residual sum of squares
"""
self._start_time = time.time()
self.phase_history = []
self._best_tracker = {
"best_params_global": initial_params,
"best_cost_global": float("inf"),
}
verbose = getattr(self.config, "verbose", 1)
# =====================================================
# PHASE 0: Setup
# =====================================================
self.current_phase = 0
if verbose >= 1:
_logger.info("Phase 0: Setup and validation")
x_data, y_data = data_source
x_data = jnp.asarray(x_data, dtype=jnp.float64)
y_data = jnp.asarray(y_data, dtype=jnp.float64)
current_params = initial_params
# =====================================================
# PHASE 1: L-BFGS Warmup
# =====================================================
self.current_phase = 1
if verbose >= 1:
_logger.info("Phase 1: L-BFGS warmup")
warmup_result = None
if self.warmup_phase is not None:
phase1_result = self.warmup_phase.run(
data_source=(x_data, y_data),
initial_params=current_params,
phase_history=self.phase_history,
best_tracker=self._best_tracker,
)
current_params = phase1_result["best_params"]
warmup_result = phase1_result.get("warmup_result")
if verbose >= 1:
_logger.info(
f"Phase 1 complete: {phase1_result['iterations']} iterations, "
f"loss={phase1_result['best_loss']:.6e}, "
f"reason: {phase1_result['switch_reason']}"
)
else:
# Skip warmup if no phase handler
self.phase_history.append(
{
"phase": 1,
"name": "lbfgs_warmup",
"iterations": 0,
"final_loss": float("inf"),
"best_loss": float("inf"),
"switch_reason": "Warmup phase not initialized",
"timestamp": time.time(),
"skipped": True,
}
)
# =====================================================
# PHASE 2: Streaming Gauss-Newton
# =====================================================
self.current_phase = 2
if verbose >= 1:
_logger.info("Phase 2: Streaming Gauss-Newton")
gn_result = None
JTJ_final = None
residual_sum_sq = 0.0
if self.gn_phase is not None:
phase2_result = self.gn_phase.run(
data_source=(x_data, y_data),
initial_params=current_params,
phase_history=self.phase_history,
best_tracker=self._best_tracker,
)
current_params = phase2_result["best_params"]
gn_result = phase2_result.get("gn_result")
JTJ_final = phase2_result.get("JTJ_final")
residual_sum_sq = phase2_result.get("residual_sum_sq", 0.0)
if verbose >= 1:
_logger.info(
f"Phase 2 complete: {phase2_result['iterations']} iterations, "
f"cost={phase2_result['best_cost']:.6e}, "
f"reason: {phase2_result['convergence_reason']}"
)
else:
# Skip GN if no phase handler
self.phase_history.append(
{
"phase": 2,
"name": "gauss_newton",
"iterations": 0,
"final_cost": float("inf"),
"best_cost": float("inf"),
"convergence_reason": "GN phase not initialized",
"timestamp": time.time(),
"skipped": True,
}
)
# =====================================================
# PHASE 3: Finalization
# =====================================================
self.current_phase = 3
if verbose >= 1:
_logger.info("Phase 3: Finalization")
# Use best parameters found globally
final_normalized_params = self._best_tracker["best_params_global"]
if final_normalized_params is None:
final_normalized_params = current_params
# Denormalize parameters if normalizer available
if normalizer is not None and hasattr(normalizer, "denormalize"):
final_params = normalizer.denormalize(final_normalized_params)
else:
final_params = final_normalized_params
total_time = time.time() - self._start_time
# Record Phase 3
self.phase_history.append(
{
"phase": 3,
"name": "finalization",
"final_cost": self._best_tracker["best_cost_global"],
"timestamp": time.time(),
"total_time": total_time,
}
)
if verbose >= 1:
best_cost = self._best_tracker["best_cost_global"]
_logger.info(
f"Optimization complete: cost={best_cost:.6e}, "
f"total_time={total_time:.1f}s"
)
return {
"final_params": final_params,
"normalized_params": final_normalized_params,
"best_cost": self._best_tracker["best_cost_global"],
"warmup_result": warmup_result,
"gn_result": gn_result,
"phase_history": self.phase_history,
"JTJ_final": JTJ_final,
"residual_sum_sq": residual_sum_sq,
"total_time": total_time,
}
[docs]
def get_phase_history(self) -> list[dict[str, Any]]:
"""Get the phase transition history.
Returns
-------
phase_history : list
List of phase transition records.
"""
return self.phase_history
[docs]
def get_best_params(self) -> jnp.ndarray | None:
"""Get the best parameters found across all phases.
Returns
-------
best_params : array_like or None
Best parameters in normalized space.
"""
return self._best_tracker.get("best_params_global")
[docs]
def get_best_cost(self) -> float:
"""Get the best cost found across all phases.
Returns
-------
best_cost : float
Best cost value.
"""
return self._best_tracker.get("best_cost_global", float("inf"))
__all__ = ["PhaseOrchestrator", "PhaseOrchestratorResult"]