nlsq.device module¶
GPU detection and warning utilities (System CUDA version).
This module provides runtime GPU availability checks and system CUDA detection to help users realize when GPU acceleration is available but not being used.
- nlsq.device.get_system_cuda_version()[source]
Detect system CUDA version from nvcc.
- nlsq.device.get_gpu_info()[source]
Detect GPU name and SM version.
- nlsq.device.check_plugin_conflicts()[source]
Check for known JAX CUDA plugin conflicts.
- nlsq.device.get_recommended_package()[source]
Get recommended JAX package based on system CUDA.
- Returns:
Package name like “jax[cuda12-local]” or “jax[cuda13-local]”, or None if no compatible setup found.
- Return type:
str | None
- nlsq.device.check_gpu_availability(warn=True)[source]
Check if GPU is available but not being used by JAX.
Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.
- Parameters:
warn (bool, default=True) – If True, print warning when GPU available but not used.
- Returns:
bool – True if GPU is being used by JAX, False otherwise.
Environment Variables
———————
NLSQ_SKIP_GPU_CHECK (str, optional) – Set to “1”, “true”, or “yes” (case-insensitive) to suppress GPU warnings. Useful for CI/CD pipelines or users who intentionally use CPU-only JAX.
- Return type:
Notes
This check runs automatically on import but has minimal overhead: - Subprocess call to nvidia-smi (~5ms) - JAX device query (~1ms) - Only prints warning when mismatch detected - Silent failures prevent disruption
- nlsq.device.get_device_info()[source]
Get comprehensive device information.
- Returns:
Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package - plugin_issues: List of detected plugin conflict descriptions
- Return type:
Overview¶
The nlsq.device module provides GPU detection and warning utilities to help
users realize when GPU acceleration is available but not being used. This module
helps maximize performance by alerting users to available hardware acceleration
opportunities.
Important
GPU acceleration is Linux only. On macOS and Windows,
check_gpu_availability() and get_recommended_package() return
early without running any subprocess detection. The CPU-only backend
is enforced automatically by nlsq/__init__.py.
Key Features¶
Linux-only GPU detection via nvidia-smi hardware query
Platform-aware guards — non-Linux platforms skip detection entirely
JAX device inspection to check current compute backend
Plugin conflict detection for dual cuda12/cuda13 and version mismatches
User-friendly warnings with actionable installation instructions
20-100x speedup recommendations for GPU-enabled configurations
Silent failure handling to avoid disrupting workflow
Environment variable control for CI/CD and intentional CPU-only usage
Minimal overhead (~6ms total: 5ms nvidia-smi + 1ms JAX query)
Functions¶
|
Check if GPU is available but not being used by JAX. |
Check for known JAX CUDA plugin conflicts. |
|
Detect system CUDA version from nvcc. |
|
Detect GPU name and SM version. |
|
Get recommended JAX package based on system CUDA. |
|
Get comprehensive device information. |
Usage Examples¶
Automatic GPU Check on Import¶
The GPU check runs automatically when importing NLSQ:
import nlsq # Automatically checks GPU availability
If an NVIDIA GPU is detected but JAX is using CPU, you’ll see:
GPU AVAILABLE BUT NOT USED
===========================
GPU: Tesla V100-SXM2-16GB (SM 7.0)
System CUDA: 12.6
JAX backend: CPU-only
Fix: make install-jax-gpu
Or: pip uninstall -y jax-cuda13-plugin jax-cuda13-pjrt jax-cuda12-plugin jax-cuda12-pjrt
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]"
To suppress this warning: export NLSQ_SKIP_GPU_CHECK=1
Suppressing GPU Warnings¶
For CI/CD pipelines or intentional CPU-only usage:
import os
# Option 1: Set before importing NLSQ
os.environ["NLSQ_SKIP_GPU_CHECK"] = "1"
import nlsq # No GPU warning
Or via shell:
# Option 2: Export environment variable
export NLSQ_SKIP_GPU_CHECK=1
python your_script.py
# Option 3: Inline with command
NLSQ_SKIP_GPU_CHECK=1 python your_script.py
# Option 4: Add to CI/CD environment variables
# GitHub Actions example:
env:
NLSQ_SKIP_GPU_CHECK: "1"
Accepted values: "1", "true", "yes" (case-insensitive)
Manual GPU Check¶
Call the check function directly:
from nlsq.device import check_gpu_availability
# Manually trigger GPU check
check_gpu_availability()
Note: This will respect the NLSQ_SKIP_GPU_CHECK environment variable.
Verifying GPU Usage¶
Check which devices JAX is actually using:
import jax
# List all available devices
devices = jax.devices()
print(f"JAX devices: {devices}")
# Expected with GPU: [cuda(id=0)]
# Expected CPU-only: [CpuDevice(id=0)]
# Check if using GPU
using_gpu = any("cuda" in str(d).lower() or "gpu" in str(d).lower() for d in devices)
print(f"Using GPU: {using_gpu}")
Configuration¶
Environment Variables¶
NLSQ_SKIP_GPU_CHECKControls whether GPU availability check runs on import.
Values:
"1","true","yes"(case-insensitive)Default: Not set (check runs)
Effect: Suppresses GPU availability warning
Use cases: CI/CD pipelines, intentional CPU-only usage, stdout parsing
# Suppress GPU check export NLSQ_SKIP_GPU_CHECK=1 # Or inline NLSQ_SKIP_GPU_CHECK=1 python script.py
Performance Characteristics¶
Check Overhead:
Total time: ~6ms per import
nvidia-smi query: ~5ms
JAX device query: ~1ms
Print warning: <1ms (only when GPU available but unused)
When Check Runs:
Automatically on first
import nlsqOnly once per Python session (not on subsequent imports)
Can be manually triggered with
check_gpu_availability()
Failure Behavior:
The check silently fails (no error messages) when:
NVIDIA GPU hardware is not present
nvidia-smi is not installed
JAX is not installed yet
Permission denied errors
Timeout errors (>5 seconds)
Any unexpected exceptions
This design ensures the check never disrupts normal workflow.
Best Practices¶
CI/CD Pipelines: Set
NLSQ_SKIP_GPU_CHECK=1to suppress warningsProduction Deployment: Verify GPU usage with
jax.devices()Development: Keep warnings enabled to catch misconfiguration
Performance Testing: Compare CPU vs GPU benchmarks
Documentation: Note GPU requirements in deployment guides
Common Use Cases¶
Testing in CI/CD Without GPU¶
Suppress warnings in continuous integration:
# GitHub Actions example
jobs:
test:
runs-on: ubuntu-latest
env:
NLSQ_SKIP_GPU_CHECK: "1" # Suppress GPU warnings
steps:
- name: Run tests
run: pytest tests/
Jupyter Notebooks¶
Reduce output clutter in notebooks:
# At top of notebook
import os
os.environ["NLSQ_SKIP_GPU_CHECK"] = "1"
import nlsq # No warning printed
Programmatic Output Parsing¶
When parsing stdout programmatically:
import os
import subprocess
# Suppress GPU warnings in subprocess
env = os.environ.copy()
env["NLSQ_SKIP_GPU_CHECK"] = "1"
result = subprocess.run(
["python", "my_nlsq_script.py"],
env=env,
capture_output=True,
text=True,
)
# Parse clean output without GPU warnings
output = result.stdout
Debug GPU Detection Issues¶
Manually check GPU availability:
import subprocess
# Check if nvidia-smi works
result = subprocess.run(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
capture_output=True,
text=True,
)
if result.returncode == 0:
print(f"GPU detected: {result.stdout.strip()}")
else:
print("No GPU detected or nvidia-smi not available")
# Check JAX backend
import jax
print(f"JAX devices: {jax.devices()}")
Implementation Details¶
GPU Detection Algorithm:
Check if
NLSQ_SKIP_GPU_CHECKenvironment variable is setCheck platform — return immediately on macOS and Windows (GPU is Linux-only)
Query nvidia-smi for GPU hardware (5 second timeout)
Parse GPU name from output
Query JAX for current device backend
Compare hardware availability vs JAX usage
Print warning only if mismatch detected
Error Handling:
All exceptions are silently caught to prevent workflow disruption:
FileNotFoundError: nvidia-smi not installedsubprocess.TimeoutExpired: nvidia-smi hungImportError: JAX not installedException: Any other unexpected error
GPU Name Sanitization:
Limit to 100 characters maximum
Convert to ASCII (replace non-ASCII with
?)Prevents display issues with special characters
Security Considerations¶
Command Injection:
The module uses subprocess.run() with a fixed command list (no shell=True),
preventing command injection attacks.
Timeout Protection:
nvidia-smi query has a 5-second timeout to prevent hanging.
Privilege Escalation:
No privileged operations or file writes are performed.
See Also¶
nlsq.config module : Configuration management
Your First Curve Fit : Getting started with GPU setup
Performance Optimization Guide : Performance optimization guide
JAX Installation Guide : JAX GPU setup