Source code for nlsq.core.trf_jit

"""JIT-compiled functions for Trust Region Reflective optimization.

This module contains JAX JIT-compiled helper functions for the TRF algorithm,
providing GPU/TPU-accelerated implementations of core mathematical operations.

XLA Memory Optimization
-----------------------
All JIT functions are defined at module level as singletons. This ensures each
function is compiled only once per unique input shape, regardless of how many
TrustRegionJITFunctions instances are created. Previously, each instance created
10+ new @jit closures, causing unbounded XLA compilation cache growth.
"""

from __future__ import annotations

import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.linalg import svd as jax_svd

from nlsq.stability.svd_fallback import compute_svd_with_fallback, is_gpu_error

__all__ = ["TrustRegionJITFunctions"]

# Algorithm constants
LOSS_FUNCTION_COEFF = 0.5  # Coefficient for loss function (0.5 * ||f||^2)
NUMERICAL_ZERO_THRESHOLD = 1e-14  # Threshold for values considered numerically zero
DEFAULT_TOLERANCE = 1e-8  # Default tolerance for iterative solvers (matches outer ftol/gtol/xtol defaults)
# Fixed CG iteration limit — prevents shape-dependent recompilation.
# The while_loop convergence check provides early exit for small problems.
CG_MAX_ITERATIONS = 100


# ---------------------------------------------------------------------------
# Module-level JIT-compiled functions (singletons — compiled once per shape)
# ---------------------------------------------------------------------------


@jit
def _default_loss_func(f: jnp.ndarray) -> jnp.ndarray:
    """Default loss: 0.5 * ||f||^2."""
    return LOSS_FUNCTION_COEFF * jnp.dot(f, f)


@jit
def _compute_grad(J: jnp.ndarray, f: jnp.ndarray) -> jnp.ndarray:
    """Compute gradient of loss function: f^T J."""
    return f.dot(J)


