"""CMA-ES configuration for global optimization.
This module provides configuration dataclasses and utilities for CMA-ES
optimization using the evosax library.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Literal
__all__ = [
"CMAES_PRESETS",
"CMAESConfig",
"auto_configure_cmaes_memory",
"estimate_cmaes_memory_gb",
"get_evosax_import_error",
"is_evosax_available",
]
logger = logging.getLogger(__name__)
# Global cache for evosax availability check
_EVOSAX_AVAILABLE: bool | None = None
_EVOSAX_IMPORT_ERROR: str | None = None
[docs]
def is_evosax_available() -> bool:
"""Check if evosax is available for import.
Uses lazy import pattern to avoid import overhead until needed.
Caches the result for subsequent calls.
Returns
-------
bool
True if evosax can be imported successfully, False otherwise.
"""
global _EVOSAX_AVAILABLE, _EVOSAX_IMPORT_ERROR # noqa: PLW0603
if _EVOSAX_AVAILABLE is None:
try:
# Import to verify package is available and functional
# We need to test actual imports, not just find_spec, because
# evosax may be installed but have missing dependencies
import evosax.algorithms # type: ignore[import-not-found,import-untyped]
import evosax.core.restart # type: ignore[import-not-found,import-untyped]
_EVOSAX_AVAILABLE = True
_EVOSAX_IMPORT_ERROR = None
except ImportError as e:
_EVOSAX_AVAILABLE = False
_EVOSAX_IMPORT_ERROR = str(e)
logger.info(
"evosax not available - CMA-ES will fall back to multi-start. "
"Install with: pip install 'nlsq[global]'"
)
return _EVOSAX_AVAILABLE
def get_evosax_import_error() -> str | None:
"""Get the import error message if evosax is not available.
Returns
-------
str | None
The import error message, or None if evosax is available.
"""
is_evosax_available() # Ensure availability check has been performed
return _EVOSAX_IMPORT_ERROR
[docs]
@dataclass(slots=True)
class CMAESConfig:
"""Configuration for CMA-ES global optimization.
Attributes
----------
popsize : int | None
Population size. If None, uses CMA-ES default: int(4 + 3 * log(n)).
max_generations : int
Maximum number of generations before stopping. Default: 100.
sigma : float
Initial step size (standard deviation). Default: 0.5.
tol_fun : float
Function value tolerance for convergence. Default: 1e-8.
tol_x : float
Parameter tolerance for convergence. Default: 1e-8.
restart_strategy : Literal['none', 'bipop']
Restart strategy. Default: 'bipop'.
max_restarts : int
Maximum restart attempts for BIPOP. Default: 9.
population_batch_size : int | None
Batch size for population evaluation. If None, evaluates all candidates
in parallel. Set to smaller values (e.g., 4) to reduce memory usage.
data_chunk_size : int | None
Chunk size for data streaming. If None, processes full dataset at once.
Set to smaller values (e.g., 50000) for datasets >10M points to avoid OOM.
Must be >= 1024 for numerical stability.
refine_with_nlsq : bool
Whether to refine best solution with NLSQ TRF. Default: True.
seed : int | None
Random seed for reproducibility. If None, uses random seed.
Examples
--------
>>> config = CMAESConfig(popsize=32, max_generations=200)
>>> config = CMAESConfig.from_preset('cmaes-global')
Memory-efficient configuration for large datasets:
>>> config = CMAESConfig(
... population_batch_size=4, # Evaluate 4 candidates at a time
... data_chunk_size=50000, # Process 50K points per chunk
... )
"""
# Population and generations
popsize: int | None = None # None = auto: int(4 + 3 * log(n))
max_generations: int = 100
# Step size and tolerances
sigma: float = 0.5
tol_fun: float = 1e-8
tol_x: float = 1e-8
# Restart strategy (BIPOP enabled by default per spec)
restart_strategy: Literal["none", "bipop"] = "bipop"
max_restarts: int = 9
# Memory management
population_batch_size: int | None = None
data_chunk_size: int | None = None
# NLSQ refinement
refine_with_nlsq: bool = True
# Reproducibility
seed: int | None = None
[docs]
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate configuration parameters.
Raises
------
ValueError
If any parameter is invalid.
"""
if self.popsize is not None and self.popsize < 4:
raise ValueError(f"popsize must be >= 4, got {self.popsize}")
if self.max_generations < 1:
raise ValueError(
f"max_generations must be >= 1, got {self.max_generations}"
)
if self.sigma <= 0:
raise ValueError(f"sigma must be > 0, got {self.sigma}")
if self.tol_fun <= 0:
raise ValueError(f"tol_fun must be > 0, got {self.tol_fun}")
if self.tol_x <= 0:
raise ValueError(f"tol_x must be > 0, got {self.tol_x}")
if self.max_restarts < 0:
raise ValueError(f"max_restarts must be >= 0, got {self.max_restarts}")
if self.restart_strategy not in ("none", "bipop"):
raise ValueError(
f"restart_strategy must be 'none' or 'bipop', "
f"got '{self.restart_strategy}'"
)
if self.population_batch_size is not None and self.population_batch_size < 1:
raise ValueError(
f"population_batch_size must be >= 1, got {self.population_batch_size}"
)
if self.data_chunk_size is not None and self.data_chunk_size < 1024:
raise ValueError(
f"data_chunk_size must be >= 1024 for numerical stability, "
f"got {self.data_chunk_size}"
)
[docs]
@classmethod
def from_preset(cls, preset_name: str) -> CMAESConfig:
"""Create a CMAESConfig from a named preset.
Parameters
----------
preset_name : str
Name of the preset. One of 'cmaes-fast', 'cmaes', 'cmaes-global'.
Returns
-------
CMAESConfig
Configuration for the specified preset.
Raises
------
ValueError
If preset_name is not recognized.
Examples
--------
>>> config = CMAESConfig.from_preset('cmaes-fast')
>>> config.max_generations
50
"""
if preset_name not in CMAES_PRESETS:
available = ", ".join(sorted(CMAES_PRESETS.keys()))
raise ValueError(
f"Unknown preset '{preset_name}'. Available presets: {available}"
)
preset_config = CMAES_PRESETS[preset_name]
return cls(**preset_config)
# CMA-ES presets - source of truth for CMA-ES preset configurations
CMAES_PRESETS: dict[str, dict] = {
"cmaes-fast": {
"popsize": None, # auto
"max_generations": 50,
"restart_strategy": "none",
"max_restarts": 0,
},
"cmaes": {
"popsize": None, # auto
"max_generations": 100,
"restart_strategy": "bipop",
"max_restarts": 9,
},
"cmaes-global": {
"popsize": None, # Will be doubled in optimizer (2x auto)
"max_generations": 200,
"restart_strategy": "bipop",
"max_restarts": 9,
},
}
[docs]
def estimate_cmaes_memory_gb(
n_data: int,
popsize: int,
population_batch_size: int | None = None,
data_chunk_size: int | None = None,
) -> float:
"""Estimate peak GPU memory usage for CMA-ES in GB.
Parameters
----------
n_data : int
Number of data points.
popsize : int
Population size.
population_batch_size : int | None, optional
Batch size for population evaluation. If None, uses full popsize.
data_chunk_size : int | None, optional
Chunk size for data streaming. If None, uses full dataset.
Returns
-------
float
Estimated peak memory usage in GB.
Examples
--------
>>> estimate_cmaes_memory_gb(10_000_000, popsize=16)
1.1920928955078125
>>> estimate_cmaes_memory_gb(10_000_000, popsize=16, population_batch_size=4)
0.298023223876953
>>> estimate_cmaes_memory_gb(10_000_000, popsize=16, data_chunk_size=50000)
0.005960464477539062
"""
eff_pop = population_batch_size if population_batch_size is not None else popsize
eff_data = data_chunk_size if data_chunk_size is not None else n_data
# Memory per candidate during fitness evaluation:
# - predictions array: eff_data * 8 bytes (float64)
# - residuals array: eff_data * 8 bytes (float64)
# - intermediate computations: ~eff_data * 8 bytes
bytes_per_candidate = eff_data * 8 * 3 # 3 arrays
peak_bytes = eff_pop * bytes_per_candidate
# Add overhead for input data (xdata, ydata)
# When data_chunk_size is set, we use dynamic_slice which doesn't copy
if data_chunk_size is None:
peak_bytes += n_data * 8 * 2 # xdata + ydata
return peak_bytes / (1024**3)