nlsq.compilation_cache module

JIT compilation cache for optimization functions.

This module provides caching of compiled JAX functions to avoid recompilation overhead.

Phase 3 Optimizations (Task Group 9): - LRU eviction with max_cache_size parameter (default 256) - Function hash race condition fix with composite key (id(func), id(func.__code__))

class nlsq.caching.compilation_cache.CompilationCache(enable_stats=True, max_cache_size=512)[source]

Bases: object

Cache for JIT-compiled functions with LRU eviction.

Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.

Phase 3 Optimizations (2.2a, 2.1a):

  • Uses OrderedDict for LRU tracking with move_to_end() on hits

  • Evicts oldest entry with popitem(last=False) when at capacity

  • Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name

Thread Safety

All public methods are guarded by a per-instance threading.Lock. The lock is held only for dict operations and released during jax.jit() compilation to avoid blocking concurrent threads.

cache

OrderedDict mapping function signatures to compiled functions (enables LRU eviction)

Type:

OrderedDict

max_cache_size

Maximum number of compiled functions to cache (default 512)

Type:

int

stats

Compilation cache statistics

Type:

dict

_func_hash_cache

Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks

Type:

dict

__init__(enable_stats=True, max_cache_size=512)[source]

Initialize compilation cache.

Parameters:
  • enable_stats (bool) – Track cache statistics

  • max_cache_size (int) – Maximum number of compiled functions to cache (default 512). Increased from 256 to reduce recompilation frequency. Caps memory usage at approximately 4GB for 512 cached functions.

compile(func, static_argnums=(), donate_argnums=())[source]

Compile function with JIT and cache result.

Parameters:
  • func (callable) – Function to compile

  • static_argnums (tuple of int) – Indices of static arguments

  • donate_argnums (tuple of int) – Indices of arguments to donate

Returns:

compiled_func – JIT-compiled function (may be cached)

Return type:

callable

get_or_compile(func, *args, static_argnums=(), **kwargs)[source]

Get cached compiled function or compile if not cached.

Parameters:
  • func (callable) – Function to compile

  • args (tuple) – Arguments to function (for signature generation)

  • static_argnums (tuple of int) – Indices of static arguments

  • kwargs (dict) – Keyword arguments

Returns:

  • compiled_func (callable) – Compiled function

  • signature (str) – Function signature

Return type:

tuple[Callable, str]

clear()[source]

Clear compilation cache, function hash memoization, and reset stats.

This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache hit rate and other statistics

Return type:

dict

__enter__()[source]

Context manager entry.

__exit__(exc_type, exc_val, exc_tb)[source]

Context manager exit.

nlsq.caching.compilation_cache.get_global_compilation_cache()[source]

Get or create global compilation cache (thread-safe).

Uses double-checked locking to avoid acquiring the lock on every call once the singleton has been initialized.

Returns:

cache – Global compilation cache instance

Return type:

CompilationCache

nlsq.caching.compilation_cache.cached_jit(func=None, static_argnums=(), donate_argnums=())[source]

Decorator for caching JIT-compiled functions.

Parameters:
  • func (callable, optional) – Function to decorate

  • static_argnums (tuple of int) – Indices of static arguments

  • donate_argnums (tuple of int) – Indices of arguments to donate

Returns:

decorated – Decorated function with cached compilation

Return type:

callable

Examples

>>> @cached_jit
... def my_function(x):
...     return x ** 2
>>> @cached_jit(static_argnums=(1,))
... def my_function_with_static(x, n):
...     return x ** n
nlsq.caching.compilation_cache.clear_compilation_cache()[source]

Clear the global compilation cache, function hash cache, and reset stats.

This function is useful in interactive environments (Jupyter notebooks, IPython, REPL) where model functions may be redefined during development. When a function is redefined with the same name but different code, the compilation cache may return stale compiled versions. Calling this function clears the cache and allows fresh compilation of redefined functions.

The function also resets all statistics counters (hits, misses, compilations) to zero, enabling accurate cache performance tracking after the clear.

Examples

In a Jupyter notebook, after redefining a model function:

