"""Parameter normalization for improved optimization convergence.
This module provides automatic parameter scaling to address gradient signal
weakness caused by parameter scale imbalance. Parameters spanning many orders
of magnitude can cause slow convergence and numerical instability.
The ParameterNormalizer class supports multiple normalization strategies:
- Bounds-based: Normalize to [0, 1] using parameter bounds
- p0-based: Scale by initial parameter magnitudes
- None: Identity transform (no normalization)
The NormalizedModelWrapper wraps user model functions to work transparently
in normalized parameter space while maintaining JAX JIT compatibility.
"""
from __future__ import annotations
from collections.abc import Callable
import jax.numpy as jnp
__all__ = ["NormalizedModelWrapper", "ParameterNormalizer"]
[docs]
class ParameterNormalizer:
"""Normalizes parameters to improve optimization convergence.
This class handles automatic parameter scaling to address gradient signal
weakness from parameter scale imbalance. It supports multiple strategies:
- **bounds**: Normalize to [0, 1] using parameter bounds (lb, ub)
- **p0**: Scale by initial parameter magnitudes
- **auto**: Use bounds if provided, else p0-based
- **none**: Identity transform (no normalization)
The normalizer computes and stores the normalization Jacobian analytically,
which is needed for transforming covariance matrices back to original space.
Parameters
----------
p0 : array_like
Initial parameter guess of shape (n_params,)
bounds : tuple of array_like, optional
Parameter bounds as (lb, ub) where lb and ub are arrays of shape
(n_params,). If None, p0-based scaling is used.
strategy : str, default='auto'
Normalization strategy. Options:
- 'auto': Use bounds if provided, else p0-based
- 'bounds': Normalize to [0, 1] using bounds
- 'p0': Scale by initial parameter magnitudes
- 'none': Identity transform (no normalization)
Attributes
----------
strategy : str
Selected normalization strategy
scales : jax.Array
Scaling factors for each parameter (diagonal of Jacobian)
offsets : jax.Array
Offset for each parameter (used in bounds-based)
original_bounds : tuple of jax.Array or None
Original parameter bounds (lb, ub)
normalization_jacobian : jax.Array
Denormalization Jacobian matrix (diagonal) of shape (n_params, n_params).
For covariance transform: Cov_orig = J @ Cov_norm @ J.T
Examples
--------
Bounds-based normalization:
>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer
>>> p0 = jnp.array([50.0, 0.5])
>>> bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0]))
>>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[0.444... 0.5]
>>> denormalized = normalizer.denormalize(normalized)
>>> print(jnp.allclose(denormalized, p0))
True
p0-based normalization:
>>> p0 = jnp.array([1000.0, 1.0, 0.001])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[1. 1. 1.]
No normalization:
>>> p0 = jnp.array([5.0, 15.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='none')
>>> normalized = normalizer.normalize(p0)
>>> print(jnp.allclose(normalized, p0))
True
See Also
--------
NormalizedModelWrapper : Wraps model functions for normalized parameters
HybridStreamingConfig : Configuration with normalization_strategy parameter
Notes
-----
Implements Phase 0 (Parameter Normalization Setup) of the Adaptive
Hybrid Streaming Optimizer specification.
"""
[docs]
def __init__(
self,
p0: jnp.ndarray,
bounds: tuple[jnp.ndarray, jnp.ndarray] | None = None,
strategy: str = "auto",
):
"""Initialize parameter normalizer.
Parameters
----------
p0 : array_like
Initial parameter guess
bounds : tuple of array_like, optional
Parameter bounds as (lb, ub)
strategy : str, default='auto'
Normalization strategy
"""
self.p0 = jnp.asarray(p0, dtype=jnp.float64)
self.n_params = len(self.p0)
self.original_bounds = bounds
# Validate strategy
valid_strategies = ("bounds", "p0", "none", "auto")
if strategy not in valid_strategies:
raise ValueError(
f"strategy must be one of {valid_strategies}, got: {strategy}"
)
# Determine strategy
if strategy == "auto":
# Auto-select: use bounds if provided, else p0-based
if bounds is not None:
self.strategy = "bounds"
else:
self.strategy = "p0"
else:
self.strategy = strategy
# Compute scales and offsets based on strategy
self._compute_normalization_parameters()
# Compute normalization Jacobian (denormalization Jacobian)
self._normalization_jacobian = self._compute_jacobian()
def _compute_normalization_parameters(self):
"""Compute scaling factors and offsets based on strategy."""
if self.strategy == "bounds":
# Bounds-based: normalize to [0, 1]
if self.original_bounds is None:
raise ValueError("bounds must be provided for bounds-based strategy")
lb, ub = self.original_bounds
lb = jnp.asarray(lb, dtype=jnp.float64)
ub = jnp.asarray(ub, dtype=jnp.float64)
# Scale = (ub - lb), offset = lb
# Normalized: (params - lb) / (ub - lb)
self.scales = ub - lb
# Handle zero range (constant parameter)
eps = jnp.finfo(jnp.float64).eps
self.scales = jnp.where(
jnp.abs(self.scales) < eps, jnp.ones_like(self.scales), self.scales
)
self.offsets = lb
elif self.strategy == "p0":
# p0-based: scale by parameter magnitudes
# Normalized: params / |p0|
abs_p0 = jnp.abs(self.p0)
# Handle zero parameters with small epsilon
eps = jnp.finfo(jnp.float64).eps * 10
self.scales = jnp.where(abs_p0 < eps, jnp.ones_like(abs_p0), abs_p0)
self.offsets = jnp.zeros(self.n_params, dtype=jnp.float64)
elif self.strategy == "none":
# No normalization: identity transform
self.scales = jnp.ones(self.n_params, dtype=jnp.float64)
self.offsets = jnp.zeros(self.n_params, dtype=jnp.float64)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
def _compute_jacobian(self) -> jnp.ndarray:
"""Compute denormalization Jacobian analytically.
For our scaling operations, the Jacobian is diagonal with scales on diagonal.
Returns
-------
jax.Array
Denormalization Jacobian matrix of shape (n_params, n_params)
"""
# Jacobian is diagonal matrix with scales on diagonal
# This is d(denormalized)/d(normalized) = diag(scales)
return jnp.diag(self.scales)
@property
def normalization_jacobian(self) -> jnp.ndarray:
"""Get the denormalization Jacobian matrix.
This is the Jacobian of the denormalization transform, needed for
transforming covariance matrices from normalized to original space:
Cov_orig = J @ Cov_norm @ J.T
Returns
-------
jax.Array
Denormalization Jacobian matrix of shape (n_params, n_params).
For our scaling, this is a diagonal matrix with scales on the diagonal.
"""
return self._normalization_jacobian
[docs]
def normalize(self, params: jnp.ndarray) -> jnp.ndarray:
"""Normalize parameters to scaled space.
Parameters
----------
params : array_like
Parameters in original space of shape (n_params,)
Returns
-------
jax.Array
Normalized parameters of shape (n_params,)
"""
params = jnp.asarray(params, dtype=jnp.float64)
# Apply normalization: (params - offset) / scale
normalized = (params - self.offsets) / self.scales
return normalized
[docs]
def denormalize(self, normalized_params: jnp.ndarray) -> jnp.ndarray:
"""Denormalize parameters back to original space.
This is the exact inverse of normalize().
Parameters
----------
normalized_params : array_like
Parameters in normalized space of shape (n_params,)
Returns
-------
jax.Array
Parameters in original space of shape (n_params,)
"""
normalized_params = jnp.asarray(normalized_params, dtype=jnp.float64)
# Apply denormalization: params = normalized * scale + offset
params = normalized_params * self.scales + self.offsets
return params
[docs]
class NormalizedModelWrapper:
"""Wraps user model function to work with normalized parameters.
This wrapper allows optimization algorithms to work in normalized parameter
space while the user model function operates in original parameter space.
The wrapper is JAX JIT-compatible and preserves gradients correctly.
Parameters
----------
model_fn : callable
User model function with signature: ``model_fn(x, *params) -> predictions``
normalizer : ParameterNormalizer
Parameter normalizer instance
Examples
--------
>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer, NormalizedModelWrapper
>>> def model(x, a, b):
... return a * x + b
>>> p0 = jnp.array([5.0, 10.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> wrapped_model = NormalizedModelWrapper(model, normalizer)
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> normalized_params = normalizer.normalize(p0)
>>> output = wrapped_model(x, *normalized_params)
>>> print(output)
[15. 20. 25.]
JIT compilation:
>>> import jax
>>> @jax.jit
... def optimized_model(x, a_norm, b_norm):
... return wrapped_model(x, a_norm, b_norm)
>>> output = optimized_model(x, *normalized_params)
See Also
--------
ParameterNormalizer : Handles parameter normalization
"""
[docs]
def __init__(
self, model_fn: Callable[..., jnp.ndarray], normalizer: ParameterNormalizer
):
"""Initialize normalized model wrapper.
Parameters
----------
model_fn : callable
User model function
normalizer : ParameterNormalizer
Parameter normalizer
"""
self.model_fn = model_fn
self.normalizer = normalizer
[docs]
def __call__(self, x: jnp.ndarray, *normalized_params: jnp.ndarray) -> jnp.ndarray:
"""Call wrapped model with normalized parameters.
Parameters
----------
x : array_like
Independent variable data
*normalized_params : array_like
Normalized parameter values (unpacked)
Returns
-------
jax.Array
Model predictions
"""
# Convert normalized parameters to array
normalized_params_array = jnp.asarray(normalized_params, dtype=jnp.float64)
# Denormalize parameters to original space
original_params = self.normalizer.denormalize(normalized_params_array)
# Call original model with denormalized parameters
# Unpack parameters for model function
return self.model_fn(x, *original_params)