"""
NLSQ: JAX-accelerated nonlinear least squares curve fitting.
GPU/TPU-accelerated curve fitting with automatic differentiation.
Provides drop-in SciPy compatibility with curve_fit function.
Supports large datasets through automatic chunking and streaming optimization.
Key Features
------------
- Drop-in replacement for scipy.optimize.curve_fit
- GPU/TPU acceleration via JAX
- Automatic memory management for datasets up to 100M+ points
- Streaming optimization for unlimited data
- Smart algorithm selection and numerical stability
- Unified fit() entry point with automatic workflow selection
Examples
--------
>>> import jax.numpy as jnp
>>> from nlsq import curve_fit, fit
>>> def model(x, a, b): return a * jnp.exp(-b * x)
>>> popt, pcov = curve_fit(model, xdata, ydata)
>>> # Or use unified fit() with automatic workflow selection:
>>> popt, pcov = fit(model, xdata, ydata, workflow="auto", goal="quality")
"""
import faulthandler
import os
import sys
# Enable faulthandler FIRST so any SIGBUS/SIGSEGV during import prints a
# traceback to stderr instead of dying silently. This is deliberately before
# the environment lock and all library imports below.
faulthandler.enable()
_NLSQ_DEBUG = bool(os.environ.get("NLSQ_DEBUG"))
def _dbg(msg: str) -> None:
if _NLSQ_DEBUG:
sys.stderr.write(f"[nlsq] {msg}\n")
sys.stderr.flush()
_dbg("faulthandler enabled")
# Suppress noisy XLA/PJRT C++ warnings (e.g. "Assume version compatibility.
# PjRt-IFRT does not track XLA executable versions.") that are harmless on
# JAX >=0.8 and clutter stderr. Level 2 = ERROR-only (suppress INFO+WARNING).
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
# =============================================================================
# PLATFORM-SPECIFIC ENVIRONMENT LOCK
# GPU acceleration (CUDA) is supported ONLY on Linux.
# macOS and Windows enforce CPU-only JAX backend.
# These MUST be set before any JAX, Matplotlib, or Qt modules are initialized.
# =============================================================================
if sys.platform != "linux":
# Non-Linux CPU enforcement (macOS + Windows)
os.environ["NLSQ_FORCE_CPU"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_PLATFORMS"] = "cpu"
if sys.platform == "darwin":
# macOS-specific: Prevent SIGBUS from Metal/OpenGL/XLA conflicts
if "XLA_FLAGS" not in os.environ:
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MPLBACKEND"] = "Agg"
if "QT_MAC_WANTS_LAYER" not in os.environ:
os.environ["QT_MAC_WANTS_LAYER"] = "1"
os.environ["QT_API"] = "pyside6"
os.environ["PYQTGRAPH_QT_LIB"] = "PySide6"
# Force software OpenGL — prevents Metal/OpenGL context conflicts that
# cause SIGBUS on macOS 26+ Apple Silicon with PySide6 / Qt 6.10.
os.environ.setdefault("QT_OPENGL", "software")
os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
_dbg("env lock applied")
import numpy as np
_dbg("numpy imported")
# =============================================================================
# CORE IMPORTS (Always Eager - Required for basic curve_fit usage)
# =============================================================================
# Version information
try:
from nlsq._version import __version__ # type: ignore[import-not-found]
except ImportError:
__version__ = "0.0.0+unknown"
# Standard library imports needed at module level
from typing import TYPE_CHECKING, Any, Literal
from nlsq.callbacks import (
CallbackBase,
CallbackChain,
EarlyStopping,
IterationLogger,
ProgressBar,
StopOptimization,
)
_dbg("callbacks imported")
from nlsq.core import functions
from nlsq.core.least_squares import LeastSquares
from nlsq.core.minpack import CurveFit, curve_fit
_dbg("core imported")
from nlsq.result import OptimizeResult, OptimizeWarning
from nlsq.types import ArrayLike, BoundsTuple, MethodLiteral, ModelFunction
_dbg("types imported")
# =============================================================================
# LAZY IMPORT SYSTEM
# Specialty modules are loaded on-demand to reduce import time.
# This achieves 50%+ reduction in cold import time (SC-001).
# =============================================================================
# Mapping of lazy export names to their source modules
_LAZY_MODULES: dict[str, str] = {
# Streaming & Large Dataset (h5py dependency, specialty use case)
"AdaptiveHybridStreamingOptimizer": "nlsq.streaming.adaptive_hybrid",
"DefenseLayerTelemetry": "nlsq.streaming.adaptive_hybrid",
"get_defense_telemetry": "nlsq.streaming.adaptive_hybrid",
"reset_defense_telemetry": "nlsq.streaming.adaptive_hybrid",
"HybridStreamingConfig": "nlsq.streaming.hybrid_config",
"LargeDatasetFitter": "nlsq.streaming.large_dataset",
"LDMemoryConfig": "nlsq.streaming.large_dataset",
"estimate_memory_requirements": "nlsq.streaming.large_dataset",
"fit_large_dataset": "nlsq.streaming.large_dataset",
# Global Optimization
"GlobalOptimizationConfig": "nlsq.global_optimization",
"MultiStartOrchestrator": "nlsq.global_optimization",
"TournamentSelector": "nlsq.global_optimization",
# Profiling & Visualization
"PerformanceProfiler": "nlsq.utils.profiler",
"ProfileMetrics": "nlsq.utils.profiler",
"clear_profiling_data": "nlsq.utils.profiler",
"get_global_profiler": "nlsq.utils.profiler",
"ProfilerVisualization": "nlsq.utils.profiler_visualization",
"ProfilingDashboard": "nlsq.utils.profiler_visualization",
# Diagnostics
"ConvergenceMonitor": "nlsq.utils.diagnostics",
"OptimizationDiagnostics": "nlsq.utils.diagnostics",
# Memory Management
"MemoryManager": "nlsq.caching.memory_manager",
"clear_memory_pool": "nlsq.caching.memory_manager",
"get_memory_manager": "nlsq.caching.memory_manager",
"get_memory_stats": "nlsq.caching.memory_manager",
"MemoryPool": "nlsq.caching.memory_pool",
"TRFMemoryPool": "nlsq.caching.memory_pool",
"clear_global_pool": "nlsq.caching.memory_pool",
"get_global_pool": "nlsq.caching.memory_pool",
# Fallback & Recovery
"FallbackOrchestrator": "nlsq.stability.fallback",
"FallbackResult": "nlsq.stability.fallback",
"FallbackStrategy": "nlsq.stability.fallback",
"OptimizationRecovery": "nlsq.stability.recovery",
# Sparse Jacobian
"SparseJacobianComputer": "nlsq.core.sparse_jacobian",
"SparseOptimizer": "nlsq.core.sparse_jacobian",
"detect_jacobian_sparsity": "nlsq.core.sparse_jacobian",
# Workflow System (014-unified-memory-strategy)
"MemoryBudget": "nlsq.core.workflow",
"MemoryBudgetSelector": "nlsq.core.workflow",
"OptimizationGoal": "nlsq.core.workflow",
# Algorithm Selection
"AlgorithmSelector": "nlsq.precision.algorithm_selector",
"auto_select_algorithm": "nlsq.precision.algorithm_selector",
# Bounds Inference
"BoundsInference": "nlsq.precision.bound_inference",
"infer_bounds": "nlsq.precision.bound_inference",
"merge_bounds": "nlsq.precision.bound_inference",
# Robust Decomposition
"RobustDecomposition": "nlsq.stability.robust_decomposition",
"robust_decomp": "nlsq.stability.robust_decomposition",
# Parameter Normalizer
"ParameterNormalizer": "nlsq.precision.parameter_normalizer",
# Stability
"NumericalStabilityGuard": "nlsq.stability.guard",
"apply_automatic_fixes": "nlsq.stability.guard",
"check_problem_stability": "nlsq.stability.guard",
"detect_collinearity": "nlsq.stability.guard",
"detect_parameter_scale_mismatch": "nlsq.stability.guard",
"estimate_condition_number": "nlsq.stability.guard",
# Configuration
"LargeDatasetConfig": "nlsq.config",
"MemoryConfig": "nlsq.config",
"configure_for_large_datasets": "nlsq.config",
"get_large_dataset_config": "nlsq.config",
"get_memory_config": "nlsq.config",
"large_dataset_context": "nlsq.config",
"memory_context": "nlsq.config",
"set_memory_limits": "nlsq.config",
# Caching
"SmartCache": "nlsq.caching.smart_cache",
"cached_function": "nlsq.caching.smart_cache",
"cached_jacobian": "nlsq.caching.smart_cache",
"clear_all_caches": "nlsq.caching.smart_cache",
"get_global_cache": "nlsq.caching.smart_cache",
"get_jit_cache": "nlsq.caching.smart_cache",
# Compilation Cache
"CompilationCache": "nlsq.caching.compilation_cache",
"cached_jit": "nlsq.caching.compilation_cache",
"clear_compilation_cache": "nlsq.caching.compilation_cache",
"get_global_compilation_cache": "nlsq.caching.compilation_cache",
# Validators
"InputValidator": "nlsq.utils.validators",
# Model Health Diagnostics
"HealthStatus": "nlsq.diagnostics.types",
"IssueSeverity": "nlsq.diagnostics.types",
"IssueCategory": "nlsq.diagnostics.types",
"DiagnosticLevel": "nlsq.diagnostics.types",
"ModelHealthIssue": "nlsq.diagnostics.types",
"IdentifiabilityReport": "nlsq.diagnostics.types",
"GradientHealthReport": "nlsq.diagnostics.types",
"ParameterSensitivityReport": "nlsq.diagnostics.types",
"ModelHealthReport": "nlsq.diagnostics.types",
"DiagnosticsConfig": "nlsq.diagnostics.types",
"IdentifiabilityAnalyzer": "nlsq.diagnostics.identifiability",
"GradientMonitor": "nlsq.diagnostics.gradient_health",
"ParameterSensitivityAnalyzer": "nlsq.diagnostics.parameter_sensitivity",
"DiagnosticPlugin": "nlsq.diagnostics.plugin",
"PluginRegistry": "nlsq.diagnostics.plugin",
"run_plugins": "nlsq.diagnostics.plugin",
"create_health_report": "nlsq.diagnostics.health_report",
}
# Cache for lazily-loaded attributes
_LAZY_CACHE: dict[str, Any] = {}
def __getattr__(name: str) -> Any:
"""Lazily load specialty modules on first access.
This enables faster import time by deferring loading of specialty
modules until they are actually needed.
"""
# Check if it's a lazy module
if name in _LAZY_MODULES:
# Return cached value if already loaded
if name in _LAZY_CACHE:
return _LAZY_CACHE[name]
# Import the module and get the attribute
module_path = _LAZY_MODULES[name]
try:
import importlib
module = importlib.import_module(module_path)
attr = getattr(module, name)
# Cache for future access (setdefault is atomic under GIL)
return _LAZY_CACHE.setdefault(name, attr)
except ImportError as e:
# Handle missing optional dependencies gracefully
raise ImportError(
f"Cannot import '{name}' from '{module_path}'. "
f"This may require an optional dependency. Error: {e}"
) from e
except AttributeError as e:
raise AttributeError(
f"Module '{module_path}' does not have attribute '{name}'"
) from e
# Not found - raise standard AttributeError
raise AttributeError(f"module 'nlsq' has no attribute '{name}'")
def __dir__() -> list[str]:
"""Return list of module attributes including lazy exports.
This allows tools like IPython and IDEs to see all available exports
even before they are loaded.
"""
# Get standard module attributes
module_attrs = list(globals().keys())
# Add all lazy module exports
lazy_attrs = list(_LAZY_MODULES.keys())
# Combine and deduplicate
return sorted(set(module_attrs + lazy_attrs))
# Public API - only expose main user-facing functions
__all__ = [
"AdaptiveHybridStreamingOptimizer",
"AlgorithmSelector",
"BoundsInference",
"CallbackBase",
"CallbackChain",
"CompilationCache",
"ConvergenceMonitor",
"CurveFit",
"DefenseLayerTelemetry",
"EarlyStopping",
"FallbackOrchestrator",
"FallbackResult",
"FallbackStrategy",
"GlobalOptimizationConfig",
"HybridStreamingConfig",
"InputValidator",
"IterationLogger",
"LDMemoryConfig",
"LargeDatasetConfig",
"LargeDatasetFitter",
"LeastSquares",
"MemoryBudget",
"MemoryBudgetSelector",
"MemoryConfig",
"MemoryManager",
"MemoryPool",
"MultiStartOrchestrator",
"NumericalStabilityGuard",
"OptimizationDiagnostics",
"OptimizationGoal",
"OptimizationRecovery",
"OptimizeResult",
"OptimizeWarning",
"ParameterNormalizer",
"PerformanceProfiler",
"ProfileMetrics",
"ProfilerVisualization",
"ProfilingDashboard",
"ProgressBar",
"RobustDecomposition",
"SmartCache",
"SparseJacobianComputer",
"SparseOptimizer",
"StopOptimization",
"TRFMemoryPool",
"TournamentSelector",
"__version__",
"apply_automatic_fixes",
"auto_select_algorithm",
"cached_function",
"cached_jacobian",
"cached_jit",
"check_problem_stability",
"clear_all_caches",
"clear_compilation_cache",
"clear_global_pool",
"clear_memory_pool",
"clear_profiling_data",
"configure_for_large_datasets",
"curve_fit",
"curve_fit_large",
"detect_collinearity",
"detect_jacobian_sparsity",
"detect_parameter_scale_mismatch",
"estimate_condition_number",
"estimate_memory_requirements",
"fit",
"fit_large_dataset",
"functions",
"get_defense_telemetry",
"get_global_cache",
"get_global_compilation_cache",
"get_global_pool",
"get_global_profiler",
"get_jit_cache",
"get_large_dataset_config",
"get_memory_config",
"get_memory_manager",
"get_memory_stats",
"infer_bounds",
"large_dataset_context",
"memory_context",
"merge_bounds",
"reset_defense_telemetry",
"robust_decomp",
"set_memory_limits",
]
# Preset configurations for the fit() function
_FIT_PRESETS = {
"fast": {
"n_starts": 0,
"multistart": False,
"description": "Single-start optimization for maximum speed",
},
"robust": {
"n_starts": 5,
"multistart": True,
"description": "Multi-start with 5 starts for robustness",
},
"global": {
"n_starts": 20,
"multistart": True,
"description": "Thorough global search with 20 starts",
},
"streaming": {
"n_starts": 10,
"multistart": True,
"use_streaming": True,
"description": "Streaming optimization for large datasets with multi-start",
},
"large": {
"n_starts": 5,
"multistart": True,
"use_large_dataset": True,
"description": "Auto-detect dataset size and use appropriate strategy",
},
}
[docs]
def fit(
f: ModelFunction,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None = None,
sigma: ArrayLike | None = None,
absolute_sigma: bool = False,
check_finite: bool = True,
bounds: BoundsTuple | tuple[float, float] = (-float("inf"), float("inf")),
method: MethodLiteral | None = None,
# NEW: Workflow parameter (v0.6.3)
workflow: str | None = None,
# Legacy: preset parameter (deprecated)
preset: Literal["fast", "robust", "global", "streaming", "large"] | None = None,
# Multi-start parameters (can override preset)
multistart: bool | None = None,
n_starts: int | None = None,
sampler: Literal["lhs", "sobol", "halton"] = "lhs",
center_on_p0: bool = True,
scale_factor: float = 1.0,
# Large dataset parameters
memory_limit_gb: float | None = None,
size_threshold: int = 1_000_000,
show_progress: bool = False,
chunk_size: int | None = None,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray] | OptimizeResult:
"""Unified curve fitting function with preset-based configuration.
This function provides a simplified API for curve fitting with sensible
defaults based on preset configurations. It automatically selects the
appropriate backend (curve_fit, curve_fit_large, or streaming) based on
the preset and dataset characteristics.
Parameters
----------
f : callable
Model function f(x, \\*params) -> y. Must use jax.numpy operations.
xdata : array_like
Independent variable data.
ydata : array_like
Dependent variable data.
p0 : array_like, optional
Initial parameter guess.
sigma : array_like, optional
Uncertainties in ydata for weighted fitting.
absolute_sigma : bool, optional
Whether sigma represents absolute uncertainties.
check_finite : bool, optional
Check for finite input values.
bounds : tuple, optional
Parameter bounds as (lower, upper).
method : str, optional
Optimization algorithm ('trf', 'lm', or None for auto).
preset : {'fast', 'robust', 'global', 'streaming', 'large'}, optional
Preset configuration to use:
- 'fast': Single-start optimization for maximum speed (n_starts=0)
- 'robust': Multi-start with 5 starts for robustness
- 'global': Thorough global search with 20 starts
- 'streaming': Streaming optimization for large datasets with multi-start
- 'large': Auto-detect dataset size and use appropriate strategy
If None, defaults to 'fast' for small datasets or 'large' for datasets
exceeding size_threshold.
multistart : bool, optional
Override preset's multi-start setting.
n_starts : int, optional
Override preset's n_starts setting.
sampler : {'lhs', 'sobol', 'halton'}, optional
Sampling strategy for multi-start. Default: 'lhs'.
center_on_p0 : bool, optional
Center multi-start samples around p0. Default: True.
scale_factor : float, optional
Scale factor for exploration region. Default: 1.0.
memory_limit_gb : float, optional
Maximum memory usage in GB for large datasets.
size_threshold : int, optional
Point threshold for large dataset processing (default: 1M).
show_progress : bool, optional
Display progress bar for long operations.
chunk_size : int, optional
Override automatic chunk size calculation.
**kwargs
Additional optimization parameters (ftol, xtol, gtol, max_nfev, loss).
Returns
-------
result : CurveFitResult or tuple
Optimization result. Contains popt, pcov, and multistart_diagnostics.
Supports tuple unpacking: popt, pcov = fit(...)
Examples
--------
Basic usage with default preset:
>>> popt, pcov = fit(model_func, xdata, ydata, p0=[1, 2, 3])
Using 'robust' preset for multi-start:
>>> result = fit(model_func, xdata, ydata, p0=[1, 2, 3],
... bounds=([0, 0, 0], [10, 10, 10]), preset='robust')
Using 'global' preset for thorough search:
>>> result = fit(model_func, xdata, ydata, p0=[1, 2, 3],
... bounds=([0, 0, 0], [10, 10, 10]), preset='global')
Large dataset with auto-detection:
>>> result = fit(model_func, big_xdata, big_ydata,
... preset='large', show_progress=True)
See Also
--------
curve_fit : Lower-level API with full control
curve_fit_large : Specialized API for large datasets
MemoryBudgetSelector : Memory-based optimizer selection
"""
# NEW WORKFLOW SYSTEM (v0.6.3)
# If workflow parameter is specified, delegate to minpack.fit()
if workflow is not None:
from nlsq.core.minpack import fit as minpack_fit
# Forward to the new workflow-based fit
return minpack_fit(
f=f,
xdata=xdata,
ydata=ydata,
p0=p0,
sigma=sigma,
absolute_sigma=absolute_sigma,
check_finite=check_finite,
bounds=bounds,
method=method,
workflow=workflow,
n_starts=n_starts,
memory_limit_gb=memory_limit_gb,
**kwargs,
)
# LEGACY PATH: preset-based configuration (for backwards compatibility)
# Emit deprecation warning if preset is explicitly provided
if preset is not None:
import warnings
warnings.warn(
"The 'preset' parameter is deprecated and will be removed in v0.8.0. "
"Use the 'workflow' parameter instead: "
"workflow='auto' (default), 'auto_global' (robust), or 'hpc' (streaming). "
"See migration guide: https://nlsq.readthedocs.io/en/latest/migration.html",
FutureWarning,
stacklevel=2,
)
# Input validation
xdata = np.asarray(xdata)
ydata = np.asarray(ydata)
n_points = len(xdata)
# Auto-select preset if not provided
if preset is None:
if n_points >= size_threshold:
preset = "large"
else:
preset = "fast"
# Get preset configuration
preset_config = _FIT_PRESETS.get(preset, _FIT_PRESETS["fast"])
# Apply preset defaults, allowing overrides
effective_multistart: bool = (
multistart
if multistart is not None
else bool(preset_config.get("multistart", False))
)
_n_starts_default: int = preset_config.get("n_starts", 0) # type: ignore[assignment]
effective_n_starts: int = n_starts if n_starts is not None else _n_starts_default
use_streaming = preset_config.get("use_streaming", False)
use_large_dataset = preset_config.get("use_large_dataset", False)
# Determine which backend to use
if use_streaming or preset == "streaming":
# Use AdaptiveHybridStreaming for streaming preset
from nlsq.streaming.adaptive_hybrid import AdaptiveHybridStreamingOptimizer
from nlsq.streaming.hybrid_config import HybridStreamingConfig
# Prepare p0
if p0 is None:
from inspect import signature
sig = signature(f)
args = sig.parameters
if len(args) < 2:
raise ValueError("Unable to determine number of fit parameters.")
n_params = len(args) - 1
p0 = np.ones(n_params)
p0 = np.atleast_1d(p0)
# Prepare bounds
from nlsq.core.least_squares import prepare_bounds
lb, ub = prepare_bounds(bounds, len(p0))
bounds_tuple = (
(lb, ub)
if not (np.all(np.isneginf(lb)) and np.all(np.isposinf(ub)))
else None
)
# Create config with multi-start settings
# Use the correct parameter names from HybridStreamingConfig
config = HybridStreamingConfig(
enable_multistart=effective_multistart and effective_n_starts > 0,
n_starts=effective_n_starts if effective_multistart else 10,
multistart_sampler=sampler,
)
optimizer = AdaptiveHybridStreamingOptimizer(config=config)
result_dict = optimizer.fit(
data_source=(xdata, ydata),
func=f,
p0=p0, # type: ignore[arg-type]
bounds=bounds_tuple, # type: ignore[arg-type]
sigma=sigma, # type: ignore[arg-type]
absolute_sigma=absolute_sigma,
callback=kwargs.get("callback"),
verbose=kwargs.get("verbose", 1),
)
# Convert to standard result format
from nlsq.result import CurveFitResult
result = CurveFitResult(result_dict)
result["pcov"] = result_dict.get("pcov", np.full((len(p0), len(p0)), np.inf))
result["multistart_diagnostics"] = {
"n_starts_configured": effective_n_starts,
"bypassed": not effective_multistart or effective_n_starts == 0,
"preset": preset,
}
return result
elif use_large_dataset or n_points >= size_threshold:
# Use curve_fit_large for large datasets
return curve_fit_large(
f,
xdata,
ydata,
p0=p0,
sigma=sigma,
absolute_sigma=absolute_sigma,
check_finite=check_finite,
bounds=bounds,
method=method,
memory_limit_gb=memory_limit_gb,
size_threshold=size_threshold,
show_progress=show_progress,
chunk_size=chunk_size,
multistart=effective_multistart,
n_starts=effective_n_starts,
sampler=sampler,
center_on_p0=center_on_p0,
scale_factor=scale_factor,
**kwargs,
)
else:
# Use standard curve_fit
return curve_fit(
f,
xdata,
ydata,
p0=p0,
sigma=sigma,
absolute_sigma=absolute_sigma,
check_finite=check_finite,
bounds=bounds,
method=method, # type: ignore[arg-type]
multistart=effective_multistart,
n_starts=effective_n_starts,
sampler=sampler,
center_on_p0=center_on_p0,
scale_factor=scale_factor,
**kwargs,
)
# Convenience function for large dataset curve fitting
[docs]
def curve_fit_large(
f: ModelFunction,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None = None,
sigma: ArrayLike | None = None,
absolute_sigma: bool = False,
check_finite: bool = True,
bounds: BoundsTuple | tuple[float, float] = (-float("inf"), float("inf")),
method: MethodLiteral | None = None,
# Stability parameters
stability: Literal["auto", "check", False] = False,
rescale_data: bool = True,
max_jacobian_elements_for_svd: int = 10_000_000,
# Large dataset specific parameters
memory_limit_gb: float | None = None,
auto_size_detection: bool = True,
size_threshold: int = 1_000_000, # 1M points
show_progress: bool = False,
chunk_size: int | None = None,
# Multi-start optimization parameters (Task Group 5)
multistart: bool = False,
n_starts: int = 10,
global_search: bool = False,
sampler: Literal["lhs", "sobol", "halton"] = "lhs",
center_on_p0: bool = True,
scale_factor: float = 1.0,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray] | OptimizeResult:
"""Curve fitting with automatic memory management for large datasets.
Automatically selects processing strategy based on dataset size:
- Small (< 1M points): Standard curve_fit
- Medium (1M - 100M points): Chunked processing
- Large (> 100M points): Streaming optimization
Parameters
----------
f : callable
Model function f(x, \\*params) -> y. Must use jax.numpy operations.
xdata : array_like
Independent variable data.
ydata : array_like
Dependent variable data.
p0 : array_like, optional
Initial parameter guess.
sigma : array_like, optional
Uncertainties in ydata for weighted fitting.
absolute_sigma : bool, optional
Whether sigma represents absolute uncertainties.
check_finite : bool, optional
Check for finite input values.
bounds : tuple, optional
Parameter bounds as (lower, upper).
method : str, optional
Optimization algorithm ('trf', 'lm', or None for auto).
memory_limit_gb : float, optional
Maximum memory usage in GB.
auto_size_detection : bool, optional
Auto-detect dataset size for processing strategy.
size_threshold : int, optional
Point threshold for large dataset processing (default: 1M).
show_progress : bool, optional
Display progress bar for long operations.
chunk_size : int, optional
Override automatic chunk size calculation.
multistart : bool, optional
Enable multi-start optimization for global search. Default: False.
n_starts : int, optional
Number of starting points for multi-start optimization. Default: 10.
global_search : bool, optional
Shorthand for multistart=True, n_starts=20. Default: False.
sampler : {'lhs', 'sobol', 'halton'}, optional
Sampling strategy for multi-start. Default: 'lhs'.
center_on_p0 : bool, optional
Center multi-start samples around p0. Default: True.
scale_factor : float, optional
Scale factor for exploration region. Default: 1.0.
**kwargs
Additional optimization parameters (ftol, xtol, gtol, max_nfev, loss)
Returns
-------
popt : ndarray
Fitted parameters.
pcov : ndarray
Parameter covariance matrix.
Notes
-----
All large datasets use streaming optimization for 100% data utilization.
**Important: Model Function Requirements for Chunking**
When auto_size_detection triggers chunked processing (>1M points), your model
function MUST respect the size of xdata. Model output shape must match ydata shape.
INCORRECT - Fixed-size output (causes shape errors):
>>> def bad_model(xdata, a, b):
... # WRONG: Returns fixed-size array regardless of xdata
... t_full = jnp.arange(10_000_000)
... return a * jnp.exp(-b * t_full) # Shape mismatch!
CORRECT - Output matches xdata size:
>>> def good_model(xdata, a, b):
... # CORRECT: Uses xdata as indices
... indices = xdata.astype(jnp.int32)
... return a * jnp.exp(-b * indices)
>>> def direct_model(xdata, a, b):
... # CORRECT: Operates on xdata directly
... return a * jnp.exp(-b * xdata)
Examples
--------
Basic usage:
>>> popt, _pcov = curve_fit_large(model_func, xdata, ydata, p0=[1, 2, 3])
Large dataset with progress bar:
>>> popt, _pcov = curve_fit_large(model_func, big_xdata, big_ydata,
... show_progress=True, memory_limit_gb=8)
With multi-start optimization:
>>> popt, _pcov = curve_fit_large(model_func, xdata, ydata,
... p0=[1, 2, 3], bounds=([0, 0, 0], [10, 10, 10]),
... multistart=True, n_starts=10)
Using external logger for diagnostics:
>>> import logging
>>> my_logger = logging.getLogger("myapp")
>>> fitter = LargeDatasetFitter(memory_limit_gb=8, logger=my_logger)
>>> result = fitter.fit(model_func, xdata, ydata, p0=[1, 2])
>>> # Chunk failures now appear in myapp's logs
"""
# Handle global_search shorthand
if global_search:
multistart = True
n_starts = 20
# Reject removed sampling parameters
removed_params = {"enable_sampling", "sampling_threshold", "max_sampled_size"}
for param in removed_params:
if param in kwargs:
raise TypeError(
f"curve_fit_large() got an unexpected keyword argument '{param}'. "
"This parameter was removed in v0.2.0. Use streaming instead."
)
# Input validation
xdata = np.asarray(xdata)
ydata = np.asarray(ydata)
# Check for edge cases
if len(xdata) == 0:
raise ValueError("`xdata` cannot be empty.")
if len(ydata) == 0:
raise ValueError("`ydata` cannot be empty.")
if len(xdata) != len(ydata):
raise ValueError(
f"`xdata` and `ydata` must have the same length: {len(xdata)} vs {len(ydata)}."
)
if len(xdata) < 2:
raise ValueError(f"Need at least 2 data points for fitting, got {len(xdata)}.")
n_points = len(xdata)
# Handle hybrid_streaming method specially
if method == "hybrid_streaming":
from nlsq.streaming.adaptive_hybrid import AdaptiveHybridStreamingOptimizer
from nlsq.streaming.hybrid_config import HybridStreamingConfig
# Extract verbosity from kwargs
verbose = kwargs.pop("verbose", 1)
# Create configuration (allow kwargs to override defaults)
config_overrides = {}
for key in list(kwargs.keys()):
if hasattr(HybridStreamingConfig, key):
config_overrides[key] = kwargs.pop(key)
config = (
HybridStreamingConfig(**config_overrides)
if config_overrides
else HybridStreamingConfig()
)
# Prepare p0 and bounds
if p0 is None:
from inspect import signature
sig = signature(f)
args = sig.parameters
if len(args) < 2:
raise ValueError("Unable to determine number of fit parameters.")
n_params = len(args) - 1
p0 = np.ones(n_params)
p0 = np.atleast_1d(p0)
from nlsq.core.least_squares import prepare_bounds
lb, ub = prepare_bounds(bounds, len(p0))
bounds_tuple = (
(lb, ub)
if not (np.all(np.isneginf(lb)) and np.all(np.isposinf(ub)))
else None
)
# Create optimizer
optimizer = AdaptiveHybridStreamingOptimizer(config=config)
# Run optimization
result_dict = optimizer.fit(
data_source=(xdata, ydata),
func=f,
p0=p0, # type: ignore[arg-type]
bounds=bounds_tuple, # type: ignore[arg-type]
sigma=sigma, # type: ignore[arg-type]
absolute_sigma=absolute_sigma,
callback=kwargs.get("callback"),
verbose=verbose,
)
# Convert to tuple format for backward compatibility
popt = result_dict["x"]
pcov = result_dict.get("pcov", np.full((len(p0), len(p0)), np.inf))
return popt, pcov
# Auto-detect if we should use large dataset processing
if auto_size_detection and n_points < size_threshold:
# Use regular curve_fit for small datasets
# Rebuild kwargs for curve_fit
fit_kwargs = kwargs.copy()
if p0 is not None:
fit_kwargs["p0"] = p0
if sigma is not None:
fit_kwargs["sigma"] = sigma
if bounds != (-float("inf"), float("inf")):
fit_kwargs["bounds"] = bounds
if method is not None:
fit_kwargs["method"] = method
fit_kwargs["absolute_sigma"] = absolute_sigma
fit_kwargs["check_finite"] = check_finite
fit_kwargs["stability"] = stability
fit_kwargs["rescale_data"] = rescale_data
fit_kwargs["max_jacobian_elements_for_svd"] = max_jacobian_elements_for_svd
# Add multi-start parameters
fit_kwargs["multistart"] = multistart
fit_kwargs["n_starts"] = n_starts
fit_kwargs["sampler"] = sampler
fit_kwargs["center_on_p0"] = center_on_p0
fit_kwargs["scale_factor"] = scale_factor
return curve_fit(f, xdata, ydata, **fit_kwargs)
# Use large dataset processing
# Import lazy modules needed for large dataset processing
from nlsq.config import (
LargeDatasetConfig,
MemoryConfig,
large_dataset_context,
memory_context,
)
from nlsq.streaming.large_dataset import LargeDatasetFitter, LDMemoryConfig
# Configure memory settings if provided
if memory_limit_gb is None:
# Auto-detect available memory
try:
import psutil
available_gb = psutil.virtual_memory().available / (1024**3)
memory_limit_gb = min(8.0, available_gb * 0.7) # Use 70% of available
except ImportError:
memory_limit_gb = 8.0 # Conservative default
# Create memory configuration
memory_config = MemoryConfig(
memory_limit_gb=memory_limit_gb,
progress_reporting=show_progress,
min_chunk_size=max(1000, n_points // 10000), # Dynamic min chunk size
max_chunk_size=min(1_000_000, n_points // 10)
if chunk_size is None
else chunk_size,
)
# Create large dataset configuration (v0.2.0: no more sampling params)
large_dataset_config = LargeDatasetConfig(
enable_automatic_solver_selection=True,
)
# Use context managers to temporarily set configuration
with memory_context(memory_config), large_dataset_context(large_dataset_config):
# Create fitter with current configuration
fitter = LargeDatasetFitter(
memory_limit_gb=memory_limit_gb,
config=LDMemoryConfig(
memory_limit_gb=memory_limit_gb,
min_chunk_size=memory_config.min_chunk_size,
max_chunk_size=memory_config.max_chunk_size,
),
)
# Handle sigma parameter by including it in kwargs if provided
if sigma is not None:
kwargs["sigma"] = sigma
kwargs["absolute_sigma"] = absolute_sigma
kwargs["check_finite"] = check_finite
# Add multi-start parameters to kwargs
kwargs["multistart"] = multistart
kwargs["n_starts"] = n_starts
kwargs["sampler"] = sampler
kwargs["center_on_p0"] = center_on_p0
kwargs["scale_factor"] = scale_factor
# Convert p0 to appropriate type for LargeDatasetFitter
# LargeDatasetFitter expects np.ndarray | list | None (no tuple or jnp.ndarray)
converted_p0: np.ndarray | list | None
if p0 is None:
converted_p0 = None
elif isinstance(p0, list):
converted_p0 = p0
else:
# Convert tuple, jnp.ndarray, or np.ndarray to np.ndarray
converted_p0 = np.asarray(p0)
# Provide default method if None
final_method = method if method is not None else "trf"
# Perform the fit
if show_progress:
result = fitter.fit_with_progress(
f,
xdata,
ydata,
p0=converted_p0,
bounds=bounds,
method=final_method,
**kwargs, # type: ignore
)
else:
result = fitter.fit(
f,
xdata,
ydata,
p0=converted_p0,
bounds=bounds,
method=final_method,
**kwargs, # type: ignore
)
# Extract popt and pcov from result
if hasattr(result, "popt") and hasattr(result, "pcov"):
return result.popt, result.pcov
elif hasattr(result, "x"):
# Fallback: construct basic covariance matrix
popt = result.x
# Create identity covariance matrix if not available
pcov = np.eye(len(popt))
return popt, pcov
else:
raise RuntimeError(
f"Unexpected result format from large dataset fitter: {result}"
)
# Optional: Provide convenience access to submodules for advanced users
# Users can still access internal functions via:
# from nlsq.core.loss_functions import LossFunctionsJIT
# from nlsq.core.trf import TrustRegionReflective
# etc.
# Check GPU availability on import (non-intrusive warning)
# This helps users realize when GPU acceleration is available but not being used
from nlsq.device import check_gpu_availability
check_gpu_availability()
_dbg("nlsq import complete")