>>> import jax.numpy as jnp
>>> from nlsq.caching.compilation_cache import clear_compilation_cache
>>> from nlsq import curve_fit
>>>
>>> # Initial model definition
>>> def model(x, a, b):
...     return a * jnp.exp(-b * x)
>>>
>>> # ... some fitting work ...
>>>
>>> # Redefine the model with different implementation
>>> def model(x, a, b):
...     return a * jnp.exp(-b * x) + 0.1  # Added offset
>>>
>>> # Clear cache to ensure new model is compiled
>>> clear_compilation_cache()
>>>
>>> # Now curve_fit will use the updated model
>>> popt, pcov = curve_fit(model, xdata, ydata, p0=[1.0, 0.1])

Overview

The compilation_cache module provides JIT compilation caching to avoid recompilation overhead.

Note

This module is being consolidated into nlsq.unified_cache module. For new code, use the unified cache for better performance and features.

Key Features

  • JIT compilation caching for function reuse

  • Hash-based cache keys for function identification

  • Automatic cache invalidation when functions change

  • Memory-efficient storage with weak references

Classes

class nlsq.caching.compilation_cache.CompilationCache(enable_stats=True, max_cache_size=512)[source]

Bases: object

Cache for JIT-compiled functions with LRU eviction.

Caches compiled versions of functions based on their signature to avoid repeated JIT compilation overhead.

Phase 3 Optimizations (2.2a, 2.1a):

  • Uses OrderedDict for LRU tracking with move_to_end() on hits

  • Evicts oldest entry with popitem(last=False) when at capacity

  • Uses composite key (id(func), id(func.__code__)) to prevent cache poisoning when functions are redefined with same name

Thread Safety

All public methods are guarded by a per-instance threading.Lock. The lock is held only for dict operations and released during jax.jit() compilation to avoid blocking concurrent threads.

cache

OrderedDict mapping function signatures to compiled functions (enables LRU eviction)

Type:

OrderedDict

max_cache_size

Maximum number of compiled functions to cache (default 512)

Type:

int

stats

Compilation cache statistics

Type:

dict

_func_hash_cache

Memoization cache for function code hashes using composite key (id(func), id(func.__code__)) for correctness in notebooks

Type:

dict

__init__(enable_stats=True, max_cache_size=512)[source]

Initialize compilation cache.

Parameters:
  • enable_stats (bool) – Track cache statistics

  • max_cache_size (int) – Maximum number of compiled functions to cache (default 512). Increased from 256 to reduce recompilation frequency. Caps memory usage at approximately 4GB for 512 cached functions.

compile(func, static_argnums=(), donate_argnums=())[source]

Compile function with JIT and cache result.

Parameters:
  • func (callable) – Function to compile

  • static_argnums (tuple of int) – Indices of static arguments

  • donate_argnums (tuple of int) – Indices of arguments to donate

Returns:

compiled_func – JIT-compiled function (may be cached)

Return type:

callable

get_or_compile(func, *args, static_argnums=(), **kwargs)[source]

Get cached compiled function or compile if not cached.

Parameters:
  • func (callable) – Function to compile

  • args (tuple) – Arguments to function (for signature generation)

  • static_argnums (tuple of int) – Indices of static arguments

  • kwargs (dict) – Keyword arguments

Returns:

  • compiled_func (callable) – Compiled function

  • signature (str) – Function signature

Return type:

tuple[Callable, str]

clear()[source]

Clear compilation cache, function hash memoization, and reset stats.

This method clears all cached data and resets statistics counters to zero, allowing accurate hit/miss tracking after the clear operation.

get_stats()[source]

Get cache statistics.

Returns:

stats – Cache hit rate and other statistics

Return type:

dict

__enter__()[source]

Context manager entry.

__exit__(exc_type, exc_val, exc_tb)[source]

Context manager exit.

Example Usage

from nlsq.caching.compilation_cache import CompilationCache

cache = CompilationCache(maxsize=100)

# Cache will automatically be used by curve_fit
# to avoid recompiling the same function

See Also