Source code for nlsq.global_optimization.tournament

"""
Tournament Selector for Global Optimization
============================================

This module provides the TournamentSelector class for progressive elimination
in multi-start optimization on streaming/large datasets. Tournament selection
is memory-efficient as it evaluates candidates on data batches without loading
the full dataset.

Key Features
------------
- Progressive elimination: N -> N/2 -> N/4 -> ... -> top M candidates
- Memory-efficient evaluation on streaming data batches
- Configurable elimination fraction and batches per round
- Checkpoint/resume support for fault tolerance
- Graceful handling of numerical failures

Examples
--------
Basic tournament selection:

>>> from nlsq.global_optimization import TournamentSelector, GlobalOptimizationConfig
>>> import numpy as np
>>>
>>> # Generate candidates
>>> candidates = np.random.randn(20, 3)  # 20 candidates, 3 parameters
>>>
>>> config = GlobalOptimizationConfig(
...     n_starts=20,
...     elimination_rounds=3,
...     elimination_fraction=0.5,
...     batches_per_round=50,
... )
>>>
>>> selector = TournamentSelector(candidates, config)
>>>
>>> def model(x, a, b, c):
...     return a * x**2 + b * x + c
>>>
>>> def data_batch_generator():
...     for _ in range(200):
...         x_batch = np.random.randn(100)
...         y_batch = 1.0 * x_batch**2 + 2.0 * x_batch + 3.0
...         yield x_batch, y_batch
>>>
>>> best_candidates = selector.run_tournament(
...     data_batch_iterator=data_batch_generator(),
...     model=model,
...     top_m=1,
... )

See Also
--------
GlobalOptimizationConfig : Configuration for multi-start optimization
MultiStartOrchestrator : Orchestrates multi-start for standard datasets
AdaptiveHybridStreamingOptimizer : Streaming optimizer with tournament integration
"""

from collections.abc import Callable, Iterator
from typing import Any

import jax.numpy as jnp
import numpy as np

from nlsq.global_optimization.config import GlobalOptimizationConfig
from nlsq.utils.logging import get_logger

__all__ = ["TournamentSelector"]


