nlsq.caching module

Function caching and JIT compilation utilities.

Caching and memory management modules.

This subpackage contains caching and memory management: - core: Basic caching utilities - smart_cache: SmartCache with intelligent invalidation - unified_cache: Unified cache management - compilation_cache: JIT compilation caching - memory_manager: Memory management and tracking - memory_pool: Memory pooling for optimization

class nlsq.caching.CompilationCache(enable_stats=True, max_cache_size=512)[source]

Bases: object

Cache for JIT-compiled functions with LRU eviction.

Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.

Phase 3 Optimizations (2.2a, 2.1a):

  • Uses OrderedDict for LRU tracking with move_to_end() on hits

  • Evicts oldest entry with popitem(last=False) when at capacity

  • Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name

Thread Safety

All public methods are guarded by a per-instance threading.Lock. The lock is held only for dict operations and released during jax.jit() compilation to avoid blocking concurrent threads.

cache

OrderedDict mapping function signatures to compiled functions (enables LRU eviction)

Type:

OrderedDict

max_cache_size

Maximum number of compiled functions to cache (default 512)

Type:

int

stats

Compilation cache statistics

Type:

dict

_func_hash_cache

Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks

Type:

dict

__init__(enable_stats=True, max_cache_size=512)[source]

Initialize compilation cache.

Parameters:
  • enable_stats (bool) – Track cache statistics

  • max_cache_size (int) – Maximum number of compiled functions to cache (default 512). Increased from 256 to reduce recompilation frequency. Caps memory usage at approximately 4GB for 512 cached functions.

compile(func, static_argnums=(), donate_argnums=())[source]

Compile function with JIT and cache result.

Parameters:
  • func (callable) – Function to compile

  • static_argnums (tuple of int) – Indices of static arguments

  • donate_argnums (tuple of int) – Indices of arguments to donate

Returns:

compiled_func – JIT-compiled function (may be cached)

Return type:

callable

get_or_compile(func, *args, static_argnums=(), **kwargs)[source]

Get cached compiled function or compile if not cached.

Parameters:
  • func (callable) – Function to compile

  • args (tuple) – Arguments to function (for signature generation)

  • static_argnums (tuple of int) – Indices of static arguments

  • kwargs (dict) – Keyword arguments

Returns:

  • compiled_func (callable) – Compiled function

  • signature (str) – Function signature

Return type:

tuple[Callable, str]

clear()[source]

Clear compilation cache, function hash memoization, and reset stats.

This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache hit rate and other statistics

Return type:

dict

__enter__()[source]

Context manager entry.

__exit__(exc_type, exc_val, exc_tb)[source]

Context manager exit.

class nlsq.caching.MemoryManager(gc_threshold=0.8, safety_factor=1.2, enable_adaptive_safety=False, disable_padding=False, memory_cache_ttl=1.0, adaptive_ttl=True)[source]

Bases: object

Intelligent memory management for optimization algorithms.

This class provides: - Memory usage monitoring and prediction - Array pooling to reduce allocations with LRU eviction - Automatic garbage collection triggers - Context managers for memory-safe operations

LRU Memory Pool (Task Group 7 - 1.2a)

The memory pool uses an OrderedDict to track access order, enabling true LRU (Least Recently Used) eviction when at capacity. This improves cache utilization for frequently accessed array shapes by 5-10%.

Telemetry Circular Buffer (Task Group 9 - 1.3a)

The safety telemetry uses a deque with maxlen=1000 to prevent memory leak in multi-day optimization runs. This maintains the last 1000 telemetry records for adaptive safety factor calculation.

memory_pool

Pool of reusable arrays indexed by (shape, dtype) with LRU tracking

Type:

OrderedDict

allocation_history

History of memory allocations

Type:

list

gc_threshold

Memory usage threshold (0-1) for triggering garbage collection

Type:

float

safety_factor

Safety factor for memory predictions

Type:

float

__init__(gc_threshold=0.8, safety_factor=1.2, enable_adaptive_safety=False, disable_padding=False, memory_cache_ttl=1.0, adaptive_ttl=True)[source]

Initialize memory manager.

Parameters:
  • gc_threshold (float) – Trigger GC when memory usage exceeds this fraction (0-1)

  • safety_factor (float) – Multiply memory requirements by this factor for safety

  • enable_adaptive_safety (bool) – Enable adaptive safety factor reduction (1.2 -> 1.05 after warmup)

  • disable_padding (bool) – Disable padding/bucketing for strict memory environments (Task 5.6). When True: uses exact shapes, sets safety_factor=1.0. Use case: cloud quotas, strict memory limits.

  • memory_cache_ttl (float) – TTL in seconds for cached memory info (default: 1.0). Reduces psutil system call overhead by 90%.

  • adaptive_ttl (bool) – Enable adaptive TTL based on call frequency (default: True). High-frequency callers (>100 calls/sec) get 15s effective TTL. Medium-frequency callers (>10 calls/sec) get 10s effective TTL. Low-frequency callers use the default TTL. Reduces psutil overhead in streaming optimization by 15-20%.

