nlsq.unified_cache module¶
Unified JAX JIT compilation cache for NLSQ.
This module consolidates three independent cache implementations (compilation_cache.py, caching.py, smart_cache.py) into a single unified cache with shape-relaxed keys, comprehensive statistics tracking, and optimized memory management.
Key Features¶
Shape-relaxed cache keys: (func_hash, dtype, rank) instead of full shapes
Comprehensive statistics: hits, misses, compile_time_ms, hit_rate
LRU eviction with configurable maxsize
Optional two-tier caching (memory + disk)
Weak references to avoid memory leaks
Thread-safe operations via per-instance threading.Lock
Async disk writes (deferred to Phase 2)
Design Goals¶
80%+ cache hit rate on typical batch fitting workflows
2-5x reduction in cold-start compile time through better cache reuse
Backward compatibility with existing cache APIs (gradual migration)
Zero breaking changes to curve_fit API
- class nlsq.caching.unified_cache.UnifiedCache(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]
Bases:
objectUnified compilation cache merging three legacy cache patterns.
This cache provides: - Shape-relaxed keys: cache based on (func_hash, dtype, rank) not exact shapes - Comprehensive stats: hits, misses, compile_time_ms, hit_rate, cache_size - LRU eviction when maxsize exceeded - Optional disk caching (two-tier architecture) - Weak references to functions to prevent memory leaks - Thread-safe operations via per-instance lock
- maxsize
Maximum number of compiled functions to cache
- Type:
- enable_stats
Whether to track cache statistics
- Type:
- disk_cache_enabled
Whether to enable disk caching (default: False for Phase 1)
- Type:
Examples
>>> cache = UnifiedCache(maxsize=128, enable_stats=True) >>> def my_func(x, a): ... return a * x ** 2 >>> x = jnp.array([1.0, 2.0, 3.0]) >>> compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,)) >>> result = compiled(x, 1.0) >>> stats = cache.get_stats() >>> print(f"Hit rate: {stats['hit_rate']:.2%}")
- __init__(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]
Initialize unified cache.
- get_or_compile(func, args, kwargs, static_argnums=(), donate_argnums=())[source]
Get cached compiled function or compile if not cached.
Thread-safe: the lock is held for dict operations only and released during expensive JIT compilation to avoid serializing parallel compilations. A re-check after compilation prevents duplicate stores when multiple threads miss on the same key simultaneously.
- Parameters:
func (Callable) – Function to compile
args (tuple) – Arguments to function (for signature generation)
kwargs (dict) – Keyword arguments
static_argnums (tuple of int, default=()) – Indices of static arguments for JIT
donate_argnums (tuple of int, default=()) – Indices of arguments to donate for memory efficiency
- Returns:
compiled_func – JIT-compiled function (from cache or newly compiled)
- Return type:
Callable
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache statistics including: - hits : int - Number of cache hits - misses : int - Number of cache misses - compilations : int - Number of JIT compilations performed - evictions : int - Number of LRU evictions - cache_size : int - Current cache size - hit_rate : float - Cache hit rate (hits / total_requests) - compile_time_ms : float - Total compilation time in milliseconds
- Return type:
- clear()[source]
Clear all cached compilations and reset statistics.
- __repr__()[source]
String representation of cache.
- nlsq.caching.unified_cache.get_global_cache()[source]
Get or create global unified cache instance (thread-safe).
- Returns:
cache – Global unified cache instance
- Return type:
UnifiedCache
- nlsq.caching.unified_cache.clear_cache()[source]
Clear the global unified cache.
- nlsq.caching.unified_cache.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]
Decorator for cached JIT compilation using unified cache.
This decorator provides automatic caching of JIT-compiled functions with shape-relaxed keys for better cache reuse.
- Parameters:
- Returns:
decorated – Decorated function with cached compilation
- Return type:
Callable
Examples
>>> @cached_jit(static_argnums=(1,)) ... def my_function(x, n): ... return x ** n
>>> @cached_jit ... def simple_function(x): ... return x ** 2
- nlsq.caching.unified_cache.get_cache_stats()[source]
Get statistics from the global unified cache.
- Returns:
stats – Cache statistics
- Return type:
- class nlsq.caching.unified_cache.CompilationCacheCompat(enable_stats=True)[source]
Bases:
objectBackward compatibility wrapper for compilation_cache.py API.
This allows gradual migration of existing code using CompilationCache to the new UnifiedCache without breaking changes.
- __init__(enable_stats=True)[source]
Initialize compatibility wrapper.
- compile(func, static_argnums=(), donate_argnums=())[source]
Compile function with JIT and cache result (compatibility wrapper).
- get_stats()[source]
Get cache statistics (compatibility wrapper).
- clear()[source]
Clear cache (compatibility wrapper).
- class nlsq.caching.unified_cache.FunctionCacheCompat(maxsize=128)[source]
Bases:
objectBackward compatibility wrapper for caching.py API.
- __init__(maxsize=128)[source]
Initialize compatibility wrapper.
- get_function_hash(func)[source]
Generate stable hash for a function (compatibility wrapper).
- get_stats()[source]
Get cache statistics (compatibility wrapper).
- clear()[source]
Clear cache (compatibility wrapper).
- property hit_rate: float
Get cache hit rate (compatibility wrapper).
Overview¶
The unified_cache module consolidates three independent cache implementations into a single unified cache with shape-relaxed keys, comprehensive statistics tracking, and optimized memory management.
Key Features¶
Shape-relaxed cache keys: Cache based on
(func_hash, dtype, rank)instead of full shapesComprehensive statistics: Track hits, misses, compile time, hit rate
LRU eviction: Configurable maxsize with automatic eviction
Two-tier caching: Optional memory + disk caching
Weak references: Prevents memory leaks
Thread-safe operations: Safe for concurrent use
Performance Goals¶
80%+ cache hit rate on typical batch fitting workflows
2-5x reduction in cold-start compile time through better cache reuse
Backward compatibility with existing cache APIs
Zero breaking changes to curve_fit API
Classes¶
- class nlsq.unified_cache.UnifiedCache(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]
Bases:
objectUnified compilation cache merging three legacy cache patterns.
This cache provides: - Shape-relaxed keys: cache based on (func_hash, dtype, rank) not exact shapes - Comprehensive stats: hits, misses, compile_time_ms, hit_rate, cache_size - LRU eviction when maxsize exceeded - Optional disk caching (two-tier architecture) - Weak references to functions to prevent memory leaks - Thread-safe operations via per-instance lock
- maxsize
Maximum number of compiled functions to cache
- Type:
- enable_stats
Whether to track cache statistics
- Type:
- disk_cache_enabled
Whether to enable disk caching (default: False for Phase 1)
- Type:
Examples
>>> cache = UnifiedCache(maxsize=128, enable_stats=True) >>> def my_func(x, a): ... return a * x ** 2 >>> x = jnp.array([1.0, 2.0, 3.0]) >>> compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,)) >>> result = compiled(x, 1.0) >>> stats = cache.get_stats() >>> print(f"Hit rate: {stats['hit_rate']:.2%}")
- __init__(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]
Initialize unified cache.
- get_or_compile(func, args, kwargs, static_argnums=(), donate_argnums=())[source]
Get cached compiled function or compile if not cached.
Thread-safe: the lock is held for dict operations only and released during expensive JIT compilation to avoid serializing parallel compilations. A re-check after compilation prevents duplicate stores when multiple threads miss on the same key simultaneously.
- Parameters:
func (Callable) – Function to compile
args (tuple) – Arguments to function (for signature generation)
kwargs (dict) – Keyword arguments
static_argnums (tuple of int, default=()) – Indices of static arguments for JIT
donate_argnums (tuple of int, default=()) – Indices of arguments to donate for memory efficiency
- Returns:
compiled_func – JIT-compiled function (from cache or newly compiled)
- Return type:
Callable
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache statistics including: - hits : int - Number of cache hits - misses : int - Number of cache misses - compilations : int - Number of JIT compilations performed - evictions : int - Number of LRU evictions - cache_size : int - Current cache size - hit_rate : float - Cache hit rate (hits / total_requests) - compile_time_ms : float - Total compilation time in milliseconds
- Return type:
- clear()[source]
Clear all cached compilations and reset statistics.
- __repr__()[source]
String representation of cache.
Example Usage¶
from nlsq.unified_cache import UnifiedCache
import jax.numpy as jnp
# Create cache with statistics tracking
cache = UnifiedCache(maxsize=128, enable_stats=True)
# Define function to cache
def my_func(x, a):
return a * x**2
# Use cache for JIT compilation
x = jnp.array([1.0, 2.0, 3.0])
compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,))
result = compiled(x, 1.0)
# Check cache statistics
stats = cache.get_stats()
print(f"Hit rate: {stats['hit_rate']:.2%}")
print(f"Cache size: {stats['cache_size']}")
See Also¶
nlsq.compilation_cache module - Legacy compilation cache
nlsq.caching module - General caching utilities
nlsq.smart_cache module - Smart cache with adaptive features