"""Model registry for NLSQ CLI.
This module provides model function resolution for curve fitting workflows:
- Builtin models from nlsq.functions module
- Custom models loaded from external Python files
- Polynomial models generated by degree
The ModelRegistry class handles discovery, loading, and validation of
model functions for use in curve fitting workflows.
Security Features
-----------------
Custom models are validated before loading to prevent arbitrary code execution:
- AST-based pattern detection for dangerous operations
- Path traversal prevention
- Resource limits (timeout, memory) for model execution
- Audit logging for model loading attempts
Example Usage
-------------
>>> from nlsq.cli.model_registry import ModelRegistry
>>>
>>> registry = ModelRegistry()
>>>
>>> # Get a builtin model
>>> linear = registry.get_model("linear", {"type": "builtin", "name": "linear"})
>>>
>>> # Get a custom model from file (with security validation)
>>> custom = registry.get_model(
... "/path/to/model.py",
... {"type": "custom", "path": "/path/to/model.py", "function": "my_model"}
... )
>>>
>>> # Get a polynomial model
>>> poly3 = registry.get_model("poly", {"type": "polynomial", "degree": 3})
"""
import importlib.util
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any
from nlsq.cli.errors import ModelError
from nlsq.cli.model_validation import (
get_audit_logger,
validate_model,
validate_path,
)
logger = logging.getLogger("nlsq.cli.model_registry")
# Type alias for model functions
ModelFunction = Callable[..., Any]
[docs]
class ModelRegistry:
"""Registry for model functions used in curve fitting.
The ModelRegistry handles three types of models:
1. **Builtin**: Models from nlsq.functions (linear, gaussian, etc.)
2. **Custom**: Models loaded from external Python files
3. **Polynomial**: Dynamically generated polynomial functions
Each model function may have optional methods attached:
- `estimate_p0(xdata, ydata)`: Estimate initial parameters
- `bounds()`: Return default parameter bounds
Attributes
----------
_builtin_cache : dict
Cache of loaded builtin model functions.
Examples
--------
>>> registry = ModelRegistry()
>>> model = registry.get_model("gaussian", {"type": "builtin", "name": "gaussian"})
>>> print(model.estimate_p0([1,2,3], [1,4,9]))
"""
[docs]
def __init__(self) -> None:
"""Initialize ModelRegistry."""
self._builtin_cache: dict[str, ModelFunction] = {}
[docs]
def list_builtin_models(self) -> list[str]:
"""List all available builtin model names.
Discovers models by introspecting nlsq.functions.__all__.
Returns
-------
list[str]
Names of available builtin models.
Examples
--------
>>> registry = ModelRegistry()
>>> models = registry.list_builtin_models()
>>> print("linear" in models)
True
"""
import nlsq.core.functions
return list(nlsq.core.functions.__all__)
[docs]
def get_model(self, name_or_path: str, config: dict[str, Any]) -> ModelFunction:
"""Get a model function by name or path.
Parameters
----------
name_or_path : str
For builtin models: the model name (e.g., "linear", "gaussian").
For custom models: path to the Python file.
For polynomial: any identifier (degree is in config).
config : dict
Model configuration with keys:
- type: str - Model type ("builtin", "custom", "polynomial")
- name: str - Model name (for builtin)
- path: str - File path (for custom)
- function: str - Function name (for custom)
- degree: int - Polynomial degree (for polynomial)
Returns
-------
ModelFunction
The model function with optional estimate_p0 and bounds methods.
Raises
------
ModelError
If the model cannot be resolved.
Examples
--------
>>> registry = ModelRegistry()
>>> # Builtin model
>>> linear = registry.get_model("linear", {"type": "builtin", "name": "linear"})
>>> # Custom model
>>> custom = registry.get_model(
... "path/model.py",
... {"type": "custom", "path": "path/model.py", "function": "my_func"}
... )
>>> # Polynomial model
>>> poly = registry.get_model("poly", {"type": "polynomial", "degree": 2})
"""
model_type = config.get("type", "builtin")
if model_type == "builtin":
return self._get_builtin_model(config)
elif model_type == "custom":
return self._get_custom_model(config)
elif model_type == "polynomial":
return self._get_polynomial_model(config)
else:
raise ModelError(
f"Unknown model type: {model_type!r}",
model_type=model_type,
context={"valid_types": ["builtin", "custom", "polynomial"]},
suggestion="Use type 'builtin', 'custom', or 'polynomial'",
)
def _get_builtin_model(self, config: dict[str, Any]) -> ModelFunction:
"""Get a builtin model from nlsq.functions.
Parameters
----------
config : dict
Configuration with 'name' key specifying the model name.
Returns
-------
ModelFunction
The builtin model function.
Raises
------
ModelError
If the model name is not found in nlsq.functions.
"""
model_name = config.get("name", "")
if not model_name:
raise ModelError(
"Model name is required for builtin models",
model_type="builtin",
suggestion="Specify 'name' in the model configuration",
)
# Check cache first
if model_name in self._builtin_cache:
return self._builtin_cache[model_name]
# Get available models
available = self.list_builtin_models()
if model_name not in available:
# Try to suggest a similar model name
suggestion = self._suggest_similar_model(model_name, available)
raise ModelError(
f"Builtin model '{model_name}' not found",
model_name=model_name,
model_type="builtin",
context={"available_models": available},
suggestion=suggestion,
)
# Import the model from nlsq.core.functions
import nlsq.core.functions
model = getattr(nlsq.core.functions, model_name)
# For polynomial factory, just return it as-is
# The polynomial function returns a callable when invoked with degree
if model_name == "polynomial":
# Store in cache and return
self._builtin_cache[model_name] = model
return model
# Cache and return
self._builtin_cache[model_name] = model
return model
# =========================================================================
# Custom Model Helper Methods (extracted for complexity reduction)
# =========================================================================
def _validate_custom_model_config(
self,
config: dict[str, Any],
) -> tuple[str, str]:
"""Validate and extract custom model configuration.
Parameters
----------
config : dict
Configuration with path, function, and optional trusted keys.
Returns
-------
file_path : str
The file path to the model.
function_name : str
The function name to load.
Raises
------
ModelError
If file_path or function_name is missing.
"""
file_path = config.get("path", "")
function_name = config.get("function", "")
if not file_path:
raise ModelError(
"File path is required for custom models",
model_type="custom",
suggestion="Specify 'path' in the model configuration",
)
if not function_name:
raise ModelError(
"Function name is required for custom models",
model_type="custom",
context={"path": file_path},
suggestion="Specify 'function' in the model configuration",
)
return file_path, function_name
def _validate_model_file_exists(
self,
file_path: str,
function_name: str,
) -> Path:
"""Validate that the model file exists and is a regular file.
Parameters
----------
file_path : str
Path to the model file.
function_name : str
Function name (for error messages).
Returns
-------
path : Path
Validated Path object.
Raises
------
ModelError
If file doesn't exist or isn't a regular file.
"""
path = Path(file_path)
if not path.exists():
raise ModelError(
f"Custom model file not found: {file_path}",
model_name=function_name,
model_type="custom",
context={"path": file_path},
suggestion="Check the file path and ensure the file exists",
)
if not path.is_file():
raise ModelError(
f"Path is not a file: {file_path}",
model_name=function_name,
model_type="custom",
context={"path": file_path},
suggestion="Provide a path to a Python file",
)
return path
def _check_path_security(
self,
path: Path,
function_name: str,
is_trusted: bool,
) -> None:
"""Check for path traversal attacks.
Parameters
----------
path : Path
Path to validate.
function_name : str
Function name (for error messages).
is_trusted : bool
Whether the model is trusted.
Raises
------
ModelError
If path traversal is detected.
"""
if not validate_path(path):
audit_logger = get_audit_logger()
from nlsq.cli.model_validation import ModelValidationResult
result = ModelValidationResult(
path=path,
is_valid=False,
is_trusted=is_trusted,
violations=["Path traversal detected"],
)
audit_logger.log_load_attempt(path, result)
raise ModelError(
f"Path traversal detected: {path}",
model_name=function_name,
model_type="custom",
context={"path": str(path)},
suggestion="Model files must be within the current working directory",
)
def _validate_model_security(
self,
path: Path,
function_name: str,
is_trusted: bool,
) -> None:
"""Run AST-based security validation on the model file.
Parameters
----------
path : Path
Path to the model file.
function_name : str
Function name (for error messages).
is_trusted : bool
Whether to skip validation.
Raises
------
ModelError
If security validation fails and not trusted.
"""
validation_result = validate_model(path, trusted=is_trusted)
audit_logger = get_audit_logger()
audit_logger.log_load_attempt(path, validation_result)
if not validation_result.is_valid and not is_trusted:
violations_msg = "; ".join(validation_result.violations[:3])
if len(validation_result.violations) > 3:
violations_msg += f" (and {len(validation_result.violations) - 3} more)"
raise ModelError(
f"Security validation failed for {path}: {violations_msg}",
model_name=function_name,
model_type="custom",
context={
"path": str(path),
"violations": validation_result.violations,
},
suggestion="Remove dangerous patterns or use --trust flag to skip validation",
)
if is_trusted and validation_result.violations:
logger.warning(
"Loading model %s with --trust flag, bypassing security checks. "
"Violations that would have blocked: %s",
path,
"; ".join(validation_result.violations[:3]),
)
def _load_module_from_path(
self,
path: Path,
function_name: str,
) -> Any:
"""Load a Python module from a file path.
Parameters
----------
path : Path
Path to the Python file.
function_name : str
Function name (for error messages).
Returns
-------
module : module
The loaded Python module.
Raises
------
ModelError
If module cannot be loaded.
"""
try:
spec = importlib.util.spec_from_file_location("custom_model", path)
if spec is None or spec.loader is None:
raise ModelError(
f"Failed to load custom model file: {path}",
model_name=function_name,
model_type="custom",
context={"path": str(path)},
suggestion="Ensure the file is a valid Python module",
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
except SyntaxError as e:
raise ModelError(
f"Syntax error in custom model file: {e}",
model_name=function_name,
model_type="custom",
context={"path": str(path), "line": e.lineno},
suggestion="Fix the syntax error in the custom model file",
) from e
except ModelError:
raise
except Exception as e:
raise ModelError(
f"Error loading custom model file: {e}",
model_name=function_name,
model_type="custom",
context={"path": str(path)},
suggestion="Ensure the file is a valid Python module",
) from e
def _resolve_function_from_module(
self,
module: Any,
function_name: str,
path: Path,
) -> ModelFunction:
"""Resolve and validate the model function from a loaded module.
Parameters
----------
module : module
The loaded Python module.
function_name : str
Name of the function to extract.
path : Path
Path to the module file (for error messages).
Returns
-------
model : ModelFunction
The resolved model function.
Raises
------
ModelError
If function not found, not callable, or invalid signature.
"""
if not hasattr(module, function_name):
available_functions = [
name
for name, obj in inspect.getmembers(module)
if callable(obj) and not name.startswith("_")
]
raise ModelError(
f"Function '{function_name}' not found in {path}",
model_name=function_name,
model_type="custom",
context={"path": str(path), "available_functions": available_functions},
suggestion=f"Available functions: {', '.join(available_functions) or 'none'}",
)
model = getattr(module, function_name)
if not callable(model):
raise ModelError(
f"'{function_name}' is not a callable function",
model_name=function_name,
model_type="custom",
context={"path": str(path)},
suggestion="Ensure the model is defined as a function",
)
# Validate signature has at least one parameter (x)
sig = inspect.signature(model)
params = list(sig.parameters.keys())
if len(params) < 1:
raise ModelError(
f"Model function '{function_name}' must have at least one parameter (x)",
model_name=function_name,
model_type="custom",
context={"path": str(path), "parameters": params},
suggestion="Define model as f(x, *params) where x is the independent variable",
)
return model
def _attach_model_utilities(
self,
model: ModelFunction,
module: Any,
) -> ModelFunction:
"""Attach estimate_p0 and bounds methods from module to model.
Parameters
----------
model : ModelFunction
The model function.
module : module
The module containing optional utilities.
Returns
-------
model : ModelFunction
The model with attached utilities.
"""
if hasattr(module, "estimate_p0"):
estimate_p0_func = module.estimate_p0
if callable(estimate_p0_func):
model.estimate_p0 = estimate_p0_func # type: ignore[attr-defined]
if hasattr(module, "bounds"):
bounds_func = module.bounds
if callable(bounds_func):
model.bounds = bounds_func # type: ignore[attr-defined]
return model
def _get_custom_model(
self,
config: dict[str, Any],
*,
trusted: bool = False,
) -> ModelFunction:
"""Load a custom model from an external Python file.
Parameters
----------
config : dict
Configuration with:
- path: str - Path to the Python file
- function: str - Name of the model function
- trusted: bool - Skip security validation (default: False)
trusted : bool, default=False
If True, skip security validation. This is a keyword-only
argument that overrides config["trusted"].
Returns
-------
ModelFunction
The custom model function with optional estimate_p0 and bounds.
Raises
------
ModelError
If the file or function cannot be found, the function is invalid,
or the model fails security validation.
"""
# Step 1: Validate configuration
file_path, function_name = self._validate_custom_model_config(config)
is_trusted = trusted or config.get("trusted", False)
# Step 2: Validate file exists
path = self._validate_model_file_exists(file_path, function_name)
# Step 3: Security validation - path traversal
self._check_path_security(path, function_name, is_trusted)
# Step 4: Security validation - AST-based
self._validate_model_security(path, function_name, is_trusted)
# Step 5: Load module dynamically
module = self._load_module_from_path(path, function_name)
# Step 6: Resolve and validate function
model = self._resolve_function_from_module(module, function_name, path)
# Step 7: Attach utility methods
return self._attach_model_utilities(model, module)
def _get_polynomial_model(self, config: dict[str, Any]) -> ModelFunction:
"""Generate a polynomial model of given degree.
Parameters
----------
config : dict
Configuration with 'degree' key specifying the polynomial degree.
Returns
-------
ModelFunction
A polynomial function with estimate_p0 and bounds methods.
Raises
------
ModelError
If the degree is invalid.
"""
degree = config.get("degree")
if degree is None:
raise ModelError(
"Polynomial degree is required",
model_type="polynomial",
suggestion="Specify 'degree' in the model configuration",
)
try:
degree = int(degree)
except (ValueError, TypeError) as e:
raise ModelError(
f"Invalid polynomial degree: {degree!r}",
model_type="polynomial",
context={"degree": degree},
suggestion="Degree must be a non-negative integer",
) from e
if degree < 0:
raise ModelError(
f"Polynomial degree must be non-negative, got {degree}",
model_type="polynomial",
context={"degree": degree},
suggestion="Use a non-negative integer for degree (0, 1, 2, ...)",
)
# Use nlsq.core.functions.polynomial factory
import nlsq.core.functions
model = nlsq.core.functions.polynomial(degree)
return model
def _suggest_similar_model(self, name: str, available: list[str]) -> str:
"""Suggest a similar model name based on Levenshtein-like matching.
Parameters
----------
name : str
The invalid model name.
available : list[str]
List of available model names.
Returns
-------
str
A suggestion message.
"""
# Simple substring matching for suggestions
name_lower = name.lower()
suggestions = [
model_name
for model_name in available
if name_lower in model_name.lower() or model_name.lower() in name_lower
]
# Also check for common typos (character difference)
if not suggestions:
for model_name in available:
if len(name) == len(model_name):
diff_count = sum(
1 for a, b in zip(name, model_name, strict=False) if a != b
)
if diff_count <= 2:
suggestions.append(model_name)
elif abs(len(name) - len(model_name)) <= 2:
# Allow length difference of up to 2
shorter = name if len(name) < len(model_name) else model_name
longer = model_name if len(name) < len(model_name) else name
if shorter in longer:
suggestions.append(model_name)
if suggestions:
return f"Did you mean: {', '.join(suggestions)}?"
else:
return f"Available models: {', '.join(available)}"