Source code for nlsq.config

"""Central configuration management for NLSQ package."""

from __future__ import annotations

import copy
import json
import logging
import os
import threading
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field


[docs] @dataclass(slots=True) class MemoryConfig: """Configuration for memory management and GPU settings. Attributes ---------- memory_limit_gb : float Maximum memory limit in GB (default: 8.0) gpu_memory_fraction : float | None Fraction of GPU memory to use (0.0-1.0, None for automatic) chunk_size_mb : int | None Default chunk size in MB for data processing out_of_memory_strategy : str Strategy when out of memory: 'fallback', 'reduce', 'error' safety_factor : float Safety factor for memory calculations (0.0-1.0) auto_chunk_threshold_gb : float Automatically enable chunking above this memory threshold progress_reporting : bool Enable progress reporting for large operations min_chunk_size : int Minimum chunk size in data points max_chunk_size : int Maximum chunk size in data points """ memory_limit_gb: float = 8.0 gpu_memory_fraction: float | None = None chunk_size_mb: int | None = None out_of_memory_strategy: str = "fallback" safety_factor: float = 0.8 auto_chunk_threshold_gb: float = 4.0 progress_reporting: bool = True min_chunk_size: int = 1000 max_chunk_size: int = 1_000_000
[docs] def __post_init__(self): """Validate configuration values.""" if not 0.1 <= self.memory_limit_gb <= 1024: raise ValueError( f"memory_limit_gb must be between 0.1 and 1024, got {self.memory_limit_gb}" ) if self.gpu_memory_fraction is not None: if not 0.0 < self.gpu_memory_fraction <= 1.0: raise ValueError( f"gpu_memory_fraction must be between 0.0 and 1.0, got {self.gpu_memory_fraction}" ) if not 0.1 <= self.safety_factor <= 1.0: raise ValueError( f"safety_factor must be between 0.1 and 1.0, got {self.safety_factor}" ) if self.out_of_memory_strategy not in ["fallback", "reduce", "error"]: raise ValueError( f"out_of_memory_strategy must be 'fallback', 'reduce', or 'error', got {self.out_of_memory_strategy}" ) if self.min_chunk_size > self.max_chunk_size: raise ValueError( f"min_chunk_size ({self.min_chunk_size}) cannot be larger than max_chunk_size ({self.max_chunk_size})" )
[docs] @dataclass(slots=True) class LargeDatasetConfig: """Configuration for large dataset processing. Attributes ---------- enable_automatic_solver_selection : bool Automatically select optimal solver based on dataset size solver_selection_thresholds : Dict[str, int] Thresholds for automatic solver selection Notes ----- As of v0.2.0, all subsampling parameters have been removed. Use streaming optimization instead for unlimited datasets. See MIGRATION_V0.2.0.md for migration instructions. """ enable_automatic_solver_selection: bool = True solver_selection_thresholds: dict[str, int] = field( default_factory=lambda: { "direct": 100_000, # Use direct solver below this size "iterative": 10_000_000, # Use iterative solver below this size "chunked": 100_000_000, # Use chunked processing above this size } )
[docs] def __post_init__(self) -> None: """Validate configuration values.""" if not self.solver_selection_thresholds: return # Empty dict means "apply no threshold overrides" — valid. required_keys = {"direct", "iterative", "chunked"} missing = required_keys - set(self.solver_selection_thresholds) if missing: raise ValueError( f"solver_selection_thresholds missing required keys: {missing}" ) thresholds = self.solver_selection_thresholds if not (thresholds["direct"] < thresholds["iterative"] < thresholds["chunked"]): raise ValueError( "solver_selection_thresholds must be monotonically increasing: " f"direct={thresholds['direct']} < iterative={thresholds['iterative']} " f"< chunked={thresholds['chunked']}" ) for key, val in thresholds.items(): if val <= 0: raise ValueError( f"solver_selection_thresholds['{key}'] must be positive, " f"got {val!r}" )
[docs] class JAXConfig: """Singleton configuration manager for JAX and memory settings. This class ensures that JAX configuration is set once and consistently across all NLSQ modules, avoiding duplicate configuration calls. It also manages memory settings and large dataset configuration. """ _instance: JAXConfig | None = None _lock: threading.RLock = threading.RLock() _x64_enabled: bool = False _initialized: bool = False _memory_config: MemoryConfig | None = None _large_dataset_config: LargeDatasetConfig | None = None _gpu_memory_configured: bool = False
[docs] def __new__(cls) -> JAXConfig: """Ensure singleton pattern (thread-safe).""" with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance
[docs] def __init__(self): """Initialize JAX and memory configuration if not already done.""" with self._lock: if not self._initialized: self._initialize_jax() self._initialize_memory_config() self._initialize_large_dataset_config() self._initialized = True
def _initialize_jax(self): """Initialize JAX with default NLSQ settings.""" # Import here to avoid circular imports from jax import config # Force CPU backend if requested (useful for testing) if ( os.getenv("NLSQ_FORCE_CPU", "0") == "1" or os.getenv("JAX_PLATFORM_NAME") == "cpu" ): config.update("jax_platform_name", "cpu") # Enable 64-bit precision by default for NLSQ if not self._x64_enabled and os.getenv("NLSQ_DISABLE_X64") != "1": config.update("jax_enable_x64", True) self._x64_enabled = True # Configure persistent compilation cache (eliminates 2-10s cold start) self._configure_persistent_cache(config) # Configure GPU memory if specified self._configure_gpu_memory(config) def _configure_persistent_cache(self, config): """Configure JAX persistent compilation cache. This enables caching of compiled functions across Python sessions, eliminating cold-start overhead of 2-10 seconds. """ # Skip if explicitly disabled if os.getenv("NLSQ_DISABLE_PERSISTENT_CACHE") == "1": return # Set cache directory cache_dir = os.getenv( "NLSQ_JAX_CACHE_DIR", os.path.expanduser("~/.cache/nlsq/jax_cache") ) try: # Create cache directory if it doesn't exist os.makedirs(cache_dir, exist_ok=True) # Enable persistent compilation cache config.update("jax_compilation_cache_dir", cache_dir) # Only cache compilations that take at least 1 second min_compile_time = float(os.getenv("NLSQ_CACHE_MIN_COMPILE_TIME_SECS", "1")) config.update( "jax_persistent_cache_min_compile_time_secs", min_compile_time ) logging.debug(f"JAX persistent cache enabled at {cache_dir}") except Exception as e: # Non-fatal: log warning and continue without persistent cache logging.warning( f"Failed to enable JAX persistent compilation cache: {e}. " "Cold-start may be slower." ) def _configure_gpu_memory(self, config): """Configure GPU memory settings via XLA environment variables. Sets XLA_PYTHON_CLIENT_PREALLOCATE and XLA_PYTHON_CLIENT_MEM_FRACTION to prevent XLA from consuming all GPU memory. These must be set before JAX initializes the GPU backend, so we set them as early as possible. Without these defaults, XLA preallocates 75% of GPU memory at startup and never releases it, which combined with unbounded compilation cache growth can cause OOM for long-running sessions. """ if self._gpu_memory_configured: return # Default: disable preallocation to allow memory to grow/shrink # This prevents XLA from grabbing 75% of GPU memory at startup. # Users can override via XLA_PYTHON_CLIENT_PREALLOCATE=true. if "XLA_PYTHON_CLIENT_PREALLOCATE" not in os.environ: os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" logging.debug("Set XLA_PYTHON_CLIENT_PREALLOCATE=false (default for NLSQ)") # Check for NLSQ-specific GPU memory fraction setting gpu_memory_fraction = os.getenv("NLSQ_GPU_MEMORY_FRACTION") if gpu_memory_fraction: try: fraction = float(gpu_memory_fraction) if 0.0 < fraction <= 1.0: os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(fraction) logging.info(f"Set XLA_PYTHON_CLIENT_MEM_FRACTION={fraction}") else: warnings.warn( f"Invalid NLSQ_GPU_MEMORY_FRACTION: {gpu_memory_fraction}. Must be between 0.0 and 1.0.", stacklevel=2, ) except ValueError: warnings.warn( f"Invalid NLSQ_GPU_MEMORY_FRACTION: {gpu_memory_fraction}. Must be a number.", stacklevel=2, ) # Configure memory preallocation via JAX config if explicitly requested if os.getenv("NLSQ_DISABLE_GPU_PREALLOCATION") == "1": try: config.update("jax_preallocate_gpu_memory", False) logging.info("Disabled GPU memory preallocation via JAX config") except AttributeError: # JAX version may not support this option logging.warning( "JAX version does not support jax_preallocate_gpu_memory option" ) self._gpu_memory_configured = True def _initialize_memory_config(self): """Initialize memory configuration from environment variables.""" if self._memory_config is not None: return # Collect env overrides, then construct MemoryConfig in one shot # (safe if MemoryConfig ever becomes frozen=True) overrides: dict = {} limit = os.getenv("NLSQ_MEMORY_LIMIT_GB") if limit: try: overrides["memory_limit_gb"] = float(limit) except ValueError: warnings.warn(f"Invalid NLSQ_MEMORY_LIMIT_GB: {limit}", stacklevel=2) fraction = os.getenv("NLSQ_GPU_MEMORY_FRACTION") if fraction: try: overrides["gpu_memory_fraction"] = float(fraction) except ValueError: warnings.warn( f"Invalid NLSQ_GPU_MEMORY_FRACTION: {fraction}", stacklevel=2 ) chunk_size = os.getenv("NLSQ_CHUNK_SIZE_MB") if chunk_size: try: overrides["chunk_size_mb"] = int(chunk_size) except ValueError: warnings.warn(f"Invalid NLSQ_CHUNK_SIZE_MB: {chunk_size}", stacklevel=2) strategy = os.getenv("NLSQ_OOM_STRATEGY") if strategy: if strategy in ["fallback", "reduce", "error"]: overrides["out_of_memory_strategy"] = strategy else: warnings.warn( f"Invalid NLSQ_OOM_STRATEGY: {strategy}. Must be 'fallback', 'reduce', or 'error'.", stacklevel=2, ) safety_factor = os.getenv("NLSQ_SAFETY_FACTOR") if safety_factor: try: overrides["safety_factor"] = float(safety_factor) except ValueError: warnings.warn( f"Invalid NLSQ_SAFETY_FACTOR: {safety_factor}", stacklevel=2 ) if os.getenv("NLSQ_DISABLE_PROGRESS_REPORTING") == "1": overrides["progress_reporting"] = False self._memory_config = MemoryConfig(**overrides) def _initialize_large_dataset_config(self): """Initialize large dataset configuration from environment variables.""" if self._large_dataset_config is not None: return # Load defaults large_dataset_config = LargeDatasetConfig() # Override from environment variables if os.getenv("NLSQ_DISABLE_AUTO_SOLVER_SELECTION") == "1": large_dataset_config.enable_automatic_solver_selection = False self._large_dataset_config = large_dataset_config
[docs] @classmethod def enable_x64(cls, enable: bool = True): """Enable or disable 64-bit precision. Parameters ---------- enable : bool, optional If True, enable 64-bit precision. If False, use 32-bit. Default is True. """ from jax import config instance = cls() if enable and not instance._x64_enabled: config.update("jax_enable_x64", True) instance._x64_enabled = True elif not enable and instance._x64_enabled: config.update("jax_enable_x64", False) instance._x64_enabled = False
[docs] @classmethod def is_x64_enabled(cls) -> bool: """Check if 64-bit precision is enabled. Returns ------- bool True if 64-bit precision is enabled, False otherwise. """ instance = cls() return instance._x64_enabled
[docs] @classmethod @contextmanager def precision_context(cls, use_x64: bool): """Context manager for temporarily changing precision. Parameters ---------- use_x64 : bool If True, use 64-bit precision within context. If False, use 32-bit precision. Examples -------- >>> with JAXConfig.precision_context(use_x64=False): ... # Code here runs with 32-bit precision ... result = some_computation() >>> # Back to previous precision setting """ instance = cls() original_state = instance._x64_enabled try: cls.enable_x64(use_x64) yield finally: cls.enable_x64(original_state)
# Memory configuration methods
[docs] @classmethod def get_memory_config(cls) -> MemoryConfig: """Get the current memory configuration. Returns ------- MemoryConfig Current memory configuration """ instance = cls() if instance._memory_config is None: instance._initialize_memory_config() # Explicit validation after initialization (not assert - can be optimized away) if instance._memory_config is None: raise RuntimeError("Memory config initialization failed") return copy.copy(instance._memory_config)
[docs] @classmethod def set_memory_config(cls, config: MemoryConfig): """Set the memory configuration. Parameters ---------- config : MemoryConfig New memory configuration """ instance = cls() with cls._lock: instance._memory_config = config # Apply GPU memory settings immediately if possible try: if config.gpu_memory_fraction is not None: # Note: JAX memory fraction handling varies by version and backend # This is stored in our config for use by downstream components logging.info( f"Updated GPU memory fraction to {config.gpu_memory_fraction} (stored for downstream use)" ) except ImportError: pass # JAX not available
[docs] @classmethod def get_large_dataset_config(cls) -> LargeDatasetConfig: """Get the current large dataset configuration. Returns ------- LargeDatasetConfig Current large dataset configuration """ instance = cls() if instance._large_dataset_config is None: instance._initialize_large_dataset_config() # Explicit validation after initialization (not assert - can be optimized away) if instance._large_dataset_config is None: raise RuntimeError("Large dataset config initialization failed") return copy.copy(instance._large_dataset_config)
[docs] @classmethod def set_large_dataset_config(cls, config: LargeDatasetConfig): """Set the large dataset configuration. Parameters ---------- config : LargeDatasetConfig New large dataset configuration """ instance = cls() with cls._lock: instance._large_dataset_config = config
[docs] @classmethod @contextmanager def memory_context(cls, memory_config: MemoryConfig): """Context manager for temporarily changing memory configuration. Parameters ---------- memory_config : MemoryConfig Temporary memory configuration Examples -------- >>> from nlsq.config import JAXConfig, MemoryConfig >>> temp_config = MemoryConfig(memory_limit_gb=16.0) >>> with JAXConfig.memory_context(temp_config): ... # Code here runs with increased memory limit ... result = fit_large_dataset(func, x, y) >>> # Back to previous memory settings """ instance = cls() with cls._lock: original_config = copy.deepcopy(instance._memory_config) try: cls.set_memory_config(memory_config) yield finally: if original_config is not None: cls.set_memory_config(original_config) else: instance._memory_config = None instance._initialize_memory_config()
[docs] @classmethod @contextmanager def large_dataset_context(cls, large_dataset_config: LargeDatasetConfig): """Context manager for temporarily changing large dataset configuration. Parameters ---------- large_dataset_config : LargeDatasetConfig Temporary large dataset configuration """ instance = cls() with cls._lock: original_config = copy.deepcopy(instance._large_dataset_config) try: cls.set_large_dataset_config(large_dataset_config) yield finally: if original_config is not None: cls.set_large_dataset_config(original_config) else: instance._large_dataset_config = None instance._initialize_large_dataset_config()
# Initialize configuration on module import _config = JAXConfig() # Convenience functions
[docs] def enable_x64(enable: bool = True): """Enable or disable 64-bit precision. Parameters ---------- enable : bool, optional If True, enable 64-bit precision. If False, use 32-bit. Default is True. """ JAXConfig.enable_x64(enable)
[docs] def is_x64_enabled() -> bool: """Check if 64-bit precision is enabled. Returns ------- bool True if 64-bit precision is enabled, False otherwise. """ return JAXConfig.is_x64_enabled()
[docs] def precision_context(use_x64: bool): """Context manager for temporarily changing precision. Parameters ---------- use_x64 : bool If True, use 64-bit precision within context. If False, use 32-bit precision. Examples -------- >>> from nlsq.config import precision_context >>> with precision_context(use_x64=False): ... # Code here runs with 32-bit precision ... result = some_computation() >>> # Back to previous precision setting """ return JAXConfig.precision_context(use_x64)
# Memory management convenience functions
[docs] def get_memory_config() -> MemoryConfig: """Get the current memory configuration. Returns ------- MemoryConfig Current memory configuration """ return JAXConfig.get_memory_config()
[docs] def set_memory_limits( memory_limit_gb: float, gpu_memory_fraction: float | None = None, safety_factor: float = 0.8, ): """Set memory limits for NLSQ operations. Parameters ---------- memory_limit_gb : float Maximum memory to use in GB gpu_memory_fraction : float, optional Fraction of GPU memory to use (0.0-1.0) safety_factor : float, optional Safety factor for memory calculations (default: 0.8) Examples -------- >>> from nlsq.config import set_memory_limits >>> # Set 16GB memory limit with 80% GPU memory usage >>> set_memory_limits(16.0, gpu_memory_fraction=0.8) """ current_config = get_memory_config() new_config = MemoryConfig( memory_limit_gb=memory_limit_gb, gpu_memory_fraction=( gpu_memory_fraction if gpu_memory_fraction is not None else current_config.gpu_memory_fraction ), safety_factor=safety_factor, chunk_size_mb=current_config.chunk_size_mb, out_of_memory_strategy=current_config.out_of_memory_strategy, auto_chunk_threshold_gb=current_config.auto_chunk_threshold_gb, progress_reporting=current_config.progress_reporting, min_chunk_size=current_config.min_chunk_size, max_chunk_size=current_config.max_chunk_size, ) JAXConfig.set_memory_config(new_config)
[docs] def configure_for_large_datasets( memory_limit_gb: float = 8.0, enable_chunking: bool = True, progress_reporting: bool = True, ): """Configure NLSQ for optimal large dataset performance. This function sets up memory management, chunking, streaming, and other settings for handling large datasets efficiently. Parameters ---------- memory_limit_gb : float, optional Maximum memory to use in GB (default: 8.0) enable_chunking : bool, optional Enable automatic data chunking (default: True) progress_reporting : bool, optional Enable progress reporting for long operations (default: True) Notes ----- All large datasets use streaming optimization for 100% data utilization. Examples -------- >>> from nlsq.config import configure_for_large_datasets >>> # Configure for large datasets with 16GB memory limit >>> configure_for_large_datasets( ... memory_limit_gb=16.0, ... progress_reporting=True ... ) """ # Configure memory settings memory_config = MemoryConfig( memory_limit_gb=memory_limit_gb, auto_chunk_threshold_gb=memory_limit_gb * 0.5 if enable_chunking else float("inf"), progress_reporting=progress_reporting, ) JAXConfig.set_memory_config(memory_config) # Configure large dataset settings large_dataset_config = LargeDatasetConfig( enable_automatic_solver_selection=True, ) JAXConfig.set_large_dataset_config(large_dataset_config) logging.info("Configured NLSQ for large datasets:") logging.info(f" Memory limit: {memory_limit_gb} GB") logging.info(" Streaming: enabled (always available)") logging.info(f" Chunking: {'enabled' if enable_chunking else 'disabled'}") logging.info( f" Progress reporting: {'enabled' if progress_reporting else 'disabled'}" )
[docs] def get_large_dataset_config() -> LargeDatasetConfig: """Get the current large dataset configuration. Returns ------- LargeDatasetConfig Current large dataset configuration """ return JAXConfig.get_large_dataset_config()
[docs] def memory_context(memory_config: MemoryConfig): """Context manager for temporarily changing memory configuration. Parameters ---------- memory_config : MemoryConfig Temporary memory configuration Examples -------- >>> from nlsq.config import memory_context, MemoryConfig >>> temp_config = MemoryConfig(memory_limit_gb=16.0) >>> with memory_context(temp_config): ... # Code here runs with increased memory limit ... result = fit_large_dataset(func, x, y) >>> # Back to previous memory settings """ return JAXConfig.memory_context(memory_config)
[docs] def large_dataset_context(large_dataset_config: LargeDatasetConfig): """Context manager for temporarily changing large dataset configuration. Parameters ---------- large_dataset_config : LargeDatasetConfig Temporary large dataset configuration Examples -------- >>> from nlsq.config import large_dataset_context, LargeDatasetConfig >>> temp_config = LargeDatasetConfig(enable_automatic_solver_selection=True) >>> with large_dataset_context(temp_config): ... # Code here uses automatic solver selection ... result = fit_large_dataset(func, x, y) """ return JAXConfig.large_dataset_context(large_dataset_config)
# Jacobian mode configuration functions
[docs] def get_jacobian_mode() -> tuple[str, str]: """Get Jacobian mode from configuration sources. Configuration precedence (highest to lowest): 1. Environment variable (NLSQ_JACOBIAN_MODE) 2. Config file (~/.nlsq/config.json) 3. Auto-default Returns ------- mode : str Jacobian mode ('auto', 'fwd', or 'rev') source : str Source of the configuration ('environment variable', 'config file', 'auto-default') Examples -------- >>> from nlsq.config import get_jacobian_mode >>> mode, source = get_jacobian_mode() >>> print(f"Using {mode} mode from {source}") Using auto mode from auto-default Notes ----- Valid jacobian_mode values: - 'auto': Automatically select based on problem dimensions - 'fwd': Force forward-mode automatic differentiation (jacfwd) - 'rev': Force reverse-mode automatic differentiation (jacrev) """ # Check environment variable (highest priority) env_mode = os.environ.get("NLSQ_JACOBIAN_MODE") if env_mode: if env_mode in ("auto", "fwd", "rev"): return env_mode, "environment variable" else: warnings.warn( f"Invalid NLSQ_JACOBIAN_MODE: {env_mode}. Must be 'auto', 'fwd', or 'rev'. Using auto-default.", stacklevel=2, ) # Check config file config_path = os.path.expanduser("~/.nlsq/config.json") if os.path.exists(config_path): try: with open(config_path) as f: config = json.load(f) if "jacobian_mode" in config: mode = config["jacobian_mode"] if mode in ("auto", "fwd", "rev"): return mode, "config file" else: warnings.warn( f"Invalid jacobian_mode in config file: {mode}. Must be 'auto', 'fwd', or 'rev'. Using auto-default.", stacklevel=2, ) except (OSError, json.JSONDecodeError) as e: warnings.warn( f"Failed to read Jacobian mode from config file: {e}. Using auto-default.", stacklevel=2, ) # Default to auto return "auto", "auto-default"
[docs] def set_jacobian_mode(mode: str): """Set Jacobian mode via environment variable. This sets the NLSQ_JACOBIAN_MODE environment variable for the current process. To persist the setting, use a config file at ~/.nlsq/config.json. Parameters ---------- mode : str Jacobian mode ('auto', 'fwd', or 'rev') Raises ------ ValueError If mode is not one of 'auto', 'fwd', 'rev' Examples -------- >>> from nlsq.config import set_jacobian_mode >>> set_jacobian_mode('rev') # Force reverse-mode AD for all fits """ if mode not in ("auto", "fwd", "rev"): raise ValueError( f"Invalid jacobian_mode: {mode}. Must be 'auto', 'fwd', or 'rev'." ) os.environ["NLSQ_JACOBIAN_MODE"] = mode logging.info(f"Set Jacobian mode to '{mode}' via environment variable")