nlsq.parameter_normalizer module¶
Parameter normalization for improved optimization convergence.
This module provides automatic parameter scaling to address gradient signal weakness caused by parameter scale imbalance. Parameters spanning many orders of magnitude can cause slow convergence and numerical instability.
The ParameterNormalizer class supports multiple normalization strategies: - Bounds-based: Normalize to [0, 1] using parameter bounds - p0-based: Scale by initial parameter magnitudes - None: Identity transform (no normalization)
The NormalizedModelWrapper wraps user model functions to work transparently in normalized parameter space while maintaining JAX JIT compatibility.
- class nlsq.precision.parameter_normalizer.NormalizedModelWrapper(model_fn, normalizer)[source]
Bases:
objectWraps user model function to work with normalized parameters.
This wrapper allows optimization algorithms to work in normalized parameter space while the user model function operates in original parameter space. The wrapper is JAX JIT-compatible and preserves gradients correctly.
- Parameters:
model_fn (callable) – User model function with signature:
model_fn(x, *params) -> predictionsnormalizer (ParameterNormalizer) – Parameter normalizer instance
Examples
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer, NormalizedModelWrapper >>> def model(x, a, b): ... return a * x + b >>> p0 = jnp.array([5.0, 10.0]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0') >>> wrapped_model = NormalizedModelWrapper(model, normalizer) >>> x = jnp.array([1.0, 2.0, 3.0]) >>> normalized_params = normalizer.normalize(p0) >>> output = wrapped_model(x, *normalized_params) >>> print(output) [15. 20. 25.]
JIT compilation:
>>> import jax >>> @jax.jit ... def optimized_model(x, a_norm, b_norm): ... return wrapped_model(x, a_norm, b_norm) >>> output = optimized_model(x, *normalized_params)
See also
ParameterNormalizerHandles parameter normalization
- __init__(model_fn, normalizer)[source]
Initialize normalized model wrapper.
- Parameters:
model_fn (callable) – User model function
normalizer (ParameterNormalizer) – Parameter normalizer
- class nlsq.precision.parameter_normalizer.ParameterNormalizer(p0, bounds=None, strategy='auto')[source]
Bases:
objectNormalizes parameters to improve optimization convergence.
This class handles automatic parameter scaling to address gradient signal weakness from parameter scale imbalance. It supports multiple strategies:
bounds: Normalize to [0, 1] using parameter bounds (lb, ub)
p0: Scale by initial parameter magnitudes
auto: Use bounds if provided, else p0-based
none: Identity transform (no normalization)
The normalizer computes and stores the normalization Jacobian analytically, which is needed for transforming covariance matrices back to original space.
- Parameters:
p0 (array_like) – Initial parameter guess of shape (n_params,)
bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub) where lb and ub are arrays of shape (n_params,). If None, p0-based scaling is used.
strategy (str, default='auto') –
Normalization strategy. Options:
’auto’: Use bounds if provided, else p0-based
’bounds’: Normalize to [0, 1] using bounds
’p0’: Scale by initial parameter magnitudes
’none’: Identity transform (no normalization)
- strategy
Selected normalization strategy
- Type:
- scales
Scaling factors for each parameter (diagonal of Jacobian)
- Type:
- offsets
Offset for each parameter (used in bounds-based)
- Type:
- normalization_jacobian
Denormalization Jacobian matrix (diagonal) of shape (n_params, n_params). For covariance transform: Cov_orig = J @ Cov_norm @ J.T
- Type:
Examples
Bounds-based normalization:
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer >>> p0 = jnp.array([50.0, 0.5]) >>> bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0])) >>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds') >>> normalized = normalizer.normalize(p0) >>> print(normalized) [0.444... 0.5] >>> denormalized = normalizer.denormalize(normalized) >>> print(jnp.allclose(denormalized, p0)) True
p0-based normalization:
>>> p0 = jnp.array([1000.0, 1.0, 0.001]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0') >>> normalized = normalizer.normalize(p0) >>> print(normalized) [1. 1. 1.]
No normalization:
>>> p0 = jnp.array([5.0, 15.0]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='none') >>> normalized = normalizer.normalize(p0) >>> print(jnp.allclose(normalized, p0)) True
See also
NormalizedModelWrapperWraps model functions for normalized parameters
HybridStreamingConfigConfiguration with normalization_strategy parameter
Notes
Implements Phase 0 (Parameter Normalization Setup) of the Adaptive Hybrid Streaming Optimizer specification.
- __init__(p0, bounds=None, strategy='auto')[source]
Initialize parameter normalizer.
- normalize(params)[source]
Normalize parameters to scaled space.
- Parameters:
params (array_like) – Parameters in original space of shape (n_params,)
- Returns:
Normalized parameters of shape (n_params,)
- Return type:
- denormalize(normalized_params)[source]
Denormalize parameters back to original space.
This is the exact inverse of normalize().
- Parameters:
normalized_params (array_like) – Parameters in normalized space of shape (n_params,)
- Returns:
Parameters in original space of shape (n_params,)
- Return type:
- transform_bounds()[source]
Transform bounds to normalized space.
- Returns:
lb_normalized (jax.Array) – Lower bounds in normalized space of shape (n_params,)
ub_normalized (jax.Array) – Upper bounds in normalized space of shape (n_params,)
- Return type:
Examples
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer >>> p0 = jnp.array([50.0]) >>> bounds = (jnp.array([10.0]), jnp.array([100.0])) >>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds') >>> lb_norm, ub_norm = normalizer.transform_bounds() >>> print(lb_norm, ub_norm) [0.] [1.]
Overview¶
The nlsq.parameter_normalizer module provides automatic parameter scaling to
address gradient signal weakness caused by parameter scale imbalance. Parameters
spanning many orders of magnitude can cause slow convergence and numerical
instability in optimization.
New in version 0.3.0: Parameter normalization for adaptive hybrid streaming.
Key Features¶
Bounds-based normalization: Normalize to [0, 1] using parameter bounds
p0-based normalization: Scale by initial parameter magnitudes
Identity transform: No normalization option
JAX JIT compatibility: All operations are JIT-compilable
Automatic Jacobian computation: Analytical denormalization Jacobian for covariance transform
Transparent model wrapping: User model operates in original parameter space
Classes¶
- class nlsq.precision.parameter_normalizer.ParameterNormalizer(p0, bounds=None, strategy='auto')[source]¶
Bases:
objectNormalizes parameters to improve optimization convergence.
This class handles automatic parameter scaling to address gradient signal weakness from parameter scale imbalance. It supports multiple strategies:
bounds: Normalize to [0, 1] using parameter bounds (lb, ub)
p0: Scale by initial parameter magnitudes
auto: Use bounds if provided, else p0-based
none: Identity transform (no normalization)
The normalizer computes and stores the normalization Jacobian analytically, which is needed for transforming covariance matrices back to original space.
- Parameters:
p0 (array_like) – Initial parameter guess of shape (n_params,)
bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub) where lb and ub are arrays of shape (n_params,). If None, p0-based scaling is used.
strategy (str, default='auto') –
Normalization strategy. Options:
’auto’: Use bounds if provided, else p0-based
’bounds’: Normalize to [0, 1] using bounds
’p0’: Scale by initial parameter magnitudes
’none’: Identity transform (no normalization)
- normalization_jacobian¶
Denormalization Jacobian matrix (diagonal) of shape (n_params, n_params). For covariance transform: Cov_orig = J @ Cov_norm @ J.T
- Type:
Examples
Bounds-based normalization:
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer >>> p0 = jnp.array([50.0, 0.5]) >>> bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0])) >>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds') >>> normalized = normalizer.normalize(p0) >>> print(normalized) [0.444... 0.5] >>> denormalized = normalizer.denormalize(normalized) >>> print(jnp.allclose(denormalized, p0)) True
p0-based normalization:
>>> p0 = jnp.array([1000.0, 1.0, 0.001]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0') >>> normalized = normalizer.normalize(p0) >>> print(normalized) [1. 1. 1.]
No normalization:
>>> p0 = jnp.array([5.0, 15.0]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='none') >>> normalized = normalizer.normalize(p0) >>> print(jnp.allclose(normalized, p0)) True
See also
NormalizedModelWrapperWraps model functions for normalized parameters
HybridStreamingConfigConfiguration with normalization_strategy parameter
Notes
Implements Phase 0 (Parameter Normalization Setup) of the Adaptive Hybrid Streaming Optimizer specification.
- normalize(params)[source]¶
Normalize parameters to scaled space.
- Parameters:
params (array_like) – Parameters in original space of shape (n_params,)
- Returns:
Normalized parameters of shape (n_params,)
- Return type:
- denormalize(normalized_params)[source]¶
Denormalize parameters back to original space.
This is the exact inverse of normalize().
- Parameters:
normalized_params (array_like) – Parameters in normalized space of shape (n_params,)
- Returns:
Parameters in original space of shape (n_params,)
- Return type:
- transform_bounds()[source]¶
Transform bounds to normalized space.
- Returns:
lb_normalized (jax.Array) – Lower bounds in normalized space of shape (n_params,)
ub_normalized (jax.Array) – Upper bounds in normalized space of shape (n_params,)
- Return type:
Examples
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer >>> p0 = jnp.array([50.0]) >>> bounds = (jnp.array([10.0]), jnp.array([100.0])) >>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds') >>> lb_norm, ub_norm = normalizer.transform_bounds() >>> print(lb_norm, ub_norm) [0.] [1.]
- class nlsq.precision.parameter_normalizer.NormalizedModelWrapper(model_fn, normalizer)[source]¶
Bases:
objectWraps user model function to work with normalized parameters.
This wrapper allows optimization algorithms to work in normalized parameter space while the user model function operates in original parameter space. The wrapper is JAX JIT-compatible and preserves gradients correctly.
- Parameters:
model_fn (callable) – User model function with signature:
model_fn(x, *params) -> predictionsnormalizer (ParameterNormalizer) – Parameter normalizer instance
Examples
>>> import jax.numpy as jnp >>> from nlsq.precision.parameter_normalizer import ParameterNormalizer, NormalizedModelWrapper >>> def model(x, a, b): ... return a * x + b >>> p0 = jnp.array([5.0, 10.0]) >>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0') >>> wrapped_model = NormalizedModelWrapper(model, normalizer) >>> x = jnp.array([1.0, 2.0, 3.0]) >>> normalized_params = normalizer.normalize(p0) >>> output = wrapped_model(x, *normalized_params) >>> print(output) [15. 20. 25.]
JIT compilation:
>>> import jax >>> @jax.jit ... def optimized_model(x, a_norm, b_norm): ... return wrapped_model(x, a_norm, b_norm) >>> output = optimized_model(x, *normalized_params)
See also
ParameterNormalizerHandles parameter normalization
- __init__(model_fn, normalizer)[source]¶
Initialize normalized model wrapper.
- Parameters:
model_fn (callable) – User model function
normalizer (ParameterNormalizer) – Parameter normalizer
Normalization Strategies¶
Bounds-Based Normalization¶
Normalizes parameters to [0, 1] using provided bounds. Best when you have meaningful parameter bounds:
import jax.numpy as jnp
from nlsq.precision.parameter_normalizer import ParameterNormalizer
# Parameters: amplitude in [10, 100], decay in [0, 1]
p0 = jnp.array([50.0, 0.5])
bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0]))
normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")
# Normalize: (50-10)/(100-10) = 0.444, (0.5-0)/(1-0) = 0.5
normalized = normalizer.normalize(p0)
print(normalized) # [0.444... 0.5]
# Denormalize back to original
denormalized = normalizer.denormalize(normalized)
print(jnp.allclose(denormalized, p0)) # True
p0-Based Normalization¶
Scales parameters by their initial magnitudes. Best when parameters have vastly different scales but no clear bounds:
# Parameters: large, medium, small
p0 = jnp.array([1000.0, 1.0, 0.001])
normalizer = ParameterNormalizer(p0, bounds=None, strategy="p0")
# All parameters normalized to ~1.0
normalized = normalizer.normalize(p0)
print(normalized) # [1. 1. 1.]
# Works with different values
params = jnp.array([500.0, 2.0, 0.002])
normalized = normalizer.normalize(params)
print(normalized) # [0.5 2. 2.]
No Normalization¶
Identity transform when normalization is not needed:
p0 = jnp.array([5.0, 15.0])
normalizer = ParameterNormalizer(p0, bounds=None, strategy="none")
normalized = normalizer.normalize(p0)
print(jnp.allclose(normalized, p0)) # True
Auto Strategy¶
Automatically selects the best strategy:
# With bounds: uses bounds-based
normalizer = ParameterNormalizer(p0, bounds=bounds, strategy="auto")
print(normalizer.strategy) # 'bounds'
# Without bounds: uses p0-based
normalizer = ParameterNormalizer(p0, bounds=None, strategy="auto")
print(normalizer.strategy) # 'p0'
Usage Examples¶
Model Wrapping for Optimization¶
Use NormalizedModelWrapper to transparently work in normalized space:
import jax.numpy as jnp
from nlsq.precision.parameter_normalizer import (
ParameterNormalizer,
NormalizedModelWrapper,
)
# Define model in original parameter space
def model(x, amplitude, decay):
return amplitude * jnp.exp(-decay * x)
# Setup normalization
p0 = jnp.array([100.0, 0.1]) # Very different scales
bounds = (jnp.array([10.0, 0.01]), jnp.array([200.0, 1.0]))
normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")
# Wrap model
wrapped_model = NormalizedModelWrapper(model, normalizer)
# Use wrapped model with normalized parameters
x = jnp.linspace(0, 10, 100)
normalized_p0 = normalizer.normalize(p0)
# Wrapped model internally denormalizes before calling original model
predictions = wrapped_model(x, *normalized_p0)
print(predictions.shape) # (100,)
JIT Compilation¶
All operations are JAX JIT-compatible:
import jax
from nlsq.precision.parameter_normalizer import (
ParameterNormalizer,
NormalizedModelWrapper,
)
def model(x, a, b):
return a * x + b
p0 = jnp.array([5.0, 10.0])
normalizer = ParameterNormalizer(p0, bounds=None, strategy="p0")
wrapped = NormalizedModelWrapper(model, normalizer)
@jax.jit
def compute_predictions(x, a_norm, b_norm):
return wrapped(x, a_norm, b_norm)
x = jnp.array([1.0, 2.0, 3.0])
normalized = normalizer.normalize(p0)
result = compute_predictions(x, *normalized)
Bounds Transformation¶
Transform bounds to normalized space:
p0 = jnp.array([50.0])
bounds = (jnp.array([10.0]), jnp.array([100.0]))
normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")
lb_norm, ub_norm = normalizer.transform_bounds()
print(lb_norm, ub_norm) # [0.] [1.]
Covariance Transform¶
Transform covariance from normalized to original space using the Jacobian:
# Get denormalization Jacobian
J = normalizer.normalization_jacobian
print(J.shape) # (n_params, n_params)
# Transform covariance: Cov_orig = J @ Cov_norm @ J.T
cov_normalized = jnp.eye(2) * 0.01 # Example normalized covariance
cov_original = J @ cov_normalized @ J.T
Mathematical Details¶
Normalization Transform¶
For bounds-based normalization with bounds \([l_i, u_i]\):
For p0-based normalization with initial values \(\\theta_0\):
Denormalization Jacobian¶
The denormalization Jacobian is diagonal:
For bounds-based: \(J_{ii} = u_i - l_i\)
For p0-based: \(J_{ii} = |\\theta_{0,i}|\)
Covariance Transform¶
To transform covariance from normalized to original space:
Since \(J\) is diagonal, this simplifies to:
See Also¶
nlsq.adaptive_hybrid_streaming module : Uses this for Phase 0 normalization
nlsq.hybrid_streaming_config module : Configuration with normalization_strategy
Advanced Customization Guide : Advanced optimization features