Bing Yan commited on
Commit
7fcacad
·
1 Parent(s): e854c2f

Add example loader for CV and TPD demo data

Browse files

- Add dropdown + "Load Example" button in both CV and TPD tabs
so users can try the demo without uploading their own data
- Auto-discovers 6 CV and 6 TPD examples from demo_data/ metadata
- Pre-fills file uploads, scan rates/heating rates, and physical params
- Remove "in milliseconds" from subtitle (demo runs on CPU)

Made-with: Cursor

Files changed (2) hide show
  1. app.py +113 -1
  2. multi_mechanism_model.py +8 -1
app.py CHANGED
@@ -35,6 +35,7 @@ from plotting import (
35
  # Model paths (relative to repo root)
36
  # ---------------------------------------------------------------------------
37
  REPO_ROOT = Path(__file__).resolve().parent
 
38
 
39
  EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
40
  TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
@@ -43,6 +44,76 @@ TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
43
  EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)))
44
  TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)))
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  predictor = None
47
 
48
 
@@ -611,7 +682,7 @@ def build_app():
611
  "<div class='main-header'>"
612
  "<h1>⚡ ECFlow</h1>"
613
  "<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
614
- "and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty — in milliseconds.</p>"
615
  "</div>"
616
  )
617
 
@@ -631,6 +702,19 @@ def build_app():
631
  "If a **Time (s)** column is present, the scan rate is "
632
  "detected automatically. Otherwise, enter scan rates below."
633
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  cv_files = gr.File(
635
  label="CSV files (one per scan rate)",
636
  file_count="multiple",
@@ -687,6 +771,15 @@ def build_app():
687
  inputs=[cv_mech_dd, cv_state],
688
  outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
689
  )
 
 
 
 
 
 
 
 
 
690
 
691
  # --- Image mode ---
692
  with gr.Tab("From Image"):
@@ -783,6 +876,19 @@ def build_app():
783
  "**You must provide the correct β for each file** — "
784
  "the model uses β to condition inference."
785
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  tpd_files = gr.File(
787
  label="CSV files (one per heating rate)",
788
  file_count="multiple",
@@ -818,6 +924,12 @@ def build_app():
818
  inputs=[tpd_mech_dd, tpd_state],
819
  outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
820
  )
 
 
 
 
 
 
821
 
822
  # --- Image mode ---
823
  with gr.Tab("From Image"):
 
35
  # Model paths (relative to repo root)
36
  # ---------------------------------------------------------------------------
37
  REPO_ROOT = Path(__file__).resolve().parent
38
+ DEMO_DIR = REPO_ROOT / "demo_data"
39
 
40
  EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
41
  TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
 
44
  EC_CHECKPOINT = Path(os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)))
45
  TPD_CHECKPOINT = Path(os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)))
46
 
47
+ # ---------------------------------------------------------------------------
48
+ # Demo examples
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _discover_examples():
52
+ """Scan demo_data/ for metadata files and build example catalogs."""
53
+ cv_examples = {}
54
+ tpd_examples = {}
55
+ if not DEMO_DIR.is_dir():
56
+ return cv_examples, tpd_examples
57
+ for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")):
58
+ with open(meta_path) as f:
59
+ meta = json.load(f)
60
+ mech = meta["mechanism"]
61
+ csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]]
62
+ if meta_path.name.startswith("ec_"):
63
+ rates = meta.get("scan_rates_Vs", [])
64
+ rates_str = ", ".join(f"{r:.4g}" for r in rates)
65
+ phys = meta.get("physical_params", {})
66
+ cv_examples[f"CV — {mech}"] = {
67
+ "files": csv_files,
68
+ "scan_rates": rates_str,
69
+ "E0_V": phys.get("E0_V"),
70
+ "T_K": phys.get("T_K", 298.15),
71
+ "A_cm2": phys.get("A_cm2", 0.0707),
72
+ "C_mM": phys.get("C_mM", 1.0),
73
+ "D_cm2s": phys.get("D_cm2s", 1e-5),
74
+ "n_electrons": phys.get("n_electrons", 1),
75
+ }
76
+ elif meta_path.name.startswith("tpd_"):
77
+ betas = meta.get("betas_Ks", [])
78
+ betas_str = ", ".join(f"{b:.4g}" for b in betas)
79
+ tpd_examples[f"TPD — {mech}"] = {
80
+ "files": csv_files,
81
+ "heating_rates": betas_str,
82
+ }
83
+ return cv_examples, tpd_examples
84
+
85
+
86
+ CV_EXAMPLES, TPD_EXAMPLES = _discover_examples()
87
+
88
+
89
+ def _load_cv_example(example_name):
90
+ """Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example."""
91
+ if not example_name or example_name not in CV_EXAMPLES:
92
+ return [gr.update()] * 8
93
+ ex = CV_EXAMPLES[example_name]
94
+ return (
95
+ ex["files"],
96
+ ex["scan_rates"],
97
+ ex["E0_V"],
98
+ ex["T_K"],
99
+ ex["A_cm2"],
100
+ ex["C_mM"],
101
+ ex["D_cm2s"],
102
+ ex["n_electrons"],
103
+ )
104
+
105
+
106
+ def _load_tpd_example(example_name):
107
+ """Return (files, heating_rates) for the chosen TPD example."""
108
+ if not example_name or example_name not in TPD_EXAMPLES:
109
+ return [gr.update()] * 2
110
+ ex = TPD_EXAMPLES[example_name]
111
+ return (
112
+ ex["files"],
113
+ ex["heating_rates"],
114
+ )
115
+
116
+
117
  predictor = None
118
 
119
 
 
682
  "<div class='main-header'>"
683
  "<h1>⚡ ECFlow</h1>"
684
  "<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
685
+ "and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty.</p>"
686
  "</div>"
687
  )
688
 
 
702
  "If a **Time (s)** column is present, the scan rate is "
703
  "detected automatically. Otherwise, enter scan rates below."
704
  )
705
+ if CV_EXAMPLES:
706
+ with gr.Accordion("Try an example (no data needed)", open=True):
707
+ with gr.Row():
708
+ cv_example_dd = gr.Dropdown(
709
+ label="Select example",
710
+ choices=list(CV_EXAMPLES.keys()),
711
+ value=None,
712
+ interactive=True,
713
+ scale=3,
714
+ )
715
+ cv_example_btn = gr.Button(
716
+ "Load Example", variant="secondary", scale=1,
717
+ )
718
  cv_files = gr.File(
719
  label="CSV files (one per scan rate)",
720
  file_count="multiple",
 
771
  inputs=[cv_mech_dd, cv_state],
772
  outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
773
  )
774
+ if CV_EXAMPLES:
775
+ cv_example_btn.click(
776
+ _load_cv_example,
777
+ inputs=[cv_example_dd],
778
+ outputs=[
779
+ cv_files, cv_rates, cv_E0, cv_T,
780
+ cv_A, cv_C, cv_D, cv_n,
781
+ ],
782
+ )
783
 
784
  # --- Image mode ---
785
  with gr.Tab("From Image"):
 
876
  "**You must provide the correct β for each file** — "
877
  "the model uses β to condition inference."
878
  )
879
+ if TPD_EXAMPLES:
880
+ with gr.Accordion("Try an example (no data needed)", open=True):
881
+ with gr.Row():
882
+ tpd_example_dd = gr.Dropdown(
883
+ label="Select example",
884
+ choices=list(TPD_EXAMPLES.keys()),
885
+ value=None,
886
+ interactive=True,
887
+ scale=3,
888
+ )
889
+ tpd_example_btn = gr.Button(
890
+ "Load Example", variant="secondary", scale=1,
891
+ )
892
  tpd_files = gr.File(
893
  label="CSV files (one per heating rate)",
894
  file_count="multiple",
 
924
  inputs=[tpd_mech_dd, tpd_state],
925
  outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
926
  )
927
+ if TPD_EXAMPLES:
928
+ tpd_example_btn.click(
929
+ _load_tpd_example,
930
+ inputs=[tpd_example_dd],
931
+ outputs=[tpd_files, tpd_betas],
932
+ )
933
 
934
  # --- Image mode ---
935
  with gr.Tab("From Image"):
multi_mechanism_model.py CHANGED
@@ -130,6 +130,11 @@ class MultiScanEncoder(nn.Module):
130
  if aggregation == 'set_transformer':
131
  self.sab = SAB(d_context, n_heads=n_heads)
132
  self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
 
 
 
 
 
133
  elif aggregation == 'mean_pool':
134
  pass
135
  else:
@@ -177,7 +182,9 @@ class MultiScanEncoder(nn.Module):
177
  h = self.sab(h, key_padding_mask=cv_invalid)
178
  h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
179
  h = h.squeeze(1) # [B, d_context]
180
- elif self.aggregation == 'mean_pool':
 
 
181
  if cv_invalid is not None:
182
  cv_valid = (~cv_invalid).unsqueeze(-1).float() # [B, N, 1]
183
  h = (h * cv_valid).sum(dim=1) / cv_valid.sum(dim=1).clamp(min=1)
 
130
  if aggregation == 'set_transformer':
131
  self.sab = SAB(d_context, n_heads=n_heads)
132
  self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
133
+ elif aggregation == 'deepsets':
134
+ self.phi = nn.Sequential(
135
+ nn.Linear(d_context, d_context),
136
+ nn.GELU(),
137
+ )
138
  elif aggregation == 'mean_pool':
139
  pass
140
  else:
 
182
  h = self.sab(h, key_padding_mask=cv_invalid)
183
  h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
184
  h = h.squeeze(1) # [B, d_context]
185
+ elif self.aggregation in ('deepsets', 'mean_pool'):
186
+ if self.aggregation == 'deepsets':
187
+ h = self.phi(h)
188
  if cv_invalid is not None:
189
  cv_valid = (~cv_invalid).unsqueeze(-1).float() # [B, N, 1]
190
  h = (h * cv_valid).sum(dim=1) / cv_valid.sum(dim=1).clamp(min=1)