"""Central configuration management for NLSQ package."""
from __future__ import annotations
import copy
import json
import logging
import os
import threading
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
[docs]
@dataclass(slots=True)
class MemoryConfig:
"""Configuration for memory management and GPU settings.
Attributes
----------
memory_limit_gb : float
Maximum memory limit in GB (default: 8.0)
gpu_memory_fraction : float | None
Fraction of GPU memory to use (0.0-1.0, None for automatic)
chunk_size_mb : int | None
Default chunk size in MB for data processing
out_of_memory_strategy : str
Strategy when out of memory: 'fallback', 'reduce', 'error'
safety_factor : float
Safety factor for memory calculations (0.0-1.0)
auto_chunk_threshold_gb : float
Automatically enable chunking above this memory threshold
progress_reporting : bool
Enable progress reporting for large operations
min_chunk_size : int
Minimum chunk size in data points
max_chunk_size : int
Maximum chunk size in data points
"""
memory_limit_gb: float = 8.0
gpu_memory_fraction: float | None = None
chunk_size_mb: int | None = None
out_of_memory_strategy: str = "fallback"
safety_factor: float = 0.8
auto_chunk_threshold_gb: float = 4.0
progress_reporting: bool = True
min_chunk_size: int = 1000
max_chunk_size: int = 1_000_000
[docs]
def __post_init__(self):
"""Validate configuration values."""
if not 0.1 <= self.memory_limit_gb <= 1024:
raise ValueError(
f"memory_limit_gb must be between 0.1 and 1024, got {self.memory_limit_gb}"
)
if self.gpu_memory_fraction is not None:
if not 0.0 < self.gpu_memory_fraction <= 1.0:
raise ValueError(
f"gpu_memory_fraction must be between 0.0 and 1.0, got {self.gpu_memory_fraction}"
)
if not 0.1 <= self.safety_factor <= 1.0:
raise ValueError(
f"safety_factor must be between 0.1 and 1.0, got {self.safety_factor}"
)
if self.out_of_memory_strategy not in ["fallback", "reduce", "error"]:
raise ValueError(
f"out_of_memory_strategy must be 'fallback', 'reduce', or 'error', got {self.out_of_memory_strategy}"
)
if self.min_chunk_size > self.max_chunk_size:
raise ValueError(
f"min_chunk_size ({self.min_chunk_size}) cannot be larger than max_chunk_size ({self.max_chunk_size})"
)
[docs]
@dataclass(slots=True)
class LargeDatasetConfig:
"""Configuration for large dataset processing.
Attributes
----------
enable_automatic_solver_selection : bool
Automatically select optimal solver based on dataset size
solver_selection_thresholds : Dict[str, int]
Thresholds for automatic solver selection
Notes
-----
As of v0.2.0, all subsampling parameters have been removed. Use streaming
optimization instead for unlimited datasets. See MIGRATION_V0.2.0.md for
migration instructions.
"""
enable_automatic_solver_selection: bool = True
solver_selection_thresholds: dict[str, int] = field(
default_factory=lambda: {
"direct": 100_000, # Use direct solver below this size
"iterative": 10_000_000, # Use iterative solver below this size
"chunked": 100_000_000, # Use chunked processing above this size
}
)
[docs]
def __post_init__(self) -> None:
"""Validate configuration values."""
if not self.solver_selection_thresholds:
return # Empty dict means "apply no threshold overrides" — valid.
required_keys = {"direct", "iterative", "chunked"}
missing = required_keys - set(self.solver_selection_thresholds)
if missing:
raise ValueError(
f"solver_selection_thresholds missing required keys: {missing}"
)
thresholds = self.solver_selection_thresholds
if not (thresholds["direct"] < thresholds["iterative"] < thresholds["chunked"]):
raise ValueError(
"solver_selection_thresholds must be monotonically increasing: "
f"direct={thresholds['direct']} < iterative={thresholds['iterative']} "
f"< chunked={thresholds['chunked']}"
)
for key, val in thresholds.items():
if val <= 0:
raise ValueError(
f"solver_selection_thresholds['{key}'] must be positive, "
f"got {val!r}"
)
[docs]
class JAXConfig:
"""Singleton configuration manager for JAX and memory settings.
This class ensures that JAX configuration is set once and consistently
across all NLSQ modules, avoiding duplicate configuration calls. It also
manages memory settings and large dataset configuration.
"""
_instance: JAXConfig | None = None
_lock: threading.RLock = threading.RLock()
_x64_enabled: bool = False
_initialized: bool = False
_memory_config: MemoryConfig | None = None
_large_dataset_config: LargeDatasetConfig | None = None
_gpu_memory_configured: bool = False
[docs]
def __new__(cls) -> JAXConfig:
"""Ensure singleton pattern (thread-safe)."""
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
[docs]
def __init__(self):
"""Initialize JAX and memory configuration if not already done."""
with self._lock:
if not self._initialized:
self._initialize_jax()
self._initialize_memory_config()
self._initialize_large_dataset_config()
self._initialized = True
def _initialize_jax(self):
"""Initialize JAX with default NLSQ settings."""
# Import here to avoid circular imports
from jax import config
# Force CPU backend if requested (useful for testing)
if (
os.getenv("NLSQ_FORCE_CPU", "0") == "1"
or os.getenv("JAX_PLATFORM_NAME") == "cpu"
):
config.update("jax_platform_name", "cpu")
# Enable 64-bit precision by default for NLSQ
if not self._x64_enabled and os.getenv("NLSQ_DISABLE_X64") != "1":
config.update("jax_enable_x64", True)
self._x64_enabled = True
# Configure persistent compilation cache (eliminates 2-10s cold start)
self._configure_persistent_cache(config)
# Configure GPU memory if specified
self._configure_gpu_memory(config)
def _configure_persistent_cache(self, config):
"""Configure JAX persistent compilation cache.
This enables caching of compiled functions across Python sessions,
eliminating cold-start overhead of 2-10 seconds.
"""
# Skip if explicitly disabled
if os.getenv("NLSQ_DISABLE_PERSISTENT_CACHE") == "1":
return
# Set cache directory
cache_dir = os.getenv(
"NLSQ_JAX_CACHE_DIR", os.path.expanduser("~/.cache/nlsq/jax_cache")
)
try:
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# Enable persistent compilation cache
config.update("jax_compilation_cache_dir", cache_dir)
# Only cache compilations that take at least 1 second
min_compile_time = float(os.getenv("NLSQ_CACHE_MIN_COMPILE_TIME_SECS", "1"))
config.update(
"jax_persistent_cache_min_compile_time_secs", min_compile_time
)
logging.debug(f"JAX persistent cache enabled at {cache_dir}")
except Exception as e:
# Non-fatal: log warning and continue without persistent cache
logging.warning(
f"Failed to enable JAX persistent compilation cache: {e}. "
"Cold-start may be slower."
)
def _configure_gpu_memory(self, config):
"""Configure GPU memory settings via XLA environment variables.
Sets XLA_PYTHON_CLIENT_PREALLOCATE and XLA_PYTHON_CLIENT_MEM_FRACTION
to prevent XLA from consuming all GPU memory. These must be set before
JAX initializes the GPU backend, so we set them as early as possible.
Without these defaults, XLA preallocates 75% of GPU memory at startup
and never releases it, which combined with unbounded compilation cache
growth can cause OOM for long-running sessions.
"""
if self._gpu_memory_configured:
return
# Default: disable preallocation to allow memory to grow/shrink
# This prevents XLA from grabbing 75% of GPU memory at startup.
# Users can override via XLA_PYTHON_CLIENT_PREALLOCATE=true.
if "XLA_PYTHON_CLIENT_PREALLOCATE" not in os.environ:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
logging.debug("Set XLA_PYTHON_CLIENT_PREALLOCATE=false (default for NLSQ)")
# Check for NLSQ-specific GPU memory fraction setting
gpu_memory_fraction = os.getenv("NLSQ_GPU_MEMORY_FRACTION")
if gpu_memory_fraction:
try:
fraction = float(gpu_memory_fraction)
if 0.0 < fraction <= 1.0:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(fraction)
logging.info(f"Set XLA_PYTHON_CLIENT_MEM_FRACTION={fraction}")
else:
warnings.warn(
f"Invalid NLSQ_GPU_MEMORY_FRACTION: {gpu_memory_fraction}. Must be between 0.0 and 1.0.",
stacklevel=2,
)
except ValueError:
warnings.warn(
f"Invalid NLSQ_GPU_MEMORY_FRACTION: {gpu_memory_fraction}. Must be a number.",
stacklevel=2,
)
# Configure memory preallocation via JAX config if explicitly requested
if os.getenv("NLSQ_DISABLE_GPU_PREALLOCATION") == "1":
try:
config.update("jax_preallocate_gpu_memory", False)
logging.info("Disabled GPU memory preallocation via JAX config")
except AttributeError:
# JAX version may not support this option
logging.warning(
"JAX version does not support jax_preallocate_gpu_memory option"
)
self._gpu_memory_configured = True
def _initialize_memory_config(self):
"""Initialize memory configuration from environment variables."""
if self._memory_config is not None:
return
# Collect env overrides, then construct MemoryConfig in one shot
# (safe if MemoryConfig ever becomes frozen=True)
overrides: dict = {}
limit = os.getenv("NLSQ_MEMORY_LIMIT_GB")
if limit:
try:
overrides["memory_limit_gb"] = float(limit)
except ValueError:
warnings.warn(f"Invalid NLSQ_MEMORY_LIMIT_GB: {limit}", stacklevel=2)
fraction = os.getenv("NLSQ_GPU_MEMORY_FRACTION")
if fraction:
try:
overrides["gpu_memory_fraction"] = float(fraction)
except ValueError:
warnings.warn(
f"Invalid NLSQ_GPU_MEMORY_FRACTION: {fraction}", stacklevel=2
)
chunk_size = os.getenv("NLSQ_CHUNK_SIZE_MB")
if chunk_size:
try:
overrides["chunk_size_mb"] = int(chunk_size)
except ValueError:
warnings.warn(f"Invalid NLSQ_CHUNK_SIZE_MB: {chunk_size}", stacklevel=2)
strategy = os.getenv("NLSQ_OOM_STRATEGY")
if strategy:
if strategy in ["fallback", "reduce", "error"]:
overrides["out_of_memory_strategy"] = strategy
else:
warnings.warn(
f"Invalid NLSQ_OOM_STRATEGY: {strategy}. Must be 'fallback', 'reduce', or 'error'.",
stacklevel=2,
)
safety_factor = os.getenv("NLSQ_SAFETY_FACTOR")
if safety_factor:
try:
overrides["safety_factor"] = float(safety_factor)
except ValueError:
warnings.warn(
f"Invalid NLSQ_SAFETY_FACTOR: {safety_factor}", stacklevel=2
)
if os.getenv("NLSQ_DISABLE_PROGRESS_REPORTING") == "1":
overrides["progress_reporting"] = False
self._memory_config = MemoryConfig(**overrides)
def _initialize_large_dataset_config(self):
"""Initialize large dataset configuration from environment variables."""
if self._large_dataset_config is not None:
return
# Load defaults
large_dataset_config = LargeDatasetConfig()
# Override from environment variables
if os.getenv("NLSQ_DISABLE_AUTO_SOLVER_SELECTION") == "1":
large_dataset_config.enable_automatic_solver_selection = False
self._large_dataset_config = large_dataset_config
[docs]
@classmethod
def enable_x64(cls, enable: bool = True):
"""Enable or disable 64-bit precision.
Parameters
----------
enable : bool, optional
If True, enable 64-bit precision. If False, use 32-bit.
Default is True.
"""
from jax import config
instance = cls()
if enable and not instance._x64_enabled:
config.update("jax_enable_x64", True)
instance._x64_enabled = True
elif not enable and instance._x64_enabled:
config.update("jax_enable_x64", False)
instance._x64_enabled = False
[docs]
@classmethod
def is_x64_enabled(cls) -> bool:
"""Check if 64-bit precision is enabled.
Returns
-------
bool
True if 64-bit precision is enabled, False otherwise.
"""
instance = cls()
return instance._x64_enabled
[docs]
@classmethod
@contextmanager
def precision_context(cls, use_x64: bool):
"""Context manager for temporarily changing precision.
Parameters
----------
use_x64 : bool
If True, use 64-bit precision within context.
If False, use 32-bit precision.
Examples
--------
>>> with JAXConfig.precision_context(use_x64=False):
... # Code here runs with 32-bit precision
... result = some_computation()
>>> # Back to previous precision setting
"""
instance = cls()
original_state = instance._x64_enabled
try:
cls.enable_x64(use_x64)
yield
finally:
cls.enable_x64(original_state)
# Memory configuration methods
[docs]
@classmethod
def get_memory_config(cls) -> MemoryConfig:
"""Get the current memory configuration.
Returns
-------
MemoryConfig
Current memory configuration
"""
instance = cls()
if instance._memory_config is None:
instance._initialize_memory_config()
# Explicit validation after initialization (not assert - can be optimized away)
if instance._memory_config is None:
raise RuntimeError("Memory config initialization failed")
return copy.copy(instance._memory_config)
[docs]
@classmethod
def set_memory_config(cls, config: MemoryConfig):
"""Set the memory configuration.
Parameters
----------
config : MemoryConfig
New memory configuration
"""
instance = cls()
with cls._lock:
instance._memory_config = config
# Apply GPU memory settings immediately if possible
try:
if config.gpu_memory_fraction is not None:
# Note: JAX memory fraction handling varies by version and backend
# This is stored in our config for use by downstream components
logging.info(
f"Updated GPU memory fraction to {config.gpu_memory_fraction} (stored for downstream use)"
)
except ImportError:
pass # JAX not available
[docs]
@classmethod
def get_large_dataset_config(cls) -> LargeDatasetConfig:
"""Get the current large dataset configuration.
Returns
-------
LargeDatasetConfig
Current large dataset configuration
"""
instance = cls()
if instance._large_dataset_config is None:
instance._initialize_large_dataset_config()
# Explicit validation after initialization (not assert - can be optimized away)
if instance._large_dataset_config is None:
raise RuntimeError("Large dataset config initialization failed")
return copy.copy(instance._large_dataset_config)
[docs]
@classmethod
def set_large_dataset_config(cls, config: LargeDatasetConfig):
"""Set the large dataset configuration.
Parameters
----------
config : LargeDatasetConfig
New large dataset configuration
"""
instance = cls()
with cls._lock:
instance._large_dataset_config = config
[docs]
@classmethod
@contextmanager
def memory_context(cls, memory_config: MemoryConfig):
"""Context manager for temporarily changing memory configuration.
Parameters
----------
memory_config : MemoryConfig
Temporary memory configuration
Examples
--------
>>> from nlsq.config import JAXConfig, MemoryConfig
>>> temp_config = MemoryConfig(memory_limit_gb=16.0)
>>> with JAXConfig.memory_context(temp_config):
... # Code here runs with increased memory limit
... result = fit_large_dataset(func, x, y)
>>> # Back to previous memory settings
"""
instance = cls()
with cls._lock:
original_config = copy.deepcopy(instance._memory_config)
try:
cls.set_memory_config(memory_config)
yield
finally:
if original_config is not None:
cls.set_memory_config(original_config)
else:
instance._memory_config = None
instance._initialize_memory_config()
[docs]
@classmethod
@contextmanager
def large_dataset_context(cls, large_dataset_config: LargeDatasetConfig):
"""Context manager for temporarily changing large dataset configuration.
Parameters
----------
large_dataset_config : LargeDatasetConfig
Temporary large dataset configuration
"""
instance = cls()
with cls._lock:
original_config = copy.deepcopy(instance._large_dataset_config)
try:
cls.set_large_dataset_config(large_dataset_config)
yield
finally:
if original_config is not None:
cls.set_large_dataset_config(original_config)
else:
instance._large_dataset_config = None
instance._initialize_large_dataset_config()
# Initialize configuration on module import
_config = JAXConfig()
# Convenience functions
[docs]
def enable_x64(enable: bool = True):
"""Enable or disable 64-bit precision.
Parameters
----------
enable : bool, optional
If True, enable 64-bit precision. If False, use 32-bit.
Default is True.
"""
JAXConfig.enable_x64(enable)
[docs]
def is_x64_enabled() -> bool:
"""Check if 64-bit precision is enabled.
Returns
-------
bool
True if 64-bit precision is enabled, False otherwise.
"""
return JAXConfig.is_x64_enabled()
[docs]
def precision_context(use_x64: bool):
"""Context manager for temporarily changing precision.
Parameters
----------
use_x64 : bool
If True, use 64-bit precision within context.
If False, use 32-bit precision.
Examples
--------
>>> from nlsq.config import precision_context
>>> with precision_context(use_x64=False):
... # Code here runs with 32-bit precision
... result = some_computation()
>>> # Back to previous precision setting
"""
return JAXConfig.precision_context(use_x64)
# Memory management convenience functions
[docs]
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration.
Returns
-------
MemoryConfig
Current memory configuration
"""
return JAXConfig.get_memory_config()
[docs]
def set_memory_limits(
memory_limit_gb: float,
gpu_memory_fraction: float | None = None,
safety_factor: float = 0.8,
):
"""Set memory limits for NLSQ operations.
Parameters
----------
memory_limit_gb : float
Maximum memory to use in GB
gpu_memory_fraction : float, optional
Fraction of GPU memory to use (0.0-1.0)
safety_factor : float, optional
Safety factor for memory calculations (default: 0.8)
Examples
--------
>>> from nlsq.config import set_memory_limits
>>> # Set 16GB memory limit with 80% GPU memory usage
>>> set_memory_limits(16.0, gpu_memory_fraction=0.8)
"""
current_config = get_memory_config()
new_config = MemoryConfig(
memory_limit_gb=memory_limit_gb,
gpu_memory_fraction=(
gpu_memory_fraction
if gpu_memory_fraction is not None
else current_config.gpu_memory_fraction
),
safety_factor=safety_factor,
chunk_size_mb=current_config.chunk_size_mb,
out_of_memory_strategy=current_config.out_of_memory_strategy,
auto_chunk_threshold_gb=current_config.auto_chunk_threshold_gb,
progress_reporting=current_config.progress_reporting,
min_chunk_size=current_config.min_chunk_size,
max_chunk_size=current_config.max_chunk_size,
)
JAXConfig.set_memory_config(new_config)
[docs]
def get_large_dataset_config() -> LargeDatasetConfig:
"""Get the current large dataset configuration.
Returns
-------
LargeDatasetConfig
Current large dataset configuration
"""
return JAXConfig.get_large_dataset_config()
[docs]
def memory_context(memory_config: MemoryConfig):
"""Context manager for temporarily changing memory configuration.
Parameters
----------
memory_config : MemoryConfig
Temporary memory configuration
Examples
--------
>>> from nlsq.config import memory_context, MemoryConfig
>>> temp_config = MemoryConfig(memory_limit_gb=16.0)
>>> with memory_context(temp_config):
... # Code here runs with increased memory limit
... result = fit_large_dataset(func, x, y)
>>> # Back to previous memory settings
"""
return JAXConfig.memory_context(memory_config)
[docs]
def large_dataset_context(large_dataset_config: LargeDatasetConfig):
"""Context manager for temporarily changing large dataset configuration.
Parameters
----------
large_dataset_config : LargeDatasetConfig
Temporary large dataset configuration
Examples
--------
>>> from nlsq.config import large_dataset_context, LargeDatasetConfig
>>> temp_config = LargeDatasetConfig(enable_automatic_solver_selection=True)
>>> with large_dataset_context(temp_config):
... # Code here uses automatic solver selection
... result = fit_large_dataset(func, x, y)
"""
return JAXConfig.large_dataset_context(large_dataset_config)
# Jacobian mode configuration functions
[docs]
def get_jacobian_mode() -> tuple[str, str]:
"""Get Jacobian mode from configuration sources.
Configuration precedence (highest to lowest):
1. Environment variable (NLSQ_JACOBIAN_MODE)
2. Config file (~/.nlsq/config.json)
3. Auto-default
Returns
-------
mode : str
Jacobian mode ('auto', 'fwd', or 'rev')
source : str
Source of the configuration ('environment variable', 'config file', 'auto-default')
Examples
--------
>>> from nlsq.config import get_jacobian_mode
>>> mode, source = get_jacobian_mode()
>>> print(f"Using {mode} mode from {source}")
Using auto mode from auto-default
Notes
-----
Valid jacobian_mode values:
- 'auto': Automatically select based on problem dimensions
- 'fwd': Force forward-mode automatic differentiation (jacfwd)
- 'rev': Force reverse-mode automatic differentiation (jacrev)
"""
# Check environment variable (highest priority)
env_mode = os.environ.get("NLSQ_JACOBIAN_MODE")
if env_mode:
if env_mode in ("auto", "fwd", "rev"):
return env_mode, "environment variable"
else:
warnings.warn(
f"Invalid NLSQ_JACOBIAN_MODE: {env_mode}. Must be 'auto', 'fwd', or 'rev'. Using auto-default.",
stacklevel=2,
)
# Check config file
config_path = os.path.expanduser("~/.nlsq/config.json")
if os.path.exists(config_path):
try:
with open(config_path) as f:
config = json.load(f)
if "jacobian_mode" in config:
mode = config["jacobian_mode"]
if mode in ("auto", "fwd", "rev"):
return mode, "config file"
else:
warnings.warn(
f"Invalid jacobian_mode in config file: {mode}. Must be 'auto', 'fwd', or 'rev'. Using auto-default.",
stacklevel=2,
)
except (OSError, json.JSONDecodeError) as e:
warnings.warn(
f"Failed to read Jacobian mode from config file: {e}. Using auto-default.",
stacklevel=2,
)
# Default to auto
return "auto", "auto-default"
[docs]
def set_jacobian_mode(mode: str):
"""Set Jacobian mode via environment variable.
This sets the NLSQ_JACOBIAN_MODE environment variable for the current process.
To persist the setting, use a config file at ~/.nlsq/config.json.
Parameters
----------
mode : str
Jacobian mode ('auto', 'fwd', or 'rev')
Raises
------
ValueError
If mode is not one of 'auto', 'fwd', 'rev'
Examples
--------
>>> from nlsq.config import set_jacobian_mode
>>> set_jacobian_mode('rev') # Force reverse-mode AD for all fits
"""
if mode not in ("auto", "fwd", "rev"):
raise ValueError(
f"Invalid jacobian_mode: {mode}. Must be 'auto', 'fwd', or 'rev'."
)
os.environ["NLSQ_JACOBIAN_MODE"] = mode
logging.info(f"Set Jacobian mode to '{mode}' via environment variable")