"""Robust matrix decomposition with multi-level fallback strategies.
This module extends the SVD fallback to provide comprehensive fallback
strategies for all matrix decompositions used in optimization.
"""
from typing import Literal, cast
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import cholesky as jax_cholesky
from jax.scipy.linalg import qr as jax_qr
from jax.scipy.linalg import svd as jax_svd
# Use NLSQ logging system
from nlsq.utils.logging import get_logger
# Import the existing SVD fallback utilities
[docs]
class RobustDecomposition:
"""Multi-level fallback for matrix decompositions.
This class provides robust matrix decomposition methods that
automatically fall back through multiple strategies if the
primary method fails. The fallback chain goes from:
1. JAX on GPU (if available)
2. JAX on CPU
3. SciPy (if available)
4. NumPy
5. Safe mode with regularization
Attributes
----------
fallback_chain : list
Ordered list of (name, method) tuples for fallback strategies
logger : logging.Logger
Logger for debugging decomposition issues
"""
[docs]
def __init__(self, enable_logging: bool = False):
"""Initialize robust decomposition handler.
Parameters
----------
enable_logging : bool
Whether to enable detailed logging of fallback attempts
"""
self.logger = get_logger("robust_decomposition")
if enable_logging:
from nlsq.utils.logging import LogLevel
self.logger.logger.setLevel(LogLevel.DEBUG)
# Build fallback chain
self.fallback_chain = [
("jax_gpu", self._jax_gpu_decomp),
("jax_cpu", self._jax_cpu_decomp),
("scipy", self._scipy_decomp),
("numpy", self._numpy_decomp),
("safe_mode", self._safe_mode_decomp),
]
# Regularization parameters
self.eps = np.finfo(np.float64).eps
self.regularization_factor = 1e-10
[docs]
def svd(
self, matrix: jnp.ndarray, full_matrices: bool = False
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute SVD with automatic fallback.
Parameters
----------
matrix : jnp.ndarray
Matrix to decompose
full_matrices : bool
Whether to compute full matrices
Returns
-------
U : jnp.ndarray
Left singular vectors
s : jnp.ndarray
Singular values
Vt : jnp.ndarray
Right singular vectors (transposed)
Raises
------
RuntimeError
If all decomposition methods fail
"""
for name, method in self.fallback_chain:
try:
result = method(matrix, "svd", full_matrices)
if result is not None and self._validate_svd(result):
self.logger.debug(f"SVD succeeded with {name}")
return result
except Exception as e:
self.logger.debug(f"{name} SVD failed: {e}")
continue
raise RuntimeError("All SVD methods failed")
[docs]
def qr(
self, matrix: jnp.ndarray, mode: str = "reduced"
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Compute QR decomposition with fallback.
Parameters
----------
matrix : jnp.ndarray
Matrix to decompose
mode : str
QR mode ('reduced' or 'complete')
Returns
-------
Q : jnp.ndarray
Orthogonal matrix
R : jnp.ndarray
Upper triangular matrix
Raises
------
RuntimeError
If all decomposition methods fail
"""
for name, method in self.fallback_chain:
try:
result = method(matrix, "qr", mode)
if result is not None and self._validate_qr(result):
self.logger.debug(f"QR succeeded with {name}")
return result
except Exception as e:
self.logger.debug(f"{name} QR failed: {e}")
continue
raise RuntimeError("All QR methods failed")
[docs]
def cholesky(self, matrix: jnp.ndarray, lower: bool = True) -> jnp.ndarray:
"""Compute Cholesky decomposition with fallback and regularization.
Parameters
----------
matrix : jnp.ndarray
Positive definite matrix to decompose
lower : bool
Whether to return lower triangular matrix
Returns
-------
L : jnp.ndarray
Cholesky factor (lower or upper triangular)
Raises
------
RuntimeError
If all decomposition methods fail
"""
# First ensure matrix is positive definite
matrix = self._ensure_positive_definite(matrix)
for name, method in self.fallback_chain:
try:
result = method(matrix, "cholesky", lower)
if result is not None:
self.logger.debug(f"Cholesky succeeded with {name}")
return result
except Exception as e:
self.logger.debug(f"{name} Cholesky failed: {e}")
continue
# Last resort: eigendecomposition
return self._cholesky_via_eigen(matrix, lower)
def _jax_gpu_decomp(self, matrix: jnp.ndarray, decomp_type: str, *args):
"""Try decomposition on GPU using JAX."""
try:
gpu_devices = jax.devices("gpu")
except RuntimeError:
return None
if not gpu_devices:
return None
gpu_device = gpu_devices[0]
with jax.default_device(gpu_device):
if decomp_type == "svd":
full_matrices = args[0] if args else False
U, s, Vt = jax_svd(matrix, full_matrices=full_matrices)
return U, s, Vt
elif decomp_type == "qr":
mode = args[0] if args else "reduced"
qr_result = cast(
tuple[jnp.ndarray, jnp.ndarray], jax_qr(matrix, mode=mode)
)
Q, R = qr_result
return Q, R
elif decomp_type == "cholesky":
lower = args[0] if args else True
L = jax_cholesky(matrix, lower=lower)
return L
def _jax_cpu_decomp(self, matrix: jnp.ndarray, decomp_type: str, *args):
"""Try decomposition on CPU using JAX."""
try:
cpu_device = jax.devices("cpu")[0]
except IndexError:
return None
with jax.default_device(cpu_device):
# Move data to CPU
matrix_cpu = jax.device_put(matrix, cpu_device)
if decomp_type == "svd":
full_matrices = args[0] if args else False
U, s, Vt = jax_svd(matrix_cpu, full_matrices=full_matrices)
return U, s, Vt
elif decomp_type == "qr":
mode = args[0] if args else "reduced"
qr_result = cast(
tuple[jnp.ndarray, jnp.ndarray], jax_qr(matrix_cpu, mode=mode)
)
Q, R = qr_result
return Q, R
elif decomp_type == "cholesky":
lower = args[0] if args else True
L = jax_cholesky(matrix_cpu, lower=lower)
return L
def _scipy_decomp(self, matrix: jnp.ndarray, decomp_type: str, *args):
"""Try decomposition using SciPy."""
try:
import scipy.linalg
except ImportError:
return None
# Convert to numpy
matrix_np = np.array(matrix)
if decomp_type == "svd":
full_matrices = args[0] if args else False
U, s, Vt = scipy.linalg.svd(matrix_np, full_matrices=full_matrices)
return jnp.array(U), jnp.array(s), jnp.array(Vt)
elif decomp_type == "qr":
mode = args[0] if args else "reduced"
Q, R = scipy.linalg.qr(matrix_np, mode=mode)
return jnp.array(Q), jnp.array(R)
elif decomp_type == "cholesky":
lower = args[0] if args else True
L = scipy.linalg.cholesky(matrix_np, lower=lower)
return jnp.array(L)
def _numpy_decomp(self, matrix: jnp.ndarray, decomp_type: str, *args):
"""Try decomposition using NumPy."""
# Convert to numpy
matrix_np = np.array(matrix)
if decomp_type == "svd":
full_matrices = args[0] if args else False
U, s, Vt = np.linalg.svd(matrix_np, full_matrices=full_matrices)
return jnp.array(U), jnp.array(s), jnp.array(Vt)
elif decomp_type == "qr":
mode = args[0] if args else "reduced"
mode = cast(Literal["reduced", "complete", "r", "raw"], mode)
Q, R = np.linalg.qr(matrix_np, mode=mode)
return jnp.array(Q), jnp.array(R)
elif decomp_type == "cholesky":
L = np.linalg.cholesky(matrix_np)
return jnp.array(L)
def _safe_mode_decomp(self, matrix: jnp.ndarray, decomp_type: str, *args):
"""Safe mode decomposition with regularization."""
m, n = matrix.shape
if decomp_type == "svd":
# Add regularization to improve conditioning
reg_matrix = matrix + self.regularization_factor * jnp.eye(m, n)
return self._numpy_decomp(reg_matrix, "svd", *args)
elif decomp_type == "qr":
# QR with column pivoting for stability
matrix_np = np.array(matrix)
try:
import scipy.linalg
Q, R, P = scipy.linalg.qr(matrix_np, mode="economic", pivoting=True)
# Reorder to undo pivoting
R_reordered = R[:, np.argsort(P)]
return jnp.array(Q), jnp.array(R_reordered)
except (np.linalg.LinAlgError, ValueError, ImportError):
# Fall back to basic QR with regularization
reg_matrix = matrix + self.eps * jnp.eye(m, n)
return self._numpy_decomp(reg_matrix, "qr", *args)
elif decomp_type == "cholesky":
# Ensure positive definite with stronger regularization
matrix = self._ensure_positive_definite(
matrix, factor=self.regularization_factor * 100
)
return self._numpy_decomp(matrix, "cholesky", *args)
def _ensure_positive_definite(
self, matrix: jnp.ndarray, factor: float | None = None
) -> jnp.ndarray:
"""Make matrix positive definite.
Parameters
----------
matrix : jnp.ndarray
Matrix to make positive definite
factor : float, optional
Regularization factor (uses default if None)
Returns
-------
matrix_pd : jnp.ndarray
Positive definite matrix
"""
if factor is None:
factor = self.regularization_factor
n = matrix.shape[0]
# Ensure symmetry
matrix = 0.5 * (matrix + matrix.T)
try:
# Check minimum eigenvalue
eigenvalues = jnp.linalg.eigvalsh(matrix)
min_eig = jnp.min(eigenvalues)
if min_eig < factor:
# Add diagonal to ensure positive definiteness
shift = factor - min_eig + self.eps
matrix = matrix + shift * jnp.eye(n)
except (np.linalg.LinAlgError, ValueError):
# Fallback: add diagonal regularization
matrix = matrix + factor * jnp.eye(n)
return matrix
def _cholesky_via_eigen(
self, matrix: jnp.ndarray, lower: bool = True
) -> jnp.ndarray:
"""Compute Cholesky via eigendecomposition.
Parameters
----------
matrix : jnp.ndarray
Positive definite matrix
lower : bool
Whether to return lower triangular
Returns
-------
L : jnp.ndarray
Cholesky factor
"""
try:
# Eigendecomposition
eigenvalues, eigenvectors = jnp.linalg.eigh(matrix)
# Ensure all eigenvalues are positive
eigenvalues = jnp.maximum(eigenvalues, self.eps)
# Reconstruct: A = V * diag(lambda) * V^T
# So L = V * sqrt(diag(lambda))
L = eigenvectors @ jnp.diag(jnp.sqrt(eigenvalues))
if lower:
return L
else:
return L.T
except Exception as e:
raise RuntimeError(f"Cholesky via eigendecomposition failed: {e}") from e
def _validate_svd(self, result: tuple) -> bool:
"""Validate SVD result.
Parameters
----------
result : tuple
(U, s, Vt) from SVD
Returns
-------
valid : bool
Whether the result is valid
"""
try:
U, s, Vt = result
return (
bool(jnp.all(jnp.isfinite(U)))
and bool(jnp.all(jnp.isfinite(s)))
and bool(jnp.all(jnp.isfinite(Vt)))
and bool(jnp.all(s >= 0)) # Singular values must be non-negative
)
except (ValueError, TypeError, AttributeError):
return False
def _validate_qr(self, result: tuple) -> bool:
"""Validate QR result.
Parameters
----------
result : tuple
(Q, R) from QR decomposition
Returns
-------
valid : bool
Whether the result is valid
"""
try:
Q, R = result
return bool(jnp.all(jnp.isfinite(Q))) and bool(jnp.all(jnp.isfinite(R)))
except (ValueError, TypeError, AttributeError):
return False
[docs]
def solve_least_squares(self, A: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
"""Solve least squares problem with robust decomposition.
Parameters
----------
A : jnp.ndarray
Coefficient matrix
b : jnp.ndarray
Right-hand side
Returns
-------
x : jnp.ndarray
Solution vector
"""
try:
# Try SVD first (most stable)
U, s, Vt = self.svd(A, full_matrices=False)
# Compute pseudoinverse solution
# x = V @ (S^+ @ (U^T @ b))
s_inv = jnp.where(s > self.eps * jnp.max(s), 1.0 / s, 0.0)
x = Vt.T @ (s_inv * (U.T @ b))
return x
except Exception as e:
self.logger.warning(f"SVD failed for least squares, trying QR: {e}")
try:
# Fall back to QR
Q, R = self.qr(A)
# Solve R @ x = Q^T @ b
y = Q.T @ b
x = jnp.linalg.solve(R, y)
return x
except Exception as e2:
self.logger.warning(
f"QR failed for least squares, using normal equations: {e2}"
)
# Last resort: normal equations (less stable)
AtA = A.T @ A
Atb = A.T @ b
# Regularize if needed
AtA = self._ensure_positive_definite(AtA)
try:
L = self.cholesky(AtA)
# Solve L @ L^T @ x = A^T @ b
y = jnp.linalg.solve(L, Atb)
x = jnp.linalg.solve(L.T, y)
return x
except (np.linalg.LinAlgError, ValueError):
# Ultimate fallback: direct solve with regularization
x = jnp.linalg.solve(AtA, Atb)
return x
# Create global instance
robust_decomp = RobustDecomposition()