[docs] class TournamentSelector: """Tournament selector for progressive elimination in multi-start optimization. This class implements tournament-style progressive elimination for selecting the best starting points when optimizing on large/streaming datasets. Instead of evaluating all candidates on the full dataset, candidates are evaluated on streaming batches and the worst performers are eliminated each round. The tournament proceeds as: - Round 1: N candidates -> Keep top (1 - elimination_fraction) * N - Round 2: Survivors -> Keep top (1 - elimination_fraction) * survivors - ... - Final: Return top M candidates Parameters ---------- candidates : np.ndarray Array of shape (n_candidates, n_params) containing candidate starting points. config : GlobalOptimizationConfig Configuration controlling tournament parameters: - elimination_rounds: Number of elimination rounds - elimination_fraction: Fraction to eliminate each round (default 0.5) - batches_per_round: Number of batches to evaluate per round Attributes ---------- n_candidates : int Total number of candidates. n_params : int Number of parameters per candidate. survival_mask : np.ndarray Boolean mask indicating which candidates are still alive. cumulative_losses : np.ndarray Accumulated loss for each candidate (inf for eliminated). current_round : int Current tournament round (0-indexed). round_history : list History of each round with statistics. Examples -------- >>> import numpy as np >>> from nlsq.global_optimization import TournamentSelector, GlobalOptimizationConfig >>> >>> candidates = np.random.randn(16, 3) >>> config = GlobalOptimizationConfig( ... n_starts=16, ... elimination_rounds=3, ... elimination_fraction=0.5, ... batches_per_round=10, ... ) >>> selector = TournamentSelector(candidates, config) >>> print(f"Starting with {selector.n_candidates} candidates") Starting with 16 candidates Notes ----- Tournament selection is particularly effective for: - Large datasets where full evaluation is expensive - Streaming datasets that don't fit in memory - High-dimensional parameter spaces with many local minima The elimination_fraction parameter controls the aggressiveness of pruning: - 0.5 (default): Eliminate half each round (log2(N) rounds to get to 1) - 0.25: Eliminate 25% each round (slower, more conservative) - 0.75: Eliminate 75% each round (faster, more aggressive) """
[docs] def __init__( self, candidates: np.ndarray, config: GlobalOptimizationConfig, ): """Initialize tournament selector. Parameters ---------- candidates : np.ndarray Array of shape (n_candidates, n_params) with candidate starting points. config : GlobalOptimizationConfig Configuration for tournament parameters. """ self.candidates = np.asarray(candidates) self.config = config self.logger = get_logger("tournament") # Validate candidates shape if self.candidates.ndim == 1: # Single candidate self.candidates = self.candidates.reshape(1, -1) self.n_candidates = self.candidates.shape[0] self.n_params = self.candidates.shape[1] # Tournament state self.survival_mask = np.ones(self.n_candidates, dtype=bool) self.cumulative_losses = np.zeros(self.n_candidates) self.evaluation_counts = np.zeros(self.n_candidates, dtype=int) self.current_round = 0 self.round_history: list[dict[str, Any]] = [] # Tracking for diagnostics self.total_batches_evaluated = 0 self.numerical_failures = 0
@property def n_survivors(self) -> int: """Number of currently surviving candidates.""" return int(np.sum(self.survival_mask))
[docs] def run_tournament( self, data_batch_iterator: Iterator[tuple[np.ndarray, np.ndarray]], model: Callable, top_m: int = 1, ) -> list[np.ndarray]: """Run full tournament selection. Executes all elimination rounds and returns the top M surviving candidates. Parameters ---------- data_batch_iterator : Iterator Iterator yielding (x_batch, y_batch) tuples of data. model : Callable Model function with signature ``model(x, *params) -> predictions``. top_m : int, default=1 Number of top candidates to return. Returns ------- list[np.ndarray] List of top M candidate parameter arrays, sorted by loss (best first). Notes ----- The iterator is consumed during tournament execution. Ensure it yields enough batches: elimination_rounds * batches_per_round. If the iterator runs out of batches before completing all rounds, the tournament will return the best candidates found so far. """ self.logger.info( f"Starting tournament with {self.n_candidates} candidates, " f"{self.config.elimination_rounds} rounds, " f"elimination_fraction={self.config.elimination_fraction}" ) # Handle edge case: no elimination rounds if self.config.elimination_rounds == 0: self.logger.debug("No elimination rounds configured, evaluating once") self._evaluate_initial_round(data_batch_iterator, model) return self.get_top_candidates(top_m) # Run elimination rounds for round_num in range(self.config.elimination_rounds): if self.n_survivors <= top_m: self.logger.info( f"Round {round_num}: Only {self.n_survivors} survivors, " f"stopping early (need {top_m})" ) break try: self._run_single_round(data_batch_iterator, model, round_num) except StopIteration: self.logger.warning( f"Data exhausted during round {round_num}, returning best so far" ) break self.current_round = round_num + 1 self.logger.info( f"Tournament complete: {self.n_survivors} survivors from " f"{self.n_candidates} candidates" ) return self.get_top_candidates(top_m)
def _run_single_round( self, data_batch_iterator: Iterator[tuple[np.ndarray, np.ndarray]], model: Callable, round_number: int, ) -> None: """Run a single elimination round. Parameters ---------- data_batch_iterator : Iterator Data batch iterator. model : Callable Model function. round_number : int Current round number (0-indexed). """ n_survivors_before = self.n_survivors self.logger.debug( f"Round {round_number}: Evaluating {n_survivors_before} survivors " f"on {self.config.batches_per_round} batches" ) # Reset losses for this round round_losses = np.zeros(self.n_candidates) round_loss_counts = np.zeros(self.n_candidates, dtype=int) # Evaluate on batches for batch_idx in range(self.config.batches_per_round): try: x_batch, y_batch = next(data_batch_iterator) except StopIteration: if batch_idx == 0: raise # No batches at all break # Use partial evaluation batch_losses = self._evaluate_candidates_on_batch(x_batch, y_batch, model) # Accumulate losses for survivors for i in range(self.n_candidates): if self.survival_mask[i] and np.isfinite(batch_losses[i]): round_losses[i] += batch_losses[i] round_loss_counts[i] += 1 self.total_batches_evaluated += 1 # Compute average loss for this round (avoid division by zero) # Use safe division with masked array to prevent warning with np.errstate(divide="ignore", invalid="ignore"): avg_round_losses = np.where( round_loss_counts > 0, round_losses / np.maximum(round_loss_counts, 1), np.inf, ) # Update cumulative losses self.cumulative_losses += avg_round_losses self.evaluation_counts += round_loss_counts # Perform elimination n_to_eliminate = int(n_survivors_before * self.config.elimination_fraction) n_to_eliminate = max(0, min(n_to_eliminate, n_survivors_before - 1)) if n_to_eliminate > 0: self._eliminate_worst(n_to_eliminate, avg_round_losses) # Record round history self.round_history.append( { "round": round_number, "n_survivors_before": n_survivors_before, "n_survivors_after": self.n_survivors, "n_eliminated": n_survivors_before - self.n_survivors, "batches_evaluated": min(batch_idx + 1, self.config.batches_per_round), "mean_loss": float(np.mean(avg_round_losses[self.survival_mask])) if self.n_survivors > 0 else np.inf, } ) self.logger.debug( f"Round {round_number} complete: {n_survivors_before} -> {self.n_survivors} survivors" ) def _evaluate_initial_round( self, data_batch_iterator: Iterator[tuple[np.ndarray, np.ndarray]], model: Callable, ) -> None: """Evaluate all candidates on initial batches (no elimination rounds case).""" for batch_idx in range(self.config.batches_per_round): try: x_batch, y_batch = next(data_batch_iterator) except StopIteration: break batch_losses = self._evaluate_candidates_on_batch(x_batch, y_batch, model) for i in range(self.n_candidates): if np.isfinite(batch_losses[i]): self.cumulative_losses[i] += batch_losses[i] self.evaluation_counts[i] += 1 self.total_batches_evaluated += 1 def _evaluate_candidates_on_batch( self, x_batch: np.ndarray, y_batch: np.ndarray, model: Callable, ) -> np.ndarray: """Evaluate all surviving candidates on a single batch. Parameters ---------- x_batch : np.ndarray Independent variable batch. y_batch : np.ndarray Dependent variable batch. model : Callable Model function. Returns ------- np.ndarray Loss values for each candidate (inf for eliminated or failed). """ x_jax = jnp.asarray(x_batch) y_jax = jnp.asarray(y_batch) losses = np.full(self.n_candidates, np.inf) for i in range(self.n_candidates): if not self.survival_mask[i]: continue # Skip eliminated candidates try: params = self.candidates[i] predictions = model(x_jax, *params) residuals = y_jax - predictions # Compute mean squared error loss = float(jnp.mean(residuals**2)) if np.isfinite(loss): losses[i] = loss else: self.numerical_failures += 1 self.logger.debug(f"Candidate {i}: Non-finite loss {loss}") except Exception as e: self.numerical_failures += 1 self.logger.debug(f"Candidate {i} failed evaluation: {str(e)[:50]}") # Keep loss as inf return losses def _eliminate_worst( self, n_to_eliminate: int, round_losses: np.ndarray, ) -> None: """Eliminate the worst-performing candidates. Parameters ---------- n_to_eliminate : int Number of candidates to eliminate. round_losses : np.ndarray Loss values for this round. """ # Get indices of current survivors survivor_indices = np.where(self.survival_mask)[0] # Sort survivors by their round loss (worst first) survivor_losses = round_losses[survivor_indices] sorted_order = np.argsort(-survivor_losses) # Descending (worst first) # Eliminate the worst n_to_eliminate for i in range(min(n_to_eliminate, len(sorted_order))): idx_to_eliminate = survivor_indices[sorted_order[i]] self.survival_mask[idx_to_eliminate] = False # Mark as eliminated in cumulative losses self.cumulative_losses[idx_to_eliminate] = np.inf
[docs] def get_top_candidates(self, top_m: int = 1) -> list[np.ndarray]: """Get the top M candidates by cumulative loss. Parameters ---------- top_m : int, default=1 Number of top candidates to return. Returns ------- list[np.ndarray] List of top candidate parameter arrays, sorted by loss (best first). """ # Get survivors with finite cumulative loss valid_indices = np.where( self.survival_mask & np.isfinite(self.cumulative_losses) )[0] if len(valid_indices) == 0: # Fall back to any surviving candidate valid_indices = np.where(self.survival_mask)[0] if len(valid_indices) == 0: # Fall back to first candidate self.logger.warning("No valid survivors, returning first candidate") return [self.candidates[0].copy()] # Sort by cumulative loss losses = self.cumulative_losses[valid_indices] sorted_order = np.argsort(losses) # Return top M top_indices = valid_indices[sorted_order[:top_m]] return [self.candidates[i].copy() for i in top_indices]
[docs] def to_checkpoint(self) -> dict[str, Any]: """Serialize tournament state to a checkpoint dictionary. Returns ------- dict Checkpoint state that can be serialized and saved. Examples -------- >>> from nlsq.utils.safe_serialize import safe_dumps >>> checkpoint = selector.to_checkpoint() >>> with open('tournament_checkpoint.json', 'w') as f: ... f.write(safe_dumps(checkpoint)) """ return { "candidates": self.candidates.copy(), "survival_mask": self.survival_mask.copy(), "cumulative_losses": self.cumulative_losses.copy(), "evaluation_counts": self.evaluation_counts.copy(), "current_round": self.current_round, "round_history": self.round_history.copy(), "total_batches_evaluated": self.total_batches_evaluated, "numerical_failures": self.numerical_failures, "n_candidates": self.n_candidates, "n_params": self.n_params, }
[docs] @classmethod def from_checkpoint( cls, checkpoint: dict[str, Any], config: GlobalOptimizationConfig, ) -> "TournamentSelector": """Restore tournament selector from checkpoint. Parameters ---------- checkpoint : dict Checkpoint state from to_checkpoint(). config : GlobalOptimizationConfig Configuration (must match original). Returns ------- TournamentSelector Restored tournament selector. Examples -------- >>> from nlsq.utils.safe_serialize import safe_loads >>> with open('tournament_checkpoint.json') as f: ... checkpoint = safe_loads(f.read()) >>> selector = TournamentSelector.from_checkpoint(checkpoint, config) """ candidates = checkpoint["candidates"] selector = cls(candidates=candidates, config=config) # Restore state selector.survival_mask = checkpoint["survival_mask"].copy() selector.cumulative_losses = checkpoint["cumulative_losses"].copy() selector.evaluation_counts = checkpoint["evaluation_counts"].copy() selector.current_round = checkpoint["current_round"] selector.round_history = checkpoint["round_history"].copy() selector.total_batches_evaluated = checkpoint["total_batches_evaluated"] selector.numerical_failures = checkpoint["numerical_failures"] return selector
[docs] def get_diagnostics(self) -> dict[str, Any]: """Get tournament diagnostics. Returns ------- dict Dictionary with tournament statistics and history. """ avg_loss = np.mean(self.cumulative_losses[self.survival_mask]) return { "n_candidates_initial": self.n_candidates, "n_survivors": self.n_survivors, "elimination_rate": 1.0 - (self.n_survivors / self.n_candidates), "rounds_completed": self.current_round, "total_batches_evaluated": self.total_batches_evaluated, "numerical_failures": self.numerical_failures, "mean_survivor_loss": float(avg_loss) if np.isfinite(avg_loss) else None, "round_history": self.round_history, }