Source code for nlsq.caching.compilation_cache

"""JIT compilation cache for optimization functions.

This module provides caching of compiled JAX functions to avoid
recompilation overhead.

Phase 3 Optimizations (Task Group 9):
- LRU eviction with max_cache_size parameter (default 256)
- Function hash race condition fix with composite key (id(func), id(func.__code__))
"""

import hashlib
import threading
import warnings
import weakref
from collections import OrderedDict
from collections.abc import Callable
from functools import wraps

import jax
import jax.numpy as jnp


[docs] class CompilationCache: """Cache for JIT-compiled functions with LRU eviction. Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead. Phase 3 Optimizations (2.2a, 2.1a): - Uses OrderedDict for LRU tracking with move_to_end() on hits - Evicts oldest entry with popitem(last=False) when at capacity - Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name Thread Safety ------------- All public methods are guarded by a per-instance ``threading.Lock``. The lock is held only for dict operations and released during ``jax.jit()`` compilation to avoid blocking concurrent threads. Attributes ---------- cache : OrderedDict OrderedDict mapping function signatures to compiled functions (enables LRU eviction) max_cache_size : int Maximum number of compiled functions to cache (default 512) stats : dict Compilation cache statistics _func_hash_cache : dict Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks """
[docs] def __init__(self, enable_stats: bool = True, max_cache_size: int = 512): """Initialize compilation cache. Parameters ---------- enable_stats : bool Track cache statistics max_cache_size : int Maximum number of compiled functions to cache (default 512). Increased from 256 to reduce recompilation frequency. Caps memory usage at approximately 4GB for 512 cached functions. """ # Task 9.2: Use OrderedDict for LRU eviction self.cache: OrderedDict[str, Callable] = OrderedDict() self.enable_stats = enable_stats self.max_cache_size = max_cache_size # WeakKeyDictionary: entries are auto-removed when the function # object is garbage-collected, preventing stale id() reuse self._func_hash_cache: weakref.WeakKeyDictionary[Callable, str] = ( weakref.WeakKeyDictionary() ) if enable_stats: self.stats = { "hits": 0, "misses": 0, "compilations": 0, "cache_size": 0, } self._lock = threading.Lock()
def _get_function_code_hash(self, func: Callable) -> str: """Get memoized hash of function code. This method caches function code hashes using a WeakKeyDictionary keyed on the function object itself. This prevents stale entries when functions are garbage-collected and a new object reuses the same ``id()``. The WeakKeyDictionary is protected by ``self._lock`` to prevent TOCTOU races where GC fires between ``__contains__`` and ``__getitem__``. Parameters ---------- func : callable Function to hash Returns ------- hash : str SHA256 hash of function code (first 8 chars) """ # Check memoization cache under lock (GC can fire between # __contains__ and __getitem__ on WeakKeyDictionary) with self._lock: if func in self._func_hash_cache: return self._func_hash_cache[func] # Compute hash outside lock (no shared state access) try: func_code = func.__code__.co_code if hasattr(func, "__code__") else b"" code_hash = hashlib.sha256(func_code).hexdigest()[:8] except (AttributeError, TypeError): # Fallback: hash the function's qualified name name = getattr(func, "__qualname__", getattr(func, "__name__", "unknown")) code_hash = hashlib.sha256(name.encode()).hexdigest()[:8] # Store under lock with self._lock: self._func_hash_cache[func] = code_hash return code_hash def _get_function_signature(self, func: Callable, *args, **kwargs) -> str: """Generate unique signature for function and arguments. Parameters ---------- func : callable Function to generate signature for args : tuple Positional arguments kwargs : dict Keyword arguments Returns ------- signature : str Unique signature string """ try: # Get function name and code func_name = func.__name__ if hasattr(func, "__name__") else "unknown" # Get argument shapes and dtypes arg_info = [] for arg in args: if isinstance(arg, jax.Array): arg_info.append(f"{arg.shape}_{arg.dtype}") elif isinstance(arg, (int, float, str, bool)): arg_info.append(f"{type(arg).__name__}_{arg}") else: arg_info.append(str(type(arg).__name__)) # Include static arguments from kwargs for k, v in sorted(kwargs.items()): if isinstance(v, (int, float, str, bool)): arg_info.append(f"{k}={v}") # Create signature sig_str = f"{func_name}_{'_'.join(arg_info)}" # Hash if too long if len(sig_str) > 200: sig_hash = hashlib.sha256(sig_str.encode()).hexdigest()[:16] sig_str = f"{func_name}_{sig_hash}" return sig_str except (AttributeError, TypeError) as e: warnings.warn(f"Could not generate function signature: {e}") return f"fallback_{id(func)}" def _evict_if_at_capacity(self): """Evict oldest entry if cache is at capacity. Uses LRU eviction with popitem(last=False) to remove the least recently used entry. Note: Caller must hold ``self._lock``. """ while len(self.cache) >= self.max_cache_size: # Task 9.2: LRU eviction using popitem(last=False) self.cache.popitem(last=False)
[docs] def compile( self, func: Callable, static_argnums: tuple[int, ...] = (), donate_argnums: tuple[int, ...] = (), ) -> Callable: """Compile function with JIT and cache result. Parameters ---------- func : callable Function to compile static_argnums : tuple of int Indices of static arguments donate_argnums : tuple of int Indices of arguments to donate Returns ------- compiled_func : callable JIT-compiled function (may be cached) """ # Create cache key based on function and compilation options # Uses memoized function hash for 95% faster repeated lookups try: code_hash = self._get_function_code_hash(func) func_name = func.__name__ if hasattr(func, "__name__") else "unknown" cache_key = f"{func_name}_{code_hash}_s{static_argnums}_d{donate_argnums}" except (AttributeError, TypeError): cache_key = f"{id(func)}_s{static_argnums}_d{donate_argnums}" # Check cache and record miss stats in single lock acquisition with self._lock: if cache_key in self.cache: if self.enable_stats: self.stats["hits"] += 1 # Task 9.2: Move to end for LRU tracking (most recently used) self.cache.move_to_end(cache_key) return self.cache[cache_key] # Cache miss -- update stats while we still hold the lock if self.enable_stats: self.stats["misses"] += 1 self.stats["compilations"] += 1 # Compile function outside lock (jax.jit can be slow) compiled_func = jax.jit( func, static_argnums=static_argnums, donate_argnums=donate_argnums ) # Store in cache under lock (re-check to avoid duplicate) with self._lock: if cache_key in self.cache: self.cache.move_to_end(cache_key) return self.cache[cache_key] # Task 9.2: Evict oldest entry if at capacity before adding new one self._evict_if_at_capacity() # Store in cache (at end, as most recently used) self.cache[cache_key] = compiled_func if self.enable_stats: self.stats["cache_size"] = len(self.cache) return compiled_func
[docs] def get_or_compile( self, func: Callable, *args, static_argnums: tuple[int, ...] = (), **kwargs ) -> tuple[Callable, str]: """Get cached compiled function or compile if not cached. Parameters ---------- func : callable Function to compile args : tuple Arguments to function (for signature generation) static_argnums : tuple of int Indices of static arguments kwargs : dict Keyword arguments Returns ------- compiled_func : callable Compiled function signature : str Function signature """ sig = self._get_function_signature(func, *args, **kwargs) full_key = f"{sig}_s{static_argnums}" # Check cache and record miss stats in single lock acquisition with self._lock: if full_key in self.cache: if self.enable_stats: self.stats["hits"] += 1 # Task 9.2: Move to end for LRU tracking self.cache.move_to_end(full_key) return self.cache[full_key], sig # Cache miss -- update stats while we still hold the lock if self.enable_stats: self.stats["misses"] += 1 self.stats["compilations"] += 1 # Compile outside lock (jax.jit can be slow) compiled_func = jax.jit(func, static_argnums=static_argnums) # Store in cache under lock (re-check to avoid duplicate) with self._lock: if full_key in self.cache: self.cache.move_to_end(full_key) return self.cache[full_key], sig # Task 9.2: Evict if at capacity before storing self._evict_if_at_capacity() # Store with full signature self.cache[full_key] = compiled_func if self.enable_stats: self.stats["cache_size"] = len(self.cache) return compiled_func, sig
[docs] def clear(self): """Clear compilation cache, function hash memoization, and reset stats. This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation. """ with self._lock: self.cache.clear() self._func_hash_cache.clear() if self.enable_stats: self.stats["hits"] = 0 self.stats["misses"] = 0 self.stats["compilations"] = 0 self.stats["cache_size"] = 0
[docs] def get_stats(self) -> dict: """Get cache statistics. Returns ------- stats : dict Cache hit rate and other statistics """ if not self.enable_stats: return {"enabled": False} with self._lock: total_requests = self.stats["hits"] + self.stats["misses"] hit_rate = ( self.stats["hits"] / total_requests if total_requests > 0 else 0.0 ) return { **self.stats, "hit_rate": hit_rate, "total_requests": total_requests, "max_cache_size": self.max_cache_size, }
[docs] def __enter__(self): """Context manager entry.""" return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" return False
# Global compilation cache (thread-safe double-checked locking) _global_compilation_cache: CompilationCache | None = None _global_compilation_cache_lock = threading.Lock()
[docs] def get_global_compilation_cache() -> CompilationCache: """Get or create global compilation cache (thread-safe). Uses double-checked locking to avoid acquiring the lock on every call once the singleton has been initialized. Returns ------- cache : CompilationCache Global compilation cache instance """ global _global_compilation_cache # noqa: PLW0603 if _global_compilation_cache is not None: return _global_compilation_cache with _global_compilation_cache_lock: if _global_compilation_cache is None: _global_compilation_cache = CompilationCache(enable_stats=True) return _global_compilation_cache
[docs] def cached_jit( func: Callable | None = None, static_argnums: tuple[int, ...] = (), donate_argnums: tuple[int, ...] = (), ) -> Callable: """Decorator for caching JIT-compiled functions. Parameters ---------- func : callable, optional Function to decorate static_argnums : tuple of int Indices of static arguments donate_argnums : tuple of int Indices of arguments to donate Returns ------- decorated : callable Decorated function with cached compilation Examples -------- >>> @cached_jit ... def my_function(x): ... return x ** 2 >>> @cached_jit(static_argnums=(1,)) ... def my_function_with_static(x, n): ... return x ** n """ def decorator(f): cache = get_global_compilation_cache() @wraps(f) def wrapper(*args, **kwargs): compiled_func, _ = cache.get_or_compile( f, *args, static_argnums=static_argnums, **kwargs ) return compiled_func(*args, **kwargs) return wrapper if func is None: # Called with arguments: @cached_jit(static_argnums=(1,)) return decorator else: # Called without arguments: @cached_jit return decorator(func)
[docs] def clear_compilation_cache(): """Clear the global compilation cache, function hash cache, and reset stats. This function is useful in interactive environments (Jupyter notebooks, IPython, REPL) where model functions may be redefined during development. When a function is redefined with the same name but different code, the compilation cache may return stale compiled versions. Calling this function clears the cache and allows fresh compilation of redefined functions. The function also resets all statistics counters (hits, misses, compilations) to zero, enabling accurate cache performance tracking after the clear. Examples -------- In a Jupyter notebook, after redefining a model function: >>> import jax.numpy as jnp >>> from nlsq.caching.compilation_cache import clear_compilation_cache >>> from nlsq import curve_fit >>> >>> # Initial model definition >>> def model(x, a, b): ... return a * jnp.exp(-b * x) >>> >>> # ... some fitting work ... >>> >>> # Redefine the model with different implementation >>> def model(x, a, b): ... return a * jnp.exp(-b * x) + 0.1 # Added offset >>> >>> # Clear cache to ensure new model is compiled >>> clear_compilation_cache() >>> >>> # Now curve_fit will use the updated model >>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1]) """ global _global_compilation_cache # noqa: PLW0602 if _global_compilation_cache is not None: _global_compilation_cache.clear()