nlsq.streaming.large_dataset.fit_large_dataset¶
- nlsq.streaming.large_dataset.fit_large_dataset(f, xdata, ydata, p0=None, memory_limit_gb=8.0, show_progress=False, logger=None, multistart=False, n_starts=10, sampler='lhs', **kwargs)[source]¶
Convenience function for fitting large datasets.
- Parameters:
f (callable) – The model function f(x, *params) -> y
xdata (np.ndarray) – Independent variable data
ydata (np.ndarray) – Dependent variable data
p0 (array-like, optional) – Initial parameter guess
memory_limit_gb (float, optional) – Memory limit in GB (default: 8.0)
show_progress (bool, optional) – Whether to show progress (default: False)
logger (logging.Logger, optional) – External logger for application integration (default: None)
multistart (bool, optional) – Enable multi-start optimization for global search (default: False). When enabled, explores multiple starting points on full data before running the full chunked optimization.
n_starts (int, optional) – Number of starting points for multi-start optimization (default: 10). Set to 0 to disable multi-start even when multistart=True.
sampler (str, optional) – Sampling strategy for generating starting points (default: ‘lhs’). Options: ‘lhs’ (Latin Hypercube), ‘sobol’, ‘halton’.
**kwargs – Additional arguments passed to curve_fit
- Returns:
Optimization result
- Return type:
Examples
>>> from nlsq.streaming.large_dataset import fit_large_dataset >>> import numpy as np >>> import jax.numpy as jnp >>> >>> # Generate large dataset >>> x_large = np.linspace(0, 10, 5_000_000) >>> y_large = 2.5 * np.exp(-1.3 * x_large) + np.random.normal(0, 0.1, len(x_large)) >>> >>> # Fit with automatic memory management >>> result = fit_large_dataset( ... lambda x, a, b: a * jnp.exp(-b * x), ... x_large, y_large, ... p0=[2.0, 1.0], ... memory_limit_gb=4.0, ... show_progress=True ... ) >>> print(f"Fitted parameters: {result.popt}") >>> print(f"Success rate: {result.success_rate:.1%}") >>> >>> # Fit with multi-start optimization >>> result = fit_large_dataset( ... lambda x, a, b: a * jnp.exp(-b * x), ... x_large, y_large, ... p0=[2.0, 1.0], ... bounds=([0, 0], [10, 5]), ... multistart=True, ... n_starts=10, ... sampler='lhs' ... ) >>> >>> # Check failure diagnostics if some chunks failed >>> if result.failure_summary['total_failures'] > 0: ... print(f"Failed chunks: {result.failure_summary['failed_chunk_indices']}") ... print(f"Common errors: {result.failure_summary['common_errors']}")