Source code for nlsq.global_optimization.bipop

"""BIPOP restart strategy for CMA-ES.

Implements the Bi-Population restart strategy where large and small
population runs alternate to balance exploration and exploitation.

References
----------
Hansen, N. (2009). Benchmarking a BI-Population CMA-ES on the BBOB-2009
Function Testbed. GECCO Workshop on Black-Box Optimization Benchmarking.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    import jax

__all__ = ["BIPOPRestarter"]

logger = logging.getLogger(__name__)


[docs] @dataclass class BIPOPRestarter: """BIPOP restart manager for CMA-ES optimization. Manages alternating large/small population restarts following the BIPOP strategy. Large populations explore broadly while small populations exploit local regions more intensively. Parameters ---------- base_popsize : int Base population size (typically 4 + floor(3 * ln(n))). n_params : int Number of parameters being optimized. max_restarts : int, optional Maximum number of restarts before giving up. Default is 9. min_fitness_spread : float, optional Minimum fitness spread threshold for stagnation detection. Default is 1e-12. Attributes ---------- restart_count : int Number of restarts performed so far. exhausted : bool True if max_restarts has been reached. best_solution : jax.Array | None Best solution found across all restarts. best_fitness : float Best fitness found across all restarts. Examples -------- >>> restarter = BIPOPRestarter(base_popsize=8, n_params=3) >>> popsize = restarter.get_next_popsize() >>> # ... run CMA-ES with popsize ... >>> if restarter.check_stagnation(fitness_spread=1e-15): ... restarter.register_restart() ... popsize = restarter.get_next_popsize() """ base_popsize: int n_params: int max_restarts: int = 9 min_fitness_spread: float = 1e-12 # Internal state restart_count: int = field(default=0, init=False) _use_large_pop: bool = field(default=True, init=False) _best_solution: jax.Array | None = field(default=None, init=False) _best_fitness: float = field(default=-float("inf"), init=False) _rng: np.random.Generator = field(default_factory=np.random.default_rng, init=False) @property def exhausted(self) -> bool: """Whether maximum restarts have been reached.""" return self.restart_count >= self.max_restarts @property def best_solution(self) -> jax.Array | None: """Best solution found across all restarts.""" return self._best_solution @property def best_fitness(self) -> float: """Best fitness found across all restarts.""" return self._best_fitness
[docs] def get_next_popsize(self) -> int: """Get population size for the next run. Returns ------- int Population size to use for next CMA-ES run. Large runs use 2x base_popsize, small runs use base_popsize/2 to base_popsize. """ if self._use_large_pop: # Large population: doubled base popsize = self.base_popsize * 2 logger.debug(f"BIPOP: Using large population (popsize={popsize})") else: # Small population: random between base/2 and base min_pop = max(4, self.base_popsize // 2) # At least 4 max_pop = self.base_popsize popsize = int(self._rng.integers(min_pop, max_pop + 1)) logger.debug(f"BIPOP: Using small population (popsize={popsize})") return popsize
[docs] def register_restart(self) -> None: """Register that a restart has occurred. Call this after a restart to update internal state. """ self.restart_count += 1 self._use_large_pop = not self._use_large_pop # Alternate logger.debug( f"BIPOP: Restart {self.restart_count}/{self.max_restarts}, " f"next run: {'large' if self._use_large_pop else 'small'}" )
[docs] def check_stagnation(self, fitness_spread: float) -> bool: """Check if optimization has stagnated. Parameters ---------- fitness_spread : float Difference between max and min fitness in current population. Returns ------- bool True if stagnation detected (fitness spread below threshold). """ return fitness_spread < self.min_fitness_spread
[docs] def update_best(self, solution: jax.Array, fitness: float) -> None: """Update best solution if the new one is better. Parameters ---------- solution : jax.Array Candidate solution. fitness : float Fitness value (CMA-ES maximizes, so higher is better). """ if fitness > self._best_fitness: self._best_solution = solution self._best_fitness = fitness logger.debug(f"BIPOP: New best fitness: {fitness:.6e}")
[docs] def get_best(self) -> tuple[jax.Array | None, float]: """Get the best solution found across all restarts. Returns ------- tuple[jax.Array | None, float] Best solution and its fitness, or (None, -inf) if none found. """ return self._best_solution, self._best_fitness
[docs] def reset(self) -> None: """Reset the restarter to initial state.""" self.restart_count = 0 self._use_large_pop = True self._best_solution = None self._best_fitness = -float("inf")