"""Robust loss functions for outlier-resistant curve fitting.
This module provides JIT-compiled implementations of robust loss functions including
Huber, Cauchy, soft_l1, and arctan for nonlinear least squares optimization.
These functions reduce the influence of outliers compared to standard least squares.
Robust loss functions replace the squared residual z with a function ρ(z) that
grows more slowly for large residuals, making the optimization less sensitive to
outliers while preserving accuracy for well-behaved data.
Available Loss Functions:
- 'linear' (default): Standard least squares, ρ(z) = z
- 'huber': Quadratic for small residuals, linear for large (recommended)
- 'soft_l1': Smooth approximation to L1 loss
- 'cauchy': Heavy-tailed, very robust to outliers
- 'arctan': Bounded loss function
Example:
>>> from nlsq import curve_fit
>>> import jax.numpy as jnp
>>>
>>> def model(x, a, b): return a * jnp.exp(-b * x)
>>>
>>> # Fit with Huber loss to handle outliers
>>> popt, pcov = curve_fit(model, x, y, p0=[2.0, 0.5], loss='huber')
See Also:
nlsq.curve_fit : Main fitting function that uses these loss functions
nlsq.least_squares : Lower-level interface with loss function control
"""
# Initialize JAX configuration through central config
from nlsq.config import JAXConfig
_jax_config = JAXConfig()
import jax.numpy as jnp
from jax import jit
[docs]
class LossFunctionsJIT:
"""JIT-compiled robust loss functions for nonlinear least squares optimization.
This class provides GPU/TPU-accelerated implementations of robust loss functions
that reduce the influence of outliers in curve fitting. All loss functions are
JIT-compiled for maximum performance and include analytical derivatives required
for efficient optimization.
Robust Loss Function Theory
----------------------------
Standard least squares minimizes sum of squared residuals, making it sensitive
to outliers. Robust loss functions ρ(z) replace z (squared residuals) with
functions that grow more slowly for large residuals:
Standard LS: min Σ f_i²
Robust LS: min Σ ρ(f_i²/σ²)
where σ is the scaling parameter (f_scale) and z = (f/σ)².
Available Loss Functions
-------------------------
1. **linear**: Standard least squares (ρ(z) = z)
- No outlier protection
- Fastest computation
- Optimal for clean data without outliers
2. **huber**: Huber loss function
- ρ(z) = z if z ≤ 1, else 2√z - 1
- Quadratic for small residuals, linear for large ones
- Good balance between efficiency and robustness
- Recommended for data with moderate outliers
3. **soft_l1**: Soft L1 loss function
- ρ(z) = 2(√(1+z) - 1)
- Smooth approximation to L1 norm
- More robust than Huber for severe outliers
- Preserves differentiability everywhere
4. **cauchy**: Cauchy (Lorentzian) loss function
- ρ(z) = ln(1 + z)
- Extremely robust to outliers
- Can handle heavy-tailed error distributions
- May converge slowly for well-behaved data
5. **arctan**: Arctangent loss function
- ρ(z) = arctan(z)
- Bounded loss function
- Very robust to extreme outliers
- Useful for data with unknown error characteristics
Mathematical Implementation
----------------------------
Each loss function computes three quantities:
- **ρ(z)**: Loss function value
- **ρ'(z)**: First derivative for gradient computation
- **ρ''(z)**: Second derivative for Hessian approximation
The derivatives are used in the optimization algorithm:
- Gradient: g = J^T (rho'(z) ⊙ f)
- Hessian: H ≈ J^T diag(rho'(z)) J + J^T diag(rho''(z) ⊙ f²) J
Performance Characteristics
----------------------------
- **JIT Compilation**: All functions compiled for GPU/TPU acceleration
- **Vectorized Operations**: Efficient batch processing of residuals
- **Memory Optimization**: In-place operations where possible
- **Numerical Stability**: Careful handling of edge cases and overflow
Usage Example
-------------
::
from nlsq.core.loss_functions import LossFunctionsJIT
# Initialize loss function handler
loss_jit = LossFunctionsJIT()
# Get robust loss function
huber_loss = loss_jit.get_loss_function('huber')
# Apply to residuals
residuals = jnp.array([0.1, 5.0, 0.2, 10.0]) # Contains outliers
f_scale = 1.0
data_mask = jnp.ones_like(residuals, dtype=bool)
# Compute loss with derivatives
rho = huber_loss(residuals, f_scale, data_mask, cost_only=False)
# rho[0] = loss values, rho[1] = first derivatives, rho[2] = second derivatives
# Compute total cost only
cost = huber_loss(residuals, f_scale, data_mask, cost_only=True)
Loss Function Selection Guidelines
-----------------------------------
- **Clean Data**: Use 'linear' for maximum efficiency
- **Few Outliers**: Use 'huber' for balanced robustness
- **Many Outliers**: Use 'soft_l1' or 'cauchy'
- **Unknown Data Quality**: Start with 'huber', upgrade if needed
- **Extreme Outliers**: Use 'cauchy' or 'arctan'
Scale Parameter (f_scale)
--------------------------
The scale parameter σ (f_scale) determines the transition point between
quadratic and robust behavior:
- **Too Small**: All residuals treated as outliers
- **Too Large**: No outlier protection
- **Optimal**: ~median absolute residual or robust MAD estimate
- **Adaptive**: Can be estimated during optimization
Technical Implementation Details
--------------------------------
- All functions handle both scalar and vector inputs
- Derivatives computed analytically for accuracy
- Special handling for z=0 to avoid numerical issues
- Efficient masking for missing data points
- Compatible with JAX transformations (grad, jit, vmap)
"""
[docs]
def __init__(self):
self.stack_rhos = self.create_stack_rhos()
self.create_huber_funcs()
self.create_soft_l1_funcs()
self.create_cauchy_funcs()
self.create_arctan_funcs()
self.IMPLEMENTED_LOSSES = {
"linear": None,
"huber": self.huber,
"soft_l1": self.soft_l1,
"cauchy": self.cauchy,
"arctan": self.arctan,
}
self.loss_funcs = self.construct_all_loss_functions()
self.create_zscale()
self.create_calculate_cost()
self.create_scale_rhos()
[docs]
def create_stack_rhos(self):
"""Create JIT-compiled function to stack rho values into array.
Returns a function that stacks rho0, rho1, rho2 into a (3, n) array
for efficient loss and derivative computation.
"""
@jit
def stack_rhos(rho0, rho1, rho2):
return jnp.stack([rho0, rho1, rho2])
return stack_rhos
[docs]
def get_empty_rhos(self, z):
"""Return zero arrays for rho1 and rho2 when only cost is needed.
Used when cost_only=True to avoid computing unnecessary derivatives.
"""
dlength = len(z)
rho1 = jnp.zeros([dlength])
rho2 = jnp.zeros([dlength])
return rho1, rho2
[docs]
def create_huber_funcs(self):
"""Create JIT-compiled Huber loss functions.
Creates huber1 (rho0 and mask) and huber2 (rho1, rho2 derivatives).
Huber loss is quadratic for ``|z| <= 1`` and linear for ``|z| > 1``.
"""
@jit
def huber1(z):
mask = z <= 1
return jnp.where(mask, z, 2 * z**0.5 - 1), mask
@jit
def huber2(z, mask):
safe_z = jnp.where(mask, 1.0, z)
rho1 = jnp.where(mask, 1, safe_z**-0.5)
rho2 = jnp.where(mask, 0, -0.5 * safe_z**-1.5)
return rho1, rho2
self.huber1 = huber1
self.huber2 = huber2
[docs]
def huber(self, z, cost_only):
"""Compute Huber loss rho values."""
rho0, mask = self.huber1(z)
if cost_only:
rho1, rho2 = self.get_empty_rhos(z)
else:
rho1, rho2 = self.huber2(z, mask)
return self.stack_rhos(rho0, rho1, rho2)
[docs]
def create_soft_l1_funcs(self):
"""Create JIT-compiled soft L1 loss functions.
Creates soft_l1_1 (rho0 and intermediate t) and soft_l1_2 (derivatives).
Soft L1 is a smooth approximation to L1 loss: rho(z) = 2*(sqrt(1+z) - 1).
"""
@jit
def soft_l1_1(z):
t = 1 + z
return 2 * (t**0.5 - 1), t
@jit
def soft_l1_2(t):
rho1 = t**-0.5
rho2 = -0.5 * t**-1.5
return rho1, rho2
self.soft_l1_1 = soft_l1_1
self.soft_l1_2 = soft_l1_2
[docs]
def soft_l1(self, z, cost_only):
"""Compute soft L1 loss rho values."""
rho0, t = self.soft_l1_1(z)
if cost_only:
rho1, rho2 = self.get_empty_rhos(z)
else:
rho1, rho2 = self.soft_l1_2(t)
return self.stack_rhos(rho0, rho1, rho2)
[docs]
def create_cauchy_funcs(self):
"""Create JIT-compiled Cauchy (Lorentzian) loss functions.
Creates cauchy1 (rho0) and cauchy2 (derivatives).
Cauchy loss: rho(z) = ln(1 + z). Very robust to outliers.
"""
@jit
def cauchy1(z):
return jnp.log1p(z)
@jit
def cauchy2(z):
t = 1 + z
rho1 = 1 / t
rho2 = -1 / t**2
return rho1, rho2
self.cauchy1 = cauchy1
self.cauchy2 = cauchy2
[docs]
def cauchy(self, z, cost_only):
"""Compute Cauchy loss rho values."""
rho0 = self.cauchy1(z)
if cost_only:
rho1, rho2 = self.get_empty_rhos(z)
else:
rho1, rho2 = self.cauchy2(z)
return self.stack_rhos(rho0, rho1, rho2)
[docs]
def create_arctan_funcs(self):
"""Create JIT-compiled arctan loss functions.
Creates arctan1 (rho0) and arctan2 (derivatives).
Arctan loss: rho(z) = arctan(z). Bounded loss for extreme outliers.
"""
@jit
def arctan1(z):
return jnp.arctan(z)
@jit
def arctan2(z):
t = 1 + z**2
return 1 / t, -2 * z / t**2
self.arctan1 = arctan1
self.arctan2 = arctan2
[docs]
def arctan(self, z, cost_only):
"""Compute arctan loss rho values."""
rho0 = self.arctan1(z)
if cost_only:
rho1, rho2 = self.get_empty_rhos(z)
else:
rho1, rho2 = self.arctan2(z)
return self.stack_rhos(rho0, rho1, rho2)
[docs]
def create_zscale(self):
"""Create JIT-compiled function to compute scaled squared residuals.
Computes z = (f/f_scale)^2 for robust loss function input.
"""
@jit
def zscale(f, f_scale):
return (f / f_scale) ** 2
self.zscale = zscale
[docs]
def create_calculate_cost(self):
"""Create JIT-compiled cost calculation function.
Computes total cost as 0.5 * f_scale^2 * sum(rho0) with masking.
"""
@jit
def calculate_cost(f_scale, rho, data_mask):
cost_array = jnp.where(data_mask, rho[0], 0)
return 0.5 * f_scale**2 * jnp.sum(cost_array)
self.calculate_cost = calculate_cost
[docs]
def create_scale_rhos(self):
"""Create JIT-compiled function to scale rho values by f_scale.
Applies proper scaling: ``rho0 *= f_scale**2``, ``rho2 /= f_scale**2``.
"""
@jit
def scale_rhos(rho, f_scale):
rho0 = rho[0] * f_scale**2
rho1 = rho[1]
rho2 = rho[2] / f_scale**2
return self.stack_rhos(rho0, rho1, rho2)
self.scale_rhos = scale_rhos
[docs]
def construct_single_loss_function(self, loss):
def loss_function(f, f_scale, data_mask=None, cost_only=False):
z = self.zscale(f, f_scale)
rho = loss(z, cost_only=cost_only)
if cost_only:
return self.calculate_cost(f_scale, rho, data_mask)
rho = self.scale_rhos(rho, f_scale)
return rho
return loss_function
[docs]
def construct_all_loss_functions(self):
loss_funcs = {}
for key, loss in self.IMPLEMENTED_LOSSES.items():
if loss is not None:
loss_funcs[key] = self.construct_single_loss_function(loss)
return loss_funcs
[docs]
def get_loss_function(self, loss):
if loss == "linear":
return None
if not callable(loss):
return self.loss_funcs[loss]
else:
def loss_function(f, f_scale, data_mask=None, cost_only=False):
z = self.zscale(f, f_scale)
rho = loss(z)
if cost_only:
return self.calculate_cost(f_scale, rho, data_mask)
rho = self.scale_rhos(rho, f_scale)
return rho
return loss_function