ecflow / app.py
Bing Yan
Add image examples to From Image tabs
1e4169d
"""
ECFlow — Amortized Bayesian Inference for Electrochemistry & Catalysis
Gradio web interface for mechanism classification and parameter inference
from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data.
"""
import os
import sys
import json
import tempfile
from pathlib import Path
import numpy as np
import gradio as gr
from inference import ECFlowPredictor
from preprocessing import (
nondimensionalize_cv,
estimate_E0,
parse_cv_csv,
parse_tpd_csv,
)
from plotting import (
plot_mechanism_probs,
plot_posteriors,
plot_parameter_table,
plot_reconstruction,
plot_concentration_profiles,
)
# ---------------------------------------------------------------------------
# Model paths (relative to repo root)
# ---------------------------------------------------------------------------
REPO_ROOT = Path(__file__).resolve().parent
DEMO_DIR = REPO_ROOT / "demo_data"
DEMO_RENDERS = REPO_ROOT / "demo_renders"
EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
# Allow override via environment variables
EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)))
TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)))
# ---------------------------------------------------------------------------
# Demo examples
# ---------------------------------------------------------------------------
def _discover_examples():
"""Scan demo_data/ for metadata files and build example catalogs."""
cv_examples = {}
tpd_examples = {}
if not DEMO_DIR.is_dir():
return cv_examples, tpd_examples
for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")):
with open(meta_path) as f:
meta = json.load(f)
mech = meta["mechanism"]
csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]]
if meta_path.name.startswith("ec_"):
rates = meta.get("scan_rates_Vs", [])
rates_str = ", ".join(f"{r:.4g}" for r in rates)
phys = meta.get("physical_params", {})
cv_examples[f"CV — {mech}"] = {
"files": csv_files,
"scan_rates": rates_str,
"E0_V": phys.get("E0_V"),
"T_K": phys.get("T_K", 298.15),
"A_cm2": phys.get("A_cm2", 0.0707),
"C_mM": phys.get("C_mM", 1.0),
"D_cm2s": phys.get("D_cm2s", 1e-5),
"n_electrons": phys.get("n_electrons", 1),
}
elif meta_path.name.startswith("tpd_"):
betas = meta.get("betas_Ks", [])
betas_str = ", ".join(f"{b:.4g}" for b in betas)
tpd_examples[f"TPD — {mech}"] = {
"files": csv_files,
"heating_rates": betas_str,
}
return cv_examples, tpd_examples
CV_EXAMPLES, TPD_EXAMPLES = _discover_examples()
def _discover_image_examples():
"""Build image example catalogs from demo_renders/ directory."""
import re
cv_img_examples = {}
tpd_img_examples = {}
cv_output_renders = {}
tpd_output_renders = {}
if not DEMO_RENDERS.is_dir():
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")):
m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name)
if not m:
continue
mech, rate_mVs = m.group(1), int(m.group(2))
rate_Vs = rate_mVs / 1000.0
cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs))
for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")):
m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name)
if not m:
continue
mech, idx = m.group(1), m.group(2)
tpd_img_examples.setdefault(mech, []).append((str(p), idx))
for mech in cv_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction", "concentration"]:
rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
cv_output_renders[mech] = renders
for mech in tpd_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction"]:
rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
tpd_output_renders[mech] = renders
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
(CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES,
CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples()
def _load_cv_image_example(mech_name):
"""Return (files, scan_rates_str) for a CV image example."""
if not mech_name or mech_name not in CV_IMG_EXAMPLES:
return [gr.update()] * 2
entries = CV_IMG_EXAMPLES[mech_name]
files = [e[0] for e in entries]
rates_str = ", ".join(f"{e[1]}" for e in entries)
return files, rates_str
def _load_tpd_image_example(mech_name):
"""Return (image_path, heating_rate_str) for a TPD image example."""
if not mech_name or mech_name not in TPD_IMG_EXAMPLES:
return [gr.update()] * 2
entries = TPD_IMG_EXAMPLES[mech_name]
return entries[0][0], entries[0][1]
def _load_cv_example(example_name):
"""Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example."""
if not example_name or example_name not in CV_EXAMPLES:
return [gr.update()] * 8
ex = CV_EXAMPLES[example_name]
return (
ex["files"],
ex["scan_rates"],
ex["E0_V"],
ex["T_K"],
ex["A_cm2"],
ex["C_mM"],
ex["D_cm2s"],
ex["n_electrons"],
)
def _load_tpd_example(example_name):
"""Return (files, heating_rates) for the chosen TPD example."""
if not example_name or example_name not in TPD_EXAMPLES:
return [gr.update()] * 2
ex = TPD_EXAMPLES[example_name]
return (
ex["files"],
ex["heating_rates"],
)
predictor = None
def get_predictor():
global predictor
if predictor is None:
ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None
tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None
predictor = ECFlowPredictor(
ec_checkpoint=ec_ckpt,
tpd_checkpoint=tpd_ckpt,
device="cpu",
)
return predictor
# =========================================================================
# CV Analysis
# =========================================================================
def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2,
C_mM, D_cm2s, n_electrons, n_samples):
"""Analyze CV data from potentiostat CSV files.
Accepts CSV files with columns for potential (V) and current (A/mA/µA).
If the CSV includes a Time (s) column, the scan rate is auto-detected.
Otherwise, scan rates must be provided.
"""
if not files:
return _ec_error("Please upload at least one CSV file.")
scan_rates_text = scan_rates_text.strip() if scan_rates_text else ""
user_rates = None
if scan_rates_text:
try:
user_rates = [float(s.strip()) for s in scan_rates_text.split(",")]
except ValueError:
return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(user_rates):
return _ec_error(
f"Number of files ({len(files)}) must match number of "
f"scan rates ({len(user_rates)}).")
C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6
n = int(n_electrons) if n_electrons else 1
T = float(T_K) if T_K else 298.15
A = float(A_cm2) if A_cm2 else 0.0707
parsed_data = []
scan_rates = []
for idx, f in enumerate(files):
content = Path(f.name).read_text()
parsed = parse_cv_csv(content)
parsed_data.append(parsed)
if user_rates is not None:
v = user_rates[idx]
elif "scan_rate_Vs" in parsed:
v = parsed["scan_rate_Vs"]
else:
return _ec_error(
f"Cannot determine scan rate for file '{Path(f.name).name}'. "
"Either provide scan rates (V/s) or upload CSV files that "
"include a Time (s) column.")
scan_rates.append(v)
if E0_V:
e0 = float(E0_V)
e0_source = "user"
else:
e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data]
e0 = float(np.median(e0_estimates))
e0_source = "auto"
D = float(D_cm2s) if D_cm2s else 1e-5
potentials, fluxes, sigmas_list = [], [], []
for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)):
E, i_A = parsed["E_V"], parsed["i_A"]
theta, flux, sigma = nondimensionalize_cv(
E, i_A, v, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples)
_CURRENT_UNIT_SCALES = {
"µA": 1e-6,
"mA": 1e-3,
"A": 1.0,
"nA": 1e-9,
}
def _guess_current_unit(i_values):
"""Guess the current unit from the magnitude of digitized values."""
i_max = np.max(np.abs(i_values))
if i_max > 1e3:
return "nA"
if i_max > 100:
return "µA"
if i_max > 0.1:
return "µA"
if i_max > 1e-4:
return "mA"
return "A"
def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit,
n_samples, x_min, x_max, y_min, y_max):
"""Analyze CV from uploaded plot images (one per scan rate).
Extracts CV curves via image digitization, then nondimensionalizes
and runs inference identically to the CSV path.
Axis bounds are auto-detected via OCR — override in Advanced if needed.
"""
if not files:
return _ec_error("Please upload at least one image.")
try:
from digitizer import digitize_plot, auto_detect_axis_bounds
from PIL import Image as PILImage
except ImportError:
return _ec_error("Required libraries not available for image digitization.")
scan_rate_text = scan_rate_text.strip() if scan_rate_text else ""
if not scan_rate_text:
return _ec_error("Please enter the scan rate(s) (V/s), comma-separated.")
try:
scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")]
except ValueError:
return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(scan_rates):
return _ec_error(
f"Number of images ({len(files)}) must match number of "
f"scan rates ({len(scan_rates)}).")
has_user_bounds = all(
v is not None and v != 0 for v in [x_min, x_max, y_min, y_max]
)
D = 1e-5
T = 298.15
A = 0.0707
C_molcm3 = 1e-6
n = 1
# First pass: digitize all images and estimate E0 per image
image_data = []
e0_estimates = []
for idx, f in enumerate(files):
img_arr = np.array(PILImage.open(f.name).convert("RGB"))
v_Vs = scan_rates[idx]
if has_user_bounds:
bounds = {
"x_min": float(x_min), "x_max": float(x_max),
"y_min": float(y_min), "y_max": float(y_max),
}
else:
bounds = auto_detect_axis_bounds(img_arr)
if bounds is None:
return _ec_error(
f"Could not auto-detect axis bounds for image {idx + 1}. "
"Please enter E min, E max, I min, I max under "
"'Axis overrides'.")
try:
E_V, I_raw = digitize_plot(
img_arr, bounds["x_min"], bounds["x_max"],
bounds["y_min"], bounds["y_max"],
threshold=int(threshold),
x_ticks=bounds.get("x_ticks"),
y_ticks=bounds.get("y_ticks"),
)
except Exception as exc:
return _ec_error(f"Digitization failed for image {idx + 1}: {exc}")
if current_unit and current_unit != "auto":
i_unit = current_unit
elif "y_unit" in bounds:
i_unit = bounds["y_unit"]
else:
i_unit = _guess_current_unit(I_raw)
i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6)
i_A = I_raw * i_scale
e0_estimates.append(float(estimate_E0(E_V, i_A)))
image_data.append((E_V, i_A, v_Vs, i_unit, bounds))
# Determine E0: user-provided or median of per-image estimates
if E0_V is not None and E0_V != 0:
e0 = float(E0_V)
e0_source = "user"
else:
e0 = float(np.median(e0_estimates))
e0_source = "auto"
# Second pass: nondimensionalize with shared E0
potentials, fluxes, sigmas_list = [], [], []
preproc_parts = []
for E_V, i_A, v_Vs, i_unit, bounds in image_data:
theta, flux, sigma = nondimensionalize_cv(
E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
preproc_parts.append(
f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, "
f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, "
f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})"
)
return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples)
def _ec_error(msg=""):
"""Return empty outputs for EC error cases."""
return None, None, gr.Dropdown(choices=[], value=None), None, None, None, None
def _run_ec_analysis(potentials, fluxes, sigmas, n_samples):
"""Core EC analysis: predict + reconstruct for top mechanism."""
pred = get_predictor()
result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples))
top_mech = result["predicted_mechanism"]
recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas)
fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec")
sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": result,
"potentials": [p.tolist() for p in potentials],
"fluxes": [f.tolist() for f in fluxes],
"sigmas": sigmas,
}
fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism(
top_mech, result, recon, sigmas
)
return (
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon, fig_conc,
)
def _render_ec_mechanism(mech, result, recon, sigmas):
"""Render posteriors, param table, reconstruction, and concentration for one EC mechanism."""
stats = result["parameter_stats"].get(mech)
samples = result["posterior_samples"].get(mech)
fig_posteriors = None
fig_table = None
if stats and samples is not None:
fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec")
fig_table = plot_parameter_table(stats, mech)
fig_recon = None
fig_conc = None
if recon is not None:
scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None
fig_recon = plot_reconstruction(
recon["observed"], recon["reconstructed"], domain="ec",
nrmses=recon.get("nrmse"), r2s=recon.get("r2"),
scan_labels=scan_labels,
)
conc_curves = recon.get("concentrations")
if conc_curves:
fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels)
return fig_posteriors, fig_table, fig_recon, fig_conc
def _on_ec_mechanism_change(mech_choice, state):
"""Callback when user selects a different EC mechanism from the dropdown."""
if not state or not mech_choice:
return None, None, None, None
mech = mech_choice.split(" (")[0]
result = state["result"]
potentials = [np.array(p) for p in state["potentials"]]
fluxes = [np.array(f) for f in state["fluxes"]]
sigmas = state["sigmas"]
pred = get_predictor()
recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech)
return _render_ec_mechanism(mech, result, recon, sigmas)
# =========================================================================
# TPD Analysis
# =========================================================================
def analyze_tpd(files, heating_rates_text, n_samples):
"""Analyze TPD data."""
if not files:
return _tpd_error("Please upload at least one CSV file.")
temperatures, rates = [], []
csv_betas = []
for f in files:
content = Path(f.name).read_text()
parsed = parse_tpd_csv(content)
temperatures.append(parsed["T_K"])
rates.append(parsed["signal"])
if "beta_Ks" in parsed:
csv_betas.append(parsed["beta_Ks"])
heating_rates_text = heating_rates_text.strip() if heating_rates_text else ""
if heating_rates_text:
try:
betas = [float(s.strip()) for s in heating_rates_text.split(",")]
except ValueError:
return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.")
if len(files) != len(betas):
return _tpd_error(
f"Number of files ({len(files)}) must match heating rates ({len(betas)}).")
elif len(csv_betas) == len(files):
betas = csv_betas
else:
return _tpd_error(
"Please enter the heating rate (β in K/s) for each file. "
"This value is critical for correct inference. "
"Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.")
return _run_tpd_analysis(temperatures, rates, betas, n_samples)
def analyze_tpd_image(files, heating_rates_text, threshold, n_samples,
x_min, x_max, y_min, y_max):
"""Analyze TPD from uploaded plot images (one per heating rate).
Axis bounds are auto-detected via OCR — override in Advanced if needed.
"""
if not files:
return _tpd_error("Please upload at least one image.")
try:
from digitizer import digitize_plot, auto_detect_axis_bounds
from PIL import Image as PILImage
except ImportError:
return _tpd_error("Required libraries not available for image digitization.")
heating_rates_text = heating_rates_text.strip() if heating_rates_text else ""
if not heating_rates_text:
return _tpd_error(
"Please enter the heating rate(s) (β in K/s), comma-separated. "
"This value is critical for correct inference.")
try:
betas = [float(s.strip()) for s in heating_rates_text.split(",")]
except ValueError:
return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.")
if len(files) != len(betas):
return _tpd_error(
f"Number of images ({len(files)}) must match number of "
f"heating rates ({len(betas)}).")
has_user_bounds = all(
v is not None and v != 0 for v in [x_min, x_max, y_min, y_max]
)
temperatures, rates = [], []
for idx, f in enumerate(files):
img_arr = np.array(PILImage.open(f.name).convert("RGB"))
if has_user_bounds:
bounds = {
"x_min": float(x_min), "x_max": float(x_max),
"y_min": float(y_min), "y_max": float(y_max),
}
else:
bounds = auto_detect_axis_bounds(img_arr)
if bounds is None:
return _tpd_error(
f"Could not auto-detect axis bounds for image {idx + 1}. "
"Please enter T min, T max, Signal min, Signal max "
"under 'Axis overrides'.")
try:
x_data, y_data = digitize_plot(
img_arr, bounds["x_min"], bounds["x_max"],
bounds["y_min"], bounds["y_max"],
threshold=int(threshold),
x_ticks=bounds.get("x_ticks"),
y_ticks=bounds.get("y_ticks"),
)
except Exception as exc:
return _tpd_error(f"Digitization failed for image {idx + 1}: {exc}")
temperatures.append(x_data)
rates.append(y_data)
return _run_tpd_analysis(temperatures, rates, betas, n_samples)
def _tpd_error(msg=""):
"""Return empty outputs for TPD error cases."""
return None, None, gr.Dropdown(choices=[], value=None), None, None, None
def _run_tpd_analysis(temperatures, rates, betas, n_samples):
"""Core TPD analysis: predict + reconstruct for top mechanism."""
pred = get_predictor()
result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples))
top_mech = result["predicted_mechanism"]
recon = pred.reconstruct_tpd(result, temperatures, rates, betas)
fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd")
sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": result,
"temperatures": [t.tolist() for t in temperatures],
"rates": [r.tolist() for r in rates],
"betas": betas,
}
fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas)
return (
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon,
)
def _render_tpd_mechanism(mech, result, recon, betas):
"""Render posteriors, param table, and reconstruction for one TPD mechanism."""
stats = result["parameter_stats"].get(mech)
samples = result["posterior_samples"].get(mech)
fig_posteriors = None
fig_table = None
if stats and samples is not None:
fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd")
fig_table = plot_parameter_table(stats, mech)
fig_recon = None
if recon is not None:
scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None
fig_recon = plot_reconstruction(
recon["observed"], recon["reconstructed"], domain="tpd",
nrmses=recon.get("nrmse"), r2s=recon.get("r2"),
scan_labels=scan_labels,
)
return fig_posteriors, fig_table, fig_recon
def _on_tpd_mechanism_change(mech_choice, state):
"""Callback when user selects a different TPD mechanism from the dropdown."""
if not state or not mech_choice:
return None, None, None
mech = mech_choice.split(" (")[0]
result = state["result"]
temperatures = [np.array(t) for t in state["temperatures"]]
rates = [np.array(r) for r in state["rates"]]
betas = state["betas"]
pred = get_predictor()
recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech)
return _render_tpd_mechanism(mech, result, recon, betas)
# =========================================================================
# Shared helpers
# =========================================================================
def download_results(result_text):
"""Create a downloadable JSON from the summary."""
if not result_text:
return None
tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w")
tmp.write(result_text)
tmp.close()
return tmp.name
# =========================================================================
# Gradio UI
# =========================================================================
def _build_ec_output_section(prefix):
"""Build shared output components for one EC input tab.
Returns (probs, state, mech_dd, posteriors, param_table, recon, conc).
"""
gr.Markdown("---")
probs = gr.Plot(label="Mechanism Classification")
state = gr.State(value=None)
mech_dd = gr.Dropdown(
label="Select mechanism to view details",
choices=[],
interactive=True,
)
with gr.Row():
posteriors = gr.Plot(label="Parameter Posteriors")
param_table = gr.Plot(label="Parameter Estimates")
recon = gr.Plot(label="Signal Reconstruction")
conc = gr.Plot(label="Surface Concentration Profiles")
return probs, state, mech_dd, posteriors, param_table, recon, conc
def _build_tpd_output_section(prefix):
"""Build shared output components for one TPD input tab.
Returns (probs, state, mech_dd, posteriors, param_table, recon).
"""
gr.Markdown("---")
probs = gr.Plot(label="Mechanism Classification")
state = gr.State(value=None)
mech_dd = gr.Dropdown(
label="Select mechanism to view details",
choices=[],
interactive=True,
)
with gr.Row():
posteriors = gr.Plot(label="Parameter Posteriors")
param_table = gr.Plot(label="Parameter Estimates")
recon = gr.Plot(label="Signal Reconstruction")
return probs, state, mech_dd, posteriors, param_table, recon
CUSTOM_CSS = """
.main-header { text-align: center; padding: 24px 16px 8px 16px; }
.main-header h1 { font-size: 2.2em; margin-bottom: 2px; letter-spacing: -0.5px; }
.main-header p { color: #6B7280; font-size: 1.05em; max-width: 720px; margin: 0 auto; line-height: 1.5; }
.section-heading { margin-top: 24px !important; margin-bottom: 4px !important; }
.preproc-card { background: #F0F9FF; border: 1px solid #BAE6FD; border-radius: 8px; padding: 10px 16px; margin-top: 12px; font-size: 0.93em; color: #0C4A6E; }
.summary-card { border: 1px solid #E5E7EB; border-radius: 10px; padding: 16px 20px; background: #FAFBFC; }
.summary-card table { width: 100%; font-size: 0.92em; }
.summary-card td, .summary-card th { padding: 4px 8px; }
footer { display: none !important; }
"""
def build_app():
with gr.Blocks(
title="ECFlow — Bayesian Inference for Electrochemistry & Catalysis",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
font=gr.themes.GoogleFont("Inter"),
),
css=CUSTOM_CSS,
) as app:
gr.HTML(
"<div class='main-header'>"
"<h1>⚡ ECFlow</h1>"
"<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
"and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty.</p>"
"</div>"
)
with gr.Tabs():
# =================================================================
# Tab 1: CV Analysis
# =================================================================
with gr.Tab("CV Analysis"):
with gr.Tabs():
# --- CSV upload mode (primary) ---
with gr.Tab("CSV Data"):
gr.Markdown(
"Upload CSV files exported from your potentiostat. "
"One file per scan rate. For best accuracy, upload multiple scan rates.\n\n"
"**Expected columns:** `Potential (V)`, `Current (A/mA/µA)`, and optionally `Time (s)`.\n\n"
"Example header: `Time (s), Potential (V), Current (A)`\n\n"
"If a **Time (s)** column is present, the scan rate is "
"detected automatically. Otherwise, enter scan rates below."
)
if CV_EXAMPLES:
with gr.Accordion("Try an example (no data needed)", open=True):
with gr.Row():
cv_example_dd = gr.Dropdown(
label="Select example",
choices=list(CV_EXAMPLES.keys()),
value=None,
interactive=True,
scale=3,
)
cv_example_btn = gr.Button(
"Load Example", variant="secondary", scale=1,
)
cv_files = gr.File(
label="CSV files (one per scan rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
cv_rates = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_E0 = gr.Number(
label="Formal potential E₀ (V)",
value=None,
info="Auto-estimated from peak positions if empty",
)
cv_T = gr.Number(label="Temperature (K)", value=298.15)
cv_A = gr.Number(label="Electrode area (cm²)", value=0.0707)
with gr.Row():
cv_C = gr.Number(label="Concentration (mM)", value=1.0)
cv_D = gr.Number(
label="Diffusion coeff D (cm²/s)",
value=None,
info="Estimated via Randles-Ševčík if empty",
)
cv_n = gr.Number(label="Number of electrons", value=1, precision=0)
with gr.Row():
cv_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
cv_btn = gr.Button("Analyze", variant="primary", scale=2)
(cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc) = _build_ec_output_section("cv")
ec_outputs = [
cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc,
]
cv_btn.click(
analyze_cv,
inputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n, cv_nsamples,
],
outputs=ec_outputs,
)
cv_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_mech_dd, cv_state],
outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
)
if CV_EXAMPLES:
cv_example_btn.click(
_load_cv_example,
inputs=[cv_example_dd],
outputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n,
],
)
# --- Image mode ---
with gr.Tab("From Image"):
gr.Markdown(
"Upload plot images of CVs (potential in V on x-axis, "
"current in A/mA/µA on y-axis). **One image per scan rate.** "
"For best accuracy, upload multiple scan rates.\n\n"
"Axis bounds are **auto-detected** via OCR — "
"override in Advanced if needed."
)
if CV_IMG_EXAMPLES:
with gr.Accordion("Try an example (click to load)", open=True):
with gr.Row():
cv_img_example_dd = gr.Dropdown(
label="Select mechanism",
choices=list(CV_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=3,
)
cv_img_example_btn = gr.Button(
"Load Example", variant="secondary", scale=1,
)
cv_img_example_gallery = gr.Gallery(
label="Example output (classification → posteriors → reconstruction)",
columns=3, height=220, object_fit="contain",
interactive=False,
)
cv_img_files = gr.File(
label="Plot images (one per scan rate)",
file_count="multiple",
file_types=["image"],
)
cv_img_scan_rate = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_img_E0 = gr.Number(
label="Formal potential E₀ (V)",
value=None,
info="Auto-estimated from peaks if empty",
)
cv_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
cv_img_current_unit = gr.Dropdown(
label="Current unit on y-axis",
choices=["auto", "µA", "mA", "A", "nA"],
value="auto",
info="Select the unit shown on the y-axis of your plot",
)
with gr.Accordion("Axis overrides", open=False):
gr.Markdown(
"Leave at 0 to auto-detect from each image. "
"Override if OCR detection is inaccurate. "
"Overrides apply to **all** images."
)
with gr.Row():
cv_img_xmin = gr.Number(label="E min (V)", value=None)
cv_img_xmax = gr.Number(label="E max (V)", value=None)
cv_img_ymin = gr.Number(label="I min", value=None)
cv_img_ymax = gr.Number(label="I max", value=None)
with gr.Row():
cv_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
cv_img_btn = gr.Button("Analyze", variant="primary", scale=2)
(cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img")
ec_img_outputs = [
cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc,
]
cv_img_btn.click(
analyze_cv_image,
inputs=[
cv_img_files, cv_img_scan_rate, cv_img_E0,
cv_img_threshold, cv_img_current_unit,
cv_img_nsamples,
cv_img_xmin, cv_img_xmax,
cv_img_ymin, cv_img_ymax,
],
outputs=ec_img_outputs,
)
cv_img_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_img_mech_dd, cv_img_state],
outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc],
)
if CV_IMG_EXAMPLES:
def _on_cv_img_example_select(mech_name):
files, rates = _load_cv_image_example(mech_name)
renders = CV_OUTPUT_RENDERS.get(mech_name, {})
gallery_imgs = [
renders[k] for k in
["classification", "posteriors", "reconstruction"]
if k in renders
]
return files, rates, gallery_imgs
cv_img_example_btn.click(
_on_cv_img_example_select,
inputs=[cv_img_example_dd],
outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery],
)
cv_img_example_dd.change(
lambda m: [
CV_OUTPUT_RENDERS.get(m, {}).get(k)
for k in ["classification", "posteriors", "reconstruction"]
if CV_OUTPUT_RENDERS.get(m, {}).get(k)
],
inputs=[cv_img_example_dd],
outputs=[cv_img_example_gallery],
)
# =================================================================
# Tab 2: TPD Analysis
# =================================================================
with gr.Tab("TPD Analysis"):
with gr.Tabs():
# --- CSV mode ---
with gr.Tab("CSV Data"):
gr.Markdown(
"Upload CSV files with **Temperature (K)** and **Signal** columns. "
"Optionally include a **Time (s)** column for automatic β detection. "
"One file per heating rate. For best accuracy, upload multiple heating rates.\n\n"
"Example header: `Temperature (K), Signal`\n\n"
"**You must provide the correct β for each file** — "
"the model uses β to condition inference."
)
if TPD_EXAMPLES:
with gr.Accordion("Try an example (no data needed)", open=True):
with gr.Row():
tpd_example_dd = gr.Dropdown(
label="Select example",
choices=list(TPD_EXAMPLES.keys()),
value=None,
interactive=True,
scale=3,
)
tpd_example_btn = gr.Button(
"Load Example", variant="secondary", scale=1,
)
tpd_files = gr.File(
label="CSV files (one per heating rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
tpd_betas = gr.Textbox(
label="Heating rates β (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)",
value="",
)
with gr.Row():
tpd_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
tpd_btn = gr.Button("Analyze", variant="primary", scale=2)
(tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table,
tpd_recon) = _build_tpd_output_section("tpd")
tpd_outputs = [
tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon,
]
tpd_btn.click(
analyze_tpd,
inputs=[tpd_files, tpd_betas, tpd_nsamples],
outputs=tpd_outputs,
)
tpd_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_mech_dd, tpd_state],
outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
)
if TPD_EXAMPLES:
tpd_example_btn.click(
_load_tpd_example,
inputs=[tpd_example_dd],
outputs=[tpd_files, tpd_betas],
)
# --- Image mode ---
with gr.Tab("From Image"):
gr.Markdown(
"Upload plot images of TPD curves (temperature in K on "
"x-axis, signal on y-axis). **One image per heating rate.** "
"For best accuracy, upload multiple heating rates.\n\n"
"Axis bounds are **auto-detected** via OCR — "
"override in Advanced if needed."
)
if TPD_IMG_EXAMPLES:
with gr.Accordion("Try an example (click to load)", open=True):
with gr.Row():
tpd_img_example_dd = gr.Dropdown(
label="Select mechanism",
choices=list(TPD_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=3,
)
tpd_img_example_btn = gr.Button(
"Load Example", variant="secondary", scale=1,
)
tpd_img_example_gallery = gr.Gallery(
label="Example output (classification → posteriors → reconstruction)",
columns=3, height=220, object_fit="contain",
interactive=False,
)
tpd_img_files = gr.File(
label="Plot images (one per heating rate)",
file_count="multiple",
file_types=["image"],
)
tpd_img_betas = gr.Textbox(
label="Heating rates β (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
tpd_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
with gr.Accordion("Axis overrides", open=False):
gr.Markdown(
"Leave at 0 to auto-detect from each image. "
"Override if OCR detection is inaccurate. "
"Overrides apply to **all** images."
)
with gr.Row():
tpd_img_xmin = gr.Number(label="T min (K)", value=None)
tpd_img_xmax = gr.Number(label="T max (K)", value=None)
tpd_img_ymin = gr.Number(label="Signal min", value=None)
tpd_img_ymax = gr.Number(label="Signal max", value=None)
with gr.Row():
tpd_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
tpd_img_btn = gr.Button("Analyze", variant="primary", scale=2)
(tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table,
tpd_img_recon) = _build_tpd_output_section("tpd_img")
tpd_img_outputs = [
tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon,
]
tpd_img_btn.click(
analyze_tpd_image,
inputs=[
tpd_img_files, tpd_img_betas,
tpd_img_threshold, tpd_img_nsamples,
tpd_img_xmin, tpd_img_xmax,
tpd_img_ymin, tpd_img_ymax,
],
outputs=tpd_img_outputs,
)
tpd_img_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_img_mech_dd, tpd_img_state],
outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon],
)
if TPD_IMG_EXAMPLES:
def _on_tpd_img_example_select(mech_name):
files, betas = _load_tpd_image_example(mech_name)
renders = TPD_OUTPUT_RENDERS.get(mech_name, {})
gallery_imgs = [
renders[k] for k in
["classification", "posteriors", "reconstruction"]
if k in renders
]
return files, betas, gallery_imgs
tpd_img_example_btn.click(
_on_tpd_img_example_select,
inputs=[tpd_img_example_dd],
outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery],
)
tpd_img_example_dd.change(
lambda m: [
TPD_OUTPUT_RENDERS.get(m, {}).get(k)
for k in ["classification", "posteriors", "reconstruction"]
if TPD_OUTPUT_RENDERS.get(m, {}).get(k)
],
inputs=[tpd_img_example_dd],
outputs=[tpd_img_example_gallery],
)
# =================================================================
# Tab 3: About
# =================================================================
with gr.Tab("About"):
gr.Markdown("""
## How It Works
ECFlow uses **conditional normalizing flows** with a **Set Transformer** encoder to perform amortized Bayesian inference.
Given one or more experimental curves, it simultaneously classifies the reaction mechanism and produces
full posterior distributions over kinetic parameters — in a single forward pass.
| | Electrochemistry (CV) | Catalysis (TPD) |
|---|---|---|
| **Mechanisms** | Nernst, Butler–Volmer, Marcus–Hush–Chidsey, Adsorption, EC, Langmuir–Hinshelwood | First-order, Second-order, LH Surface, Mars–van Krevelen, Coverage-dependent, Diffusion-limited |
| **Inference** | ~50 ms on CPU | ~50 ms on CPU |
| **Calibration** | 89–94 % coverage at 90 % nominal | Conformal coverage verified |
Training data is generated from physics-based simulators (Crank–Nicolson for CV, ODE integrators for TPD).
Posteriors are calibrated via a coverage-aware loss with per-parameter inverse-spread weighting.
### Citation
```
Yan, B. (2026). ECFlow: Amortized Bayesian Inference for Mechanism Identification
and Parameter Estimation in Electrochemistry and Catalysis via Conditional
Normalizing Flows. [Preprint]
```
Built at MIT. Code and paper at [github.com/bingyan/ECFlow](https://github.com/bingyan/ECFlow).
""")
return app
# =========================================================================
# Entry point
# =========================================================================
if __name__ == "__main__":
app = build_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
)