⚡ ECFlow
" "Upload cyclic voltammetry or TPD data to identify the reaction mechanism " "and infer kinetic parameters with full Bayesian uncertainty.
" """" ECFlow — Amortized Bayesian Inference for Electrochemistry & Catalysis Gradio web interface for mechanism classification and parameter inference from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data. """ import os import sys import json import tempfile from pathlib import Path import numpy as np import gradio as gr from inference import ECFlowPredictor from preprocessing import ( nondimensionalize_cv, estimate_E0, parse_cv_csv, parse_tpd_csv, ) from plotting import ( plot_mechanism_probs, plot_posteriors, plot_parameter_table, plot_reconstruction, plot_concentration_profiles, ) # --------------------------------------------------------------------------- # Model paths (relative to repo root) # --------------------------------------------------------------------------- REPO_ROOT = Path(__file__).resolve().parent DEMO_DIR = REPO_ROOT / "demo_data" DEMO_RENDERS = REPO_ROOT / "demo_renders" EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt" TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt" # Allow override via environment variables EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT))) TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT))) # --------------------------------------------------------------------------- # Demo examples # --------------------------------------------------------------------------- def _discover_examples(): """Scan demo_data/ for metadata files and build example catalogs.""" cv_examples = {} tpd_examples = {} if not DEMO_DIR.is_dir(): return cv_examples, tpd_examples for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")): with open(meta_path) as f: meta = json.load(f) mech = meta["mechanism"] csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]] if meta_path.name.startswith("ec_"): rates = meta.get("scan_rates_Vs", []) rates_str = ", ".join(f"{r:.4g}" for r in rates) phys = meta.get("physical_params", {}) cv_examples[f"CV — {mech}"] = { "files": csv_files, "scan_rates": rates_str, "E0_V": phys.get("E0_V"), "T_K": phys.get("T_K", 298.15), "A_cm2": phys.get("A_cm2", 0.0707), "C_mM": phys.get("C_mM", 1.0), "D_cm2s": phys.get("D_cm2s", 1e-5), "n_electrons": phys.get("n_electrons", 1), } elif meta_path.name.startswith("tpd_"): betas = meta.get("betas_Ks", []) betas_str = ", ".join(f"{b:.4g}" for b in betas) tpd_examples[f"TPD — {mech}"] = { "files": csv_files, "heating_rates": betas_str, } return cv_examples, tpd_examples CV_EXAMPLES, TPD_EXAMPLES = _discover_examples() def _discover_image_examples(): """Build image example catalogs from demo_renders/ directory.""" import re cv_img_examples = {} tpd_img_examples = {} cv_output_renders = {} tpd_output_renders = {} if not DEMO_RENDERS.is_dir(): return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")): m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name) if not m: continue mech, rate_mVs = m.group(1), int(m.group(2)) rate_Vs = rate_mVs / 1000.0 cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs)) for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")): m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name) if not m: continue mech, idx = m.group(1), m.group(2) tpd_img_examples.setdefault(mech, []).append((str(p), idx)) for mech in cv_img_examples: renders = {} for suffix in ["classification", "posteriors", "reconstruction", "concentration"]: rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png" if rp.exists(): renders[suffix] = str(rp) if renders: cv_output_renders[mech] = renders for mech in tpd_img_examples: renders = {} for suffix in ["classification", "posteriors", "reconstruction"]: rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png" if rp.exists(): renders[suffix] = str(rp) if renders: tpd_output_renders[mech] = renders return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders (CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES, CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples() def _load_cv_image_example(mech_name): """Return (files, scan_rates_str) for a CV image example.""" if not mech_name or mech_name not in CV_IMG_EXAMPLES: return [gr.update()] * 2 entries = CV_IMG_EXAMPLES[mech_name] files = [e[0] for e in entries] rates_str = ", ".join(f"{e[1]}" for e in entries) return files, rates_str def _load_tpd_image_example(mech_name): """Return (image_path, heating_rate_str) for a TPD image example.""" if not mech_name or mech_name not in TPD_IMG_EXAMPLES: return [gr.update()] * 2 entries = TPD_IMG_EXAMPLES[mech_name] return entries[0][0], entries[0][1] def _load_cv_example(example_name): """Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example.""" if not example_name or example_name not in CV_EXAMPLES: return [gr.update()] * 8 ex = CV_EXAMPLES[example_name] return ( ex["files"], ex["scan_rates"], ex["E0_V"], ex["T_K"], ex["A_cm2"], ex["C_mM"], ex["D_cm2s"], ex["n_electrons"], ) def _load_tpd_example(example_name): """Return (files, heating_rates) for the chosen TPD example.""" if not example_name or example_name not in TPD_EXAMPLES: return [gr.update()] * 2 ex = TPD_EXAMPLES[example_name] return ( ex["files"], ex["heating_rates"], ) predictor = None def get_predictor(): global predictor if predictor is None: ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None predictor = ECFlowPredictor( ec_checkpoint=ec_ckpt, tpd_checkpoint=tpd_ckpt, device="cpu", ) return predictor # ========================================================================= # CV Analysis # ========================================================================= def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2, C_mM, D_cm2s, n_electrons, n_samples): """Analyze CV data from potentiostat CSV files. Accepts CSV files with columns for potential (V) and current (A/mA/µA). If the CSV includes a Time (s) column, the scan rate is auto-detected. Otherwise, scan rates must be provided. """ if not files: return _ec_error("Please upload at least one CSV file.") scan_rates_text = scan_rates_text.strip() if scan_rates_text else "" user_rates = None if scan_rates_text: try: user_rates = [float(s.strip()) for s in scan_rates_text.split(",")] except ValueError: return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") if len(files) != len(user_rates): return _ec_error( f"Number of files ({len(files)}) must match number of " f"scan rates ({len(user_rates)}).") C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6 n = int(n_electrons) if n_electrons else 1 T = float(T_K) if T_K else 298.15 A = float(A_cm2) if A_cm2 else 0.0707 parsed_data = [] scan_rates = [] for idx, f in enumerate(files): content = Path(f.name).read_text() parsed = parse_cv_csv(content) parsed_data.append(parsed) if user_rates is not None: v = user_rates[idx] elif "scan_rate_Vs" in parsed: v = parsed["scan_rate_Vs"] else: return _ec_error( f"Cannot determine scan rate for file '{Path(f.name).name}'. " "Either provide scan rates (V/s) or upload CSV files that " "include a Time (s) column.") scan_rates.append(v) if E0_V: e0 = float(E0_V) e0_source = "user" else: e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data] e0 = float(np.median(e0_estimates)) e0_source = "auto" D = float(D_cm2s) if D_cm2s else 1e-5 potentials, fluxes, sigmas_list = [], [], [] for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)): E, i_A = parsed["E_V"], parsed["i_A"] theta, flux, sigma = nondimensionalize_cv( E, i_A, v, e0, T, A, C_molcm3, D, n ) potentials.append(theta) fluxes.append(flux) sigmas_list.append(sigma) return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) _CURRENT_UNIT_SCALES = { "µA": 1e-6, "mA": 1e-3, "A": 1.0, "nA": 1e-9, } def _guess_current_unit(i_values): """Guess the current unit from the magnitude of digitized values.""" i_max = np.max(np.abs(i_values)) if i_max > 1e3: return "nA" if i_max > 100: return "µA" if i_max > 0.1: return "µA" if i_max > 1e-4: return "mA" return "A" def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit, n_samples, x_min, x_max, y_min, y_max): """Analyze CV from uploaded plot images (one per scan rate). Extracts CV curves via image digitization, then nondimensionalizes and runs inference identically to the CSV path. Axis bounds are auto-detected via OCR — override in Advanced if needed. """ if not files: return _ec_error("Please upload at least one image.") try: from digitizer import digitize_plot, auto_detect_axis_bounds from PIL import Image as PILImage except ImportError: return _ec_error("Required libraries not available for image digitization.") scan_rate_text = scan_rate_text.strip() if scan_rate_text else "" if not scan_rate_text: return _ec_error("Please enter the scan rate(s) (V/s), comma-separated.") try: scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")] except ValueError: return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") if len(files) != len(scan_rates): return _ec_error( f"Number of images ({len(files)}) must match number of " f"scan rates ({len(scan_rates)}).") has_user_bounds = all( v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] ) D = 1e-5 T = 298.15 A = 0.0707 C_molcm3 = 1e-6 n = 1 # First pass: digitize all images and estimate E0 per image image_data = [] e0_estimates = [] for idx, f in enumerate(files): img_arr = np.array(PILImage.open(f.name).convert("RGB")) v_Vs = scan_rates[idx] if has_user_bounds: bounds = { "x_min": float(x_min), "x_max": float(x_max), "y_min": float(y_min), "y_max": float(y_max), } else: bounds = auto_detect_axis_bounds(img_arr) if bounds is None: return _ec_error( f"Could not auto-detect axis bounds for image {idx + 1}. " "Please enter E min, E max, I min, I max under " "'Axis overrides'.") try: E_V, I_raw = digitize_plot( img_arr, bounds["x_min"], bounds["x_max"], bounds["y_min"], bounds["y_max"], threshold=int(threshold), x_ticks=bounds.get("x_ticks"), y_ticks=bounds.get("y_ticks"), ) except Exception as exc: return _ec_error(f"Digitization failed for image {idx + 1}: {exc}") if current_unit and current_unit != "auto": i_unit = current_unit elif "y_unit" in bounds: i_unit = bounds["y_unit"] else: i_unit = _guess_current_unit(I_raw) i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6) i_A = I_raw * i_scale e0_estimates.append(float(estimate_E0(E_V, i_A))) image_data.append((E_V, i_A, v_Vs, i_unit, bounds)) # Determine E0: user-provided or median of per-image estimates if E0_V is not None and E0_V != 0: e0 = float(E0_V) e0_source = "user" else: e0 = float(np.median(e0_estimates)) e0_source = "auto" # Second pass: nondimensionalize with shared E0 potentials, fluxes, sigmas_list = [], [], [] preproc_parts = [] for E_V, i_A, v_Vs, i_unit, bounds in image_data: theta, flux, sigma = nondimensionalize_cv( E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n ) potentials.append(theta) fluxes.append(flux) sigmas_list.append(sigma) preproc_parts.append( f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, " f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, " f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})" ) return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) def _ec_error(msg=""): """Return empty outputs for EC error cases.""" return None, None, gr.Dropdown(choices=[], value=None), None, None, None, None def _run_ec_analysis(potentials, fluxes, sigmas, n_samples): """Core EC analysis: predict + reconstruct for top mechanism.""" pred = get_predictor() result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples)) top_mech = result["predicted_mechanism"] recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas) fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec") sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": result, "potentials": [p.tolist() for p in potentials], "fluxes": [f.tolist() for f in fluxes], "sigmas": sigmas, } fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( top_mech, result, recon, sigmas ) return ( fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, fig_conc, ) def _render_ec_mechanism(mech, result, recon, sigmas): """Render posteriors, param table, reconstruction, and concentration for one EC mechanism.""" stats = result["parameter_stats"].get(mech) samples = result["posterior_samples"].get(mech) fig_posteriors = None fig_table = None if stats and samples is not None: fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec") fig_table = plot_parameter_table(stats, mech) fig_recon = None fig_conc = None if recon is not None: scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None fig_recon = plot_reconstruction( recon["observed"], recon["reconstructed"], domain="ec", nrmses=recon.get("nrmse"), r2s=recon.get("r2"), scan_labels=scan_labels, ) conc_curves = recon.get("concentrations") if conc_curves: fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels) return fig_posteriors, fig_table, fig_recon, fig_conc def _on_ec_mechanism_change(mech_choice, state): """Callback when user selects a different EC mechanism from the dropdown.""" if not state or not mech_choice: return None, None, None, None mech = mech_choice.split(" (")[0] result = state["result"] potentials = [np.array(p) for p in state["potentials"]] fluxes = [np.array(f) for f in state["fluxes"]] sigmas = state["sigmas"] pred = get_predictor() recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech) return _render_ec_mechanism(mech, result, recon, sigmas) # ========================================================================= # TPD Analysis # ========================================================================= def analyze_tpd(files, heating_rates_text, n_samples): """Analyze TPD data.""" if not files: return _tpd_error("Please upload at least one CSV file.") temperatures, rates = [], [] csv_betas = [] for f in files: content = Path(f.name).read_text() parsed = parse_tpd_csv(content) temperatures.append(parsed["T_K"]) rates.append(parsed["signal"]) if "beta_Ks" in parsed: csv_betas.append(parsed["beta_Ks"]) heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" if heating_rates_text: try: betas = [float(s.strip()) for s in heating_rates_text.split(",")] except ValueError: return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") if len(files) != len(betas): return _tpd_error( f"Number of files ({len(files)}) must match heating rates ({len(betas)}).") elif len(csv_betas) == len(files): betas = csv_betas else: return _tpd_error( "Please enter the heating rate (β in K/s) for each file. " "This value is critical for correct inference. " "Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.") return _run_tpd_analysis(temperatures, rates, betas, n_samples) def analyze_tpd_image(files, heating_rates_text, threshold, n_samples, x_min, x_max, y_min, y_max): """Analyze TPD from uploaded plot images (one per heating rate). Axis bounds are auto-detected via OCR — override in Advanced if needed. """ if not files: return _tpd_error("Please upload at least one image.") try: from digitizer import digitize_plot, auto_detect_axis_bounds from PIL import Image as PILImage except ImportError: return _tpd_error("Required libraries not available for image digitization.") heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" if not heating_rates_text: return _tpd_error( "Please enter the heating rate(s) (β in K/s), comma-separated. " "This value is critical for correct inference.") try: betas = [float(s.strip()) for s in heating_rates_text.split(",")] except ValueError: return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") if len(files) != len(betas): return _tpd_error( f"Number of images ({len(files)}) must match number of " f"heating rates ({len(betas)}).") has_user_bounds = all( v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] ) temperatures, rates = [], [] for idx, f in enumerate(files): img_arr = np.array(PILImage.open(f.name).convert("RGB")) if has_user_bounds: bounds = { "x_min": float(x_min), "x_max": float(x_max), "y_min": float(y_min), "y_max": float(y_max), } else: bounds = auto_detect_axis_bounds(img_arr) if bounds is None: return _tpd_error( f"Could not auto-detect axis bounds for image {idx + 1}. " "Please enter T min, T max, Signal min, Signal max " "under 'Axis overrides'.") try: x_data, y_data = digitize_plot( img_arr, bounds["x_min"], bounds["x_max"], bounds["y_min"], bounds["y_max"], threshold=int(threshold), x_ticks=bounds.get("x_ticks"), y_ticks=bounds.get("y_ticks"), ) except Exception as exc: return _tpd_error(f"Digitization failed for image {idx + 1}: {exc}") temperatures.append(x_data) rates.append(y_data) return _run_tpd_analysis(temperatures, rates, betas, n_samples) def _tpd_error(msg=""): """Return empty outputs for TPD error cases.""" return None, None, gr.Dropdown(choices=[], value=None), None, None, None def _run_tpd_analysis(temperatures, rates, betas, n_samples): """Core TPD analysis: predict + reconstruct for top mechanism.""" pred = get_predictor() result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples)) top_mech = result["predicted_mechanism"] recon = pred.reconstruct_tpd(result, temperatures, rates, betas) fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd") sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": result, "temperatures": [t.tolist() for t in temperatures], "rates": [r.tolist() for r in rates], "betas": betas, } fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas) return ( fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, ) def _render_tpd_mechanism(mech, result, recon, betas): """Render posteriors, param table, and reconstruction for one TPD mechanism.""" stats = result["parameter_stats"].get(mech) samples = result["posterior_samples"].get(mech) fig_posteriors = None fig_table = None if stats and samples is not None: fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd") fig_table = plot_parameter_table(stats, mech) fig_recon = None if recon is not None: scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None fig_recon = plot_reconstruction( recon["observed"], recon["reconstructed"], domain="tpd", nrmses=recon.get("nrmse"), r2s=recon.get("r2"), scan_labels=scan_labels, ) return fig_posteriors, fig_table, fig_recon def _on_tpd_mechanism_change(mech_choice, state): """Callback when user selects a different TPD mechanism from the dropdown.""" if not state or not mech_choice: return None, None, None mech = mech_choice.split(" (")[0] result = state["result"] temperatures = [np.array(t) for t in state["temperatures"]] rates = [np.array(r) for r in state["rates"]] betas = state["betas"] pred = get_predictor() recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech) return _render_tpd_mechanism(mech, result, recon, betas) # ========================================================================= # Shared helpers # ========================================================================= def download_results(result_text): """Create a downloadable JSON from the summary.""" if not result_text: return None tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") tmp.write(result_text) tmp.close() return tmp.name # ========================================================================= # Gradio UI # ========================================================================= def _build_ec_output_section(prefix): """Build shared output components for one EC input tab. Returns (probs, state, mech_dd, posteriors, param_table, recon, conc). """ gr.Markdown("---") probs = gr.Plot(label="Mechanism Classification") state = gr.State(value=None) mech_dd = gr.Dropdown( label="Select mechanism to view details", choices=[], interactive=True, ) with gr.Row(): posteriors = gr.Plot(label="Parameter Posteriors") param_table = gr.Plot(label="Parameter Estimates") recon = gr.Plot(label="Signal Reconstruction") conc = gr.Plot(label="Surface Concentration Profiles") return probs, state, mech_dd, posteriors, param_table, recon, conc def _build_tpd_output_section(prefix): """Build shared output components for one TPD input tab. Returns (probs, state, mech_dd, posteriors, param_table, recon). """ gr.Markdown("---") probs = gr.Plot(label="Mechanism Classification") state = gr.State(value=None) mech_dd = gr.Dropdown( label="Select mechanism to view details", choices=[], interactive=True, ) with gr.Row(): posteriors = gr.Plot(label="Parameter Posteriors") param_table = gr.Plot(label="Parameter Estimates") recon = gr.Plot(label="Signal Reconstruction") return probs, state, mech_dd, posteriors, param_table, recon CUSTOM_CSS = """ .main-header { text-align: center; padding: 24px 16px 8px 16px; } .main-header h1 { font-size: 2.2em; margin-bottom: 2px; letter-spacing: -0.5px; } .main-header p { color: #6B7280; font-size: 1.05em; max-width: 720px; margin: 0 auto; line-height: 1.5; } .section-heading { margin-top: 24px !important; margin-bottom: 4px !important; } .preproc-card { background: #F0F9FF; border: 1px solid #BAE6FD; border-radius: 8px; padding: 10px 16px; margin-top: 12px; font-size: 0.93em; color: #0C4A6E; } .summary-card { border: 1px solid #E5E7EB; border-radius: 10px; padding: 16px 20px; background: #FAFBFC; } .summary-card table { width: 100%; font-size: 0.92em; } .summary-card td, .summary-card th { padding: 4px 8px; } footer { display: none !important; } """ def build_app(): with gr.Blocks( title="ECFlow — Bayesian Inference for Electrochemistry & Catalysis", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="slate", font=gr.themes.GoogleFont("Inter"), ), css=CUSTOM_CSS, ) as app: gr.HTML( "
Upload cyclic voltammetry or TPD data to identify the reaction mechanism " "and infer kinetic parameters with full Bayesian uncertainty.
" "