Source code for nlsq.streaming.large_dataset

"""Large Dataset Fitting Module for NLSQ.

This module provides utilities for efficiently fitting curve parameters to very large datasets
(>10M points) with intelligent memory management, automatic chunking, and progress reporting.
"""

# mypy: disable-error-code="assignment,arg-type,var-annotated,misc,attr-defined,index"
# Note: mypy errors are mostly assignment/index issues from dict-based result
# accumulation in chunked fitting. These require deeper refactoring.

from __future__ import annotations

import gc
import time
from collections import defaultdict
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from logging import Logger
from typing import TYPE_CHECKING, Literal

import jax
import numpy as np
import psutil

# Initialize JAX configuration through central config
from nlsq.config import JAXConfig

_jax_config = JAXConfig()


from nlsq.result import OptimizeResult
from nlsq.streaming.adaptive_hybrid import AdaptiveHybridStreamingOptimizer
from nlsq.streaming.hybrid_config import HybridStreamingConfig
from nlsq.utils.logging import get_logger

# Type-only imports to avoid circular dependencies
if TYPE_CHECKING:
    from nlsq.core.minpack import CurveFit

# Default fallback memory in GB when detection fails (per requirements)
_DEFAULT_FALLBACK_MEMORY_GB = 16.0

# Power-of-2 bucket sizes for static array shapes during chunked processing
# This eliminates JIT recompilation overhead by ensuring all chunks pad to
# a fixed set of sizes, enabling efficient compilation cache reuse.
# See research.md for rationale on power-of-2 buckets.
CHUNK_BUCKETS: tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072)


