nlsq.parameter_normalizer module

Parameter normalization for improved optimization convergence.

This module provides automatic parameter scaling to address gradient signal weakness caused by parameter scale imbalance. Parameters spanning many orders of magnitude can cause slow convergence and numerical instability.

The ParameterNormalizer class supports multiple normalization strategies: - Bounds-based: Normalize to [0, 1] using parameter bounds - p0-based: Scale by initial parameter magnitudes - None: Identity transform (no normalization)

The NormalizedModelWrapper wraps user model functions to work transparently in normalized parameter space while maintaining JAX JIT compatibility.

class nlsq.precision.parameter_normalizer.NormalizedModelWrapper(model_fn, normalizer)[source]

Bases: object

Wraps user model function to work with normalized parameters.

This wrapper allows optimization algorithms to work in normalized parameter space while the user model function operates in original parameter space. The wrapper is JAX JIT-compatible and preserves gradients correctly.

Parameters:
  • model_fn (callable) – User model function with signature: model_fn(x, *params) -> predictions

  • normalizer (ParameterNormalizer) – Parameter normalizer instance

Examples

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer, NormalizedModelWrapper
>>> def model(x, a, b):
...     return a * x + b
>>> p0 = jnp.array([5.0, 10.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> wrapped_model = NormalizedModelWrapper(model, normalizer)
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> normalized_params = normalizer.normalize(p0)
>>> output = wrapped_model(x, *normalized_params)
>>> print(output)
[15. 20. 25.]

JIT compilation:

>>> import jax
>>> @jax.jit
... def optimized_model(x, a_norm, b_norm):
...     return wrapped_model(x, a_norm, b_norm)
>>> output = optimized_model(x, *normalized_params)

See also

ParameterNormalizer

Handles parameter normalization

__init__(model_fn, normalizer)[source]

Initialize normalized model wrapper.

Parameters:
  • model_fn (callable) – User model function

  • normalizer (ParameterNormalizer) – Parameter normalizer

__call__(x, *normalized_params)[source]

Call wrapped model with normalized parameters.

Parameters:
  • x (array_like) – Independent variable data

  • *normalized_params (array_like) – Normalized parameter values (unpacked)

Returns:

Model predictions

Return type:

jax.Array

class nlsq.precision.parameter_normalizer.ParameterNormalizer(p0, bounds=None, strategy='auto')[source]

Bases: object

Normalizes parameters to improve optimization convergence.

This class handles automatic parameter scaling to address gradient signal weakness from parameter scale imbalance. It supports multiple strategies:

  • bounds: Normalize to [0, 1] using parameter bounds (lb, ub)

  • p0: Scale by initial parameter magnitudes

  • auto: Use bounds if provided, else p0-based

  • none: Identity transform (no normalization)

The normalizer computes and stores the normalization Jacobian analytically, which is needed for transforming covariance matrices back to original space.

Parameters:
  • p0 (array_like) – Initial parameter guess of shape (n_params,)

  • bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub) where lb and ub are arrays of shape (n_params,). If None, p0-based scaling is used.

  • strategy (str, default='auto') –

    Normalization strategy. Options:

    • ’auto’: Use bounds if provided, else p0-based

    • ’bounds’: Normalize to [0, 1] using bounds

    • ’p0’: Scale by initial parameter magnitudes

    • ’none’: Identity transform (no normalization)

strategy

Selected normalization strategy

Type:

str

scales

Scaling factors for each parameter (diagonal of Jacobian)

Type:

jax.Array

offsets

Offset for each parameter (used in bounds-based)

Type:

jax.Array

original_bounds

Original parameter bounds (lb, ub)

Type:

tuple of jax.Array or None

normalization_jacobian

Denormalization Jacobian matrix (diagonal) of shape (n_params, n_params). For covariance transform: Cov_orig = J @ Cov_norm @ J.T

Type:

jax.Array

Examples

Bounds-based normalization:

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer
>>> p0 = jnp.array([50.0, 0.5])
>>> bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0]))
>>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[0.444... 0.5]
>>> denormalized = normalizer.denormalize(normalized)
>>> print(jnp.allclose(denormalized, p0))
True

p0-based normalization:

>>> p0 = jnp.array([1000.0, 1.0, 0.001])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[1. 1. 1.]

No normalization:

