nlsq.unified_cache module

Unified JAX JIT compilation cache for NLSQ.

This module consolidates three independent cache implementations (compilation_cache.py, caching.py, smart_cache.py) into a single unified cache with shape-relaxed keys, comprehensive statistics tracking, and optimized memory management.

Key Features

  • Shape-relaxed cache keys: (func_hash, dtype, rank) instead of full shapes

  • Comprehensive statistics: hits, misses, compile_time_ms, hit_rate

  • LRU eviction with configurable maxsize

  • Optional two-tier caching (memory + disk)

  • Weak references to avoid memory leaks

  • Thread-safe operations via per-instance threading.Lock

  • Async disk writes (deferred to Phase 2)

Design Goals

  1. 80%+ cache hit rate on typical batch fitting workflows

  2. 2-5x reduction in cold-start compile time through better cache reuse

  3. Backward compatibility with existing cache APIs (gradual migration)

  4. Zero breaking changes to curve_fit API

class nlsq.caching.unified_cache.UnifiedCache(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]

Bases: object

Unified compilation cache merging three legacy cache patterns.

This cache provides: - Shape-relaxed keys: cache based on (func_hash, dtype, rank) not exact shapes - Comprehensive stats: hits, misses, compile_time_ms, hit_rate, cache_size - LRU eviction when maxsize exceeded - Optional disk caching (two-tier architecture) - Weak references to functions to prevent memory leaks - Thread-safe operations via per-instance lock

maxsize

Maximum number of compiled functions to cache

Type:

int

enable_stats

Whether to track cache statistics

Type:

bool

disk_cache_enabled

Whether to enable disk caching (default: False for Phase 1)

Type:

bool

Examples

>>> cache = UnifiedCache(maxsize=128, enable_stats=True)
>>> def my_func(x, a):
...     return a * x ** 2
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,))
>>> result = compiled(x, 1.0)
>>> stats = cache.get_stats()
>>> print(f"Hit rate: {stats['hit_rate']:.2%}")
__init__(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]

Initialize unified cache.

Parameters:
  • maxsize (int, default=128) – Maximum number of compiled functions to cache (LRU eviction)

  • enable_stats (bool, default=True) – Track cache statistics (hits, misses, compile_time_ms)

  • disk_cache_enabled (bool, default=False) – Enable disk caching tier (Phase 2 feature)

get_or_compile(func, args, kwargs, static_argnums=(), donate_argnums=())[source]

Get cached compiled function or compile if not cached.

Thread-safe: the lock is held for dict operations only and released during expensive JIT compilation to avoid serializing parallel compilations. A re-check after compilation prevents duplicate stores when multiple threads miss on the same key simultaneously.

Parameters:
  • func (Callable) – Function to compile

  • args (tuple) – Arguments to function (for signature generation)

  • kwargs (dict) – Keyword arguments

  • static_argnums (tuple of int, default=()) – Indices of static arguments for JIT

  • donate_argnums (tuple of int, default=()) – Indices of arguments to donate for memory efficiency

Returns:

compiled_func – JIT-compiled function (from cache or newly compiled)

Return type:

Callable

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache statistics including: - hits : int - Number of cache hits - misses : int - Number of cache misses - compilations : int - Number of JIT compilations performed - evictions : int - Number of LRU evictions - cache_size : int - Current cache size - hit_rate : float - Cache hit rate (hits / total_requests) - compile_time_ms : float - Total compilation time in milliseconds

Return type:

dict

clear()[source]

Clear all cached compilations and reset statistics.

__repr__()[source]

String representation of cache.

nlsq.caching.unified_cache.get_global_cache()[source]

Get or create global unified cache instance (thread-safe).

Returns:

cache – Global unified cache instance

Return type:

UnifiedCache

nlsq.caching.unified_cache.clear_cache()[source]

Clear the global unified cache.

nlsq.caching.unified_cache.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]

Decorator for cached JIT compilation using unified cache.

This decorator provides automatic caching of JIT-compiled functions with shape-relaxed keys for better cache reuse.

Parameters:
  • func (Callable, optional) – Function to decorate

  • static_argnums (tuple of int, default=()) – Indices of static arguments

  • donate_argnums (tuple of int, default=()) – Indices of arguments to donate

Returns:

decorated – Decorated function with cached compilation

Return type:

Callable

Examples

>>> @cached_jit(static_argnums=(1,))
... def my_function(x, n):
...     return x ** n
>>> @cached_jit
... def simple_function(x):
...     return x ** 2
nlsq.caching.unified_cache.get_cache_stats()[source]

Get statistics from the global unified cache.

Returns:

stats – Cache statistics

Return type:

dict

class nlsq.caching.unified_cache.CompilationCacheCompat(enable_stats=True)[source]

Bases: object

Backward compatibility wrapper for compilation_cache.py API.

This allows gradual migration of existing code using CompilationCache to the new UnifiedCache without breaking changes.

__init__(enable_stats=True)[source]

