""" Visualization utilities for ECFlow web app. Generates matplotlib figures for mechanism classification, parameter posteriors, and signal reconstruction overlays. """ import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec COLORS = { "primary": "#2563EB", "secondary": "#7C3AED", "accent": "#059669", "warm": "#DC2626", "neutral": "#6B7280", "bg": "#F9FAFB", } MECH_COLORS_EC = { "Nernst": "#3B82F6", "BV": "#8B5CF6", "MHC": "#EC4899", "Ads": "#F59E0B", "EC": "#10B981", "LH": "#EF4444", } MECH_COLORS_TPD = { "FirstOrder": "#3B82F6", "SecondOrder": "#8B5CF6", "LH_Surface": "#EC4899", "MvK": "#F59E0B", "FirstOrderCovDep": "#10B981", "DiffLimited": "#EF4444", } MECH_FULL_NAMES_EC = { "BV": "Butler–Volmer", "MHC": "Marcus–Hush–Chidsey", "Nernst": "Nernstian (reversible)", "Ads": "Adsorption-coupled", "EC": "EC mechanism", "LH": "Langmuir–Hinshelwood", } MECH_FULL_NAMES_TPD = { "FirstOrder": "1st-Order", "SecondOrder": "2nd-Order", "LH_Surface": "LH Surface", "MvK": "Mars–van Krevelen", "FirstOrderCovDep": "1st-Order Cov-Dep", "DiffLimited": "Diff-Limited", } def plot_mechanism_probs(probs_dict, domain="ec"): """ Horizontal bar chart of mechanism classification probabilities. Args: probs_dict: {mechanism_name: probability} domain: 'ec' or 'tpd' Returns: matplotlib Figure """ colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD full_names = MECH_FULL_NAMES_EC if domain == "ec" else MECH_FULL_NAMES_TPD names = list(probs_dict.keys()) probs = [probs_dict[n] for n in names] sorted_idx = np.argsort(probs) names = [names[i] for i in sorted_idx] probs = [probs[i] for i in sorted_idx] bar_colors = [colors.get(n, COLORS["neutral"]) for n in names] display_names = [f"{n} ({full_names.get(n, n)})" for n in names] fig, ax = plt.subplots(figsize=(9, max(3, len(names) * 0.7))) bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white", linewidth=0.5, height=0.7) ax.set_yticks(range(len(names))) ax.set_yticklabels(display_names, fontsize=11, fontweight="medium") ax.set_xlim(0, 1.05) ax.set_xlabel("Probability", fontsize=12) ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15) for i, (bar, prob) in enumerate(zip(bars, probs)): if prob > 0.05: ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2, f"{prob:.1%}", va="center", fontsize=11, fontweight="bold") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.grid(axis="x", alpha=0.3, linestyle="--") fig.tight_layout() return fig def plot_posteriors(samples, param_names, mechanism_name, domain="ec"): """ Violin plots of posterior distributions for each parameter. Args: samples: [n_samples, D] array of posterior samples param_names: list of parameter names mechanism_name: name of the mechanism domain: 'ec' or 'tpd' Returns: matplotlib Figure """ n_params = len(param_names) fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5)) if n_params == 1: axes = [axes] colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD color = colors.get(mechanism_name, COLORS["primary"]) for i, (ax, name) in enumerate(zip(axes, param_names)): data = samples[:, i] parts = ax.violinplot(data, positions=[0], showmeans=True, showmedians=True, showextrema=False) for pc in parts["bodies"]: pc.set_facecolor(color) pc.set_alpha(0.6) parts["cmeans"].set_color("black") parts["cmedians"].set_color(COLORS["warm"]) q05, q95 = np.quantile(data, [0.05, 0.95]) ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium") ax.set_xticks([]) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) mean_val = data.mean() ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes, ha="center", fontsize=9, color=COLORS["neutral"]) fig.suptitle(f"Parameter Posteriors — {mechanism_name}", fontsize=14, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.93]) return fig def plot_reconstruction(observed_curves, recon_curves, domain="ec", nrmses=None, r2s=None, scan_labels=None): """ Overlay of observed vs reconstructed signals with optional metrics. Args: observed_curves: list of dicts with 'x' and 'y' arrays recon_curves: list of dicts with 'x' and 'y' arrays (same length) domain: 'ec' or 'tpd' nrmses: optional list of NRMSE values per curve r2s: optional list of R2 values per curve scan_labels: optional list of label strings per curve Returns: matplotlib Figure """ n_curves = len(observed_curves) fig, axes = plt.subplots(1, min(n_curves, 4), figsize=(max(5, 4 * min(n_curves, 4)), 5), squeeze=False) axes = axes[0] xlabel = "Potential (\u03b8)" if domain == "ec" else "Temperature (K)" ylabel = "Flux" if domain == "ec" else "Rate" for i, ax in enumerate(axes): if i >= n_curves: ax.set_visible(False) continue obs = observed_curves[i] rec = recon_curves[i] ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5, label="Observed", alpha=0.8) ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5, label="Reconstructed", linestyle="--") ax.set_xlabel(xlabel, fontsize=10) if i == 0: ax.set_ylabel(ylabel, fontsize=10) ax.legend(fontsize=8, framealpha=0.8, loc="best") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if scan_labels and i < len(scan_labels): title = scan_labels[i] elif domain == "ec": title = f"Scan rate {i + 1}" else: title = f"Heating rate {i + 1}" ax.set_title(title, fontsize=10) metrics_parts = [] if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]): metrics_parts.append(f"NRMSE={nrmses[i]:.4f}") if r2s and i < len(r2s) and np.isfinite(r2s[i]): metrics_parts.append(f"R\u00b2={r2s[i]:.4f}") if metrics_parts: ax.text(0.02, 0.98, " ".join(metrics_parts), transform=ax.transAxes, fontsize=8, va="top", color=COLORS["accent"], fontweight="bold", bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8, edgecolor=COLORS["accent"])) suptitle = "Signal Reconstruction" if nrmses and r2s: valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] if valid_nrmse and valid_r2: avg_nrmse = np.mean(valid_nrmse) avg_r2 = np.mean(valid_r2) suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})" fig.suptitle(suptitle, fontsize=12, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.93]) return fig def _add_sweep_arrows(ax, pot, y_ox, y_red, mid, show_labels=False): """Add direction arrows for forward/reverse sweeps on both species.""" sweep_specs = [ (slice(None, mid), 0.35), (slice(mid, None), 0.65), ] curves = [ (y_ox, COLORS["primary"]), (y_red, COLORS["warm"]), ] for y_data, color in curves: for segment, frac in sweep_specs: x_seg = pot[segment] y_seg = y_data[segment] n = len(x_seg) if n < 10: continue idx = int(n * frac) idx = max(2, min(idx, n - 3)) step = max(1, n // 30) i0 = max(0, idx - step) i1 = min(n - 1, idx + step) ax.annotate( "", xy=(x_seg[i1], y_seg[i1]), xytext=(x_seg[i0], y_seg[i0]), arrowprops=dict(arrowstyle="-|>", color=color, lw=1.8, mutation_scale=14), ) def plot_concentration_profiles(conc_curves, scan_labels=None): """ Plot surface concentration profiles (C_A and C_B) vs potential. Args: conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red', or None for failed reconstructions scan_labels: optional list of label strings per curve Returns: matplotlib Figure, or None if no valid data """ valid = [c for c in conc_curves if c is not None] if not valid: return None n_curves = len(conc_curves) fig, axes = plt.subplots(1, min(n_curves, 4), figsize=(max(5, 4 * min(n_curves, 4)), 5), squeeze=False) axes = axes[0] for i, ax in enumerate(axes): if i >= n_curves or conc_curves[i] is None: ax.set_visible(False) continue c = conc_curves[i] pot = np.asarray(c["x"]) c_ox = np.asarray(c["c_ox"]) c_red = np.asarray(c["c_red"]) mid = len(pot) // 2 # Forward sweep (reductive): first half ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5, label="C$_A$ (ox)") ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5, label="C$_B$ (red)") # Reverse sweep (oxidative): second half ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5) ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5) _add_sweep_arrows(ax, pot, c_ox, c_red, mid) ax.set_xlabel("Potential (\u03b8)", fontsize=10) if i == 0: ax.set_ylabel("Surface concentration", fontsize=10) ax.legend(fontsize=8, framealpha=0.8, loc="best") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if scan_labels and i < len(scan_labels): ax.set_title(scan_labels[i], fontsize=10) else: ax.set_title(f"Scan rate {i + 1}", fontsize=10) fig.suptitle("Surface Concentration Profiles", fontsize=12, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.93]) return fig def plot_parameter_table(param_stats, mechanism_name): """ Create a formatted parameter summary table as a figure. Args: param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95' mechanism_name: name of the mechanism Returns: matplotlib Figure """ names = param_stats["names"] means = param_stats["mean"] stds = param_stats["std"] q05s = param_stats["q05"] q95s = param_stats["q95"] n = len(names) fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1))) ax.axis("off") col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"] cell_text = [] for i in range(n): cell_text.append([ _format_param_name(names[i]), f"{means[i]:.4f}", f"{stds[i]:.4f}", f"{q05s[i]:.4f}", f"{q95s[i]:.4f}", ]) table = ax.table(cellText=cell_text, colLabels=col_labels, loc="center", cellLoc="center") table.auto_set_font_size(False) table.set_fontsize(11) table.scale(1.0, 1.5) for (row, col), cell in table.get_celld().items(): if row == 0: cell.set_facecolor("#E5E7EB") cell.set_text_props(fontweight="bold") else: cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white") ax.set_title(f"Parameter Estimates — {mechanism_name}", fontsize=14, fontweight="bold", pad=20) fig.tight_layout() return fig def _format_param_name(name): """Format parameter names for display.""" replacements = { "log10(K0)": "log₁₀(K₀)", "log10(dB)": "log₁₀(d_B)", "log10(dA)": "log₁₀(d_A)", "log10(kc)": "log₁₀(k_c)", "log10(reorg_e)": "log₁₀(λ)", "log10(Gamma_sat)": "log₁₀(Γ_sat)", "log10(KA_eq)": "log₁₀(K_A,eq)", "log10(KB_eq)": "log₁₀(K_B,eq)", "log10(nu)": "log₁₀(ν)", "log10(nu_red)": "log₁₀(ν_red)", "log10(D0)": "log₁₀(D₀)", "E0_offset": "E₀ offset", "alpha": "α", "alpha_cov": "α_cov", "Ed": "E_d (K)", "Ed0": "E_d0 (K)", "Ea": "E_a (K)", "Ea_red": "E_a,red (K)", "Ea_reox": "E_a,reox (K)", "E_diff": "E_diff (K)", "theta_0": "θ₀", "theta_A0": "θ_A0", "theta_B0": "θ_B0", "theta_O0": "θ_O0", } return replacements.get(name, name)