Source code for nlsq.global_optimization.multi_start

"""
Multi-Start Orchestrator for Global Optimization
=================================================

This module provides the MultiStartOrchestrator class which evaluates multiple
starting points generated by Latin Hypercube Sampling (or other quasi-random
samplers) and selects the best one based on loss value.

The orchestrator integrates with the existing NLSQ curve fitting infrastructure
and supports automatic bounds inference and parameter estimation.

Key Features
------------
- Evaluate N starting points and select best by loss value
- Parallel evaluation using ThreadPoolExecutor with per-thread CurveFit isolation
- Adaptive worker count based on hardware (GPU count, CPU cores)
- Preset configurations: 'fast', 'robust', 'global', 'thorough', 'streaming'
- Integration with parameter estimation for centering LHS samples
- Automatic bounds inference when bounds not provided
- Multi-start diagnostics in result object

Examples
--------
Basic usage:

>>> from nlsq.global_optimization import MultiStartOrchestrator, GlobalOptimizationConfig
>>> import numpy as np
>>> import jax.numpy as jnp
>>>
>>> def model(x, a, b, c):
...     return a * jnp.exp(-b * x) + c
>>>
>>> x = np.linspace(0, 5, 50)
>>> y = 3 * np.exp(-0.5 * x) + 1 + np.random.normal(0, 0.1, 50)
>>>
>>> config = GlobalOptimizationConfig(n_starts=10)
>>> orchestrator = MultiStartOrchestrator(config=config)
>>> result = orchestrator.fit(model, x, y, bounds=([0, 0, 0], [10, 5, 10]))

Using presets:

>>> orchestrator = MultiStartOrchestrator.from_preset('robust')
>>> result = orchestrator.fit(model, x, y)

See Also
--------
GlobalOptimizationConfig : Configuration for multi-start optimization
latin_hypercube_sample : LHS sampling function
TournamentSelector : Tournament selection for large datasets (Task Group 4)
"""

import logging
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import jax
import numpy as np

from nlsq.global_optimization.config import GlobalOptimizationConfig

# Type-only import to avoid circular dependency
if TYPE_CHECKING:
    from nlsq.core.minpack import CurveFit
from nlsq.global_optimization.sampling import (
    center_samples_around_p0,
    get_sampler,
    scale_samples_to_bounds,
)
from nlsq.precision.bound_inference import infer_bounds_for_multistart
from nlsq.precision.parameter_estimation import estimate_initial_parameters
from nlsq.result import CurveFitResult
from nlsq.utils.logging import get_logger

__all__ = ["MultiStartOrchestrator"]

_worker_logger = logging.getLogger("nlsq.multi_start")


def _fit_single_start(
    f: Callable,
    xdata: np.ndarray,
    ydata: np.ndarray,
    p0: np.ndarray,
    bounds: tuple,
    kwargs: dict,
) -> tuple[np.ndarray, float, Any]:
    """Thread worker -- isolated CurveFit instance per call.

    Each worker creates its own CurveFit() to isolate mutable per-fit state
    in LeastSquares (self.n, self.func_*, self.jac*). JIT-compiled functions
    and the compilation cache are shared safely across threads.

    Parameters
    ----------
    f : Callable
        Model function ``f(x, *params) -> y``.
    xdata : np.ndarray
        Independent variable data.
    ydata : np.ndarray
        Dependent variable data.
    p0 : np.ndarray
        Initial parameter guess for this starting point.
    bounds : tuple
        Parameter bounds (lower, upper).
    kwargs : dict
        Additional arguments passed to curve_fit.

    Returns
    -------
    tuple
        (params, loss, result) where params is the fitted parameters,
        loss is the sum of squared residuals (inf on failure), and
        result is the CurveFitResult or None on failure.
    """
    from nlsq.core.minpack import CurveFit

    cf = CurveFit()
    try:
        result: CurveFitResult = cf.curve_fit(  # type: ignore[assignment]
            f, xdata, ydata, p0=p0, bounds=bounds, **kwargs
        )
        if hasattr(result, "cost"):
            loss = 2 * result.cost  # cost is half sum of squares
        elif hasattr(result, "fun"):
            loss = float(np.sum(np.asarray(result.fun) ** 2))
        else:
            # Fallback: compute from residuals
            predictions = f(xdata, *result.popt)
            loss = float(np.sum((ydata - np.asarray(predictions)) ** 2))
        return (result.popt, loss, result)
    except Exception as e:
        _worker_logger.debug("Starting point failed: %s", e, exc_info=True)
        return (p0, float("inf"), None)