>>> p0 = jnp.array([5.0, 15.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='none')
>>> normalized = normalizer.normalize(p0)
>>> print(jnp.allclose(normalized, p0))
True

See also

NormalizedModelWrapper

Wraps model functions for normalized parameters

HybridStreamingConfig

Configuration with normalization_strategy parameter

Notes

Implements Phase 0 (Parameter Normalization Setup) of the Adaptive Hybrid Streaming Optimizer specification.

__init__(p0, bounds=None, strategy='auto')[source]

Initialize parameter normalizer.

Parameters:
  • p0 (array_like) – Initial parameter guess

  • bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub)

  • strategy (str, default='auto') – Normalization strategy

normalize(params)[source]

Normalize parameters to scaled space.

Parameters:

params (array_like) – Parameters in original space of shape (n_params,)

Returns:

Normalized parameters of shape (n_params,)

Return type:

jax.Array

denormalize(normalized_params)[source]

Denormalize parameters back to original space.

This is the exact inverse of normalize().

Parameters:

normalized_params (array_like) – Parameters in normalized space of shape (n_params,)

Returns:

Parameters in original space of shape (n_params,)

Return type:

jax.Array

transform_bounds()[source]

Transform bounds to normalized space.

Returns:

  • lb_normalized (jax.Array) – Lower bounds in normalized space of shape (n_params,)

  • ub_normalized (jax.Array) – Upper bounds in normalized space of shape (n_params,)

Return type:

tuple[Array, Array]

Examples

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer
>>> p0 = jnp.array([50.0])
>>> bounds = (jnp.array([10.0]), jnp.array([100.0]))
>>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds')
>>> lb_norm, ub_norm = normalizer.transform_bounds()
>>> print(lb_norm, ub_norm)
[0.] [1.]

Overview

The nlsq.parameter_normalizer module provides automatic parameter scaling to address gradient signal weakness caused by parameter scale imbalance. Parameters spanning many orders of magnitude can cause slow convergence and numerical instability in optimization.

New in version 0.3.0: Parameter normalization for adaptive hybrid streaming.

Key Features

  • Bounds-based normalization: Normalize to [0, 1] using parameter bounds

  • p0-based normalization: Scale by initial parameter magnitudes

  • Identity transform: No normalization option

  • JAX JIT compatibility: All operations are JIT-compilable

  • Automatic Jacobian computation: Analytical denormalization Jacobian for covariance transform

  • Transparent model wrapping: User model operates in original parameter space

Classes

class nlsq.precision.parameter_normalizer.ParameterNormalizer(p0, bounds=None, strategy='auto')[source]

Bases: object

Normalizes parameters to improve optimization convergence.

This class handles automatic parameter scaling to address gradient signal weakness from parameter scale imbalance. It supports multiple strategies:

  • bounds: Normalize to [0, 1] using parameter bounds (lb, ub)

  • p0: Scale by initial parameter magnitudes

  • auto: Use bounds if provided, else p0-based

  • none: Identity transform (no normalization)

The normalizer computes and stores the normalization Jacobian analytically, which is needed for transforming covariance matrices back to original space.

Parameters:
  • p0 (array_like) – Initial parameter guess of shape (n_params,)

  • bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub) where lb and ub are arrays of shape (n_params,). If None, p0-based scaling is used.

  • strategy (str, default='auto') –

    Normalization strategy. Options:

    • ’auto’: Use bounds if provided, else p0-based

    • ’bounds’: Normalize to [0, 1] using bounds

    • ’p0’: Scale by initial parameter magnitudes

    • ’none’: Identity transform (no normalization)

strategy

Selected normalization strategy

Type:

str

scales

Scaling factors for each parameter (diagonal of Jacobian)

Type:

jax.Array

offsets

Offset for each parameter (used in bounds-based)

Type:

jax.Array

original_bounds

Original parameter bounds (lb, ub)

Type:

tuple of jax.Array or None

normalization_jacobian

Denormalization Jacobian matrix (diagonal) of shape (n_params, n_params). For covariance transform: Cov_orig = J @ Cov_norm @ J.T

Type:

jax.Array

Examples

Bounds-based normalization:

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer
>>> p0 = jnp.array([50.0, 0.5])
>>> bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0]))
>>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[0.444... 0.5]
>>> denormalized = normalizer.denormalize(normalized)
>>> print(jnp.allclose(denormalized, p0))
True

p0-based normalization:

>>> p0 = jnp.array([1000.0, 1.0, 0.001])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> normalized = normalizer.normalize(p0)
>>> print(normalized)
[1. 1. 1.]

No normalization:

