Source code for nlsq.global_optimization.bounds_transform

"""Sigmoid bound transformation for CMA-ES optimization.

CMA-ES operates in unbounded space. This module provides smooth, differentiable
transformations between bounded parameter space and unbounded CMA-ES space.

The sigmoid transformation is used:
- Forward: x_bounded = lb + (ub - lb) * sigmoid(x_unbounded)
- Inverse: x_unbounded = logit((x_bounded - lb) / (ub - lb))
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import jax.numpy as jnp

if TYPE_CHECKING:
    from jax import Array

__all__ = [
    "compute_default_popsize",
    "transform_from_bounds",
    "transform_to_bounds",
]


def transform_to_bounds(
    x_unbounded: Array,
    lower_bounds: Array,
    upper_bounds: Array,
) -> Array:
    """Transform unbounded CMA-ES samples to bounded parameter space.

    Uses sigmoid transformation for smooth, differentiable bounds handling:
    x_bounded = lb + (ub - lb) * sigmoid(x_unbounded)

    Parameters
    ----------
    x_unbounded : Array
        Unbounded parameter values from CMA-ES, shape (n_params,) or
        (population_size, n_params).
    lower_bounds : Array
        Lower bounds for each parameter, shape (n_params,).
    upper_bounds : Array
        Upper bounds for each parameter, shape (n_params,).

    Returns
    -------
    Array
        Bounded parameter values, same shape as x_unbounded.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.array([0.0, 1.0, -1.0])
    >>> lb = jnp.array([0.0, 0.0, 0.0])
    >>> ub = jnp.array([1.0, 10.0, 100.0])
    >>> transform_to_bounds(x, lb, ub)
    Array([0.5, 7.310586, 26.894142], dtype=float32)
    """
    # Sigmoid: 1 / (1 + exp(-x))
    sigmoid_x = jnp.where(
        x_unbounded >= 0,
        1.0 / (1.0 + jnp.exp(-x_unbounded)),
        jnp.exp(x_unbounded) / (1.0 + jnp.exp(x_unbounded)),
    )

    # Map to bounds: lb + (ub - lb) * sigmoid(x)
    return lower_bounds + (upper_bounds - lower_bounds) * sigmoid_x


def transform_from_bounds(
    x_bounded: Array,
    lower_bounds: Array,
    upper_bounds: Array,
    epsilon: float = 1e-8,
) -> Array:
    """Transform bounded parameters to unbounded CMA-ES space.

    Uses logit (inverse sigmoid) transformation:
    x_unbounded = log(ratio / (1 - ratio)) where ratio = (x - lb) / (ub - lb)

    Parameters
    ----------
    x_bounded : Array
        Bounded parameter values, shape (n_params,) or (population_size, n_params).
    lower_bounds : Array
        Lower bounds for each parameter, shape (n_params,).
    upper_bounds : Array
        Upper bounds for each parameter, shape (n_params,).
    epsilon : float, optional
        Small value to prevent log(0) or log(1). Default: 1e-8.

    Returns
    -------
    Array
        Unbounded parameter values for CMA-ES, same shape as x_bounded.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.array([0.5, 5.0, 50.0])
    >>> lb = jnp.array([0.0, 0.0, 0.0])
    >>> ub = jnp.array([1.0, 10.0, 100.0])
    >>> transform_from_bounds(x, lb, ub)
    Array([0., 0., 0.], dtype=float32)
    """
    # Normalize to [0, 1]
    range_val = upper_bounds - lower_bounds
    safe_range = jnp.where(range_val == 0, 1.0, range_val)
    ratio = (x_bounded - lower_bounds) / safe_range

    # Clamp to avoid log(0) or log(inf)
    ratio_clamped = jnp.clip(ratio, epsilon, 1.0 - epsilon)

    # Logit: log(ratio / (1 - ratio))
    return jnp.log(ratio_clamped / (1.0 - ratio_clamped))


[docs] def compute_default_popsize(n_params: int) -> int: """Compute default CMA-ES population size. Uses the standard CMA-ES formula: int(4 + 3 * log(n)) Parameters ---------- n_params : int Number of parameters being optimized. Returns ------- int Default population size, minimum 4. Examples -------- >>> compute_default_popsize(5) 8 >>> compute_default_popsize(20) 12 """ import math return max(4, int(4 + 3 * math.log(max(1, n_params))))