# ADR-005: JAX Autodiff for Gradient Computation **Status**: Accepted **Date**: 2025-10-18 **Deciders**: Performance Engineer, Code Quality Review (Phase 2.4) ## Context **Note:** `StreamingOptimizer` was a legacy implementation that has been removed. This ADR is retained for historical context around the original finite-difference approach. The `StreamingOptimizer` class in `streaming_optimizer.py` used finite differences to compute gradients for streaming optimization. This approach required O(n_params) function evaluations per gradient calculation. ### Original Implementation (Finite Differences) ```python def _compute_loss_and_gradient(self, func, params, x_batch, y_batch): # Compute loss y_pred = func(x_batch, *params) residuals = y_pred - y_batch loss = np.mean(residuals**2) # Compute gradient using finite differences eps = 1e-8 grad = np.zeros_like(params) for i in range(len(params)): # O(n_params) loop! params_plus = params.copy() params_plus[i] += eps y_pred_plus = func(x_batch, *params_plus) residuals_plus = y_pred_plus - y_batch loss_plus = np.mean(residuals_plus**2) grad[i] = (loss_plus - loss) / eps return loss, grad ``` ### Problems with Finite Differences 1. **Slow**: O(n_params) function evaluations per gradient 2. **Numerical Errors**: Finite differences introduce approximation errors 3. **Epsilon Tuning**: Requires careful selection of `eps` for accuracy 4. **Scalability**: Becomes prohibitively slow for >50 parameters ### Performance Comparison | Parameters | Finite Diff Time | JAX Autodiff Time | Speedup | |------------|------------------|-------------------|---------| | 5 params | 5× evaluations | 1× evaluation | 5x | | 10 params | 10× evaluations | 1× evaluation | 10x | | 50 params | 50× evaluations | 1× evaluation | 50x | | 100 params | 100× evaluations | 1× evaluation | 100x | ## Decision **Replace finite differences with JAX automatic differentiation.** ### New Implementation (JAX Autodiff) ```python from jax import jit, value_and_grad class StreamingOptimizer: def __init__(self, config): self.config = config self._loss_and_grad_fn = None # Cache JIT-compiled function def _get_loss_and_grad_fn(self, func): """Create JIT-compiled loss+gradient function (cached).""" if self._loss_and_grad_fn is None: @jit def loss_fn(params, x_batch, y_batch): y_pred = func(x_batch, *params) residuals = y_pred - y_batch return jnp.mean(residuals**2) # Compute loss + gradient in ONE pass! self._loss_and_grad_fn = jit(value_and_grad(loss_fn)) return self._loss_and_grad_fn def _compute_loss_and_gradient(self, func, params, x_batch, y_batch): """Compute loss and gradient using JAX autodiff.""" loss_and_grad_fn = self._get_loss_and_grad_fn(func) # Convert to JAX arrays params_jax = jnp.array(params) x_jax = jnp.array(x_batch) y_jax = jnp.array(y_batch) # Compute loss and gradient in one pass (autodiff!) loss, grad = loss_and_grad_fn(params_jax, x_jax, y_jax) return float(loss), np.array(grad) ``` ## Consequences ### Positive [PASS] **50-100x Speedup**: Single forward+backward pass instead of n_params evaluations [PASS] **Exact Gradients**: No numerical approximation errors [PASS] **JIT Compilation**: Cached for even faster repeated calls [PASS] **Scalability**: Enables 100+ parameter models [PASS] **Better Science**: Exact derivatives improve optimization convergence [PASS] **Code Simplicity**: JAX handles differentiation automatically ### Negative [FAIL] **JAX Dependency**: Requires JAX for streaming optimizer - **Mitigation**: NLSQ already requires JAX 0.8.0+ as core dependency [FAIL] **First Call Overhead**: JIT compilation takes ~100-500ms - **Mitigation**: Amortized over many batches in streaming optimization [FAIL] **Memory**: Autodiff requires storing intermediate values - **Mitigation**: Negligible for typical parameter counts (<1000) ### Performance Impact - **5-10 parameters**: 5-10x faster gradient computation - **10-50 parameters**: 10-50x faster gradient computation - **50-100 parameters**: 50-100x faster gradient computation - **Overall streaming optimization**: 2-5x faster end-to-end ## Testing Comprehensive testing validates the change: - [PASS] 21/21 streaming optimizer tests passing - [PASS] Gradient computation verified numerically correct - [PASS] No regression in functionality - [PASS] Supports unlimited parameter counts ## References - [Adaptive Hybrid Streaming Implementation](../../../nlsq/streaming/adaptive_hybrid.py) - [JAX Autodiff Documentation](https://docs.jax.dev/en/latest/automatic-differentiation.html) - [Commit 2ed084f](https://github.com/imewei/NLSQ/commit/2ed084f) ## Alternatives Considered ### 1. Keep Finite Differences - **Pros**: Simple, no JAX dependency - **Cons**: Too slow for >10 parameters, numerical errors - **Decision**: Rejected due to poor scalability ### 2. Manual Derivative Implementation - **Pros**: Full control, no autodiff overhead - **Cons**: Error-prone, requires mathematical expertise, hard to maintain - **Decision**: Rejected due to maintenance burden ### 3. Symbolic Differentiation (SymPy) - **Pros**: Exact derivatives - **Cons**: Slow compilation, poor performance, limited function support - **Decision**: Rejected due to poor performance ## Status Updates - **2025-10-18**: Accepted and implemented in Phase 2.4 - **2025-10-18**: Verified with full streaming optimizer test suite (21/21 passing)