nlsq.compilation_cache module¶
JIT compilation cache for optimization functions.
This module provides caching of compiled JAX functions to avoid recompilation overhead.
Phase 3 Optimizations (Task Group 9): - LRU eviction with max_cache_size parameter (default 256) - Function hash race condition fix with composite key (id(func), id(func.__code__))
- class nlsq.caching.compilation_cache.CompilationCache(enable_stats=True, max_cache_size=512)[source]
Bases:
objectCache for JIT-compiled functions with LRU eviction.
Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.
Phase 3 Optimizations (2.2a, 2.1a):
Uses OrderedDict for LRU tracking with move_to_end() on hits
Evicts oldest entry with popitem(last=False) when at capacity
Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name
Thread Safety¶
All public methods are guarded by a per-instance
threading.Lock. The lock is held only for dict operations and released duringjax.jit()compilation to avoid blocking concurrent threads.- cache
OrderedDict mapping function signatures to compiled functions (enables LRU eviction)
- Type:
OrderedDict
- max_cache_size
Maximum number of compiled functions to cache (default 512)
- Type:
- stats
Compilation cache statistics
- Type:
- _func_hash_cache
Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks
- Type:
- __init__(enable_stats=True, max_cache_size=512)[source]
Initialize compilation cache.
- compile(func, static_argnums=(), donate_argnums=())[source]
Compile function with JIT and cache result.
- get_or_compile(func, *args, static_argnums=(), **kwargs)[source]
Get cached compiled function or compile if not cached.
- Parameters:
- Returns:
compiled_func (callable) – Compiled function
signature (str) – Function signature
- Return type:
- clear()[source]
Clear compilation cache, function hash memoization, and reset stats.
This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache hit rate and other statistics
- Return type:
- __enter__()[source]
Context manager entry.
- __exit__(exc_type, exc_val, exc_tb)[source]
Context manager exit.
- nlsq.caching.compilation_cache.get_global_compilation_cache()[source]
Get or create global compilation cache (thread-safe).
Uses double-checked locking to avoid acquiring the lock on every call once the singleton has been initialized.
- Returns:
cache – Global compilation cache instance
- Return type:
- nlsq.caching.compilation_cache.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]
Decorator for caching JIT-compiled functions.
- Parameters:
- Returns:
decorated – Decorated function with cached compilation
- Return type:
callable
Examples
>>> @cached_jit ... def my_function(x): ... return x ** 2
>>> @cached_jit(static_argnums=(1,)) ... def my_function_with_static(x, n): ... return x ** n
- nlsq.caching.compilation_cache.clear_compilation_cache()[source]
Clear the global compilation cache, function hash cache, and reset stats.
This function is useful in interactive environments (Jupyter notebooks, IPython, REPL) where model functions may be redefined during development. When a function is redefined with the same name but different code, the compilation cache may return stale compiled versions. Calling this function clears the cache and allows fresh compilation of redefined functions.
The function also resets all statistics counters (hits, misses, compilations) to zero, enabling accurate cache performance tracking after the clear.
Examples
In a Jupyter notebook, after redefining a model function:
>>> import jax.numpy as jnp >>> from nlsq.caching.compilation_cache import clear_compilation_cache >>> from nlsq import curve_fit >>> >>> # Initial model definition >>> def model(x, a, b): ... return a * jnp.exp(-b * x) >>> >>> # ... some fitting work ... >>> >>> # Redefine the model with different implementation >>> def model(x, a, b): ... return a * jnp.exp(-b * x) + 0.1 # Added offset >>> >>> # Clear cache to ensure new model is compiled >>> clear_compilation_cache() >>> >>> # Now curve_fit will use the updated model >>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1])
Overview¶
The compilation_cache module provides JIT compilation caching to avoid recompilation overhead.
Note
This module is being consolidated into nlsq.unified_cache module. For new code, use the unified cache for better performance and features.
Key Features¶
JIT compilation caching for function reuse
Hash-based cache keys for function identification
Automatic cache invalidation when functions change
Memory-efficient storage with weak references
Classes¶
- class nlsq.caching.compilation_cache.CompilationCache(enable_stats=True, max_cache_size=512)[source]
Bases:
objectCache for JIT-compiled functions with LRU eviction.
Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.
Phase 3 Optimizations (2.2a, 2.1a):
Uses OrderedDict for LRU tracking with move_to_end() on hits
Evicts oldest entry with popitem(last=False) when at capacity
Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name
Thread Safety¶
All public methods are guarded by a per-instance
threading.Lock. The lock is held only for dict operations and released duringjax.jit()compilation to avoid blocking concurrent threads.- cache
OrderedDict mapping function signatures to compiled functions (enables LRU eviction)
- Type:
OrderedDict
- max_cache_size
Maximum number of compiled functions to cache (default 512)
- Type:
- stats
Compilation cache statistics
- Type:
- _func_hash_cache
Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks
- Type:
- __init__(enable_stats=True, max_cache_size=512)[source]
Initialize compilation cache.
- compile(func, static_argnums=(), donate_argnums=())[source]
Compile function with JIT and cache result.
- get_or_compile(func, *args, static_argnums=(), **kwargs)[source]
Get cached compiled function or compile if not cached.
- Parameters:
- Returns:
compiled_func (callable) – Compiled function
signature (str) – Function signature
- Return type:
- clear()[source]
Clear compilation cache, function hash memoization, and reset stats.
This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.
- get_stats()[source]
Get cache statistics.
- Returns:
stats – Cache hit rate and other statistics
- Return type:
- __enter__()[source]
Context manager entry.
- __exit__(exc_type, exc_val, exc_tb)[source]
Context manager exit.
Example Usage¶
from nlsq.caching.compilation_cache import CompilationCache
cache = CompilationCache(maxsize=100)
# Cache will automatically be used by curve_fit
# to avoid recompiling the same function
See Also¶
nlsq.unified_cache module - Unified cache (recommended)
nlsq.smart_cache module - Smart cache with adaptive features