ADR-005: JAX Autodiff for Gradient Computation

Status: Accepted

Date: 2025-10-18

Deciders: Performance Engineer, Code Quality Review (Phase 2.4)

Context

Note: StreamingOptimizer was a legacy implementation that has been removed. This ADR is retained for historical context around the original finite-difference approach.

The StreamingOptimizer class in streaming_optimizer.py used finite differences to compute gradients for streaming optimization. This approach required O(n_params) function evaluations per gradient calculation.

Original Implementation (Finite Differences)

def _compute_loss_and_gradient(self, func, params, x_batch, y_batch):
    # Compute loss
    y_pred = func(x_batch, *params)
    residuals = y_pred - y_batch
    loss = np.mean(residuals**2)

    # Compute gradient using finite differences
    eps = 1e-8
    grad = np.zeros_like(params)

    for i in range(len(params)):  # O(n_params) loop!
        params_plus = params.copy()
        params_plus[i] += eps

        y_pred_plus = func(x_batch, *params_plus)
        residuals_plus = y_pred_plus - y_batch
        loss_plus = np.mean(residuals_plus**2)

        grad[i] = (loss_plus - loss) / eps

    return loss, grad

Problems with Finite Differences

  1. Slow: O(n_params) function evaluations per gradient

  2. Numerical Errors: Finite differences introduce approximation errors

  3. Epsilon Tuning: Requires careful selection of eps for accuracy

  4. Scalability: Becomes prohibitively slow for >50 parameters

Performance Comparison

Parameters

Finite Diff Time

JAX Autodiff Time

Speedup

5 params

5× evaluations

1× evaluation

5x

10 params

10× evaluations

1× evaluation

10x

50 params

50× evaluations

1× evaluation

50x

100 params

100× evaluations

1× evaluation

100x

Decision

Replace finite differences with JAX automatic differentiation.

New Implementation (JAX Autodiff)

from jax import jit, value_and_grad


class StreamingOptimizer:
    def __init__(self, config):
        self.config = config
        self._loss_and_grad_fn = None  # Cache JIT-compiled function

    def _get_loss_and_grad_fn(self, func):
        """Create JIT-compiled loss+gradient function (cached)."""
        if self._loss_and_grad_fn is None:

            @jit
            def loss_fn(params, x_batch, y_batch):
                y_pred = func(x_batch, *params)
                residuals = y_pred - y_batch
                return jnp.mean(residuals**2)

            # Compute loss + gradient in ONE pass!
            self._loss_and_grad_fn = jit(value_and_grad(loss_fn))

        return self._loss_and_grad_fn

    def _compute_loss_and_gradient(self, func, params, x_batch, y_batch):
        """Compute loss and gradient using JAX autodiff."""
        loss_and_grad_fn = self._get_loss_and_grad_fn(func)

        # Convert to JAX arrays
        params_jax = jnp.array(params)
        x_jax = jnp.array(x_batch)
        y_jax = jnp.array(y_batch)

        # Compute loss and gradient in one pass (autodiff!)
        loss, grad = loss_and_grad_fn(params_jax, x_jax, y_jax)

        return float(loss), np.array(grad)

Consequences

Positive

[PASS] 50-100x Speedup: Single forward+backward pass instead of n_params evaluations [PASS] Exact Gradients: No numerical approximation errors [PASS] JIT Compilation: Cached for even faster repeated calls [PASS] Scalability: Enables 100+ parameter models [PASS] Better Science: Exact derivatives improve optimization convergence [PASS] Code Simplicity: JAX handles differentiation automatically

Negative

[FAIL] JAX Dependency: Requires JAX for streaming optimizer

  • Mitigation: NLSQ already requires JAX 0.8.0+ as core dependency [FAIL] First Call Overhead: JIT compilation takes ~100-500ms

  • Mitigation: Amortized over many batches in streaming optimization [FAIL] Memory: Autodiff requires storing intermediate values

  • Mitigation: Negligible for typical parameter counts (<1000)

Performance Impact

  • 5-10 parameters: 5-10x faster gradient computation

  • 10-50 parameters: 10-50x faster gradient computation

  • 50-100 parameters: 50-100x faster gradient computation

  • Overall streaming optimization: 2-5x faster end-to-end

Testing

Comprehensive testing validates the change:

  • [PASS] 21/21 streaming optimizer tests passing

  • [PASS] Gradient computation verified numerically correct

  • [PASS] No regression in functionality

  • [PASS] Supports unlimited parameter counts

References

Alternatives Considered

1. Keep Finite Differences

  • Pros: Simple, no JAX dependency

  • Cons: Too slow for >10 parameters, numerical errors

  • Decision: Rejected due to poor scalability

2. Manual Derivative Implementation

  • Pros: Full control, no autodiff overhead

  • Cons: Error-prone, requires mathematical expertise, hard to maintain

  • Decision: Rejected due to maintenance burden

3. Symbolic Differentiation (SymPy)

  • Pros: Exact derivatives

  • Cons: Slow compilation, poor performance, limited function support

  • Decision: Rejected due to poor performance

Status Updates

  • 2025-10-18: Accepted and implemented in Phase 2.4

  • 2025-10-18: Verified with full streaming optimizer test suite (21/21 passing)