Source code for nlsq.cli.model_registry

"""Model registry for NLSQ CLI.

This module provides model function resolution for curve fitting workflows:
- Builtin models from nlsq.functions module
- Custom models loaded from external Python files
- Polynomial models generated by degree

The ModelRegistry class handles discovery, loading, and validation of
model functions for use in curve fitting workflows.

Security Features
-----------------
Custom models are validated before loading to prevent arbitrary code execution:
- AST-based pattern detection for dangerous operations
- Path traversal prevention
- Resource limits (timeout, memory) for model execution
- Audit logging for model loading attempts

Example Usage
-------------
>>> from nlsq.cli.model_registry import ModelRegistry
>>>
>>> registry = ModelRegistry()
>>>
>>> # Get a builtin model
>>> linear = registry.get_model("linear", {"type": "builtin", "name": "linear"})
>>>
>>> # Get a custom model from file (with security validation)
>>> custom = registry.get_model(
...     "/path/to/model.py",
...     {"type": "custom", "path": "/path/to/model.py", "function": "my_model"}
... )
>>>
>>> # Get a polynomial model
>>> poly3 = registry.get_model("poly", {"type": "polynomial", "degree": 3})
"""

import importlib.util
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any

from nlsq.cli.errors import ModelError
from nlsq.cli.model_validation import (
    get_audit_logger,
    validate_model,
    validate_path,
)

logger = logging.getLogger("nlsq.cli.model_registry")

# Type alias for model functions
ModelFunction = Callable[..., Any]


