"""Workflow Configuration and Selection Module.
This module provides memory-based optimizer selection and adaptive tolerance
calculation for NLSQ curve fitting operations.
Key Components
--------------
- ``OptimizationGoal`` enum: Defines optimization priorities (FAST, ROBUST, QUALITY, etc.)
- ``MemoryBudget`` dataclass: Computes memory requirements for optimizer selection
- ``MemoryBudgetSelector`` class: Selects optimal optimizer strategy based on memory
- ``calculate_adaptive_tolerances()``: Returns size-appropriate convergence tolerances
- ``ClusterDetector`` class: Detects HPC cluster environments (PBS Pro)
Examples
--------
Memory-based optimizer selection:
>>> from nlsq.core.workflow import MemoryBudgetSelector
>>> selector = MemoryBudgetSelector(safety_factor=0.75)
>>> strategy, config = selector.select(n_points=5_000_000, n_params=10)
>>> if strategy == "streaming":
... pass # Use HybridStreamingOptimizer
>>> elif strategy == "chunked":
... pass # Use LargeDatasetFitter
>>> else:
... pass # Use standard curve_fit()
Adaptive tolerance calculation:
>>> from nlsq.core.workflow import calculate_adaptive_tolerances, OptimizationGoal
>>> tols = calculate_adaptive_tolerances(n_points=5_000_000, goal=OptimizationGoal.QUALITY)
>>> tols['gtol'] # Returns tighter tolerance for QUALITY goal
1e-08
Cluster detection for HPC environments:
>>> from nlsq.core.workflow import ClusterDetector
>>> detector = ClusterDetector()
>>> cluster_info = detector.detect()
>>> if cluster_info:
... print(f"Running on cluster: {cluster_info.total_gpus} GPUs")
"""
import os
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from nlsq.streaming.hybrid_config import HybridStreamingConfig
from nlsq.streaming.large_dataset import LDMemoryConfig
[docs]
class OptimizationGoal(Enum):
"""Optimization goals that influence workflow selection and tolerances.
Each goal represents a different optimization priority, affecting:
- Convergence tolerances (gtol, ftol, xtol)
- Multi-start enablement
- Memory/speed tradeoffs
Attributes
----------
FAST : auto
Prioritize speed with local optimization only.
Uses one tier looser tolerances, skips multi-start.
Best for: quick exploration, well-conditioned problems.
ROBUST : auto
Standard tolerances with multi-start for better global optimum.
Uses dataset-appropriate tolerances, enables multi-start via `MultiStartOrchestrator`.
Best for: production use, unknown problem conditioning.
GLOBAL : auto
Synonym for ROBUST. Emphasizes global optimization.
Same behavior as ROBUST, provided for semantic clarity.
MEMORY_EFFICIENT : auto
Minimize memory usage with standard tolerances.
Prioritizes streaming/chunking with smaller chunk sizes.
Best for: memory-constrained environments, very large datasets.
QUALITY : auto
Highest precision/accuracy as TOP PRIORITY.
Uses one tier tighter tolerances, enables multi-start, runs validation passes.
Best for: publication-quality results, critical applications.
"""
FAST = auto()
ROBUST = auto()
GLOBAL = auto() # Alias for ROBUST
MEMORY_EFFICIENT = auto()
QUALITY = auto()
[docs]
@classmethod
def normalize(cls, goal: "OptimizationGoal") -> "OptimizationGoal":
"""Normalize GLOBAL to ROBUST since they have same behavior.
Parameters
----------
goal : OptimizationGoal
The goal to normalize.
Returns
-------
OptimizationGoal
ROBUST if goal was GLOBAL, otherwise the original goal.
"""
if goal == cls.GLOBAL:
return cls.ROBUST
return goal
# ============================================================================
# Memory Budget API (new unified memory-based optimizer selection)
# ============================================================================
[docs]
@dataclass(slots=True, frozen=True)
class MemoryBudget:
"""Computed memory budget for optimizer selection.
This immutable dataclass represents the computed memory requirements
and available resources for automatic optimizer strategy selection.
Use the `compute()` factory method to create instances.
Attributes
----------
available_gb : float
Available system memory in GB (CPU or GPU depending on target).
threshold_gb : float
Safe memory threshold = available_gb × safety_factor.
data_gb : float
Memory required for data arrays (x_data, y_data).
jacobian_gb : float
Memory required for full Jacobian matrix.
peak_gb : float
Estimated peak memory = data_gb + 1.3 × jacobian_gb + solver overhead.
Examples
--------
>>> budget = MemoryBudget.compute(n_points=10_000_000, n_params=10)
>>> print(f"Available: {budget.available_gb:.1f} GB")
>>> print(f"Peak estimate: {budget.peak_gb:.2f} GB")
>>> print(f"Fits in memory: {budget.fits_in_memory}")
"""
available_gb: float
threshold_gb: float
data_gb: float
jacobian_gb: float
peak_gb: float
@property
def fits_in_memory(self) -> bool:
"""Check if estimated peak memory fits within safe threshold.
Returns
-------
bool
True if peak_gb <= threshold_gb.
"""
return self.peak_gb <= self.threshold_gb
@property
def data_fits(self) -> bool:
"""Check if data arrays alone fit within safe threshold.
Returns
-------
bool
True if data_gb <= threshold_gb.
"""
return self.data_gb <= self.threshold_gb
[docs]
@classmethod
def compute(
cls,
n_points: int,
n_params: int,
n_features: int = 1,
dtype_bytes: int = 8,
safety_factor: float = 0.75,
memory_limit_gb: float | None = None,
use_gpu: bool = False,
) -> "MemoryBudget":
"""Compute memory budget for a given dataset size.
Parameters
----------
n_points : int
Number of data points.
n_params : int
Number of fit parameters.
n_features : int, default=1
Number of features in x_data (dimensions).
dtype_bytes : int, default=8
Bytes per element (8 for float64, 4 for float32).
safety_factor : float, default=0.75
Memory safety factor (0.75 means use 75% of available).
memory_limit_gb : float | None, default=None
Override memory limit in GB. If None, auto-detect.
use_gpu : bool, default=False
If True, use GPU memory instead of CPU memory.
Returns
-------
MemoryBudget
Computed memory budget with all fields populated.
Raises
------
ValueError
If n_points <= 0, n_params <= 0, or safety_factor not in (0, 1].
Examples
--------
>>> budget = MemoryBudget.compute(n_points=1_000_000, n_params=5)
>>> budget.fits_in_memory
True
"""
# Validation
if n_points <= 0:
raise ValueError("n_points must be positive")
if n_params <= 0:
raise ValueError("n_params must be positive")
if safety_factor <= 0 or safety_factor > 1.0:
raise ValueError("safety_factor must be in (0, 1]")
# Get available memory (lazy import to avoid circular dependency)
if memory_limit_gb is not None:
available_gb = memory_limit_gb
else:
available_gb = cls._detect_available_memory(use_gpu)
# Compute memory estimates
# Data arrays: x_data (n_points x n_features) + y_data (n_points)
data_bytes = n_points * (n_features + 1) * dtype_bytes
data_gb = data_bytes / (1024**3)
# Jacobian: n_points x n_params
jacobian_bytes = n_points * n_params * dtype_bytes
jacobian_gb = jacobian_bytes / (1024**3)
# Peak estimate: data + 1.3*jacobian (SVD working memory) + solver overhead
# 1.3 factor accounts for SVD U, S, V matrices
solver_overhead_gb = 0.1 # Fixed solver overhead (workspace, etc.)
peak_gb = data_gb + 1.3 * jacobian_gb + solver_overhead_gb
# Compute threshold
threshold_gb = available_gb * safety_factor
return cls(
available_gb=available_gb,
threshold_gb=threshold_gb,
data_gb=data_gb,
jacobian_gb=jacobian_gb,
peak_gb=peak_gb,
)
@staticmethod
def _detect_available_memory(use_gpu: bool = False) -> float:
"""Detect available system memory.
Parameters
----------
use_gpu : bool, default=False
If True, detect GPU memory; otherwise detect CPU memory.
Returns
-------
float
Available memory in GB. Falls back to 8.0 GB if detection fails.
"""
_FALLBACK_MEMORY_GB = 8.0 # Conservative fallback (FR-009)
if use_gpu:
try:
from nlsq.streaming.large_dataset import GPUMemoryEstimator
estimator = GPUMemoryEstimator()
gpu_memory = estimator.get_available_gpu_memory_gb()
if gpu_memory > 0:
return gpu_memory
# No GPU available, fall back to CPU
return MemoryBudget._detect_available_memory(use_gpu=False)
except Exception:
return _FALLBACK_MEMORY_GB
else:
try:
from nlsq.streaming.large_dataset import MemoryEstimator
return MemoryEstimator.get_available_memory_gb()
except Exception:
return _FALLBACK_MEMORY_GB
[docs]
class MemoryBudgetSelector:
"""Selects optimal optimizer strategy based on memory budget.
This class computes memory requirements and selects between STREAMING,
CHUNKED, and STANDARD strategies based on three sequential memory
comparisons.
Decision Tree:
1. data_gb > threshold_gb → STREAMING (data doesn't fit)
2. peak_gb > threshold_gb → CHUNKED (Jacobian doesn't fit)
3. else → STANDARD (everything fits)
Parameters
----------
safety_factor : float, default=0.75
Memory safety factor (0.75 means use 75% of available memory).
Examples
--------
>>> selector = MemoryBudgetSelector(safety_factor=0.75)
>>> strategy, config = selector.select(n_points=5_000_000, n_params=10)
>>> if strategy == "streaming":
... # Use HybridStreamingOptimizer with config
... pass
>>> elif strategy == "chunked":
... # Use LargeDatasetFitter with config
... pass
>>> else:
... # Use standard curve_fit()
... pass
"""
[docs]
def __init__(self, safety_factor: float = 0.75) -> None:
"""Initialize selector with safety factor.
Parameters
----------
safety_factor : float, default=0.75
Memory safety factor (0.75 means use 75% of available memory).
"""
self.safety_factor = safety_factor
[docs]
def select(
self,
n_points: int,
n_params: int,
n_features: int = 1,
memory_limit_gb: float | None = None,
goal: "OptimizationGoal | None" = None,
use_gpu: bool = False,
verbose: bool = False,
) -> tuple[str, "HybridStreamingConfig | LDMemoryConfig | None"]:
"""Select optimal optimizer strategy based on memory budget.
Parameters
----------
n_points : int
Number of data points.
n_params : int
Number of fit parameters.
n_features : int, default=1
Number of features in x_data.
memory_limit_gb : float | None, default=None
Override memory limit in GB. If None, auto-detect.
goal : OptimizationGoal | None, default=None
Optimization goal (affects tolerances, not strategy selection).
use_gpu : bool, default=False
If True, use GPU memory instead of CPU memory.
verbose : bool, default=False
If True, log memory budget details and strategy selection reason.
Returns
-------
tuple[str, config]
- strategy: "streaming", "chunked", or "standard"
- config: HybridStreamingConfig, LDMemoryConfig, or None
Raises
------
ValueError
If n_points <= 0 or n_params <= 0.
"""
import logging
logger = logging.getLogger("nlsq")
# Compute memory budget
budget = MemoryBudget.compute(
n_points=n_points,
n_params=n_params,
n_features=n_features,
safety_factor=self.safety_factor,
memory_limit_gb=memory_limit_gb,
use_gpu=use_gpu,
)
# Log memory budget details if verbose
if verbose:
logger.info(
f"[NLSQ] Memory budget: available={budget.available_gb:.1f} GB, "
f"threshold={budget.threshold_gb:.1f} GB"
)
logger.info(
f"[NLSQ] Estimates: data={budget.data_gb:.3f} GB, "
f"jacobian={budget.jacobian_gb:.3f} GB, peak={budget.peak_gb:.3f} GB"
)
# Decision tree (FR-003):
# 1. data_gb > threshold_gb → STREAMING
if not budget.data_fits:
if verbose:
logger.info(
f"[NLSQ] Strategy: streaming (data {budget.data_gb:.2f} GB > "
f"threshold {budget.threshold_gb:.2f} GB)"
)
return self._create_streaming_config(budget, n_params, goal)
# 2. peak_gb > threshold_gb → CHUNKED (but data fits)
# Also apply 10% safety margin (FR-010)
safety_margin_threshold = budget.threshold_gb * 0.9
if budget.peak_gb > safety_margin_threshold:
if verbose:
logger.info(
f"[NLSQ] Strategy: chunked (peak {budget.peak_gb:.2f} GB > "
f"safety threshold {safety_margin_threshold:.2f} GB)"
)
return self._create_chunked_config(budget, n_params, goal)
# 3. else → STANDARD
if verbose:
logger.info(
f"[NLSQ] Strategy: standard (peak {budget.peak_gb:.2f} GB < "
f"threshold {budget.threshold_gb:.2f} GB)"
)
return ("standard", None)
def _create_streaming_config(
self,
budget: MemoryBudget,
n_params: int,
goal: "OptimizationGoal | None",
) -> tuple[str, "HybridStreamingConfig"]:
"""Create configuration for streaming strategy.
Parameters
----------
budget : MemoryBudget
Computed memory budget.
n_params : int
Number of fit parameters.
goal : OptimizationGoal | None
Optimization goal.
Returns
-------
tuple[str, HybridStreamingConfig]
Strategy name and configuration.
"""
from nlsq.streaming.hybrid_config import HybridStreamingConfig
# Compute batch size based on available memory
batch_size = self._compute_streaming_batch_size(budget, n_params)
return (
"streaming",
HybridStreamingConfig(
chunk_size=batch_size, # HybridStreamingConfig uses chunk_size
normalize=True, # Always normalize for streaming
),
)
def _create_chunked_config(
self,
budget: MemoryBudget,
n_params: int,
goal: "OptimizationGoal | None",
) -> tuple[str, "LDMemoryConfig"]:
"""Create configuration for chunked strategy.
Parameters
----------
budget : MemoryBudget
Computed memory budget.
n_params : int
Number of fit parameters.
goal : OptimizationGoal | None
Optimization goal.
Returns
-------
tuple[str, LDMemoryConfig]
Strategy name and configuration.
"""
from nlsq.streaming.large_dataset import LDMemoryConfig
# Compute chunk size based on available memory
chunk_size = self._compute_chunk_size(budget, n_params)
return (
"chunked",
LDMemoryConfig(
memory_limit_gb=budget.threshold_gb,
safety_factor=self.safety_factor,
min_chunk_size=1_000,
max_chunk_size=1_000_000,
streaming_batch_size=chunk_size,
),
)
def _compute_chunk_size(self, budget: MemoryBudget, n_params: int) -> int:
"""Compute optimal chunk size based on memory budget.
Parameters
----------
budget : MemoryBudget
Computed memory budget.
n_params : int
Number of fit parameters.
Returns
-------
int
Optimal chunk size in data points.
"""
# Target: use ~75% of threshold for chunk processing
target_memory_gb = budget.threshold_gb * 0.75
# Memory per point: data + jacobian row
# data: (1 + 1) * 8 bytes = 16 bytes (x, y)
# jacobian: n_params * 8 bytes
bytes_per_point = (2 * 8) + (n_params * 8)
gb_per_point = bytes_per_point / (1024**3)
if gb_per_point > 0:
computed_chunk = int(target_memory_gb / gb_per_point)
else:
computed_chunk = 100_000 # Default
# Clamp to bounds (FR-007: 1K-1M range)
return max(1_000, min(computed_chunk, 1_000_000))
def _compute_streaming_batch_size(self, budget: MemoryBudget, n_params: int) -> int:
"""Compute optimal streaming batch size based on memory budget.
Parameters
----------
budget : MemoryBudget
Computed memory budget.
n_params : int
Number of fit parameters.
Returns
-------
int
Optimal batch size in data points.
"""
# Streaming needs smaller batches than chunked
# Target: use ~50% of threshold for batch processing
target_memory_gb = budget.threshold_gb * 0.5
# Memory per point in streaming: lighter than chunked
# data: (1 + 1) * 8 bytes = 16 bytes
# gradient accumulation: n_params * 8 bytes
bytes_per_point = (2 * 8) + (n_params * 8)
gb_per_point = bytes_per_point / (1024**3)
if gb_per_point > 0:
computed_batch = int(target_memory_gb / gb_per_point)
else:
computed_batch = 50_000 # Default
# Clamp to bounds
return max(1_000, min(computed_batch, 1_000_000))
# Dataset size thresholds and their corresponding tolerances
# (max_points exclusive, tolerance)
_SIZE_TOLERANCE_TABLE: list[tuple[int, float]] = [
(1_000, 1e-12), # TINY: < 1K points
(10_000, 1e-10), # SMALL: 1K - 10K points
(100_000, 1e-9), # MEDIUM: 10K - 100K points
(1_000_000, 1e-8), # LARGE: 100K - 1M points
(10_000_000, 1e-7), # VERY_LARGE: 1M - 10M points
(100_000_000, 1e-6), # HUGE: 10M - 100M points
]
_MASSIVE_TOLERANCE = 1e-5 # > 100M points
def _get_size_tier_index(n_points: int) -> int:
"""Get the tier index for a given dataset size.
Parameters
----------
n_points : int
Number of data points.
Returns
-------
int
Index into _SIZE_TOLERANCE_TABLE, or len(_SIZE_TOLERANCE_TABLE)
for MASSIVE datasets.
"""
for i, (max_points, _) in enumerate(_SIZE_TOLERANCE_TABLE):
if n_points < max_points:
return i
return len(_SIZE_TOLERANCE_TABLE) # MASSIVE
def _get_tolerance_by_index(index: int) -> float:
"""Get tolerance by tier index, clamped to valid range.
Parameters
----------
index : int
Tier index (may be out of bounds, will be clamped).
Returns
-------
float
The tolerance for the (clamped) tier.
"""
if index < 0:
return _SIZE_TOLERANCE_TABLE[0][1] # Tightest
elif index >= len(_SIZE_TOLERANCE_TABLE):
return _MASSIVE_TOLERANCE # Loosest
else:
return _SIZE_TOLERANCE_TABLE[index][1]
[docs]
def calculate_adaptive_tolerances(
n_points: int,
goal: OptimizationGoal | None = None,
) -> dict[str, float]:
"""Calculate adaptive tolerances based on dataset size and optimization goal.
This function determines appropriate convergence tolerances (gtol, ftol, xtol)
for the given dataset size, then applies goal-based adjustments:
- "quality" goal: Use one tier tighter (lower) tolerances
- "fast" goal: Use one tier looser (higher) tolerances
- "robust"/"global"/"memory_efficient": Use standard tolerances for dataset size
Parameters
----------
n_points : int
Number of data points in the dataset.
goal : OptimizationGoal, optional
Optimization goal to adjust tolerances. Default: None (use dataset-appropriate).
Returns
-------
dict[str, float]
Dictionary with 'gtol', 'ftol', 'xtol' keys and corresponding tolerance values.
Examples
--------
>>> tols = calculate_adaptive_tolerances(5_000_000)
>>> tols['gtol']
1e-07
>>> tols = calculate_adaptive_tolerances(5_000_000, goal=OptimizationGoal.QUALITY)
>>> tols['gtol'] # One tier tighter
1e-08
>>> tols = calculate_adaptive_tolerances(5_000_000, goal=OptimizationGoal.FAST)
>>> tols['gtol'] # One tier looser
1e-06
"""
# Get base tier index from dataset size
tier_index = _get_size_tier_index(n_points)
# Apply goal-based tier shifting
if goal is not None:
# Normalize GLOBAL to ROBUST
goal = OptimizationGoal.normalize(goal)
if goal == OptimizationGoal.QUALITY:
# Use one tier tighter (shift toward smaller datasets)
tier_index = tier_index - 1
elif goal == OptimizationGoal.FAST:
# Use one tier looser (shift toward larger datasets)
tier_index = tier_index + 1
# ROBUST, MEMORY_EFFICIENT: use base tier (no shift)
# Get effective tolerance (clamped to valid range)
tolerance = _get_tolerance_by_index(tier_index)
return {
"gtol": tolerance,
"ftol": tolerance,
"xtol": tolerance,
}
# =============================================================================
# Cluster Detection and Distributed Processing
# =============================================================================
[docs]
@dataclass(slots=True)
class ClusterInfo:
"""Information about detected cluster environment.
This dataclass contains information about the cluster configuration,
including node count, GPUs per node, and total resources available.
Parameters
----------
node_count : int
Number of nodes in the cluster.
gpus_per_node : int
Number of GPUs per node.
total_gpus : int
Total number of GPUs across all nodes.
node_list : list[str]
List of node hostnames.
scheduler : str
Cluster scheduler type ('pbs', 'local', or 'unknown').
job_id : str | None
PBS job ID if available.
interconnect : str | None
Interconnect type if detectable (e.g., 'infiniband').
Examples
--------
>>> cluster_info = ClusterInfo(
... node_count=6,
... gpus_per_node=8,
... total_gpus=48,
... node_list=["node01", "node02", "node03", "node04", "node05", "node06"],
... scheduler="pbs",
... job_id="12345.pbs_server",
... )
>>> cluster_info.total_gpus
48
"""
node_count: int
gpus_per_node: int
total_gpus: int
node_list: list[str]
scheduler: str = "unknown"
job_id: str | None = None
interconnect: str | None = None
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize cluster info to dictionary.
Returns
-------
dict
Dictionary representation of cluster info.
"""
return {
"node_count": self.node_count,
"gpus_per_node": self.gpus_per_node,
"total_gpus": self.total_gpus,
"node_list": self.node_list,
"scheduler": self.scheduler,
"job_id": self.job_id,
"interconnect": self.interconnect,
}
[docs]
@classmethod
def from_dict(cls, d: dict[str, Any]) -> "ClusterInfo":
"""Create ClusterInfo from dictionary.
Parameters
----------
d : dict
Dictionary with cluster info fields.
Returns
-------
ClusterInfo
ClusterInfo instance.
"""
return cls(
node_count=d.get("node_count", 1),
gpus_per_node=d.get("gpus_per_node", 0),
total_gpus=d.get("total_gpus", 0),
node_list=d.get("node_list", []),
scheduler=d.get("scheduler", "unknown"),
job_id=d.get("job_id"),
interconnect=d.get("interconnect"),
)
[docs]
class ClusterDetector:
"""Detector for cluster environments and GPU configurations.
This class auto-detects PBS cluster environments via $PBS_NODEFILE
and single-node multi-GPU configurations via JAX's device API.
Supports:
- PBS Pro cluster manager
- Single-node multi-GPU (2-8 GPUs)
- Multi-node HPC clusters (10-100 nodes, 8x A100 GPUs per node)
Examples
--------
>>> detector = ClusterDetector()
>>> cluster_info = detector.detect()
>>> if cluster_info is not None:
... print(f"Cluster detected: {cluster_info.node_count} nodes")
... print(f"Total GPUs: {cluster_info.total_gpus}")
... else:
... print("Not in cluster environment")
Check for PBS specifically:
>>> if detector.is_pbs_environment():
... cluster_info = detector.detect_pbs()
... print(f"PBS Job ID: {cluster_info.job_id}")
"""
# Default GPUs per node for HPC environments (A100 nodes)
DEFAULT_GPUS_PER_NODE = 8
[docs]
def __init__(self, default_gpus_per_node: int = 8) -> None:
"""Initialize ClusterDetector.
Parameters
----------
default_gpus_per_node : int, optional
Default number of GPUs per node when not auto-detectable.
Default: 8 (for A100 HPC nodes).
"""
self._default_gpus_per_node = default_gpus_per_node
[docs]
def detect(self) -> ClusterInfo | None:
"""Auto-detect cluster environment.
Tries PBS first, then falls back to local multi-GPU detection.
Returns None if not in a cluster environment (single CPU-only machine).
Returns
-------
ClusterInfo or None
ClusterInfo if cluster detected, None otherwise.
Examples
--------
>>> detector = ClusterDetector()
>>> info = detector.detect()
>>> if info:
... print(f"Running on {info.scheduler} with {info.total_gpus} GPUs")
"""
# Try PBS environment first
if self.is_pbs_environment():
return self.detect_pbs()
# Try local multi-GPU
local_info = self.detect_local_gpus()
if local_info and local_info.total_gpus > 0:
return local_info
# Not in cluster environment
return None
[docs]
def is_pbs_environment(self) -> bool:
"""Check if running in PBS cluster environment.
Returns
-------
bool
True if PBS_NODEFILE environment variable is set.
"""
return "PBS_NODEFILE" in os.environ
[docs]
def detect_pbs(self) -> ClusterInfo | None:
"""Detect PBS Pro cluster configuration.
Parses PBS_NODEFILE to determine node count and list.
GPU count per node is either auto-detected via JAX or uses default.
Returns
-------
ClusterInfo or None
ClusterInfo with PBS configuration, or None if not in PBS environment.
Notes
-----
PBS_NODEFILE contains one line per allocated processor slot.
For GPU jobs, typically each GPU gets one line per node.
"""
nodefile_path = os.environ.get("PBS_NODEFILE")
if not nodefile_path:
return None
try:
# Parse PBS_NODEFILE
nodefile = Path(nodefile_path)
if not nodefile.exists():
return None
with open(nodefile) as f:
lines = f.read().strip().split("\n")
if not lines or not lines[0]:
return None
# Get unique nodes (PBS lists each slot, often duplicates)
unique_nodes = list(dict.fromkeys(lines)) # Preserves order
node_count = len(unique_nodes)
# Try to detect GPUs per node via JAX
gpus_per_node = self._detect_gpus_per_node()
if gpus_per_node == 0:
# Fallback to default
gpus_per_node = self._default_gpus_per_node
# Get PBS job ID
job_id = os.environ.get("PBS_JOBID")
# Detect interconnect (heuristic based on common setups)
interconnect = self._detect_interconnect()
return ClusterInfo(
node_count=node_count,
gpus_per_node=gpus_per_node,
total_gpus=node_count * gpus_per_node,
node_list=unique_nodes,
scheduler="pbs",
job_id=job_id,
interconnect=interconnect,
)
except (OSError, ValueError):
return None
[docs]
def detect_local_gpus(self) -> ClusterInfo | None:
"""Detect local multi-GPU configuration.
Uses JAX's device API to enumerate available GPUs on the local node.
Returns
-------
ClusterInfo or None
ClusterInfo with local GPU configuration, or None if detection fails.
"""
try:
gpu_count = self._detect_gpus_per_node()
if gpu_count == 0:
return None
import socket
hostname = socket.gethostname()
return ClusterInfo(
node_count=1,
gpus_per_node=gpu_count,
total_gpus=gpu_count,
node_list=[hostname],
scheduler="local",
job_id=None,
interconnect=None,
)
except Exception:
return None
def _detect_gpus_per_node(self) -> int:
"""Detect number of GPUs on the local node via JAX.
Returns
-------
int
Number of GPU devices, or 0 if no GPUs or detection fails.
"""
try:
import jax
devices = jax.devices()
gpu_count = sum(
1 for d in devices if getattr(d, "platform", "cpu") != "cpu"
)
return gpu_count
except Exception:
return 0
def _detect_interconnect(self) -> str | None:
"""Detect interconnect type (heuristic).
Returns
-------
str or None
Interconnect type ('infiniband', 'ethernet') or None.
"""
# Check for Infiniband indicators
if Path("/sys/class/infiniband").exists():
return "infiniband"
# Check for common IB environment variables (OpenMPI)
# Note: Environment variable names are case-sensitive and this one uses lowercase
if os.environ.get("OMPI_MCA_btl_openib_allow_ib"): # noqa: SIM112
return "infiniband"
return None
[docs]
@dataclass(slots=True)
class MultiGPUConfig:
"""Configuration for multi-GPU data parallelism.
This class holds configuration for distributing data across multiple GPUs
using JAX's pmap/pjit primitives.
Parameters
----------
n_devices : int
Number of GPU devices to use.
shard_axis : int
Axis along which to shard data. Default: 0 (batch dimension).
use_pmap : bool
Use pmap for data parallelism. Default: True.
use_pjit : bool
Use pjit for more flexible sharding. Default: False.
per_device_batch_size : int
Batch size per device. Default: 10000.
Examples
--------
>>> config = MultiGPUConfig(n_devices=4, per_device_batch_size=5000)
>>> config.total_batch_size
20000
"""
n_devices: int
shard_axis: int = 0
use_pmap: bool = True
use_pjit: bool = False
per_device_batch_size: int = 10000
@property
def total_batch_size(self) -> int:
"""Total batch size across all devices."""
return self.n_devices * self.per_device_batch_size
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize to dictionary."""
return {
"n_devices": self.n_devices,
"shard_axis": self.shard_axis,
"use_pmap": self.use_pmap,
"use_pjit": self.use_pjit,
"per_device_batch_size": self.per_device_batch_size,
}
[docs]
def get_multi_gpu_config(
cluster_info: ClusterInfo | None = None,
) -> MultiGPUConfig | None:
"""Generate multi-GPU sharding configuration.
Creates a MultiGPUConfig based on detected cluster or local GPU setup.
Parameters
----------
cluster_info : ClusterInfo, optional
Cluster information from ClusterDetector. If None, auto-detects.
Returns
-------
MultiGPUConfig or None
Configuration for multi-GPU processing, or None if no GPUs available.
Examples
--------
>>> config = get_multi_gpu_config()
>>> if config:
... print(f"Using {config.n_devices} GPUs with batch size {config.total_batch_size}")
"""
if cluster_info is None:
detector = ClusterDetector()
cluster_info = detector.detect()
if cluster_info is None or cluster_info.total_gpus == 0:
return None
# For single-node, use all local GPUs
if cluster_info.node_count == 1:
n_devices = cluster_info.gpus_per_node
per_device_batch = 10000
else:
# For multi-node, use GPUs on current node (pjit handles distribution)
n_devices = cluster_info.gpus_per_node
per_device_batch = 50000 # Larger batches for distributed
return MultiGPUConfig(
n_devices=n_devices,
shard_axis=0,
use_pmap=cluster_info.node_count == 1, # pmap for single-node
use_pjit=cluster_info.node_count > 1, # pjit for multi-node
per_device_batch_size=per_device_batch,
)
[docs]
def create_distributed_config(cluster_info: ClusterInfo) -> dict[str, Any]:
"""Create distributed processing configuration for HPC clusters.
Generates configuration suitable for PBS Pro multi-node setup with
appropriate chunk sizes, checkpointing, and memory settings.
Parameters
----------
cluster_info : ClusterInfo
Cluster information from ClusterDetector.
Returns
-------
dict
Configuration dictionary for distributed processing.
Examples
--------
>>> detector = ClusterDetector()
>>> cluster_info = detector.detect()
>>> if cluster_info:
... dist_config = create_distributed_config(cluster_info)
... print(f"Chunk size: {dist_config['chunk_size']}")
"""
# Calculate memory per node (estimate based on A100 config)
# A100 has 40GB or 80GB GPU memory; assume 80GB per GPU
gpu_memory_per_node_gb = cluster_info.gpus_per_node * 80 # Conservative
# For distributed, chunk size should be larger to amortize communication
# But not so large that it overflows GPU memory
chunk_size = min(
1_000_000, # Max 1M points per chunk
max(
100_000, # Min 100K points per chunk
int(gpu_memory_per_node_gb * 1e9 / (8 * 100)), # ~100 bytes per point
),
)
# Enable checkpointing for fault tolerance in long-running distributed jobs
enable_checkpoints = cluster_info.node_count > 1 or cluster_info.total_gpus > 4
return {
"tier": "STREAMING_CHECKPOINT",
"goal": "ROBUST",
"enable_multistart": True,
"n_starts": min(cluster_info.total_gpus, 20), # Scale with GPUs
"chunk_size": chunk_size,
"enable_checkpoints": enable_checkpoints,
"checkpoint_frequency": 50, # Checkpoint every 50 iterations
"gtol": 1e-6,
"ftol": 1e-6,
"xtol": 1e-6,
"distributed": True,
"n_devices": cluster_info.total_gpus,
"nodes": cluster_info.node_count,
"gpus_per_node": cluster_info.gpus_per_node,
"scheduler": cluster_info.scheduler,
}
[docs]
def create_checkpoint_directory(base_dir: str | Path | None = None) -> str:
"""Create a checkpoint directory with timestamp.
Creates a directory at ./nlsq_checkpoints/YYYYMMDD_HHMMSS/ for storing
optimization checkpoints. Integrates with HybridStreamingConfig.enable_checkpoints.
Parameters
----------
base_dir : str or Path, optional
Base directory for checkpoints. Default: ./nlsq_checkpoints
Returns
-------
str
Absolute path to the created checkpoint directory.
Examples
--------
>>> checkpoint_dir = create_checkpoint_directory()
>>> # Returns path like './nlsq_checkpoints/20251219_143052/'
"""
if base_dir is None:
base_dir = Path.cwd() / "nlsq_checkpoints"
else:
base_dir = Path(base_dir)
# Create timestamp-based subdirectory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = base_dir / timestamp
# Create directory (including parents if needed)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return str(checkpoint_dir)
__all__ = [
"ClusterDetector",
"ClusterInfo",
"MemoryBudget",
"MemoryBudgetSelector",
"MultiGPUConfig",
"OptimizationGoal",
"calculate_adaptive_tolerances",
"create_checkpoint_directory",
"create_distributed_config",
"get_multi_gpu_config",
]