"""Smart caching system for NLSQ optimization.
This module provides intelligent caching for expensive computations,
particularly Jacobian evaluations and function calls.
Note: This module uses safe serialization only (JSON and numpy.savez
with allow_pickle=False). No pickle is used.
Phase 3 Optimizations (Task Group 9):
- Array hash optimization: stride-based sampling only for >10000 elements
- For smaller arrays, hash full array directly (no redundant sampling)
"""
import hashlib
import json
import logging
import os
import threading
import time
from collections import OrderedDict
from collections.abc import Callable
from functools import wraps
from typing import Any
import numpy as np
from nlsq.config import JAXConfig
_jax_config = JAXConfig()
_logger = logging.getLogger(__name__)
import contextlib
import jax.numpy as jnp
# Cache version for invalidating old cache entries when hash algorithm changes
CACHE_VERSION = "v2"
# Threshold for using stride-based sampling (Task 9.3)
# Arrays larger than this use stride sampling for efficiency
LARGE_ARRAY_THRESHOLD = 10000
# Try to use xxhash for faster hashing (10x faster than SHA256)
try:
import xxhash # type: ignore[import-not-found]
HAS_XXHASH = True
except ImportError:
HAS_XXHASH = False
[docs]
class SmartCache:
"""Intelligent caching system for optimization computations.
This class provides:
- Memory and disk caching with LRU eviction
- Automatic cache key generation from function arguments
- Cache persistence across sessions
- Cache invalidation and warming strategies
Phase 3 Optimizations (3.2a):
- Array hash optimization: uses stride-based sampling only for
arrays with >10000 elements when xxhash is unavailable
- For smaller arrays, hashes full array directly without redundant
sampling, providing 15-20% improvement in cache key generation
All dict operations are protected by a per-instance ``threading.Lock``
so that concurrent threads can safely call ``get``/``set``/``invalidate``.
Attributes
----------
cache_dir : str
Directory for disk cache storage
memory_cache : dict
In-memory cache storage
disk_cache_enabled : bool
Whether disk caching is enabled
max_memory_items : int
Maximum items in memory cache
cache_stats : dict
Cache hit/miss statistics
"""
[docs]
def __init__(
self,
cache_dir: str = ".nlsq_cache",
max_memory_items: int = 1000,
disk_cache_enabled: bool = True,
enable_stats: bool = True,
):
"""Initialize smart cache.
Parameters
----------
cache_dir : str
Directory for disk cache
max_memory_items : int
Maximum items in memory cache
disk_cache_enabled : bool
Enable disk caching
enable_stats : bool
Track cache statistics
"""
self._lock = threading.Lock()
self.cache_dir = cache_dir
self.memory_cache: dict[str, tuple[Any, float]] = {} # value, timestamp
self.access_count: dict[str, int] = {} # Track access frequency
self.disk_cache_enabled = disk_cache_enabled
self.max_memory_items = max_memory_items
self.enable_stats = enable_stats
# Statistics
self.cache_stats = {
"hits": 0,
"misses": 0,
"memory_hits": 0,
"disk_hits": 0,
"evictions": 0,
}
# Create cache directory if needed
if disk_cache_enabled and not os.path.exists(cache_dir):
try:
os.makedirs(cache_dir)
except OSError:
_logger.debug("Could not create cache directory %s", cache_dir)
self.disk_cache_enabled = False
[docs]
def cache_key(self, *args, **kwargs) -> str:
"""Generate cache key from arguments.
Parameters
----------
*args : tuple
Positional arguments
**kwargs : dict
Keyword arguments
Returns
-------
key : str
Hash of arguments (xxhash if available, BLAKE2b fallback)
Notes
-----
Uses xxhash (xxh64) when available for ~10x faster hashing compared
to SHA256/BLAKE2b. Falls back to BLAKE2b if xxhash is not installed.
All cache keys are prefixed with CACHE_VERSION to ensure old cache
entries are invalidated when the hash algorithm changes.
Task 9.3 (3.2a): Array hash optimization
- Arrays <= 10000 elements: hash full array directly (no sampling)
- Arrays > 10000 elements: use stride-based sampling for efficiency
- Removes redundant sampling when computing full hash in fallback path
"""
key_parts = []
for arg in args:
if isinstance(arg, (np.ndarray, jnp.ndarray)):
# For arrays, use shape, dtype, and fast hash of values
arr = np.asarray(arg)
if HAS_XXHASH:
# Fast path: xxhash on contiguous data (10x faster than SHA256)
if arr.flags["C_CONTIGUOUS"]:
data_hash = xxhash.xxh64(arr).hexdigest()[:16]
else:
data_hash = xxhash.xxh64(np.ascontiguousarray(arr)).hexdigest()[
:16
]
key_parts.append(f"array_{arg.shape}_{arg.dtype}_{data_hash}")
else:
# Task 9.3: Optimized fallback path
# Use stride-based sampling ONLY for very large arrays (>10000 elements)
arr_flat = arr.flatten()
arr_size = len(arr_flat)
if arr_size > LARGE_ARRAY_THRESHOLD:
# Large array: use stride-based sampling for efficiency
# Calculate stride to sample approximately 1000 elements
stride = max(1, arr_size // 1000)
sample = arr_flat[::stride]
# Use BLAKE2b for the sample hash
sample_hash = hashlib.blake2b(
sample.tobytes(), digest_size=16
).hexdigest()
key_parts.append(f"array_{arg.shape}_{arg.dtype}_{sample_hash}")
else:
# Small/medium array: hash full array directly (no sampling overhead)
# This is the optimized path - removes redundant sampling
full_hash = hashlib.blake2b(
arr_flat.tobytes(), digest_size=16
).hexdigest()
key_parts.append(f"array_{arg.shape}_{arg.dtype}_{full_hash}")
elif callable(arg):
# For functions, use their name and module
key_parts.append(f"func_{arg.__module__}_{arg.__name__}")
else:
key_parts.append(str(arg))
# Add kwargs with same type-aware hashing as positional args
for k, v in sorted(kwargs.items()):
if isinstance(v, (np.ndarray, jnp.ndarray)):
arr = np.asarray(v)
if HAS_XXHASH:
raw = (
arr if arr.flags["C_CONTIGUOUS"] else np.ascontiguousarray(arr)
)
h = xxhash.xxh64(raw).hexdigest()[:16]
else:
flat = arr.flatten()
if len(flat) > LARGE_ARRAY_THRESHOLD:
h = hashlib.blake2b(
flat[:: max(1, len(flat) // 1000)].tobytes(), digest_size=16
).hexdigest()
else:
h = hashlib.blake2b(flat.tobytes(), digest_size=16).hexdigest()
key_parts.append(f"{k}=array_{v.shape}_{v.dtype}_{h}")
elif callable(v):
key_parts.append(f"{k}=func_{v.__module__}_{v.__name__}")
else:
key_parts.append(f"{k}={v}")
key_str = "|".join(key_parts)
# Use xxhash for final key if available, BLAKE2b as fallback
if HAS_XXHASH:
hash_hex = xxhash.xxh64(key_str.encode()).hexdigest()
else:
# Use BLAKE2b instead of MD5 for better security and collision resistance
hash_hex = hashlib.blake2b(key_str.encode(), digest_size=16).hexdigest()
# Prefix with cache version to invalidate old cache entries
return f"{CACHE_VERSION}_{hash_hex}"
[docs]
def get(self, key: str) -> Any | None:
"""Get value from cache.
Parameters
----------
key : str
Cache key
Returns
-------
value : Any or None
Cached value or None if not found
"""
# Check memory cache first (under lock for atomic LRU update)
with self._lock:
if key in self.memory_cache:
value, timestamp = self.memory_cache[key]
self.access_count[key] = self.access_count.get(key, 0) + 1
if self.enable_stats:
self.cache_stats["hits"] += 1
self.cache_stats["memory_hits"] += 1
# Move to end (LRU)
del self.memory_cache[key]
self.memory_cache[key] = (value, timestamp)
return value
# Check disk cache (disk I/O outside lock)
if self.disk_cache_enabled:
cache_file = os.path.join(self.cache_dir, f"{key}.npz")
if os.path.exists(cache_file):
try:
value = self._load_from_disk(cache_file)
with self._lock:
if self.enable_stats:
self.cache_stats["hits"] += 1
self.cache_stats["disk_hits"] += 1
# Add to memory cache, preserving file mtime as timestamp
# so TTL checks reflect the original cache time, not load time
disk_mtime = os.path.getmtime(cache_file)
self._add_to_memory_cache(key, value, timestamp=disk_mtime)
return value
except Exception as e:
_logger.debug("Could not load from disk cache: %s", e)
# Remove corrupted cache file
with contextlib.suppress(OSError):
os.remove(cache_file)
with self._lock:
if self.enable_stats:
self.cache_stats["misses"] += 1
return None
[docs]
def set(self, key: str, value: Any):
"""Set value in cache.
Parameters
----------
key : str
Cache key
value : Any
Value to cache
"""
# Add to memory cache (lock acquired inside _add_to_memory_cache)
self._add_to_memory_cache(key, value)
# Save to disk cache (disk I/O outside lock)
if self.disk_cache_enabled:
cache_file = os.path.join(self.cache_dir, f"{key}.npz")
try:
self._save_to_disk(cache_file, value)
except Exception as e:
_logger.debug("Could not save to disk cache: %s", e)
def _add_to_memory_cache(
self, key: str, value: Any, timestamp: float | None = None
):
"""Add item to memory cache with LRU eviction.
Parameters
----------
key : str
Cache key
value : Any
Value to cache
timestamp : float, optional
Timestamp to associate with the entry. When loading from disk,
pass the file's mtime so that TTL checks reflect the original
cache time rather than the load time. Defaults to ``time.time()``.
"""
with self._lock:
# Check if we need to evict
if len(self.memory_cache) >= self.max_memory_items:
# Evict least recently used item
if self.memory_cache:
oldest_key = next(iter(self.memory_cache))
del self.memory_cache[oldest_key]
if oldest_key in self.access_count:
del self.access_count[oldest_key]
if self.enable_stats:
self.cache_stats["evictions"] += 1
self.memory_cache[key] = (
value,
timestamp if timestamp is not None else time.time(),
)
self.access_count[key] = self.access_count.get(key, 0) + 1
[docs]
def invalidate(self, key: str | None = None):
"""Invalidate cache entries.
Parameters
----------
key : str, optional
Specific key to invalidate, or None to clear all
"""
if key is None:
# Clear all caches (dict ops under lock, disk I/O outside)
with self._lock:
self.memory_cache.clear()
self.access_count.clear()
if self.disk_cache_enabled and os.path.isdir(self.cache_dir):
try:
for file in os.listdir(self.cache_dir):
if file.endswith(".npz"):
os.remove(os.path.join(self.cache_dir, file))
except OSError as e:
_logger.debug("Could not clear disk cache: %s", e)
else:
# Clear specific key
with self._lock:
if key in self.memory_cache:
del self.memory_cache[key]
if key in self.access_count:
del self.access_count[key]
if self.disk_cache_enabled:
cache_file = os.path.join(self.cache_dir, f"{key}.npz")
if os.path.exists(cache_file):
with contextlib.suppress(OSError):
os.remove(cache_file)
[docs]
def get_stats(self) -> dict:
"""Get cache statistics.
Returns
-------
stats : dict
Cache statistics including hit rate
"""
with self._lock:
total_accesses = self.cache_stats["hits"] + self.cache_stats["misses"]
if total_accesses > 0:
hit_rate = self.cache_stats["hits"] / total_accesses
else:
hit_rate = 0.0
return {
**self.cache_stats,
"hit_rate": hit_rate,
"memory_size": len(self.memory_cache),
"total_accesses": total_accesses,
}
[docs]
def optimize_cache(self):
"""Optimize cache by removing rarely accessed items.
Computes threshold from snapshot, then re-checks live counts under
lock before invalidating to avoid evicting keys that became hot
between the snapshot and eviction.
"""
with self._lock:
if not self.access_count:
return
# Snapshot under lock
access_snapshot = dict(self.access_count)
# Calculate average access count (no lock needed for snapshot)
avg_access = np.mean(list(access_snapshot.values()))
threshold = avg_access * 0.5
# Re-check live count under lock before invalidating each key
keys_to_remove = []
with self._lock:
for key in access_snapshot:
live_count = self.access_count.get(key, 0)
if live_count < threshold:
keys_to_remove.append(key)
# Invalidate outside the lock (invalidate acquires its own lock)
for key in keys_to_remove:
self.invalidate(key)
def _save_to_disk(self, cache_file: str, value: Any):
"""Save value to disk using safe serialization.
Uses numpy.savez for arrays and JSON for other data types.
This is safe as it does not use pickle or execute arbitrary code.
Parameters
----------
cache_file : str
Path to cache file
value : Any
Value to save
"""
# Check if value is array-like (numpy or JAX array)
if isinstance(value, (np.ndarray, jnp.ndarray)):
# Convert JAX array to numpy for saving
if isinstance(value, jnp.ndarray):
value = np.asarray(value)
np.savez_compressed(cache_file, data=value)
elif isinstance(value, (dict, list, str, int, float, bool, type(None))):
# Use JSON for simple data types
json_file = cache_file.replace(".npz", ".json")
with open(json_file, "w") as f:
json.dump(value, f)
elif isinstance(value, tuple) and all(
isinstance(v, (np.ndarray, jnp.ndarray)) for v in value
):
# Handle tuple of arrays (common for multi-output functions)
arrays_dict: dict[str, Any] = {
f"arr_{i}": np.asarray(v) for i, v in enumerate(value)
}
arrays_dict["_is_tuple"] = np.array([True])
arrays_dict["_length"] = np.array([len(value)])
np.savez_compressed(cache_file, **arrays_dict)
else:
# For other types, convert to numpy array if possible
try:
arr = np.asarray(value)
np.savez_compressed(cache_file, data=arr)
except (ValueError, TypeError):
_logger.debug(
"Cannot safely cache type %s, skipping disk cache",
type(value).__name__,
)
def _load_from_disk(self, cache_file: str) -> Any:
"""Load value from disk using safe deserialization.
Uses numpy.load for arrays and JSON for other data types.
This is safe as allow_pickle=False prevents code execution.
Parameters
----------
cache_file : str
Path to cache file
Returns
-------
value : Any
Loaded value
"""
# Check if JSON file exists
json_file = cache_file.replace(".npz", ".json")
if os.path.exists(json_file):
with open(json_file) as f:
return json.load(f)
# Load from numpy file (safe: allow_pickle=False)
with np.load(cache_file, allow_pickle=False) as data:
# Check if it's a tuple of arrays
if "_is_tuple" in data.files:
length = int(data["_length"])
return tuple(data[f"arr_{i}"] for i in range(length))
# Single array
elif "data" in data.files:
return data["data"]
else:
# Legacy format or unknown structure
raise ValueError(f"Unknown cache file structure: {data.files}")
[docs]
def cached_function(cache: SmartCache | None = None, ttl: float | None = None):
"""Decorator for caching function results.
Parameters
----------
cache : SmartCache, optional
Cache instance to use (creates new if None)
ttl : float, optional
Time-to-live in seconds for cached values
Returns
-------
decorator : function
Decorator function
"""
if cache is None:
cache = SmartCache()
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Generate cache key
cache_key = cache.cache_key(func, *args, **kwargs)
# Check cache (get value and timestamp atomically under one lock)
cached_result = None
with cache._lock:
if cache_key in cache.memory_cache:
value, timestamp = cache.memory_cache[cache_key]
cache.access_count[cache_key] = (
cache.access_count.get(cache_key, 0) + 1
)
if cache.enable_stats:
cache.cache_stats["hits"] += 1
cache.cache_stats["memory_hits"] += 1
# LRU move
del cache.memory_cache[cache_key]
cache.memory_cache[cache_key] = (value, timestamp)
# TTL check
if ttl is not None and time.time() - timestamp > ttl:
value = None # expired
cached_result = value
# Disk fallback (outside lock)
if cached_result is None and cache.disk_cache_enabled:
cached_result = cache.get(cache_key)
if cached_result is None:
# Compute and cache
result = func(*args, **kwargs)
cache.set(cache_key, result)
return result
return cached_result
# Add cache management methods to wrapper function
wrapper.cache = cache # type: ignore[attr-defined]
wrapper.invalidate = cache.invalidate # type: ignore[attr-defined]
wrapper.get_stats = cache.get_stats # type: ignore[attr-defined]
return wrapper
return decorator
[docs]
def cached_jacobian(cache: SmartCache | None = None):
"""Decorator specifically for caching Jacobian evaluations.
Parameters
----------
cache : SmartCache, optional
Cache instance to use
Returns
-------
decorator : function
Decorator function
"""
if cache is None:
cache = SmartCache(max_memory_items=100) # Jacobians can be large
def decorator(func):
@wraps(func)
def wrapper(x, *params):
# Include func in key so two different functions sharing one cache
# instance don't collide on identical (x, params) inputs.
cache_key = cache.cache_key(func, x, *params)
# Check cache
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result
# Compute and cache
result = func(x, *params)
cache.set(cache_key, result)
return result
wrapper.cache = cache # type: ignore[attr-defined]
wrapper.invalidate = cache.invalidate # type: ignore[attr-defined]
return wrapper
return decorator
[docs]
class JITCompilationCache:
"""Cache for JAX JIT-compiled functions with LRU eviction.
This cache stores compiled functions to avoid recompilation
when function signatures match. Uses OrderedDict for LRU eviction
to prevent unbounded XLA compilation cache growth.
All dict operations are protected by a per-instance ``threading.Lock``
so that concurrent threads can safely call ``get_or_compile``/``clear``.
Parameters
----------
max_cache_size : int
Maximum number of compiled functions to cache (default 256).
Oldest entries are evicted when capacity is reached.
"""
[docs]
def __init__(self, max_cache_size: int = 256):
"""Initialize JIT compilation cache with LRU eviction."""
self._lock = threading.Lock()
self.compiled_functions: OrderedDict = OrderedDict()
self.compilation_times: OrderedDict = OrderedDict()
self.max_cache_size = max_cache_size
[docs]
def get_or_compile(self, func: Callable, static_argnums: tuple = ()) -> Callable:
"""Get cached compilation or compile and cache.
Parameters
----------
func : callable
Function to compile
static_argnums : tuple
Static argument numbers for JIT
Returns
-------
compiled_func : callable
JIT-compiled function
"""
from jax import jit
# Create key from function and static args
key = (func.__module__, func.__name__, static_argnums)
# Check cache under lock
with self._lock:
if key in self.compiled_functions:
self.compiled_functions.move_to_end(key)
return self.compiled_functions[key]
# Compile outside lock (jit can be slow)
start_time = time.time()
compiled_func = jit(func, static_argnums=static_argnums)
compilation_time = time.time() - start_time
# Store under lock (double-check to avoid overwriting a concurrent compile)
with self._lock:
if key not in self.compiled_functions:
# Evict oldest entry if at capacity
if len(self.compiled_functions) >= self.max_cache_size:
evicted_key, _ = self.compiled_functions.popitem(last=False)
self.compilation_times.pop(evicted_key, None)
self.compiled_functions[key] = compiled_func
self.compilation_times[key] = compilation_time
else:
# Another thread already stored it; use that one
self.compiled_functions.move_to_end(key)
compiled_func = self.compiled_functions[key]
return compiled_func
[docs]
def clear(self):
"""Clear compilation cache."""
with self._lock:
self.compiled_functions.clear()
self.compilation_times.clear()
[docs]
def get_stats(self) -> dict:
"""Get compilation statistics.
Returns
-------
stats : dict
Compilation statistics
"""
with self._lock:
return {
"cached_functions": len(self.compiled_functions),
"max_cache_size": self.max_cache_size,
"total_compilation_time": sum(self.compilation_times.values()),
"functions": list(self.compiled_functions.keys()),
}
# Global cache instances
_global_cache = SmartCache()
_jit_cache = JITCompilationCache()
[docs]
def get_global_cache() -> SmartCache:
"""Get global cache instance.
Returns
-------
cache : SmartCache
Global cache instance
"""
return _global_cache
[docs]
def get_jit_cache() -> JITCompilationCache:
"""Get JIT compilation cache.
Returns
-------
cache : JITCompilationCache
JIT compilation cache
"""
return _jit_cache
[docs]
def clear_all_caches():
"""Clear all global caches."""
_global_cache.invalidate()
_jit_cache.clear()