def _select_worker_count(n_starts: int) -> int:
    """Select optimal number of parallel workers.

    Adapts to hardware:
    - Multi-GPU: one fit per GPU
    - Single GPU: min(n_starts, 4) to limit memory contention
    - CPU-only: min(n_starts, cpu_count, 16) to prevent XLA oversubscription

    Parameters
    ----------
    n_starts : int
        Number of starting points to evaluate.

    Returns
    -------
    int
        Number of workers to use for parallel evaluation.
    """
    n_devices = len(jax.devices())
    if n_devices > 1:
        return min(n_starts, n_devices)
    elif n_devices == 1:
        return min(n_starts, 4)
    else:
        return min(n_starts, os.cpu_count() or 4, 16)


[docs] class MultiStartOrchestrator: """Orchestrator for multi-start optimization with LHS sampling. This class generates multiple starting points using Latin Hypercube Sampling (or other quasi-random samplers), evaluates each starting point by running a full optimization, and selects the best result based on minimum loss. Parameters ---------- config : GlobalOptimizationConfig, optional Configuration for multi-start optimization. If None, uses default config. curve_fit_instance : CurveFit, optional Custom CurveFit instance to use for optimization. If None, creates a new one. Attributes ---------- config : GlobalOptimizationConfig Configuration settings for multi-start optimization. curve_fit : CurveFit CurveFit instance used for running optimizations. logger : NLSQLogger Logger for multi-start operations. Examples -------- >>> from nlsq.global_optimization import MultiStartOrchestrator, GlobalOptimizationConfig >>> >>> # Basic usage >>> orchestrator = MultiStartOrchestrator() >>> result = orchestrator.fit(model, x, y, bounds=([0, 0], [10, 10])) >>> >>> # With custom config >>> config = GlobalOptimizationConfig(n_starts=20, sampler='sobol') >>> orchestrator = MultiStartOrchestrator(config=config) >>> >>> # Using presets >>> orchestrator = MultiStartOrchestrator.from_preset('global') """
[docs] def __init__( self, config: GlobalOptimizationConfig | None = None, curve_fit_instance: "CurveFit | None" = None, ): """Initialize MultiStartOrchestrator. Parameters ---------- config : GlobalOptimizationConfig, optional Configuration for multi-start optimization. curve_fit_instance : CurveFit, optional Custom CurveFit instance for optimization. """ self.config = config if config is not None else GlobalOptimizationConfig() if curve_fit_instance is not None: self.curve_fit = curve_fit_instance else: # Deferred import to avoid circular dependency from nlsq.core.minpack import CurveFit self.curve_fit = CurveFit() self.logger = get_logger("multi_start")
[docs] @classmethod def from_preset( cls, preset_name: str, curve_fit_instance: "CurveFit | None" = None ) -> "MultiStartOrchestrator": """Create orchestrator from a named preset. Parameters ---------- preset_name : str Name of the preset. One of: 'fast', 'robust', 'global', 'thorough', 'streaming'. curve_fit_instance : CurveFit, optional Custom CurveFit instance for optimization. Returns ------- MultiStartOrchestrator Orchestrator instance configured with preset values. Raises ------ ValueError If preset_name is not a known preset. Examples -------- >>> orchestrator = MultiStartOrchestrator.from_preset('robust') >>> print(orchestrator.config.n_starts) 5 """ config = GlobalOptimizationConfig.from_preset(preset_name) return cls(config=config, curve_fit_instance=curve_fit_instance)
def _generate_starting_points( self, n_params: int, lb: np.ndarray, ub: np.ndarray, p0: np.ndarray | None = None, rng_key: Any = None, ) -> np.ndarray: """Generate starting points using the configured sampler. Parameters ---------- n_params : int Number of parameters. lb : np.ndarray Lower bounds for parameters. ub : np.ndarray Upper bounds for parameters. p0 : np.ndarray, optional Center point for centering samples (if center_on_p0=True). rng_key : Any, optional JAX random key for reproducibility. Returns ------- np.ndarray Array of shape (n_starts, n_params) with starting points. """ n_starts = self.config.n_starts if n_starts <= 0: return np.array([]).reshape(0, n_params) # Get sampler function sampler = get_sampler(self.config.sampler) # Generate samples in [0, 1]^n if self.config.sampler == "lhs": # LHS accepts rng_key samples_raw = sampler(n_starts, n_params, rng_key=rng_key) else: # Sobol and Halton are deterministic samples_raw = sampler(n_starts, n_params) # Scale to bounds if self.config.center_on_p0 and p0 is not None: # Center samples around p0 import jax.numpy as jnp samples_jnp = jnp.asarray(samples_raw) p0_jnp = jnp.asarray(p0) lb_jnp = jnp.asarray(lb) ub_jnp = jnp.asarray(ub) scaled = center_samples_around_p0( samples_jnp, p0_jnp, self.config.scale_factor, lb_jnp, ub_jnp, ) return np.asarray(scaled) else: # Scale to full bounds import jax.numpy as jnp samples_jnp = jnp.asarray(samples_raw) lb_jnp = jnp.asarray(lb) ub_jnp = jnp.asarray(ub) scaled = scale_samples_to_bounds(samples_jnp, lb_jnp, ub_jnp) return np.asarray(scaled)
[docs] def evaluate_starting_points( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, starting_points: np.ndarray, bounds: tuple[np.ndarray, np.ndarray], **kwargs: Any, ) -> tuple[ list[tuple[np.ndarray, float, CurveFitResult | None]], dict[str, Any], ]: """Evaluate starting points in parallel using ThreadPoolExecutor. Each starting point is evaluated with an isolated CurveFit instance to avoid thread-safety issues with mutable per-fit state in LeastSquares. JIT-compiled functions and the compilation cache are shared safely across threads. Parameters ---------- f : Callable Model function ``f(x, *params) -> y``. xdata : np.ndarray Independent variable data. ydata : np.ndarray Dependent variable data. starting_points : np.ndarray Array of shape (n_starts, n_params) with starting points. bounds : tuple Tuple of (lower_bounds, upper_bounds) for parameters. **kwargs Additional arguments passed to curve_fit. Returns ------- results : list[tuple[np.ndarray, float, CurveFitResult | None]] List of (params, loss, result) tuples sorted by loss (ascending). diagnostics : dict[str, Any] Parallel diagnostics with keys ``parallel``, ``n_workers``, ``wall_time_sec``. """ if len(starting_points) == 0: return [], {"parallel": False, "n_workers": 0, "wall_time_sec": 0.0} import time as _time from concurrent.futures import ThreadPoolExecutor, as_completed n_starts = len(starting_points) n_workers = _select_worker_count(n_starts) self.logger.debug( f"Evaluating {n_starts} starting points with {n_workers} workers", n_starts=n_starts, n_workers=n_workers, ) start_wall = _time.monotonic() results: list[tuple[np.ndarray, float, CurveFitResult | None]] = [ None # type: ignore[list-item] ] * n_starts if n_workers <= 1: # Sequential fallback (1 start or 1 worker) for i, p0 in enumerate(starting_points): results[i] = _fit_single_start( f, xdata, ydata, p0, bounds, kwargs.copy() ) _params, loss, _ = results[i] self.logger.debug( f"Starting point {i + 1}/{n_starts}: loss={loss:.6f}", starting_point_idx=i, loss=loss, ) else: # Parallel execution futures = {} with ThreadPoolExecutor(max_workers=n_workers) as executor: for i, p0 in enumerate(starting_points): future = executor.submit( _fit_single_start, f, xdata, ydata, p0, bounds, kwargs.copy(), ) futures[future] = i for future in as_completed(futures): i = futures[future] results[i] = future.result() _params, loss, _ = results[i] self.logger.debug( f"Starting point {i + 1}/{n_starts}: loss={loss:.6f}", starting_point_idx=i, loss=loss, ) wall_time = _time.monotonic() - start_wall parallel_diagnostics: dict[str, Any] = { "parallel": n_workers > 1, "n_workers": n_workers, "wall_time_sec": wall_time, } # Sort by loss (ascending) results.sort(key=lambda x: x[1]) return results, parallel_diagnostics
[docs] def fit( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | None = None, bounds: tuple | None = None, **kwargs: Any, ) -> CurveFitResult: """Run multi-start optimization and return best result. This is the main entry point for multi-start optimization. It: 1. Infers bounds if not provided 2. Estimates p0 if not provided and center_on_p0=True 3. Generates LHS starting points 4. Evaluates all starting points 5. Returns the result with the minimum loss Parameters ---------- f : Callable Model function ``f(x, *params) -> y``. xdata : np.ndarray Independent variable data. ydata : np.ndarray Dependent variable data. p0 : np.ndarray, optional Initial parameter guess. Used for centering if center_on_p0=True. bounds : tuple, optional Tuple of (lower_bounds, upper_bounds) for parameters. If None, bounds are inferred from data. kwargs : dict Additional arguments passed to curve_fit. Returns ------- CurveFitResult Best optimization result with multi-start diagnostics. Examples -------- >>> orchestrator = MultiStartOrchestrator.from_preset('robust') >>> result = orchestrator.fit(model, x, y, bounds=([0, 0], [10, 10])) >>> print(f"Best params: {result.popt}") >>> print(f"Multi-start diagnostics: {result.multistart_diagnostics}") """ xdata = np.asarray(xdata) ydata = np.asarray(ydata) # Diagnostics tracking diagnostics: dict[str, Any] = { "n_starts_configured": self.config.n_starts, "sampler": self.config.sampler, "center_on_p0": self.config.center_on_p0, "scale_factor": self.config.scale_factor, "bounds_inferred": False, "p0_estimated": False, "bypassed": False, } # Check if multi-start is disabled (n_starts=0) if self.config.n_starts == 0: self.logger.debug("Multi-start disabled (n_starts=0), using single-start") diagnostics["bypassed"] = True diagnostics["n_starts_evaluated"] = 0 # Run single-start optimization result: CurveFitResult = self.curve_fit.curve_fit( # type: ignore[assignment] f, xdata, ydata, p0=p0, bounds=bounds or (-np.inf, np.inf), **kwargs, ) # Add diagnostics to result result["multistart_diagnostics"] = diagnostics return result # Determine number of parameters if p0 is not None: p0 = np.atleast_1d(p0) n_params = len(p0) else: # Try to estimate p0 for parameter count try: p0_estimated = estimate_initial_parameters(f, xdata, ydata, p0="auto") p0 = np.atleast_1d(p0_estimated) n_params = len(p0) diagnostics["p0_estimated"] = True self.logger.debug( f"Estimated initial parameters: {p0}", p0=p0.tolist(), ) except Exception as e: # Fall back to signature inspection from inspect import signature sig = signature(f) n_params = len(sig.parameters) - 1 # Exclude x p0 = np.ones(n_params) self.logger.debug( f"Could not estimate p0, using defaults: {e}", n_params=n_params, ) # Handle bounds if bounds is None: # Infer bounds for multi-start self.logger.debug("Inferring bounds for multi-start sampling") lb, ub = infer_bounds_for_multistart(xdata, ydata, p0) bounds = (lb, ub) diagnostics["bounds_inferred"] = True diagnostics["inferred_bounds"] = { "lower": lb.tolist(), "upper": ub.tolist(), } else: # Convert bounds to arrays lb = np.atleast_1d(bounds[0]) ub = np.atleast_1d(bounds[1]) # Broadcast scalar bounds if len(lb) == 1: lb = np.full(n_params, lb[0]) if len(ub) == 1: ub = np.full(n_params, ub[0]) # Check for infinite bounds and infer if needed if not np.all(np.isfinite(lb)) or not np.all(np.isfinite(ub)): self.logger.debug("Some bounds are infinite, inferring finite bounds") lb_inferred, ub_inferred = infer_bounds_for_multistart( xdata, ydata, p0, user_bounds=(lb, ub) ) lb = lb_inferred ub = ub_inferred diagnostics["bounds_inferred"] = True bounds = (lb, ub) # Generate starting points self.logger.info( f"Generating {self.config.n_starts} starting points using {self.config.sampler}", n_starts=self.config.n_starts, sampler=self.config.sampler, center_on_p0=self.config.center_on_p0, ) starting_points = self._generate_starting_points( n_params=n_params, lb=lb, ub=ub, p0=p0 if self.config.center_on_p0 else None, ) diagnostics["n_starts_generated"] = len(starting_points) # Evaluate all starting points self.logger.info(f"Evaluating {len(starting_points)} starting points") evaluation_results, parallel_diag = self.evaluate_starting_points( f=f, xdata=xdata, ydata=ydata, starting_points=starting_points, bounds=bounds, **kwargs, ) # Merge parallel diagnostics diagnostics.update(parallel_diag) # Count successful evaluations n_successful = sum( 1 for _, loss, _ in evaluation_results if loss < float("inf") ) diagnostics["n_starts_evaluated"] = len(evaluation_results) diagnostics["n_starts_successful"] = n_successful # Select best result if n_successful == 0: self.logger.warning( "All starting points failed, falling back to single-start" ) # Fall back to single-start with original p0 result = self.curve_fit.curve_fit( # type: ignore[assignment] f, xdata, ydata, p0=p0, bounds=bounds, **kwargs, ) diagnostics["fallback_to_single_start"] = True else: # Get best result (first in sorted list with finite loss) best_params, best_loss, best_result = evaluation_results[0] self.logger.info( f"Best starting point: loss={best_loss:.6f}", best_loss=best_loss, best_params=best_params.tolist() if hasattr(best_params, "tolist") else list(best_params), ) diagnostics["best_loss"] = best_loss diagnostics["all_losses"] = [ loss for _, loss, _ in evaluation_results if loss < float("inf") ] if best_result is not None: result = best_result else: # Shouldn't happen, but fallback result = self.curve_fit.curve_fit( # type: ignore[assignment] f, xdata, ydata, p0=best_params, bounds=bounds, **kwargs, ) # Add diagnostics to result result["multistart_diagnostics"] = diagnostics # Also add as attribute for convenience if not hasattr(result, "multistart_diagnostics"): result.multistart_diagnostics = diagnostics # type: ignore[attr-defined] return result