"""Publication-quality visualization module for NLSQ CLI.
This module provides the FitVisualizer class for generating publication-quality
plots of curve fitting results, including:
- Combined main plot + residuals layout
- Separate histogram of residuals
- Confidence bands from covariance matrix error propagation
- Multiple style presets (publication, presentation, nature, science, minimal)
- Multi-format output (PDF vector, PNG raster)
- Fit statistics annotation (R-squared, RMSE)
- Colorblind-safe palette support
Example Usage
-------------
>>> from nlsq.cli.visualization import FitVisualizer
>>>
>>> visualizer = FitVisualizer()
>>> result = {"popt": [1.0, 0.5, 0.1], "pcov": [[0.01, 0, 0], [0, 0.02, 0], [0, 0, 0.005]], ...}
>>> data = {"xdata": x, "ydata": y, "sigma": sigma}
>>> config = {"visualization": {"enabled": True, "output_dir": "figures", ...}}
>>> output_paths = visualizer.generate(result, data, model, config)
"""
import warnings
from collections.abc import Callable
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg") # Non-interactive backend for CLI use
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# =============================================================================
# Style Presets Dictionary
# =============================================================================
STYLE_PRESETS: dict[str, dict[str, Any]] = {
"publication": {
# Clean serif fonts, 300 DPI, standard figure size
"font_family": "serif",
"font_size": 10,
"math_fontset": "cm",
"dpi": 300,
"figure_size": [6.0, 4.5],
"grid": True,
"grid_alpha": 0.3,
"spine_visibility": {"top": True, "right": True, "bottom": True, "left": True},
"linewidth": 1.5,
},
"presentation": {
# Larger sans-serif fonts, lower DPI for slides
"font_family": "sans-serif",
"font_size": 14,
"math_fontset": "dejavusans",
"dpi": 150,
"figure_size": [10.0, 7.5],
"grid": True,
"grid_alpha": 0.4,
"spine_visibility": {"top": True, "right": True, "bottom": True, "left": True},
"linewidth": 2.0,
},
"nature": {
# Nature journal specs: 3.5" width (single column), Arial font
"font_family": "sans-serif",
"font_size": 8,
"math_fontset": "dejavusans",
"dpi": 300,
"figure_size": [3.5, 2.625], # Single column width, 4:3 aspect
"grid": False,
"grid_alpha": 0.0,
"spine_visibility": {
"top": False,
"right": False,
"bottom": True,
"left": True,
},
"linewidth": 1.0,
},
"science": {
# Science journal specifications
"font_family": "sans-serif",
"font_size": 9,
"math_fontset": "dejavusans",
"dpi": 300,
"figure_size": [3.5, 2.625],
"grid": False,
"grid_alpha": 0.0,
"spine_visibility": {
"top": False,
"right": False,
"bottom": True,
"left": True,
},
"linewidth": 1.0,
},
"minimal": {
# No top/right spines, no grid, clean look
"font_family": "sans-serif",
"font_size": 10,
"math_fontset": "dejavusans",
"dpi": 300,
"figure_size": [6.0, 4.5],
"grid": False,
"grid_alpha": 0.0,
"spine_visibility": {
"top": False,
"right": False,
"bottom": True,
"left": True,
},
"linewidth": 1.5,
},
}
# =============================================================================
# FitVisualizer Class
# =============================================================================
[docs]
class FitVisualizer:
"""Visualizer for curve fitting results.
Generates publication-quality plots including combined fit + residuals
layouts, histograms, and confidence bands.
Attributes
----------
None
Methods
-------
generate(result, data, model, config)
Generate all configured visualizations and save to files.
Examples
--------
>>> visualizer = FitVisualizer()
>>> result = {"popt": [1.0, 0.5], "pcov": [[0.01, 0], [0, 0.02]], ...}
>>> data = {"xdata": x, "ydata": y}
>>> config = {"visualization": {"enabled": True, "output_dir": "figures"}}
>>> output_paths = visualizer.generate(result, data, model, config)
"""
[docs]
def generate(
self,
result: dict[str, Any],
data: dict[str, Any],
model: Callable,
config: dict[str, Any],
) -> list[str]:
"""Generate visualizations based on configuration.
Parameters
----------
result : dict
Fit result dictionary containing:
- popt: Fitted parameters
- pcov: Covariance matrix
- fun: Residuals (optional)
- statistics: Dict with r_squared, rmse, etc.
data : dict
Data dictionary containing:
- xdata: Independent variable array
- ydata: Dependent variable array
- sigma: Uncertainties (optional)
model : callable
Model function ``f(x, *params)``.
config : dict
Configuration dictionary with visualization section.
Returns
-------
list[str]
List of output file paths that were generated.
"""
vis_config = config.get("visualization", {})
if not vis_config.get("enabled", True):
return []
output_paths: list[str] = []
output_dir = Path(vis_config.get("output_dir", "figures"))
output_dir.mkdir(parents=True, exist_ok=True)
filename_prefix = vis_config.get("filename_prefix", "fit")
formats = vis_config.get("formats", ["pdf", "png"])
# Apply style preset
self._apply_style_preset(vis_config)
# Generate combined plot (main + residuals)
fig_combined = self._create_combined_figure(result, data, model, vis_config)
combined_paths = self._save_figure(
fig_combined, output_dir, f"{filename_prefix}_combined", formats, vis_config
)
output_paths.extend(combined_paths)
plt.close(fig_combined)
# Generate histogram if enabled
histogram_config = vis_config.get("histogram", {})
if histogram_config.get("enabled", False):
residuals = self._get_residuals(result, data, model)
if residuals is not None:
fig_hist = self._create_histogram_figure(residuals, vis_config)
hist_paths = self._save_figure(
fig_hist,
output_dir,
f"{filename_prefix}_histogram",
formats,
vis_config,
)
output_paths.extend(hist_paths)
plt.close(fig_hist)
return output_paths
def _apply_style_preset(self, config: dict[str, Any]) -> None:
"""Apply style preset to matplotlib rcParams.
Parameters
----------
config : dict
Visualization configuration containing style preset name.
"""
style_name = config.get("style", "publication")
preset = STYLE_PRESETS.get(style_name, STYLE_PRESETS["publication"])
# Apply font settings
plt.rcParams["font.family"] = preset.get("font_family", "serif")
plt.rcParams["font.size"] = preset.get("font_size", 10)
plt.rcParams["mathtext.fontset"] = preset.get("math_fontset", "cm")
# Apply line settings
plt.rcParams["lines.linewidth"] = preset.get("linewidth", 1.5)
# Override with config-specific font settings if provided
font_config = config.get("font", {})
if "family" in font_config:
plt.rcParams["font.family"] = font_config["family"]
if "size" in font_config:
plt.rcParams["font.size"] = font_config["size"]
if "math_fontset" in font_config:
plt.rcParams["mathtext.fontset"] = font_config["math_fontset"]
def _get_color_scheme(self, config: dict[str, Any]) -> dict[str, str]:
"""Get the active color scheme from configuration.
Parameters
----------
config : dict
Visualization configuration.
Returns
-------
dict
Dictionary mapping element names to hex colors.
"""
active_scheme = config.get("active_scheme", "default")
color_schemes = config.get("color_schemes", {})
# Default fallback colors
default_colors = {
"data": "#1f77b4",
"fit": "#d62728",
"residuals": "#2ca02c",
"confidence": "#ff7f0e",
}
return color_schemes.get(active_scheme, default_colors)
def _create_combined_figure(
self,
result: dict[str, Any],
data: dict[str, Any],
model: Callable,
config: dict[str, Any],
) -> plt.Figure:
"""Create combined figure with main plot and residuals.
Parameters
----------
result : dict
Fit result dictionary.
data : dict
Data dictionary.
model : callable
Model function.
config : dict
Visualization configuration.
Returns
-------
matplotlib.figure.Figure
The combined figure with two subplots.
"""
style_name = config.get("style", "publication")
preset = STYLE_PRESETS.get(style_name, STYLE_PRESETS["publication"])
# Get figure size from config or preset
figure_size = config.get("figure_size", preset.get("figure_size", [6.0, 4.5]))
# Check if residuals plot is enabled
residuals_config = config.get("residuals_plot", {})
show_residuals = residuals_config.get("enabled", True)
if show_residuals:
# Create figure with 2 subplots (3:1 height ratio)
# Use constrained layout for better handling of shared axes
fig, (ax_main, ax_residuals) = plt.subplots(
2,
1,
figsize=figure_size,
gridspec_kw={"height_ratios": [3, 1], "hspace": 0.1},
sharex=True,
layout="constrained",
)
else:
fig, ax_main = plt.subplots(figsize=figure_size, layout="constrained")
ax_residuals = None
# Get color scheme
colors = self._get_color_scheme(config)
# Extract data
xdata = np.asarray(data.get("xdata", []))
ydata = np.asarray(data.get("ydata", []))
sigma = data.get("sigma")
if sigma is not None:
sigma = np.asarray(sigma)
popt = np.asarray(result.get("popt", []))
pcov = result.get("pcov")
if pcov is not None:
pcov = np.asarray(pcov)
# Plot main figure
self._plot_main(
ax_main, xdata, ydata, sigma, popt, pcov, model, config, colors, result
)
# Plot residuals if enabled
if ax_residuals is not None:
residuals = self._get_residuals(result, data, model)
if residuals is not None:
self._plot_residuals(ax_residuals, xdata, residuals, config, colors)
# Apply spine visibility
spine_vis = preset.get("spine_visibility", {})
for spine, visible in spine_vis.items():
if spine in ax_main.spines:
ax_main.spines[spine].set_visible(visible)
if ax_residuals is not None and spine in ax_residuals.spines:
ax_residuals.spines[spine].set_visible(visible)
return fig
def _plot_main(
self,
ax: plt.Axes,
xdata: np.ndarray,
ydata: np.ndarray,
sigma: np.ndarray | None,
popt: np.ndarray,
pcov: np.ndarray | None,
model: Callable,
config: dict[str, Any],
colors: dict[str, str],
result: dict[str, Any],
) -> None:
"""Plot main fit data and curve.
Parameters
----------
ax : matplotlib.axes.Axes
The axes to plot on.
xdata : ndarray
X data points.
ydata : ndarray
Y data points.
sigma : ndarray or None
Y uncertainties.
popt : ndarray
Fitted parameters.
pcov : ndarray or None
Covariance matrix.
model : callable
Model function.
config : dict
Visualization configuration.
colors : dict
Color scheme dictionary.
result : dict
Fit result dictionary (for statistics annotation).
"""
main_config = config.get("main_plot", {})
data_config = main_config.get("data", {})
fit_config = main_config.get("fit", {})
confidence_config = main_config.get("confidence_band", {})
# Get data color from config or color scheme
data_color = data_config.get("color", colors.get("data", "#1f77b4"))
fit_color = fit_config.get("color", colors.get("fit", "#d62728"))
# Plot data points
marker = data_config.get("marker", "o")
size = data_config.get("size", 20)
alpha = data_config.get("alpha", 0.7)
data_label = data_config.get("label", "Data")
if sigma is not None and data_config.get("show_errorbars", True):
ax.errorbar(
xdata,
ydata,
yerr=sigma,
fmt=marker,
color=data_color,
markersize=np.sqrt(size),
alpha=alpha,
label=data_label,
capsize=data_config.get("capsize", 2),
)
else:
ax.scatter(
xdata,
ydata,
marker=marker,
s=size,
c=data_color,
alpha=alpha,
label=data_label,
)
# Generate fit curve
n_fit_points = fit_config.get("n_points", 500)
x_fit = np.linspace(xdata.min(), xdata.max(), n_fit_points)
if len(popt) > 0:
y_fit = model(x_fit, *popt)
# Plot confidence band if enabled
if confidence_config.get("enabled", False) and pcov is not None:
confidence_level = confidence_config.get("level", 0.95)
confidence_color = confidence_config.get(
"color", colors.get("confidence", fit_color)
)
confidence_alpha = confidence_config.get("alpha", 0.2)
lower, upper = self._calculate_confidence_band(
model, x_fit, popt, pcov, confidence_level
)
ax.fill_between(
x_fit,
lower,
upper,
color=confidence_color,
alpha=confidence_alpha,
label=f"{confidence_level * 100:.0f}% CI",
)
# Plot fit curve
ax.plot(
x_fit,
y_fit,
color=fit_color,
linewidth=fit_config.get("linewidth", 1.5),
linestyle=fit_config.get("linestyle", "-"),
label=fit_config.get("label", "Fit"),
)
# Set labels
ax.set_xlabel(main_config.get("x_label", "x"))
ax.set_ylabel(main_config.get("y_label", "y"))
if main_config.get("title"):
ax.set_title(main_config["title"])
# Grid
if main_config.get("show_grid", True):
ax.grid(True, alpha=main_config.get("grid_alpha", 0.3))
# Legend
legend_config = main_config.get("legend", {})
if legend_config.get("enabled", True):
ax.legend(
loc=legend_config.get("location", "best"),
frameon=legend_config.get("frameon", True),
fontsize=legend_config.get("fontsize"),
)
# Annotation (fit statistics)
annotation_config = main_config.get("annotation", {})
if annotation_config.get("enabled", False):
self._add_statistics_annotation(ax, result, annotation_config)
def _plot_residuals(
self,
ax: plt.Axes,
xdata: np.ndarray,
residuals: np.ndarray,
config: dict[str, Any],
colors: dict[str, str],
) -> None:
"""Plot residuals subplot.
Parameters
----------
ax : matplotlib.axes.Axes
The axes to plot on.
xdata : ndarray
X data points.
residuals : ndarray
Residual values.
config : dict
Visualization configuration.
colors : dict
Color scheme dictionary.
"""
residuals_config = config.get("residuals_plot", {})
# Get color from config or scheme
residual_color = residuals_config.get(
"color", colors.get("residuals", "#2ca02c")
)
# Plot residuals
plot_type = residuals_config.get("type", "scatter")
marker = residuals_config.get("marker", "o")
size = residuals_config.get("size", 15)
alpha = residuals_config.get("alpha", 0.7)
if plot_type == "scatter":
ax.scatter(
xdata, residuals, marker=marker, s=size, c=residual_color, alpha=alpha
)
elif plot_type == "stem":
markerline, stemlines, baseline = ax.stem(xdata, residuals)
plt.setp(markerline, color=residual_color, markersize=np.sqrt(size))
plt.setp(stemlines, color=residual_color, alpha=alpha)
plt.setp(baseline, visible=False)
else: # line
ax.plot(xdata, residuals, marker=marker, color=residual_color, alpha=alpha)
# Zero reference line
if residuals_config.get("show_zero_line", True):
ax.axhline(
0,
linestyle=residuals_config.get("zero_line_style", "--"),
color=residuals_config.get("zero_line_color", "gray"),
linewidth=residuals_config.get("zero_line_width", 1.0),
)
# Standard deviation bands
std_config = residuals_config.get("std_bands", {})
if std_config.get("enabled", False):
std_residual = np.std(residuals)
levels = std_config.get("levels", [1, 2])
band_colors = std_config.get("colors", ["#fff3cd", "#ffe69c"])
band_alpha = std_config.get("alpha", 0.4)
for i, level in enumerate(reversed(levels)):
color = band_colors[min(i, len(band_colors) - 1)]
ax.axhspan(
-level * std_residual,
level * std_residual,
color=color,
alpha=band_alpha,
zorder=0,
)
# Labels
ax.set_xlabel(residuals_config.get("x_label", "x"))
ax.set_ylabel(residuals_config.get("y_label", "Residual"))
if residuals_config.get("title"):
ax.set_title(residuals_config["title"])
def _create_histogram_figure(
self,
residuals: np.ndarray,
config: dict[str, Any],
) -> plt.Figure:
"""Create histogram of residuals.
Parameters
----------
residuals : ndarray
Residual values.
config : dict
Visualization configuration.
Returns
-------
matplotlib.figure.Figure
The histogram figure.
"""
style_name = config.get("style", "publication")
preset = STYLE_PRESETS.get(style_name, STYLE_PRESETS["publication"])
figure_size = config.get("figure_size", preset.get("figure_size", [6.0, 4.5]))
fig, ax = plt.subplots(figsize=figure_size, layout="constrained")
histogram_config = config.get("histogram", {})
# Get bins
bins = histogram_config.get("bins", "auto")
if bins == "sqrt":
bins = int(np.sqrt(len(residuals)))
elif bins == "sturges":
bins = int(np.ceil(np.log2(len(residuals))) + 1)
# Get colors
colors = self._get_color_scheme(config)
bar_color = histogram_config.get("color", colors.get("residuals", "#9467bd"))
bar_alpha = histogram_config.get("alpha", 0.7)
edgecolor = histogram_config.get("edgecolor", "white")
# Plot histogram
_n, _bin_edges, _patches = ax.hist(
residuals,
bins=bins,
color=bar_color,
alpha=bar_alpha,
edgecolor=edgecolor,
density=True,
)
# Overlay normal distribution fit
if histogram_config.get("show_normal_fit", True):
mu, std = stats.norm.fit(residuals)
# Guard against std == 0 (e.g. perfect fit, all residuals identical),
# which would cause norm.pdf to return inf/nan.
if std > 0:
x_norm = np.linspace(residuals.min(), residuals.max(), 100)
y_norm = stats.norm.pdf(x_norm, mu, std)
normal_color = histogram_config.get(
"normal_color", colors.get("fit", "#d62728")
)
ax.plot(
x_norm, y_norm, color=normal_color, linewidth=2, label="Normal fit"
)
ax.legend()
# Labels
ax.set_xlabel(histogram_config.get("x_label", "Residual"))
ax.set_ylabel(histogram_config.get("y_label", "Frequency"))
if histogram_config.get("title"):
ax.set_title(histogram_config["title"])
return fig
def _calculate_confidence_band(
self,
model: Callable,
x: np.ndarray,
popt: np.ndarray,
pcov: np.ndarray,
confidence_level: float = 0.95,
) -> tuple[np.ndarray, np.ndarray]:
"""Calculate confidence bands using error propagation.
Uses the Jacobian of the model with respect to parameters and
the covariance matrix to compute prediction uncertainties.
Parameters
----------
model : callable
Model function ``f(x, *params)``.
x : ndarray
X values for computing the band.
popt : ndarray
Fitted parameters.
pcov : ndarray
Parameter covariance matrix.
confidence_level : float
Confidence level (default 0.95 for 95% CI).
Returns
-------
tuple[ndarray, ndarray]
Lower and upper bounds of the confidence band.
"""
n_points = len(x)
n_params = len(popt)
# Compute Jacobian numerically using finite differences
eps = 1e-8
jacobian = np.zeros((n_points, n_params))
y0 = model(x, *popt)
for i in range(n_params):
params_plus = popt.copy()
params_plus[i] += eps
y_plus = model(x, *params_plus)
jacobian[:, i] = (y_plus - y0) / eps
# Compute variance of predictions: var(y) = J @ pcov @ J.T (diagonal elements)
# For efficiency, compute element-wise
variance = np.zeros(n_points)
for i in range(n_points):
j_i = jacobian[i, :]
variance[i] = j_i @ pcov @ j_i
std_prediction = np.sqrt(np.maximum(variance, 0))
# Compute confidence interval using t-distribution
# For large samples, use normal distribution
alpha = 1 - confidence_level
z = stats.norm.ppf(1 - alpha / 2)
lower = y0 - z * std_prediction
upper = y0 + z * std_prediction
return lower, upper
def _get_residuals(
self,
result: dict[str, Any],
data: dict[str, Any],
model: Callable,
) -> np.ndarray | None:
"""Extract or compute residuals from result.
Parameters
----------
result : dict
Fit result dictionary.
data : dict
Data dictionary.
model : callable
Model function.
Returns
-------
ndarray or None
Residual values, or None if cannot be computed.
"""
# Try to get from result
if "fun" in result and result["fun"] is not None:
return np.asarray(result["fun"])
# Compute from data and fit
popt = result.get("popt")
if popt is None:
return None
xdata = data.get("xdata")
ydata = data.get("ydata")
if xdata is None or ydata is None:
return None
xdata = np.asarray(xdata)
ydata = np.asarray(ydata)
popt = np.asarray(popt)
try:
y_fit = model(xdata, *popt)
return ydata - y_fit
except Exception:
return None
def _add_statistics_annotation(
self,
ax: plt.Axes,
result: dict[str, Any],
annotation_config: dict[str, Any],
) -> None:
"""Add fit statistics annotation to plot.
Parameters
----------
ax : matplotlib.axes.Axes
The axes to annotate.
result : dict
Fit result dictionary containing statistics.
annotation_config : dict
Annotation configuration.
"""
lines = []
statistics = result.get("statistics", {})
if annotation_config.get("show_r_squared", True):
r_squared = statistics.get("r_squared")
if r_squared is not None:
lines.append(f"$R^2 = {r_squared:.4f}$")
if annotation_config.get("show_rmse", False):
rmse = statistics.get("rmse")
if rmse is not None:
lines.append(f"RMSE = {rmse:.4g}")
if annotation_config.get("show_chi_squared", False):
chi_sq = statistics.get("chi_squared")
if chi_sq is not None:
lines.append(f"$\\chi^2 = {chi_sq:.4g}$")
if not lines:
return
text = "\n".join(lines)
fontsize = annotation_config.get("fontsize", 9)
location = annotation_config.get("location", "upper right")
# Map location string to axes coordinates
location_map = {
"upper right": (0.95, 0.95),
"upper left": (0.05, 0.95),
"lower right": (0.95, 0.05),
"lower left": (0.05, 0.05),
"center": (0.5, 0.5),
}
coords = location_map.get(location, (0.95, 0.95))
# Determine alignment based on position
ha = "right" if coords[0] > 0.5 else "left"
va = "top" if coords[1] > 0.5 else "bottom"
ax.text(
coords[0],
coords[1],
text,
transform=ax.transAxes,
fontsize=fontsize,
verticalalignment=va,
horizontalalignment=ha,
bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.8},
)
def _save_figure(
self,
fig: plt.Figure,
output_dir: Path,
filename_base: str,
formats: list[str],
config: dict[str, Any],
) -> list[str]:
"""Save figure to multiple formats.
Parameters
----------
fig : matplotlib.figure.Figure
The figure to save.
output_dir : Path
Output directory path.
filename_base : str
Base filename without extension.
formats : list[str]
List of format extensions (e.g., ["pdf", "png"]).
config : dict
Visualization configuration.
Returns
-------
list[str]
List of saved file paths.
"""
style_name = config.get("style", "publication")
preset = STYLE_PRESETS.get(style_name, STYLE_PRESETS["publication"])
dpi = config.get("dpi", preset.get("dpi", 300))
output_paths = []
for fmt in formats:
output_path = output_dir / f"{filename_base}.{fmt}"
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r".*Glyph 65534.*",
category=UserWarning,
)
fig.savefig(
output_path,
format=fmt,
dpi=dpi,
bbox_inches="tight",
facecolor="white",
edgecolor="none",
)
output_paths.append(str(output_path))
return output_paths