>>> p0 = jnp.array([5.0, 15.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='none')
>>> normalized = normalizer.normalize(p0)
>>> print(jnp.allclose(normalized, p0))
True

See also

NormalizedModelWrapper

Wraps model functions for normalized parameters

HybridStreamingConfig

Configuration with normalization_strategy parameter

Notes

Implements Phase 0 (Parameter Normalization Setup) of the Adaptive Hybrid Streaming Optimizer specification.

__init__(p0, bounds=None, strategy='auto')[source]

Initialize parameter normalizer.

Parameters:
  • p0 (array_like) – Initial parameter guess

  • bounds (tuple of array_like, optional) – Parameter bounds as (lb, ub)

  • strategy (str, default='auto') – Normalization strategy

normalize(params)[source]

Normalize parameters to scaled space.

Parameters:

params (array_like) – Parameters in original space of shape (n_params,)

Returns:

Normalized parameters of shape (n_params,)

Return type:

jax.Array

denormalize(normalized_params)[source]

Denormalize parameters back to original space.

This is the exact inverse of normalize().

Parameters:

normalized_params (array_like) – Parameters in normalized space of shape (n_params,)

Returns:

Parameters in original space of shape (n_params,)

Return type:

jax.Array

transform_bounds()[source]

Transform bounds to normalized space.

Returns:

  • lb_normalized (jax.Array) – Lower bounds in normalized space of shape (n_params,)

  • ub_normalized (jax.Array) – Upper bounds in normalized space of shape (n_params,)

Return type:

tuple[Array, Array]

Examples

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer
>>> p0 = jnp.array([50.0])
>>> bounds = (jnp.array([10.0]), jnp.array([100.0]))
>>> normalizer = ParameterNormalizer(p0, bounds, strategy='bounds')
>>> lb_norm, ub_norm = normalizer.transform_bounds()
>>> print(lb_norm, ub_norm)
[0.] [1.]
class nlsq.precision.parameter_normalizer.NormalizedModelWrapper(model_fn, normalizer)[source]

Bases: object

Wraps user model function to work with normalized parameters.

This wrapper allows optimization algorithms to work in normalized parameter space while the user model function operates in original parameter space. The wrapper is JAX JIT-compatible and preserves gradients correctly.

Parameters:
  • model_fn (callable) – User model function with signature: model_fn(x, *params) -> predictions

  • normalizer (ParameterNormalizer) – Parameter normalizer instance

Examples

>>> import jax.numpy as jnp
>>> from nlsq.precision.parameter_normalizer import ParameterNormalizer, NormalizedModelWrapper
>>> def model(x, a, b):
...     return a * x + b
>>> p0 = jnp.array([5.0, 10.0])
>>> normalizer = ParameterNormalizer(p0, bounds=None, strategy='p0')
>>> wrapped_model = NormalizedModelWrapper(model, normalizer)
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> normalized_params = normalizer.normalize(p0)
>>> output = wrapped_model(x, *normalized_params)
>>> print(output)
[15. 20. 25.]

JIT compilation:

>>> import jax
>>> @jax.jit
... def optimized_model(x, a_norm, b_norm):
...     return wrapped_model(x, a_norm, b_norm)
>>> output = optimized_model(x, *normalized_params)

See also

ParameterNormalizer

Handles parameter normalization

__init__(model_fn, normalizer)[source]

Initialize normalized model wrapper.

Parameters:
  • model_fn (callable) – User model function

  • normalizer (ParameterNormalizer) – Parameter normalizer

__call__(x, *normalized_params)[source]

Call wrapped model with normalized parameters.

Parameters:
  • x (array_like) – Independent variable data

  • *normalized_params (array_like) – Normalized parameter values (unpacked)

Returns:

Model predictions

Return type:

jax.Array

Normalization Strategies

Bounds-Based Normalization

Normalizes parameters to [0, 1] using provided bounds. Best when you have meaningful parameter bounds:

import jax.numpy as jnp
from nlsq.precision.parameter_normalizer import ParameterNormalizer

# Parameters: amplitude in [10, 100], decay in [0, 1]
p0 = jnp.array([50.0, 0.5])
bounds = (jnp.array([10.0, 0.0]), jnp.array([100.0, 1.0]))

normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")

# Normalize: (50-10)/(100-10) = 0.444, (0.5-0)/(1-0) = 0.5
normalized = normalizer.normalize(p0)
print(normalized)  # [0.444... 0.5]

# Denormalize back to original
denormalized = normalizer.denormalize(normalized)
print(jnp.allclose(denormalized, p0))  # True

p0-Based Normalization

Scales parameters by their initial magnitudes. Best when parameters have vastly different scales but no clear bounds:

# Parameters: large, medium, small
p0 = jnp.array([1000.0, 1.0, 0.001])

normalizer = ParameterNormalizer(p0, bounds=None, strategy="p0")

# All parameters normalized to ~1.0
normalized = normalizer.normalize(p0)
print(normalized)  # [1. 1. 1.]

# Works with different values
params = jnp.array([500.0, 2.0, 0.002])
normalized = normalizer.normalize(params)
print(normalized)  # [0.5 2. 2.]

No Normalization

Identity transform when normalization is not needed:

p0 = jnp.array([5.0, 15.0])
normalizer = ParameterNormalizer(p0, bounds=None, strategy="none")

normalized = normalizer.normalize(p0)
print(jnp.allclose(normalized, p0))  # True

Auto Strategy

Automatically selects the best strategy:

# With bounds: uses bounds-based
normalizer = ParameterNormalizer(p0, bounds=bounds, strategy="auto")
print(normalizer.strategy)  # 'bounds'

# Without bounds: uses p0-based
normalizer = ParameterNormalizer(p0, bounds=None, strategy="auto")
print(normalizer.strategy)  # 'p0'

Usage Examples

Model Wrapping for Optimization

Use NormalizedModelWrapper to transparently work in normalized space:

import jax.numpy as jnp
from nlsq.precision.parameter_normalizer import (
    ParameterNormalizer,
    NormalizedModelWrapper,
)


# Define model in original parameter space
def model(x, amplitude, decay):
    return amplitude * jnp.exp(-decay * x)


# Setup normalization
p0 = jnp.array([100.0, 0.1])  # Very different scales
bounds = (jnp.array([10.0, 0.01]), jnp.array([200.0, 1.0]))
normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")

# Wrap model
wrapped_model = NormalizedModelWrapper(model, normalizer)

# Use wrapped model with normalized parameters
x = jnp.linspace(0, 10, 100)
normalized_p0 = normalizer.normalize(p0)

# Wrapped model internally denormalizes before calling original model
predictions = wrapped_model(x, *normalized_p0)
print(predictions.shape)  # (100,)

JIT Compilation

All operations are JAX JIT-compatible:

import jax
from nlsq.precision.parameter_normalizer import (
    ParameterNormalizer,
    NormalizedModelWrapper,
)


def model(x, a, b):
    return a * x + b


p0 = jnp.array([5.0, 10.0])
normalizer = ParameterNormalizer(p0, bounds=None, strategy="p0")
wrapped = NormalizedModelWrapper(model, normalizer)


@jax.jit
def compute_predictions(x, a_norm, b_norm):
    return wrapped(x, a_norm, b_norm)


x = jnp.array([1.0, 2.0, 3.0])
normalized = normalizer.normalize(p0)
result = compute_predictions(x, *normalized)

Bounds Transformation

Transform bounds to normalized space:

p0 = jnp.array([50.0])
bounds = (jnp.array([10.0]), jnp.array([100.0]))
normalizer = ParameterNormalizer(p0, bounds, strategy="bounds")

lb_norm, ub_norm = normalizer.transform_bounds()
print(lb_norm, ub_norm)  # [0.] [1.]

Covariance Transform

Transform covariance from normalized to original space using the Jacobian:

# Get denormalization Jacobian
J = normalizer.normalization_jacobian
print(J.shape)  # (n_params, n_params)

# Transform covariance: Cov_orig = J @ Cov_norm @ J.T
cov_normalized = jnp.eye(2) * 0.01  # Example normalized covariance
cov_original = J @ cov_normalized @ J.T

Mathematical Details

Normalization Transform

For bounds-based normalization with bounds \([l_i, u_i]\):

\[\begin{split}\\theta_{\\text{norm},i} = \\frac{\\theta_i - l_i}{u_i - l_i}\end{split}\]

For p0-based normalization with initial values \(\\theta_0\):

\[\begin{split}\\theta_{\\text{norm},i} = \\frac{\\theta_i}{|\\theta_{0,i}|}\end{split}\]

Denormalization Jacobian

The denormalization Jacobian is diagonal:

\[\begin{split}J_{ii} = \\text{scale}_i\end{split}\]

For bounds-based: \(J_{ii} = u_i - l_i\)

For p0-based: \(J_{ii} = |\\theta_{0,i}|\)

Covariance Transform

To transform covariance from normalized to original space:

\[\begin{split}\\Sigma_{\\text{orig}} = J \\, \\Sigma_{\\text{norm}} \\, J^T\end{split}\]

Since \(J\) is diagonal, this simplifies to:

\[\begin{split}\\Sigma_{\\text{orig},ij} = J_{ii} \\, \\Sigma_{\\text{norm},ij} \\, J_{jj}\end{split}\]

See Also