get_available_memory()[source]

Get available memory in bytes.

Returns:

available – Available memory in bytes

Return type:

float

Notes

Uses TTL-based caching to reduce psutil system call overhead by 90%. When adaptive_ttl is enabled, the effective TTL is adjusted based on call frequency to further reduce overhead for streaming optimization.

get_memory_usage_bytes()[source]

Get current memory usage in bytes.

Returns:

usage – Current memory usage in bytes

Return type:

float

Notes

Uses TTL-based caching to reduce psutil system call overhead by 90%.

get_memory_usage_fraction()[source]

Get current memory usage as fraction of total.

Returns:

fraction – Memory usage fraction (0-1)

Return type:

float

Notes

Uses TTL-based caching to reduce psutil system call overhead by 90%.

predict_memory_requirement(n_points, n_params, algorithm='trf', dtype=<class 'jax.numpy.float64'>)[source]

Predict memory requirement for optimization.

Parameters:
  • n_points (int) – Number of data points

  • n_params (int) – Number of parameters

  • algorithm (str) – Algorithm name (‘trf’, ‘lm’, ‘dogbox’)

  • dtype (jnp.dtype, optional) – Data type for computations (default: jnp.float64). Affects memory calculations: float32 uses 4 bytes, float64 uses 8 bytes.

Returns:

bytes_needed – Estimated memory requirement in bytes

Return type:

int

Notes

Memory requirements scale linearly with precision: - float32: 4 bytes per element (50% memory savings) - float64: 8 bytes per element (default, higher precision)

check_memory_availability(bytes_needed)[source]

Check if enough memory is available.

Parameters:

bytes_needed (int) – Memory required in bytes

Returns:

  • available (bool) – Whether enough memory is available

  • message (str) – Descriptive message

Return type:

tuple[bool, str]

memory_guard(bytes_needed)[source]

Context manager to ensure memory availability.

Parameters:

bytes_needed (int) – Required memory in bytes

Raises:

MemoryError – If insufficient memory is available

get_safety_telemetry()[source]

Get safety factor telemetry statistics.

Returns:

telemetry – Safety factor telemetry with: - current_safety_factor: Current safety factor - initial_safety_factor: Initial safety factor (1.2) - min_safety_factor: Target minimum (1.05) - telemetry_entries: Number of telemetry entries collected - p95_safety_needed: 95th percentile of safety factors needed (if data available) - safety_factor_history: List of safety factors over time

Return type:

dict

allocate_array(shape, dtype=<class 'numpy.float64'>, zero=True)[source]

Allocate array with memory pooling and LRU tracking.

Parameters:
  • shape (tuple) – Shape of array to allocate

  • dtype (type) – Data type of array

  • zero (bool) – Whether to zero-initialize the array

Returns:

array – Allocated array

Return type:

np.ndarray

Raises:

MemoryError – If allocation fails

Notes

Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. When an array is reused from the pool, it is moved to the end (most recently used) to enable proper LRU eviction.

free_array(arr)[source]

Return array to pool for reuse.

Parameters:

arr (np.ndarray) – Array to free

Notes

Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. The returned array is added/moved to the end of the pool, marking it as recently used.

clear_pool()[source]

Clear memory pool and run garbage collection.

get_memory_stats()[source]

Get memory usage statistics.

Returns:

stats – Memory statistics including current usage, peak, pool size

Return type:

dict

optimize_memory_pool(max_arrays=100)[source]

Optimize memory pool using LRU eviction.

Parameters:

max_arrays (int) – Maximum number of arrays to keep in pool

Notes

Task Group 7 (1.2a): Uses LRU eviction via popitem(last=False). Arrays are evicted in order of least recent use, keeping the most recently used arrays in the pool.

temporary_allocation(shape, dtype=<class 'numpy.float64'>)[source]

Context manager for temporary array allocation.

Parameters:
  • shape (tuple) – Shape of array

  • dtype (type) – Data type

Yields:

array (np.ndarray) – Temporary array that will be returned to pool on exit

estimate_chunking_strategy(n_points, n_params, algorithm='trf', memory_limit_gb=None)[source]

Estimate optimal chunking strategy for large datasets.

Parameters:
  • n_points (int) – Number of data points

  • n_params (int) – Number of parameters

  • algorithm (str) – Algorithm to use

  • memory_limit_gb (float, optional) – Memory limit in GB (uses available memory if None)

Returns:

strategy – Chunking strategy with chunk_size and n_chunks

Return type:

dict

