nlsq.device module

GPU detection and warning utilities (System CUDA version).

This module provides runtime GPU availability checks and system CUDA detection to help users realize when GPU acceleration is available but not being used.

nlsq.device.get_system_cuda_version()[source]

Detect system CUDA version from nvcc.

Returns:

Tuple of (full_version, major_version) or (None, None) if not found. Example: (“12.6”, 12) or (“13.0”, 13)

Return type:

tuple[str | None, int | None]

nlsq.device.get_gpu_info()[source]

Detect GPU name and SM version.

Returns:

Tuple of (gpu_name, sm_version) or (None, None) if not found. Example: (“NVIDIA GeForce RTX 4090”, 8.9)

Return type:

tuple[str | None, float | None]

nlsq.device.check_plugin_conflicts()[source]

Check for known JAX CUDA plugin conflicts.

Returns:

List of issue descriptions (empty means no issues detected).

Return type:

list[str]

nlsq.device.get_recommended_package()[source]

Get recommended JAX package based on system CUDA.

Returns:

Package name like “jax[cuda12-local]” or “jax[cuda13-local]”, or None if no compatible setup found.

Return type:

str | None

nlsq.device.check_gpu_availability(warn=True)[source]

Check if GPU is available but not being used by JAX.

Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.

Parameters:

warn (bool, default=True) – If True, print warning when GPU available but not used.

Returns:

  • bool – True if GPU is being used by JAX, False otherwise.

  • Environment Variables

  • ———————

  • NLSQ_SKIP_GPU_CHECK (str, optional) – Set to “1”, “true”, or “yes” (case-insensitive) to suppress GPU warnings. Useful for CI/CD pipelines or users who intentionally use CPU-only JAX.

Return type:

bool

Notes

This check runs automatically on import but has minimal overhead: - Subprocess call to nvidia-smi (~5ms) - JAX device query (~1ms) - Only prints warning when mismatch detected - Silent failures prevent disruption

nlsq.device.get_device_info()[source]

Get comprehensive device information.

Returns:

Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package - plugin_issues: List of detected plugin conflict descriptions

Return type:

dict

Overview

The nlsq.device module provides GPU detection and warning utilities to help users realize when GPU acceleration is available but not being used. This module helps maximize performance by alerting users to available hardware acceleration opportunities.

Important

GPU acceleration is Linux only. On macOS and Windows, check_gpu_availability() and get_recommended_package() return early without running any subprocess detection. The CPU-only backend is enforced automatically by nlsq/__init__.py.

Key Features

  • Linux-only GPU detection via nvidia-smi hardware query

  • Platform-aware guards — non-Linux platforms skip detection entirely

  • JAX device inspection to check current compute backend

  • Plugin conflict detection for dual cuda12/cuda13 and version mismatches

  • User-friendly warnings with actionable installation instructions

  • 20-100x speedup recommendations for GPU-enabled configurations

  • Silent failure handling to avoid disrupting workflow

  • Environment variable control for CI/CD and intentional CPU-only usage

  • Minimal overhead (~6ms total: 5ms nvidia-smi + 1ms JAX query)

Functions

check_gpu_availability([warn])

Check if GPU is available but not being used by JAX.

check_plugin_conflicts()

Check for known JAX CUDA plugin conflicts.

get_system_cuda_version()

Detect system CUDA version from nvcc.

get_gpu_info()

Detect GPU name and SM version.

get_recommended_package()

Get recommended JAX package based on system CUDA.

get_device_info()

Get comprehensive device information.

Usage Examples

Automatic GPU Check on Import

The GPU check runs automatically when importing NLSQ:

import nlsq  # Automatically checks GPU availability

If an NVIDIA GPU is detected but JAX is using CPU, you’ll see:

GPU AVAILABLE BUT NOT USED
===========================
  GPU: Tesla V100-SXM2-16GB (SM 7.0)
  System CUDA: 12.6
  JAX backend: CPU-only

  Fix: make install-jax-gpu
  Or:  pip uninstall -y jax-cuda13-plugin jax-cuda13-pjrt jax-cuda12-plugin jax-cuda12-pjrt
       pip uninstall -y jax jaxlib
       pip install "jax[cuda12-local]"

To suppress this warning: export NLSQ_SKIP_GPU_CHECK=1

Suppressing GPU Warnings

For CI/CD pipelines or intentional CPU-only usage:

import os

# Option 1: Set before importing NLSQ
os.environ["NLSQ_SKIP_GPU_CHECK"] = "1"
import nlsq  # No GPU warning

Or via shell:

# Option 2: Export environment variable
export NLSQ_SKIP_GPU_CHECK=1
python your_script.py

# Option 3: Inline with command
NLSQ_SKIP_GPU_CHECK=1 python your_script.py

# Option 4: Add to CI/CD environment variables
# GitHub Actions example:
env:
  NLSQ_SKIP_GPU_CHECK: "1"

