Source code for nlsq.caching.memory_manager

"""Memory management for NLSQ optimization.

This module provides intelligent memory management capabilities including
prediction, monitoring, pooling, and automatic garbage collection.

Phase 3 Optimizations (Task Group 9):

- Telemetry Circular Buffer (1.3a): Uses deque(maxlen=1000) for _safety_telemetry
  to prevent memory leak in multi-day optimization runs
"""

import gc
import logging
import threading
import time
import warnings
from collections import OrderedDict, deque
from contextlib import contextmanager

import jax.numpy as jnp
import numpy as np

# Module logger for debug output
logger = logging.getLogger(__name__)

try:
    import psutil

    HAS_PSUTIL = True
except ImportError:
    HAS_PSUTIL = False
    warnings.warn(
        "psutil not installed, memory monitoring will be limited", UserWarning
    )


[docs] class MemoryManager: """Intelligent memory management for optimization algorithms. This class provides: - Memory usage monitoring and prediction - Array pooling to reduce allocations with LRU eviction - Automatic garbage collection triggers - Context managers for memory-safe operations LRU Memory Pool (Task Group 7 - 1.2a) ------------------------------------- The memory pool uses an OrderedDict to track access order, enabling true LRU (Least Recently Used) eviction when at capacity. This improves cache utilization for frequently accessed array shapes by 5-10%. Telemetry Circular Buffer (Task Group 9 - 1.3a) ------------------------------------------------ The safety telemetry uses a deque with maxlen=1000 to prevent memory leak in multi-day optimization runs. This maintains the last 1000 telemetry records for adaptive safety factor calculation. Attributes ---------- memory_pool : OrderedDict Pool of reusable arrays indexed by (shape, dtype) with LRU tracking allocation_history : list History of memory allocations gc_threshold : float Memory usage threshold (0-1) for triggering garbage collection safety_factor : float Safety factor for memory predictions """
[docs] def __init__( self, gc_threshold: float = 0.8, safety_factor: float = 1.2, enable_adaptive_safety: bool = False, disable_padding: bool = False, memory_cache_ttl: float = 1.0, adaptive_ttl: bool = True, ): """Initialize memory manager. Parameters ---------- gc_threshold : float Trigger GC when memory usage exceeds this fraction (0-1) safety_factor : float Multiply memory requirements by this factor for safety enable_adaptive_safety : bool Enable adaptive safety factor reduction (1.2 -> 1.05 after warmup) disable_padding : bool Disable padding/bucketing for strict memory environments (Task 5.6). When True: uses exact shapes, sets safety_factor=1.0. Use case: cloud quotas, strict memory limits. memory_cache_ttl : float TTL in seconds for cached memory info (default: 1.0). Reduces psutil system call overhead by 90%. adaptive_ttl : bool Enable adaptive TTL based on call frequency (default: True). High-frequency callers (>100 calls/sec) get 15s effective TTL. Medium-frequency callers (>10 calls/sec) get 10s effective TTL. Low-frequency callers use the default TTL. Reduces psutil overhead in streaming optimization by 15-20%. """ # Task Group 7 (1.2a): Use OrderedDict for LRU memory pool # This enables move_to_end() for recently used arrays and # popitem(last=False) for LRU eviction when at capacity. self.memory_pool: OrderedDict[tuple, np.ndarray] = OrderedDict() self.allocation_history: deque[dict] = deque(maxlen=1000) self.gc_threshold = gc_threshold self.disable_padding = disable_padding # If disable_padding is True, force safety_factor to 1.0 if disable_padding: self.safety_factor = 1.0 self._initial_safety_factor = 1.0 self.enable_adaptive_safety = False # No adaptation when padding disabled else: self.safety_factor = safety_factor self._initial_safety_factor = safety_factor self.enable_adaptive_safety = enable_adaptive_safety self._peak_memory: float = 0.0 # TTL-based cache for psutil calls (reduces overhead by 90%) self._memory_cache_ttl = memory_cache_ttl self._available_memory_cache: float | None = None self._available_memory_cache_time: float = 0.0 self._memory_usage_cache: float | None = None self._memory_usage_cache_time: float = 0.0 self._memory_fraction_cache: float | None = None self._memory_fraction_cache_time: float = 0.0 # Adaptive TTL feature (Task 3 - 1.1a) # Tracks timestamps of last 100 calls to compute call frequency self._adaptive_ttl = adaptive_ttl self._call_frequency_tracker: deque[float] = deque(maxlen=100) self._initial_memory = self.get_memory_usage_bytes() # Task 9.4 (1.3a): Telemetry Circular Buffer # Use deque with maxlen=1000 to prevent memory leak in multi-day runs # Maintains last 1000 telemetry records for adaptive safety factor calculation self._safety_telemetry: deque[dict] = deque(maxlen=1000) self._warmup_runs = 10 # Number of runs before adapting self._min_safety_factor = 1.05 # Target minimum safety factor
def _get_effective_ttl(self) -> float: """Calculate effective TTL based on call frequency. Returns ------- effective_ttl : float The effective TTL in seconds based on call frequency: - 15.0s for high-frequency callers (>100 calls/sec) - 10.0s for medium-frequency callers (>10 calls/sec) - default TTL for low-frequency callers Notes ----- This method is only used when adaptive_ttl is enabled. The call frequency is computed from the time span of the last 100 calls. """ if not self._adaptive_ttl: return self._memory_cache_ttl # Need at least 2 calls to compute frequency if len(self._call_frequency_tracker) < 2: return self._memory_cache_ttl # Compute time span of tracked calls oldest_call = self._call_frequency_tracker[0] newest_call = self._call_frequency_tracker[-1] time_span = newest_call - oldest_call if time_span <= 0: # All calls happened at the same time, assume very high frequency return 15.0 # Compute calls per second num_calls = len(self._call_frequency_tracker) calls_per_sec = num_calls / time_span # Determine effective TTL based on frequency thresholds # Memory availability changes slowly, so aggressive caching is safe if calls_per_sec > 100: # High frequency: use 15s TTL (memory is stable) return 15.0 elif calls_per_sec > 10: # Medium frequency: use 10s TTL return 10.0 else: # Low frequency: use default TTL return self._memory_cache_ttl
[docs] def get_available_memory(self) -> float: """Get available memory in bytes. Returns ------- available : float Available memory in bytes Notes ----- Uses TTL-based caching to reduce psutil system call overhead by 90%. When adaptive_ttl is enabled, the effective TTL is adjusted based on call frequency to further reduce overhead for streaming optimization. """ now = time.time() # Track call timestamp for adaptive TTL if self._adaptive_ttl: self._call_frequency_tracker.append(now) # Calculate effective TTL (adaptive or default) effective_ttl = self._get_effective_ttl() # Return cached value if still valid if ( self._available_memory_cache is not None and now - self._available_memory_cache_time < effective_ttl ): return self._available_memory_cache # Fetch fresh value if HAS_PSUTIL: try: mem = psutil.virtual_memory() self._available_memory_cache = mem.available self._available_memory_cache_time = now return mem.available except Exception as e: # Fallback if psutil fails - log for debugging logger.debug(f"psutil memory check failed (non-critical): {e}") # Conservative fallback estimate (4 GB) return 4.0 * 1024**3
[docs] def get_memory_usage_bytes(self) -> float: """Get current memory usage in bytes. Returns ------- usage : float Current memory usage in bytes Notes ----- Uses TTL-based caching to reduce psutil system call overhead by 90%. """ now = time.time() # Return cached value if still valid if ( self._memory_usage_cache is not None and now - self._memory_usage_cache_time < self._memory_cache_ttl ): return self._memory_usage_cache # Fetch fresh value if HAS_PSUTIL: try: process = psutil.Process() usage = process.memory_info().rss self._memory_usage_cache = usage self._memory_usage_cache_time = now return usage except Exception as e: logger.debug(f"psutil process memory check failed (non-critical): {e}") # Fallback: try to estimate from Python's view import sys return sys.getsizeof(self.memory_pool) + sum( arr.nbytes for arr in self.memory_pool.values() )
[docs] def get_memory_usage_fraction(self) -> float: """Get current memory usage as fraction of total. Returns ------- fraction : float Memory usage fraction (0-1) Notes ----- Uses TTL-based caching to reduce psutil system call overhead by 90%. """ now = time.time() # Return cached value if still valid if ( self._memory_fraction_cache is not None and now - self._memory_fraction_cache_time < self._memory_cache_ttl ): return self._memory_fraction_cache # Fetch fresh value if HAS_PSUTIL: try: mem = psutil.virtual_memory() fraction = mem.percent / 100.0 self._memory_fraction_cache = fraction self._memory_fraction_cache_time = now return fraction except Exception as e: logger.debug(f"psutil memory fraction check failed (non-critical): {e}") # Conservative estimate return 0.5
[docs] def predict_memory_requirement( self, n_points: int, n_params: int, algorithm: str = "trf", dtype: jnp.dtype = jnp.float64, ) -> int: """Predict memory requirement for optimization. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters algorithm : str Algorithm name ('trf', 'lm', 'dogbox') dtype : jnp.dtype, optional Data type for computations (default: jnp.float64). Affects memory calculations: float32 uses 4 bytes, float64 uses 8 bytes. Returns ------- bytes_needed : int Estimated memory requirement in bytes Notes ----- Memory requirements scale linearly with precision: - float32: 4 bytes per element (50% memory savings) - float64: 8 bytes per element (default, higher precision) """ # Size of float depends on dtype (4 bytes for float32, 8 bytes for float64) float_size = 4 if dtype == jnp.float32 else 8 # Base arrays: x, y, params base_memory = float_size * (2 * n_points + n_params) # Jacobian matrix jacobian_memory = float_size * n_points * n_params # Algorithm-specific memory if algorithm == "trf": # Trust Region Reflective # Needs: SVD decomposition, working arrays svd_memory = float_size * min(n_points, n_params) ** 2 working_memory = float_size * (3 * n_points + 5 * n_params) total = base_memory + jacobian_memory + svd_memory + working_memory elif algorithm == "lm": # Levenberg-Marquardt # Needs: Normal equations, working arrays normal_memory = float_size * n_params**2 working_memory = float_size * (2 * n_points + 3 * n_params) total = base_memory + jacobian_memory + normal_memory + working_memory elif algorithm == "dogbox": # Dogbox # Similar to TRF but with additional bound constraints svd_memory = float_size * min(n_points, n_params) ** 2 working_memory = float_size * (4 * n_points + 6 * n_params) total = base_memory + jacobian_memory + svd_memory + working_memory else: # Conservative estimate for unknown algorithms total = base_memory + jacobian_memory * 2 # Apply safety factor return int(total * self.safety_factor)
[docs] def check_memory_availability(self, bytes_needed: int) -> tuple[bool, str]: """Check if enough memory is available. Parameters ---------- bytes_needed : int Memory required in bytes Returns ------- available : bool Whether enough memory is available message : str Descriptive message """ available = self.get_available_memory() if available >= bytes_needed: return ( True, f"Memory available: {available / 1e9:.2f}GB >= {bytes_needed / 1e9:.2f}GB needed", ) # Try garbage collection gc.collect() available = self.get_available_memory() if available >= bytes_needed: return True, "Memory available after garbage collection" return False, ( f"Insufficient memory: need {bytes_needed / 1e9:.2f}GB, " f"have {available / 1e9:.2f}GB available" )
[docs] @contextmanager def memory_guard(self, bytes_needed: int): """Context manager to ensure memory availability. Parameters ---------- bytes_needed : int Required memory in bytes Raises ------ MemoryError If insufficient memory is available """ # Check availability is_available, message = self.check_memory_availability(bytes_needed) if not is_available: # Last resort: clear memory pool self.clear_pool() is_available, message = self.check_memory_availability(bytes_needed) if not is_available: raise MemoryError(message) initial_memory = self.get_memory_usage_bytes() try: yield finally: # Track peak memory current_memory = self.get_memory_usage_bytes() self._peak_memory = max(self._peak_memory, current_memory) # Check if we should trigger GC if self.get_memory_usage_fraction() > self.gc_threshold: gc.collect() # Log allocation self.allocation_history.append( { "bytes_requested": bytes_needed, "bytes_used": current_memory - initial_memory, "peak_memory": self._peak_memory, } ) # Track telemetry for adaptive safety factor (Task 5.2) self._record_safety_telemetry(bytes_needed, current_memory - initial_memory)
def _record_safety_telemetry(self, bytes_predicted: int, bytes_actual: int | float): """Record telemetry for adaptive safety factor calculation. Parameters ---------- bytes_predicted : int Predicted memory requirement (with current safety factor) bytes_actual : int | float Actual memory used Notes ----- This method collects safety_factor_needed = actual / (predicted / safety_factor) which represents the minimum safety factor needed for this allocation. After warmup, we calculate p95(safety_factor_needed) to adaptively reduce the default safety factor. Task 9.4 (1.3a): Uses deque with maxlen=1000 to prevent unbounded growth. """ if not self.enable_adaptive_safety: return # Calculate base prediction (without safety factor) bytes_predicted_base = bytes_predicted / self.safety_factor # Calculate minimum safety factor needed for this allocation if bytes_predicted_base > 0: safety_factor_needed = bytes_actual / bytes_predicted_base else: safety_factor_needed = 1.0 # Record telemetry (deque automatically evicts oldest if at maxlen) self._safety_telemetry.append( { "bytes_predicted": bytes_predicted, "bytes_actual": bytes_actual, "safety_factor_needed": safety_factor_needed, "current_safety_factor": self.safety_factor, } ) # Update adaptive safety factor after warmup if len(self._safety_telemetry) >= self._warmup_runs: self._update_adaptive_safety_factor() def _update_adaptive_safety_factor(self): """Update safety factor based on telemetry (after warmup). Calculates p95(safety_factor_needed) from telemetry and uses it to gradually reduce safety factor from initial value (1.2) to target minimum (1.05). """ if ( not self.enable_adaptive_safety or len(self._safety_telemetry) < self._warmup_runs ): return # Extract safety factors needed from telemetry safety_factors_needed = [ entry["safety_factor_needed"] for entry in self._safety_telemetry ] # Calculate p95 (95th percentile) - conservative estimate p95_safety = np.percentile(safety_factors_needed, 95) # Adaptive safety factor: max(min_safety_factor, p95_safety) # This ensures we never go below 1.05, but use higher if needed adaptive_safety = max(self._min_safety_factor, p95_safety) # Gradually reduce safety factor (don't jump abruptly) if adaptive_safety < self.safety_factor: # Reduce by at most 0.05 per update to avoid sudden changes self.safety_factor = max(adaptive_safety, self.safety_factor - 0.05) logger.debug( f"Adaptive safety factor: {self.safety_factor:.3f} " f"(p95_needed={p95_safety:.3f}, runs={len(self._safety_telemetry)})" )
[docs] def get_safety_telemetry(self) -> dict: """Get safety factor telemetry statistics. Returns ------- telemetry : dict Safety factor telemetry with: - current_safety_factor: Current safety factor - initial_safety_factor: Initial safety factor (1.2) - min_safety_factor: Target minimum (1.05) - telemetry_entries: Number of telemetry entries collected - p95_safety_needed: 95th percentile of safety factors needed (if data available) - safety_factor_history: List of safety factors over time """ telemetry: dict[str, object] = { "current_safety_factor": self.safety_factor, "initial_safety_factor": self._initial_safety_factor, "min_safety_factor": self._min_safety_factor, "telemetry_entries": len(self._safety_telemetry), "adaptive_enabled": self.enable_adaptive_safety, } if self._safety_telemetry: safety_factors_needed = [ entry["safety_factor_needed"] for entry in self._safety_telemetry ] telemetry["p95_safety_needed"] = float( np.percentile(safety_factors_needed, 95) ) telemetry["mean_safety_needed"] = float(np.mean(safety_factors_needed)) telemetry["max_safety_needed"] = float(np.max(safety_factors_needed)) telemetry["safety_factor_history"] = [ entry["current_safety_factor"] for entry in self._safety_telemetry ] return telemetry
[docs] def allocate_array( self, shape: tuple[int, ...], dtype: type = np.float64, zero: bool = True ) -> np.ndarray: """Allocate array with memory pooling and LRU tracking. Parameters ---------- shape : tuple Shape of array to allocate dtype : type Data type of array zero : bool Whether to zero-initialize the array Returns ------- array : np.ndarray Allocated array Raises ------ MemoryError If allocation fails Notes ----- Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. When an array is reused from the pool, it is moved to the end (most recently used) to enable proper LRU eviction. """ key = (shape, dtype) # Check pool for existing array if key in self.memory_pool: arr = self.memory_pool[key] # Task Group 7 (1.2a): Move to end for LRU tracking # This marks the array as recently used self.memory_pool.move_to_end(key) if zero: arr.fill(0) return arr # Calculate memory needed bytes_needed = int(np.prod(shape) * np.dtype(dtype).itemsize) # Allocate with memory guard with self.memory_guard(bytes_needed): if zero: arr = np.zeros(shape, dtype=dtype) else: arr = np.empty(shape, dtype=dtype) # Add to pool (at end, as most recently used) self.memory_pool[key] = arr return arr
[docs] def free_array(self, arr: np.ndarray): """Return array to pool for reuse. Parameters ---------- arr : np.ndarray Array to free Notes ----- Task Group 7 (1.2a): Uses LRU tracking via OrderedDict. The returned array is added/moved to the end of the pool, marking it as recently used. """ key = (arr.shape, arr.dtype) if key in self.memory_pool: # Already in pool, just move to end (mark as recently used) self.memory_pool.move_to_end(key) else: # Add new entry at end self.memory_pool[key] = arr
[docs] def clear_pool(self): """Clear memory pool and run garbage collection.""" self.memory_pool.clear() gc.collect()
[docs] def get_memory_stats(self) -> dict: """Get memory usage statistics. Returns ------- stats : dict Memory statistics including current usage, peak, pool size """ current_memory = self.get_memory_usage_bytes() available_memory = self.get_available_memory() pool_memory = sum(arr.nbytes for arr in self.memory_pool.values()) pool_arrays = len(self.memory_pool) stats: dict[str, object] = { "current_usage_gb": current_memory / 1e9, "available_gb": available_memory / 1e9, "peak_usage_gb": self._peak_memory / 1e9, "usage_fraction": self.get_memory_usage_fraction(), "pool_memory_gb": pool_memory / 1e9, "pool_arrays": pool_arrays, "allocations": len(self.allocation_history), } if self.allocation_history: total_requested = sum(a["bytes_requested"] for a in self.allocation_history) total_used = sum(a["bytes_used"] for a in self.allocation_history) stats["total_requested_gb"] = total_requested / 1e9 stats["total_used_gb"] = total_used / 1e9 stats["efficiency"] = ( total_used / total_requested if total_requested > 0 else 1.0 ) # Include safety factor telemetry (Task 5.2) if self.enable_adaptive_safety: stats["safety_telemetry"] = self.get_safety_telemetry() # Include padding configuration (Task 5.6) stats["disable_padding"] = self.disable_padding return stats
[docs] def optimize_memory_pool(self, max_arrays: int = 100): """Optimize memory pool using LRU eviction. Parameters ---------- max_arrays : int Maximum number of arrays to keep in pool Notes ----- Task Group 7 (1.2a): Uses LRU eviction via popitem(last=False). Arrays are evicted in order of least recent use, keeping the most recently used arrays in the pool. """ if len(self.memory_pool) <= max_arrays: return # Task Group 7 (1.2a): Use LRU eviction # popitem(last=False) removes the oldest (least recently used) entry while len(self.memory_pool) > max_arrays: self.memory_pool.popitem(last=False) gc.collect()
[docs] @contextmanager def temporary_allocation(self, shape: tuple[int, ...], dtype: type = np.float64): """Context manager for temporary array allocation. Parameters ---------- shape : tuple Shape of array dtype : type Data type Yields ------ array : np.ndarray Temporary array that will be returned to pool on exit """ arr = self.allocate_array(shape, dtype) try: yield arr finally: # Return array to pool for reuse self.free_array(arr)
[docs] def estimate_chunking_strategy( self, n_points: int, n_params: int, algorithm: str = "trf", memory_limit_gb: float | None = None, ) -> dict: """Estimate optimal chunking strategy for large datasets. Parameters ---------- n_points : int Number of data points n_params : int Number of parameters algorithm : str Algorithm to use memory_limit_gb : float, optional Memory limit in GB (uses available memory if None) Returns ------- strategy : dict Chunking strategy with chunk_size and n_chunks """ if memory_limit_gb is None: memory_limit = self.get_available_memory() * 0.8 # Use 80% of available else: memory_limit = memory_limit_gb * 1e9 # Calculate memory per point memory_per_point = self.predict_memory_requirement(1, n_params, algorithm) # Calculate maximum points that fit in memory max_points = int(memory_limit / memory_per_point) if max_points >= n_points: # No chunking needed return { "needs_chunking": False, "chunk_size": n_points, "n_chunks": 1, "memory_per_chunk_gb": self.predict_memory_requirement( n_points, n_params, algorithm ) / 1e9, } # Calculate chunking parameters chunk_size = min( max_points, max(100, n_points // 100) ) # At least 100 points per chunk n_chunks = (n_points + chunk_size - 1) // chunk_size return { "needs_chunking": True, "chunk_size": chunk_size, "n_chunks": n_chunks, "memory_per_chunk_gb": self.predict_memory_requirement( chunk_size, n_params, algorithm ) / 1e9, "total_points": n_points, }
# Global memory manager instance _memory_manager: MemoryManager | None = None _memory_manager_lock = threading.Lock()
[docs] def get_memory_manager() -> MemoryManager: """Get or create global memory manager instance. Returns ------- manager : MemoryManager Global memory manager instance """ global _memory_manager # noqa: PLW0603 if _memory_manager is None: with _memory_manager_lock: if _memory_manager is None: _memory_manager = MemoryManager() return _memory_manager
[docs] def clear_memory_pool(): """Clear the global memory pool.""" manager = get_memory_manager() manager.clear_pool()
[docs] def get_memory_stats() -> dict: """Get memory usage statistics. Returns ------- stats : dict Memory statistics """ manager = get_memory_manager() return manager.get_memory_stats()