class nlsq.caching.MemoryPool(max_pool_size=10, enable_stats=False, enable_bucketing=True)[source]

Bases: object

Memory pool for reusable array buffers.

Pre-allocates buffers for common array shapes to avoid repeated allocations during optimization iterations.

pools

Dictionary mapping (shape, dtype) to list of available buffers

Type:

dict

allocated

Dictionary tracking allocated buffers

Type:

dict

max_pool_size

Maximum number of buffers per shape/dtype combination

Type:

int

stats

Statistics on pool usage

Type:

dict

__init__(max_pool_size=10, enable_stats=False, enable_bucketing=True)[source]

Initialize memory pool.

Parameters:
  • max_pool_size (int) – Maximum number of buffers to keep per shape/dtype

  • enable_stats (bool) – Track allocation statistics

  • enable_bucketing (bool) – Enable size-class bucketing for better reuse (Task 5.4)

allocate(shape, dtype=<class 'jax.numpy.float64'>)[source]

Allocate array from pool or create new one.

Parameters:
  • shape (tuple) – Shape of array to allocate

  • dtype (type) – Data type of array

Returns:

array – Allocated array (may be reused from pool)

Return type:

jnp.ndarray

Notes

When bucketing is enabled, arrays are pooled by size classes (1KB/10KB/100KB) for better reuse rates (Task 5.4).

release(arr)[source]

Return array to pool for reuse.

Parameters:

arr (jnp.ndarray) – Array to return to pool

Notes

When bucketing is enabled, arrays are stored in size-class buckets for better reuse (Task 5.4).

clear()[source]

Clear all pools and reset statistics.

get_stats()[source]

Get pool usage statistics.

Returns:

stats – Pool usage statistics including reuse_rate (Task 5.5)

Return type:

dict

Notes

reuse_rate = reused_allocations / total_allocations With bucketing enabled, expect 5x higher reuse rates.

__enter__()[source]

Context manager entry.

__exit__(exc_type, exc_val, exc_tb)[source]

Context manager exit - clear pool.

class nlsq.caching.SmartCache(cache_dir='.nlsq_cache', max_memory_items=1000, disk_cache_enabled=True, enable_stats=True)[source]

Bases: object

Intelligent caching system for optimization computations.

This class provides:

  • Memory and disk caching with LRU eviction

  • Automatic cache key generation from function arguments

  • Cache persistence across sessions

  • Cache invalidation and warming strategies

Phase 3 Optimizations (3.2a):

  • Array hash optimization: uses stride-based sampling only for arrays with >10000 elements when xxhash is unavailable

  • For smaller arrays, hashes full array directly without redundant sampling, providing 15-20% improvement in cache key generation

All dict operations are protected by a per-instance threading.Lock so that concurrent threads can safely call get/set/invalidate.

cache_dir

Directory for disk cache storage

Type:

str

memory_cache

In-memory cache storage

Type:

dict

disk_cache_enabled

Whether disk caching is enabled

Type:

bool

max_memory_items

Maximum items in memory cache

Type:

int

cache_stats

Cache hit/miss statistics

Type:

dict

__init__(cache_dir='.nlsq_cache', max_memory_items=1000, disk_cache_enabled=True, enable_stats=True)[source]

Initialize smart cache.

Parameters:
  • cache_dir (str) – Directory for disk cache

  • max_memory_items (int) – Maximum items in memory cache

  • disk_cache_enabled (bool) – Enable disk caching

  • enable_stats (bool) – Track cache statistics

cache_key(*args, **kwargs)[source]

Generate cache key from arguments.

Parameters:
  • *args (tuple) – Positional arguments

  • **kwargs (dict) – Keyword arguments

Returns:

key – Hash of arguments (xxhash if available, BLAKE2b fallback)

Return type:

str

Notes

Uses xxhash (xxh64) when available for ~10x faster hashing compared to SHA256/BLAKE2b. Falls back to BLAKE2b if xxhash is not installed. All cache keys are prefixed with CACHE_VERSION to ensure old cache entries are invalidated when the hash algorithm changes.

Task 9.3 (3.2a): Array hash optimization - Arrays <= 10000 elements: hash full array directly (no sampling) - Arrays > 10000 elements: use stride-based sampling for efficiency - Removes redundant sampling when computing full hash in fallback path

get(key)[source]

Get value from cache.

Parameters:

key (str) – Cache key

Returns:

value – Cached value or None if not found

Return type:

Any or None

set(key, value)[source]

Set value in cache.

Parameters:
  • key (str) – Cache key

  • value (Any) – Value to cache

invalidate(key=None)[source]

Invalidate cache entries.

Parameters:

key (str, optional) – Specific key to invalidate, or None to clear all

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache statistics including hit rate

Return type:

dict

optimize_cache()[source]

