"""
Multi-Start Orchestrator for Global Optimization
=================================================
This module provides the MultiStartOrchestrator class which evaluates multiple
starting points generated by Latin Hypercube Sampling (or other quasi-random
samplers) and selects the best one based on loss value.
The orchestrator integrates with the existing NLSQ curve fitting infrastructure
and supports automatic bounds inference and parameter estimation.
Key Features
------------
- Evaluate N starting points and select best by loss value
- Parallel evaluation using ThreadPoolExecutor with per-thread CurveFit isolation
- Adaptive worker count based on hardware (GPU count, CPU cores)
- Preset configurations: 'fast', 'robust', 'global', 'thorough', 'streaming'
- Integration with parameter estimation for centering LHS samples
- Automatic bounds inference when bounds not provided
- Multi-start diagnostics in result object
Examples
--------
Basic usage:
>>> from nlsq.global_optimization import MultiStartOrchestrator, GlobalOptimizationConfig
>>> import numpy as np
>>> import jax.numpy as jnp
>>>
>>> def model(x, a, b, c):
... return a * jnp.exp(-b * x) + c
>>>
>>> x = np.linspace(0, 5, 50)
>>> y = 3 * np.exp(-0.5 * x) + 1 + np.random.normal(0, 0.1, 50)
>>>
>>> config = GlobalOptimizationConfig(n_starts=10)
>>> orchestrator = MultiStartOrchestrator(config=config)
>>> result = orchestrator.fit(model, x, y, bounds=([0, 0, 0], [10, 5, 10]))
Using presets:
>>> orchestrator = MultiStartOrchestrator.from_preset('robust')
>>> result = orchestrator.fit(model, x, y)
See Also
--------
GlobalOptimizationConfig : Configuration for multi-start optimization
latin_hypercube_sample : LHS sampling function
TournamentSelector : Tournament selection for large datasets (Task Group 4)
"""
import logging
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import jax
import numpy as np
from nlsq.global_optimization.config import GlobalOptimizationConfig
# Type-only import to avoid circular dependency
if TYPE_CHECKING:
from nlsq.core.minpack import CurveFit
from nlsq.global_optimization.sampling import (
center_samples_around_p0,
get_sampler,
scale_samples_to_bounds,
)
from nlsq.precision.bound_inference import infer_bounds_for_multistart
from nlsq.precision.parameter_estimation import estimate_initial_parameters
from nlsq.result import CurveFitResult
from nlsq.utils.logging import get_logger
__all__ = ["MultiStartOrchestrator"]
_worker_logger = logging.getLogger("nlsq.multi_start")
def _fit_single_start(
f: Callable,
xdata: np.ndarray,
ydata: np.ndarray,
p0: np.ndarray,
bounds: tuple,
kwargs: dict,
) -> tuple[np.ndarray, float, Any]:
"""Thread worker -- isolated CurveFit instance per call.
Each worker creates its own CurveFit() to isolate mutable per-fit state
in LeastSquares (self.n, self.func_*, self.jac*). JIT-compiled functions
and the compilation cache are shared safely across threads.
Parameters
----------
f : Callable
Model function ``f(x, *params) -> y``.
xdata : np.ndarray
Independent variable data.
ydata : np.ndarray
Dependent variable data.
p0 : np.ndarray
Initial parameter guess for this starting point.
bounds : tuple
Parameter bounds (lower, upper).
kwargs : dict
Additional arguments passed to curve_fit.
Returns
-------
tuple
(params, loss, result) where params is the fitted parameters,
loss is the sum of squared residuals (inf on failure), and
result is the CurveFitResult or None on failure.
"""
from nlsq.core.minpack import CurveFit
cf = CurveFit()
try:
result: CurveFitResult = cf.curve_fit( # type: ignore[assignment]
f, xdata, ydata, p0=p0, bounds=bounds, **kwargs
)
if hasattr(result, "cost"):
loss = 2 * result.cost # cost is half sum of squares
elif hasattr(result, "fun"):
loss = float(np.sum(np.asarray(result.fun) ** 2))
else:
# Fallback: compute from residuals
predictions = f(xdata, *result.popt)
loss = float(np.sum((ydata - np.asarray(predictions)) ** 2))
return (result.popt, loss, result)
except Exception as e:
_worker_logger.debug("Starting point failed: %s", e, exc_info=True)
return (p0, float("inf"), None)
def _select_worker_count(n_starts: int) -> int:
"""Select optimal number of parallel workers.
Adapts to hardware:
- Multi-GPU: one fit per GPU
- Single GPU: min(n_starts, 4) to limit memory contention
- CPU-only: min(n_starts, cpu_count, 16) to prevent XLA oversubscription
Parameters
----------
n_starts : int
Number of starting points to evaluate.
Returns
-------
int
Number of workers to use for parallel evaluation.
"""
n_devices = len(jax.devices())
if n_devices > 1:
return min(n_starts, n_devices)
elif n_devices == 1:
return min(n_starts, 4)
else:
return min(n_starts, os.cpu_count() or 4, 16)
[docs]
class MultiStartOrchestrator:
"""Orchestrator for multi-start optimization with LHS sampling.
This class generates multiple starting points using Latin Hypercube Sampling
(or other quasi-random samplers), evaluates each starting point by running
a full optimization, and selects the best result based on minimum loss.
Parameters
----------
config : GlobalOptimizationConfig, optional
Configuration for multi-start optimization. If None, uses default config.
curve_fit_instance : CurveFit, optional
Custom CurveFit instance to use for optimization. If None, creates a new one.
Attributes
----------
config : GlobalOptimizationConfig
Configuration settings for multi-start optimization.
curve_fit : CurveFit
CurveFit instance used for running optimizations.
logger : NLSQLogger
Logger for multi-start operations.
Examples
--------
>>> from nlsq.global_optimization import MultiStartOrchestrator, GlobalOptimizationConfig
>>>
>>> # Basic usage
>>> orchestrator = MultiStartOrchestrator()
>>> result = orchestrator.fit(model, x, y, bounds=([0, 0], [10, 10]))
>>>
>>> # With custom config
>>> config = GlobalOptimizationConfig(n_starts=20, sampler='sobol')
>>> orchestrator = MultiStartOrchestrator(config=config)
>>>
>>> # Using presets
>>> orchestrator = MultiStartOrchestrator.from_preset('global')
"""
[docs]
def __init__(
self,
config: GlobalOptimizationConfig | None = None,
curve_fit_instance: "CurveFit | None" = None,
):
"""Initialize MultiStartOrchestrator.
Parameters
----------
config : GlobalOptimizationConfig, optional
Configuration for multi-start optimization.
curve_fit_instance : CurveFit, optional
Custom CurveFit instance for optimization.
"""
self.config = config if config is not None else GlobalOptimizationConfig()
if curve_fit_instance is not None:
self.curve_fit = curve_fit_instance
else:
# Deferred import to avoid circular dependency
from nlsq.core.minpack import CurveFit
self.curve_fit = CurveFit()
self.logger = get_logger("multi_start")
[docs]
@classmethod
def from_preset(
cls, preset_name: str, curve_fit_instance: "CurveFit | None" = None
) -> "MultiStartOrchestrator":
"""Create orchestrator from a named preset.
Parameters
----------
preset_name : str
Name of the preset. One of: 'fast', 'robust', 'global',
'thorough', 'streaming'.
curve_fit_instance : CurveFit, optional
Custom CurveFit instance for optimization.
Returns
-------
MultiStartOrchestrator
Orchestrator instance configured with preset values.
Raises
------
ValueError
If preset_name is not a known preset.
Examples
--------
>>> orchestrator = MultiStartOrchestrator.from_preset('robust')
>>> print(orchestrator.config.n_starts)
5
"""
config = GlobalOptimizationConfig.from_preset(preset_name)
return cls(config=config, curve_fit_instance=curve_fit_instance)
def _generate_starting_points(
self,
n_params: int,
lb: np.ndarray,
ub: np.ndarray,
p0: np.ndarray | None = None,
rng_key: Any = None,
) -> np.ndarray:
"""Generate starting points using the configured sampler.
Parameters
----------
n_params : int
Number of parameters.
lb : np.ndarray
Lower bounds for parameters.
ub : np.ndarray
Upper bounds for parameters.
p0 : np.ndarray, optional
Center point for centering samples (if center_on_p0=True).
rng_key : Any, optional
JAX random key for reproducibility.
Returns
-------
np.ndarray
Array of shape (n_starts, n_params) with starting points.
"""
n_starts = self.config.n_starts
if n_starts <= 0:
return np.array([]).reshape(0, n_params)
# Get sampler function
sampler = get_sampler(self.config.sampler)
# Generate samples in [0, 1]^n
if self.config.sampler == "lhs":
# LHS accepts rng_key
samples_raw = sampler(n_starts, n_params, rng_key=rng_key)
else:
# Sobol and Halton are deterministic
samples_raw = sampler(n_starts, n_params)
# Scale to bounds
if self.config.center_on_p0 and p0 is not None:
# Center samples around p0
import jax.numpy as jnp
samples_jnp = jnp.asarray(samples_raw)
p0_jnp = jnp.asarray(p0)
lb_jnp = jnp.asarray(lb)
ub_jnp = jnp.asarray(ub)
scaled = center_samples_around_p0(
samples_jnp,
p0_jnp,
self.config.scale_factor,
lb_jnp,
ub_jnp,
)
return np.asarray(scaled)
else:
# Scale to full bounds
import jax.numpy as jnp
samples_jnp = jnp.asarray(samples_raw)
lb_jnp = jnp.asarray(lb)
ub_jnp = jnp.asarray(ub)
scaled = scale_samples_to_bounds(samples_jnp, lb_jnp, ub_jnp)
return np.asarray(scaled)
[docs]
def evaluate_starting_points(
self,
f: Callable,
xdata: np.ndarray,
ydata: np.ndarray,
starting_points: np.ndarray,
bounds: tuple[np.ndarray, np.ndarray],
**kwargs: Any,
) -> tuple[
list[tuple[np.ndarray, float, CurveFitResult | None]],
dict[str, Any],
]:
"""Evaluate starting points in parallel using ThreadPoolExecutor.
Each starting point is evaluated with an isolated CurveFit instance
to avoid thread-safety issues with mutable per-fit state in
LeastSquares. JIT-compiled functions and the compilation cache are
shared safely across threads.
Parameters
----------
f : Callable
Model function ``f(x, *params) -> y``.
xdata : np.ndarray
Independent variable data.
ydata : np.ndarray
Dependent variable data.
starting_points : np.ndarray
Array of shape (n_starts, n_params) with starting points.
bounds : tuple
Tuple of (lower_bounds, upper_bounds) for parameters.
**kwargs
Additional arguments passed to curve_fit.
Returns
-------
results : list[tuple[np.ndarray, float, CurveFitResult | None]]
List of (params, loss, result) tuples sorted by loss (ascending).
diagnostics : dict[str, Any]
Parallel diagnostics with keys ``parallel``, ``n_workers``,
``wall_time_sec``.
"""
if len(starting_points) == 0:
return [], {"parallel": False, "n_workers": 0, "wall_time_sec": 0.0}
import time as _time
from concurrent.futures import ThreadPoolExecutor, as_completed
n_starts = len(starting_points)
n_workers = _select_worker_count(n_starts)
self.logger.debug(
f"Evaluating {n_starts} starting points with {n_workers} workers",
n_starts=n_starts,
n_workers=n_workers,
)
start_wall = _time.monotonic()
results: list[tuple[np.ndarray, float, CurveFitResult | None]] = [
None # type: ignore[list-item]
] * n_starts
if n_workers <= 1:
# Sequential fallback (1 start or 1 worker)
for i, p0 in enumerate(starting_points):
results[i] = _fit_single_start(
f, xdata, ydata, p0, bounds, kwargs.copy()
)
_params, loss, _ = results[i]
self.logger.debug(
f"Starting point {i + 1}/{n_starts}: loss={loss:.6f}",
starting_point_idx=i,
loss=loss,
)
else:
# Parallel execution
futures = {}
with ThreadPoolExecutor(max_workers=n_workers) as executor:
for i, p0 in enumerate(starting_points):
future = executor.submit(
_fit_single_start,
f,
xdata,
ydata,
p0,
bounds,
kwargs.copy(),
)
futures[future] = i
for future in as_completed(futures):
i = futures[future]
results[i] = future.result()
_params, loss, _ = results[i]
self.logger.debug(
f"Starting point {i + 1}/{n_starts}: loss={loss:.6f}",
starting_point_idx=i,
loss=loss,
)
wall_time = _time.monotonic() - start_wall
parallel_diagnostics: dict[str, Any] = {
"parallel": n_workers > 1,
"n_workers": n_workers,
"wall_time_sec": wall_time,
}
# Sort by loss (ascending)
results.sort(key=lambda x: x[1])
return results, parallel_diagnostics
[docs]
def fit(
self,
f: Callable,
xdata: np.ndarray,
ydata: np.ndarray,
p0: np.ndarray | None = None,
bounds: tuple | None = None,
**kwargs: Any,
) -> CurveFitResult:
"""Run multi-start optimization and return best result.
This is the main entry point for multi-start optimization. It:
1. Infers bounds if not provided
2. Estimates p0 if not provided and center_on_p0=True
3. Generates LHS starting points
4. Evaluates all starting points
5. Returns the result with the minimum loss
Parameters
----------
f : Callable
Model function ``f(x, *params) -> y``.
xdata : np.ndarray
Independent variable data.
ydata : np.ndarray
Dependent variable data.
p0 : np.ndarray, optional
Initial parameter guess. Used for centering if center_on_p0=True.
bounds : tuple, optional
Tuple of (lower_bounds, upper_bounds) for parameters. If None,
bounds are inferred from data.
kwargs : dict
Additional arguments passed to curve_fit.
Returns
-------
CurveFitResult
Best optimization result with multi-start diagnostics.
Examples
--------
>>> orchestrator = MultiStartOrchestrator.from_preset('robust')
>>> result = orchestrator.fit(model, x, y, bounds=([0, 0], [10, 10]))
>>> print(f"Best params: {result.popt}")
>>> print(f"Multi-start diagnostics: {result.multistart_diagnostics}")
"""
xdata = np.asarray(xdata)
ydata = np.asarray(ydata)
# Diagnostics tracking
diagnostics: dict[str, Any] = {
"n_starts_configured": self.config.n_starts,
"sampler": self.config.sampler,
"center_on_p0": self.config.center_on_p0,
"scale_factor": self.config.scale_factor,
"bounds_inferred": False,
"p0_estimated": False,
"bypassed": False,
}
# Check if multi-start is disabled (n_starts=0)
if self.config.n_starts == 0:
self.logger.debug("Multi-start disabled (n_starts=0), using single-start")
diagnostics["bypassed"] = True
diagnostics["n_starts_evaluated"] = 0
# Run single-start optimization
result: CurveFitResult = self.curve_fit.curve_fit( # type: ignore[assignment]
f,
xdata,
ydata,
p0=p0,
bounds=bounds or (-np.inf, np.inf),
**kwargs,
)
# Add diagnostics to result
result["multistart_diagnostics"] = diagnostics
return result
# Determine number of parameters
if p0 is not None:
p0 = np.atleast_1d(p0)
n_params = len(p0)
else:
# Try to estimate p0 for parameter count
try:
p0_estimated = estimate_initial_parameters(f, xdata, ydata, p0="auto")
p0 = np.atleast_1d(p0_estimated)
n_params = len(p0)
diagnostics["p0_estimated"] = True
self.logger.debug(
f"Estimated initial parameters: {p0}",
p0=p0.tolist(),
)
except Exception as e:
# Fall back to signature inspection
from inspect import signature
sig = signature(f)
n_params = len(sig.parameters) - 1 # Exclude x
p0 = np.ones(n_params)
self.logger.debug(
f"Could not estimate p0, using defaults: {e}",
n_params=n_params,
)
# Handle bounds
if bounds is None:
# Infer bounds for multi-start
self.logger.debug("Inferring bounds for multi-start sampling")
lb, ub = infer_bounds_for_multistart(xdata, ydata, p0)
bounds = (lb, ub)
diagnostics["bounds_inferred"] = True
diagnostics["inferred_bounds"] = {
"lower": lb.tolist(),
"upper": ub.tolist(),
}
else:
# Convert bounds to arrays
lb = np.atleast_1d(bounds[0])
ub = np.atleast_1d(bounds[1])
# Broadcast scalar bounds
if len(lb) == 1:
lb = np.full(n_params, lb[0])
if len(ub) == 1:
ub = np.full(n_params, ub[0])
# Check for infinite bounds and infer if needed
if not np.all(np.isfinite(lb)) or not np.all(np.isfinite(ub)):
self.logger.debug("Some bounds are infinite, inferring finite bounds")
lb_inferred, ub_inferred = infer_bounds_for_multistart(
xdata, ydata, p0, user_bounds=(lb, ub)
)
lb = lb_inferred
ub = ub_inferred
diagnostics["bounds_inferred"] = True
bounds = (lb, ub)
# Generate starting points
self.logger.info(
f"Generating {self.config.n_starts} starting points using {self.config.sampler}",
n_starts=self.config.n_starts,
sampler=self.config.sampler,
center_on_p0=self.config.center_on_p0,
)
starting_points = self._generate_starting_points(
n_params=n_params,
lb=lb,
ub=ub,
p0=p0 if self.config.center_on_p0 else None,
)
diagnostics["n_starts_generated"] = len(starting_points)
# Evaluate all starting points
self.logger.info(f"Evaluating {len(starting_points)} starting points")
evaluation_results, parallel_diag = self.evaluate_starting_points(
f=f,
xdata=xdata,
ydata=ydata,
starting_points=starting_points,
bounds=bounds,
**kwargs,
)
# Merge parallel diagnostics
diagnostics.update(parallel_diag)
# Count successful evaluations
n_successful = sum(
1 for _, loss, _ in evaluation_results if loss < float("inf")
)
diagnostics["n_starts_evaluated"] = len(evaluation_results)
diagnostics["n_starts_successful"] = n_successful
# Select best result
if n_successful == 0:
self.logger.warning(
"All starting points failed, falling back to single-start"
)
# Fall back to single-start with original p0
result = self.curve_fit.curve_fit( # type: ignore[assignment]
f,
xdata,
ydata,
p0=p0,
bounds=bounds,
**kwargs,
)
diagnostics["fallback_to_single_start"] = True
else:
# Get best result (first in sorted list with finite loss)
best_params, best_loss, best_result = evaluation_results[0]
self.logger.info(
f"Best starting point: loss={best_loss:.6f}",
best_loss=best_loss,
best_params=best_params.tolist()
if hasattr(best_params, "tolist")
else list(best_params),
)
diagnostics["best_loss"] = best_loss
diagnostics["all_losses"] = [
loss for _, loss, _ in evaluation_results if loss < float("inf")
]
if best_result is not None:
result = best_result
else:
# Shouldn't happen, but fallback
result = self.curve_fit.curve_fit( # type: ignore[assignment]
f,
xdata,
ydata,
p0=best_params,
bounds=bounds,
**kwargs,
)
# Add diagnostics to result
result["multistart_diagnostics"] = diagnostics
# Also add as attribute for convenience
if not hasattr(result, "multistart_diagnostics"):
result.multistart_diagnostics = diagnostics # type: ignore[attr-defined]
return result