Accepted values: "1", "true", "yes" (case-insensitive)

Manual GPU Check

Call the check function directly:

from nlsq.device import check_gpu_availability

# Manually trigger GPU check
check_gpu_availability()

Note: This will respect the NLSQ_SKIP_GPU_CHECK environment variable.

Verifying GPU Usage

Check which devices JAX is actually using:

import jax

# List all available devices
devices = jax.devices()
print(f"JAX devices: {devices}")

# Expected with GPU: [cuda(id=0)]
# Expected CPU-only: [CpuDevice(id=0)]

# Check if using GPU
using_gpu = any("cuda" in str(d).lower() or "gpu" in str(d).lower() for d in devices)
print(f"Using GPU: {using_gpu}")

Configuration

Environment Variables

NLSQ_SKIP_GPU_CHECK

Controls whether GPU availability check runs on import.

  • Values: "1", "true", "yes" (case-insensitive)

  • Default: Not set (check runs)

  • Effect: Suppresses GPU availability warning

  • Use cases: CI/CD pipelines, intentional CPU-only usage, stdout parsing

# Suppress GPU check
export NLSQ_SKIP_GPU_CHECK=1

# Or inline
NLSQ_SKIP_GPU_CHECK=1 python script.py

Performance Characteristics

Check Overhead:

  • Total time: ~6ms per import

  • nvidia-smi query: ~5ms

  • JAX device query: ~1ms

  • Print warning: <1ms (only when GPU available but unused)

When Check Runs:

  • Automatically on first import nlsq

  • Only once per Python session (not on subsequent imports)

  • Can be manually triggered with check_gpu_availability()

Failure Behavior:

The check silently fails (no error messages) when:

  • NVIDIA GPU hardware is not present

  • nvidia-smi is not installed

  • JAX is not installed yet

  • Permission denied errors

  • Timeout errors (>5 seconds)

  • Any unexpected exceptions

This design ensures the check never disrupts normal workflow.

Best Practices

  1. CI/CD Pipelines: Set NLSQ_SKIP_GPU_CHECK=1 to suppress warnings

  2. Production Deployment: Verify GPU usage with jax.devices()

  3. Development: Keep warnings enabled to catch misconfiguration

  4. Performance Testing: Compare CPU vs GPU benchmarks

  5. Documentation: Note GPU requirements in deployment guides

Common Use Cases

Testing in CI/CD Without GPU

Suppress warnings in continuous integration:

# GitHub Actions example
jobs:
  test:
    runs-on: ubuntu-latest
    env:
      NLSQ_SKIP_GPU_CHECK: "1"  # Suppress GPU warnings
    steps:
      - name: Run tests
        run: pytest tests/

Jupyter Notebooks

Reduce output clutter in notebooks:

# At top of notebook
import os

os.environ["NLSQ_SKIP_GPU_CHECK"] = "1"
import nlsq  # No warning printed

Programmatic Output Parsing

When parsing stdout programmatically:

import os
import subprocess

# Suppress GPU warnings in subprocess
env = os.environ.copy()
env["NLSQ_SKIP_GPU_CHECK"] = "1"

result = subprocess.run(
    ["python", "my_nlsq_script.py"],
    env=env,
    capture_output=True,
    text=True,
)

# Parse clean output without GPU warnings
output = result.stdout

Debug GPU Detection Issues

Manually check GPU availability:

import subprocess

# Check if nvidia-smi works
result = subprocess.run(
    ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
    capture_output=True,
    text=True,
)

if result.returncode == 0:
    print(f"GPU detected: {result.stdout.strip()}")
else:
    print("No GPU detected or nvidia-smi not available")

# Check JAX backend
import jax

print(f"JAX devices: {jax.devices()}")

Implementation Details

GPU Detection Algorithm:

  1. Check if NLSQ_SKIP_GPU_CHECK environment variable is set

  2. Check platform — return immediately on macOS and Windows (GPU is Linux-only)

  3. Query nvidia-smi for GPU hardware (5 second timeout)

  4. Parse GPU name from output

  5. Query JAX for current device backend

  6. Compare hardware availability vs JAX usage

  7. Print warning only if mismatch detected

Error Handling:

All exceptions are silently caught to prevent workflow disruption:

  • FileNotFoundError: nvidia-smi not installed

  • subprocess.TimeoutExpired: nvidia-smi hung

  • ImportError: JAX not installed

  • Exception: Any other unexpected error

GPU Name Sanitization:

  • Limit to 100 characters maximum

  • Convert to ASCII (replace non-ASCII with ?)

  • Prevents display issues with special characters

Security Considerations

Command Injection:

The module uses subprocess.run() with a fixed command list (no shell=True), preventing command injection attacks.

Timeout Protection:

nvidia-smi query has a 5-second timeout to prevent hanging.

Privilege Escalation:

No privileged operations or file writes are performed.

See Also