Optimize cache by removing rarely accessed items.

Computes threshold from snapshot, then re-checks live counts under lock before invalidating to avoid evicting keys that became hot between the snapshot and eviction.

class nlsq.caching.TRFMemoryPool(m, n, dtype=<class 'jax.numpy.float64'>)[source]

Bases: object

Specialized memory pool for Trust Region Reflective algorithm.

Pre-allocates buffers for common TRF operations.

Parameters:
  • m (int) – Number of residuals

  • n (int) – Number of parameters

  • dtype (type) – Data type for arrays

__init__(m, n, dtype=<class 'jax.numpy.float64'>)[source]

Initialize TRF memory pool.

Parameters:
  • m (int) – Number of residuals

  • n (int) – Number of parameters

  • dtype (type) – Data type

get_jacobian_buffer()[source]

Get Jacobian buffer (m×n).

get_residual_buffer()[source]

Get residual buffer (m).

get_gradient_buffer()[source]

Get gradient buffer (n).

get_step_buffer()[source]

Get step buffer (n).

get_x_buffer()[source]

Get parameter buffer (n).

reset()[source]

Reset all buffers to zero.

nlsq.caching.cached_function(cache=None, ttl=None)[source]

Decorator for caching function results.

Parameters:
  • cache (SmartCache, optional) – Cache instance to use (creates new if None)

  • ttl (float, optional) – Time-to-live in seconds for cached values

Returns:

decorator – Decorator function

Return type:

function

nlsq.caching.cached_jacobian(cache=None)[source]

Decorator specifically for caching Jacobian evaluations.

Parameters:

cache (SmartCache, optional) – Cache instance to use

Returns:

decorator – Decorator function

Return type:

function

nlsq.caching.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]

Decorator for caching JIT-compiled functions.

Parameters:
  • func (callable, optional) – Function to decorate

  • static_argnums (tuple of int) – Indices of static arguments

  • donate_argnums (tuple of int) – Indices of arguments to donate

Returns:

decorated – Decorated function with cached compilation

Return type:

callable

Examples

>>> @cached_jit
... def my_function(x):
...     return x ** 2
>>> @cached_jit(static_argnums=(1,))
... def my_function_with_static(x, n):
...     return x ** n
nlsq.caching.clear_all_caches()[source]

Clear all global caches.

nlsq.caching.clear_compilation_cache()[source]

Clear the global compilation cache, function hash cache, and reset stats.

This function is useful in interactive environments (Jupyter notebooks, IPython, REPL) where model functions may be redefined during development. When a function is redefined with the same name but different code, the compilation cache may return stale compiled versions. Calling this function clears the cache and allows fresh compilation of redefined functions.

The function also resets all statistics counters (hits, misses, compilations) to zero, enabling accurate cache performance tracking after the clear.

Examples

In a Jupyter notebook, after redefining a model function:

>>> import jax.numpy as jnp
>>> from nlsq.caching.compilation_cache import clear_compilation_cache
>>> from nlsq import curve_fit
>>>
>>> # Initial model definition
>>> def model(x, a, b):
...     return a * jnp.exp(-b * x)
>>>
>>> # ... some fitting work ...
>>>
>>> # Redefine the model with different implementation
>>> def model(x, a, b):
...     return a * jnp.exp(-b * x) + 0.1  # Added offset
>>>
>>> # Clear cache to ensure new model is compiled
>>> clear_compilation_cache()
>>>
>>> # Now curve_fit will use the updated model
>>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1])
nlsq.caching.clear_global_pool()[source]

Clear the global memory pool.

Notes

For test isolation, this resets the global pool to None, forcing fresh initialization on next access.

nlsq.caching.clear_memory_pool()[source]

Clear the global memory pool.

nlsq.caching.get_global_cache()[source]

Get global cache instance.

Returns:

cache – Global cache instance

Return type:

SmartCache

nlsq.caching.get_global_compilation_cache()[source]

Get or create global compilation cache (thread-safe).

Uses double-checked locking to avoid acquiring the lock on every call once the singleton has been initialized.

Returns:

cache – Global compilation cache instance

Return type:

CompilationCache

nlsq.caching.get_global_pool(enable_stats=False)[source]

Get or create global memory pool.

Parameters:

enable_stats (bool) – Enable statistics tracking

Returns:

pool – Global memory pool instance

Return type:

MemoryPool

nlsq.caching.get_jit_cache()[source]

Get JIT compilation cache.

Returns:

cache – JIT compilation cache

Return type:

JITCompilationCache

nlsq.caching.get_memory_manager()[source]

Get or create global memory manager instance.

Returns:

manager – Global memory manager instance

Return type:

MemoryManager

nlsq.caching.get_memory_stats()[source]

Get memory usage statistics.

Returns:

stats – Memory statistics

Return type:

dict