Source code for nlsq.core.sparse_jacobian

"""Sparse Jacobian support for large-scale optimization.

This module provides sparse matrix support for Jacobian computations,
enabling efficient handling of problems with 20M+ data points.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import jax.numpy as jnp
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix

from nlsq.constants import FINITE_DIFF_REL_STEP
from nlsq.utils.logging import get_logger

if TYPE_CHECKING:
    from nlsq.result import CurveFitResult

logger = get_logger(__name__)


[docs] class SparseJacobianComputer: """Compute and manage sparse Jacobians for large-scale problems. For many curve fitting problems, the Jacobian has a sparse structure where each data point only depends on a subset of parameters. This class exploits that structure to reduce memory usage by 10-100x. """
[docs] def __init__(self, sparsity_threshold: float = 0.01): """Initialize sparse Jacobian computer. Parameters ---------- sparsity_threshold : float Elements with absolute value below this threshold are considered zero. Default is 0.01 which works well for most problems. """ self.sparsity_threshold = sparsity_threshold self._sparsity_pattern: np.ndarray | None = None self._sparse_indices: tuple | None = None
[docs] def detect_sparsity_pattern( self, func: Callable, x0: np.ndarray, xdata_sample: np.ndarray | list, n_samples: int = 100, ) -> tuple[np.ndarray, float]: """Detect sparsity pattern of Jacobian from sample evaluations. Parameters ---------- func : Callable Function to evaluate x0 : np.ndarray Initial parameter values xdata_sample : np.ndarray or list Sample of x data points. Can be a single array for 1D problems, a list of arrays [X, Y] for 2D problems, or a 2D array with shape (k, N) for multi-dimensional coordinates. n_samples : int Number of samples to use for pattern detection Returns ------- pattern : np.ndarray Boolean array indicating non-zero elements sparsity : float Fraction of zero elements """ n_params = len(x0) xdata_arr = np.asarray(xdata_sample) # Handle multi-dimensional xdata (e.g., [X, Y] for 2D fitting) xdata_sliced: list | np.ndarray # Can be list or ndarray if isinstance(xdata_sample, list | tuple): # Get number of data points from first coordinate array n_data_points = len(xdata_sample[0]) n_data = min(n_samples, n_data_points) # Slice each coordinate array xdata_sliced = [coord[:n_data] for coord in xdata_sample] elif xdata_arr.ndim == 2 and xdata_arr.shape[0] < xdata_arr.shape[1]: # 2D array with shape (k, N) - slice along N dimension n_data_points = xdata_arr.shape[1] n_data = min(n_samples, n_data_points) xdata_sliced = xdata_arr[:, :n_data] else: n_data = min(n_samples, len(xdata_arr)) xdata_sliced = xdata_arr[:n_data] # Sample Jacobian at a few points to detect pattern pattern = np.zeros((n_data, n_params), dtype=bool) # Use finite differences to detect sparsity eps = FINITE_DIFF_REL_STEP f0 = np.asarray(func(xdata_sliced, *x0)) # Validate that function output matches the subsampled xdata size. # Closure-based model functions may ignore xdata and always return # full-size output, making sparsity detection inapplicable. if f0.shape[0] != n_data: raise ValueError( f"Function output size ({f0.shape[0]}) does not match " f"xdata sample size ({n_data}). The function may use " f"closure-captured data instead of xdata." ) for i in range(n_params): # OPT-6: Use JAX functional update instead of copy + mutate step = eps * max(1.0, abs(float(x0[i]))) x_perturb = jnp.asarray(x0).at[i].add(step) f_perturb = func(xdata_sliced, *x_perturb) # Compute finite difference jac_col = (f_perturb - f0) / step # Mark non-zero elements pattern[:, i] = np.abs(jac_col) > self.sparsity_threshold # Calculate sparsity (handle empty pattern) if pattern.size > 0: sparsity = 1.0 - np.sum(pattern) / pattern.size else: sparsity = 0.0 # No data means no sparsity information self._sparsity_pattern = pattern return pattern, sparsity
[docs] def compute_sparse_jacobian( self, jac_func: Callable, x: np.ndarray, xdata: np.ndarray, ydata: np.ndarray, data_mask: np.ndarray | None = None, chunk_size: int = 10000, func: Callable | None = None, # Add func parameter for finite diff fallback ) -> csr_matrix: """Compute Jacobian in sparse format with chunking. Parameters ---------- jac_func : Callable Jacobian function x : np.ndarray Current parameter values xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data data_mask : np.ndarray, optional Mask for valid data points chunk_size : int Size of chunks for computation func : Callable, optional Original function for finite difference fallback Returns ------- J_sparse : csr_matrix Sparse Jacobian matrix """ n_data = len(ydata) n_params = len(x) n_chunks = (n_data + chunk_size - 1) // chunk_size if data_mask is None: data_mask = np.ones(n_data, dtype=bool) # Accumulate COO format data across chunks for fully vectorized construction all_rows: list[np.ndarray] = [] all_cols: list[np.ndarray] = [] all_values: list[np.ndarray] = [] # Process in chunks to manage memory for chunk_idx in range(n_chunks): start = chunk_idx * chunk_size end = min((chunk_idx + 1) * chunk_size, n_data) # Compute dense Jacobian for chunk x_chunk = xdata[start:end] if hasattr(xdata, "__getitem__") else xdata y_chunk = ydata[start:end] mask_chunk = data_mask[start:end] # Convert to JAX arrays for computation x_jax = jnp.asarray(x) # Compute Jacobian for chunk (assuming jac_func returns dense) if callable(jac_func): J_chunk = jac_func(x_jax, x_chunk, y_chunk, mask_chunk, None) else: # Fallback to finite differences if no jac_func if func is None: raise ValueError( "func parameter required for finite difference fallback" ) J_chunk = self._finite_diff_jacobian( func, x, x_chunk, y_chunk, mask_chunk ) # Convert to numpy if needed if hasattr(J_chunk, "block_until_ready"): J_chunk = np.array(J_chunk) # Vectorized sparse extraction: O(nnz) instead of O(nm) # Find elements above threshold using NumPy vectorization mask = np.abs(J_chunk) > self.sparsity_threshold chunk_rows, chunk_cols = np.where(mask) chunk_values = J_chunk[chunk_rows, chunk_cols] # Adjust row indices for the full matrix offset chunk_rows = chunk_rows + start # Accumulate for batch construction all_rows.append(chunk_rows) all_cols.append(chunk_cols) all_values.append(chunk_values) # Build sparse matrix in one vectorized operation using COO format if all_rows: rows = np.concatenate(all_rows) cols = np.concatenate(all_cols) values = np.concatenate(all_values) J_sparse = coo_matrix((values, (rows, cols)), shape=(n_data, n_params)) else: # Empty matrix case J_sparse = coo_matrix((n_data, n_params)) # Convert to CSR format for efficient operations return J_sparse.tocsr()
def _finite_diff_jacobian( self, func: Callable, x: np.ndarray, xdata: np.ndarray, ydata: np.ndarray, data_mask: np.ndarray, eps: float = FINITE_DIFF_REL_STEP, ) -> np.ndarray: """Compute Jacobian using finite differences as fallback. Parameters ---------- func : Callable Function to differentiate x : np.ndarray Current parameter values xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data data_mask : np.ndarray Mask for valid data eps : float Finite difference step size Returns ------- J : np.ndarray Dense Jacobian matrix for chunk """ n_data = len(ydata) n_params = len(x) J = np.zeros((n_data, n_params)) # Base function evaluation f0 = func(xdata, *x) f0 = f0 - ydata f0 = np.where(data_mask, f0, 0) # Compute finite differences for j in range(n_params): # OPT-6: Use JAX functional update instead of copy + mutate step = eps * max(1.0, abs(float(x[j]))) x_perturb = jnp.asarray(x).at[j].add(step) f_perturb = func(xdata, *x_perturb) f_perturb = f_perturb - ydata f_perturb = np.where(data_mask, f_perturb, 0) J[:, j] = (f_perturb - f0) / step return J
[docs] def sparse_matrix_vector_product( self, J_sparse: csr_matrix, v: np.ndarray ) -> np.ndarray: """Efficient sparse matrix-vector product. Parameters ---------- J_sparse : csr_matrix Sparse Jacobian matrix v : np.ndarray Vector to multiply Returns ------- result : np.ndarray J @ v """ return J_sparse @ v
[docs] def sparse_normal_equations( self, J_sparse: csr_matrix, f: np.ndarray ) -> tuple[Callable, np.ndarray]: """Set up normal equations with sparse Jacobian. Solves (J^T @ J) @ p = -J^T @ f without forming J^T @ J explicitly. Parameters ---------- J_sparse : csr_matrix Sparse Jacobian matrix f : np.ndarray Residual vector Returns ------- matvec : callable Function that computes (J^T @ J) @ v rhs : np.ndarray Right-hand side -J^T @ f """ def matvec(v): """Compute (J^T @ J) @ v without forming J^T @ J.""" Jv = J_sparse @ v return J_sparse.T @ Jv rhs = -J_sparse.T @ f return matvec, rhs
[docs] def estimate_memory_usage( self, n_data: int, n_params: int, sparsity: float = 0.99 ) -> dict: """Estimate memory usage for sparse vs dense Jacobian. Parameters ---------- n_data : int Number of data points n_params : int Number of parameters sparsity : float Fraction of zero elements (0-1) Returns ------- memory_info : dict Memory usage estimates in GB """ # Dense memory usage dense_bytes = n_data * n_params * 8 # 8 bytes per float64 dense_gb = dense_bytes / (1024**3) # Sparse memory usage (CSR format) # Need to store: values, column indices, row pointers nnz = int(n_data * n_params * (1 - sparsity)) sparse_bytes = nnz * 8 # values sparse_bytes += nnz * 4 # column indices (int32) sparse_bytes += (n_data + 1) * 4 # row pointers (int32) sparse_gb = sparse_bytes / (1024**3) # Memory savings savings = (dense_gb - sparse_gb) / dense_gb * 100 return { "dense_gb": dense_gb, "sparse_gb": sparse_gb, "savings_percent": savings, "sparsity": sparsity, "nnz": nnz, "reduction_factor": dense_gb / sparse_gb if sparse_gb > 0 else float("inf"), }
[docs] class SparseOptimizer: """Optimizer that uses sparse Jacobians for large-scale problems. This optimizer automatically detects when sparse Jacobians would be beneficial and switches to sparse computations transparently. """
[docs] def __init__( self, sparsity_threshold: float = 0.01, min_sparsity: float = 0.9, auto_detect: bool = True, ): """Initialize sparse optimizer. Parameters ---------- sparsity_threshold : float Threshold for considering elements as zero min_sparsity : float Minimum sparsity level to use sparse methods auto_detect : bool Automatically detect and use sparsity """ self.sparsity_threshold = sparsity_threshold self.min_sparsity = min_sparsity self.auto_detect = auto_detect self.sparse_computer = SparseJacobianComputer(sparsity_threshold) self._use_sparse = False self._detected_sparsity = 0.0
[docs] def should_use_sparse( self, n_data: int, n_params: int, force_check: bool = False ) -> bool: """Determine if sparse methods should be used. Parameters ---------- n_data : int Number of data points n_params : int Number of parameters force_check : bool Force sparsity detection even if auto_detect is False Returns ------- use_sparse : bool Whether to use sparse methods """ # Heuristic: use sparse for large problems problem_size = n_data * n_params if problem_size < 1e6: # Less than 1M elements return False if not self.auto_detect and not force_check: # For very large problems, assume sparse is beneficial return problem_size > 1e8 # More than 100M elements # Auto-detect based on problem characteristics # Many curve fitting problems have local parameter influence expected_sparsity = 1.0 - min(10.0 / n_params, 1.0) return expected_sparsity >= self.min_sparsity
[docs] def optimize_with_sparsity( self, func: Callable, x0: np.ndarray, xdata: np.ndarray, ydata: np.ndarray, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray] | CurveFitResult: """Optimize using sparse Jacobian methods. Parameters ---------- func : Callable Objective function x0 : np.ndarray Initial parameters xdata : np.ndarray Independent variable data ydata : np.ndarray Dependent variable data **kwargs Additional optimization parameters Returns ------- result : dict Optimization result """ n_data = len(ydata) n_params = len(x0) # Check if sparse methods should be used self._use_sparse = self.should_use_sparse(n_data, n_params) if self._use_sparse: logger.info( f"Using sparse Jacobian methods for {n_data}×{n_params} problem" ) # Detect sparsity pattern from samples sample_size = min(1000, n_data) sample_indices = np.random.choice(n_data, sample_size, replace=False) _pattern, sparsity = self.sparse_computer.detect_sparsity_pattern( func, x0, xdata[sample_indices], sample_size ) self._detected_sparsity = sparsity logger.info(f"Detected sparsity: {sparsity:.1%}") # Estimate memory savings memory_info = self.sparse_computer.estimate_memory_usage( n_data, n_params, sparsity ) logger.info(f"Memory savings: {memory_info['savings_percent']:.1f}%") logger.info( f"Dense: {memory_info['dense_gb']:.2f}GB → Sparse: {memory_info['sparse_gb']:.2f}GB" ) # Use sparse methods if beneficial if sparsity >= self.min_sparsity: return self._optimize_sparse(func, x0, xdata, ydata, **kwargs) # Fall back to standard dense optimization logger.info(f"Using standard dense methods for {n_data}×{n_params} problem") from nlsq.core.minpack import curve_fit return curve_fit(func, xdata, ydata, p0=x0, **kwargs)
def _optimize_sparse( self, func: Callable, x0: np.ndarray, xdata: np.ndarray, ydata: np.ndarray, **kwargs, ): """Internal sparse optimization implementation. This would integrate with the existing TRF optimizer but using sparse matrix operations throughout. """ # This is a simplified implementation # Full implementation would integrate with TrustRegionReflective # For now, return a placeholder indicating sparse methods would be used return { "x": x0, "success": True, "message": "Sparse optimization placeholder", "sparsity": self._detected_sparsity, "method": "sparse", }
[docs] def detect_jacobian_sparsity( func: Callable, x0: np.ndarray, xdata_sample: np.ndarray | list, threshold: float = 0.01, ) -> tuple[float, dict]: """Detect and analyze Jacobian sparsity for a given problem. Parameters ---------- func : Callable Objective function x0 : np.ndarray Initial parameters xdata_sample : np.ndarray or list Sample of x data. Can be a single array for 1D problems, a list of arrays [X, Y] for 2D problems, or a 2D array with shape (k, N) for multi-dimensional coordinates. threshold : float Threshold for zero elements Returns ------- sparsity : float Fraction of zero elements info : dict Additional sparsity information """ # Get number of data points, handling multi-dimensional xdata xdata_arr = np.asarray(xdata_sample) if isinstance(xdata_sample, list | tuple): n_data_points = len(xdata_sample[0]) elif xdata_arr.ndim == 2 and xdata_arr.shape[0] < xdata_arr.shape[1]: # 2D array with shape (k, N) - N is number of data points n_data_points = xdata_arr.shape[1] else: n_data_points = len(xdata_arr) computer = SparseJacobianComputer(threshold) pattern, sparsity = computer.detect_sparsity_pattern( func, x0, xdata_sample, min(100, n_data_points) ) # Analyze pattern _n_data, _n_params = pattern.shape nnz_per_row = np.sum(pattern, axis=1) nnz_per_col = np.sum(pattern, axis=0) info = { "sparsity": sparsity, "nnz": np.sum(pattern), "avg_nnz_per_row": np.mean(nnz_per_row), "avg_nnz_per_col": np.mean(nnz_per_col), "max_nnz_per_row": np.max(nnz_per_row), "max_nnz_per_col": np.max(nnz_per_col), "pattern_shape": pattern.shape, "memory_reduction": sparsity * 100, } return sparsity, info
[docs] def detect_sparsity_at_p0( func: Callable, p0: np.ndarray, xdata: np.ndarray, n_residuals: int, threshold: float = 0.01, sample_size: int = 100, ) -> tuple[float, bool]: """Detect sparsity at p0 initialization for automatic sparse solver selection. This function computes the Jacobian at the initial parameter guess p0 and calculates the sparsity ratio. The result is cached to avoid recomputation. Parameters ---------- func : Callable Model function f(x, \\*params) -> residuals p0 : np.ndarray Initial parameter guess xdata : np.ndarray Independent variable data n_residuals : int Number of residuals (data points) threshold : float, optional Threshold for considering elements as zero (default: 0.01) sample_size : int, optional Number of data points to sample for detection (default: 100) Using a sample speeds up detection for large datasets Returns ------- sparsity_ratio : float Fraction of zero elements in Jacobian (0.0 = dense, 1.0 = completely sparse) Example: 0.9 means 90% of Jacobian elements are zero is_sparse : bool Whether the problem is considered sparse (sparsity_ratio > 0.5) Notes ----- Detection strategy: - Samples up to `sample_size` data points for efficiency - Uses finite differences to compute Jacobian at p0 - Considers elements with \\|J[i,j]\\| < threshold as zero - Caches result to avoid repeated computation For auto-selection to activate sparse solver, both conditions must be met: - sparsity_ratio > 0.5 (more than 50% zeros) - n_residuals > 10000 (problem size threshold) Examples -------- >>> def model(x, a, b): ... # Each parameter affects different data regions (sparse) ... return jnp.where(x < 0.5, a, b) >>> p0 = np.array([1.0, 2.0]) >>> xdata = np.linspace(0, 1, 1000) >>> sparsity_ratio, is_sparse = detect_sparsity_at_p0( ... model, p0, xdata, n_residuals=1000 ... ) >>> print(f"Sparsity: {sparsity_ratio:.1%}, Is sparse: {is_sparse}") Sparsity: 50.0%, Is sparse: True """ # Sample data for efficient detection # Handle multi-dimensional xdata (e.g., [X, Y] for 2D fitting) xdata_arr = np.asarray(xdata) if isinstance(xdata, list | tuple): # xdata is a list of coordinate arrays (e.g., [X, Y] for 2D problems) # Get the number of data points from the first coordinate array n_data_points = len(xdata[0]) actual_sample_size = min(sample_size, n_residuals, n_data_points) if actual_sample_size < n_data_points: sample_indices = np.linspace( 0, n_data_points - 1, actual_sample_size, dtype=int ) xdata_sample = [coord[sample_indices] for coord in xdata] else: xdata_sample = xdata elif xdata_arr.ndim == 2 and xdata_arr.shape[0] < xdata_arr.shape[1]: # xdata is a 2D array with shape (k, N) where k is number of coords # and N is number of data points (e.g., shape (2, 40000) for X,Y) n_data_points = xdata_arr.shape[1] actual_sample_size = min(sample_size, n_residuals, n_data_points) if actual_sample_size < n_data_points: sample_indices = np.linspace( 0, n_data_points - 1, actual_sample_size, dtype=int ) xdata_sample = xdata_arr[:, sample_indices] else: xdata_sample = xdata else: # xdata is a single 1D array n_data_points = len(xdata_arr) actual_sample_size = min(sample_size, n_residuals, n_data_points) if actual_sample_size < n_data_points: sample_indices = np.linspace( 0, n_data_points - 1, actual_sample_size, dtype=int ) xdata_sample = xdata_arr[sample_indices] else: xdata_sample = xdata # Use existing detection function sparsity_ratio, _info = detect_jacobian_sparsity( func, p0, xdata_sample, threshold=threshold ) # Determine if sparse based on sparsity threshold is_sparse = sparsity_ratio > 0.5 logger.debug( f"Sparsity detection at p0: {sparsity_ratio:.1%} " f"({'sparse' if is_sparse else 'dense'})" ) return sparsity_ratio, is_sparse