"""JIT compilation cache for optimization functions.
This module provides caching of compiled JAX functions to avoid
recompilation overhead.
Phase 3 Optimizations (Task Group 9):
- LRU eviction with max_cache_size parameter (default 256)
- Function hash race condition fix with composite key (id(func), id(func.__code__))
"""
import hashlib
import threading
import warnings
import weakref
from collections import OrderedDict
from collections.abc import Callable
from functools import wraps
import jax
import jax.numpy as jnp
[docs]
class CompilationCache:
"""Cache for JIT-compiled functions with LRU eviction.
Caches compiled versions of functions based on their signature
to avoid repeated JIT compilation overhead.
Phase 3 Optimizations (2.2a, 2.1a):
- Uses OrderedDict for LRU tracking with move_to_end() on hits
- Evicts oldest entry with popitem(last=False) when at capacity
- Uses composite key (id(func), id(func.__code__)) to prevent
cache poisoning when functions are redefined with same name
Thread Safety
-------------
All public methods are guarded by a per-instance ``threading.Lock``.
The lock is held only for dict operations and released during
``jax.jit()`` compilation to avoid blocking concurrent threads.
Attributes
----------
cache : OrderedDict
OrderedDict mapping function signatures to compiled functions
(enables LRU eviction)
max_cache_size : int
Maximum number of compiled functions to cache (default 512)
stats : dict
Compilation cache statistics
_func_hash_cache : dict
Memoization cache for function code hashes using composite key
(id(func), id(func.__code__)) for correctness in notebooks
"""
[docs]
def __init__(self, enable_stats: bool = True, max_cache_size: int = 512):
"""Initialize compilation cache.
Parameters
----------
enable_stats : bool
Track cache statistics
max_cache_size : int
Maximum number of compiled functions to cache (default 512).
Increased from 256 to reduce recompilation frequency.
Caps memory usage at approximately 4GB for 512 cached functions.
"""
# Task 9.2: Use OrderedDict for LRU eviction
self.cache: OrderedDict[str, Callable] = OrderedDict()
self.enable_stats = enable_stats
self.max_cache_size = max_cache_size
# WeakKeyDictionary: entries are auto-removed when the function
# object is garbage-collected, preventing stale id() reuse
self._func_hash_cache: weakref.WeakKeyDictionary[Callable, str] = (
weakref.WeakKeyDictionary()
)
if enable_stats:
self.stats = {
"hits": 0,
"misses": 0,
"compilations": 0,
"cache_size": 0,
}
self._lock = threading.Lock()
def _get_function_code_hash(self, func: Callable) -> str:
"""Get memoized hash of function code.
This method caches function code hashes using a WeakKeyDictionary
keyed on the function object itself. This prevents stale entries
when functions are garbage-collected and a new object reuses the
same ``id()``.
The WeakKeyDictionary is protected by ``self._lock`` to prevent
TOCTOU races where GC fires between ``__contains__`` and
``__getitem__``.
Parameters
----------
func : callable
Function to hash
Returns
-------
hash : str
SHA256 hash of function code (first 8 chars)
"""
# Check memoization cache under lock (GC can fire between
# __contains__ and __getitem__ on WeakKeyDictionary)
with self._lock:
if func in self._func_hash_cache:
return self._func_hash_cache[func]
# Compute hash outside lock (no shared state access)
try:
func_code = func.__code__.co_code if hasattr(func, "__code__") else b""
code_hash = hashlib.sha256(func_code).hexdigest()[:8]
except (AttributeError, TypeError):
# Fallback: hash the function's qualified name
name = getattr(func, "__qualname__", getattr(func, "__name__", "unknown"))
code_hash = hashlib.sha256(name.encode()).hexdigest()[:8]
# Store under lock
with self._lock:
self._func_hash_cache[func] = code_hash
return code_hash
def _get_function_signature(self, func: Callable, *args, **kwargs) -> str:
"""Generate unique signature for function and arguments.
Parameters
----------
func : callable
Function to generate signature for
args : tuple
Positional arguments
kwargs : dict
Keyword arguments
Returns
-------
signature : str
Unique signature string
"""
try:
# Get function name and code
func_name = func.__name__ if hasattr(func, "__name__") else "unknown"
# Get argument shapes and dtypes
arg_info = []
for arg in args:
if isinstance(arg, jax.Array):
arg_info.append(f"{arg.shape}_{arg.dtype}")
elif isinstance(arg, (int, float, str, bool)):
arg_info.append(f"{type(arg).__name__}_{arg}")
else:
arg_info.append(str(type(arg).__name__))
# Include static arguments from kwargs
for k, v in sorted(kwargs.items()):
if isinstance(v, (int, float, str, bool)):
arg_info.append(f"{k}={v}")
# Create signature
sig_str = f"{func_name}_{'_'.join(arg_info)}"
# Hash if too long
if len(sig_str) > 200:
sig_hash = hashlib.sha256(sig_str.encode()).hexdigest()[:16]
sig_str = f"{func_name}_{sig_hash}"
return sig_str
except (AttributeError, TypeError) as e:
warnings.warn(f"Could not generate function signature: {e}")
return f"fallback_{id(func)}"
def _evict_if_at_capacity(self):
"""Evict oldest entry if cache is at capacity.
Uses LRU eviction with popitem(last=False) to remove the
least recently used entry.
Note: Caller must hold ``self._lock``.
"""
while len(self.cache) >= self.max_cache_size:
# Task 9.2: LRU eviction using popitem(last=False)
self.cache.popitem(last=False)
[docs]
def compile(
self,
func: Callable,
static_argnums: tuple[int, ...] = (),
donate_argnums: tuple[int, ...] = (),
) -> Callable:
"""Compile function with JIT and cache result.
Parameters
----------
func : callable
Function to compile
static_argnums : tuple of int
Indices of static arguments
donate_argnums : tuple of int
Indices of arguments to donate
Returns
-------
compiled_func : callable
JIT-compiled function (may be cached)
"""
# Create cache key based on function and compilation options
# Uses memoized function hash for 95% faster repeated lookups
try:
code_hash = self._get_function_code_hash(func)
func_name = func.__name__ if hasattr(func, "__name__") else "unknown"
cache_key = f"{func_name}_{code_hash}_s{static_argnums}_d{donate_argnums}"
except (AttributeError, TypeError):
cache_key = f"{id(func)}_s{static_argnums}_d{donate_argnums}"
# Check cache and record miss stats in single lock acquisition
with self._lock:
if cache_key in self.cache:
if self.enable_stats:
self.stats["hits"] += 1
# Task 9.2: Move to end for LRU tracking (most recently used)
self.cache.move_to_end(cache_key)
return self.cache[cache_key]
# Cache miss -- update stats while we still hold the lock
if self.enable_stats:
self.stats["misses"] += 1
self.stats["compilations"] += 1
# Compile function outside lock (jax.jit can be slow)
compiled_func = jax.jit(
func, static_argnums=static_argnums, donate_argnums=donate_argnums
)
# Store in cache under lock (re-check to avoid duplicate)
with self._lock:
if cache_key in self.cache:
self.cache.move_to_end(cache_key)
return self.cache[cache_key]
# Task 9.2: Evict oldest entry if at capacity before adding new one
self._evict_if_at_capacity()
# Store in cache (at end, as most recently used)
self.cache[cache_key] = compiled_func
if self.enable_stats:
self.stats["cache_size"] = len(self.cache)
return compiled_func
[docs]
def get_or_compile(
self, func: Callable, *args, static_argnums: tuple[int, ...] = (), **kwargs
) -> tuple[Callable, str]:
"""Get cached compiled function or compile if not cached.
Parameters
----------
func : callable
Function to compile
args : tuple
Arguments to function (for signature generation)
static_argnums : tuple of int
Indices of static arguments
kwargs : dict
Keyword arguments
Returns
-------
compiled_func : callable
Compiled function
signature : str
Function signature
"""
sig = self._get_function_signature(func, *args, **kwargs)
full_key = f"{sig}_s{static_argnums}"
# Check cache and record miss stats in single lock acquisition
with self._lock:
if full_key in self.cache:
if self.enable_stats:
self.stats["hits"] += 1
# Task 9.2: Move to end for LRU tracking
self.cache.move_to_end(full_key)
return self.cache[full_key], sig
# Cache miss -- update stats while we still hold the lock
if self.enable_stats:
self.stats["misses"] += 1
self.stats["compilations"] += 1
# Compile outside lock (jax.jit can be slow)
compiled_func = jax.jit(func, static_argnums=static_argnums)
# Store in cache under lock (re-check to avoid duplicate)
with self._lock:
if full_key in self.cache:
self.cache.move_to_end(full_key)
return self.cache[full_key], sig
# Task 9.2: Evict if at capacity before storing
self._evict_if_at_capacity()
# Store with full signature
self.cache[full_key] = compiled_func
if self.enable_stats:
self.stats["cache_size"] = len(self.cache)
return compiled_func, sig
[docs]
def clear(self):
"""Clear compilation cache, function hash memoization, and reset stats.
This method clears all cached data and resets statistics counters to
zero, allowing accurate hit/miss tracking after the clear operation.
"""
with self._lock:
self.cache.clear()
self._func_hash_cache.clear()
if self.enable_stats:
self.stats["hits"] = 0
self.stats["misses"] = 0
self.stats["compilations"] = 0
self.stats["cache_size"] = 0
[docs]
def get_stats(self) -> dict:
"""Get cache statistics.
Returns
-------
stats : dict
Cache hit rate and other statistics
"""
if not self.enable_stats:
return {"enabled": False}
with self._lock:
total_requests = self.stats["hits"] + self.stats["misses"]
hit_rate = (
self.stats["hits"] / total_requests if total_requests > 0 else 0.0
)
return {
**self.stats,
"hit_rate": hit_rate,
"total_requests": total_requests,
"max_cache_size": self.max_cache_size,
}
[docs]
def __enter__(self):
"""Context manager entry."""
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
return False
# Global compilation cache (thread-safe double-checked locking)
_global_compilation_cache: CompilationCache | None = None
_global_compilation_cache_lock = threading.Lock()
[docs]
def get_global_compilation_cache() -> CompilationCache:
"""Get or create global compilation cache (thread-safe).
Uses double-checked locking to avoid acquiring the lock on
every call once the singleton has been initialized.
Returns
-------
cache : CompilationCache
Global compilation cache instance
"""
global _global_compilation_cache # noqa: PLW0603
if _global_compilation_cache is not None:
return _global_compilation_cache
with _global_compilation_cache_lock:
if _global_compilation_cache is None:
_global_compilation_cache = CompilationCache(enable_stats=True)
return _global_compilation_cache
[docs]
def cached_jit(
func: Callable | None = None,
static_argnums: tuple[int, ...] = (),
donate_argnums: tuple[int, ...] = (),
) -> Callable:
"""Decorator for caching JIT-compiled functions.
Parameters
----------
func : callable, optional
Function to decorate
static_argnums : tuple of int
Indices of static arguments
donate_argnums : tuple of int
Indices of arguments to donate
Returns
-------
decorated : callable
Decorated function with cached compilation
Examples
--------
>>> @cached_jit
... def my_function(x):
... return x ** 2
>>> @cached_jit(static_argnums=(1,))
... def my_function_with_static(x, n):
... return x ** n
"""
def decorator(f):
cache = get_global_compilation_cache()
@wraps(f)
def wrapper(*args, **kwargs):
compiled_func, _ = cache.get_or_compile(
f, *args, static_argnums=static_argnums, **kwargs
)
return compiled_func(*args, **kwargs)
return wrapper
if func is None:
# Called with arguments: @cached_jit(static_argnums=(1,))
return decorator
else:
# Called without arguments: @cached_jit
return decorator(func)
[docs]
def clear_compilation_cache():
"""Clear the global compilation cache, function hash cache, and reset stats.
This function is useful in interactive environments (Jupyter notebooks,
IPython, REPL) where model functions may be redefined during development.
When a function is redefined with the same name but different code, the
compilation cache may return stale compiled versions. Calling this function
clears the cache and allows fresh compilation of redefined functions.
The function also resets all statistics counters (hits, misses, compilations)
to zero, enabling accurate cache performance tracking after the clear.
Examples
--------
In a Jupyter notebook, after redefining a model function:
>>> import jax.numpy as jnp
>>> from nlsq.caching.compilation_cache import clear_compilation_cache
>>> from nlsq import curve_fit
>>>
>>> # Initial model definition
>>> def model(x, a, b):
... return a * jnp.exp(-b * x)
>>>
>>> # ... some fitting work ...
>>>
>>> # Redefine the model with different implementation
>>> def model(x, a, b):
... return a * jnp.exp(-b * x) + 0.1 # Added offset
>>>
>>> # Clear cache to ensure new model is compiled
>>> clear_compilation_cache()
>>>
>>> # Now curve_fit will use the updated model
>>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1])
"""
global _global_compilation_cache # noqa: PLW0602
if _global_compilation_cache is not None:
_global_compilation_cache.clear()