| """ |
| ECFlow — Amortized Bayesian Inference for Electrochemistry & Catalysis |
| |
| Gradio web interface for mechanism classification and parameter inference |
| from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import gradio as gr |
|
|
|
|
|
|
| from inference import ECFlowPredictor |
| from preprocessing import ( |
| nondimensionalize_cv, |
| estimate_E0, |
| parse_cv_csv, |
| parse_tpd_csv, |
| ) |
| from plotting import ( |
| plot_mechanism_probs, |
| plot_posteriors, |
| plot_parameter_table, |
| plot_reconstruction, |
| plot_concentration_profiles, |
| ) |
|
|
| |
| |
| |
| REPO_ROOT = Path(__file__).resolve().parent |
| DEMO_DIR = REPO_ROOT / "demo_data" |
| DEMO_RENDERS = REPO_ROOT / "demo_renders" |
|
|
| EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt" |
| TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt" |
|
|
| |
| EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT))) |
| TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT))) |
|
|
| |
| |
| |
|
|
| def _discover_examples(): |
| """Scan demo_data/ for metadata files and build example catalogs.""" |
| cv_examples = {} |
| tpd_examples = {} |
| if not DEMO_DIR.is_dir(): |
| return cv_examples, tpd_examples |
| for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")): |
| with open(meta_path) as f: |
| meta = json.load(f) |
| mech = meta["mechanism"] |
| csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]] |
| if meta_path.name.startswith("ec_"): |
| rates = meta.get("scan_rates_Vs", []) |
| rates_str = ", ".join(f"{r:.4g}" for r in rates) |
| phys = meta.get("physical_params", {}) |
| cv_examples[f"CV — {mech}"] = { |
| "files": csv_files, |
| "scan_rates": rates_str, |
| "E0_V": phys.get("E0_V"), |
| "T_K": phys.get("T_K", 298.15), |
| "A_cm2": phys.get("A_cm2", 0.0707), |
| "C_mM": phys.get("C_mM", 1.0), |
| "D_cm2s": phys.get("D_cm2s", 1e-5), |
| "n_electrons": phys.get("n_electrons", 1), |
| } |
| elif meta_path.name.startswith("tpd_"): |
| betas = meta.get("betas_Ks", []) |
| betas_str = ", ".join(f"{b:.4g}" for b in betas) |
| tpd_examples[f"TPD — {mech}"] = { |
| "files": csv_files, |
| "heating_rates": betas_str, |
| } |
| return cv_examples, tpd_examples |
|
|
|
|
| CV_EXAMPLES, TPD_EXAMPLES = _discover_examples() |
|
|
|
|
| def _discover_image_examples(): |
| """Build image example catalogs from demo_renders/ directory.""" |
| import re |
| cv_img_examples = {} |
| tpd_img_examples = {} |
| cv_output_renders = {} |
| tpd_output_renders = {} |
|
|
| if not DEMO_RENDERS.is_dir(): |
| return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders |
|
|
| for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")): |
| m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name) |
| if not m: |
| continue |
| mech, rate_mVs = m.group(1), int(m.group(2)) |
| rate_Vs = rate_mVs / 1000.0 |
| cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs)) |
|
|
| for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")): |
| m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name) |
| if not m: |
| continue |
| mech, idx = m.group(1), m.group(2) |
| tpd_img_examples.setdefault(mech, []).append((str(p), idx)) |
|
|
| for mech in cv_img_examples: |
| renders = {} |
| for suffix in ["classification", "posteriors", "reconstruction", "concentration"]: |
| rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png" |
| if rp.exists(): |
| renders[suffix] = str(rp) |
| if renders: |
| cv_output_renders[mech] = renders |
|
|
| for mech in tpd_img_examples: |
| renders = {} |
| for suffix in ["classification", "posteriors", "reconstruction"]: |
| rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png" |
| if rp.exists(): |
| renders[suffix] = str(rp) |
| if renders: |
| tpd_output_renders[mech] = renders |
|
|
| return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders |
|
|
|
|
| (CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES, |
| CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples() |
|
|
|
|
| def _load_cv_image_example(mech_name): |
| """Return (files, scan_rates_str) for a CV image example.""" |
| if not mech_name or mech_name not in CV_IMG_EXAMPLES: |
| return [gr.update()] * 2 |
| entries = CV_IMG_EXAMPLES[mech_name] |
| files = [e[0] for e in entries] |
| rates_str = ", ".join(f"{e[1]}" for e in entries) |
| return files, rates_str |
|
|
|
|
| def _load_tpd_image_example(mech_name): |
| """Return (image_path, heating_rate_str) for a TPD image example.""" |
| if not mech_name or mech_name not in TPD_IMG_EXAMPLES: |
| return [gr.update()] * 2 |
| entries = TPD_IMG_EXAMPLES[mech_name] |
| return entries[0][0], entries[0][1] |
|
|
|
|
| def _load_cv_example(example_name): |
| """Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example.""" |
| if not example_name or example_name not in CV_EXAMPLES: |
| return [gr.update()] * 8 |
| ex = CV_EXAMPLES[example_name] |
| return ( |
| ex["files"], |
| ex["scan_rates"], |
| ex["E0_V"], |
| ex["T_K"], |
| ex["A_cm2"], |
| ex["C_mM"], |
| ex["D_cm2s"], |
| ex["n_electrons"], |
| ) |
|
|
|
|
| def _load_tpd_example(example_name): |
| """Return (files, heating_rates) for the chosen TPD example.""" |
| if not example_name or example_name not in TPD_EXAMPLES: |
| return [gr.update()] * 2 |
| ex = TPD_EXAMPLES[example_name] |
| return ( |
| ex["files"], |
| ex["heating_rates"], |
| ) |
|
|
|
|
| predictor = None |
|
|
|
|
| def get_predictor(): |
| global predictor |
| if predictor is None: |
| ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None |
| tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None |
| predictor = ECFlowPredictor( |
| ec_checkpoint=ec_ckpt, |
| tpd_checkpoint=tpd_ckpt, |
| device="cpu", |
| ) |
| return predictor |
|
|
|
|
| |
| |
| |
|
|
| def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2, |
| C_mM, D_cm2s, n_electrons, n_samples): |
| """Analyze CV data from potentiostat CSV files. |
| |
| Accepts CSV files with columns for potential (V) and current (A/mA/µA). |
| If the CSV includes a Time (s) column, the scan rate is auto-detected. |
| Otherwise, scan rates must be provided. |
| """ |
| if not files: |
| return _ec_error("Please upload at least one CSV file.") |
|
|
| scan_rates_text = scan_rates_text.strip() if scan_rates_text else "" |
|
|
| user_rates = None |
| if scan_rates_text: |
| try: |
| user_rates = [float(s.strip()) for s in scan_rates_text.split(",")] |
| except ValueError: |
| return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") |
| if len(files) != len(user_rates): |
| return _ec_error( |
| f"Number of files ({len(files)}) must match number of " |
| f"scan rates ({len(user_rates)}).") |
|
|
| C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6 |
| n = int(n_electrons) if n_electrons else 1 |
| T = float(T_K) if T_K else 298.15 |
| A = float(A_cm2) if A_cm2 else 0.0707 |
|
|
| parsed_data = [] |
| scan_rates = [] |
| for idx, f in enumerate(files): |
| content = Path(f.name).read_text() |
| parsed = parse_cv_csv(content) |
| parsed_data.append(parsed) |
|
|
| if user_rates is not None: |
| v = user_rates[idx] |
| elif "scan_rate_Vs" in parsed: |
| v = parsed["scan_rate_Vs"] |
| else: |
| return _ec_error( |
| f"Cannot determine scan rate for file '{Path(f.name).name}'. " |
| "Either provide scan rates (V/s) or upload CSV files that " |
| "include a Time (s) column.") |
| scan_rates.append(v) |
|
|
| if E0_V: |
| e0 = float(E0_V) |
| e0_source = "user" |
| else: |
| e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data] |
| e0 = float(np.median(e0_estimates)) |
| e0_source = "auto" |
|
|
| D = float(D_cm2s) if D_cm2s else 1e-5 |
|
|
| potentials, fluxes, sigmas_list = [], [], [] |
| for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)): |
| E, i_A = parsed["E_V"], parsed["i_A"] |
|
|
| theta, flux, sigma = nondimensionalize_cv( |
| E, i_A, v, e0, T, A, C_molcm3, D, n |
| ) |
| potentials.append(theta) |
| fluxes.append(flux) |
| sigmas_list.append(sigma) |
|
|
| return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) |
|
|
|
|
| _CURRENT_UNIT_SCALES = { |
| "µA": 1e-6, |
| "mA": 1e-3, |
| "A": 1.0, |
| "nA": 1e-9, |
| } |
|
|
|
|
| def _guess_current_unit(i_values): |
| """Guess the current unit from the magnitude of digitized values.""" |
| i_max = np.max(np.abs(i_values)) |
| if i_max > 1e3: |
| return "nA" |
| if i_max > 100: |
| return "µA" |
| if i_max > 0.1: |
| return "µA" |
| if i_max > 1e-4: |
| return "mA" |
| return "A" |
|
|
|
|
| def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit, |
| n_samples, x_min, x_max, y_min, y_max): |
| """Analyze CV from uploaded plot images (one per scan rate). |
| |
| Extracts CV curves via image digitization, then nondimensionalizes |
| and runs inference identically to the CSV path. |
| Axis bounds are auto-detected via OCR — override in Advanced if needed. |
| """ |
| if not files: |
| return _ec_error("Please upload at least one image.") |
|
|
| try: |
| from digitizer import digitize_plot, auto_detect_axis_bounds |
| from PIL import Image as PILImage |
| except ImportError: |
| return _ec_error("Required libraries not available for image digitization.") |
|
|
| scan_rate_text = scan_rate_text.strip() if scan_rate_text else "" |
| if not scan_rate_text: |
| return _ec_error("Please enter the scan rate(s) (V/s), comma-separated.") |
| try: |
| scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")] |
| except ValueError: |
| return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") |
|
|
| if len(files) != len(scan_rates): |
| return _ec_error( |
| f"Number of images ({len(files)}) must match number of " |
| f"scan rates ({len(scan_rates)}).") |
|
|
| has_user_bounds = all( |
| v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] |
| ) |
|
|
| D = 1e-5 |
| T = 298.15 |
| A = 0.0707 |
| C_molcm3 = 1e-6 |
| n = 1 |
|
|
| |
| image_data = [] |
| e0_estimates = [] |
| for idx, f in enumerate(files): |
| img_arr = np.array(PILImage.open(f.name).convert("RGB")) |
| v_Vs = scan_rates[idx] |
|
|
| if has_user_bounds: |
| bounds = { |
| "x_min": float(x_min), "x_max": float(x_max), |
| "y_min": float(y_min), "y_max": float(y_max), |
| } |
| else: |
| bounds = auto_detect_axis_bounds(img_arr) |
| if bounds is None: |
| return _ec_error( |
| f"Could not auto-detect axis bounds for image {idx + 1}. " |
| "Please enter E min, E max, I min, I max under " |
| "'Axis overrides'.") |
|
|
| try: |
| E_V, I_raw = digitize_plot( |
| img_arr, bounds["x_min"], bounds["x_max"], |
| bounds["y_min"], bounds["y_max"], |
| threshold=int(threshold), |
| x_ticks=bounds.get("x_ticks"), |
| y_ticks=bounds.get("y_ticks"), |
| ) |
| except Exception as exc: |
| return _ec_error(f"Digitization failed for image {idx + 1}: {exc}") |
|
|
| if current_unit and current_unit != "auto": |
| i_unit = current_unit |
| elif "y_unit" in bounds: |
| i_unit = bounds["y_unit"] |
| else: |
| i_unit = _guess_current_unit(I_raw) |
|
|
| i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6) |
| i_A = I_raw * i_scale |
|
|
| e0_estimates.append(float(estimate_E0(E_V, i_A))) |
| image_data.append((E_V, i_A, v_Vs, i_unit, bounds)) |
|
|
| |
| if E0_V is not None and E0_V != 0: |
| e0 = float(E0_V) |
| e0_source = "user" |
| else: |
| e0 = float(np.median(e0_estimates)) |
| e0_source = "auto" |
|
|
| |
| potentials, fluxes, sigmas_list = [], [], [] |
| preproc_parts = [] |
| for E_V, i_A, v_Vs, i_unit, bounds in image_data: |
| theta, flux, sigma = nondimensionalize_cv( |
| E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n |
| ) |
| potentials.append(theta) |
| fluxes.append(flux) |
| sigmas_list.append(sigma) |
|
|
| preproc_parts.append( |
| f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, " |
| f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, " |
| f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})" |
| ) |
|
|
| return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) |
|
|
|
|
| def _ec_error(msg=""): |
| """Return empty outputs for EC error cases.""" |
| return None, None, gr.Dropdown(choices=[], value=None), None, None, None, None |
|
|
|
|
| def _run_ec_analysis(potentials, fluxes, sigmas, n_samples): |
| """Core EC analysis: predict + reconstruct for top mechanism.""" |
| pred = get_predictor() |
| result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples)) |
| top_mech = result["predicted_mechanism"] |
| recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas) |
|
|
| fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec") |
|
|
| sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
|
|
| state = { |
| "result": result, |
| "potentials": [p.tolist() for p in potentials], |
| "fluxes": [f.tolist() for f in fluxes], |
| "sigmas": sigmas, |
| } |
|
|
| fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( |
| top_mech, result, recon, sigmas |
| ) |
|
|
| return ( |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, fig_conc, |
| ) |
|
|
|
|
| def _render_ec_mechanism(mech, result, recon, sigmas): |
| """Render posteriors, param table, reconstruction, and concentration for one EC mechanism.""" |
| stats = result["parameter_stats"].get(mech) |
| samples = result["posterior_samples"].get(mech) |
|
|
| fig_posteriors = None |
| fig_table = None |
| if stats and samples is not None: |
| fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec") |
| fig_table = plot_parameter_table(stats, mech) |
|
|
| fig_recon = None |
| fig_conc = None |
| if recon is not None: |
| scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None |
| fig_recon = plot_reconstruction( |
| recon["observed"], recon["reconstructed"], domain="ec", |
| nrmses=recon.get("nrmse"), r2s=recon.get("r2"), |
| scan_labels=scan_labels, |
| ) |
| conc_curves = recon.get("concentrations") |
| if conc_curves: |
| fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels) |
|
|
| return fig_posteriors, fig_table, fig_recon, fig_conc |
|
|
|
|
| def _on_ec_mechanism_change(mech_choice, state): |
| """Callback when user selects a different EC mechanism from the dropdown.""" |
| if not state or not mech_choice: |
| return None, None, None, None |
|
|
| mech = mech_choice.split(" (")[0] |
| result = state["result"] |
| potentials = [np.array(p) for p in state["potentials"]] |
| fluxes = [np.array(f) for f in state["fluxes"]] |
| sigmas = state["sigmas"] |
|
|
| pred = get_predictor() |
| recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech) |
| return _render_ec_mechanism(mech, result, recon, sigmas) |
|
|
|
|
| |
| |
| |
|
|
| def analyze_tpd(files, heating_rates_text, n_samples): |
| """Analyze TPD data.""" |
| if not files: |
| return _tpd_error("Please upload at least one CSV file.") |
|
|
| temperatures, rates = [], [] |
| csv_betas = [] |
| for f in files: |
| content = Path(f.name).read_text() |
| parsed = parse_tpd_csv(content) |
| temperatures.append(parsed["T_K"]) |
| rates.append(parsed["signal"]) |
| if "beta_Ks" in parsed: |
| csv_betas.append(parsed["beta_Ks"]) |
|
|
| heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" |
| if heating_rates_text: |
| try: |
| betas = [float(s.strip()) for s in heating_rates_text.split(",")] |
| except ValueError: |
| return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") |
| if len(files) != len(betas): |
| return _tpd_error( |
| f"Number of files ({len(files)}) must match heating rates ({len(betas)}).") |
| elif len(csv_betas) == len(files): |
| betas = csv_betas |
| else: |
| return _tpd_error( |
| "Please enter the heating rate (β in K/s) for each file. " |
| "This value is critical for correct inference. " |
| "Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.") |
|
|
| return _run_tpd_analysis(temperatures, rates, betas, n_samples) |
|
|
|
|
| def analyze_tpd_image(files, heating_rates_text, threshold, n_samples, |
| x_min, x_max, y_min, y_max): |
| """Analyze TPD from uploaded plot images (one per heating rate). |
| |
| Axis bounds are auto-detected via OCR — override in Advanced if needed. |
| """ |
| if not files: |
| return _tpd_error("Please upload at least one image.") |
|
|
| try: |
| from digitizer import digitize_plot, auto_detect_axis_bounds |
| from PIL import Image as PILImage |
| except ImportError: |
| return _tpd_error("Required libraries not available for image digitization.") |
|
|
| heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" |
| if not heating_rates_text: |
| return _tpd_error( |
| "Please enter the heating rate(s) (β in K/s), comma-separated. " |
| "This value is critical for correct inference.") |
| try: |
| betas = [float(s.strip()) for s in heating_rates_text.split(",")] |
| except ValueError: |
| return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") |
|
|
| if len(files) != len(betas): |
| return _tpd_error( |
| f"Number of images ({len(files)}) must match number of " |
| f"heating rates ({len(betas)}).") |
|
|
| has_user_bounds = all( |
| v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] |
| ) |
|
|
| temperatures, rates = [], [] |
| for idx, f in enumerate(files): |
| img_arr = np.array(PILImage.open(f.name).convert("RGB")) |
|
|
| if has_user_bounds: |
| bounds = { |
| "x_min": float(x_min), "x_max": float(x_max), |
| "y_min": float(y_min), "y_max": float(y_max), |
| } |
| else: |
| bounds = auto_detect_axis_bounds(img_arr) |
| if bounds is None: |
| return _tpd_error( |
| f"Could not auto-detect axis bounds for image {idx + 1}. " |
| "Please enter T min, T max, Signal min, Signal max " |
| "under 'Axis overrides'.") |
|
|
| try: |
| x_data, y_data = digitize_plot( |
| img_arr, bounds["x_min"], bounds["x_max"], |
| bounds["y_min"], bounds["y_max"], |
| threshold=int(threshold), |
| x_ticks=bounds.get("x_ticks"), |
| y_ticks=bounds.get("y_ticks"), |
| ) |
| except Exception as exc: |
| return _tpd_error(f"Digitization failed for image {idx + 1}: {exc}") |
|
|
| temperatures.append(x_data) |
| rates.append(y_data) |
|
|
| return _run_tpd_analysis(temperatures, rates, betas, n_samples) |
|
|
|
|
| def _tpd_error(msg=""): |
| """Return empty outputs for TPD error cases.""" |
| return None, None, gr.Dropdown(choices=[], value=None), None, None, None |
|
|
|
|
| def _run_tpd_analysis(temperatures, rates, betas, n_samples): |
| """Core TPD analysis: predict + reconstruct for top mechanism.""" |
| pred = get_predictor() |
| result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples)) |
| top_mech = result["predicted_mechanism"] |
| recon = pred.reconstruct_tpd(result, temperatures, rates, betas) |
|
|
| fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd") |
|
|
| sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
|
|
| state = { |
| "result": result, |
| "temperatures": [t.tolist() for t in temperatures], |
| "rates": [r.tolist() for r in rates], |
| "betas": betas, |
| } |
|
|
| fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas) |
|
|
| return ( |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, |
| ) |
|
|
|
|
| def _render_tpd_mechanism(mech, result, recon, betas): |
| """Render posteriors, param table, and reconstruction for one TPD mechanism.""" |
| stats = result["parameter_stats"].get(mech) |
| samples = result["posterior_samples"].get(mech) |
|
|
| fig_posteriors = None |
| fig_table = None |
| if stats and samples is not None: |
| fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd") |
| fig_table = plot_parameter_table(stats, mech) |
|
|
| fig_recon = None |
| if recon is not None: |
| scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None |
| fig_recon = plot_reconstruction( |
| recon["observed"], recon["reconstructed"], domain="tpd", |
| nrmses=recon.get("nrmse"), r2s=recon.get("r2"), |
| scan_labels=scan_labels, |
| ) |
|
|
| return fig_posteriors, fig_table, fig_recon |
|
|
|
|
| def _on_tpd_mechanism_change(mech_choice, state): |
| """Callback when user selects a different TPD mechanism from the dropdown.""" |
| if not state or not mech_choice: |
| return None, None, None |
|
|
| mech = mech_choice.split(" (")[0] |
| result = state["result"] |
| temperatures = [np.array(t) for t in state["temperatures"]] |
| rates = [np.array(r) for r in state["rates"]] |
| betas = state["betas"] |
|
|
| pred = get_predictor() |
| recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech) |
| return _render_tpd_mechanism(mech, result, recon, betas) |
|
|
|
|
| |
| |
| |
|
|
| def download_results(result_text): |
| """Create a downloadable JSON from the summary.""" |
| if not result_text: |
| return None |
| tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") |
| tmp.write(result_text) |
| tmp.close() |
| return tmp.name |
|
|
|
|
| |
| |
| |
|
|
| def _build_ec_output_section(prefix): |
| """Build shared output components for one EC input tab. |
| |
| Returns (probs, state, mech_dd, posteriors, param_table, recon, conc). |
| """ |
| gr.Markdown("---") |
| probs = gr.Plot(label="Mechanism Classification") |
| state = gr.State(value=None) |
| mech_dd = gr.Dropdown( |
| label="Select mechanism to view details", |
| choices=[], |
| interactive=True, |
| ) |
| with gr.Row(): |
| posteriors = gr.Plot(label="Parameter Posteriors") |
| param_table = gr.Plot(label="Parameter Estimates") |
| recon = gr.Plot(label="Signal Reconstruction") |
| conc = gr.Plot(label="Surface Concentration Profiles") |
| return probs, state, mech_dd, posteriors, param_table, recon, conc |
|
|
|
|
| def _build_tpd_output_section(prefix): |
| """Build shared output components for one TPD input tab. |
| |
| Returns (probs, state, mech_dd, posteriors, param_table, recon). |
| """ |
| gr.Markdown("---") |
| probs = gr.Plot(label="Mechanism Classification") |
| state = gr.State(value=None) |
| mech_dd = gr.Dropdown( |
| label="Select mechanism to view details", |
| choices=[], |
| interactive=True, |
| ) |
| with gr.Row(): |
| posteriors = gr.Plot(label="Parameter Posteriors") |
| param_table = gr.Plot(label="Parameter Estimates") |
| recon = gr.Plot(label="Signal Reconstruction") |
| return probs, state, mech_dd, posteriors, param_table, recon |
|
|
|
|
| CUSTOM_CSS = """ |
| .main-header { text-align: center; padding: 24px 16px 8px 16px; } |
| .main-header h1 { font-size: 2.2em; margin-bottom: 2px; letter-spacing: -0.5px; } |
| .main-header p { color: #6B7280; font-size: 1.05em; max-width: 720px; margin: 0 auto; line-height: 1.5; } |
| .section-heading { margin-top: 24px !important; margin-bottom: 4px !important; } |
| .preproc-card { background: #F0F9FF; border: 1px solid #BAE6FD; border-radius: 8px; padding: 10px 16px; margin-top: 12px; font-size: 0.93em; color: #0C4A6E; } |
| .summary-card { border: 1px solid #E5E7EB; border-radius: 10px; padding: 16px 20px; background: #FAFBFC; } |
| .summary-card table { width: 100%; font-size: 0.92em; } |
| .summary-card td, .summary-card th { padding: 4px 8px; } |
| footer { display: none !important; } |
| """ |
|
|
|
|
| def build_app(): |
| with gr.Blocks( |
| title="ECFlow — Bayesian Inference for Electrochemistry & Catalysis", |
| theme=gr.themes.Soft( |
| primary_hue="blue", |
| secondary_hue="slate", |
| font=gr.themes.GoogleFont("Inter"), |
| ), |
| css=CUSTOM_CSS, |
| ) as app: |
| gr.HTML( |
| "<div class='main-header'>" |
| "<h1>⚡ ECFlow</h1>" |
| "<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> " |
| "and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty.</p>" |
| "</div>" |
| ) |
|
|
| with gr.Tabs(): |
| |
| |
| |
| with gr.Tab("CV Analysis"): |
| with gr.Tabs(): |
| |
| with gr.Tab("CSV Data"): |
| gr.Markdown( |
| "Upload CSV files exported from your potentiostat. " |
| "One file per scan rate. For best accuracy, upload multiple scan rates.\n\n" |
| "**Expected columns:** `Potential (V)`, `Current (A/mA/µA)`, and optionally `Time (s)`.\n\n" |
| "Example header: `Time (s), Potential (V), Current (A)`\n\n" |
| "If a **Time (s)** column is present, the scan rate is " |
| "detected automatically. Otherwise, enter scan rates below." |
| ) |
| if CV_EXAMPLES: |
| with gr.Accordion("Try an example (no data needed)", open=True): |
| with gr.Row(): |
| cv_example_dd = gr.Dropdown( |
| label="Select example", |
| choices=list(CV_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=3, |
| ) |
| cv_example_btn = gr.Button( |
| "Load Example", variant="secondary", scale=1, |
| ) |
| cv_files = gr.File( |
| label="CSV files (one per scan rate)", |
| file_count="multiple", |
| file_types=[".csv", ".txt"], |
| ) |
| cv_rates = gr.Textbox( |
| label="Scan rates (V/s), comma-separated", |
| placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| cv_E0 = gr.Number( |
| label="Formal potential E₀ (V)", |
| value=None, |
| info="Auto-estimated from peak positions if empty", |
| ) |
| cv_T = gr.Number(label="Temperature (K)", value=298.15) |
| cv_A = gr.Number(label="Electrode area (cm²)", value=0.0707) |
| with gr.Row(): |
| cv_C = gr.Number(label="Concentration (mM)", value=1.0) |
| cv_D = gr.Number( |
| label="Diffusion coeff D (cm²/s)", |
| value=None, |
| info="Estimated via Randles-Ševčík if empty", |
| ) |
| cv_n = gr.Number(label="Number of electrons", value=1, precision=0) |
| with gr.Row(): |
| cv_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| cv_btn = gr.Button("Analyze", variant="primary", scale=2) |
|
|
| (cv_probs, cv_state, |
| cv_mech_dd, cv_posteriors, cv_param_table, |
| cv_recon, cv_conc) = _build_ec_output_section("cv") |
|
|
| ec_outputs = [ |
| cv_probs, cv_state, |
| cv_mech_dd, cv_posteriors, cv_param_table, |
| cv_recon, cv_conc, |
| ] |
| cv_btn.click( |
| analyze_cv, |
| inputs=[ |
| cv_files, cv_rates, cv_E0, cv_T, |
| cv_A, cv_C, cv_D, cv_n, cv_nsamples, |
| ], |
| outputs=ec_outputs, |
| ) |
| cv_mech_dd.change( |
| _on_ec_mechanism_change, |
| inputs=[cv_mech_dd, cv_state], |
| outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc], |
| ) |
| if CV_EXAMPLES: |
| cv_example_btn.click( |
| _load_cv_example, |
| inputs=[cv_example_dd], |
| outputs=[ |
| cv_files, cv_rates, cv_E0, cv_T, |
| cv_A, cv_C, cv_D, cv_n, |
| ], |
| ) |
|
|
| |
| with gr.Tab("From Image"): |
| gr.Markdown( |
| "Upload plot images of CVs (potential in V on x-axis, " |
| "current in A/mA/µA on y-axis). **One image per scan rate.** " |
| "For best accuracy, upload multiple scan rates.\n\n" |
| "Axis bounds are **auto-detected** via OCR — " |
| "override in Advanced if needed." |
| ) |
| if CV_IMG_EXAMPLES: |
| with gr.Accordion("Try an example (click to load)", open=True): |
| with gr.Row(): |
| cv_img_example_dd = gr.Dropdown( |
| label="Select mechanism", |
| choices=list(CV_IMG_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=3, |
| ) |
| cv_img_example_btn = gr.Button( |
| "Load Example", variant="secondary", scale=1, |
| ) |
| cv_img_example_gallery = gr.Gallery( |
| label="Example output (classification → posteriors → reconstruction)", |
| columns=3, height=220, object_fit="contain", |
| interactive=False, |
| ) |
| cv_img_files = gr.File( |
| label="Plot images (one per scan rate)", |
| file_count="multiple", |
| file_types=["image"], |
| ) |
| cv_img_scan_rate = gr.Textbox( |
| label="Scan rates (V/s), comma-separated", |
| placeholder="e.g., 0.01, 0.1, 1.0", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| cv_img_E0 = gr.Number( |
| label="Formal potential E₀ (V)", |
| value=None, |
| info="Auto-estimated from peaks if empty", |
| ) |
| cv_img_threshold = gr.Slider( |
| 0, 255, value=0, step=1, |
| label="Binarization threshold (0 = auto)", |
| ) |
| cv_img_current_unit = gr.Dropdown( |
| label="Current unit on y-axis", |
| choices=["auto", "µA", "mA", "A", "nA"], |
| value="auto", |
| info="Select the unit shown on the y-axis of your plot", |
| ) |
| with gr.Accordion("Axis overrides", open=False): |
| gr.Markdown( |
| "Leave at 0 to auto-detect from each image. " |
| "Override if OCR detection is inaccurate. " |
| "Overrides apply to **all** images." |
| ) |
| with gr.Row(): |
| cv_img_xmin = gr.Number(label="E min (V)", value=None) |
| cv_img_xmax = gr.Number(label="E max (V)", value=None) |
| cv_img_ymin = gr.Number(label="I min", value=None) |
| cv_img_ymax = gr.Number(label="I max", value=None) |
| with gr.Row(): |
| cv_img_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| cv_img_btn = gr.Button("Analyze", variant="primary", scale=2) |
|
|
| (cv_img_probs, cv_img_state, |
| cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, |
| cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img") |
|
|
| ec_img_outputs = [ |
| cv_img_probs, cv_img_state, |
| cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, |
| cv_img_recon, cv_img_conc, |
| ] |
| cv_img_btn.click( |
| analyze_cv_image, |
| inputs=[ |
| cv_img_files, cv_img_scan_rate, cv_img_E0, |
| cv_img_threshold, cv_img_current_unit, |
| cv_img_nsamples, |
| cv_img_xmin, cv_img_xmax, |
| cv_img_ymin, cv_img_ymax, |
| ], |
| outputs=ec_img_outputs, |
| ) |
| cv_img_mech_dd.change( |
| _on_ec_mechanism_change, |
| inputs=[cv_img_mech_dd, cv_img_state], |
| outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc], |
| ) |
| if CV_IMG_EXAMPLES: |
| def _on_cv_img_example_select(mech_name): |
| files, rates = _load_cv_image_example(mech_name) |
| renders = CV_OUTPUT_RENDERS.get(mech_name, {}) |
| gallery_imgs = [ |
| renders[k] for k in |
| ["classification", "posteriors", "reconstruction"] |
| if k in renders |
| ] |
| return files, rates, gallery_imgs |
|
|
| cv_img_example_btn.click( |
| _on_cv_img_example_select, |
| inputs=[cv_img_example_dd], |
| outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery], |
| ) |
| cv_img_example_dd.change( |
| lambda m: [ |
| CV_OUTPUT_RENDERS.get(m, {}).get(k) |
| for k in ["classification", "posteriors", "reconstruction"] |
| if CV_OUTPUT_RENDERS.get(m, {}).get(k) |
| ], |
| inputs=[cv_img_example_dd], |
| outputs=[cv_img_example_gallery], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("TPD Analysis"): |
| with gr.Tabs(): |
| |
| with gr.Tab("CSV Data"): |
| gr.Markdown( |
| "Upload CSV files with **Temperature (K)** and **Signal** columns. " |
| "Optionally include a **Time (s)** column for automatic β detection. " |
| "One file per heating rate. For best accuracy, upload multiple heating rates.\n\n" |
| "Example header: `Temperature (K), Signal`\n\n" |
| "**You must provide the correct β for each file** — " |
| "the model uses β to condition inference." |
| ) |
| if TPD_EXAMPLES: |
| with gr.Accordion("Try an example (no data needed)", open=True): |
| with gr.Row(): |
| tpd_example_dd = gr.Dropdown( |
| label="Select example", |
| choices=list(TPD_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=3, |
| ) |
| tpd_example_btn = gr.Button( |
| "Load Example", variant="secondary", scale=1, |
| ) |
| tpd_files = gr.File( |
| label="CSV files (one per heating rate)", |
| file_count="multiple", |
| file_types=[".csv", ".txt"], |
| ) |
| tpd_betas = gr.Textbox( |
| label="Heating rates β (K/s), comma-separated", |
| placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)", |
| value="", |
| ) |
| with gr.Row(): |
| tpd_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| tpd_btn = gr.Button("Analyze", variant="primary", scale=2) |
|
|
| (tpd_probs, tpd_state, |
| tpd_mech_dd, tpd_posteriors, tpd_param_table, |
| tpd_recon) = _build_tpd_output_section("tpd") |
|
|
| tpd_outputs = [ |
| tpd_probs, tpd_state, |
| tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon, |
| ] |
| tpd_btn.click( |
| analyze_tpd, |
| inputs=[tpd_files, tpd_betas, tpd_nsamples], |
| outputs=tpd_outputs, |
| ) |
| tpd_mech_dd.change( |
| _on_tpd_mechanism_change, |
| inputs=[tpd_mech_dd, tpd_state], |
| outputs=[tpd_posteriors, tpd_param_table, tpd_recon], |
| ) |
| if TPD_EXAMPLES: |
| tpd_example_btn.click( |
| _load_tpd_example, |
| inputs=[tpd_example_dd], |
| outputs=[tpd_files, tpd_betas], |
| ) |
|
|
| |
| with gr.Tab("From Image"): |
| gr.Markdown( |
| "Upload plot images of TPD curves (temperature in K on " |
| "x-axis, signal on y-axis). **One image per heating rate.** " |
| "For best accuracy, upload multiple heating rates.\n\n" |
| "Axis bounds are **auto-detected** via OCR — " |
| "override in Advanced if needed." |
| ) |
| if TPD_IMG_EXAMPLES: |
| with gr.Accordion("Try an example (click to load)", open=True): |
| with gr.Row(): |
| tpd_img_example_dd = gr.Dropdown( |
| label="Select mechanism", |
| choices=list(TPD_IMG_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=3, |
| ) |
| tpd_img_example_btn = gr.Button( |
| "Load Example", variant="secondary", scale=1, |
| ) |
| tpd_img_example_gallery = gr.Gallery( |
| label="Example output (classification → posteriors → reconstruction)", |
| columns=3, height=220, object_fit="contain", |
| interactive=False, |
| ) |
| tpd_img_files = gr.File( |
| label="Plot images (one per heating rate)", |
| file_count="multiple", |
| file_types=["image"], |
| ) |
| tpd_img_betas = gr.Textbox( |
| label="Heating rates β (K/s), comma-separated", |
| placeholder="e.g., 0.3, 2.6, 22.1", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| tpd_img_threshold = gr.Slider( |
| 0, 255, value=0, step=1, |
| label="Binarization threshold (0 = auto)", |
| ) |
| with gr.Accordion("Axis overrides", open=False): |
| gr.Markdown( |
| "Leave at 0 to auto-detect from each image. " |
| "Override if OCR detection is inaccurate. " |
| "Overrides apply to **all** images." |
| ) |
| with gr.Row(): |
| tpd_img_xmin = gr.Number(label="T min (K)", value=None) |
| tpd_img_xmax = gr.Number(label="T max (K)", value=None) |
| tpd_img_ymin = gr.Number(label="Signal min", value=None) |
| tpd_img_ymax = gr.Number(label="Signal max", value=None) |
| with gr.Row(): |
| tpd_img_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| tpd_img_btn = gr.Button("Analyze", variant="primary", scale=2) |
|
|
| (tpd_img_probs, tpd_img_state, |
| tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, |
| tpd_img_recon) = _build_tpd_output_section("tpd_img") |
|
|
| tpd_img_outputs = [ |
| tpd_img_probs, tpd_img_state, |
| tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon, |
| ] |
| tpd_img_btn.click( |
| analyze_tpd_image, |
| inputs=[ |
| tpd_img_files, tpd_img_betas, |
| tpd_img_threshold, tpd_img_nsamples, |
| tpd_img_xmin, tpd_img_xmax, |
| tpd_img_ymin, tpd_img_ymax, |
| ], |
| outputs=tpd_img_outputs, |
| ) |
| tpd_img_mech_dd.change( |
| _on_tpd_mechanism_change, |
| inputs=[tpd_img_mech_dd, tpd_img_state], |
| outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon], |
| ) |
| if TPD_IMG_EXAMPLES: |
| def _on_tpd_img_example_select(mech_name): |
| files, betas = _load_tpd_image_example(mech_name) |
| renders = TPD_OUTPUT_RENDERS.get(mech_name, {}) |
| gallery_imgs = [ |
| renders[k] for k in |
| ["classification", "posteriors", "reconstruction"] |
| if k in renders |
| ] |
| return files, betas, gallery_imgs |
|
|
| tpd_img_example_btn.click( |
| _on_tpd_img_example_select, |
| inputs=[tpd_img_example_dd], |
| outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery], |
| ) |
| tpd_img_example_dd.change( |
| lambda m: [ |
| TPD_OUTPUT_RENDERS.get(m, {}).get(k) |
| for k in ["classification", "posteriors", "reconstruction"] |
| if TPD_OUTPUT_RENDERS.get(m, {}).get(k) |
| ], |
| inputs=[tpd_img_example_dd], |
| outputs=[tpd_img_example_gallery], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("About"): |
| gr.Markdown(""" |
| ## How It Works |
| |
| ECFlow uses **conditional normalizing flows** with a **Set Transformer** encoder to perform amortized Bayesian inference. |
| Given one or more experimental curves, it simultaneously classifies the reaction mechanism and produces |
| full posterior distributions over kinetic parameters — in a single forward pass. |
| |
| | | Electrochemistry (CV) | Catalysis (TPD) | |
| |---|---|---| |
| | **Mechanisms** | Nernst, Butler–Volmer, Marcus–Hush–Chidsey, Adsorption, EC, Langmuir–Hinshelwood | First-order, Second-order, LH Surface, Mars–van Krevelen, Coverage-dependent, Diffusion-limited | |
| | **Inference** | ~50 ms on CPU | ~50 ms on CPU | |
| | **Calibration** | 89–94 % coverage at 90 % nominal | Conformal coverage verified | |
| |
| Training data is generated from physics-based simulators (Crank–Nicolson for CV, ODE integrators for TPD). |
| Posteriors are calibrated via a coverage-aware loss with per-parameter inverse-spread weighting. |
| |
| ### Citation |
| |
| ``` |
| Yan, B. (2026). ECFlow: Amortized Bayesian Inference for Mechanism Identification |
| and Parameter Estimation in Electrochemistry and Catalysis via Conditional |
| Normalizing Flows. [Preprint] |
| ``` |
| |
| Built at MIT. Code and paper at [github.com/bingyan/ECFlow](https://github.com/bingyan/ECFlow). |
| """) |
|
|
| return app |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| app = build_app() |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| ) |
|
|