"""CMA-ES global optimizer with NLSQ refinement.
This module provides the CMAESOptimizer class that runs CMA-ES global search
using evosax followed by NLSQ Trust Region Reflective refinement for proper
parameter covariance estimation.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import jax
import jax.numpy as jnp
import numpy as np
from nlsq.global_optimization.bounds_transform import (
compute_default_popsize,
transform_from_bounds,
transform_to_bounds,
)
from nlsq.global_optimization.cmaes_config import (
CMAESConfig,
is_evosax_available,
)
from nlsq.global_optimization.cmaes_diagnostics import CMAESDiagnostics
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
__all__ = ["CMAESOptimizer"]
logger = logging.getLogger(__name__)
def _create_fitness_function( # noqa: C901
model_func: Callable,
xdata: jax.Array,
ydata: jax.Array,
lower_bounds: jax.Array,
upper_bounds: jax.Array,
sigma: jax.Array | None = None,
population_batch_size: int | None = None,
data_chunk_size: int | None = None,
) -> Callable[[jax.Array], jax.Array]:
"""Create a fitness function for CMA-ES optimization.
CMA-ES maximizes fitness, so we return negative SSR (sum of squared residuals).
Parameters
----------
model_func : Callable
Model function f(x, *params) -> y.
xdata : jax.Array
Independent variable data.
ydata : jax.Array
Dependent variable data.
lower_bounds : jax.Array
Lower bounds for parameters.
upper_bounds : jax.Array
Upper bounds for parameters.
sigma : jax.Array | None, optional
Standard deviation of ydata for weighted residuals.
population_batch_size : int | None, optional
Batch size for population evaluation to avoid OOM.
data_chunk_size : int | None, optional
Chunk size for data streaming to avoid OOM on large datasets.
Returns
-------
Callable[[jax.Array], jax.Array]
Fitness function that takes unbounded parameters and returns fitness.
"""
n_data = xdata.shape[0]
# Determine if we need data streaming
use_data_streaming = data_chunk_size is not None and n_data > data_chunk_size
if use_data_streaming:
# Mypy doesn't infer not-None from the boolean flag
assert data_chunk_size is not None
# Calculate number of full chunks and remainder
n_full_chunks = n_data // data_chunk_size
remainder = n_data % data_chunk_size
# Pad data to exact multiple of chunk_size for efficient slicing
if remainder > 0:
pad_size = data_chunk_size - remainder
xdata_padded = jnp.pad(xdata, (0, pad_size), constant_values=0.0)
ydata_padded = jnp.pad(ydata, (0, pad_size), constant_values=0.0)
if sigma is not None:
# Pad sigma with 1.0 to avoid division issues (residual will be 0)
sigma_padded = jnp.pad(sigma, (0, pad_size), constant_values=1.0)
else:
sigma_padded = None
n_chunks = n_full_chunks + 1
else:
xdata_padded = xdata
ydata_padded = ydata
sigma_padded = sigma
n_chunks = n_full_chunks
# Reshape data into chunks for efficient access
xdata_chunked = xdata_padded.reshape(n_chunks, data_chunk_size)
ydata_chunked = ydata_padded.reshape(n_chunks, data_chunk_size)
if sigma_padded is not None:
sigma_chunked = sigma_padded.reshape(n_chunks, data_chunk_size)
else:
sigma_chunked = None
# Create validity mask for the last chunk (handles padding)
if remainder > 0:
last_chunk_mask = jnp.arange(data_chunk_size) < remainder
else:
last_chunk_mask = jnp.ones(data_chunk_size, dtype=bool)
@jax.jit
def compute_chunk_ssr(
params_bounded: jax.Array,
x_chunk: jax.Array,
y_chunk: jax.Array,
sigma_chunk: jax.Array | None,
valid_mask: jax.Array,
) -> jax.Array:
"""Compute SSR for one data chunk."""
predictions = model_func(x_chunk, *params_bounded)
residuals = y_chunk - predictions
if sigma_chunk is not None:
residuals = residuals / sigma_chunk
# Apply validity mask to handle padding in last chunk
residuals_sq = jnp.where(valid_mask, residuals**2, 0.0)
return jnp.sum(residuals_sq)
def fitness_single_streaming(params_unbounded: jax.Array) -> jax.Array:
"""Compute fitness by streaming over data chunks."""
params_bounded = transform_to_bounds(
params_unbounded, lower_bounds, upper_bounds
)
# Accumulate SSR over chunks
ssr_total = jnp.array(0.0)
for chunk_idx in range(n_chunks):
x_chunk = xdata_chunked[chunk_idx]
y_chunk = ydata_chunked[chunk_idx]
sigma_chunk = (
sigma_chunked[chunk_idx] if sigma_chunked is not None else None
)
# Use appropriate mask for last chunk
if chunk_idx == n_chunks - 1 and remainder > 0:
valid_mask = last_chunk_mask
else:
valid_mask = jnp.ones(data_chunk_size, dtype=bool)
ssr_total = ssr_total + compute_chunk_ssr(
params_bounded, x_chunk, y_chunk, sigma_chunk, valid_mask
)
return jnp.where(jnp.isfinite(ssr_total), -ssr_total, -jnp.inf)
fitness_single = fitness_single_streaming
logger.debug(
f"Data streaming enabled: {n_data} points -> {n_chunks} chunks of {data_chunk_size}"
)
else:
# Original non-streaming fitness function
@jax.jit
def fitness_single(params_unbounded: jax.Array) -> jax.Array:
"""Compute fitness for a single parameter set."""
# Transform to bounded space
params_bounded = transform_to_bounds(
params_unbounded, lower_bounds, upper_bounds
)
# Compute predictions
predictions = model_func(xdata, *params_bounded)
# Compute residuals
residuals = ydata - predictions
# Weight by sigma if provided
if sigma is not None:
residuals = residuals / sigma
# Sum of squared residuals
ssr = jnp.sum(residuals**2)
# Handle NaN/Inf (assign worst fitness)
fitness = jnp.where(jnp.isfinite(ssr), -ssr, -jnp.inf)
return fitness
@jax.jit
def fitness_population_jit(population: jax.Array) -> jax.Array:
"""Compute fitness for entire population (vectorized)."""
return jax.vmap(fitness_single)(population)
if population_batch_size is None:
return fitness_population_jit
def fitness_population_batched(population: jax.Array) -> jax.Array:
"""Compute fitness for population in batches (sequential loop)."""
n = population.shape[0]
# If population fits in one batch, run directly
if n <= population_batch_size:
return fitness_population_jit(population)
results = []
for i in range(0, n, population_batch_size):
batch = population[i : i + population_batch_size]
results.append(fitness_population_jit(batch))
return jnp.concatenate(results)
return fitness_population_batched
[docs]
class CMAESOptimizer:
"""CMA-ES global optimizer with NLSQ refinement using evosax.
Uses evosax's CMA-ES implementation for gradient-free global optimization,
followed by NLSQ Trust Region Reflective refinement for proper parameter
covariance estimation.
Parameters
----------
config : CMAESConfig | None, optional
Configuration for CMA-ES optimization. If None, uses default config.
Attributes
----------
config : CMAESConfig
Configuration for CMA-ES optimization.
Examples
--------
>>> from nlsq.global_optimization import CMAESOptimizer, CMAESConfig
>>> import jax.numpy as jnp
>>>
>>> def model(x, a, b):
... return a * jnp.exp(-b * x)
>>>
>>> x = jnp.linspace(0, 5, 100)
>>> y = 2.5 * jnp.exp(-0.5 * x)
>>> bounds = ([0.1, 0.01], [10.0, 2.0])
>>>
>>> optimizer = CMAESOptimizer()
>>> result = optimizer.fit(model, x, y, bounds=bounds)
>>> print(f"Optimal params: {result['popt']}")
"""
[docs]
def __init__(self, config: CMAESConfig | None = None) -> None:
"""Initialize CMAESOptimizer.
Parameters
----------
config : CMAESConfig | None, optional
Configuration for CMA-ES optimization. If None, uses default config
(BIPOP enabled, 100 generations, 9 max restarts).
"""
self.config = config if config is not None else CMAESConfig()
# Verify evosax is available
if not is_evosax_available():
raise ImportError(
"evosax is required for CMA-ES optimization. "
"Install with: pip install 'nlsq[global]'"
)
[docs]
@classmethod
def from_preset(cls, preset_name: str) -> CMAESOptimizer:
"""Create optimizer from a named preset.
Parameters
----------
preset_name : str
Name of the preset. One of 'cmaes-fast', 'cmaes', 'cmaes-global'.
Returns
-------
CMAESOptimizer
Optimizer configured with the specified preset.
Examples
--------
>>> optimizer = CMAESOptimizer.from_preset('cmaes-fast')
>>> optimizer.config.max_generations
50
"""
config = CMAESConfig.from_preset(preset_name)
return cls(config=config)
[docs]
def fit(
self,
f: Callable,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None = None,
bounds: tuple[ArrayLike, ArrayLike] | None = None,
sigma: ArrayLike | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Run CMA-ES global optimization followed by NLSQ refinement.
Parameters
----------
f : Callable
Model function ``f(x, *params) -> y``.
xdata : ArrayLike
Independent variable data.
ydata : ArrayLike
Dependent variable data.
p0 : ArrayLike | None, optional
Initial parameter guess. If None, uses center of bounds.
bounds : tuple[ArrayLike, ArrayLike] | None
Lower and upper bounds for parameters. Required for CMA-ES.
sigma : ArrayLike | None, optional
Standard deviation of ydata for weighted residuals.
**kwargs : Any
Additional keyword arguments (passed to NLSQ refinement).
Returns
-------
dict[str, Any]
Result dictionary containing:
- popt: Optimal parameters
- pcov: Parameter covariance matrix (from NLSQ refinement)
- Additional fields from NLSQ result
Raises
------
ValueError
If bounds are not provided (required for CMA-ES).
"""
# Validate bounds
if bounds is None:
raise ValueError(
"CMA-ES requires explicit bounds. "
"Provide bounds as (lower_bounds, upper_bounds)."
)
# Convert inputs to JAX arrays
xdata_jax = jnp.asarray(xdata)
ydata_jax = jnp.asarray(ydata)
lower_bounds = jnp.asarray(bounds[0])
upper_bounds = jnp.asarray(bounds[1])
sigma_jax = jnp.asarray(sigma) if sigma is not None else None
n_params = len(lower_bounds)
n_data = len(ydata_jax)
# Log initialization
logger.info(
f"CMA-ES optimizer initialized: n_params={n_params}, n_data={n_data}, "
f"restart_strategy={self.config.restart_strategy}"
)
logger.debug(
f"CMA-ES bounds: lower={np.asarray(lower_bounds)}, "
f"upper={np.asarray(upper_bounds)}"
)
# Determine population size
popsize = self.config.popsize
if popsize is None:
popsize = compute_default_popsize(n_params)
# Double population for cmaes-global preset
# (detected by max_generations == 200 and bipop)
if (
self.config.max_generations == 200
and self.config.restart_strategy == "bipop"
):
popsize = popsize * 2
logger.debug("CMA-ES: Using 2x population for cmaes-global preset")
# Log memory optimization settings
if self.config.population_batch_size is not None:
logger.info(
f"CMA-ES memory optimization: population_batch_size="
f"{self.config.population_batch_size}"
)
if self.config.data_chunk_size is not None:
logger.info(
f"CMA-ES memory optimization: data_chunk_size="
f"{self.config.data_chunk_size} (data streaming enabled)"
)
# Determine initial solution
if p0 is not None:
p0_jax = jnp.asarray(p0)
# Transform to unbounded space
initial_solution = transform_from_bounds(p0_jax, lower_bounds, upper_bounds)
logger.debug(f"CMA-ES starting from p0={np.asarray(p0_jax)}")
else:
# Start at center of bounds (x=0 in unbounded space = midpoint)
initial_solution = jnp.zeros(n_params)
midpoint = (lower_bounds + upper_bounds) / 2
logger.debug(f"CMA-ES starting from bounds midpoint={np.asarray(midpoint)}")
# Create fitness function
fitness_fn = _create_fitness_function(
f,
xdata_jax,
ydata_jax,
lower_bounds,
upper_bounds,
sigma_jax,
population_batch_size=self.config.population_batch_size,
data_chunk_size=self.config.data_chunk_size,
)
# Track wall time
import time
start_time = time.perf_counter()
# Initialize diagnostics
diagnostics = CMAESDiagnostics()
# Run CMA-ES optimization (diagnostics updated in place)
best_params_unbounded, best_fitness, generations = self._run_cmaes(
fitness_fn, initial_solution, popsize, n_params, diagnostics
)
# Update diagnostics
diagnostics.total_generations = generations
diagnostics.best_fitness = float(best_fitness)
diagnostics.wall_time = time.perf_counter() - start_time
# Transform best solution back to bounded space
best_params = transform_to_bounds(
best_params_unbounded, lower_bounds, upper_bounds
)
logger.info(
f"CMA-ES optimization completed: {generations} generations, "
f"best_fitness={float(best_fitness):.6e}, "
f"wall_time={diagnostics.wall_time:.2f}s"
)
# NLSQ refinement phase for proper pcov estimation
if self.config.refine_with_nlsq:
result = self._nlsq_refinement(
f, xdata, ydata, best_params, bounds, sigma, **kwargs
)
diagnostics.nlsq_refinement = True
else:
# Return CMA-ES result without refinement
result = {
"popt": np.asarray(best_params),
"pcov": self._estimate_pcov_from_cmaes(n_params),
}
diagnostics.nlsq_refinement = False
# Add diagnostics to result
result["cmaes_diagnostics"] = diagnostics.to_dict()
return result
def _run_cmaes(
self,
fitness_fn: Callable,
initial_solution: jax.Array,
popsize: int,
n_params: int,
diagnostics: CMAESDiagnostics,
) -> tuple[jax.Array, jax.Array, int]:
"""Run CMA-ES optimization loop with optional BIPOP restarts.
Parameters
----------
fitness_fn : Callable
Fitness function for population evaluation.
initial_solution : jax.Array
Initial solution in unbounded space.
popsize : int
Population size (base population for BIPOP).
n_params : int
Number of parameters.
diagnostics : CMAESDiagnostics
Diagnostics object to update with run information.
Returns
-------
tuple[jax.Array, jax.Array, int]
Best solution, best fitness, and total number of generations.
"""
if self.config.restart_strategy == "bipop":
return self._run_cmaes_with_bipop(
fitness_fn, initial_solution, popsize, n_params, diagnostics
)
else:
return self._run_cmaes_single(
fitness_fn, initial_solution, popsize, n_params, diagnostics
)
def _run_cmaes_single(
self,
fitness_fn: Callable,
initial_solution: jax.Array,
popsize: int,
n_params: int,
diagnostics: CMAESDiagnostics,
) -> tuple[jax.Array, jax.Array, int]:
"""Run single CMA-ES optimization without restarts.
Parameters
----------
fitness_fn : Callable
Fitness function for population evaluation.
initial_solution : jax.Array
Initial solution in unbounded space.
popsize : int
Population size.
n_params : int
Number of parameters.
diagnostics : CMAESDiagnostics
Diagnostics object to update with run information.
Returns
-------
tuple[jax.Array, jax.Array, int]
Best solution, best fitness, and number of generations.
"""
from evosax.algorithms import ( # type: ignore[import-not-found,import-untyped]
CMA_ES,
)
logger.info(
f"Starting CMA-ES: popsize={popsize}, max_gen={self.config.max_generations}"
)
# Initialize CMA-ES
es = CMA_ES(population_size=popsize, solution=initial_solution)
params = es.default_params
# Set initial sigma
params = params.replace(std_init=self.config.sigma)
# Initialize random key
if self.config.seed is not None:
key = jax.random.key(self.config.seed)
else:
key = jax.random.key(np.random.randint(0, 2**31))
# Initialize state
key, subkey = jax.random.split(key)
state = es.init(subkey, initial_solution, params)
# Track best solution
best_solution = initial_solution
best_fitness = jnp.array(-jnp.inf)
convergence_reason = "max_generations"
# Progress milestones for logging (25%, 50%, 75%)
# Build the dict from a list so later entries don't silently overwrite
# earlier ones when max_generations is small (e.g. <=3 causes collisions).
milestones: dict[int, str] = {}
for pct, label in ((0.25, "25%"), (0.50, "50%"), (0.75, "75%")):
gen_idx = int(self.config.max_generations * pct)
if gen_idx not in milestones:
milestones[gen_idx] = label
# Main optimization loop
gen = -1
for gen in range(self.config.max_generations):
key, key_ask, key_tell = jax.random.split(key, 3)
# Ask for new population
population, state = es.ask(key_ask, state, params)
# Evaluate fitness
fitness = fitness_fn(population)
# Update CMA-ES state
state, _metrics = es.tell(key_tell, population, fitness, state, params)
# Track best (CMA-ES maximizes, so higher is better)
if state.best_fitness > best_fitness:
best_fitness = state.best_fitness
best_solution = state.best_solution
# Record fitness history
diagnostics.fitness_history.append(float(best_fitness))
# Simple convergence check based on std
if float(state.std) < self.config.tol_x:
logger.info(
f"CMA-ES converged at generation {gen + 1}: "
f"std={float(state.std):.2e} < tol_x={self.config.tol_x:.2e}"
)
convergence_reason = "xtol"
break
# Log progress at milestones (INFO level)
if gen + 1 in milestones:
logger.info(
f"CMA-ES progress {milestones[gen + 1]}: "
f"gen={gen + 1}/{self.config.max_generations}, "
f"best_fitness={float(best_fitness):.6e}, std={float(state.std):.2e}"
)
# Log detailed progress at debug level
if logger.isEnabledFor(logging.DEBUG) and (gen + 1) % 10 == 0:
logger.debug(
f"Generation {gen + 1}/{self.config.max_generations}: "
f"best_fitness={float(best_fitness):.6e}, std={float(state.std):.6e}"
)
# Update diagnostics
diagnostics.final_sigma = float(state.std)
diagnostics.convergence_reason = convergence_reason
diagnostics.total_restarts = 0
return best_solution, best_fitness, gen + 1
def _run_cmaes_with_bipop(
self,
fitness_fn: Callable,
initial_solution: jax.Array,
base_popsize: int,
n_params: int,
diagnostics: CMAESDiagnostics,
) -> tuple[jax.Array, jax.Array, int]:
"""Run CMA-ES with BIPOP restart strategy.
Alternates between large and small population runs, tracking the
global best across all restarts.
Parameters
----------
fitness_fn : Callable
Fitness function for population evaluation.
initial_solution : jax.Array
Initial solution in unbounded space.
base_popsize : int
Base population size for BIPOP (will be doubled for large runs).
n_params : int
Number of parameters.
diagnostics : CMAESDiagnostics
Diagnostics object to update with run information.
Returns
-------
tuple[jax.Array, jax.Array, int]
Best solution, best fitness, and total number of generations.
"""
from evosax.algorithms import ( # type: ignore[import-not-found,import-untyped]
CMA_ES,
)
from nlsq.global_optimization.bipop import BIPOPRestarter
logger.info(
f"Starting CMA-ES with BIPOP: base_popsize={base_popsize}, "
f"max_restarts={self.config.max_restarts}, max_gen={self.config.max_generations}"
)
# Initialize BIPOP restarter
restarter = BIPOPRestarter(
base_popsize=base_popsize,
n_params=n_params,
max_restarts=self.config.max_restarts,
min_fitness_spread=self.config.tol_fun,
)
# Initialize random key
if self.config.seed is not None:
key = jax.random.key(self.config.seed)
else:
key = jax.random.key(np.random.randint(0, 2**31))
total_generations = 0
convergence_reason = "max_restarts"
final_sigma = self.config.sigma
while not restarter.exhausted:
# Get population size for this run
popsize = restarter.get_next_popsize()
run_type = "large" if popsize >= base_popsize * 2 else "small"
logger.info(
f"BIPOP restart #{restarter.restart_count + 1}: "
f"popsize={popsize} ({run_type}), "
f"max_gen={self.config.max_generations}"
)
# Initialize CMA-ES for this run
es = CMA_ES(population_size=popsize, solution=initial_solution)
params = es.default_params
params = params.replace(std_init=self.config.sigma)
key, subkey = jax.random.split(key)
state = es.init(subkey, initial_solution, params)
# Track best for this run
run_best_solution = initial_solution
run_best_fitness = jnp.array(-jnp.inf)
# Run optimization loop
stagnation_counter = 0
gen = -1
for gen in range(self.config.max_generations):
key, key_ask, key_tell = jax.random.split(key, 3)
population, state = es.ask(key_ask, state, params)
fitness = fitness_fn(population)
state, _metrics = es.tell(key_tell, population, fitness, state, params)
# Track best for this run
if state.best_fitness > run_best_fitness:
run_best_fitness = state.best_fitness
run_best_solution = state.best_solution
# Record fitness history
diagnostics.fitness_history.append(float(run_best_fitness))
# Check for stagnation
fitness_spread = float(jnp.max(fitness) - jnp.min(fitness))
if restarter.check_stagnation(fitness_spread):
stagnation_counter += 1
else:
stagnation_counter = 0
# Trigger restart after sustained stagnation (5 consecutive)
if stagnation_counter >= 5:
logger.info(
f"BIPOP run #{restarter.restart_count + 1}: "
f"stagnation at gen {gen + 1}, fitness_spread={fitness_spread:.2e}"
)
break
# Also check std-based convergence
if float(state.std) < self.config.tol_x:
logger.info(
f"BIPOP run #{restarter.restart_count + 1}: "
f"converged at gen {gen + 1}, std={float(state.std):.2e}"
)
break
# Log progress
if logger.isEnabledFor(logging.DEBUG) and (gen + 1) % 10 == 0:
logger.debug(
f"BIPOP Run {restarter.restart_count + 1}: "
f"gen {gen + 1}/{self.config.max_generations}, "
f"best_fitness={float(run_best_fitness):.6e}, "
f"std={float(state.std):.6e}"
)
total_generations += gen + 1
final_sigma = float(state.std)
logger.info(
f"BIPOP run #{restarter.restart_count + 1} completed: "
f"{gen + 1} generations, best_fitness={float(run_best_fitness):.6e}"
)
# Record restart info
diagnostics.restart_history.append(
{
"popsize": popsize,
"generations": gen + 1,
"best_fitness": float(run_best_fitness),
"final_sigma": final_sigma,
}
)
# Update global best
restarter.update_best(run_best_solution, float(run_best_fitness))
# Check if this run converged well (no need for more restarts)
if float(state.std) < self.config.tol_x and stagnation_counter < 5:
logger.info("BIPOP: Good convergence achieved, stopping restarts early")
convergence_reason = "xtol"
break
# Register restart for next iteration
restarter.register_restart()
# Use current best as starting point for next run
# (adds exploitation around best solution)
if restarter.best_solution is not None:
# 50% chance to restart from best, 50% from origin
key, subkey = jax.random.split(key)
if jax.random.uniform(subkey) > 0.5:
initial_solution = restarter.best_solution
# Get global best
best_solution, best_fitness = restarter.get_best()
if best_solution is None:
best_solution = initial_solution
best_fitness = float("-inf")
logger.info(
f"BIPOP completed: {restarter.restart_count} restarts, "
f"{total_generations} total generations"
)
# Update diagnostics
diagnostics.total_restarts = restarter.restart_count
diagnostics.final_sigma = final_sigma
diagnostics.convergence_reason = convergence_reason
return best_solution, jnp.array(best_fitness), total_generations
def _nlsq_refinement(
self,
f: Callable,
xdata: ArrayLike,
ydata: ArrayLike,
p0: jax.Array,
bounds: tuple[ArrayLike, ArrayLike],
sigma: ArrayLike | None,
**kwargs: Any,
) -> dict[str, Any]:
"""Run NLSQ Trust Region Reflective refinement.
This phase provides proper parameter covariance estimation via Jacobian.
Parameters
----------
f : Callable
Model function.
xdata : ArrayLike
Independent variable data.
ydata : ArrayLike
Dependent variable data.
p0 : jax.Array
Initial parameters from CMA-ES.
bounds : tuple[ArrayLike, ArrayLike]
Parameter bounds.
sigma : ArrayLike | None
Standard deviation for weighted residuals.
**kwargs : Any
Additional arguments for curve_fit.
Returns
-------
dict[str, Any]
Result dictionary with popt, pcov, and additional fields.
"""
from nlsq.core.minpack import curve_fit
# Convert p0 to numpy for NLSQ
p0_numpy = np.asarray(p0)
logger.info(
f"Starting NLSQ Trust Region Reflective refinement "
f"(n_params={len(p0_numpy)})"
)
logger.debug(f"NLSQ refinement starting from: {p0_numpy}")
# Convert to numpy arrays for NLSQ compatibility
xdata_np = np.asarray(xdata)
ydata_np = np.asarray(ydata)
sigma_np = np.asarray(sigma) if sigma is not None else None
try:
# Run NLSQ curve_fit for refinement with memory-aware workflow
# Use workflow='auto' to auto-select memory strategy (standard/chunked/streaming)
# This prevents OOM on large datasets that were handled with data_chunk_size
# during the CMA-ES evolutionary phase
refinement_kwargs = {**kwargs}
refinement_kwargs.pop(
"workflow", None
) # Remove if present to avoid conflict
n_points = len(ydata_np)
logger.debug(
f"NLSQ refinement using workflow='auto' for {n_points:,} points"
)
result = curve_fit(
f,
xdata_np,
ydata_np,
p0=p0_numpy,
sigma=sigma_np,
bounds=bounds,
workflow="auto", # Memory-aware: auto-selects standard/chunked/streaming
**refinement_kwargs,
)
# CurveFitResult has .x for parameters, .pcov for covariance
popt = np.asarray(result.x) # type: ignore[union-attr]
pcov = np.asarray(result.pcov) # type: ignore[union-attr]
# Compute parameter change from CMA-ES to NLSQ
param_change = np.linalg.norm(popt - p0_numpy)
logger.info(
f"NLSQ refinement completed: "
f"parameter adjustment norm={param_change:.6e}"
)
logger.debug(f"NLSQ refined popt={popt}")
return {
"popt": popt,
"pcov": pcov,
"nlsq_result": result, # Include full result for diagnostics
}
except Exception as e:
logger.warning(f"NLSQ refinement failed: {e}. Using CMA-ES result.")
# Return CMA-ES result if refinement fails
return {
"popt": p0_numpy,
"pcov": self._estimate_pcov_from_cmaes(len(p0_numpy)),
}
def _estimate_pcov_from_cmaes(self, n_params: int) -> NDArray[np.floating]:
"""Estimate parameter covariance when NLSQ refinement is disabled.
This is a rough estimate; for proper pcov, use refine_with_nlsq=True.
Parameters
----------
n_params : int
Number of parameters.
Returns
-------
NDArray[np.floating]
Estimated covariance matrix (diagonal approximation).
"""
# Return diagonal matrix with inf variance to indicate unknown covariance
# Proper pcov requires Jacobian from NLSQ
# Use np.full to avoid RuntimeWarning from 0.0 * np.inf = nan in np.eye() * np.inf
pcov = np.full((n_params, n_params), 0.0)
np.fill_diagonal(pcov, np.inf)
return pcov