Bing Yan commited on
Commit ·
7fcacad
1
Parent(s): e854c2f
Add example loader for CV and TPD demo data
Browse files- Add dropdown + "Load Example" button in both CV and TPD tabs
so users can try the demo without uploading their own data
- Auto-discovers 6 CV and 6 TPD examples from demo_data/ metadata
- Pre-fills file uploads, scan rates/heating rates, and physical params
- Remove "in milliseconds" from subtitle (demo runs on CPU)
Made-with: Cursor
- app.py +113 -1
- multi_mechanism_model.py +8 -1
app.py
CHANGED
|
@@ -35,6 +35,7 @@ from plotting import (
|
|
| 35 |
# Model paths (relative to repo root)
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
REPO_ROOT = Path(__file__).resolve().parent
|
|
|
|
| 38 |
|
| 39 |
EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
|
| 40 |
TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
|
|
@@ -43,6 +44,76 @@ TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
|
|
| 43 |
EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)))
|
| 44 |
TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)))
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
predictor = None
|
| 47 |
|
| 48 |
|
|
@@ -611,7 +682,7 @@ def build_app():
|
|
| 611 |
"<div class='main-header'>"
|
| 612 |
"<h1>⚡ ECFlow</h1>"
|
| 613 |
"<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
|
| 614 |
-
"and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty
|
| 615 |
"</div>"
|
| 616 |
)
|
| 617 |
|
|
@@ -631,6 +702,19 @@ def build_app():
|
|
| 631 |
"If a **Time (s)** column is present, the scan rate is "
|
| 632 |
"detected automatically. Otherwise, enter scan rates below."
|
| 633 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
cv_files = gr.File(
|
| 635 |
label="CSV files (one per scan rate)",
|
| 636 |
file_count="multiple",
|
|
@@ -687,6 +771,15 @@ def build_app():
|
|
| 687 |
inputs=[cv_mech_dd, cv_state],
|
| 688 |
outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
|
| 689 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
# --- Image mode ---
|
| 692 |
with gr.Tab("From Image"):
|
|
@@ -783,6 +876,19 @@ def build_app():
|
|
| 783 |
"**You must provide the correct β for each file** — "
|
| 784 |
"the model uses β to condition inference."
|
| 785 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
tpd_files = gr.File(
|
| 787 |
label="CSV files (one per heating rate)",
|
| 788 |
file_count="multiple",
|
|
@@ -818,6 +924,12 @@ def build_app():
|
|
| 818 |
inputs=[tpd_mech_dd, tpd_state],
|
| 819 |
outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
|
| 820 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
|
| 822 |
# --- Image mode ---
|
| 823 |
with gr.Tab("From Image"):
|
|
|
|
| 35 |
# Model paths (relative to repo root)
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
REPO_ROOT = Path(__file__).resolve().parent
|
| 38 |
+
DEMO_DIR = REPO_ROOT / "demo_data"
|
| 39 |
|
| 40 |
EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
|
| 41 |
TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
|
|
|
|
| 44 |
EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)))
|
| 45 |
TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)))
|
| 46 |
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Demo examples
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def _discover_examples():
|
| 52 |
+
"""Scan demo_data/ for metadata files and build example catalogs."""
|
| 53 |
+
cv_examples = {}
|
| 54 |
+
tpd_examples = {}
|
| 55 |
+
if not DEMO_DIR.is_dir():
|
| 56 |
+
return cv_examples, tpd_examples
|
| 57 |
+
for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")):
|
| 58 |
+
with open(meta_path) as f:
|
| 59 |
+
meta = json.load(f)
|
| 60 |
+
mech = meta["mechanism"]
|
| 61 |
+
csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]]
|
| 62 |
+
if meta_path.name.startswith("ec_"):
|
| 63 |
+
rates = meta.get("scan_rates_Vs", [])
|
| 64 |
+
rates_str = ", ".join(f"{r:.4g}" for r in rates)
|
| 65 |
+
phys = meta.get("physical_params", {})
|
| 66 |
+
cv_examples[f"CV — {mech}"] = {
|
| 67 |
+
"files": csv_files,
|
| 68 |
+
"scan_rates": rates_str,
|
| 69 |
+
"E0_V": phys.get("E0_V"),
|
| 70 |
+
"T_K": phys.get("T_K", 298.15),
|
| 71 |
+
"A_cm2": phys.get("A_cm2", 0.0707),
|
| 72 |
+
"C_mM": phys.get("C_mM", 1.0),
|
| 73 |
+
"D_cm2s": phys.get("D_cm2s", 1e-5),
|
| 74 |
+
"n_electrons": phys.get("n_electrons", 1),
|
| 75 |
+
}
|
| 76 |
+
elif meta_path.name.startswith("tpd_"):
|
| 77 |
+
betas = meta.get("betas_Ks", [])
|
| 78 |
+
betas_str = ", ".join(f"{b:.4g}" for b in betas)
|
| 79 |
+
tpd_examples[f"TPD — {mech}"] = {
|
| 80 |
+
"files": csv_files,
|
| 81 |
+
"heating_rates": betas_str,
|
| 82 |
+
}
|
| 83 |
+
return cv_examples, tpd_examples
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
CV_EXAMPLES, TPD_EXAMPLES = _discover_examples()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _load_cv_example(example_name):
|
| 90 |
+
"""Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example."""
|
| 91 |
+
if not example_name or example_name not in CV_EXAMPLES:
|
| 92 |
+
return [gr.update()] * 8
|
| 93 |
+
ex = CV_EXAMPLES[example_name]
|
| 94 |
+
return (
|
| 95 |
+
ex["files"],
|
| 96 |
+
ex["scan_rates"],
|
| 97 |
+
ex["E0_V"],
|
| 98 |
+
ex["T_K"],
|
| 99 |
+
ex["A_cm2"],
|
| 100 |
+
ex["C_mM"],
|
| 101 |
+
ex["D_cm2s"],
|
| 102 |
+
ex["n_electrons"],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _load_tpd_example(example_name):
|
| 107 |
+
"""Return (files, heating_rates) for the chosen TPD example."""
|
| 108 |
+
if not example_name or example_name not in TPD_EXAMPLES:
|
| 109 |
+
return [gr.update()] * 2
|
| 110 |
+
ex = TPD_EXAMPLES[example_name]
|
| 111 |
+
return (
|
| 112 |
+
ex["files"],
|
| 113 |
+
ex["heating_rates"],
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
predictor = None
|
| 118 |
|
| 119 |
|
|
|
|
| 682 |
"<div class='main-header'>"
|
| 683 |
"<h1>⚡ ECFlow</h1>"
|
| 684 |
"<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
|
| 685 |
+
"and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty.</p>"
|
| 686 |
"</div>"
|
| 687 |
)
|
| 688 |
|
|
|
|
| 702 |
"If a **Time (s)** column is present, the scan rate is "
|
| 703 |
"detected automatically. Otherwise, enter scan rates below."
|
| 704 |
)
|
| 705 |
+
if CV_EXAMPLES:
|
| 706 |
+
with gr.Accordion("Try an example (no data needed)", open=True):
|
| 707 |
+
with gr.Row():
|
| 708 |
+
cv_example_dd = gr.Dropdown(
|
| 709 |
+
label="Select example",
|
| 710 |
+
choices=list(CV_EXAMPLES.keys()),
|
| 711 |
+
value=None,
|
| 712 |
+
interactive=True,
|
| 713 |
+
scale=3,
|
| 714 |
+
)
|
| 715 |
+
cv_example_btn = gr.Button(
|
| 716 |
+
"Load Example", variant="secondary", scale=1,
|
| 717 |
+
)
|
| 718 |
cv_files = gr.File(
|
| 719 |
label="CSV files (one per scan rate)",
|
| 720 |
file_count="multiple",
|
|
|
|
| 771 |
inputs=[cv_mech_dd, cv_state],
|
| 772 |
outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
|
| 773 |
)
|
| 774 |
+
if CV_EXAMPLES:
|
| 775 |
+
cv_example_btn.click(
|
| 776 |
+
_load_cv_example,
|
| 777 |
+
inputs=[cv_example_dd],
|
| 778 |
+
outputs=[
|
| 779 |
+
cv_files, cv_rates, cv_E0, cv_T,
|
| 780 |
+
cv_A, cv_C, cv_D, cv_n,
|
| 781 |
+
],
|
| 782 |
+
)
|
| 783 |
|
| 784 |
# --- Image mode ---
|
| 785 |
with gr.Tab("From Image"):
|
|
|
|
| 876 |
"**You must provide the correct β for each file** — "
|
| 877 |
"the model uses β to condition inference."
|
| 878 |
)
|
| 879 |
+
if TPD_EXAMPLES:
|
| 880 |
+
with gr.Accordion("Try an example (no data needed)", open=True):
|
| 881 |
+
with gr.Row():
|
| 882 |
+
tpd_example_dd = gr.Dropdown(
|
| 883 |
+
label="Select example",
|
| 884 |
+
choices=list(TPD_EXAMPLES.keys()),
|
| 885 |
+
value=None,
|
| 886 |
+
interactive=True,
|
| 887 |
+
scale=3,
|
| 888 |
+
)
|
| 889 |
+
tpd_example_btn = gr.Button(
|
| 890 |
+
"Load Example", variant="secondary", scale=1,
|
| 891 |
+
)
|
| 892 |
tpd_files = gr.File(
|
| 893 |
label="CSV files (one per heating rate)",
|
| 894 |
file_count="multiple",
|
|
|
|
| 924 |
inputs=[tpd_mech_dd, tpd_state],
|
| 925 |
outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
|
| 926 |
)
|
| 927 |
+
if TPD_EXAMPLES:
|
| 928 |
+
tpd_example_btn.click(
|
| 929 |
+
_load_tpd_example,
|
| 930 |
+
inputs=[tpd_example_dd],
|
| 931 |
+
outputs=[tpd_files, tpd_betas],
|
| 932 |
+
)
|
| 933 |
|
| 934 |
# --- Image mode ---
|
| 935 |
with gr.Tab("From Image"):
|
multi_mechanism_model.py
CHANGED
|
@@ -130,6 +130,11 @@ class MultiScanEncoder(nn.Module):
|
|
| 130 |
if aggregation == 'set_transformer':
|
| 131 |
self.sab = SAB(d_context, n_heads=n_heads)
|
| 132 |
self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
elif aggregation == 'mean_pool':
|
| 134 |
pass
|
| 135 |
else:
|
|
@@ -177,7 +182,9 @@ class MultiScanEncoder(nn.Module):
|
|
| 177 |
h = self.sab(h, key_padding_mask=cv_invalid)
|
| 178 |
h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
|
| 179 |
h = h.squeeze(1) # [B, d_context]
|
| 180 |
-
elif self.aggregation
|
|
|
|
|
|
|
| 181 |
if cv_invalid is not None:
|
| 182 |
cv_valid = (~cv_invalid).unsqueeze(-1).float() # [B, N, 1]
|
| 183 |
h = (h * cv_valid).sum(dim=1) / cv_valid.sum(dim=1).clamp(min=1)
|
|
|
|
| 130 |
if aggregation == 'set_transformer':
|
| 131 |
self.sab = SAB(d_context, n_heads=n_heads)
|
| 132 |
self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
|
| 133 |
+
elif aggregation == 'deepsets':
|
| 134 |
+
self.phi = nn.Sequential(
|
| 135 |
+
nn.Linear(d_context, d_context),
|
| 136 |
+
nn.GELU(),
|
| 137 |
+
)
|
| 138 |
elif aggregation == 'mean_pool':
|
| 139 |
pass
|
| 140 |
else:
|
|
|
|
| 182 |
h = self.sab(h, key_padding_mask=cv_invalid)
|
| 183 |
h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
|
| 184 |
h = h.squeeze(1) # [B, d_context]
|
| 185 |
+
elif self.aggregation in ('deepsets', 'mean_pool'):
|
| 186 |
+
if self.aggregation == 'deepsets':
|
| 187 |
+
h = self.phi(h)
|
| 188 |
if cv_invalid is not None:
|
| 189 |
cv_valid = (~cv_invalid).unsqueeze(-1).float() # [B, N, 1]
|
| 190 |
h = (h * cv_valid).sum(dim=1) / cv_valid.sum(dim=1).clamp(min=1)
|