# Module-level device cache for performance (FR-004)
# Use functools.cache for clean caching without global statement
[docs] @cache def get_cached_devices() -> list: """Get cached JAX devices, computing once on first call.""" return jax.devices()
def get_bucket_size(chunk_size: int) -> int: """Get the smallest bucket size that can contain the given chunk. Parameters ---------- chunk_size : int Actual chunk size in data points. Returns ------- int Bucket size from CHUNK_BUCKETS that is >= chunk_size. If chunk_size exceeds max bucket, returns chunk_size unchanged. Examples -------- >>> get_bucket_size(1000) 1024 >>> get_bucket_size(5000) 8192 >>> get_bucket_size(200000) 200000 """ for bucket in CHUNK_BUCKETS: if bucket >= chunk_size: return bucket # Chunk exceeds largest bucket - return unchanged (will cause recompilation # but only for very large chunks which are less common) return chunk_size @dataclass(slots=True) class ChunkBuffer: """Pre-allocated static-shaped buffer for chunked data processing. Eliminates JIT recompilation by padding data to power-of-2 bucket sizes. The mask field allows filtering out padded elements when computing results. Attributes ---------- data : np.ndarray Padded data array with shape (bucket_size,) or (bucket_size, features). valid_length : int Actual number of valid samples (before padding). bucket_size : int Static buffer size from CHUNK_BUCKETS. mask : np.ndarray Boolean mask where True indicates valid elements. Examples -------- >>> import numpy as np >>> chunk = np.array([1.0, 2.0, 3.0]) >>> buffer = ChunkBuffer.from_array(chunk) >>> buffer.bucket_size 1024 >>> buffer.valid_length 3 >>> np.sum(buffer.mask) 3 """ data: np.ndarray valid_length: int bucket_size: int mask: np.ndarray @classmethod def from_array(cls, arr: np.ndarray, pad_value: float = 0.0) -> ChunkBuffer: """Create a ChunkBuffer from an array, padding to the appropriate bucket size. Parameters ---------- arr : np.ndarray Input array to pad. pad_value : float, optional Value to use for padding (default: 0.0). Returns ------- ChunkBuffer Buffer with data padded to bucket size. """ valid_length = len(arr) bucket_size = get_bucket_size(valid_length) # Create mask for valid elements mask = np.zeros(bucket_size, dtype=bool) mask[:valid_length] = True # Pad data to bucket size if valid_length == bucket_size: padded_data = arr elif arr.ndim == 1: padded_data = np.full(bucket_size, pad_value, dtype=arr.dtype) padded_data[:valid_length] = arr else: # Multi-dimensional: pad along first axis pad_shape = (bucket_size - valid_length, *arr.shape[1:]) padding = np.full(pad_shape, pad_value, dtype=arr.dtype) padded_data = np.concatenate([arr, padding], axis=0) return cls( data=padded_data, valid_length=valid_length, bucket_size=bucket_size, mask=mask, ) def get_valid_data(self) -> np.ndarray: """Return only the valid (non-padded) portion of the data.""" return self.data[: self.valid_length]
[docs] class ChunkBufferPool: """Pre-allocated buffer pool for streaming chunk operations (FR-006). Reduces memory allocation overhead by reusing buffers across chunks. This is particularly beneficial for large dataset processing where many chunks of similar size are processed sequentially. Parameters ---------- chunk_size : int Size of each chunk buffer dtype : np.dtype Data type for buffers (default: np.float64) Examples -------- >>> pool = ChunkBufferPool(chunk_size=10000) >>> x_buf, y_buf = pool.get_buffers(5000) >>> x_buf.shape (5000,) >>> # Buffers can be reused without reallocation >>> x_buf2, y_buf2 = pool.get_buffers(8000) >>> x_buf2.shape (8000,) """
[docs] def __init__(self, chunk_size: int, dtype: np.dtype = np.float64) -> None: """Initialize buffer pool with pre-allocated arrays. Parameters ---------- chunk_size : int Size of each chunk buffer dtype : np.dtype Data type for buffers """ self._x_buffer = np.empty(chunk_size, dtype=dtype) self._y_buffer = np.empty(chunk_size, dtype=dtype) self._chunk_size = chunk_size self._dtype = dtype
[docs] def get_buffers(self, size: int) -> tuple[np.ndarray, np.ndarray]: """Get buffer views of the specified size. Parameters ---------- size : int Actual size needed (must be <= chunk_size) Returns ------- tuple[np.ndarray, np.ndarray] Views into the pre-allocated x and y buffers """ if size > self._chunk_size: # Fallback: allocate new if larger than pool return np.empty(size, dtype=self._dtype), np.empty(size, dtype=self._dtype) return self._x_buffer[:size], self._y_buffer[:size]
@property def chunk_size(self) -> int: """Return the pool's chunk size.""" return self._chunk_size @property def dtype(self) -> np.dtype: """Return the pool's data type.""" return self._dtype
[docs] @dataclass(slots=True) class LDMemoryConfig: # Renamed to avoid conflict with config.py """Configuration for memory management in large dataset fitting. Attributes ---------- memory_limit_gb : float Maximum memory to use in GB (default: 8.0) safety_factor : float Safety factor for memory calculations (default: 0.8) min_chunk_size : int Minimum chunk size in data points (default: 1000) max_chunk_size : int Maximum chunk size in data points (default: 1000000) use_streaming : bool Use adaptive hybrid streaming optimization for unlimited data (default: True) streaming_batch_size : int Chunk size for adaptive hybrid streaming (default: 50000) streaming_max_epochs : int Maximum Gauss-Newton iterations for adaptive hybrid streaming (default: 10) min_success_rate : float Minimum success rate for chunked fitting (default: 0.5) If success rate falls below this threshold, fitting is considered failed save_diagnostics : bool Whether to compute and save detailed diagnostic statistics (default: False) When False, skips statistical computations for successful chunks (5-10% faster) gc_chunk_interval : int Chunks between gc.collect() calls (default: 10, FR-007) Controls how often garbage collection runs during chunked processing. Higher values reduce GC overhead but may increase memory usage. """ memory_limit_gb: float = 8.0 safety_factor: float = 0.8 min_chunk_size: int = 1000 max_chunk_size: int = 1_000_000 use_streaming: bool = True streaming_batch_size: int = 50000 streaming_max_epochs: int = 10 min_success_rate: float = 0.5 save_diagnostics: bool = False gc_chunk_interval: int = 10
[docs] @dataclass(slots=True) class DatasetStats: """Statistics and information about a dataset. Attributes ---------- n_points : int Total number of data points n_params : int Number of parameters to fit memory_per_point_bytes : float Estimated memory usage per data point in bytes total_memory_estimate_gb : float Estimated total memory requirement in GB recommended_chunk_size : int Recommended chunk size for processing n_chunks : int Number of chunks needed """ n_points: int n_params: int memory_per_point_bytes: float total_memory_estimate_gb: float recommended_chunk_size: int n_chunks: int
[docs] class GPUMemoryEstimator: """Utilities for estimating GPU memory availability. This class provides GPU memory detection via JAX's device API, handling multiple GPUs and graceful fallback for CPU-only environments. Examples -------- >>> from nlsq.streaming.large_dataset import GPUMemoryEstimator >>> estimator = GPUMemoryEstimator() >>> available_gb = estimator.get_available_gpu_memory_gb() >>> print(f"Available GPU memory: {available_gb:.2f} GB") """
[docs] def __init__(self) -> None: """Initialize GPUMemoryEstimator.""" self._logger = get_logger(__name__)
[docs] def get_available_gpu_memory_gb(self) -> float: """Get available GPU memory in GB. Queries GPU memory via `jax.devices()[i].memory_stats()` and aggregates available memory across all GPUs. Returns 0 for CPU-only environments. Returns ------- float Available GPU memory in GB, or 0.0 if no GPU or detection fails. Notes ----- - Re-evaluates on each call (no caching) per requirements. - Handles multiple GPUs by summing available memory. - Returns 0.0 gracefully for CPU-only environments or when detection fails. """ total_available_bytes = 0.0 try: devices = get_cached_devices() for device in devices: # Skip CPU devices platform = getattr(device, "platform", "cpu") if platform == "cpu": continue try: # Query memory stats from GPU device memory_stats = device.memory_stats() if memory_stats is not None: # Calculate available memory: limit - in_use bytes_limit = memory_stats.get("bytes_limit", 0) bytes_in_use = memory_stats.get("bytes_in_use", 0) available = bytes_limit - bytes_in_use if available > 0: total_available_bytes += available self._logger.debug( f"GPU device {device}: " f"{available / (1024**3):.2f} GB available " f"({bytes_in_use / (1024**3):.2f} GB in use)" ) except Exception as e: # Individual device query failed - continue with others self._logger.debug( f"GPU memory query failed for device {device}: {e}" ) continue except Exception as e: # Complete device enumeration failed self._logger.debug(f"GPU device enumeration failed: {e}") return 0.0 return total_available_bytes / (1024**3)
[docs] def get_total_gpu_memory_gb(self) -> float: """Get total GPU memory capacity in GB. Returns ------- float Total GPU memory capacity in GB, or 0.0 if no GPU. """ total_capacity_bytes = 0.0 try: devices = get_cached_devices() for device in devices: platform = getattr(device, "platform", "cpu") if platform == "cpu": continue try: memory_stats = device.memory_stats() if memory_stats is not None: bytes_limit = memory_stats.get("bytes_limit", 0) total_capacity_bytes += bytes_limit except Exception: continue except Exception: return 0.0 return total_capacity_bytes / (1024**3)
[docs] def has_gpu(self) -> bool: """Check if any GPU is available. Returns ------- bool True if at least one GPU device is available. """ try: devices = get_cached_devices() for device in devices: platform = getattr(device, "platform", "cpu") if platform != "cpu": return True except Exception: pass return False
[docs] class MemoryEstimator: """Utilities for estimating memory usage and optimal chunk sizes. This class provides CPU memory detection via psutil, with fallback to 16GB when detection fails (e.g., containerized environments with cgroups). """
[docs] @staticmethod def estimate_memory_per_point(n_params: int, use_jacobian: bool = True) -> float: """Estimate memory usage per data point in bytes. Parameters ---------- n_params : int Number of parameters use_jacobian : bool, optional Whether Jacobian computation is needed (default: True) Returns ------- float Estimated memory usage per point in bytes """ # Estimate memory per data point base_memory = 3 * 8 # x, y, residual (float64) jacobian_memory = n_params * 8 if use_jacobian else 0 work_memory = n_params * 2 * 8 # optimization workspace jax_overhead = 50 # XLA + GPU overhead return base_memory + jacobian_memory + work_memory + jax_overhead
[docs] @staticmethod def get_available_memory_gb() -> float: """Get available system memory in GB. Returns ------- float Available memory in GB. Falls back to 16GB if detection fails. Notes ----- - Re-evaluates on each call (no caching) per requirements. - Falls back to 16GB when detection fails (containerized environments). """ try: memory = psutil.virtual_memory() return memory.available / (1024**3) # Convert to GB except Exception: # Fallback estimate: 16GB per requirements (updated from 4GB) return _DEFAULT_FALLBACK_MEMORY_GB
[docs] @staticmethod def get_total_available_memory_gb() -> float: """Get total available memory (CPU + GPU) in GB. Combines available CPU memory from psutil with available GPU memory from JAX device API. Re-evaluates on each call (no caching). Returns ------- float Total available memory (CPU + GPU) in GB. Notes ----- - CPU memory: Uses psutil.virtual_memory().available - GPU memory: Uses JAX device API via GPUMemoryEstimator - Falls back to 16GB for CPU if detection fails - Returns 0 for GPU if detection fails or no GPU present """ # Get CPU memory cpu_memory_gb = MemoryEstimator.get_available_memory_gb() # Get GPU memory gpu_estimator = GPUMemoryEstimator() gpu_memory_gb = gpu_estimator.get_available_gpu_memory_gb() return cpu_memory_gb + gpu_memory_gb
[docs] @staticmethod def estimate_maximum_memory_usage_gb( n_points: int, n_params: int, safety_factor: float = 1.2 ) -> float: """Estimate maximum memory usage to prevent crashes. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters safety_factor : float, optional Safety factor for memory estimation (default: 1.2) Returns ------- float Estimated maximum memory usage in GB. Notes ----- This method estimates the peak memory usage during optimization, which is useful for workflow selection to prevent out-of-memory crashes. """ memory_per_point = MemoryEstimator.estimate_memory_per_point(n_params) total_bytes = n_points * memory_per_point * safety_factor return total_bytes / (1024**3)
[docs] @staticmethod def calculate_optimal_chunk_size( n_points: int, n_params: int, memory_config: LDMemoryConfig ) -> tuple[int, DatasetStats]: """Calculate optimal chunk size based on memory constraints. Parameters ---------- n_points : int Total number of data points n_params : int Number of parameters memory_config : LDMemoryConfig Memory configuration Returns ------- tuple[int, DatasetStats] Optimal chunk size and dataset statistics """ estimator = MemoryEstimator() # Estimate memory per point memory_per_point = estimator.estimate_memory_per_point(n_params) # Calculate available memory for processing available_memory_gb = ( min(memory_config.memory_limit_gb, estimator.get_available_memory_gb()) * memory_config.safety_factor ) available_memory_bytes = available_memory_gb * (1024**3) # Calculate optimal chunk size theoretical_chunk_size = int(available_memory_bytes / memory_per_point) # Apply constraints chunk_size = max( memory_config.min_chunk_size, min(memory_config.max_chunk_size, theoretical_chunk_size), ) # If we can fit all data in memory, use all points if n_points <= chunk_size: chunk_size = n_points n_chunks = 1 else: n_chunks = (n_points + chunk_size - 1) // chunk_size # Calculate total memory estimate total_memory_gb = (n_points * memory_per_point) / (1024**3) stats = DatasetStats( n_points=n_points, n_params=n_params, memory_per_point_bytes=memory_per_point, total_memory_estimate_gb=total_memory_gb, recommended_chunk_size=chunk_size, n_chunks=n_chunks, ) return chunk_size, stats
[docs] def cleanup_memory() -> None: """Perform memory cleanup between workflow phases. This function clears both Python garbage and JAX compilation caches, designed to be called between workflow phases to free memory. Notes ----- - Calls gc.collect() to trigger Python garbage collection - Calls jax.clear_caches() to clear JAX JIT compilation caches - Handles errors gracefully (does not raise exceptions) Examples -------- >>> from nlsq.streaming.large_dataset import cleanup_memory >>> # After completing a workflow phase >>> cleanup_memory() """ logger = get_logger(__name__) # Python garbage collection try: gc.collect() except Exception as e: logger.debug(f"gc.collect() failed (non-critical): {e}") # JAX cache cleanup try: jax.clear_caches() except Exception as e: logger.debug(f"jax.clear_caches() failed (non-critical): {e}")
[docs] class ProgressReporter: """Progress reporting for long-running fits."""
[docs] def __init__(self, total_chunks: int, logger=None): """Initialize progress reporter. Parameters ---------- total_chunks : int Total number of chunks to process logger : optional Logger instance for reporting progress """ self.total_chunks = total_chunks self.logger = logger or get_logger(__name__) self.start_time = time.time() self.completed_chunks = 0
[docs] def update(self, chunk_idx: int, chunk_result: dict | None = None): """Update progress. Parameters ---------- chunk_idx : int Index of completed chunk chunk_result : dict, optional Results from chunk processing """ self.completed_chunks = chunk_idx + 1 elapsed = time.time() - self.start_time if self.completed_chunks > 0: avg_time_per_chunk = elapsed / self.completed_chunks remaining_chunks = self.total_chunks - self.completed_chunks eta = avg_time_per_chunk * remaining_chunks else: eta = 0 progress_pct = (self.completed_chunks / self.total_chunks) * 100 self.logger.info( f"Progress: {self.completed_chunks}/{self.total_chunks} chunks " f"({progress_pct:.1f}%) - ETA: {eta:.1f}s" ) if chunk_result: self.logger.debug(f"Chunk {chunk_idx} result: {chunk_result}")
[docs] class DataChunker: """Utility for creating and managing data chunks."""
[docs] @staticmethod def create_chunks( xdata: np.ndarray, ydata: np.ndarray, chunk_size: int, shuffle: bool = False, random_seed: int | None = None, buffer_pool: ChunkBufferPool | None = None, ) -> Generator[tuple[np.ndarray, np.ndarray, int, int]]: """Create data chunks for processing. Parameters ---------- xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data chunk_size : int Size of each chunk shuffle : bool, optional Whether to shuffle data before chunking (default: False) random_seed : int, optional Random seed for shuffling buffer_pool : ChunkBufferPool, optional Pre-allocated buffer pool for chunk reuse (FR-006). If provided, chunks are written into pooled buffers to reduce allocation overhead. Yields ------ tuple[np.ndarray, np.ndarray, int, int] (x_chunk, y_chunk, chunk_index, valid_length) where valid_length is the actual number of data points (before padding) Notes ----- Uses power-of-2 bucket sizes from CHUNK_BUCKETS for JIT cache efficiency. This ensures consistent array shapes across chunks, enabling JAX to reuse compiled kernels and avoiding recompilation overhead. """ n_points = len(xdata) indices = np.arange(n_points) if shuffle: rng = np.random.default_rng(random_seed) rng.shuffle(indices) n_chunks = (n_points + chunk_size - 1) // chunk_size for i in range(n_chunks): start_idx = i * chunk_size end_idx = min(start_idx + chunk_size, n_points) chunk_indices = indices[start_idx:end_idx] # PERFORMANCE FIX: Pad to power-of-2 bucket sizes for JIT cache efficiency # Uses CHUNK_BUCKETS to ensure consistent shapes across different chunk sizes, # enabling JAX to reuse compiled kernels. This is more cache-efficient than # padding to chunk_size (which may vary) because power-of-2 buckets mean # fewer unique shapes and thus fewer JIT compilations. current_chunk_size = len(chunk_indices) bucket_size = get_bucket_size(current_chunk_size) if current_chunk_size < bucket_size: # Zero-pad indices to bucket size for JIT cache efficiency. # The yielded valid_length lets consumers use only real data points. pad_indices = np.zeros(bucket_size, dtype=chunk_indices.dtype) pad_indices[:current_chunk_size] = chunk_indices chunk_indices = pad_indices # FR-006: Use buffer pool if provided to reduce allocation overhead if buffer_pool is not None: x_buf, y_buf = buffer_pool.get_buffers(bucket_size) np.copyto(x_buf[:bucket_size], xdata[chunk_indices[:bucket_size]]) np.copyto(y_buf[:bucket_size], ydata[chunk_indices[:bucket_size]]) yield x_buf[:bucket_size], y_buf[:bucket_size], i, current_chunk_size else: yield xdata[chunk_indices], ydata[chunk_indices], i, current_chunk_size
[docs] class LargeDatasetFitter: """Large dataset curve fitting with automatic memory management and chunking. This class handles datasets with millions to billions of points that exceed available memory through automatic chunking, progressive parameter refinement, and streaming optimization. It maintains fitting accuracy while preventing memory overflow through dynamic memory monitoring and chunk size optimization. Core Capabilities ----------------- - Automatic memory estimation based on data size and parameter count - Dynamic chunk size calculation considering available system memory - Sequential parameter refinement across data chunks with convergence tracking - Streaming optimization for unlimited datasets (no accuracy loss) - Real-time progress monitoring with ETA for long-running fits - Full integration with NLSQ optimization algorithms and GPU acceleration - Multi-start optimization for global search (uses full data) Memory Management Algorithm --------------------------- 1. Estimates total memory requirements from dataset size and parameter count 2. Calculates optimal chunk sizes considering available memory and safety margins 3. Monitors actual memory usage during processing to prevent overflow 4. Uses streaming optimization for extremely large datasets (processes all data) Processing Strategies --------------------- - **Single Pass**: For datasets fitting within memory limits - **Sequential Chunking**: Processes data in optimal-sized chunks with parameter propagation - **Streaming Optimization**: Mini-batch gradient descent for unlimited datasets (no subsampling) Multi-Start Optimization ------------------------ For medium-sized datasets (1M-100M points), multi-start optimization explores multiple starting points on full data, and the best starting point is then used for the full chunked optimization. Performance Characteristics --------------------------- - Maintains <1% parameter error for well-conditioned problems using chunking - Achieves 5-50x speedup over naive approaches through memory optimization - Scales to datasets of unlimited size using streaming (processes all data) - Provides linear time complexity with respect to chunk count Model Validation Caching (Task Group 7 - 5.1a) ---------------------------------------------- Model functions are validated once per unique function identity using a cache keyed by (id(func), id(func.__code__)). This avoids redundant validation across chunks, providing 1-5% performance gain in chunked processing. Parameters ---------- memory_limit_gb : float, default 8.0 Maximum memory usage in GB. System memory is auto-detected if None. config : LDMemoryConfig, optional Advanced configuration for fine-tuning memory management behavior. curve_fit_class : nlsq.minpack.CurveFit, optional Custom CurveFit instance for specialized fitting requirements. multistart : bool, default False Enable multi-start optimization for global search. n_starts : int, default 10 Number of starting points for multi-start optimization. sampler : str, default 'lhs' Sampling strategy for multi-start: 'lhs', 'sobol', or 'halton'. Attributes ---------- config : LDMemoryConfig Active memory management configuration curve_fitter : nlsq.minpack.CurveFit Internal curve fitting engine with JAX acceleration logger : Logger Internal logging for performance monitoring and debugging Methods ------- fit : Main fitting method with automatic memory management fit_with_progress : Fitting with real-time progress reporting and ETA get_memory_recommendations : Pre-fitting memory analysis and strategy recommendations Important: Chunking-Compatible Model Functions ----------------------------------------------- When using chunked processing (for datasets > memory limit), your model function MUST respect the size of xdata. During chunking, xdata will be a subset of the full dataset, and your model must return output matching that subset size. **INCORRECT - Model ignores xdata size (will cause shape mismatch errors):** >>> def bad_model(xdata, a, b): ... # WRONG: Always returns full array, ignoring xdata size ... t_full = jnp.arange(10_000_000) # Fixed size! ... return a * jnp.exp(-b * t_full) # Shape mismatch during chunking **CORRECT - Model respects xdata size:** >>> def good_model(xdata, a, b): ... # CORRECT: Uses xdata as indices to return only requested subset ... indices = xdata.astype(jnp.int32) ... return a * jnp.exp(-b * indices) # Shape matches xdata **Alternative - Direct computation on xdata:** >>> def direct_model(xdata, a, b): ... # CORRECT: Operates directly on xdata ... return a * jnp.exp(-b * xdata) # Shape automatically matches Examples -------- Basic usage with automatic configuration: >>> import numpy as np >>> import jax.numpy as jnp >>> >>> # 10 million data points >>> x = np.linspace(0, 10, 10_000_000) >>> y = 2.5 * jnp.exp(-1.3 * x) + 0.1 + np.random.normal(0, 0.05, len(x)) >>> >>> fitter = LargeDatasetFitter(memory_limit_gb=4.0) >>> result = fitter.fit( ... lambda x, a, b, c: a * jnp.exp(-b * x) + c, ... x, y, p0=[2, 1, 0] ... ) >>> print(f"Parameters: {result.popt}") >>> print(f"Chunks used: {result.n_chunks}") Multi-start optimization: >>> fitter = LargeDatasetFitter( ... memory_limit_gb=4.0, ... multistart=True, ... n_starts=10, ... sampler='lhs', ... ) >>> result = fitter.fit( ... lambda x, a, b, c: a * jnp.exp(-b * x) + c, ... x, y, p0=[2, 1, 0], ... bounds=([0, 0, 0], [10, 5, 10]) ... ) Advanced configuration with progress monitoring: >>> config = LDMemoryConfig( ... memory_limit_gb=8.0, ... min_chunk_size=10000, ... max_chunk_size=1000000, ... use_streaming=True, ... streaming_batch_size=50000 ... ) >>> fitter = LargeDatasetFitter(config=config) >>> >>> # Fit with progress bar for long-running operation >>> result = fitter.fit_with_progress( ... exponential_model, x_huge, y_huge, p0=[2, 1, 0] ... ) Memory analysis before processing: >>> recommendations = fitter.get_memory_recommendations(len(x), n_params=3) >>> print(f"Strategy: {recommendations['processing_strategy']}") >>> print(f"Memory estimate: {recommendations['memory_estimate_gb']:.2f} GB") >>> print(f"Recommended chunks: {recommendations['n_chunks']}") See Also -------- curve_fit_large : High-level function with automatic dataset size detection LDMemoryConfig : Configuration class for memory management parameters estimate_memory_requirements : Standalone function for memory estimation Notes ----- The sequential chunking algorithm maintains parameter accuracy by using each chunk's result as the initial guess for the next chunk. This approach typically maintains fitting accuracy within 0.1% of single-pass results for well-conditioned problems while enabling processing of arbitrarily large datasets. For extremely large datasets, streaming optimization processes all data using mini-batch gradient descent with no subsampling, ensuring zero accuracy loss compared to subsampling approaches (removed in v0.2.0). """
[docs] def __init__( self, memory_limit_gb: float = 8.0, config: LDMemoryConfig | None = None, curve_fit_class: CurveFit | None = None, logger: Logger | None = None, # Multi-start parameters (Task Group 3) multistart: bool = False, n_starts: int = 10, sampler: Literal["lhs", "sobol", "halton"] = "lhs", ) -> None: """Initialize LargeDatasetFitter. Parameters ---------- memory_limit_gb : float, optional Memory limit in GB (default: 8.0) config : LDMemoryConfig, optional Custom memory configuration curve_fit_class : nlsq.minpack.CurveFit, optional Custom CurveFit instance to use logger : logging.Logger, optional External logger instance for integration with application logging. If None, uses NLSQ's internal logger. This allows chunk failure warnings to appear in your application's logs. multistart : bool, optional Enable multi-start optimization for global search (default: False). When enabled, explores multiple starting points on full data before running the full chunked optimization. n_starts : int, optional Number of starting points for multi-start optimization (default: 10). Set to 0 to disable multi-start even when multistart=True. sampler : str, optional Sampling strategy for generating starting points (default: 'lhs'). Options: 'lhs' (Latin Hypercube), 'sobol', 'halton'. """ if config is None: config = LDMemoryConfig(memory_limit_gb=memory_limit_gb) self.config = config self.logger = logger or get_logger(__name__) # Multi-start configuration (Task Group 3) self.multistart = multistart # Ensure enough starts for robust exploration on large datasets # Only enforce minimum when multistart is enabled AND n_starts > 0 # (n_starts=0 explicitly disables multistart exploration) self.n_starts = max(n_starts, 8) if (multistart and n_starts > 0) else n_starts self.sampler = sampler # Create GlobalOptimizationConfig if multi-start is enabled self._multistart_config = None if self.multistart and self.n_starts > 0: from nlsq.global_optimization import GlobalOptimizationConfig self._multistart_config = GlobalOptimizationConfig( n_starts=self.n_starts, sampler=self.sampler, center_on_p0=True, scale_factor=1.0, ) # Initialize curve fitting backend if curve_fit_class is None: # Deferred import to avoid circular dependency from nlsq.core.minpack import CurveFit self.curve_fit = CurveFit() else: self.curve_fit = curve_fit_class # Statistics tracking self.last_stats: DatasetStats | None = None self.fit_history: list[dict] = [] self._error_log_timestamps: defaultdict[str, list[float]] = defaultdict(list) # Task Group 7 (5.1a): Model validation caching # Cache validated functions by (id(func), id(func.__code__)) to avoid # redundant validation across chunks. Provides 1-5% performance gain. self._validated_functions: dict[tuple[int, int], bool] = {} # FR-006: Buffer pool for chunk reuse (initialized lazily) self._buffer_pool: ChunkBufferPool | None = None
def _get_or_create_buffer_pool(self, chunk_size: int) -> ChunkBufferPool: """Get existing buffer pool or create new one if size changed. Parameters ---------- chunk_size : int Required chunk size Returns ------- ChunkBufferPool Buffer pool with appropriate size """ bucket_size = get_bucket_size(chunk_size) if self._buffer_pool is None or self._buffer_pool.chunk_size < bucket_size: self._buffer_pool = ChunkBufferPool(bucket_size) return self._buffer_pool def _should_log_error(self, error_signature: str, current_time: float) -> bool: """Rate-limit error logging to prevent log flooding (max once per 60s per error type). Parameters ---------- error_signature : str Unique signature identifying the error type current_time : float Current timestamp Returns ------- bool True if error should be logged, False if rate-limited """ cutoff_time = current_time - 60.0 recent_timestamps = self._error_log_timestamps[error_signature] # Remove timestamps older than 60 seconds self._error_log_timestamps[error_signature] = [ t for t in recent_timestamps if t > cutoff_time ] # Allow logging only if no recent occurrence in the last 60 seconds return len(self._error_log_timestamps[error_signature]) == 0 def _log_validation_error(self, error: Exception) -> None: """Log validation error with rate limiting. Parameters ---------- error : Exception The validation error to log """ error_signature = f"{type(error).__name__}" current_time = time.time() if self._should_log_error(error_signature, current_time): self.logger.error(f"Model function validation failed: {error}") # Track timestamp for this error type self._error_log_timestamps[error_signature].append(current_time) # Cleanup old timestamps (older than 5 minutes) cutoff_time = current_time - 300 self._error_log_timestamps[error_signature] = [ t for t in self._error_log_timestamps[error_signature] if t > cutoff_time ] def _compute_chunk_stats( self, x_chunk: np.ndarray, y_chunk: np.ndarray ) -> dict[str, float]: """Compute diagnostic statistics for a data chunk. Parameters ---------- x_chunk : np.ndarray Chunk of independent variable data y_chunk : np.ndarray Chunk of dependent variable data Returns ------- dict Dictionary containing statistical measures """ return { "x_mean": float(np.mean(x_chunk)), "x_std": float(np.std(x_chunk)), "y_mean": float(np.mean(y_chunk)), "y_std": float(np.std(y_chunk)), } def _compute_failed_chunk_stats( self, x_chunk: np.ndarray, y_chunk: np.ndarray ) -> dict[str, float | tuple]: """Compute detailed statistics for failed chunks (includes ranges). Parameters ---------- x_chunk : np.ndarray Chunk of independent variable data y_chunk : np.ndarray Chunk of dependent variable data Returns ------- dict Dictionary containing detailed statistical measures """ return { "x_mean": float(np.mean(x_chunk)), "x_std": float(np.std(x_chunk)), "x_range": (float(np.min(x_chunk)), float(np.max(x_chunk))), "y_mean": float(np.mean(y_chunk)), "y_std": float(np.std(y_chunk)), "y_range": (float(np.min(y_chunk)), float(np.max(y_chunk))), } def _validate_model_function( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None, ) -> None: """Validate model function shape compatibility before chunked processing. Tests the model function with a small subset of data to catch shape mismatches early with clear error messages. Parameters ---------- f : callable The model function to validate xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data p0 : np.ndarray | list | None Initial parameter guess Raises ------ ValueError If model function fails execution or returns wrong shape TypeError If model function returns non-array type Notes ----- Task Group 7 (5.1a): Uses validation caching by function content. Validation is skipped for functions that have already been validated, using content-based key (name, bytecode, consts) for cache lookup. Provides 1-5% performance gain in chunked processing. """ # Task Group 7 (5.1a): Check validation cache # Use function object and its code object directly (not id() which can be recycled by GC) func_key = (f.__name__, f.__code__.co_code, f.__code__.co_consts) if func_key in self._validated_functions: self.logger.debug("Model validation skipped (cached)") return self.logger.debug("Validating model function shape compatibility...") try: # Test with first 100 points to avoid expensive computation test_size = min(100, len(xdata)) x_test = xdata[:test_size] y_test = ydata[:test_size] # Get initial parameters for testing if p0 is None: # Try to infer from function signature try: from inspect import signature sig = signature(f) n_params = len(sig.parameters) - 1 # Subtract x parameter p0_test = np.ones(n_params) except Exception: # Fallback to 2 parameters p0_test = np.ones(2) self.logger.warning( "Could not infer parameter count, using 2 parameters for validation" ) else: p0_test = np.array(p0) # Call model function with test data try: output_test = f(x_test, *p0_test) except Exception as e: raise ValueError( f"Model function failed on test data: {type(e).__name__}: {e}\n" f"\n" f"Model function must be callable as f(xdata, *params) and return array.\n" f"Ensure your model:\n" f" 1. Uses JAX operations (jax.numpy, not numpy)\n" f" 2. Doesn't use Python control flow that breaks JIT\n" f" 3. Returns numeric array, not scalar or other type\n" ) from e # Validate return type - check if it's array-like (numpy or JAX) is_array = isinstance(output_test, np.ndarray) or ( hasattr(output_test, "shape") and hasattr(output_test, "dtype") ) if not is_array: raise TypeError( f"Model function must return array, got {type(output_test)}\n" f"\n" f"Your model returned: {type(output_test).__name__}\n" f"Expected: numpy.ndarray or jax.Array\n" ) # Validate shapes match if output_test.shape != y_test.shape: raise ValueError( f"Model function SHAPE MISMATCH detected!\n" f"\n" f" Input xdata shape: {x_test.shape}\n" f" Input ydata shape: {y_test.shape}\n" f" Model output shape: {output_test.shape}\n" f" Expected shape: {y_test.shape}\n" f"\n" f"ERROR: Model output must match ydata size.\n" f"\n" f"When using curve_fit_large with chunking, your model function\n" f"MUST respect the size of xdata. During chunked processing, xdata\n" f"will be a subset (e.g., 1M points) of the full dataset.\n" f"\n" f"Common cause:\n" f" Your model ignores xdata size and always returns the full array.\n" f"\n" f"Fix: Use xdata as indices to return only the requested subset:\n" f"\n" f" def model(xdata, *params):\n" f" # Compute full output if needed\n" f" y_full = compute_full_model(*params) # e.g., shape (N,)\n" f" \n" f" # Return only requested indices for chunking compatibility\n" f" indices = xdata.astype(jnp.int32) # Use JAX operations\n" f" return y_full[indices] # Shape matches xdata\n" f"\n" f"See NLSQ documentation for more details on chunking-compatible models.\n" ) self.logger.debug( f"Model validation passed: " f"f({x_test.shape}, {len(p0_test)} params) -> {output_test.shape}" ) # Task Group 7 (5.1a): Cache successful validation self._validated_functions[func_key] = True except (ValueError, TypeError) as e: # Re-raise validation errors with context (rate-limited logging) self._log_validation_error(e) raise except Exception as e: # Unexpected error during validation self.logger.warning( f"Model validation encountered unexpected error: {type(e).__name__}: {e}\n" f"Proceeding with chunked fitting, but errors may occur." ) # Don't fail here - let chunking proceed and catch real errors def _run_multistart_exploration( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | None, bounds: tuple, **kwargs, ) -> tuple[np.ndarray, dict]: """Run multi-start exploration on full data to find best starting point. Uses MultiStartOrchestrator to evaluate multiple starting points generated by LHS/Sobol/Halton sampling and returns the best one. Parameters ---------- f : Callable Model function f(x, *params) -> y xdata : np.ndarray Full independent variable data ydata : np.ndarray Full dependent variable data p0 : np.ndarray | None Initial parameter guess (used for centering if center_on_p0=True) bounds : tuple Parameter bounds (lower, upper) **kwargs Additional arguments passed to curve_fit Returns ------- tuple[np.ndarray, dict] (best_starting_point, multistart_diagnostics) - best_starting_point: Best parameters found from exploration - multistart_diagnostics: Dictionary with exploration details """ from nlsq.global_optimization import MultiStartOrchestrator exploration_start_time = time.time() # Create orchestrator with our config orchestrator = MultiStartOrchestrator( config=self._multistart_config, curve_fit_instance=self.curve_fit, ) self.logger.info( f"Running multi-start exploration with {self.n_starts} starting points " f"on {len(xdata):,} points using {self.sampler} sampling" ) # Run exploration on full data result = orchestrator.fit( f=f, xdata=xdata, ydata=ydata, p0=p0, bounds=bounds, **kwargs, ) exploration_time = time.time() - exploration_start_time # Extract diagnostics from result diagnostics = result.get("multistart_diagnostics", {}) diagnostics["exploration_time_seconds"] = exploration_time diagnostics["dataset_size"] = len(xdata) # Get the best starting point best_params = result.popt if hasattr(result, "popt") else result.get("popt", p0) self.logger.info( f"Multi-start exploration completed in {exploration_time:.2f}s. " f"Best loss: {diagnostics.get('best_loss', 'N/A')}" ) return np.asarray(best_params), diagnostics
[docs] def estimate_requirements(self, n_points: int, n_params: int) -> DatasetStats: """Estimate memory requirements and processing strategy. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters to fit Returns ------- DatasetStats Detailed statistics and recommendations """ _, stats = MemoryEstimator.calculate_optimal_chunk_size( n_points, n_params, self.config ) self.last_stats = stats # Log recommendations self.logger.info( f"Dataset analysis for {n_points:,} points, {n_params} parameters:" ) self.logger.info( f" Estimated memory per point: {stats.memory_per_point_bytes:.1f} bytes" ) self.logger.info( f" Total memory estimate: {stats.total_memory_estimate_gb:.2f} GB" ) self.logger.info(f" Recommended chunk size: {stats.recommended_chunk_size:,}") self.logger.info(f" Number of chunks: {stats.n_chunks}") return stats
[docs] def fit( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None = None, bounds: tuple = (-np.inf, np.inf), method: str = "trf", solver: str = "auto", **kwargs, ) -> OptimizeResult: """Fit curve to large dataset with automatic memory management. Parameters ---------- f : callable The model function f(x, \\*params) -> y xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data p0 : array-like, optional Initial parameter guess bounds : tuple, optional Parameter bounds (lower, upper) method : str, optional Optimization method (default: 'trf') solver : str, optional Solver type (default: 'auto') **kwargs Additional arguments passed to curve_fit Returns ------- OptimizeResult Optimization result with fitted parameters and statistics """ return self._fit_implementation( f, xdata, ydata, p0, bounds, method, solver, show_progress=False, **kwargs )
[docs] def fit_with_progress( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None = None, bounds: tuple = (-np.inf, np.inf), method: str = "trf", solver: str = "auto", **kwargs, ) -> OptimizeResult: """Fit curve with progress reporting for long-running fits. Parameters ---------- f : callable The model function f(x, \\*params) -> y xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data p0 : array-like, optional Initial parameter guess bounds : tuple, optional Parameter bounds (lower, upper) method : str, optional Optimization method (default: 'trf') solver : str, optional Solver type (default: 'auto') **kwargs Additional arguments passed to curve_fit Returns ------- OptimizeResult Optimization result with fitted parameters and statistics """ return self._fit_implementation( f, xdata, ydata, p0, bounds, method, solver, show_progress=True, **kwargs )
def _fit_implementation( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None, bounds: tuple, method: str, solver: str, show_progress: bool, **kwargs, ) -> OptimizeResult: """Internal implementation of fitting algorithm.""" fit_start_time = time.time() n_points = len(xdata) # Estimate number of parameters from function signature or p0 if p0 is not None: n_params = len(p0) else: # Try to infer from function signature try: from inspect import signature sig = signature(f) n_params = len(sig.parameters) - 1 # Subtract x parameter except Exception: n_params = 2 # Conservative default # Normalize initial guess and apply heuristics for stability if p0 is not None: p0 = np.asarray(p0, dtype=float) if self.multistart and p0.size > 0 and p0[0] < 0.5: heuristic_amp = max(float(np.ptp(ydata)), 0.5) p0 = p0.copy() p0[0] = heuristic_amp # Get processing statistics and strategy stats = self.estimate_requirements(n_points, n_params) # Determine if chunking is needed needs_chunking = stats.n_chunks > 1 # Initialize multi-start diagnostics multistart_diagnostics = { "n_starts_configured": self.n_starts if self.multistart else 0, "sampler": self.sampler, "bypassed": False, } # Run multi-start exploration if enabled and chunking is needed # (for single-chunk datasets, multi-start overhead isn't worth it) if self.multistart and self.n_starts > 0 and needs_chunking: self.logger.info("Multi-start optimization enabled for chunked dataset") # Run multi-start exploration on full data (no subsampling) best_p0, exploration_diagnostics = self._run_multistart_exploration( f=f, xdata=xdata, ydata=ydata, p0=np.array(p0) if p0 is not None else None, bounds=bounds, **kwargs, ) # Update p0 with best starting point p0 = best_p0 multistart_diagnostics.update(exploration_diagnostics) multistart_diagnostics["best_starting_point"] = ( best_p0.tolist() if hasattr(best_p0, "tolist") else list(best_p0) ) self.logger.info( f"Using best starting point from multi-start exploration: {best_p0}" ) elif self.multistart and self.n_starts == 0: # Multi-start enabled but n_starts=0 means skip multistart_diagnostics["bypassed"] = True multistart_diagnostics["n_starts_evaluated"] = 0 self.logger.debug("Multi-start disabled (n_starts=0)") elif self.multistart and not needs_chunking: # Single chunk dataset - skip multi-start overhead multistart_diagnostics["bypassed"] = True multistart_diagnostics["bypass_reason"] = "single_chunk_dataset" self.logger.debug("Multi-start skipped for single-chunk dataset") # Handle datasets that fit in memory if stats.n_chunks == 1: result = self._fit_single_chunk( f, xdata, ydata, p0, bounds, method, solver, **kwargs, ) # Add multi-start diagnostics to result result["multistart_diagnostics"] = multistart_diagnostics return result # Handle chunked processing (will use streaming if enabled for very large datasets) result = self._fit_chunked( f, xdata, ydata, p0, bounds, method, solver, show_progress, stats, **kwargs, ) # Add multi-start diagnostics and timing to result multistart_diagnostics["total_fit_time_seconds"] = time.time() - fit_start_time result["multistart_diagnostics"] = multistart_diagnostics return result def _fit_single_chunk( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None, bounds: tuple, method: str, solver: str, **kwargs, ) -> OptimizeResult: """Fit data that can be processed in a single chunk.""" self.logger.info("Fitting dataset in single chunk") try: popt, _pcov = self.curve_fit.curve_fit( f, xdata, ydata, p0=p0, bounds=bounds, method=method, solver=solver, **kwargs, ) # Create result object result = OptimizeResult( x=popt, success=True, fun=None, # Could compute final residuals if needed nfev=1, # Approximation message="Single-chunk fit completed successfully", ) # Add covariance matrix and parameters result["pcov"] = _pcov result["popt"] = popt return result except Exception as e: self.logger.error(f"Single-chunk fit failed: {e}") result = OptimizeResult( x=p0 if p0 is not None else np.ones(2), success=False, message=f"Fit failed: {e}", ) # Add empty popt and pcov for consistency result["popt"] = result.x result["pcov"] = np.eye(len(result.x)) return result def _fit_with_streaming( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None, bounds: tuple, method: str, solver: str, show_progress: bool, **kwargs, ) -> OptimizeResult: """Fit very large dataset using adaptive hybrid streaming optimization.""" self.logger.info( "Using adaptive hybrid streaming optimization for unlimited data " f"({len(xdata):,} points). " f"Chunk size: {self.config.streaming_batch_size:,}, " f"Max iterations: {self.config.streaming_max_epochs}" ) # Create adaptive hybrid streaming config streaming_config = HybridStreamingConfig( chunk_size=self.config.streaming_batch_size, gauss_newton_max_iterations=self.config.streaming_max_epochs, ) # Initialize adaptive hybrid streaming optimizer optimizer = AdaptiveHybridStreamingOptimizer(config=streaming_config) # Convert p0 to array if needed if p0 is None: p0 = np.ones(2) # Default 2-parameter model elif isinstance(p0, list): p0 = np.array(p0) # Fit using streaming optimization try: result_dict = optimizer.fit( data_source=(xdata, ydata), func=f, p0=p0, bounds=bounds, verbose=2 if show_progress else 1, ) # Convert to OptimizeResult format result = OptimizeResult( x=result_dict["x"], success=result_dict["success"], message=result_dict["message"], nfev=len(xdata), fun=result_dict.get("fun"), ) result["popt"] = result.x result["pcov"] = result_dict.get("pcov", np.eye(len(result.x))) self.logger.info( "Streaming fit completed. " f"Final loss: {result_dict.get('streaming_diagnostics', {}).get('gauss_newton_diagnostics', {}).get('final_cost', 'N/A')}" ) return result except Exception as e: self.logger.error(f"Streaming fit failed: {e}") result = OptimizeResult( x=p0 if p0 is not None else np.ones(2), success=False, message=f"Streaming fit failed: {e}", ) result["popt"] = result.x result["pcov"] = np.eye(len(result.x)) return result def _update_parameters_convergence( self, current_params: np.ndarray | None, popt_chunk: np.ndarray, param_history: list, convergence_metric: float, chunk_idx: int, n_chunks: int, ) -> tuple[np.ndarray, list, float, bool]: """Update parameters with sequential refinement and convergence checking. Args: current_params: Current parameter estimates (None on first chunk) popt_chunk: Newly fitted parameters from current chunk param_history: List of parameter estimates from previous chunks convergence_metric: Current convergence metric value chunk_idx: Index of current chunk (0-based) n_chunks: Total number of chunks Returns: tuple: (updated_params, updated_history, new_convergence_metric, should_stop) - updated_params: New current parameter estimates - updated_history: Updated parameter history - new_convergence_metric: Updated convergence metric - should_stop: True if early stopping criteria met """ # Initialize on first chunk if current_params is None: return ( popt_chunk.copy(), [popt_chunk.copy()], np.inf, False, ) # Update parameters with sequential refinement previous_params = current_params.copy() updated_params = popt_chunk.copy() # Update parameter history updated_history = [*param_history, updated_params.copy()] # Calculate convergence metric new_convergence_metric = convergence_metric if len(updated_history) > 2: param_change = np.linalg.norm(updated_params - previous_params) relative_change = param_change / (np.linalg.norm(updated_params) + 1e-10) new_convergence_metric = relative_change # Check early stopping criteria # Stop if parameters stabilized and we've processed enough chunks if new_convergence_metric < 0.001 and chunk_idx >= min(n_chunks - 1, 3): self.logger.info(f"Parameters converged after {chunk_idx + 1} chunks") return (updated_params, updated_history, new_convergence_metric, True) return (updated_params, updated_history, new_convergence_metric, False) def _initialize_chunked_fit_state( self, p0: np.ndarray | list | None, show_progress: bool, stats: DatasetStats, ) -> tuple[ ProgressReporter | None, np.ndarray | None, list, list, float, ]: """Initialize state variables for chunked fitting. Parameters ---------- p0 : np.ndarray | list | None Initial parameter guess show_progress : bool Whether to show progress updates stats : DatasetStats Dataset statistics including chunk count Returns ------- progress : ProgressReporter | None Progress reporter instance or None current_params : np.ndarray | None Initial parameters chunk_results : list Empty list for accumulating chunk results param_history : list Empty list for tracking parameter evolution convergence_metric : float Initial convergence metric (infinity) """ # Initialize progress reporter progress = ( ProgressReporter(stats.n_chunks, self.logger) if show_progress else None ) # Initialize parameters current_params = np.array(p0) if p0 is not None else None # Initialize tracking lists chunk_results = [] param_history = [] convergence_metric = np.inf return ( progress, current_params, chunk_results, param_history, convergence_metric, ) def _create_chunk_result( self, chunk_idx: int, x_chunk: np.ndarray, y_chunk: np.ndarray, chunk_duration: float, success: bool = True, popt_chunk: np.ndarray | None = None, is_retry: bool = False, error: Exception | None = None, current_params: np.ndarray | None = None, ) -> dict: """Create a standardized chunk result dictionary. Args: chunk_idx: Index of the chunk x_chunk: Input data for this chunk y_chunk: Output data for this chunk chunk_duration: Time taken to process this chunk success: Whether the chunk fitting succeeded popt_chunk: Fitted parameters (if successful) is_retry: Whether this was a retry attempt error: Exception that occurred (if failed) current_params: Current parameter estimates (for failure diagnostics) Returns: dict: Standardized chunk result with metadata """ # Base result structure result = { "chunk_idx": chunk_idx, "n_points": len(x_chunk), "success": success, "timestamp": time.time(), "duration": chunk_duration, } if success: # Success case result["parameters"] = popt_chunk if is_retry: result["retry"] = True # Add diagnostics if enabled (5-10% performance gain when disabled) if self.config.save_diagnostics: result["data_stats"] = self._compute_chunk_stats(x_chunk, y_chunk) else: # Failure case result["error"] = str(error) result["error_type"] = type(error).__name__ result["initial_params"] = ( current_params.tolist() if current_params is not None else None ) # Always compute detailed stats for failed chunks (debugging critical) result["data_stats"] = self._compute_failed_chunk_stats(x_chunk, y_chunk) return result def _retry_failed_chunk( self, f: Callable, x_chunk: np.ndarray, y_chunk: np.ndarray, chunk_idx: int, chunk_start_time: float, chunk_times: list, current_params: np.ndarray | None, initial_error: Exception, bounds: tuple, method: str, solver: str, **kwargs, ) -> tuple[dict, np.ndarray | None]: """Retry a failed chunk with perturbed parameters. Args: f: Model function x_chunk: Input data for this chunk y_chunk: Output data for this chunk chunk_idx: Index of the chunk chunk_start_time: Start time of chunk processing chunk_times: List to append chunk duration to current_params: Current parameter estimates initial_error: The exception that caused the initial failure bounds: Parameter bounds method: Optimization method solver: Solver type **kwargs: Additional curve_fit arguments Returns: tuple: (chunk_result dict, updated_params or None) """ # Only retry if we have current parameter estimates if current_params is None: chunk_duration = time.time() - chunk_start_time chunk_times.append(chunk_duration) chunk_result = self._create_chunk_result( chunk_idx=chunk_idx, x_chunk=x_chunk, y_chunk=y_chunk, chunk_duration=chunk_duration, success=False, error=initial_error, current_params=current_params, ) return chunk_result, None # Attempt retry with perturbed parameters try: self.logger.info(f"Retrying chunk {chunk_idx} with current parameters") # Add small perturbation to avoid local minima perturbed_params = current_params * ( 1 + 0.01 * np.random.randn(len(current_params)) ) popt_chunk, _pcov_chunk = self.curve_fit.curve_fit( f, x_chunk, y_chunk, p0=perturbed_params, bounds=bounds, method=method, solver=solver, **kwargs, ) # Retry succeeded - use result with lower weight adaptive_lr = 0.1 # Lower weight for retry results updated_params = ( 1 - adaptive_lr ) * current_params + adaptive_lr * popt_chunk chunk_duration = time.time() - chunk_start_time chunk_times.append(chunk_duration) chunk_result = self._create_chunk_result( chunk_idx=chunk_idx, x_chunk=x_chunk, y_chunk=y_chunk, chunk_duration=chunk_duration, success=True, popt_chunk=popt_chunk, is_retry=True, ) return chunk_result, updated_params except Exception as retry_e: # Retry also failed self.logger.warning(f"Retry for chunk {chunk_idx} also failed: {retry_e}") chunk_duration = time.time() - chunk_start_time chunk_times.append(chunk_duration) chunk_result = self._create_chunk_result( chunk_idx=chunk_idx, x_chunk=x_chunk, y_chunk=y_chunk, chunk_duration=chunk_duration, success=False, error=initial_error, current_params=current_params, ) return chunk_result, current_params # Keep current params unchanged def _create_failure_summary( self, chunk_results: list, chunk_times: list, ) -> dict: """Create comprehensive failure summary for diagnostics. Args: chunk_results: List of all chunk result dictionaries chunk_times: List of chunk processing durations Returns: dict: Failure summary with error types, timing stats, and common errors """ failed_chunks = [r for r in chunk_results if not r.get("success", False)] failure_summary = { "total_failures": len(failed_chunks), "failure_rate": len(failed_chunks) / len(chunk_results) if chunk_results else 0.0, "failed_chunk_indices": [r["chunk_idx"] for r in failed_chunks], "error_types": {}, "common_errors": [], "timing_stats": { "mean_chunk_time": float(np.mean(chunk_times)) if chunk_times else 0.0, "median_chunk_time": float(np.median(chunk_times)) if chunk_times else 0.0, "failed_chunk_times": [r.get("duration", 0.0) for r in failed_chunks], "mean_failed_chunk_time": float( np.mean([r.get("duration", 0.0) for r in failed_chunks]) ) if failed_chunks else 0.0, }, } # Aggregate error types for failed_chunk in failed_chunks: error_type = failed_chunk.get("error_type", "Unknown") failure_summary["error_types"][error_type] = ( failure_summary["error_types"].get(error_type, 0) + 1 ) # Identify most common errors (top 3) if failure_summary["error_types"]: sorted_errors = sorted( failure_summary["error_types"].items(), key=lambda x: x[1], reverse=True ) failure_summary["common_errors"] = [ {"type": err_type, "count": count} for err_type, count in sorted_errors[:3] ] return failure_summary def _compute_covariance_from_history( self, param_history: list, current_params: np.ndarray, ) -> np.ndarray: """Compute approximate covariance matrix from parameter history. In chunked fitting, we estimate covariance from parameter variations across chunks rather than from the Jacobian. Args: param_history: List of parameter estimates from previous chunks current_params: Final parameter estimates Returns: np.ndarray: Approximate covariance matrix """ if len(param_history) > 1: # Use last 10 parameter estimates for covariance estimation param_variations = np.array(param_history[-min(10, len(param_history)) :]) pcov = np.cov(param_variations.T) else: # Fallback: identity matrix scaled by parameter magnitudes # This provides a reasonable uncertainty estimate when we have no history pcov = np.diag(np.abs(current_params) * 0.01 + 0.001) return pcov def _finalize_chunked_results( self, current_params: np.ndarray, chunk_results: list, param_history: list, success_rate: float, stats: DatasetStats, chunk_times: list, ) -> OptimizeResult: """Assemble final optimization result from chunked fitting. Parameters ---------- current_params : np.ndarray Final optimized parameters chunk_results : list List of all chunk result dictionaries param_history : list History of parameter estimates across chunks success_rate : float Fraction of successful chunks stats : DatasetStats Dataset statistics including chunk count chunk_times : list Processing durations for each chunk Returns ------- OptimizeResult Final optimization result with parameters, covariance, and diagnostics """ # Log completion self.logger.info(f"Chunked fit completed with {success_rate:.1%} success rate") # Create failure summary for diagnostics failure_summary = self._create_failure_summary(chunk_results, chunk_times) # Assemble result result = OptimizeResult( x=current_params, success=True, message=f"Chunked fit completed ({stats.n_chunks} chunks, {success_rate:.1%} success)", ) result["popt"] = current_params # Create approximate covariance matrix from parameter history result["pcov"] = self._compute_covariance_from_history( param_history, current_params ) # Add diagnostic information result["chunk_results"] = chunk_results result["n_chunks"] = stats.n_chunks result["success_rate"] = success_rate result["failure_summary"] = failure_summary return result def _check_success_rate_and_create_result( self, chunk_results: list, current_params: np.ndarray | None, param_history: list, stats: DatasetStats, chunk_times: list, ) -> OptimizeResult: """Check success rate and create appropriate result (success or failure). Args: chunk_results: List of chunk processing results current_params: Final parameter estimates param_history: History of parameter updates stats: Dataset statistics chunk_times: Processing time for each chunk Returns: OptimizeResult with success or failure status based on success rate """ # Compute final statistics successful_chunks = [r for r in chunk_results if r.get("success", False)] success_rate = len(successful_chunks) / len(chunk_results) if success_rate < self.config.min_success_rate: self.logger.error( f"Too many chunks failed ({success_rate:.1%} success rate, " f"minimum required: {self.config.min_success_rate:.1%})" ) result = OptimizeResult( x=current_params if current_params is not None else np.ones(2), success=False, message=f"Chunked fit failed: {success_rate:.1%} success rate", ) # Add empty popt and pcov for consistency result["popt"] = ( current_params if current_params is not None else np.ones(2) ) result["pcov"] = np.eye(len(result["popt"])) return result # Success - assemble final result return self._finalize_chunked_results( current_params=current_params, chunk_results=chunk_results, param_history=param_history, success_rate=success_rate, stats=stats, chunk_times=chunk_times, ) def _fit_chunked( self, f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None, bounds: tuple, method: str, solver: str, show_progress: bool, stats: DatasetStats, **kwargs, ) -> OptimizeResult: """Fit dataset using chunked processing with parameter refinement.""" self.logger.info(f"Fitting dataset using {stats.n_chunks} chunks") # Validate model function shape compatibility self._validate_model_function(f, xdata, ydata, p0) # Initialize state variables ( progress, current_params, chunk_results, param_history, convergence_metric, ) = self._initialize_chunked_fit_state(p0, show_progress, stats) chunk_times = [] # Track processing time per chunk # Get gc_chunk_interval from config (FR-007) gc_chunk_interval = self.config.gc_chunk_interval # FR-006: Initialize buffer pool for chunk reuse buffer_pool = self._get_or_create_buffer_pool(stats.recommended_chunk_size) try: # Process dataset in chunks with sequential parameter refinement for x_chunk, y_chunk, chunk_idx, valid_length in DataChunker.create_chunks( xdata, ydata, stats.recommended_chunk_size, buffer_pool=buffer_pool ): chunk_start_time = time.time() try: # Use only valid data points (exclude zero-padding) x_valid = x_chunk[:valid_length] y_valid = y_chunk[:valid_length] popt_chunk, _pcov_chunk = self.curve_fit.curve_fit( f, x_valid, y_valid, p0=current_params, bounds=bounds, method=method, solver=solver, **kwargs, ) # Update parameters with sequential refinement and check convergence ( current_params, param_history, convergence_metric, should_stop, ) = self._update_parameters_convergence( current_params, popt_chunk, param_history, convergence_metric, chunk_idx, stats.n_chunks, ) # Early stopping if parameters converged if should_stop: break chunk_duration = time.time() - chunk_start_time chunk_times.append(chunk_duration) # Create successful chunk result chunk_result = self._create_chunk_result( chunk_idx=chunk_idx, x_chunk=x_chunk, y_chunk=y_chunk, chunk_duration=chunk_duration, success=True, popt_chunk=popt_chunk, ) except Exception as e: self.logger.warning(f"Chunk {chunk_idx} failed: {e}") # Retry chunk with helper method chunk_result, retry_params = self._retry_failed_chunk( f=f, x_chunk=x_chunk, y_chunk=y_chunk, chunk_idx=chunk_idx, chunk_start_time=chunk_start_time, chunk_times=chunk_times, current_params=current_params, initial_error=e, bounds=bounds, method=method, solver=solver, **kwargs, ) # Update params if retry succeeded if retry_params is not None: current_params = retry_params chunk_results.append(chunk_result) if progress: progress.update(chunk_idx, chunk_result) # Memory cleanup - conditional based on gc_chunk_interval (FR-007) if chunk_idx % gc_chunk_interval == 0: gc.collect() # Check success rate and create final result return self._check_success_rate_and_create_result( chunk_results=chunk_results, current_params=current_params, param_history=param_history, stats=stats, chunk_times=chunk_times, ) except Exception as e: self.logger.error(f"Chunked fitting failed: {e}") result = OptimizeResult( x=current_params if current_params is not None else np.ones(2), success=False, message=f"Chunked fit failed: {e}", ) # Add empty popt and pcov for consistency result["popt"] = ( current_params if current_params is not None else np.ones(2) ) result["pcov"] = np.eye(len(result["popt"])) return result
[docs] @contextmanager def memory_monitor(self): """Context manager for monitoring memory usage during fits.""" try: process = psutil.Process() initial_memory = process.memory_info().rss / (1024**3) # GB self.logger.debug(f"Initial memory usage: {initial_memory:.2f} GB") yield finally: try: final_memory = process.memory_info().rss / (1024**3) # GB memory_delta = final_memory - initial_memory self.logger.debug( f"Final memory usage: {final_memory:.2f} GB (delta: {memory_delta:+.2f} GB)" ) except Exception as e: # Memory monitoring is best effort - log but don't fail self.logger.debug(f"Memory monitoring failed (non-critical): {e}")
[docs] def get_memory_recommendations(self, n_points: int, n_params: int) -> dict: """Get memory usage recommendations for a dataset. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters Returns ------- dict Recommendations and memory analysis """ stats = self.estimate_requirements(n_points, n_params) return { "dataset_stats": stats, "memory_limit_gb": self.config.memory_limit_gb, "processing_strategy": "single_chunk" if stats.n_chunks == 1 else "chunked", "recommendations": { "chunk_size": stats.recommended_chunk_size, "n_chunks": stats.n_chunks, "memory_per_point_bytes": stats.memory_per_point_bytes, "total_memory_estimate_gb": stats.total_memory_estimate_gb, }, }
# Convenience functions
[docs] def fit_large_dataset( f: Callable, xdata: np.ndarray, ydata: np.ndarray, p0: np.ndarray | list | None = None, memory_limit_gb: float = 8.0, show_progress: bool = False, logger: Logger | None = None, # Multi-start parameters (Task Group 3) multistart: bool = False, n_starts: int = 10, sampler: Literal["lhs", "sobol", "halton"] = "lhs", **kwargs, ) -> OptimizeResult: """Convenience function for fitting large datasets. Parameters ---------- f : callable The model function f(x, \\*params) -> y xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data p0 : array-like, optional Initial parameter guess memory_limit_gb : float, optional Memory limit in GB (default: 8.0) show_progress : bool, optional Whether to show progress (default: False) logger : logging.Logger, optional External logger for application integration (default: None) multistart : bool, optional Enable multi-start optimization for global search (default: False). When enabled, explores multiple starting points on full data before running the full chunked optimization. n_starts : int, optional Number of starting points for multi-start optimization (default: 10). Set to 0 to disable multi-start even when multistart=True. sampler : str, optional Sampling strategy for generating starting points (default: 'lhs'). Options: 'lhs' (Latin Hypercube), 'sobol', 'halton'. **kwargs Additional arguments passed to curve_fit Returns ------- OptimizeResult Optimization result Examples -------- >>> from nlsq.streaming.large_dataset import fit_large_dataset >>> import numpy as np >>> import jax.numpy as jnp >>> >>> # Generate large dataset >>> x_large = np.linspace(0, 10, 5_000_000) >>> y_large = 2.5 * np.exp(-1.3 * x_large) + np.random.normal(0, 0.1, len(x_large)) >>> >>> # Fit with automatic memory management >>> result = fit_large_dataset( ... lambda x, a, b: a * jnp.exp(-b * x), ... x_large, y_large, ... p0=[2.0, 1.0], ... memory_limit_gb=4.0, ... show_progress=True ... ) >>> print(f"Fitted parameters: {result.popt}") >>> print(f"Success rate: {result.success_rate:.1%}") >>> >>> # Fit with multi-start optimization >>> result = fit_large_dataset( ... lambda x, a, b: a * jnp.exp(-b * x), ... x_large, y_large, ... p0=[2.0, 1.0], ... bounds=([0, 0], [10, 5]), ... multistart=True, ... n_starts=10, ... sampler='lhs' ... ) >>> >>> # Check failure diagnostics if some chunks failed >>> if result.failure_summary['total_failures'] > 0: ... print(f"Failed chunks: {result.failure_summary['failed_chunk_indices']}") ... print(f"Common errors: {result.failure_summary['common_errors']}") """ fitter = LargeDatasetFitter( memory_limit_gb=memory_limit_gb, logger=logger, multistart=multistart, n_starts=n_starts, sampler=sampler, ) if show_progress: return fitter.fit_with_progress(f, xdata, ydata, p0=p0, **kwargs) else: return fitter.fit(f, xdata, ydata, p0=p0, **kwargs)
[docs] def estimate_memory_requirements(n_points: int, n_params: int) -> DatasetStats: """Estimate memory requirements for a dataset. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters Returns ------- DatasetStats Memory requirements and processing recommendations Examples -------- >>> from nlsq.streaming.large_dataset import estimate_memory_requirements >>> >>> # Estimate requirements for 50M points, 3 parameters >>> stats = estimate_memory_requirements(50_000_000, 3) >>> print(f"Estimated memory: {stats.total_memory_estimate_gb:.2f} GB") >>> print(f"Recommended chunk size: {stats.recommended_chunk_size:,}") >>> print(f"Number of chunks: {stats.n_chunks}") """ config = LDMemoryConfig() _, stats = MemoryEstimator.calculate_optimal_chunk_size(n_points, n_params, config) return stats
__all__ = [ "ChunkBufferPool", "DataChunker", "DatasetStats", "GPUMemoryEstimator", "LDMemoryConfig", "LargeDatasetFitter", "MemoryEstimator", "ProgressReporter", "cleanup_memory", "estimate_memory_requirements", "fit_large_dataset", "get_cached_devices", ]