@jit
def _compute_grad_hat(g: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray:
    """Compute gradient in hat space: d * g."""
    return d * g


@jit
def _svd_no_bounds_jit(
    J: jnp.ndarray, d: jnp.ndarray, f: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    J_h = J * d
    U, s, Vt = jax_svd(J_h, full_matrices=False)
    uf = U.T.dot(f)
    return J_h, U, s, Vt.T, uf


@jit
def _svd_bounds_jit(
    f: jnp.ndarray,
    J: jnp.ndarray,
    d: jnp.ndarray,
    J_diag: jnp.ndarray,
    f_zeros: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    J_h = J * d
    J_augmented = jnp.concatenate([J_h, J_diag])
    f_augmented = jnp.concatenate([f, f_zeros])
    U, s, Vt = jax_svd(J_augmented, full_matrices=False)
    uf = U.T.dot(f_augmented)
    return J_h, U, s, Vt.T, uf


@jit
def _conjugate_gradient_solve(
    J: jnp.ndarray,
    f: jnp.ndarray,
    d: jnp.ndarray,
    alpha: float = 0.0,
    tol: float = DEFAULT_TOLERANCE,
) -> tuple[jnp.ndarray, jnp.ndarray, int]:
    """Solve (J^T J + alpha*I) p = -J^T f using CG.

    Uses jax.lax.while_loop for 3-8x GPU acceleration.

    Parameters
    ----------
    J : jnp.ndarray
        Jacobian matrix (m x n)
    f : jnp.ndarray
        Residual vector (m,)
    d : jnp.ndarray
        Scaling diagonal (n,)
    alpha : float
        Regularization parameter
    tol : float
        Convergence tolerance

    Returns
    -------
    p : jnp.ndarray
        Solution vector (n,)
    residual_norm : jnp.ndarray
        Final residual norm
    n_iter : int
        Number of CG iterations

    Notes
    -----
    Uses fixed CG_MAX_ITERATIONS (100) to prevent shape-dependent XLA
    recompilation. The while_loop convergence check exits early for
    small problems (n < 100).
    """
    _m, n = J.shape

    J_scaled = J * d[None, :]
    b = -J_scaled.T @ f

    x0 = jnp.zeros(n, dtype=b.dtype)
    r0 = b
    p0 = r0
    rsold0 = jnp.dot(r0, r0)
    tol_sq = tol * tol

    init_state = (x0, r0, p0, rsold0, 0)

    def cond_fn(state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int]):
        _x, _r, _p, rsold, i = state
        return (i < CG_MAX_ITERATIONS) & (rsold >= tol_sq)

    def body_fn(state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int]):
        x, r, p, rsold, i = state

        Jp = J_scaled @ p
        JTJp = J_scaled.T @ Jp
        Ap = JTJp + alpha * p

        pAp = jnp.dot(p, Ap)
        safe_pAp = jnp.where(
            pAp > NUMERICAL_ZERO_THRESHOLD, pAp, NUMERICAL_ZERO_THRESHOLD
        )
        alpha_cg = rsold / safe_pAp

        x_new = x + alpha_cg * p
        r_new = r - alpha_cg * Ap
        rsnew = jnp.dot(r_new, r_new)

        safe_rsold = jnp.where(rsold > 1e-30, rsold, 1.0)
        beta = jnp.where(rsold > 1e-30, rsnew / safe_rsold, 0.0)
        p_new = r_new + beta * p

        return (x_new, r_new, p_new, rsnew, i + 1)

    final_state = lax.while_loop(cond_fn, body_fn, init_state)
    x_final, _r_final, _p_final, rsold_final, n_iter = final_state

    return x_final, jnp.sqrt(rsold_final), n_iter


@jit
def _solve_tr_subproblem_cg(
    J: jnp.ndarray,
    f: jnp.ndarray,
    d: jnp.ndarray,
    Delta: float,
    alpha: float = 0.0,
) -> jnp.ndarray:
    """Solve trust region subproblem using conjugate gradient."""
    p_gn, _residual_norm, _n_iter = _conjugate_gradient_solve(J, f, d, 0.0)

    p_gn_norm = jnp.linalg.norm(p_gn)

    def compute_regularized():
        p_reg, _, _ = _conjugate_gradient_solve(J, f, d, alpha)
        p_reg_norm = jnp.maximum(jnp.linalg.norm(p_reg), 1e-10)
        # If regularized step is within trust region, use it directly;
        # otherwise scale to trust region boundary (no arbitrary clamping)
        scale = jnp.where(p_reg_norm <= Delta, 1.0, Delta / p_reg_norm)
        return scale * p_reg

    return lax.cond(
        p_gn_norm <= Delta,
        lambda: p_gn,
        compute_regularized,
    )


@jit
def _solve_tr_subproblem_cg_bounds(
    J: jnp.ndarray,
    f: jnp.ndarray,
    d: jnp.ndarray,
    J_diag: jnp.ndarray,
    f_zeros: jnp.ndarray,
    Delta: float,
    alpha: float = 0.0,
) -> jnp.ndarray:
    """Solve trust region subproblem with bounds using conjugate gradient."""
    J_augmented = jnp.concatenate([J * d[None, :], J_diag])
    f_augmented = jnp.concatenate([f, f_zeros])
    d_augmented = jnp.ones(J_augmented.shape[1], dtype=J_augmented.dtype)

    p_gn, _residual_norm, _n_iter = _conjugate_gradient_solve(
        J_augmented, f_augmented, d_augmented, 0.0
    )

    p_gn_norm = jnp.linalg.norm(p_gn)

    def compute_regularized():
        p_reg, _, _ = _conjugate_gradient_solve(
            J_augmented, f_augmented, d_augmented, alpha
        )
        p_reg_norm = jnp.maximum(jnp.linalg.norm(p_reg), 1e-10)
        # If regularized step is within trust region, use it directly;
        # otherwise scale to trust region boundary (no arbitrary clamping)
        scale = jnp.where(p_reg_norm <= Delta, 1.0, Delta / p_reg_norm)
        return scale * p_reg

    return lax.cond(
        p_gn_norm <= Delta,
        lambda: p_gn,
        compute_regularized,
    )


@jit
def _calculate_cost(rho: jnp.ndarray, data_mask: jnp.ndarray) -> jnp.ndarray:
    """Calculate cost: 0.5 * sum(masked rho[0])."""
    cost_array = jnp.where(data_mask, rho[0], 0)
    return LOSS_FUNCTION_COEFF * jnp.sum(cost_array)


@jit
def _check_isfinite(f_new: jnp.ndarray) -> jnp.ndarray:
    """Check if all residuals are finite."""
    return jnp.all(jnp.isfinite(f_new))


# ---------------------------------------------------------------------------
# Python wrappers for SVD with GPU/CPU fallback (can't be fully JIT-compiled)
# ---------------------------------------------------------------------------


def _svd_no_bounds(
    J: jnp.ndarray, d: jnp.ndarray, f: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute SVD of J in hat space (unbounded variant) with GPU fallback."""
    try:
        return _svd_no_bounds_jit(J, d, f)
    except Exception as e:
        if is_gpu_error(e):
            J_h = J * d
            U, s, V = compute_svd_with_fallback(J_h, full_matrices=False)
            uf = U.T.dot(f)
            return J_h, U, s, V, uf
        raise


def _svd_bounds(
    f: jnp.ndarray,
    J: jnp.ndarray,
    d: jnp.ndarray,
    J_diag: jnp.ndarray,
    f_zeros: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute SVD of J in hat space (bounded variant) with GPU fallback."""
    try:
        return _svd_bounds_jit(f, J, d, J_diag, f_zeros)
    except Exception as e:
        if is_gpu_error(e):
            J_h = J * d
            J_augmented = jnp.concatenate([J_h, J_diag])
            f_augmented = jnp.concatenate([f, f_zeros])
            U, s, V = compute_svd_with_fallback(J_augmented, full_matrices=False)
            uf = U.T.dot(f_augmented)
            return J_h, U, s, V, uf
        raise


# ---------------------------------------------------------------------------
# Class interface (backward-compatible, delegates to module-level singletons)
# ---------------------------------------------------------------------------


[docs] class TrustRegionJITFunctions: """JIT-compiled functions for Trust Region Reflective optimization algorithm. All JIT functions are module-level singletons to prevent XLA compilation cache bloat. Each function is compiled once per unique input shape, shared across all TrustRegionJITFunctions instances. Core Operations --------------- - **Gradient Computation**: JAX-accelerated gradient calculation using J^T * f - **SVD Decomposition**: Singular value decomposition for trust region subproblems - **Conjugate Gradient**: Iterative solver for large-scale problems - **Cost Function Evaluation**: Loss function computation with masking support - **Hat Space Transformation**: Scaled variable transformations for bounds handling Performance Characteristics --------------------------- - **Small Problems**: Direct SVD solution O(mn^2 + n^3) - **Large Problems**: CG iteration O(k*mn) where k is iteration count - **GPU Memory**: Module-level singletons prevent per-instance recompilation - **Numerical Stability**: Double precision arithmetic with condition monitoring """
[docs] def __init__(self): """Bind module-level JIT singletons to instance attributes.""" self.default_loss_func = _default_loss_func self.compute_grad = _compute_grad self.compute_grad_hat = _compute_grad_hat self.svd_no_bounds = _svd_no_bounds self.svd_bounds = _svd_bounds self.conjugate_gradient_solve = _conjugate_gradient_solve self.solve_tr_subproblem_cg = _solve_tr_subproblem_cg self.solve_tr_subproblem_cg_bounds = _solve_tr_subproblem_cg_bounds self.calculate_cost = _calculate_cost self.check_isfinite = _check_isfinite