Initialize compatibility wrapper.

compile(func, static_argnums=(), donate_argnums=())[source]

Compile function with JIT and cache result (compatibility wrapper).

get_stats()[source]

Get cache statistics (compatibility wrapper).

clear()[source]

Clear cache (compatibility wrapper).

class nlsq.caching.unified_cache.FunctionCacheCompat(maxsize=128)[source]

Bases: object

Backward compatibility wrapper for caching.py API.

__init__(maxsize=128)[source]

Initialize compatibility wrapper.

get_function_hash(func)[source]

Generate stable hash for a function (compatibility wrapper).

get_stats()[source]

Get cache statistics (compatibility wrapper).

clear()[source]

Clear cache (compatibility wrapper).

property hit_rate: float

Get cache hit rate (compatibility wrapper).

Overview

The unified_cache module consolidates three independent cache implementations into a single unified cache with shape-relaxed keys, comprehensive statistics tracking, and optimized memory management.

Key Features

  • Shape-relaxed cache keys: Cache based on (func_hash, dtype, rank) instead of full shapes

  • Comprehensive statistics: Track hits, misses, compile time, hit rate

  • LRU eviction: Configurable maxsize with automatic eviction

  • Two-tier caching: Optional memory + disk caching

  • Weak references: Prevents memory leaks

  • Thread-safe operations: Safe for concurrent use

Performance Goals

  • 80%+ cache hit rate on typical batch fitting workflows

  • 2-5x reduction in cold-start compile time through better cache reuse

  • Backward compatibility with existing cache APIs

  • Zero breaking changes to curve_fit API

Classes

class nlsq.unified_cache.UnifiedCache(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]

Bases: object

Unified compilation cache merging three legacy cache patterns.

This cache provides: - Shape-relaxed keys: cache based on (func_hash, dtype, rank) not exact shapes - Comprehensive stats: hits, misses, compile_time_ms, hit_rate, cache_size - LRU eviction when maxsize exceeded - Optional disk caching (two-tier architecture) - Weak references to functions to prevent memory leaks - Thread-safe operations via per-instance lock

maxsize

Maximum number of compiled functions to cache

Type:

int

enable_stats

Whether to track cache statistics

Type:

bool

disk_cache_enabled

Whether to enable disk caching (default: False for Phase 1)

Type:

bool

Examples

>>> cache = UnifiedCache(maxsize=128, enable_stats=True)
>>> def my_func(x, a):
...     return a * x ** 2
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,))
>>> result = compiled(x, 1.0)
>>> stats = cache.get_stats()
>>> print(f"Hit rate: {stats['hit_rate']:.2%}")
__init__(maxsize=128, enable_stats=True, disk_cache_enabled=False)[source]

Initialize unified cache.

Parameters:
  • maxsize (int, default=128) – Maximum number of compiled functions to cache (LRU eviction)

  • enable_stats (bool, default=True) – Track cache statistics (hits, misses, compile_time_ms)

  • disk_cache_enabled (bool, default=False) – Enable disk caching tier (Phase 2 feature)

get_or_compile(func, args, kwargs, static_argnums=(), donate_argnums=())[source]

Get cached compiled function or compile if not cached.

Thread-safe: the lock is held for dict operations only and released during expensive JIT compilation to avoid serializing parallel compilations. A re-check after compilation prevents duplicate stores when multiple threads miss on the same key simultaneously.

Parameters:
  • func (Callable) – Function to compile

  • args (tuple) – Arguments to function (for signature generation)

  • kwargs (dict) – Keyword arguments

  • static_argnums (tuple of int, default=()) – Indices of static arguments for JIT

  • donate_argnums (tuple of int, default=()) – Indices of arguments to donate for memory efficiency

Returns:

compiled_func – JIT-compiled function (from cache or newly compiled)

Return type:

Callable

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache statistics including: - hits : int - Number of cache hits - misses : int - Number of cache misses - compilations : int - Number of JIT compilations performed - evictions : int - Number of LRU evictions - cache_size : int - Current cache size - hit_rate : float - Cache hit rate (hits / total_requests) - compile_time_ms : float - Total compilation time in milliseconds

Return type:

dict

clear()[source]

Clear all cached compilations and reset statistics.

__repr__()[source]

String representation of cache.

Example Usage

from nlsq.unified_cache import UnifiedCache
import jax.numpy as jnp

# Create cache with statistics tracking
cache = UnifiedCache(maxsize=128, enable_stats=True)


# Define function to cache
def my_func(x, a):
    return a * x**2


# Use cache for JIT compilation
x = jnp.array([1.0, 2.0, 3.0])
compiled = cache.get_or_compile(my_func, (x, 1.0), {}, static_argnums=(1,))
result = compiled(x, 1.0)

# Check cache statistics
stats = cache.get_stats()
print(f"Hit rate: {stats['hit_rate']:.2%}")
print(f"Cache size: {stats['cache_size']}")

See Also