[docs] class ModelRegistry: """Registry for model functions used in curve fitting. The ModelRegistry handles three types of models: 1. **Builtin**: Models from nlsq.functions (linear, gaussian, etc.) 2. **Custom**: Models loaded from external Python files 3. **Polynomial**: Dynamically generated polynomial functions Each model function may have optional methods attached: - `estimate_p0(xdata, ydata)`: Estimate initial parameters - `bounds()`: Return default parameter bounds Attributes ---------- _builtin_cache : dict Cache of loaded builtin model functions. Examples -------- >>> registry = ModelRegistry() >>> model = registry.get_model("gaussian", {"type": "builtin", "name": "gaussian"}) >>> print(model.estimate_p0([1,2,3], [1,4,9])) """
[docs] def __init__(self) -> None: """Initialize ModelRegistry.""" self._builtin_cache: dict[str, ModelFunction] = {}
[docs] def list_builtin_models(self) -> list[str]: """List all available builtin model names. Discovers models by introspecting nlsq.functions.__all__. Returns ------- list[str] Names of available builtin models. Examples -------- >>> registry = ModelRegistry() >>> models = registry.list_builtin_models() >>> print("linear" in models) True """ import nlsq.core.functions return list(nlsq.core.functions.__all__)
[docs] def get_model(self, name_or_path: str, config: dict[str, Any]) -> ModelFunction: """Get a model function by name or path. Parameters ---------- name_or_path : str For builtin models: the model name (e.g., "linear", "gaussian"). For custom models: path to the Python file. For polynomial: any identifier (degree is in config). config : dict Model configuration with keys: - type: str - Model type ("builtin", "custom", "polynomial") - name: str - Model name (for builtin) - path: str - File path (for custom) - function: str - Function name (for custom) - degree: int - Polynomial degree (for polynomial) Returns ------- ModelFunction The model function with optional estimate_p0 and bounds methods. Raises ------ ModelError If the model cannot be resolved. Examples -------- >>> registry = ModelRegistry() >>> # Builtin model >>> linear = registry.get_model("linear", {"type": "builtin", "name": "linear"}) >>> # Custom model >>> custom = registry.get_model( ... "path/model.py", ... {"type": "custom", "path": "path/model.py", "function": "my_func"} ... ) >>> # Polynomial model >>> poly = registry.get_model("poly", {"type": "polynomial", "degree": 2}) """ model_type = config.get("type", "builtin") if model_type == "builtin": return self._get_builtin_model(config) elif model_type == "custom": return self._get_custom_model(config) elif model_type == "polynomial": return self._get_polynomial_model(config) else: raise ModelError( f"Unknown model type: {model_type!r}", model_type=model_type, context={"valid_types": ["builtin", "custom", "polynomial"]}, suggestion="Use type 'builtin', 'custom', or 'polynomial'", )
def _get_builtin_model(self, config: dict[str, Any]) -> ModelFunction: """Get a builtin model from nlsq.functions. Parameters ---------- config : dict Configuration with 'name' key specifying the model name. Returns ------- ModelFunction The builtin model function. Raises ------ ModelError If the model name is not found in nlsq.functions. """ model_name = config.get("name", "") if not model_name: raise ModelError( "Model name is required for builtin models", model_type="builtin", suggestion="Specify 'name' in the model configuration", ) # Check cache first if model_name in self._builtin_cache: return self._builtin_cache[model_name] # Get available models available = self.list_builtin_models() if model_name not in available: # Try to suggest a similar model name suggestion = self._suggest_similar_model(model_name, available) raise ModelError( f"Builtin model '{model_name}' not found", model_name=model_name, model_type="builtin", context={"available_models": available}, suggestion=suggestion, ) # Import the model from nlsq.core.functions import nlsq.core.functions model = getattr(nlsq.core.functions, model_name) # For polynomial factory, just return it as-is # The polynomial function returns a callable when invoked with degree if model_name == "polynomial": # Store in cache and return self._builtin_cache[model_name] = model return model # Cache and return self._builtin_cache[model_name] = model return model # ========================================================================= # Custom Model Helper Methods (extracted for complexity reduction) # ========================================================================= def _validate_custom_model_config( self, config: dict[str, Any], ) -> tuple[str, str]: """Validate and extract custom model configuration. Parameters ---------- config : dict Configuration with path, function, and optional trusted keys. Returns ------- file_path : str The file path to the model. function_name : str The function name to load. Raises ------ ModelError If file_path or function_name is missing. """ file_path = config.get("path", "") function_name = config.get("function", "") if not file_path: raise ModelError( "File path is required for custom models", model_type="custom", suggestion="Specify 'path' in the model configuration", ) if not function_name: raise ModelError( "Function name is required for custom models", model_type="custom", context={"path": file_path}, suggestion="Specify 'function' in the model configuration", ) return file_path, function_name def _validate_model_file_exists( self, file_path: str, function_name: str, ) -> Path: """Validate that the model file exists and is a regular file. Parameters ---------- file_path : str Path to the model file. function_name : str Function name (for error messages). Returns ------- path : Path Validated Path object. Raises ------ ModelError If file doesn't exist or isn't a regular file. """ path = Path(file_path) if not path.exists(): raise ModelError( f"Custom model file not found: {file_path}", model_name=function_name, model_type="custom", context={"path": file_path}, suggestion="Check the file path and ensure the file exists", ) if not path.is_file(): raise ModelError( f"Path is not a file: {file_path}", model_name=function_name, model_type="custom", context={"path": file_path}, suggestion="Provide a path to a Python file", ) return path def _check_path_security( self, path: Path, function_name: str, is_trusted: bool, ) -> None: """Check for path traversal attacks. Parameters ---------- path : Path Path to validate. function_name : str Function name (for error messages). is_trusted : bool Whether the model is trusted. Raises ------ ModelError If path traversal is detected. """ if not validate_path(path): audit_logger = get_audit_logger() from nlsq.cli.model_validation import ModelValidationResult result = ModelValidationResult( path=path, is_valid=False, is_trusted=is_trusted, violations=["Path traversal detected"], ) audit_logger.log_load_attempt(path, result) raise ModelError( f"Path traversal detected: {path}", model_name=function_name, model_type="custom", context={"path": str(path)}, suggestion="Model files must be within the current working directory", ) def _validate_model_security( self, path: Path, function_name: str, is_trusted: bool, ) -> None: """Run AST-based security validation on the model file. Parameters ---------- path : Path Path to the model file. function_name : str Function name (for error messages). is_trusted : bool Whether to skip validation. Raises ------ ModelError If security validation fails and not trusted. """ validation_result = validate_model(path, trusted=is_trusted) audit_logger = get_audit_logger() audit_logger.log_load_attempt(path, validation_result) if not validation_result.is_valid and not is_trusted: violations_msg = "; ".join(validation_result.violations[:3]) if len(validation_result.violations) > 3: violations_msg += f" (and {len(validation_result.violations) - 3} more)" raise ModelError( f"Security validation failed for {path}: {violations_msg}", model_name=function_name, model_type="custom", context={ "path": str(path), "violations": validation_result.violations, }, suggestion="Remove dangerous patterns or use --trust flag to skip validation", ) if is_trusted and validation_result.violations: logger.warning( "Loading model %s with --trust flag, bypassing security checks. " "Violations that would have blocked: %s", path, "; ".join(validation_result.violations[:3]), ) def _load_module_from_path( self, path: Path, function_name: str, ) -> Any: """Load a Python module from a file path. Parameters ---------- path : Path Path to the Python file. function_name : str Function name (for error messages). Returns ------- module : module The loaded Python module. Raises ------ ModelError If module cannot be loaded. """ try: spec = importlib.util.spec_from_file_location("custom_model", path) if spec is None or spec.loader is None: raise ModelError( f"Failed to load custom model file: {path}", model_name=function_name, model_type="custom", context={"path": str(path)}, suggestion="Ensure the file is a valid Python module", ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module except SyntaxError as e: raise ModelError( f"Syntax error in custom model file: {e}", model_name=function_name, model_type="custom", context={"path": str(path), "line": e.lineno}, suggestion="Fix the syntax error in the custom model file", ) from e except ModelError: raise except Exception as e: raise ModelError( f"Error loading custom model file: {e}", model_name=function_name, model_type="custom", context={"path": str(path)}, suggestion="Ensure the file is a valid Python module", ) from e def _resolve_function_from_module( self, module: Any, function_name: str, path: Path, ) -> ModelFunction: """Resolve and validate the model function from a loaded module. Parameters ---------- module : module The loaded Python module. function_name : str Name of the function to extract. path : Path Path to the module file (for error messages). Returns ------- model : ModelFunction The resolved model function. Raises ------ ModelError If function not found, not callable, or invalid signature. """ if not hasattr(module, function_name): available_functions = [ name for name, obj in inspect.getmembers(module) if callable(obj) and not name.startswith("_") ] raise ModelError( f"Function '{function_name}' not found in {path}", model_name=function_name, model_type="custom", context={"path": str(path), "available_functions": available_functions}, suggestion=f"Available functions: {', '.join(available_functions) or 'none'}", ) model = getattr(module, function_name) if not callable(model): raise ModelError( f"'{function_name}' is not a callable function", model_name=function_name, model_type="custom", context={"path": str(path)}, suggestion="Ensure the model is defined as a function", ) # Validate signature has at least one parameter (x) sig = inspect.signature(model) params = list(sig.parameters.keys()) if len(params) < 1: raise ModelError( f"Model function '{function_name}' must have at least one parameter (x)", model_name=function_name, model_type="custom", context={"path": str(path), "parameters": params}, suggestion="Define model as f(x, *params) where x is the independent variable", ) return model def _attach_model_utilities( self, model: ModelFunction, module: Any, ) -> ModelFunction: """Attach estimate_p0 and bounds methods from module to model. Parameters ---------- model : ModelFunction The model function. module : module The module containing optional utilities. Returns ------- model : ModelFunction The model with attached utilities. """ if hasattr(module, "estimate_p0"): estimate_p0_func = module.estimate_p0 if callable(estimate_p0_func): model.estimate_p0 = estimate_p0_func # type: ignore[attr-defined] if hasattr(module, "bounds"): bounds_func = module.bounds if callable(bounds_func): model.bounds = bounds_func # type: ignore[attr-defined] return model def _get_custom_model( self, config: dict[str, Any], *, trusted: bool = False, ) -> ModelFunction: """Load a custom model from an external Python file. Parameters ---------- config : dict Configuration with: - path: str - Path to the Python file - function: str - Name of the model function - trusted: bool - Skip security validation (default: False) trusted : bool, default=False If True, skip security validation. This is a keyword-only argument that overrides config["trusted"]. Returns ------- ModelFunction The custom model function with optional estimate_p0 and bounds. Raises ------ ModelError If the file or function cannot be found, the function is invalid, or the model fails security validation. """ # Step 1: Validate configuration file_path, function_name = self._validate_custom_model_config(config) is_trusted = trusted or config.get("trusted", False) # Step 2: Validate file exists path = self._validate_model_file_exists(file_path, function_name) # Step 3: Security validation - path traversal self._check_path_security(path, function_name, is_trusted) # Step 4: Security validation - AST-based self._validate_model_security(path, function_name, is_trusted) # Step 5: Load module dynamically module = self._load_module_from_path(path, function_name) # Step 6: Resolve and validate function model = self._resolve_function_from_module(module, function_name, path) # Step 7: Attach utility methods return self._attach_model_utilities(model, module) def _get_polynomial_model(self, config: dict[str, Any]) -> ModelFunction: """Generate a polynomial model of given degree. Parameters ---------- config : dict Configuration with 'degree' key specifying the polynomial degree. Returns ------- ModelFunction A polynomial function with estimate_p0 and bounds methods. Raises ------ ModelError If the degree is invalid. """ degree = config.get("degree") if degree is None: raise ModelError( "Polynomial degree is required", model_type="polynomial", suggestion="Specify 'degree' in the model configuration", ) try: degree = int(degree) except (ValueError, TypeError) as e: raise ModelError( f"Invalid polynomial degree: {degree!r}", model_type="polynomial", context={"degree": degree}, suggestion="Degree must be a non-negative integer", ) from e if degree < 0: raise ModelError( f"Polynomial degree must be non-negative, got {degree}", model_type="polynomial", context={"degree": degree}, suggestion="Use a non-negative integer for degree (0, 1, 2, ...)", ) # Use nlsq.core.functions.polynomial factory import nlsq.core.functions model = nlsq.core.functions.polynomial(degree) return model def _suggest_similar_model(self, name: str, available: list[str]) -> str: """Suggest a similar model name based on Levenshtein-like matching. Parameters ---------- name : str The invalid model name. available : list[str] List of available model names. Returns ------- str A suggestion message. """ # Simple substring matching for suggestions name_lower = name.lower() suggestions = [ model_name for model_name in available if name_lower in model_name.lower() or model_name.lower() in name_lower ] # Also check for common typos (character difference) if not suggestions: for model_name in available: if len(name) == len(model_name): diff_count = sum( 1 for a, b in zip(name, model_name, strict=False) if a != b ) if diff_count <= 2: suggestions.append(model_name) elif abs(len(name) - len(model_name)) <= 2: # Allow length difference of up to 2 shorter = name if len(name) < len(model_name) else model_name longer = model_name if len(name) < len(model_name) else name if shorter in longer: suggestions.append(model_name) if suggestions: return f"Did you mean: {', '.join(suggestions)}?" else: return f"Available models: {', '.join(available)}"