nlsq.caching module¶
Function caching and JIT compilation utilities.
Caching and memory management modules.
This subpackage contains caching and memory management: - core: Basic caching utilities - smart_cache: SmartCache with intelligent invalidation - unified_cache: Unified cache management - compilation_cache: JIT compilation caching - memory_manager: Memory management and tracking - memory_pool: Memory pooling for optimization
- class nlsq.caching.CompilationCache(enable_stats=True, max_cache_size=512)[source]
Bases:
objectCache for JIT-compiled functions with LRU eviction.
Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.
Phase 3 Optimizations (2.2a, 2.1a):
Uses OrderedDict for LRU tracking with move_to_end() on hits
Evicts oldest entry with popitem(last=False) when at capacity
Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name
Thread Safety¶
All public methods are guarded by a per-instance
threading.Lock. The lock is held only for dict operations and released duringjax.jit()compilation to avoid blocking concurrent threads.- cache
OrderedDict mapping function signatures to compiled functions (enables LRU eviction)
- Type:
OrderedDict
- max_cache_size
Maximum number of compiled functions to cache (default 512)
- Type:
- stats
Compilation cache statistics
- Type:
- _func_hash_cache
Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks
- Type:
- __init__(enable_stats=True, max_cache_size=512)[source]
Initialize compilation cache.
- compile(func, static_argnums=(), donate_argnums=())[source]
Compile function with JIT and cache result.
- get_or_compile(func, *args, static_argnums=(), **kwargs)[source]
Get cached compiled function or compile if not cached.
- Parameters:
- Returns:
compiled_func (callable) – Compiled function
signature (str) – Function signature
- Return type:
- clear()[source]
Clear compilation cache, function hash memoization, and reset stats.
This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache hit rate and other statistics
- Return type:
- __enter__()[source]
Context manager entry.
- __exit__(exc_type, exc_val, exc_tb)[source]
Context manager exit.
- class nlsq.caching.MemoryManager(gc_threshold=0.8, safety_factor=1.2, enable_adaptive_safety=False, disable_padding=False, memory_cache_ttl=1.0, adaptive_ttl=True)[source]
Bases:
objectIntelligent memory management for optimization algorithms.
This class provides: - Memory usage monitoring and prediction - Array pooling to reduce allocations with LRU eviction - Automatic garbage collection triggers - Context managers for memory-safe operations
LRU Memory Pool (Task Group 7 - 1.2a)¶
The memory pool uses an OrderedDict to track access order, enabling true LRU (Least Recently Used) eviction when at capacity. This improves cache utilization for frequently accessed array shapes by 5-10%.
Telemetry Circular Buffer (Task Group 9 - 1.3a)¶
The safety telemetry uses a deque with maxlen=1000 to prevent memory leak in multi-day optimization runs. This maintains the last 1000 telemetry records for adaptive safety factor calculation.
- memory_pool
Pool of reusable arrays indexed by (shape, dtype) with LRU tracking
- Type:
OrderedDict
- allocation_history
History of memory allocations
- Type:
- gc_threshold
Memory usage threshold (0-1) for triggering garbage collection
- Type:
- safety_factor
Safety factor for memory predictions
- Type:
- __init__(gc_threshold=0.8, safety_factor=1.2, enable_adaptive_safety=False, disable_padding=False, memory_cache_ttl=1.0, adaptive_ttl=True)[source]
Initialize memory manager.
- Parameters:
gc_threshold (float) – Trigger GC when memory usage exceeds this fraction (0-1)
safety_factor (float) – Multiply memory requirements by this factor for safety
enable_adaptive_safety (bool) – Enable adaptive safety factor reduction (1.2 -> 1.05 after warmup)
disable_padding (bool) – Disable padding/bucketing for strict memory environments (Task 5.6). When True: uses exact shapes, sets safety_factor=1.0. Use case: cloud quotas, strict memory limits.
memory_cache_ttl (float) – TTL in seconds for cached memory info (default: 1.0). Reduces psutil system call overhead by 90%.
adaptive_ttl (bool) – Enable adaptive TTL based on call frequency (default: True). High-frequency callers (>100 calls/sec) get 15s effective TTL. Medium-frequency callers (>10 calls/sec) get 10s effective TTL. Low-frequency callers use the default TTL. Reduces psutil overhead in streaming optimization by 15-20%.
- get_available_memory()[source]
Get available memory in bytes.
- Returns:
available – Available memory in bytes
- Return type:
Notes
Uses TTL-based caching to reduce psutil system call overhead by 90%. When adaptive_ttl is enabled, the effective TTL is adjusted based on call frequency to further reduce overhead for streaming optimization.
- get_memory_usage_bytes()[source]
Get current memory usage in bytes.
- Returns:
usage – Current memory usage in bytes
- Return type:
Notes
Uses TTL-based caching to reduce psutil system call overhead by 90%.
- get_memory_usage_fraction()[source]
Get current memory usage as fraction of total.
- Returns:
fraction – Memory usage fraction (0-1)
- Return type:
Notes
Uses TTL-based caching to reduce psutil system call overhead by 90%.
- predict_memory_requirement(n_points, n_params, algorithm='trf', dtype=<class 'jax.numpy.float64'>)[source]
Predict memory requirement for optimization.
- Parameters:
- Returns:
bytes_needed – Estimated memory requirement in bytes
- Return type:
Notes
Memory requirements scale linearly with precision: - float32: 4 bytes per element (50% memory savings) - float64: 8 bytes per element (default, higher precision)
- check_memory_availability(bytes_needed)[source]
Check if enough memory is available.
- memory_guard(bytes_needed)[source]
Context manager to ensure memory availability.
- Parameters:
bytes_needed (int) – Required memory in bytes
- Raises:
MemoryError – If insufficient memory is available
- get_safety_telemetry()[source]
Get safety factor telemetry statistics.
- Returns:
telemetry – Safety factor telemetry with: - current_safety_factor: Current safety factor - initial_safety_factor: Initial safety factor (1.2) - min_safety_factor: Target minimum (1.05) - telemetry_entries: Number of telemetry entries collected - p95_safety_needed: 95th percentile of safety factors needed (if data available) - safety_factor_history: List of safety factors over time
- Return type:
- allocate_array(shape, dtype=<class 'numpy.float64'>, zero=True)[source]
Allocate array with memory pooling and LRU tracking.
- Parameters:
- Returns:
array – Allocated array
- Return type:
np.ndarray
- Raises:
MemoryError – If allocation fails
Notes
Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. When an array is reused from the pool, it is moved to the end (most recently used) to enable proper LRU eviction.
- free_array(arr)[source]
Return array to pool for reuse.
- Parameters:
arr (np.ndarray) – Array to free
Notes
Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. The returned array is added/moved to the end of the pool, marking it as recently used.
- clear_pool()[source]
Clear memory pool and run garbage collection.
- get_memory_stats()[source]
Get memory usage statistics.
- Returns:
stats – Memory statistics including current usage, peak, pool size
- Return type:
- optimize_memory_pool(max_arrays=100)[source]
Optimize memory pool using LRU eviction.
- Parameters:
max_arrays (int) – Maximum number of arrays to keep in pool
Notes
Task Group 7 (1.2a): Uses LRU eviction via popitem(last=False). Arrays are evicted in order of least recent use, keeping the most recently used arrays in the pool.
- temporary_allocation(shape, dtype=<class 'numpy.float64'>)[source]
Context manager for temporary array allocation.
- estimate_chunking_strategy(n_points, n_params, algorithm='trf', memory_limit_gb=None)[source]
Estimate optimal chunking strategy for large datasets.
- class nlsq.caching.MemoryPool(max_pool_size=10, enable_stats=False, enable_bucketing=True)[source]
Bases:
objectMemory pool for reusable array buffers.
Pre-allocates buffers for common array shapes to avoid repeated allocations during optimization iterations.
- pools
Dictionary mapping (shape, dtype) to list of available buffers
- Type:
- allocated
Dictionary tracking allocated buffers
- Type:
- max_pool_size
Maximum number of buffers per shape/dtype combination
- Type:
- stats
Statistics on pool usage
- Type:
- __init__(max_pool_size=10, enable_stats=False, enable_bucketing=True)[source]
Initialize memory pool.
- allocate(shape, dtype=<class 'jax.numpy.float64'>)[source]
Allocate array from pool or create new one.
- Parameters:
- Returns:
array – Allocated array (may be reused from pool)
- Return type:
jnp.ndarray
Notes
When bucketing is enabled, arrays are pooled by size classes (1KB/10KB/100KB) for better reuse rates (Task 5.4).
- release(arr)[source]
Return array to pool for reuse.
- Parameters:
arr (jnp.ndarray) – Array to return to pool
Notes
When bucketing is enabled, arrays are stored in size-class buckets for better reuse (Task 5.4).
- clear()[source]
Clear all pools and reset statistics.
- get_stats()[source]
Get pool usage statistics.
- Returns:
stats – Pool usage statistics including reuse_rate (Task 5.5)
- Return type:
Notes
reuse_rate = reused_allocations / total_allocations With bucketing enabled, expect 5x higher reuse rates.
- __enter__()[source]
Context manager entry.
- __exit__(exc_type, exc_val, exc_tb)[source]
Context manager exit - clear pool.
- class nlsq.caching.SmartCache(cache_dir='.nlsq_cache', max_memory_items=1000, disk_cache_enabled=True, enable_stats=True)[source]
Bases:
objectIntelligent caching system for optimization computations.
This class provides:
Memory and disk caching with LRU eviction
Automatic cache key generation from function arguments
Cache persistence across sessions
Cache invalidation and warming strategies
Phase 3 Optimizations (3.2a):
Array hash optimization: uses stride-based sampling only for arrays with >10000 elements when xxhash is unavailable
For smaller arrays, hashes full array directly without redundant sampling, providing 15-20% improvement in cache key generation
All dict operations are protected by a per-instance
threading.Lockso that concurrent threads can safely callget/set/invalidate.- cache_dir
Directory for disk cache storage
- Type:
- memory_cache
In-memory cache storage
- Type:
- disk_cache_enabled
Whether disk caching is enabled
- Type:
- max_memory_items
Maximum items in memory cache
- Type:
- cache_stats
Cache hit/miss statistics
- Type:
- __init__(cache_dir='.nlsq_cache', max_memory_items=1000, disk_cache_enabled=True, enable_stats=True)[source]
Initialize smart cache.
- cache_key(*args, **kwargs)[source]
Generate cache key from arguments.
- Parameters:
- Returns:
key – Hash of arguments (xxhash if available, BLAKE2b fallback)
- Return type:
Notes
Uses xxhash (xxh64) when available for ~10x faster hashing compared to SHA256/BLAKE2b. Falls back to BLAKE2b if xxhash is not installed. All cache keys are prefixed with CACHE_VERSION to ensure old cache entries are invalidated when the hash algorithm changes.
Task 9.3 (3.2a): Array hash optimization - Arrays <= 10000 elements: hash full array directly (no sampling) - Arrays > 10000 elements: use stride-based sampling for efficiency - Removes redundant sampling when computing full hash in fallback path
- get(key)[source]
Get value from cache.
- Parameters:
key (str) – Cache key
- Returns:
value – Cached value or None if not found
- Return type:
Any or None
- set(key, value)[source]
Set value in cache.
- Parameters:
key (str) – Cache key
value (Any) – Value to cache
- invalidate(key=None)[source]
Invalidate cache entries.
- Parameters:
key (str, optional) – Specific key to invalidate, or None to clear all
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache statistics including hit rate
- Return type:
- optimize_cache()[source]
Optimize cache by removing rarely accessed items.
Computes threshold from snapshot, then re-checks live counts under lock before invalidating to avoid evicting keys that became hot between the snapshot and eviction.
- class nlsq.caching.TRFMemoryPool(m, n, dtype=<class 'jax.numpy.float64'>)[source]
Bases:
objectSpecialized memory pool for Trust Region Reflective algorithm.
Pre-allocates buffers for common TRF operations.
- Parameters:
- __init__(m, n, dtype=<class 'jax.numpy.float64'>)[source]
Initialize TRF memory pool.
- get_jacobian_buffer()[source]
Get Jacobian buffer (m×n).
- get_residual_buffer()[source]
Get residual buffer (m).
- get_gradient_buffer()[source]
Get gradient buffer (n).
- get_step_buffer()[source]
Get step buffer (n).
- get_x_buffer()[source]
Get parameter buffer (n).
- reset()[source]
Reset all buffers to zero.
- nlsq.caching.cached_function(cache=None, ttl=None)[source]
Decorator for caching function results.
- Parameters:
cache (SmartCache, optional) – Cache instance to use (creates new if None)
ttl (float, optional) – Time-to-live in seconds for cached values
- Returns:
decorator – Decorator function
- Return type:
function
- nlsq.caching.cached_jacobian(cache=None)[source]
Decorator specifically for caching Jacobian evaluations.
- Parameters:
cache (SmartCache, optional) – Cache instance to use
- Returns:
decorator – Decorator function
- Return type:
function
- nlsq.caching.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]
Decorator for caching JIT-compiled functions.
- Parameters:
- Returns:
decorated – Decorated function with cached compilation
- Return type:
callable
Examples
>>> @cached_jit ... def my_function(x): ... return x ** 2
>>> @cached_jit(static_argnums=(1,)) ... def my_function_with_static(x, n): ... return x ** n
- nlsq.caching.clear_all_caches()[source]
Clear all global caches.
- nlsq.caching.clear_compilation_cache()[source]
Clear the global compilation cache, function hash cache, and reset stats.
This function is useful in interactive environments (Jupyter notebooks, IPython, REPL) where model functions may be redefined during development. When a function is redefined with the same name but different code, the compilation cache may return stale compiled versions. Calling this function clears the cache and allows fresh compilation of redefined functions.
The function also resets all statistics counters (hits, misses, compilations) to zero, enabling accurate cache performance tracking after the clear.
Examples
In a Jupyter notebook, after redefining a model function:
>>> import jax.numpy as jnp >>> from nlsq.caching.compilation_cache import clear_compilation_cache >>> from nlsq import curve_fit >>> >>> # Initial model definition >>> def model(x, a, b): ... return a * jnp.exp(-b * x) >>> >>> # ... some fitting work ... >>> >>> # Redefine the model with different implementation >>> def model(x, a, b): ... return a * jnp.exp(-b * x) + 0.1 # Added offset >>> >>> # Clear cache to ensure new model is compiled >>> clear_compilation_cache() >>> >>> # Now curve_fit will use the updated model >>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1])
- nlsq.caching.clear_global_pool()[source]
Clear the global memory pool.
Notes
For test isolation, this resets the global pool to None, forcing fresh initialization on next access.
- nlsq.caching.clear_memory_pool()[source]
Clear the global memory pool.
- nlsq.caching.get_global_cache()[source]
Get global cache instance.
- Returns:
cache – Global cache instance
- Return type:
- nlsq.caching.get_global_compilation_cache()[source]
Get or create global compilation cache (thread-safe).
Uses double-checked locking to avoid acquiring the lock on every call once the singleton has been initialized.
- Returns:
cache – Global compilation cache instance
- Return type:
- nlsq.caching.get_global_pool(enable_stats=False)[source]
Get or create global memory pool.
- Parameters:
enable_stats (bool) – Enable statistics tracking
- Returns:
pool – Global memory pool instance
- Return type:
- nlsq.caching.get_jit_cache()[source]
Get JIT compilation cache.
- Returns:
cache – JIT compilation cache
- Return type:
JITCompilationCache
- nlsq.caching.get_memory_manager()[source]
Get or create global memory manager instance.
- Returns:
manager – Global memory manager instance
- Return type: