kaveh commited on
Commit
4d886f4
·
1 Parent(s): 756cec6

Updated. First version.

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .venv
3
+ venv
4
+ __pycache__
5
+ *.py[cod]
6
+ *.egg-info
7
+ .pytest_cache
8
+ .mypy_cache
9
+ .ruff_cache
10
+ .cursor
11
+ *.ipynb_checkpoints
12
+ notebooks
.gitignore CHANGED
@@ -1,3 +1,10 @@
1
  __pycache__/
2
 
3
- .DS_Store
 
 
 
 
 
 
 
 
1
  __pycache__/
2
 
3
+ .DS_Store
4
+ .venv/
5
+ venv/
6
+
7
+ # Precomputed explorer artifacts (regenerate with scripts/precompute_streamlit_cache.py)
8
+ streamlit_hf/cache/*.pkl
9
+ streamlit_hf/cache/*.parquet
10
+ streamlit_hf/cache/*.csv
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces (Streamlit + Docker). Port 7860.
2
+ # Build context: repository root. Upload `streamlit_hf/cache/*` (pickles + parquet) via Git LFS or CI.
3
+
4
+ FROM python:3.11-slim-bookworm
5
+
6
+ WORKDIR /app
7
+
8
+ COPY streamlit_hf/requirements-docker.txt /app/requirements-docker.txt
9
+ RUN pip install --no-cache-dir -r /app/requirements-docker.txt
10
+
11
+ COPY . /app
12
+
13
+ ENV PYTHONPATH=/app
14
+ ENV STREAMLIT_SERVER_HEADLESS=true
15
+ EXPOSE 7860
16
+
17
+ CMD ["streamlit", "run", "streamlit_hf/app.py", "--server.port", "7860", "--server.address", "0.0.0.0", "--browser.gatherUsageStats", "false"]
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FateFormer Explorer
3
+ short_description: Streamlit app to explore multimodal single-cell fate modeling (RNA, ATAC, metabolic flux, attention, and rankings).
4
+ emoji: 🧬
5
+ colorFrom: violet
6
+ colorTo: indigo
7
+ tags:
8
+ - streamlit
9
+ - single-cell
10
+ - multi-omics
11
+ - genomics
12
+ - atac-seq
13
+ - rna-seq
14
+ - metabolic-modeling
15
+ - deep-learning
16
+ - biology
17
+ license: mit
18
+ sdk: docker
19
+ app_port: 7860
20
+ ---
21
+
22
+ # FateFormerApp
23
+
24
+ ## Interactive explorer (Streamlit)
25
+
26
+ From the repo root, with the project virtualenv activated:
27
+
28
+ ```bash
29
+ PYTHONPATH=. streamlit run streamlit_hf/app.py
30
+ ```
31
+
32
+ The default local port is **8501**. The **Dockerfile** (and Hugging Face Space card above) use **7860** to match Spaces.
33
+
34
+ ### Updating results after new experiments (no code changes)
35
+
36
+ The app reads **fixed paths**. Replace files under `streamlit_hf/cache/` using the **same filenames**; then **restart Streamlit** (or do a hard refresh) so the new data loads.
37
+
38
+ | File | What it drives |
39
+ |------|----------------|
40
+ | `streamlit_hf/cache/latent_umap.pkl` | Single-Cell Explorer (UMAP) |
41
+ | `streamlit_hf/cache/df_features.parquet` | Feature insights + Flux analysis |
42
+ | `streamlit_hf/cache/attention_summary.pkl` | “Attention vs prediction” in Feature insights |
43
+ | `streamlit_hf/cache/attention_feature_ranks.pkl` | Optional; attention lists also live inside `attention_summary.pkl` |
44
+
45
+ You can also keep `analysis/df_features.csv` in sync for your own workflows; the UI **prefers** `streamlit_hf/cache/df_features.parquet` when present.
46
+
47
+ ### Regenerating caches from this repo
48
+
49
+ If you updated checkpoints, fold splits, shift pickles, or deg tables **inside this project**, run:
50
+
51
+ ```bash
52
+ python scripts/precompute_streamlit_cache.py
53
+ ```
54
+
55
+ That script expects (among others) `ckp/*.pth`, `objects/fold_results_multi.pkl`, `objects/mutlimodal_dataset.pkl`, `objects/fi_shift_*.pkl`, and `objects/degs.pkl`. Point those inputs at your new experiment outputs **before** running the script, or copy your new pickles/CSVs into `streamlit_hf/cache/` manually as in the table above.
56
+
57
+ ### Docker / Hugging Face
58
+
59
+ See `streamlit_hf/HUGGINGFACE.md` and the root `Dockerfile`.
analysis/df_features.csv CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/analysis_plots.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FateFormerApp — training, precompute, and local Streamlit dev
2
+ torch>=2.1.0
3
+ numpy>=1.24.0
4
+ pandas>=2.0.0
5
+ scipy>=1.11.0
6
+ scikit-learn>=1.3.0
7
+ umap-learn>=0.5.5
8
+ tqdm>=4.66.0
9
+ anndata>=0.10.0
10
+ scanpy>=1.9.0
11
+ statsmodels>=0.14.0
12
+ tabulate>=0.9.0
13
+ streamlit>=1.40.0
14
+ plotly>=5.22.0
15
+ pyarrow>=14.0.0
scripts/precompute_streamlit_cache.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ One-off cache builder for the Streamlit explorer.
4
+ Run from the repository root:
5
+ python scripts/precompute_streamlit_cache.py
6
+ python scripts/precompute_streamlit_cache.py --skip-attention # faster: reuse objects/fi_shift_*.pkl only for df_features if attention_summary exists
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import os
13
+ import pickle
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch
20
+ import umap
21
+
22
+ ROOT = Path(__file__).resolve().parents[1]
23
+ sys.path.insert(0, str(ROOT))
24
+ os.chdir(ROOT)
25
+
26
+ from data import create_dataset # noqa: E402
27
+ from interpretation import attentions as att # noqa: E402
28
+ from interpretation import latentspace as ls # noqa: E402
29
+ from interpretation import predictions as prds # noqa: E402
30
+ CACHE = ROOT / "streamlit_hf" / "cache"
31
+ CACHE.mkdir(parents=True, exist_ok=True)
32
+
33
+
34
+ def replace_fold_results_path(fold_results, ckp_root: str = "ckp"):
35
+ """Point checkpoints at flat `ckp/multi_seed0_fold{k}.pth` layout in this repo."""
36
+ for fold in fold_results:
37
+ ckpt_name = os.path.basename(fold["best_model_path"])
38
+ fold_token = next((part for part in ckpt_name.split("_") if part.startswith("fold")), "")
39
+ fold_idx = "".join(ch for ch in fold_token if ch.isdigit())
40
+ if fold_idx:
41
+ clean_ckpt_name = f"multi_seed0_fold{fold_idx}.pth"
42
+ else:
43
+ clean_ckpt_name = ckpt_name
44
+ fold["best_model_path"] = os.path.join(ckp_root, clean_ckpt_name)
45
+ return fold_results
46
+
47
+
48
+ def load_training_context():
49
+ with open(ROOT / "objects" / "mutlimodal_dataset.pkl", "rb") as f:
50
+ md = pickle.load(f)
51
+ X, y_label = md["X"], md["y_label"]
52
+ b, df_indices, pcts = md["b"], md["df_indices"], md["pcts"]
53
+
54
+ y_number = torch.tensor(
55
+ [{"reprogramming": 1, "dead-end": 0}[i] for i in list(y_label)],
56
+ dtype=torch.float32,
57
+ )
58
+ multimodal_dataset = create_dataset.MultiModalDataset(
59
+ X, b, y_number, df_indices, pcts, y_label
60
+ )
61
+
62
+ with open(ROOT / "objects" / "fold_results_multi.pkl", "rb") as f:
63
+ fold_results = pickle.load(f)
64
+ fold_results = replace_fold_results_path(fold_results)
65
+
66
+ share_config = {
67
+ "d_model": 128,
68
+ "d_ff": 16,
69
+ "n_heads": 8,
70
+ "n_encoder_layers": 2,
71
+ "n_batches": 3,
72
+ "dropout_rate": 0.0,
73
+ }
74
+ model_config_rna = {"vocab_size": 5914, "seq_len": X[0].shape[1]}
75
+ model_config_atac = {"vocab_size": 1, "seq_len": X[1].shape[1]}
76
+ model_config_flux = {"vocab_size": 1, "seq_len": X[2].shape[1]}
77
+ model_config_multi = {"d_model": 128, "n_heads_cls": 8, "d_ff_cls": 16}
78
+ model_config = {
79
+ "Share": share_config,
80
+ "RNA": model_config_rna,
81
+ "ATAC": model_config_atac,
82
+ "Flux": model_config_flux,
83
+ "Multi": model_config_multi,
84
+ }
85
+
86
+ feature_names = (
87
+ list(X[0].columns)
88
+ + ["batch_rna"]
89
+ + list(X[1].columns)
90
+ + ["batch_atac"]
91
+ + list(X[2].columns)
92
+ + ["batch_flux"]
93
+ )
94
+
95
+ adata_RNA_labelled = None
96
+ rna_pkl = ROOT / "data" / "datasets" / "rna_labelled.pkl"
97
+ try:
98
+ with open(rna_pkl, "rb") as f:
99
+ adata_RNA_labelled = pickle.load(f)
100
+ except Exception as e:
101
+ print(
102
+ f"Warning: could not load {rna_pkl} ({e}). "
103
+ "Sample table will omit AnnData-derived metadata (e.g. clone_id)."
104
+ )
105
+
106
+ return (
107
+ multimodal_dataset,
108
+ fold_results,
109
+ model_config,
110
+ feature_names,
111
+ adata_RNA_labelled,
112
+ )
113
+
114
+
115
+ def build_latent_umap(multimodal_dataset, fold_results, model_config, common_samples: bool = False):
116
+ device = "cuda" if torch.cuda.is_available() else "cpu"
117
+ ls_v, labels, preds = ls.get_latent_space(
118
+ "Multi",
119
+ fold_results,
120
+ multimodal_dataset,
121
+ model_config,
122
+ device,
123
+ common_samples=common_samples,
124
+ )
125
+ reducer = umap.UMAP(n_components=2, random_state=0, n_neighbors=30, min_dist=1.0)
126
+ xy = reducer.fit_transform(ls_v)
127
+
128
+ ordered_indices: list[int] = []
129
+ fold_ids: list[int] = []
130
+ from interpretation.attentions import filter_idx # noqa: PLC0415
131
+ from torch.utils.data import Subset # noqa: PLC0415
132
+
133
+ for fold_idx, fold in enumerate(fold_results):
134
+ val_idx = fold["val_idx"]
135
+ if common_samples:
136
+ val_idx = filter_idx(multimodal_dataset, val_idx)
137
+ ordered_indices.extend(val_idx)
138
+ fold_ids.extend([fold_idx + 1] * len(val_idx))
139
+
140
+ labels = np.asarray(labels).ravel()
141
+ preds = np.asarray(preds).ravel().astype(int)
142
+ label_name = np.where(labels > 0.5, "reprogramming", "dead-end")
143
+ pred_name = np.where(preds > 0.5, "reprogramming", "dead-end")
144
+ correct = (preds == labels.astype(int)).astype(np.int8)
145
+
146
+ ds = multimodal_dataset
147
+ batch_no = np.array([int(ds.batch_no[i].item()) for i in ordered_indices], dtype=np.int32)
148
+ pcts = np.array([float(ds.pcts[i]) for i in ordered_indices], dtype=np.float64)
149
+
150
+ modalities = []
151
+ for i in ordered_indices:
152
+ has_r = (ds.rna_data[i] != 0).any().item()
153
+ has_a = (ds.atac_data[i] != 0).any().item()
154
+ has_f = (ds.flux_data[i] != 0).any().item()
155
+ s = "".join(c for c, h in (("R", has_r), ("A", has_a), ("F", has_f)) if h)
156
+ modalities.append(s or "None")
157
+
158
+ return {
159
+ "umap_x": xy[:, 0].astype(np.float32),
160
+ "umap_y": xy[:, 1].astype(np.float32),
161
+ "label_name": label_name,
162
+ "pred_name": pred_name,
163
+ "correct": correct,
164
+ "fold": np.array(fold_ids, dtype=np.int32),
165
+ "batch_no": batch_no,
166
+ "pct": pcts,
167
+ "modality": modalities,
168
+ "dataset_idx": np.array(ordered_indices, dtype=np.int32),
169
+ "common_samples": common_samples,
170
+ }
171
+
172
+
173
+ def create_combined_feature_dataframe(
174
+ fi_shift_rna,
175
+ fi_shift_atac,
176
+ fi_shift_flux,
177
+ fi_att_rna,
178
+ fi_att_atac,
179
+ fi_att_flux,
180
+ df_rna_degs=None,
181
+ df_atac_degs=None,
182
+ df_flux_degs=None,
183
+ remove_batch=True,
184
+ ):
185
+ def process_modality(shift_list, att_list, degs_df, modality_name):
186
+ shift_df = pd.DataFrame(shift_list, columns=["feature", "importance_shift"]).reset_index()
187
+ shift_df.rename(columns={"index": "rank_shift_in_modal"}, inplace=True)
188
+ shift_df["rank_shift_in_modal"] += 1
189
+
190
+ att_df = pd.DataFrame(att_list, columns=["feature", "importance_att"]).reset_index()
191
+ att_df.rename(columns={"index": "rank_att_in_modal"}, inplace=True)
192
+ att_df["rank_att_in_modal"] += 1
193
+
194
+ combined_df = pd.merge(shift_df, att_df, on="feature", how="outer")
195
+ if degs_df is not None:
196
+ combined_df = pd.merge(combined_df, degs_df, on="feature", how="left")
197
+ combined_df["modality"] = modality_name
198
+ return combined_df
199
+
200
+ rna_df = process_modality(fi_shift_rna, fi_att_rna, df_rna_degs, "RNA")
201
+ atac_df = process_modality(fi_shift_atac, fi_att_atac, df_atac_degs, "ATAC")
202
+ flux_df = process_modality(fi_shift_flux, fi_att_flux, df_flux_degs, "Flux")
203
+ all_features_df = pd.concat([rna_df, atac_df, flux_df], ignore_index=True)
204
+
205
+ if remove_batch:
206
+ all_features_df = all_features_df[~all_features_df["feature"].str.contains("batch", na=False)]
207
+
208
+ max_rank_modal = max(
209
+ all_features_df["rank_att_in_modal"].max(), all_features_df["rank_shift_in_modal"].max()
210
+ )
211
+ all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
212
+ ["rank_att_in_modal", "rank_shift_in_modal"]
213
+ ].fillna(max_rank_modal + 1)
214
+ all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
215
+ ["rank_att_in_modal", "rank_shift_in_modal"]
216
+ ].astype("int32")
217
+
218
+ all_features_df[["importance_att", "importance_shift"]] = (
219
+ all_features_df[["importance_att", "importance_shift"]].fillna(0).astype("float64")
220
+ )
221
+
222
+ all_features_df["rank_shift"] = (
223
+ all_features_df["importance_shift"].rank(ascending=False, method="first").astype("int32")
224
+ )
225
+ all_features_df["rank_att"] = (
226
+ all_features_df["importance_att"].rank(ascending=False, method="first").astype("int32")
227
+ )
228
+ all_features_df["mean_rank"] = all_features_df[["rank_att", "rank_shift"]].mean(axis=1)
229
+
230
+ top_th = int(all_features_df.shape[0] * 0.1) + 1
231
+ all_features_df["top_10_pct"] = all_features_df.apply(
232
+ lambda row: "both"
233
+ if row["rank_shift"] <= top_th and row["rank_att"] <= top_th
234
+ else (
235
+ "shift"
236
+ if row["rank_shift"] <= top_th
237
+ else ("att" if row["rank_att"] <= top_th else "None")
238
+ ),
239
+ axis=1,
240
+ )
241
+
242
+ float_cols = [
243
+ col for col in all_features_df.columns if col.startswith(("log_fc", "mean_", "std_", "pval_"))
244
+ ]
245
+ if float_cols:
246
+ all_features_df[float_cols] = all_features_df[float_cols].round(6)
247
+ all_features_df["importance_att"] = all_features_df["importance_att"].round(6)
248
+ all_features_df["importance_shift"] = all_features_df["importance_shift"].round(6)
249
+ all_features_df = all_features_df.sort_values(by="mean_rank", ascending=True)
250
+
251
+ cols = [
252
+ "mean_rank",
253
+ "feature",
254
+ "rank_shift",
255
+ "rank_att",
256
+ "rank_shift_in_modal",
257
+ "rank_att_in_modal",
258
+ "modality",
259
+ "importance_shift",
260
+ "importance_att",
261
+ "top_10_pct",
262
+ "mean_de",
263
+ "mean_re",
264
+ "std_de",
265
+ "std_re",
266
+ "pval",
267
+ "pval_adj",
268
+ "log_fc",
269
+ "group",
270
+ "pval_adj_log",
271
+ "mean_diff",
272
+ "pathway",
273
+ "module",
274
+ ]
275
+ for c in cols:
276
+ if c not in all_features_df.columns:
277
+ all_features_df[c] = np.nan
278
+ return all_features_df[cols]
279
+
280
+
281
+ def run_attention_and_fi(
282
+ multimodal_dataset,
283
+ fold_results,
284
+ model_config,
285
+ feature_names,
286
+ device: str,
287
+ adata_rna,
288
+ ):
289
+ df_samples = prds.get_sample_predictions_dataframe(
290
+ model_type="Multi",
291
+ multimodal_dataset=multimodal_dataset,
292
+ fold_results=fold_results,
293
+ model_config=model_config,
294
+ device=device,
295
+ batch_size=32,
296
+ threshold=0.5,
297
+ adata_rna=adata_rna,
298
+ )
299
+ all_indices = df_samples["ind"].tolist()
300
+ de_preds_indices = df_samples[df_samples["predicted_class"] == "dead-end"]["ind"].tolist()
301
+ re_preds_indices = df_samples[df_samples["predicted_class"] == "reprogramming"]["ind"].tolist()
302
+
303
+ print("Running flow attention (all validation)…")
304
+ all_layers_all = att.analyze_cls_attention(
305
+ "Multi",
306
+ fold_results,
307
+ multimodal_dataset,
308
+ model_config,
309
+ device=device,
310
+ indices=all_indices,
311
+ average_heads=False,
312
+ return_flow_attention=True,
313
+ )
314
+ print("Running flow attention (predicted dead-end)…")
315
+ all_layers_de = att.analyze_cls_attention(
316
+ "Multi",
317
+ fold_results,
318
+ multimodal_dataset,
319
+ model_config,
320
+ device=device,
321
+ indices=de_preds_indices,
322
+ average_heads=False,
323
+ return_flow_attention=True,
324
+ )
325
+ print("Running flow attention (predicted reprogramming)…")
326
+ all_layers_re = att.analyze_cls_attention(
327
+ "Multi",
328
+ fold_results,
329
+ multimodal_dataset,
330
+ model_config,
331
+ device=device,
332
+ indices=re_preds_indices,
333
+ average_heads=False,
334
+ return_flow_attention=True,
335
+ )
336
+
337
+ rollout_all = att.multimodal_attention_rollout(all_layers_all)
338
+ rollout_de = att.multimodal_attention_rollout(all_layers_de)
339
+ rollout_re = att.multimodal_attention_rollout(all_layers_re)
340
+ rollout_all = rollout_all / rollout_all.sum(dim=-1, keepdim=True)
341
+ rollout_de = rollout_de / rollout_de.sum(dim=-1, keepdim=True)
342
+ rollout_re = rollout_re / rollout_re.sum(dim=-1, keepdim=True)
343
+
344
+ # Explicit splits (notebook): RNA [:945], ATAC [945:945+884], flux rest
345
+ i0, i1, i2 = 0, 945, 945 + 884
346
+
347
+ def mean_vec(t):
348
+ return t.mean(dim=0).detach().cpu().numpy()
349
+
350
+ rollout_mean = {
351
+ "all": mean_vec(rollout_all),
352
+ "dead_end": mean_vec(rollout_de),
353
+ "reprogramming": mean_vec(rollout_re),
354
+ }
355
+
356
+ top_n_get = None
357
+ fi = {"all": {}, "dead_end": {}, "reprogramming": {}}
358
+ for name, tensor in (
359
+ ("all", rollout_all),
360
+ ("dead_end", rollout_de),
361
+ ("reprogramming", rollout_re),
362
+ ):
363
+ fi[name]["rna"] = att.get_top_features(
364
+ tensor[:, i0:i1], feature_names[i0:i1], modality="RNA", top_n=top_n_get
365
+ )
366
+ fi[name]["atac"] = att.get_top_features(
367
+ tensor[:, i1:i2], feature_names[i1:i2], modality="ATAC", top_n=top_n_get
368
+ )
369
+ fi[name]["flux"] = att.get_top_features(
370
+ tensor[:, i2:], feature_names[i2:], modality="Flux", top_n=top_n_get
371
+ )
372
+
373
+ summary = {
374
+ "feature_names": feature_names,
375
+ "slices": {
376
+ "RNA": {"start": i0, "stop": i1},
377
+ "ATAC": {"start": i1, "stop": i2},
378
+ "Flux": {"start": i2, "stop": len(feature_names)},
379
+ },
380
+ "rollout_mean": rollout_mean,
381
+ "fi_att": fi,
382
+ }
383
+ return summary, df_samples
384
+
385
+
386
+ def main():
387
+ ap = argparse.ArgumentParser()
388
+ ap.add_argument("--skip-attention", action="store_true", help="Skip attention if summary exists")
389
+ ap.add_argument(
390
+ "--common-samples",
391
+ action="store_true",
392
+ help="Use common-samples filter for latent UMAP (default: False, notebook-style)",
393
+ )
394
+ args = ap.parse_args()
395
+ common_samples = args.common_samples
396
+
397
+ device = "cuda" if torch.cuda.is_available() else "cpu"
398
+ print(f"Device: {device}")
399
+
400
+ (
401
+ multimodal_dataset,
402
+ fold_results,
403
+ model_config,
404
+ feature_names,
405
+ adata_RNA_labelled,
406
+ ) = load_training_context()
407
+
408
+ print("Building latent UMAP bundle…")
409
+ latent = build_latent_umap(
410
+ multimodal_dataset, fold_results, model_config, common_samples=common_samples
411
+ )
412
+ with open(CACHE / "latent_umap.pkl", "wb") as f:
413
+ pickle.dump(latent, f)
414
+
415
+ att_path = CACHE / "attention_summary.pkl"
416
+ df_samples_path = CACHE / "samples.parquet"
417
+
418
+ if args.skip_attention and att_path.is_file():
419
+ print("Skipping attention (--skip-attention, file exists).")
420
+ with open(att_path, "rb") as f:
421
+ summary = pickle.load(f)
422
+ else:
423
+ print("Computing attention + rollout (slow)…")
424
+ summary, df_samples = run_attention_and_fi(
425
+ multimodal_dataset,
426
+ fold_results,
427
+ model_config,
428
+ feature_names,
429
+ device,
430
+ adata_RNA_labelled,
431
+ )
432
+ with open(att_path, "wb") as f:
433
+ pickle.dump(summary, f)
434
+ with open(CACHE / "attention_feature_ranks.pkl", "wb") as f:
435
+ pickle.dump(summary["fi_att"], f)
436
+ df_samples.to_parquet(df_samples_path, index=False)
437
+
438
+ if args.skip_attention and att_path.is_file() and not df_samples_path.is_file():
439
+ df_samples = prds.get_sample_predictions_dataframe(
440
+ model_type="Multi",
441
+ multimodal_dataset=multimodal_dataset,
442
+ fold_results=fold_results,
443
+ model_config=model_config,
444
+ device=device,
445
+ batch_size=32,
446
+ threshold=0.5,
447
+ adata_rna=adata_RNA_labelled,
448
+ )
449
+ df_samples.to_parquet(df_samples_path, index=False)
450
+
451
+ for name in ["fi_shift_rna.pkl", "fi_shift_atac.pkl", "fi_shift_flux.pkl"]:
452
+ src = ROOT / "objects" / name
453
+ if not src.is_file():
454
+ print(f"Warning: missing {src}")
455
+
456
+ with open(ROOT / "objects" / "fi_shift_rna.pkl", "rb") as f:
457
+ fi_shift_rna = pickle.load(f)
458
+ with open(ROOT / "objects" / "fi_shift_atac.pkl", "rb") as f:
459
+ fi_shift_atac = pickle.load(f)
460
+ with open(ROOT / "objects" / "fi_shift_flux.pkl", "rb") as f:
461
+ fi_shift_flux = pickle.load(f)
462
+
463
+ with open(ROOT / "objects" / "degs.pkl", "rb") as f:
464
+ degs = pickle.load(f)
465
+ df_rna_degs, df_atac_degs, df_flux_degs = degs[0], degs[1], degs[2]
466
+
467
+ fi = summary["fi_att"]
468
+ df_features = create_combined_feature_dataframe(
469
+ fi_shift_rna,
470
+ fi_shift_atac,
471
+ fi_shift_flux,
472
+ fi["all"]["rna"],
473
+ fi["all"]["atac"],
474
+ fi["all"]["flux"],
475
+ df_rna_degs,
476
+ df_atac_degs,
477
+ df_flux_degs,
478
+ )
479
+ df_features.to_parquet(CACHE / "df_features.parquet", index=False)
480
+ df_features.to_csv(ROOT / "analysis" / "df_features.csv", index=False)
481
+ print(f"Wrote {CACHE / 'df_features.parquet'} and analysis/df_features.csv")
482
+ print("Done.")
483
+
484
+
485
+ if __name__ == "__main__":
486
+ main()
streamlit_hf/.streamlit/config.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#2563eb"
3
+ backgroundColor = "#f8fafc"
4
+ secondaryBackgroundColor = "#ffffff"
5
+ textColor = "#0f172a"
6
+ font = "sans-serif"
7
+
8
+ [server]
9
+ headless = true
10
+ # Default CORS + XSRF settings avoid the "enableCORS=false vs XSRF" conflict on localhost.
11
+
12
+ [browser]
13
+ gatherUsageStats = false
streamlit_hf/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space (Docker + Streamlit)
2
+
3
+ The **root `README.md`** starts with the YAML card Hugging Face reads for the Space (title, tags, colours, `sdk: docker`, `app_port: 7860`). Copy that block if you maintain a separate Space README.
4
+
5
+ ```yaml
6
+ ---
7
+ title: FateFormer Explorer
8
+ short_description: Streamlit app to explore multimodal single-cell fate modeling (RNA, ATAC, metabolic flux, attention, and rankings).
9
+ emoji: 🧬
10
+ colorFrom: violet
11
+ colorTo: indigo
12
+ tags:
13
+ - streamlit
14
+ - single-cell
15
+ - multi-omics
16
+ - genomics
17
+ - atac-seq
18
+ - rna-seq
19
+ - metabolic-modeling
20
+ - deep-learning
21
+ - biology
22
+ license: mit
23
+ sdk: docker
24
+ app_port: 7860
25
+ ---
26
+ ```
27
+
28
+ `app_port` **7860** matches the root **`Dockerfile`** (`streamlit ... --server.port 7860`). Local runs use Streamlit’s default **8501** unless you pass `--server.port`.
29
+
30
+ ## Before first deploy
31
+
32
+ 1. Run locally: `python scripts/precompute_streamlit_cache.py` (requires GPU/CPU time for attention).
33
+ 2. Commit **`streamlit_hf/cache/`** contents (`latent_umap.pkl`, `attention_summary.pkl`, `attention_feature_ranks.pkl`, `df_features.parquet`, and optionally `samples.parquet` if you use it elsewhere) or attach via **Git LFS** if files are large. These paths are listed in `.gitignore`; use `git add -f streamlit_hf/cache/*` when you want them in the remote.
34
+ 3. Keep **`ckp/`** model weights available only if you run precompute in CI; the slim Docker image does **not** include PyTorch and expects precomputed caches.
35
+
36
+ The repository **`Dockerfile`** at the root builds the Space.
streamlit_hf/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Streamlit explorer package (run with PYTHONPATH=repo root).
streamlit_hf/app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FateFormer: interactive analysis explorer.
3
+ Run from repository root: PYTHONPATH=. streamlit run streamlit_hf/app.py
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import streamlit as st
9
+
10
+ _APP_DIR = Path(__file__).resolve().parent
11
+ _ICON_PATH = _APP_DIR / "static" / "app_icon.svg"
12
+ _page_icon_kw = {"page_icon": str(_ICON_PATH)} if _ICON_PATH.is_file() else {}
13
+
14
+ st.set_page_config(
15
+ page_title="FateFormer Explorer",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded",
18
+ **_page_icon_kw,
19
+ )
20
+
21
+ _home = str(_APP_DIR / "home.py")
22
+ _p1 = str(_APP_DIR / "pages" / "1_Single_Cell_Explorer.py")
23
+ _p2 = str(_APP_DIR / "pages" / "2_Feature_insights.py")
24
+ _p3 = str(_APP_DIR / "pages" / "3_Flux_analysis.py")
25
+ _p4 = str(_APP_DIR / "pages" / "4_Gene_expression_analysis.py")
26
+
27
+ pages = [
28
+ st.Page(_home, title="Home", icon=":material/home:", default=True),
29
+ st.Page(_p1, title="Single-Cell Explorer", icon=":material/scatter_plot:"),
30
+ st.Page(_p2, title="Feature Insights", icon=":material/analytics:"),
31
+ st.Page(_p3, title="Flux Analysis", icon=":material/account_tree:"),
32
+ st.Page(_p4, title="Gene Expression & TF Activity", icon=":material/genetics:"),
33
+ ]
34
+ nav = st.navigation(pages)
35
+ nav.run()
streamlit_hf/cache/.gitkeep ADDED
File without changes
streamlit_hf/home.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Landing content for the FateFormer Streamlit hub."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import streamlit as st
9
+
10
+ _REPO = Path(__file__).resolve().parents[1]
11
+ if str(_REPO) not in sys.path:
12
+ sys.path.insert(0, str(_REPO))
13
+
14
+ from streamlit_hf.lib import ui
15
+
16
+ _CACHE = Path(__file__).resolve().parent / "cache"
17
+ _HAS_CACHE = (_CACHE / "latent_umap.pkl").is_file() and (_CACHE / "df_features.parquet").is_file()
18
+
19
+ ui.inject_app_styles()
20
+
21
+ st.title("FateFormer: interactive analysis")
22
+ st.caption("Choose a workspace below or use the sidebar. All views use the same precomputed validation results.")
23
+
24
+ if not _HAS_CACHE:
25
+ st.warning(
26
+ "This deployment does not have precomputed results yet. Ask the maintainer to publish data, then reload."
27
+ )
28
+ else:
29
+ st.success("Precomputed results are available. After a server-side update, refresh the browser to load new plots.")
30
+
31
+ st.subheader("Open a page")
32
+ r1a, r1b, r1c = st.columns(3)
33
+ with r1a:
34
+ with st.container(border=True):
35
+ st.page_link("pages/1_Single_Cell_Explorer.py", label="Single-Cell Explorer", icon=":material/scatter_plot:")
36
+ st.caption("Latent UMAP: colour by fate, prediction, fold, batch, modalities, or dominant fate emphasis.")
37
+ with r1b:
38
+ with st.container(border=True):
39
+ st.page_link("pages/2_Feature_insights.py", label="Feature Insights", icon=":material/analytics:")
40
+ st.caption("Shift and attention rankings, cohort comparisons, and full feature tables.")
41
+ with r1c:
42
+ with st.container(border=True):
43
+ st.page_link("pages/3_Flux_analysis.py", label="Flux Analysis", icon=":material/account_tree:")
44
+ st.caption("Reaction pathways, differential flux, rankings, and model metadata.")
45
+ r2a, _, _ = st.columns(3)
46
+ with r2a:
47
+ with st.container(border=True):
48
+ st.page_link(
49
+ "pages/4_Gene_expression_analysis.py",
50
+ label="Gene Expression & TF Activity",
51
+ icon=":material/genetics:",
52
+ )
53
+ st.caption("Pathway enrichment, motif activity, and gene / motif tables.")
54
+
55
+ st.markdown("---")
56
+ st.markdown(
57
+ "**Tips:** use chart toolbars for pan/zoom and lasso selection where offered. Tables support search and column sort from the header row."
58
+ )
streamlit_hf/lib/__init__.py ADDED
File without changes
streamlit_hf/lib/formatters.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Human-readable labels for compact codes used in cached tables."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # Matches interpretation.predictions._get_modality_info letter codes (R/A/F order).
9
+ # Short table-friendly labels (no long parentheticals).
10
+ _MODALITY_LONG: dict[str, str] = {
11
+ "RAF": "RNA + ATAC + Flux",
12
+ "RA": "RNA + ATAC",
13
+ "RF": "RNA + Flux",
14
+ "AF": "ATAC + Flux",
15
+ "R": "RNA only",
16
+ "A": "ATAC only",
17
+ "F": "Flux only",
18
+ "None": "No modality data",
19
+ "none": "No modality data",
20
+ "nan": "No modality data",
21
+ }
22
+
23
+ # Rename row fields in inspector tables for display.
24
+ _FIELD_DISPLAY: dict[str, str] = {
25
+ "label": "CellTag-Multi label",
26
+ }
27
+
28
+ # Latent explorer: table headers and key–value inspector (exclude non-meaningful / internal cols).
29
+ LATENT_TABLE_RENAME: dict[str, str] = {
30
+ "label": "CellTag-Multi label",
31
+ "predicted_class": "Predicted fate",
32
+ "predicted_value": "Prediction score",
33
+ "correct": "Prediction correct",
34
+ "pct": "Dominant fate (%)",
35
+ "modality_label": "Available modalities",
36
+ "dataset_idx": "Dataset index",
37
+ "batch_no": "Batch",
38
+ "fold": "CV fold",
39
+ "clone_id": "Clone ID",
40
+ "clone_size": "Clone size",
41
+ "cell_type": "Cell type",
42
+ }
43
+
44
+ LATENT_DROP_FROM_TABLES: frozenset[str] = frozenset({"umap_x", "umap_y", "modality", "pct_decile"})
45
+
46
+ _NAME_MAP = {**_FIELD_DISPLAY, **LATENT_TABLE_RENAME}
47
+
48
+
49
+ def _format_scalar(v) -> str:
50
+ if v is None:
51
+ return ""
52
+ if isinstance(v, bool):
53
+ return "Yes" if v else "No"
54
+ try:
55
+ if pd.isna(v):
56
+ return ""
57
+ except (ValueError, TypeError):
58
+ pass
59
+ if isinstance(v, (float, np.floating)) and np.isnan(v):
60
+ return ""
61
+ return str(v)
62
+
63
+
64
+ def _field_label(name: str, *, fallback_field_display: bool) -> str:
65
+ k = str(name)
66
+ if fallback_field_display:
67
+ return _NAME_MAP.get(k, _FIELD_DISPLAY.get(k, k))
68
+ return _NAME_MAP.get(k, k)
69
+
70
+
71
+ def expand_modality(code) -> str:
72
+ """Map R/A/F codes (e.g. RAF, RA) to full names."""
73
+ if code is None:
74
+ return _MODALITY_LONG["None"]
75
+ try:
76
+ if pd.isna(code):
77
+ return _MODALITY_LONG["None"]
78
+ except (ValueError, TypeError):
79
+ pass
80
+ if isinstance(code, (float, np.floating)) and np.isnan(code):
81
+ return _MODALITY_LONG["None"]
82
+ key = str(code).strip()
83
+ if not key or key.lower() == "nan":
84
+ return _MODALITY_LONG["None"]
85
+ return _MODALITY_LONG.get(key, key)
86
+
87
+
88
+ def annotate_modality_column(df, code_col: str = "modality", label_col: str = "modality_label"):
89
+ """Add human-readable modality column; returns a copy."""
90
+ out = df.copy()
91
+ out[label_col] = out[code_col].map(expand_modality)
92
+ return out
93
+
94
+
95
+ def prepare_latent_display_dataframe(df: pd.DataFrame) -> pd.DataFrame:
96
+ """Drop UMAP / internal columns and rename headers for Selected-points style tables."""
97
+ drop = [c for c in df.columns if c in LATENT_DROP_FROM_TABLES or str(c).startswith("umap_")]
98
+ out = df.drop(columns=drop, errors="ignore")
99
+ return out.rename(columns=LATENT_TABLE_RENAME)
100
+
101
+
102
+ def latent_inspector_key_value(series: pd.Series) -> pd.DataFrame:
103
+ """Key–value inspector row: human names, no UMAP coordinates."""
104
+ s = series.drop(
105
+ labels=[c for c in series.index if c in LATENT_DROP_FROM_TABLES or str(c).startswith("umap_")],
106
+ errors="ignore",
107
+ )
108
+ idx = [_field_label(i, fallback_field_display=False) for i in s.index]
109
+ vals = [_format_scalar(v) for v in s.values]
110
+ return pd.DataFrame({"Field": idx, "Value": vals})
111
+
112
+
113
+ def dataframe_to_arrow_safe_kv(series: pd.Series) -> pd.DataFrame:
114
+ """Two string columns for Streamlit/PyArrow (avoids mixed-type single column)."""
115
+ s = series.copy()
116
+ idx = [_field_label(i, fallback_field_display=True) for i in s.index]
117
+ vals = [_format_scalar(v) for v in s.values]
118
+ return pd.DataFrame({"field": idx, "value": vals})
streamlit_hf/lib/io.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load precomputed explorer artifacts (no torch required at runtime)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pickle
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from streamlit_hf.lib.formatters import annotate_modality_column
12
+ from streamlit_hf.lib.reactions import normalize_reaction_key
13
+
14
+ REPO_ROOT = Path(__file__).resolve().parents[2]
15
+ CACHE_DIR = REPO_ROOT / "streamlit_hf" / "cache"
16
+ METABOLIC_MODEL_METADATA = REPO_ROOT / "data" / "datasets" / "metabolic_model_metadata.csv"
17
+
18
+
19
+ def _is_valid_features_csv(path: Path) -> bool:
20
+ if not path.is_file():
21
+ return False
22
+ try:
23
+ head = pd.read_csv(path, nrows=2)
24
+ except Exception:
25
+ return False
26
+ return "feature" in head.columns and "importance_shift" in head.columns
27
+
28
+
29
+ def load_latent_bundle():
30
+ path = CACHE_DIR / "latent_umap.pkl"
31
+ if not path.is_file():
32
+ return None
33
+ with open(path, "rb") as f:
34
+ return pickle.load(f)
35
+
36
+
37
+ def load_attention_summary():
38
+ path = CACHE_DIR / "attention_summary.pkl"
39
+ if not path.is_file():
40
+ return None
41
+ with open(path, "rb") as f:
42
+ return pickle.load(f)
43
+
44
+
45
+ def load_samples_df() -> pd.DataFrame | None:
46
+ pq = CACHE_DIR / "samples.parquet"
47
+ if pq.is_file():
48
+ df = pd.read_parquet(pq)
49
+ return annotate_modality_column(df) if "modality" in df.columns else df
50
+ return None
51
+
52
+
53
+ def _add_within_modality_orders(df: pd.DataFrame) -> pd.DataFrame:
54
+ """
55
+ Align scatter / table columns with the notebook.
56
+
57
+ Parquet from precompute already has rank_shift_in_modal / rank_att_in_modal from the same
58
+ merge-of-sorted-lists logic as the notebook; do not overwrite those with pandas ranks on
59
+ rounded importances (tie order can differ and changes the RNA cloud).
60
+ """
61
+ out = df.copy()
62
+ if "modality" not in out.columns:
63
+ return out
64
+ if "rank_shift_in_modal" in out.columns and "rank_att_in_modal" in out.columns:
65
+ out["shift_order_mod"] = out["rank_shift_in_modal"].astype(int)
66
+ out["attention_order_mod"] = out["rank_att_in_modal"].astype(int)
67
+ else:
68
+ g = out.groupby("modality", observed=True)
69
+ out["shift_order_mod"] = g["importance_shift"].rank(ascending=False, method="first").astype(int)
70
+ out["attention_order_mod"] = g["importance_att"].rank(ascending=False, method="first").astype(int)
71
+ out["rank_shift_in_modal"] = out["shift_order_mod"]
72
+ out["rank_att_in_modal"] = out["attention_order_mod"]
73
+ if "combined_order_mod" not in out.columns:
74
+ g = out.groupby("modality", observed=True)
75
+ out["combined_order_mod"] = g["mean_rank"].rank(ascending=True, method="first").astype(int)
76
+ return out
77
+
78
+
79
+ def load_metabolic_model_metadata() -> pd.DataFrame | None:
80
+ """Directed reaction edges: substrate → product, grouped by supermodule (see CSV headers)."""
81
+ if not METABOLIC_MODEL_METADATA.is_file():
82
+ return None
83
+ return pd.read_csv(METABOLIC_MODEL_METADATA)
84
+
85
+
86
+ def build_metabolic_model_table(
87
+ meta: pd.DataFrame,
88
+ flux_df: pd.DataFrame,
89
+ supermodule_id: int | None = None,
90
+ ) -> pd.DataFrame:
91
+ """
92
+ Static edge list: substrate → product, reaction label, module class, plus DE / model columns when the
93
+ reaction string matches a row in the flux feature table.
94
+ """
95
+ need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Supermodule_id", "Super.Module.class"}
96
+ if not need.issubset(set(meta.columns)):
97
+ return pd.DataFrame()
98
+ m = meta.copy()
99
+ if supermodule_id is not None:
100
+ m = m[m["Supermodule_id"] == int(supermodule_id)]
101
+ if m.empty:
102
+ return pd.DataFrame()
103
+
104
+ fd = flux_df.copy()
105
+ fd["_rk"] = fd["feature"].map(normalize_reaction_key)
106
+ fd = fd.drop_duplicates("_rk", keep="first").set_index("_rk", drop=False)
107
+
108
+ rows: list[dict] = []
109
+ for _, r in m.iterrows():
110
+ k = normalize_reaction_key(str(r["rxnName"]))
111
+ base = {
112
+ "Supermodule": r.get("Super.Module.class"),
113
+ "Module_id": r.get("Module_id"),
114
+ "Substrate": r["Compound_IN_name"],
115
+ "Product": r["Compound_OUT_name"],
116
+ "Reaction": r["rxnName"],
117
+ }
118
+ if k in fd.index:
119
+ row = fd.loc[k]
120
+ if isinstance(row, pd.DataFrame):
121
+ row = row.iloc[0]
122
+ base["log_fc"] = row["log_fc"] if "log_fc" in row.index else None
123
+ base["pval_adj"] = row["pval_adj"] if "pval_adj" in row.index else None
124
+ base["mean_rank"] = row["mean_rank"] if "mean_rank" in row.index else None
125
+ base["pathway"] = row["pathway"] if "pathway" in row.index else None
126
+ else:
127
+ base["log_fc"] = None
128
+ base["pval_adj"] = None
129
+ base["mean_rank"] = None
130
+ base["pathway"] = None
131
+ rows.append(base)
132
+ return pd.DataFrame(rows)
133
+
134
+
135
+ def load_df_features() -> pd.DataFrame | None:
136
+ pq = CACHE_DIR / "df_features.parquet"
137
+ if pq.is_file():
138
+ return _add_within_modality_orders(pd.read_parquet(pq))
139
+ csv_cache = CACHE_DIR / "df_features.csv"
140
+ if csv_cache.is_file():
141
+ return _add_within_modality_orders(pd.read_csv(csv_cache))
142
+ analysis_csv = REPO_ROOT / "analysis" / "df_features.csv"
143
+ if _is_valid_features_csv(analysis_csv):
144
+ return _add_within_modality_orders(pd.read_csv(analysis_csv))
145
+ return None
146
+
147
+
148
+ def latent_join_samples(bundle: dict, samples: pd.DataFrame | None) -> pd.DataFrame:
149
+ """One row per UMAP point, aligned with bundle arrays."""
150
+ n = len(bundle["umap_x"])
151
+ df = pd.DataFrame(
152
+ {
153
+ "umap_x": bundle["umap_x"],
154
+ "umap_y": bundle["umap_y"],
155
+ "label": bundle["label_name"],
156
+ "predicted_class": bundle["pred_name"],
157
+ "correct": bundle["correct"].astype(bool),
158
+ "fold": bundle["fold"].astype(int),
159
+ "batch_no": bundle["batch_no"].astype(int),
160
+ "pct": bundle["pct"],
161
+ "modality": bundle["modality"],
162
+ "dataset_idx": bundle["dataset_idx"].astype(int),
163
+ }
164
+ )
165
+ if samples is not None and not samples.empty:
166
+ s = samples.drop_duplicates(subset=["ind"], keep="first").set_index("ind")
167
+ extra = s.reindex(df["dataset_idx"].values)
168
+ for col in ["predicted_value", "clone_id", "clone_size", "cell_type"]:
169
+ if col in extra.columns:
170
+ df[col] = extra[col].values
171
+ return annotate_modality_column(df)
streamlit_hf/lib/pathways.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pathway enrichment tables (DAVID-style exports) for Reactome and KEGG panels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ REPO_ROOT = Path(__file__).resolve().parents[2]
11
+ DE_TSV = REPO_ROOT / "analysis" / "de_all_48.tsv"
12
+ RE_TSV = REPO_ROOT / "analysis" / "re_all_48.tsv"
13
+
14
+
15
+ def load_de_re_tsv() -> tuple[pd.DataFrame, pd.DataFrame] | None:
16
+ if not DE_TSV.is_file() or not RE_TSV.is_file():
17
+ return None
18
+ return pd.read_csv(DE_TSV, sep="\t"), pd.read_csv(RE_TSV, sep="\t")
19
+
20
+
21
+ def preprocess_pathway_file(df: pd.DataFrame, splitter: str) -> pd.DataFrame:
22
+ out = df.copy()
23
+ out["Term"] = out["Term"].astype(str).str.split(splitter).str[-1]
24
+ if splitter == "-":
25
+ out["Term"] = out["Term"].astype(str).str.split("~").str[-1]
26
+ out = out[out["Benjamini"] < 0.05].copy()
27
+ out["Gene Ratio"] = out["Count"] / out["List Total"]
28
+ return out
29
+
30
+
31
+ def merged_reactome_kegg_bubble_frames(
32
+ de_all: pd.DataFrame, re_all: pd.DataFrame
33
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
34
+ """Rows for bubble plot (Gene Ratio, Count, Benjamini, Library, Term) per notebook cell 31."""
35
+ reactome_de = de_all[de_all["Category"] == "REACTOME_PATHWAY"]
36
+ reactome_re = re_all[re_all["Category"] == "REACTOME_PATHWAY"]
37
+ kegg_de = de_all[de_all["Category"] == "KEGG_PATHWAY"]
38
+ kegg_re = re_all[re_all["Category"] == "KEGG_PATHWAY"]
39
+
40
+ rde = preprocess_pathway_file(reactome_de, "~")
41
+ rde["Library"] = "Reactome"
42
+ rre = preprocess_pathway_file(reactome_re, "~")
43
+ rre["Library"] = "Reactome"
44
+ kde = preprocess_pathway_file(kegg_de, ":")
45
+ kde["Library"] = "KEGG"
46
+ kre = preprocess_pathway_file(kegg_re, ":")
47
+ kre["Library"] = "KEGG"
48
+
49
+ merged_dead = pd.concat([rde, kde], ignore_index=True)
50
+ merged_re = pd.concat([rre, kre], ignore_index=True)
51
+ return merged_dead, merged_re
52
+
53
+
54
+ def _preprocess_exploded(df: pd.DataFrame, pval_threshold: float, splitter: str, label: str) -> pd.DataFrame:
55
+ d = df.copy()
56
+ d["Term"] = d["Term"].astype(str).str.split(splitter).str[-1]
57
+ if splitter == "-":
58
+ d["Term"] = d["Term"].astype(str).str.split("~").str[-1]
59
+
60
+ def _trunc(x: str) -> str:
61
+ return x[:60] + "..." if len(x) > 60 else x
62
+
63
+ d["Term"] = d["Term"].map(_trunc)
64
+ d = d[d["Benjamini"] < pval_threshold]
65
+ sub = d[["Term", "Genes", "Benjamini"]].copy()
66
+ sub["Label"] = label
67
+ exploded = (
68
+ sub.set_index(["Term", "Benjamini", "Label"])["Genes"].str.split(", ").explode().reset_index()
69
+ )
70
+ return exploded
71
+
72
+
73
+ def _binary_matrix(data: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
74
+ binary = pd.crosstab(data["Term"], data["Genes"])
75
+ labels = data.groupby("Term")["Label"].first()
76
+ pvals = data.groupby("Term")["Benjamini"].first()
77
+ return binary, labels, pvals
78
+
79
+
80
+ def _sort_matrix(matrix: pd.DataFrame) -> pd.DataFrame:
81
+ sp = matrix.sum(axis=1).sort_values(ascending=False).index
82
+ sg = matrix.sum(axis=0).sort_values(ascending=False).index
83
+ return matrix.loc[sp, sg]
84
+
85
+
86
+ def build_merged_pathway_membership(
87
+ de_all: pd.DataFrame, re_all: pd.DataFrame, pval_threshold: float = 0.05
88
+ ) -> tuple[np.ndarray, list[str], list[str]] | None:
89
+ """
90
+ Numeric grid for heatmap: values 0=white, 1=dead-end gene, 2=reprogramming gene,
91
+ 3=Reactome library stripe, 4=KEGG library stripe (notebook cell 29).
92
+ """
93
+ reactome_de = de_all[de_all["Category"] == "REACTOME_PATHWAY"]
94
+ reactome_re = re_all[re_all["Category"] == "REACTOME_PATHWAY"]
95
+ kegg_de = de_all[de_all["Category"] == "KEGG_PATHWAY"]
96
+ kegg_re = re_all[re_all["Category"] == "KEGG_PATHWAY"]
97
+
98
+ rde = _preprocess_exploded(reactome_de, pval_threshold, "~", "Dead-end")
99
+ rre = _preprocess_exploded(reactome_re, pval_threshold, "~", "Reprogramming")
100
+ rcomb = pd.concat([rde, rre], ignore_index=True)
101
+ kde = _preprocess_exploded(kegg_de, pval_threshold, ":", "Dead-end")
102
+ kre = _preprocess_exploded(kegg_re, pval_threshold, ":", "Reprogramming")
103
+ kcomb = pd.concat([kde, kre], ignore_index=True)
104
+
105
+ rm, rlab, _ = _binary_matrix(rcomb)
106
+ km, klab, _ = _binary_matrix(kcomb)
107
+ rm = _sort_matrix(rm)
108
+ km = _sort_matrix(km)
109
+
110
+ reactome_lib = pd.Series("Reactome", index=rm.index)
111
+ kegg_lib = pd.Series("KEGG", index=km.index)
112
+ merged = pd.concat([rm, km], axis=0, sort=False).fillna(0)
113
+ if merged.empty or merged.shape[1] == 0:
114
+ return None
115
+ merged_labels = pd.concat([rlab, klab])
116
+ merged_library = pd.concat([reactome_lib, kegg_lib])
117
+
118
+ label_code = {"Dead-end": 1, "Reprogramming": 2}
119
+ lib_code = {"Reactome": 3, "KEGG": 4}
120
+
121
+ gene_cols = list(merged.columns)
122
+ z = np.zeros((len(merged), len(gene_cols) + 1), dtype=float)
123
+ for i, term in enumerate(merged.index):
124
+ lc = label_code.get(str(merged_labels.loc[term]), 0)
125
+ for j, g in enumerate(gene_cols):
126
+ v = float(merged.loc[term, g])
127
+ if v > 0 and lc:
128
+ z[i, j] = v * lc
129
+ z[i, -1] = lib_code.get(str(merged_library.loc[term]), 0)
130
+
131
+ row_labels = [str(t) for t in merged.index]
132
+ col_labels = gene_cols + ["Library"]
133
+ return z, row_labels, col_labels
streamlit_hf/lib/plots.py ADDED
@@ -0,0 +1,1421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plotly helpers for the explorer UI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import plotly.express as px
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+
13
+ from streamlit_hf.lib.reactions import normalize_reaction_key
14
+
15
+ # Matches Streamlit theme primary + slate text; used across Plotly layouts.
16
+ PLOT_FONT = dict(family="Inter, system-ui, sans-serif", size=12)
17
+
18
+ PALETTE = (
19
+ "#2563eb",
20
+ "#dc2626",
21
+ "#059669",
22
+ "#d97706",
23
+ "#7c3aed",
24
+ "#db2777",
25
+ "#0d9488",
26
+ "#4f46e5",
27
+ )
28
+
29
+ MODALITY_COLOR = {"RNA": "#E64B35", "ATAC": "#4DBBD5", "Flux": "#00A087"}
30
+ # Global modality pie only: edit here to try other hues (bars/scatter use MODALITY_COLOR).
31
+ MODALITY_PIE_COLOR = dict(MODALITY_COLOR)
32
+ # Log₂FC heatmaps/sunburst: colours like ggplot2 scale_colour_gradient2 (mid grey at 0).
33
+ LOG_FC_COLOR_MIN = -0.5
34
+ LOG_FC_COLOR_MAX = 0.5
35
+ LOG_FC_DIVERGING_SCALE: list[list] = [
36
+ [0.0, "#1C86EE"],
37
+ [0.5, "#FAFAFA"],
38
+ [1.0, "#FF0000"],
39
+ ]
40
+ # Unicode minus (U+2212) and subscript ₁₀ / ₂ for axes/colorbars.
41
+ LABEL_NEG_LOG10_ADJ_P = "\u2212log\u2081\u2080 adj. p"
42
+ LABEL_LOG2FC = "Log\u2082FC"
43
+ # Cached attention dict uses lowercase modality keys.
44
+ FI_ATT_MOD_KEY = {"RNA": "rna", "ATAC": "atac", "Flux": "flux"}
45
+ # Model appends one batch-embedding token per modality; hide from attention rankings in the UI.
46
+ BATCH_EMBEDDING_FEATURE_NAMES = frozenset({"batch_rna", "batch_atac", "batch_flux"})
47
+
48
+
49
+ def _attention_pairs_skip_batch(pairs: list) -> list:
50
+ return [(n, s) for n, s in pairs if str(n) not in BATCH_EMBEDDING_FEATURE_NAMES]
51
+
52
+
53
+ def rollout_top_features_table(feature_names, vec, top_n: int) -> pd.DataFrame:
54
+ """Top `top_n` rollout weights per modality slice, excluding batch-embedding tokens."""
55
+ names = [str(x) for x in feature_names]
56
+ v = np.asarray(vec, dtype=float)
57
+ rows = [
58
+ (names[i], float(v[i]))
59
+ for i in range(len(names))
60
+ if names[i] not in BATCH_EMBEDDING_FEATURE_NAMES
61
+ ]
62
+ rows.sort(key=lambda x: -x[1])
63
+ rows = rows[:top_n]
64
+ if not rows:
65
+ return pd.DataFrame(columns=["feature", "mean_attention"])
66
+ feat, val = zip(*rows)
67
+ return pd.DataFrame({"feature": list(feat), "mean_attention": list(val)})
68
+
69
+ # Themed continuous scale for dominant-fate % on UMAP (low → high emphasis).
70
+ UMAP_PCT_COLORSCALE: list[list] = [
71
+ [0.0, "#eff6ff"],
72
+ [0.25, "#bfdbfe"],
73
+ [0.55, "#3b82f6"],
74
+ [0.82, "#2563eb"],
75
+ [1.0, "#1e3a8a"],
76
+ ]
77
+
78
+ # Okabe–Ito–style distinct colours (colourblind-friendly) for categorical UMAP hues.
79
+ LATENT_DISCRETE_PALETTE = (
80
+ "#0072B2",
81
+ "#E69F00",
82
+ "#009E73",
83
+ "#CC79A7",
84
+ "#56B4E9",
85
+ "#D55E00",
86
+ "#F0E442",
87
+ "#000000",
88
+ )
89
+
90
+
91
+ def latent_scatter(
92
+ df,
93
+ color_col: str,
94
+ title: str,
95
+ width: int = 720,
96
+ height: int = 520,
97
+ marker_size: float = 5.0,
98
+ marker_opacity: float = 0.78,
99
+ ):
100
+ d = df.copy()
101
+ hover_spec = {
102
+ "umap_x": ":.3f",
103
+ "umap_y": ":.3f",
104
+ "dataset_idx": True,
105
+ "fold": True,
106
+ "batch_no": True,
107
+ "predicted_class": True,
108
+ "label": True,
109
+ "correct": True,
110
+ "pct": ":.2f",
111
+ "modality_label": True,
112
+ "modality": True,
113
+ "predicted_value": ":.3f",
114
+ "clone_id": True,
115
+ "clone_size": True,
116
+ "cell_type": True,
117
+ }
118
+ if "modality_label" in d.columns:
119
+ hover_spec.pop("modality", None)
120
+ hover_data = {k: v for k, v in hover_spec.items() if k in d.columns}
121
+ _disp = {
122
+ "label": "CellTag-Multi label",
123
+ "predicted_class": "Predicted fate",
124
+ "pct": "Dominant fate (%)",
125
+ "modality_label": "Available modalities",
126
+ "dataset_idx": "Dataset index",
127
+ "batch_no": "Batch",
128
+ "fold": "CV fold",
129
+ }
130
+ labels_map = {c: _disp[c] for c in _disp if c in d.columns}
131
+
132
+ continuous = color_col == "pct"
133
+ if color_col == "fold":
134
+ d["_color"] = d["fold"].astype(str)
135
+ color_arg = "_color"
136
+ labels_map["_color"] = "Fold"
137
+ continuous = False
138
+ elif color_col == "batch_no":
139
+ d["_color"] = d["batch_no"].astype(str)
140
+ color_arg = "_color"
141
+ labels_map["_color"] = "Batch"
142
+ continuous = False
143
+ elif color_col == "correct":
144
+ d["_color"] = d["correct"].map({True: "Correct", False: "Wrong"})
145
+ color_arg = "_color"
146
+ labels_map["_color"] = "Prediction"
147
+ continuous = False
148
+ else:
149
+ color_arg = color_col
150
+
151
+ common = dict(
152
+ x="umap_x",
153
+ y="umap_y",
154
+ hover_data=hover_data,
155
+ labels=labels_map,
156
+ title=title,
157
+ width=width,
158
+ height=height,
159
+ )
160
+ if continuous:
161
+ fig = px.scatter(
162
+ d,
163
+ color=color_arg,
164
+ color_continuous_scale=UMAP_PCT_COLORSCALE,
165
+ **common,
166
+ )
167
+ else:
168
+ fig = px.scatter(
169
+ d,
170
+ color=color_arg,
171
+ color_discrete_sequence=list(LATENT_DISCRETE_PALETTE),
172
+ **common,
173
+ )
174
+ fig.update_traces(
175
+ marker=dict(size=marker_size, opacity=marker_opacity, line=dict(width=0.25, color="rgba(255,255,255,0.4)"))
176
+ )
177
+ fig.update_layout(
178
+ template="plotly_white",
179
+ font=PLOT_FONT,
180
+ title_font_size=16,
181
+ margin=dict(l=28, r=20, t=56, b=28),
182
+ legend_title_text="",
183
+ xaxis_title="",
184
+ yaxis_title="",
185
+ )
186
+ fig.update_xaxes(showticklabels=False, showgrid=True, gridcolor="rgba(0,0,0,0.06)", zeroline=False)
187
+ fig.update_yaxes(showticklabels=False, showgrid=True, gridcolor="rgba(0,0,0,0.06)", zeroline=False)
188
+ return fig
189
+
190
+
191
+ def rank_scatter_shift_vs_attention(df_mod, modality: str, width: int = 420, height: int = 440):
192
+ """Attention rank on x, shift rank on y, least-squares trend line, discrete point colours."""
193
+ need = ("shift_order_mod", "attention_order_mod")
194
+ if not all(c in df_mod.columns for c in need):
195
+ return go.Figure()
196
+ sub = df_mod.dropna(subset=list(need)).copy()
197
+ if sub.empty:
198
+ return go.Figure()
199
+ x = sub["attention_order_mod"].astype(float).to_numpy()
200
+ y = sub["shift_order_mod"].astype(float).to_numpy()
201
+ fig = px.scatter(
202
+ sub,
203
+ x="attention_order_mod",
204
+ y="shift_order_mod",
205
+ color="top_10_pct",
206
+ hover_name="feature",
207
+ hover_data={
208
+ "mean_rank": True,
209
+ "importance_shift": ":.4f",
210
+ "importance_att": ":.4f",
211
+ },
212
+ labels={
213
+ "attention_order_mod": "Attention rank",
214
+ "shift_order_mod": "Shift rank",
215
+ },
216
+ width=width,
217
+ height=height,
218
+ color_discrete_map={
219
+ "both": PALETTE[0],
220
+ "shift": PALETTE[1],
221
+ "att": PALETTE[2],
222
+ "None": "#94a3b8",
223
+ },
224
+ )
225
+ fig.update_traces(marker=dict(size=7, opacity=0.62, line=dict(width=0.5, color="rgba(15,23,42,0.28)")))
226
+ if len(x) >= 2 and float(np.ptp(x)) > 0:
227
+ coef = np.polyfit(x, y, 1)
228
+ poly = np.poly1d(coef)
229
+ xs = np.linspace(float(np.min(x)), float(np.max(x)), 100)
230
+ fig.add_trace(
231
+ go.Scatter(
232
+ x=xs,
233
+ y=poly(xs),
234
+ mode="lines",
235
+ name=f"y = {coef[0]:.2f}x + {coef[1]:.2f}",
236
+ line=dict(color="#2563eb", width=2, dash="dash"),
237
+ showlegend=True,
238
+ )
239
+ )
240
+ fig.update_layout(
241
+ template="plotly_white",
242
+ font=PLOT_FONT,
243
+ title=dict(
244
+ text=f"{modality}: shift vs attention (ranks)",
245
+ x=0.5,
246
+ xanchor="center",
247
+ y=0.98,
248
+ yanchor="top",
249
+ font=dict(size=14, family=PLOT_FONT["family"]),
250
+ ),
251
+ margin=dict(l=48, r=20, t=52, b=72),
252
+ legend=dict(orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5),
253
+ )
254
+ return fig
255
+
256
+
257
+ def _truncate_label(s: str, max_len: int = 36) -> str:
258
+ s = str(s)
259
+ return s if len(s) <= max_len else s[: max_len - 1] + "…"
260
+
261
+
262
+ def joint_shift_attention_top_features(df_mod, modality: str, top_n: int):
263
+ """
264
+ Top features by mean_rank (lowest = strongest joint shift+attention ranking).
265
+ Shift and attention importances are min–max scaled within this top-N slice for side-by-side comparison.
266
+ """
267
+ need = ("mean_rank", "importance_shift", "importance_att", "feature")
268
+ if not all(c in df_mod.columns for c in need):
269
+ return go.Figure()
270
+ sub = df_mod.nsmallest(top_n, "mean_rank").copy()
271
+ if sub.empty:
272
+ return go.Figure()
273
+
274
+ def _mm(s: pd.Series) -> pd.Series:
275
+ lo, hi = float(s.min()), float(s.max())
276
+ if hi <= lo:
277
+ return pd.Series(0.5, index=s.index)
278
+ return (s.astype(float) - lo) / (hi - lo)
279
+
280
+ sub["_zs"] = _mm(sub["importance_shift"])
281
+ sub["_za"] = _mm(sub["importance_att"])
282
+ # Best (lowest mean_rank) at top of chart; matches shift/attention rows below.
283
+ sub = sub.sort_values("mean_rank", ascending=True)
284
+ feats_full = sub["feature"].astype(str)
285
+ y_disp = feats_full.map(lambda s: _truncate_label(s, 40))
286
+ base = MODALITY_COLOR.get(modality, PALETTE[0])
287
+ att_c = "#475569" if base != "#475569" else "#64748b"
288
+
289
+ margin_l = int(min(380, 64 + 5.8 * max((len(t) for t in y_disp), default=10)))
290
+ h = min(720, 52 + 22 * len(sub))
291
+
292
+ fig = go.Figure()
293
+ fig.add_trace(
294
+ go.Bar(
295
+ name="Shift (scaled)",
296
+ y=y_disp,
297
+ x=sub["_zs"],
298
+ orientation="h",
299
+ marker_color=base,
300
+ customdata=feats_full,
301
+ hovertemplate="<b>%{customdata}</b><br>Shift (scaled): %{x:.3f}<extra></extra>",
302
+ )
303
+ )
304
+ fig.add_trace(
305
+ go.Bar(
306
+ name="Attention (scaled)",
307
+ y=y_disp,
308
+ x=sub["_za"],
309
+ orientation="h",
310
+ marker_color=att_c,
311
+ customdata=feats_full,
312
+ hovertemplate="<b>%{customdata}</b><br>Attention (scaled): %{x:.3f}<extra></extra>",
313
+ )
314
+ )
315
+ fig.update_layout(
316
+ template="plotly_white",
317
+ font=PLOT_FONT,
318
+ title=dict(
319
+ text=f"{modality} · top {top_n}",
320
+ x=0.5,
321
+ xanchor="center",
322
+ y=0.98,
323
+ yanchor="top",
324
+ font=dict(size=14, family=PLOT_FONT["family"]),
325
+ ),
326
+ barmode="group",
327
+ bargap=0.15,
328
+ bargroupgap=0.05,
329
+ width=680,
330
+ height=h,
331
+ margin=dict(l=margin_l, r=12, t=44, b=72),
332
+ xaxis_title="Scaled 0-1 within selection",
333
+ yaxis_title="",
334
+ legend=dict(orientation="h", yanchor="top", y=-0.14, xanchor="center", x=0.5),
335
+ )
336
+ fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
337
+ return fig
338
+
339
+
340
+ def modality_shift_attention_rank_stats(df_mod) -> dict[str, Any]:
341
+ """Pearson / Spearman between per-modality shift and attention ordinal ranks."""
342
+ from scipy.stats import pearsonr, spearmanr
343
+
344
+ need = ("shift_order_mod", "attention_order_mod")
345
+ if not all(c in df_mod.columns for c in need):
346
+ return {"n": 0}
347
+ sub = df_mod.dropna(subset=list(need))
348
+ n = len(sub)
349
+ if n < 3:
350
+ return {"n": n}
351
+ xs = sub["attention_order_mod"].astype(float)
352
+ ys = sub["shift_order_mod"].astype(float)
353
+ pr, pp = pearsonr(xs, ys)
354
+ sr, sp = spearmanr(xs, ys)
355
+ return {
356
+ "n": n,
357
+ "pearson_r": float(pr),
358
+ "pearson_p": float(pp),
359
+ "spearman_r": float(sr),
360
+ "spearman_p": float(sp),
361
+ }
362
+
363
+
364
+ def rank_bar(
365
+ df_top,
366
+ xcol: str,
367
+ ycol: str,
368
+ title: str,
369
+ color: str = PALETTE[0],
370
+ xaxis_title: str | None = None,
371
+ ):
372
+ d = df_top.sort_values(xcol, ascending=True)
373
+ y_raw = d[ycol].astype(str)
374
+ y_show = y_raw.map(lambda s: _truncate_label(s, 42))
375
+ margin_l = int(min(420, 80 + 5.8 * max((len(s) for s in y_show), default=12)))
376
+ fig = go.Figure(
377
+ go.Bar(
378
+ y=y_show,
379
+ x=d[xcol],
380
+ orientation="h",
381
+ marker_color=color,
382
+ customdata=y_raw,
383
+ hovertemplate="<b>%{customdata}</b><br>%{x:.4g}<extra></extra>",
384
+ )
385
+ )
386
+ xt = xaxis_title if xaxis_title is not None else xcol.replace("_", " ")
387
+ fig.update_layout(
388
+ template="plotly_white",
389
+ font=PLOT_FONT,
390
+ title=title,
391
+ width=680,
392
+ height=min(620, 38 + 20 * len(d)),
393
+ margin=dict(l=margin_l, r=24, t=48, b=40),
394
+ xaxis_title=xt,
395
+ yaxis_title="",
396
+ )
397
+ fig.update_yaxes(tickfont=dict(size=10))
398
+ return fig
399
+
400
+
401
+ def attention_top_comparison(fi_lists: dict, modality: str, top_n: int = 18):
402
+ """fi_lists: cohort -> {rna|atac|flux: [(name, score), ...]}."""
403
+ mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
404
+ traces = []
405
+ for key, name, color in (
406
+ ("all", "All validation samples", PALETTE[0]),
407
+ ("dead_end", "Predicted dead-end", PALETTE[1]),
408
+ ("reprogramming", "Predicted reprogramming", PALETTE[2]),
409
+ ):
410
+ cohort = fi_lists.get(key) or {}
411
+ items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
412
+ if not items:
413
+ continue
414
+ feats, scores = zip(*items)
415
+ traces.append(
416
+ go.Bar(
417
+ name=name,
418
+ x=list(scores),
419
+ y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
420
+ orientation="h",
421
+ marker_color=color,
422
+ )
423
+ )
424
+ fig = go.Figure(traces)
425
+ bar_h = max(320, 36 + min(top_n, 20) * 22 * max(1, len(traces)))
426
+ fig.update_layout(
427
+ barmode="group",
428
+ template="plotly_white",
429
+ font=PLOT_FONT,
430
+ title=f"Top attention (rollout): {modality}",
431
+ width=520,
432
+ height=bar_h,
433
+ margin=dict(l=220, r=24, t=56, b=40),
434
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
435
+ )
436
+ if not traces:
437
+ fig.update_layout(
438
+ annotations=[
439
+ dict(
440
+ text="No attention list for this modality (re-run precompute).",
441
+ xref="paper",
442
+ yref="paper",
443
+ x=0.5,
444
+ y=0.5,
445
+ showarrow=False,
446
+ )
447
+ ]
448
+ )
449
+ else:
450
+ fig.update_yaxes(autorange="reversed")
451
+ return fig
452
+
453
+
454
+ def attention_cohort_view(
455
+ fi_lists: dict,
456
+ modality: str,
457
+ top_n: int,
458
+ mode: str,
459
+ ):
460
+ """
461
+ mode: 'compare': grouped bars for all three cohorts;
462
+ 'all' | 'dead_end' | 'reprogramming': single cohort only.
463
+ """
464
+ if mode == "compare":
465
+ return attention_top_comparison(fi_lists, modality, top_n)
466
+ mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
467
+ cohort = fi_lists.get(mode) or {}
468
+ items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
469
+ label = {
470
+ "all": "All validation samples",
471
+ "dead_end": "Predicted dead-end",
472
+ "reprogramming": "Predicted reprogramming",
473
+ }.get(mode, mode)
474
+ if not items:
475
+ fig = go.Figure()
476
+ fig.update_layout(
477
+ template="plotly_white",
478
+ font=PLOT_FONT,
479
+ title=f"{modality} · {label}",
480
+ annotations=[
481
+ dict(
482
+ text="No items for this cohort.",
483
+ xref="paper",
484
+ yref="paper",
485
+ x=0.5,
486
+ y=0.5,
487
+ showarrow=False,
488
+ )
489
+ ],
490
+ )
491
+ return fig
492
+ feats, scores = zip(*items)
493
+ fig = go.Figure(
494
+ go.Bar(
495
+ x=list(scores),
496
+ y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
497
+ orientation="h",
498
+ marker_color=PALETTE[0],
499
+ )
500
+ )
501
+ h = max(280, 40 + min(top_n, 25) * 20)
502
+ fig.update_layout(
503
+ template="plotly_white",
504
+ font=PLOT_FONT,
505
+ title=f"{modality} · {label}",
506
+ width=520,
507
+ height=h,
508
+ margin=dict(l=220, r=24, t=56, b=40),
509
+ xaxis_title="Attention weight",
510
+ )
511
+ fig.update_yaxes(autorange="reversed")
512
+ return fig
513
+
514
+
515
+ def global_rank_triple_panel(df_features, top_n: int = 20, top_n_pie: int = 100):
516
+ """
517
+ Global top-N by latent-shift and by attention (min–max scaled), plus pie of modality mix
518
+ among the top `top_n_pie` features by mean rank.
519
+ """
520
+ d = df_features.copy()
521
+ for col in ("importance_shift", "importance_att"):
522
+ min_v, max_v = d[col].min(), d[col].max()
523
+ if max_v > min_v:
524
+ d[col + "_norm"] = (d[col] - min_v) / (max_v - min_v)
525
+ else:
526
+ d[col + "_norm"] = 0.0
527
+
528
+ shift_top = d.nlargest(top_n, "importance_shift")
529
+ att_top = d.nlargest(top_n, "importance_att")
530
+ pie_pool = d.nsmallest(top_n_pie, "mean_rank")
531
+
532
+ fig = make_subplots(
533
+ rows=1,
534
+ cols=3,
535
+ column_widths=[0.36, 0.36, 0.28],
536
+ specs=[[{}, {}, {"type": "domain"}]],
537
+ subplot_titles=(
538
+ f"Top {top_n} by latent shift (ranked)",
539
+ f"Top {top_n} by attention (ranked)",
540
+ f"Top {top_n_pie} by mean rank (modality mix)",
541
+ ),
542
+ horizontal_spacing=0.06,
543
+ )
544
+
545
+ fig.add_trace(
546
+ go.Bar(
547
+ x=shift_top["importance_shift_norm"],
548
+ y=shift_top["feature"],
549
+ orientation="h",
550
+ marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in shift_top["modality"]],
551
+ marker_line=dict(color="rgba(15,23,42,0.12)", width=1),
552
+ showlegend=False,
553
+ hovertemplate="%{y}<br>scaled shift: %{x:.3f}<extra></extra>",
554
+ ),
555
+ row=1,
556
+ col=1,
557
+ )
558
+ fig.add_trace(
559
+ go.Bar(
560
+ x=att_top["importance_att_norm"],
561
+ y=att_top["feature"],
562
+ orientation="h",
563
+ marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in att_top["modality"]],
564
+ marker_line=dict(color="rgba(15,23,42,0.12)", width=1),
565
+ showlegend=False,
566
+ hovertemplate="%{y}<br>scaled attention: %{x:.3f}<extra></extra>",
567
+ ),
568
+ row=1,
569
+ col=2,
570
+ )
571
+
572
+ pie_labels = ["RNA", "ATAC", "Flux"]
573
+ counts = pie_pool["modality"].value_counts()
574
+ pie_vals = [int(counts.get(lab, 0)) for lab in pie_labels]
575
+ if sum(pie_vals) == 0:
576
+ pie_vals = [1, 1, 1]
577
+
578
+ fig.add_trace(
579
+ go.Pie(
580
+ labels=pie_labels,
581
+ values=pie_vals,
582
+ marker=dict(
583
+ colors=[MODALITY_PIE_COLOR.get(l, "#64748b") for l in pie_labels],
584
+ line=dict(color="#1e293b", width=1.2),
585
+ ),
586
+ textinfo="label+percent",
587
+ textfont_size=12,
588
+ hole=0.0,
589
+ showlegend=False,
590
+ ),
591
+ row=1,
592
+ col=3,
593
+ )
594
+
595
+ fig.update_xaxes(title_text="Min-max scaled shift", row=1, col=1)
596
+ fig.update_xaxes(title_text="Min-max scaled attention", row=1, col=2)
597
+ fig.update_yaxes(autorange="reversed", row=1, col=1)
598
+ fig.update_yaxes(autorange="reversed", row=1, col=2)
599
+
600
+ h = max(480, 40 + top_n * 18)
601
+ fig.update_layout(
602
+ template="plotly_white",
603
+ font=PLOT_FONT,
604
+ height=h,
605
+ width=min(1280, 400 + top_n * 14),
606
+ margin=dict(l=40, r=40, t=80, b=40),
607
+ title_text="Global feature ranking (all modalities)",
608
+ title_x=0.5,
609
+ )
610
+ return fig
611
+
612
+
613
+ def _flux_prepare_top_ranked(flux_df: pd.DataFrame, top_n: int, metric: str = "mean_rank") -> pd.DataFrame:
614
+ sub = flux_df[~flux_df["feature"].astype(str).str.contains("batch", case=False, na=False)].copy()
615
+ if metric not in sub.columns:
616
+ metric = "mean_rank"
617
+ sub = sub.sort_values(metric, ascending=True).head(int(top_n)).copy()
618
+ if "pathway" in sub.columns:
619
+ pc = sub["pathway"].value_counts()
620
+ sub["_pw_n"] = sub["pathway"].map(pc)
621
+ sub.sort_values(["_pw_n", "pathway"], ascending=[False, True], inplace=True)
622
+ return sub
623
+
624
+
625
+ def flux_pathway_sunburst(flux_df: pd.DataFrame, max_features: int = 55) -> go.Figure:
626
+ sub = flux_df.dropna(subset=["pathway"]).copy()
627
+ if sub.empty:
628
+ return go.Figure()
629
+ sub = sub.nsmallest(int(max_features), "mean_rank")
630
+ sub["pathway"] = sub["pathway"].astype(str)
631
+ sub["_uid"] = np.arange(len(sub))
632
+ sub["rxn"] = sub.apply(
633
+ lambda r: f"{_truncate_label(str(r['feature']), 36)} ·{int(r['_uid'])}",
634
+ axis=1,
635
+ )
636
+ mr = sub["mean_rank"].astype(float)
637
+ sub["w"] = (mr.max() - mr + 1.0).clip(lower=0.5)
638
+ color_col = "log_fc" if "log_fc" in sub.columns and sub["log_fc"].notna().any() else "mean_rank"
639
+ sb_kw: dict[str, Any] = {
640
+ "path": ["pathway", "rxn"],
641
+ "values": "w",
642
+ "color": color_col,
643
+ "hover_data": {"mean_rank": ":.2f", "pval_adj": ":.2e", "feature": True, "w": False, "_uid": False},
644
+ }
645
+ if color_col == "log_fc":
646
+ sb_kw["color_continuous_scale"] = LOG_FC_DIVERGING_SCALE
647
+ sb_kw["range_color"] = [LOG_FC_COLOR_MIN, LOG_FC_COLOR_MAX]
648
+ else:
649
+ sb_kw["color_continuous_scale"] = "Viridis_r"
650
+ fig = px.sunburst(sub, **sb_kw)
651
+ fig.update_layout(
652
+ template="plotly_white",
653
+ font=PLOT_FONT,
654
+ margin=dict(l=8, r=8, t=100, b=16),
655
+ height=min(820, 520 + int(max_features) * 5),
656
+ title=dict(
657
+ text="Top flux reactions by model rank, nested under pathway",
658
+ x=0,
659
+ xanchor="left",
660
+ y=0.99,
661
+ yanchor="top",
662
+ font=dict(size=13, family=PLOT_FONT["family"]),
663
+ pad=dict(b=16, l=4),
664
+ ),
665
+ )
666
+ if color_col == "log_fc":
667
+ fig.update_layout(
668
+ coloraxis=dict(
669
+ cmin=LOG_FC_COLOR_MIN,
670
+ cmax=LOG_FC_COLOR_MAX,
671
+ colorbar=dict(
672
+ title=dict(text=LABEL_LOG2FC, side="right"),
673
+ tickformat=".2f",
674
+ len=0.38,
675
+ thickness=12,
676
+ y=0.52,
677
+ yanchor="middle",
678
+ ),
679
+ )
680
+ )
681
+ return fig
682
+
683
+
684
+ def flux_volcano(flux_df: pd.DataFrame) -> go.Figure:
685
+ if "log_fc" not in flux_df.columns:
686
+ return go.Figure()
687
+ d = flux_df.dropna(subset=["log_fc"]).copy()
688
+ if d.empty:
689
+ return go.Figure()
690
+ # Drop degenerate rows: ~zero fold-change with exactly-zero adjusted p (numeric artifact / noise).
691
+ lf = d["log_fc"].astype(float)
692
+ if "pval_adj" in d.columns:
693
+ pa = d["pval_adj"].astype(float)
694
+ bad = np.isfinite(lf) & np.isfinite(pa) & (np.abs(lf) < 1e-10) & (pa <= 0.0)
695
+ d = d[~bad]
696
+ if d.empty:
697
+ return go.Figure()
698
+ if "pval_adj_log" in d.columns:
699
+ y = d["pval_adj_log"].astype(float)
700
+ else:
701
+ p = d["pval_adj"].astype(float).clip(lower=1e-300)
702
+ y = -np.log10(p.to_numpy())
703
+ d = d.assign(_neglogp=y)
704
+ fig = px.scatter(
705
+ d,
706
+ x="log_fc",
707
+ y="_neglogp",
708
+ color="mean_rank",
709
+ color_continuous_scale="Viridis_r",
710
+ hover_name="feature",
711
+ hover_data=["pathway", "pval_adj", "group"],
712
+ labels={
713
+ "log_fc": LABEL_LOG2FC,
714
+ "_neglogp": LABEL_NEG_LOG10_ADJ_P,
715
+ "mean_rank": "Mean rank",
716
+ },
717
+ )
718
+ fig.update_layout(
719
+ template="plotly_white",
720
+ font=PLOT_FONT,
721
+ title="Differential flux vs statistical significance",
722
+ height=520,
723
+ margin=dict(l=52, r=24, t=52, b=48),
724
+ coloraxis_colorbar=dict(
725
+ title=dict(text="Mean rank", side="right"),
726
+ thickness=12,
727
+ len=0.55,
728
+ ),
729
+ )
730
+ return fig
731
+
732
+
733
+ def motif_tf_mean_rank_bars(atac_df: pd.DataFrame, top_n: int = 22) -> go.Figure:
734
+ """Aggregate motif features by TF name (prefix before ``_<motif_id>``); show lowest mean joint rank."""
735
+ if atac_df.empty or "feature" not in atac_df.columns:
736
+ return go.Figure()
737
+
738
+ def _tf_prefix(feat: str) -> str:
739
+ s = str(feat)
740
+ if "_" in s:
741
+ head, tail = s.rsplit("_", 1)
742
+ if tail.isdigit():
743
+ return head
744
+ return s
745
+
746
+ d = atac_df.copy()
747
+ d["_tf"] = d["feature"].map(_tf_prefix)
748
+ agg = d.groupby("_tf", as_index=False)["mean_rank"].mean()
749
+ agg = agg.nsmallest(int(top_n), "mean_rank").sort_values("mean_rank", ascending=True)
750
+ if agg.empty:
751
+ return go.Figure()
752
+ y_show = agg["_tf"].astype(str).map(lambda s: _truncate_label(s, 36))
753
+ fig = go.Figure(
754
+ go.Bar(
755
+ y=y_show,
756
+ x=agg["mean_rank"],
757
+ orientation="h",
758
+ marker_color=MODALITY_COLOR.get("ATAC", PALETTE[0]),
759
+ customdata=agg["_tf"],
760
+ hovertemplate="<b>%{customdata}</b><br>Mean mean_rank (across motifs): %{x:.2f}<extra></extra>",
761
+ )
762
+ )
763
+ fig.update_layout(
764
+ template="plotly_white",
765
+ font=PLOT_FONT,
766
+ title=f"TFs by average motif rank (top {top_n} by lowest mean rank)",
767
+ height=min(640, 48 + 22 * len(agg)),
768
+ margin=dict(l=160, r=24, t=52, b=40),
769
+ xaxis_title="Mean of mean_rank over motif instances (lower = stronger)",
770
+ yaxis_title="",
771
+ )
772
+ fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
773
+ return fig
774
+
775
+
776
+ def motif_chromvar_volcano(atac_df: pd.DataFrame) -> go.Figure:
777
+ """Motif differential view: mean activity difference (reprogramming − dead-end) vs significance."""
778
+ need = ("mean_diff", "pval_adj")
779
+ if not all(c in atac_df.columns for c in need):
780
+ return go.Figure()
781
+ d = atac_df.dropna(subset=["mean_diff", "pval_adj"]).copy()
782
+ if d.empty:
783
+ return go.Figure()
784
+ md = d["mean_diff"].astype(float)
785
+ pa = d["pval_adj"].astype(float)
786
+ bad = np.isfinite(md) & np.isfinite(pa) & (np.abs(md) < 1e-12) & (pa <= 0.0)
787
+ d = d[~bad]
788
+ if d.empty:
789
+ return go.Figure()
790
+ if "pval_adj_log" in d.columns:
791
+ y = d["pval_adj_log"].astype(float)
792
+ else:
793
+ p = d["pval_adj"].astype(float).clip(lower=1e-300)
794
+ y = -np.log10(p.to_numpy())
795
+ d = d.assign(_y=y)
796
+ hover_cols = [c for c in ("group", "pval_adj", "mean_rank", "mean_de", "mean_re") if c in d.columns]
797
+ fig = px.scatter(
798
+ d,
799
+ x="mean_diff",
800
+ y="_y",
801
+ color="mean_rank",
802
+ color_continuous_scale="Viridis_r",
803
+ hover_name="feature",
804
+ hover_data=hover_cols if hover_cols else None,
805
+ labels={
806
+ "mean_diff": "Mean difference (reprogramming − dead-end)",
807
+ "_y": LABEL_NEG_LOG10_ADJ_P,
808
+ "mean_rank": "Mean rank",
809
+ },
810
+ )
811
+ fig.update_layout(
812
+ template="plotly_white",
813
+ font=PLOT_FONT,
814
+ title="TF motif differential activity (mean difference vs significance)",
815
+ height=520,
816
+ margin=dict(l=52, r=24, t=52, b=48),
817
+ coloraxis_colorbar=dict(title=dict(text="Mean rank", side="right"), thickness=12, len=0.55),
818
+ )
819
+ return fig
820
+
821
+
822
+ def notebook_style_activity_scatter(
823
+ df: pd.DataFrame,
824
+ title: str,
825
+ x_title: str,
826
+ y_title: str,
827
+ ) -> go.Figure:
828
+ """mean_de vs mean_re, colour = pval_adj_log (Reds), marker size ∝ inverse mean_rank."""
829
+ need = ("mean_de", "mean_re", "mean_rank", "pval_adj_log", "feature", "group")
830
+ if not all(c in df.columns for c in need):
831
+ return go.Figure()
832
+ d = df.dropna(subset=["mean_de", "mean_re", "mean_rank", "pval_adj_log"]).copy()
833
+ if d.empty:
834
+ return go.Figure()
835
+ mx = float(d["mean_rank"].max())
836
+ d = d.assign(_inv=(mx - d["mean_rank"].astype(float)).clip(lower=0))
837
+ inv = d["_inv"].astype(float)
838
+ lo, hi = float(inv.min()), float(inv.max())
839
+ if hi <= lo:
840
+ d["_sz"] = 6.0
841
+ else:
842
+ d["_sz"] = 3.5 + (inv - lo) / (hi - lo) * 9.0
843
+
844
+ fig = px.scatter(
845
+ d,
846
+ x="mean_de",
847
+ y="mean_re",
848
+ color="pval_adj_log",
849
+ color_continuous_scale="Reds",
850
+ size="_sz",
851
+ size_max=14,
852
+ hover_name="feature",
853
+ hover_data={
854
+ "mean_rank": ":.2f",
855
+ "group": True,
856
+ "pval_adj_log": ":.2f",
857
+ "_inv": False,
858
+ "_sz": False,
859
+ },
860
+ labels={
861
+ "mean_de": x_title,
862
+ "mean_re": y_title,
863
+ "pval_adj_log": "Adj. p-value (log)",
864
+ },
865
+ )
866
+ fig.update_traces(
867
+ marker=dict(line=dict(width=0.45, color="rgba(255,255,255,0.75)"), opacity=0.9),
868
+ selector=dict(mode="markers"),
869
+ )
870
+ fig.update_layout(
871
+ template="plotly_white",
872
+ font=PLOT_FONT,
873
+ title=title,
874
+ height=520,
875
+ margin=dict(l=52, r=24, t=52, b=48),
876
+ coloraxis_colorbar=dict(title=dict(text="Adj. p (log)", side="right"), thickness=12, len=0.55),
877
+ )
878
+ return fig
879
+
880
+
881
+ def pathway_bubble_suggested_height(n_paths: int) -> int:
882
+ """Total figure height for pathway bubble panels (use the max of both cohorts so legends line up)."""
883
+ n = max(int(n_paths), 1)
884
+ return max(520, min(1100, 22 * n + 200))
885
+
886
+
887
+ def pathway_enrichment_bubble_panel(
888
+ df: pd.DataFrame,
889
+ title: str,
890
+ *,
891
+ show_colorbar: bool = True,
892
+ layout_height: int | None = None,
893
+ ) -> go.Figure:
894
+ """Single cohort: Reactome (circle) vs KEGG (square), colour = −log₁₀ Benjamini (scale per panel)."""
895
+ fig = go.Figure()
896
+ if df.empty:
897
+ fig.update_layout(
898
+ template="plotly_white",
899
+ font=PLOT_FONT,
900
+ title=dict(text=title, x=0.5, xanchor="center"),
901
+ annotations=[
902
+ dict(
903
+ text="No significant pathways (Benjamini–Hochberg q < 0.05)",
904
+ xref="paper",
905
+ yref="paper",
906
+ x=0.5,
907
+ y=0.5,
908
+ showarrow=False,
909
+ font=dict(size=13, color="#64748b"),
910
+ )
911
+ ],
912
+ height=320,
913
+ margin=dict(l=40, r=40, t=56, b=40),
914
+ )
915
+ return fig
916
+
917
+ # More genes in the overlap first, then stronger gene ratio (matches enrichment table emphasis).
918
+ d = df.sort_values(by=["Count", "Gene Ratio"], ascending=[False, False]).reset_index(drop=True)
919
+ d = d.assign(
920
+ _neglog=-np.log10(d["Benjamini"].astype(float).clip(lower=1e-300)),
921
+ _y=np.arange(len(d), dtype=float),
922
+ )
923
+ nl = d["_neglog"].astype(float)
924
+ cmin = float(nl.min())
925
+ cmax = float(nl.max())
926
+ if cmax <= cmin:
927
+ cmax = cmin + 1e-6
928
+
929
+ # Single trace: per-panel cmin/cmax so Viridis uses the cohort’s range (shared global max clusters at one hue).
930
+ sym_map = {"Reactome": "circle", "KEGG": "square"}
931
+ symbols = [sym_map.get(str(x), "circle") for x in d["Library"].tolist()]
932
+ sz = np.sqrt(d["Count"].astype(float).clip(lower=1)) * 4.8
933
+ customdata = np.stack(
934
+ [d["Count"].to_numpy(), d["_neglog"].to_numpy(), d["Library"].astype(str).to_numpy()],
935
+ axis=1,
936
+ )
937
+ fig.add_trace(
938
+ go.Scatter(
939
+ x=d["Gene Ratio"],
940
+ y=d["_y"],
941
+ mode="markers",
942
+ name="Pathways",
943
+ showlegend=False,
944
+ marker=dict(
945
+ size=sz,
946
+ sizemode="diameter",
947
+ sizemin=4,
948
+ symbol=symbols,
949
+ color=d["_neglog"],
950
+ cmin=cmin,
951
+ cmax=cmax,
952
+ colorscale="Viridis",
953
+ showscale=bool(show_colorbar),
954
+ colorbar=dict(
955
+ title=dict(
956
+ text="\u2212log\u2081\u2080 q",
957
+ side="right",
958
+ ),
959
+ len=0.72,
960
+ thickness=12,
961
+ y=0.45,
962
+ yanchor="middle",
963
+ outlinewidth=0,
964
+ )
965
+ if show_colorbar
966
+ else None,
967
+ line=dict(width=0.75, color="rgba(0,0,0,0.5)"),
968
+ opacity=0.92,
969
+ ),
970
+ text=d["Term"],
971
+ customdata=customdata,
972
+ hovertemplate=(
973
+ "<b>%{text}</b><br>%{customdata[2]}<br>Gene ratio: %{x:.3f}<br>Count: %{customdata[0]}"
974
+ "<br>\u2212log\u2081\u2080 Benjamini: %{customdata[1]:.2f}<extra></extra>"
975
+ ),
976
+ )
977
+ )
978
+ for lib, sym in (("Reactome", "circle"), ("KEGG", "square")):
979
+ if lib not in set(d["Library"].astype(str)):
980
+ continue
981
+ fig.add_trace(
982
+ go.Scatter(
983
+ x=[None],
984
+ y=[None],
985
+ mode="markers",
986
+ name=lib,
987
+ marker=dict(
988
+ symbol=sym,
989
+ size=11,
990
+ color="#475569",
991
+ line=dict(width=1, color="rgba(0,0,0,0.45)"),
992
+ ),
993
+ showlegend=True,
994
+ )
995
+ )
996
+
997
+ ticktext = [_truncate_label(str(t), 52) for t in d["Term"]]
998
+ h = int(layout_height) if layout_height is not None else pathway_bubble_suggested_height(len(d))
999
+ fig.update_yaxes(
1000
+ tickmode="array",
1001
+ tickvals=d["_y"].tolist(),
1002
+ ticktext=ticktext,
1003
+ autorange="reversed",
1004
+ title="",
1005
+ )
1006
+ fig.update_xaxes(title_text="Gene ratio (count ÷ list total)")
1007
+ fig.update_layout(
1008
+ template="plotly_white",
1009
+ font=PLOT_FONT,
1010
+ title=dict(
1011
+ text=title,
1012
+ x=0.5,
1013
+ xanchor="center",
1014
+ yanchor="top",
1015
+ y=0.985,
1016
+ pad=dict(b=0),
1017
+ ),
1018
+ height=h,
1019
+ margin=dict(l=215, r=132, t=48, b=108),
1020
+ legend=dict(
1021
+ orientation="h",
1022
+ yanchor="top",
1023
+ y=-0.11,
1024
+ xanchor="center",
1025
+ x=0.5,
1026
+ bgcolor="rgba(255,255,255,0.92)",
1027
+ bordercolor="rgba(0,0,0,0.08)",
1028
+ borderwidth=1,
1029
+ ),
1030
+ showlegend=True,
1031
+ )
1032
+ return fig
1033
+
1034
+
1035
+ def pathway_gene_membership_heatmap(
1036
+ z: np.ndarray, row_labels: list[str], col_labels: list[str]
1037
+ ) -> go.Figure:
1038
+ """Pathway × gene grid; empty cells transparent; light gaps; legend for category colours."""
1039
+ if z.size == 0:
1040
+ return go.Figure()
1041
+ # Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots.
1042
+ _z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84}
1043
+ zn = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z).astype(float)
1044
+ transparent = "rgba(0,0,0,0)"
1045
+ colorscale = [
1046
+ [0.0, transparent],
1047
+ [0.14, transparent],
1048
+ [0.15, "#e69138"],
1049
+ [0.33, "#e69138"],
1050
+ [0.34, "#7eb6d9"],
1051
+ [0.53, "#7eb6d9"],
1052
+ [0.54, "#9ccc65"],
1053
+ [0.73, "#9ccc65"],
1054
+ [0.74, "#283593"],
1055
+ [1.0, "#283593"],
1056
+ ]
1057
+
1058
+ def _cell_hint(v: float) -> str:
1059
+ k = int(round(float(v)))
1060
+ return {
1061
+ 0: "",
1062
+ 1: "Gene enriched in dead-end contrast",
1063
+ 2: "Gene enriched in reprogramming contrast",
1064
+ 3: "Reactome pathway set",
1065
+ 4: "KEGG pathway set",
1066
+ }.get(k, "")
1067
+
1068
+ z_int = z.astype(int)
1069
+ text_grid = [[_cell_hint(z_int[i, j]) for j in range(z.shape[1])] for i in range(z.shape[0])]
1070
+
1071
+ heat = go.Heatmap(
1072
+ z=zn,
1073
+ x=col_labels,
1074
+ y=row_labels,
1075
+ text=text_grid,
1076
+ colorscale=colorscale,
1077
+ zmin=0,
1078
+ zmax=1,
1079
+ showscale=False,
1080
+ xgap=1,
1081
+ ygap=1,
1082
+ hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
1083
+ )
1084
+
1085
+ fig = go.Figure(data=[heat])
1086
+
1087
+ n_rows, n_cols = z.shape
1088
+ cell_w = 10
1089
+ cell_h = 20
1090
+ w = int(min(1000, max(460, n_cols * cell_w + 272)))
1091
+ h = int(min(960, max(460, n_rows * cell_h + 128)))
1092
+ fig.update_layout(
1093
+ template="plotly_white",
1094
+ font=PLOT_FONT,
1095
+ title=dict(text="Pathway–gene membership", x=0.5, xanchor="center"),
1096
+ height=h,
1097
+ width=w,
1098
+ margin=dict(l=4, r=168, t=52, b=108),
1099
+ paper_bgcolor="rgba(0,0,0,0)",
1100
+ plot_bgcolor="#f4f6f9",
1101
+ xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False),
1102
+ yaxis=dict(
1103
+ tickfont=dict(size=9),
1104
+ showgrid=False,
1105
+ zeroline=False,
1106
+ autorange="reversed",
1107
+ ),
1108
+ )
1109
+
1110
+ legend_markers = [
1111
+ ("Empty cell", "#f1f5f9", "square"),
1112
+ ("Dead-end–linked gene", "#e69138", "square"),
1113
+ ("Reprogramming–linked gene", "#7eb6d9", "square"),
1114
+ ("Reactome (column tag)", "#9ccc65", "square"),
1115
+ ("KEGG (column tag)", "#283593", "square"),
1116
+ ]
1117
+ for name, color, sym in legend_markers:
1118
+ fig.add_trace(
1119
+ go.Scatter(
1120
+ x=[None],
1121
+ y=[None],
1122
+ mode="markers",
1123
+ name=name,
1124
+ marker=dict(size=11, color=color, symbol=sym, line=dict(width=1, color="rgba(0,0,0,0.25)")),
1125
+ showlegend=True,
1126
+ )
1127
+ )
1128
+
1129
+ fig.update_layout(
1130
+ legend=dict(
1131
+ orientation="v",
1132
+ yanchor="top",
1133
+ y=0.98,
1134
+ xanchor="left",
1135
+ x=1.02,
1136
+ bgcolor="rgba(255,255,255,0.92)",
1137
+ bordercolor="rgba(0,0,0,0.08)",
1138
+ borderwidth=1,
1139
+ font=dict(size=11),
1140
+ )
1141
+ )
1142
+ return fig
1143
+
1144
+
1145
+ def flux_dead_end_vs_reprogram_scatter(flux_df: pd.DataFrame, max_pathway_colors: int = 12) -> go.Figure:
1146
+ need = ("mean_de", "mean_re")
1147
+ if not all(c in flux_df.columns for c in need):
1148
+ return go.Figure()
1149
+ d = flux_df.dropna(subset=list(need)).copy()
1150
+ if d.empty:
1151
+ return go.Figure()
1152
+ imp = (
1153
+ d["importance_shift"].astype(float).clip(lower=0) * d["importance_att"].astype(float).clip(lower=0)
1154
+ ) ** 0.5
1155
+ q = float(imp.quantile(0.95)) if len(imp) else 1.0
1156
+ d = d.assign(_s=(imp / (q or 1.0)).clip(upper=1) * 20 + 5)
1157
+ pw = d["pathway"].fillna("Unknown").astype(str) if "pathway" in d.columns else pd.Series(
1158
+ ["Unknown"] * len(d), index=d.index
1159
+ )
1160
+ top_pw = pw.value_counts().head(int(max_pathway_colors)).index
1161
+ d = d.assign(_pw_col=pw.where(pw.isin(top_pw), "Other"))
1162
+ uniq = sorted(d["_pw_col"].astype(str).unique(), key=lambda x: (x == "Other", x))
1163
+ pal = list(LATENT_DISCRETE_PALETTE)
1164
+ pw_cmap: dict[str, str] = {}
1165
+ j = 0
1166
+ for name in uniq:
1167
+ if name == "Other":
1168
+ pw_cmap[name] = "#94a3b8"
1169
+ else:
1170
+ pw_cmap[name] = pal[j % len(pal)]
1171
+ j += 1
1172
+ fig = px.scatter(
1173
+ d,
1174
+ x="mean_de",
1175
+ y="mean_re",
1176
+ color="_pw_col",
1177
+ color_discrete_map=pw_cmap,
1178
+ size="_s",
1179
+ hover_name="feature",
1180
+ hover_data=["mean_rank", "log_fc", "pathway"],
1181
+ labels={
1182
+ "mean_de": "Mean flux · dead-end",
1183
+ "mean_re": "Mean flux · reprogramming",
1184
+ "_pw_col": "Pathway",
1185
+ },
1186
+ )
1187
+ fig.update_layout(
1188
+ template="plotly_white",
1189
+ font=PLOT_FONT,
1190
+ height=540,
1191
+ margin=dict(l=52, r=20, t=52, b=40),
1192
+ title="Average measured flux by fate label (each point is one reaction)",
1193
+ legend=dict(orientation="h", yanchor="top", y=-0.28, xanchor="center", x=0.5),
1194
+ )
1195
+ fig.update_traces(marker=dict(opacity=0.75, line=dict(width=0.35, color="rgba(0,0,0,0.3)")))
1196
+ return fig
1197
+
1198
+
1199
+ def flux_pathway_mean_rank_violin(flux_df: pd.DataFrame, top_pathways: int = 12) -> go.Figure:
1200
+ sub = flux_df.dropna(subset=["pathway"]).copy()
1201
+ if sub.empty:
1202
+ return go.Figure()
1203
+ top_p = sub["pathway"].astype(str).value_counts().head(int(top_pathways)).index
1204
+ sub = sub[sub["pathway"].astype(str).isin(top_p)]
1205
+ top_list = list(top_p)
1206
+ v_cmap = {p: LATENT_DISCRETE_PALETTE[i % len(LATENT_DISCRETE_PALETTE)] for i, p in enumerate(top_list)}
1207
+ fig = px.violin(
1208
+ sub,
1209
+ x="pathway",
1210
+ y="mean_rank",
1211
+ box=True,
1212
+ points=False,
1213
+ color="pathway",
1214
+ color_discrete_map=v_cmap,
1215
+ labels={"mean_rank": "Mean rank (lower = stronger model focus)", "pathway": "Pathway"},
1216
+ )
1217
+ fig.update_layout(
1218
+ template="plotly_white",
1219
+ font=PLOT_FONT,
1220
+ showlegend=False,
1221
+ height=420,
1222
+ xaxis_tickangle=-32,
1223
+ margin=dict(l=48, r=24, t=48, b=140),
1224
+ title="How joint model rank spreads within high-coverage pathways",
1225
+ )
1226
+ return fig
1227
+
1228
+
1229
+ def flux_reaction_annotation_panel(flux_df: pd.DataFrame, top_n: int = 26, metric: str = "mean_rank") -> go.Figure:
1230
+ """Three heatmap columns: pathway (categorical), DE Log₂FC, −log₁₀ adjusted p."""
1231
+ top = _flux_prepare_top_ranked(flux_df, top_n, metric)
1232
+ if top.empty:
1233
+ return go.Figure()
1234
+ n = len(top)
1235
+ pathways = top["pathway"].fillna("Unknown").astype(str).tolist() if "pathway" in top.columns else ["Unknown"] * n
1236
+ uniq = list(dict.fromkeys(pathways))
1237
+ code_map = {u: i for i, u in enumerate(uniq)}
1238
+ codes = np.array([code_map[p] for p in pathways], dtype=float)
1239
+ k = max(len(uniq), 1)
1240
+ qual = list(px.colors.qualitative.Safe) + list(px.colors.qualitative.Dark24) + list(px.colors.qualitative.Light24)
1241
+ if k <= 1:
1242
+ disc_scale = [[0, qual[0]], [1, qual[0]]]
1243
+ else:
1244
+ disc_scale = [[j / (k - 1), qual[j % len(qual)]] for j in range(k)]
1245
+ log_fc = top["log_fc"].fillna(0).astype(float).to_numpy() if "log_fc" in top.columns else np.zeros(n)
1246
+ if "pval_adj_log" in top.columns:
1247
+ pv = top["pval_adj_log"].fillna(0).astype(float).to_numpy()
1248
+ else:
1249
+ pv = -np.log10(top["pval_adj"].astype(float).clip(lower=1e-300).to_numpy())
1250
+ full_features = top["feature"].astype(str).tolist()
1251
+ y_labels = [_truncate_label(str(f), 44) for f in full_features]
1252
+ z_path = codes.reshape(-1, 1)
1253
+ # hovertext (not customdata): subplot heatmaps often render %{customdata[0]} as "-" in the browser.
1254
+ hover_path = [[f"<b>{fn}</b><br>pathway: {pw}"] for fn, pw in zip(full_features, pathways)]
1255
+ hover_lfc = [
1256
+ [f"<b>{fn}</b><br>{LABEL_LOG2FC}: {float(log_fc[i]):.4f}"]
1257
+ for i, fn in enumerate(full_features)
1258
+ ]
1259
+ hover_pv = [
1260
+ [f"<b>{fn}</b><br>{LABEL_NEG_LOG10_ADJ_P}: {float(pv[i]):.2f}"]
1261
+ for i, fn in enumerate(full_features)
1262
+ ]
1263
+ fig = make_subplots(
1264
+ rows=1,
1265
+ cols=3,
1266
+ shared_yaxes=True,
1267
+ horizontal_spacing=0.06,
1268
+ column_widths=[0.24, 0.24, 0.24],
1269
+ )
1270
+ fig.add_trace(
1271
+ go.Heatmap(
1272
+ z=z_path,
1273
+ x=[""],
1274
+ y=y_labels,
1275
+ colorscale=disc_scale,
1276
+ zmin=0,
1277
+ zmax=max(k - 1, 0),
1278
+ showscale=False,
1279
+ hovertext=hover_path,
1280
+ hovertemplate="%{hovertext}<extra></extra>",
1281
+ ),
1282
+ row=1,
1283
+ col=1,
1284
+ )
1285
+ fig.add_trace(
1286
+ go.Heatmap(
1287
+ z=log_fc.reshape(-1, 1),
1288
+ x=[""],
1289
+ y=y_labels,
1290
+ colorscale=LOG_FC_DIVERGING_SCALE,
1291
+ zmin=LOG_FC_COLOR_MIN,
1292
+ zmax=LOG_FC_COLOR_MAX,
1293
+ showscale=True,
1294
+ colorbar=dict(
1295
+ title=dict(text=LABEL_LOG2FC, side="right"),
1296
+ tickformat=".2f",
1297
+ len=0.22,
1298
+ y=0.71,
1299
+ yanchor="middle",
1300
+ x=1.0,
1301
+ xanchor="left",
1302
+ xref="paper",
1303
+ yref="paper",
1304
+ thickness=12,
1305
+ ),
1306
+ hovertext=hover_lfc,
1307
+ hovertemplate="%{hovertext}<extra></extra>",
1308
+ ),
1309
+ row=1,
1310
+ col=2,
1311
+ )
1312
+ fig.add_trace(
1313
+ go.Heatmap(
1314
+ z=pv.reshape(-1, 1),
1315
+ x=[""],
1316
+ y=y_labels,
1317
+ colorscale="Viridis",
1318
+ showscale=True,
1319
+ colorbar=dict(
1320
+ title=dict(text=LABEL_NEG_LOG10_ADJ_P, side="right"),
1321
+ len=0.22,
1322
+ y=0.29,
1323
+ yanchor="middle",
1324
+ x=1.0,
1325
+ xanchor="left",
1326
+ xref="paper",
1327
+ yref="paper",
1328
+ thickness=12,
1329
+ ),
1330
+ hovertext=hover_pv,
1331
+ hovertemplate="%{hovertext}<extra></extra>",
1332
+ ),
1333
+ row=1,
1334
+ col=3,
1335
+ )
1336
+ fig.update_layout(
1337
+ template="plotly_white",
1338
+ font=PLOT_FONT,
1339
+ height=min(820, 120 + n * 22),
1340
+ width=900,
1341
+ margin=dict(l=8, r=108, t=56, b=72),
1342
+ title=dict(
1343
+ text=f"Pathway, {LABEL_LOG2FC}, and significance",
1344
+ x=0,
1345
+ xanchor="left",
1346
+ y=0.995,
1347
+ yanchor="top",
1348
+ font=dict(size=13, family=PLOT_FONT["family"]),
1349
+ pad=dict(b=8, l=4),
1350
+ ),
1351
+ )
1352
+ fig.update_xaxes(side="bottom", title_standoff=8)
1353
+ fig.update_xaxes(title_text="Pathway", row=1, col=1)
1354
+ fig.update_xaxes(title_text=LABEL_LOG2FC, row=1, col=2)
1355
+ fig.update_xaxes(title_text=LABEL_NEG_LOG10_ADJ_P, row=1, col=3)
1356
+ fig.update_yaxes(autorange="reversed")
1357
+ return fig
1358
+
1359
+
1360
+ def flux_model_metric_profile(flux_df: pd.DataFrame, top_n: int = 22, metric: str = "mean_rank") -> go.Figure:
1361
+ """Matrix view: scaled shift, attention, model priority, and fate flux contrast."""
1362
+ top = _flux_prepare_top_ranked(flux_df, top_n, metric)
1363
+ if top.empty:
1364
+ return go.Figure()
1365
+
1366
+ def mm(s: pd.Series) -> np.ndarray:
1367
+ v = s.astype(float).to_numpy()
1368
+ lo, hi = float(np.nanmin(v)), float(np.nanmax(v))
1369
+ if hi <= lo or not np.isfinite(lo):
1370
+ return np.zeros_like(v, dtype=float)
1371
+ return (v - lo) / (hi - lo)
1372
+
1373
+ cols: list[np.ndarray] = []
1374
+ labels: list[str] = []
1375
+ for c, lab in (("importance_shift", "Latent shift impact"), ("importance_att", "Attention (rollout)")):
1376
+ if c in top.columns:
1377
+ cols.append(mm(top[c]))
1378
+ labels.append(lab)
1379
+ cols.append(1.0 - mm(top["mean_rank"]))
1380
+ labels.append("Joint priority (1 - scaled mean rank)")
1381
+ if "mean_de" in top.columns and "mean_re" in top.columns:
1382
+ de = top["mean_de"].astype(float).replace(0, np.nan)
1383
+ ratio = (top["mean_re"].astype(float) / (de + 1e-12)).fillna(0)
1384
+ cols.append(mm(ratio))
1385
+ labels.append("RE / DE mean flux (scaled)")
1386
+ z = np.column_stack(cols)
1387
+ full_rxn = top["feature"].astype(str).tolist()
1388
+ x_labels = [_truncate_label(str(f), 34) for f in full_rxn]
1389
+ fig = px.imshow(
1390
+ z.T,
1391
+ x=x_labels,
1392
+ y=labels,
1393
+ aspect="auto",
1394
+ color_continuous_scale="Tealrose",
1395
+ labels=dict(x="Reaction", y="Metric", color="Scaled 0-1 per metric"),
1396
+ )
1397
+ n_met, n_rxn = z.T.shape
1398
+ hover_cd = np.broadcast_to(np.array(full_rxn, dtype=object), (n_met, n_rxn))
1399
+ fig.update_traces(
1400
+ customdata=hover_cd,
1401
+ hovertemplate="<b>%{customdata}</b><br>%{y}<br>scaled: %{z:.3f}<extra></extra>",
1402
+ )
1403
+ fig.update_xaxes(tickangle=-50, side="bottom", title_standoff=12)
1404
+ fig.update_layout(
1405
+ template="plotly_white",
1406
+ font=PLOT_FONT,
1407
+ height=min(380, 140 + len(labels) * 36),
1408
+ margin=dict(l=200, r=28, t=64, b=200),
1409
+ title=dict(
1410
+ text="Reaction profile",
1411
+ x=0,
1412
+ xanchor="left",
1413
+ y=0.98,
1414
+ yanchor="top",
1415
+ font=dict(size=13, family=PLOT_FONT["family"]),
1416
+ pad=dict(b=10, l=4),
1417
+ ),
1418
+ )
1419
+ return fig
1420
+
1421
+
streamlit_hf/lib/reactions.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared reaction-string normalisation (flux features vs metabolic metadata)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+
8
+ def normalize_reaction_key(name: str) -> str:
9
+ """Map `A→B` style names to the same key as metadata `A -> B` (case-insensitive)."""
10
+ t = str(name).strip().replace("→", " -> ")
11
+ t = re.sub(r"\s+", " ", t)
12
+ return t.lower()
streamlit_hf/lib/ui.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Light shared styles (no heavy themes; keeps default Streamlit + plotly_white)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import streamlit as st
6
+
7
+
8
+ def inject_app_styles() -> None:
9
+ """Panel labels and home cards; safe to call on every rerun (small CSS block)."""
10
+ st.markdown(
11
+ """
12
+ <style>
13
+ .latent-panel-title {
14
+ font-size: 0.82rem;
15
+ font-weight: 600;
16
+ color: #475569;
17
+ margin: 0 0 0.35rem 0;
18
+ letter-spacing: 0.02em;
19
+ }
20
+ .latent-panel-title-gap { margin-top: 0.85rem; }
21
+ </style>
22
+ """,
23
+ unsafe_allow_html=True,
24
+ )
streamlit_hf/pages/1_Single_Cell_Explorer.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interactive UMAP of multimodal latent space (validation folds)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+
11
+ _REPO = Path(__file__).resolve().parents[2]
12
+ if str(_REPO) not in sys.path:
13
+ sys.path.insert(0, str(_REPO))
14
+
15
+ from streamlit_hf.lib import formatters
16
+ from streamlit_hf.lib import io
17
+ from streamlit_hf.lib import plots
18
+ from streamlit_hf.lib import ui
19
+
20
+ ui.inject_app_styles()
21
+
22
+ st.title("Single-Cell Explorer")
23
+ st.caption("Explore validation cells in 2-D UMAP space: colour and filter to compare fates, predictions, and modalities.")
24
+
25
+ bundle = io.load_latent_bundle()
26
+ if bundle is None:
27
+ st.error("Latent maps are not available in this session. Ask the maintainer to publish results, then reload.")
28
+ st.stop()
29
+
30
+ samples = io.load_samples_df()
31
+ df = io.latent_join_samples(bundle, samples)
32
+
33
+ left, right = st.columns([0.36, 0.64], gap="large")
34
+
35
+ with left:
36
+ st.markdown('<p class="latent-panel-title">Colour by</p>', unsafe_allow_html=True)
37
+ color_opt = st.selectbox(
38
+ "Hue",
39
+ [
40
+ "label",
41
+ "predicted_class",
42
+ "correct",
43
+ "fold",
44
+ "batch_no",
45
+ "modality_label",
46
+ "pct",
47
+ ],
48
+ format_func=lambda x: {
49
+ "label": "CellTag-Multi label",
50
+ "predicted_class": "Predicted fate",
51
+ "correct": "Prediction correct",
52
+ "fold": "CV fold",
53
+ "batch_no": "Batch",
54
+ "modality_label": "Available modalities",
55
+ "pct": "Dominant fate %",
56
+ }[x],
57
+ label_visibility="collapsed",
58
+ help="Which variable sets the colour of each point on the UMAP.",
59
+ )
60
+
61
+ st.markdown('<p class="latent-panel-title latent-panel-title-gap">Filters</p>', unsafe_allow_html=True)
62
+ mod_labels = sorted(df["modality_label"].astype(str).unique())
63
+ mod_pick = st.multiselect(
64
+ "Available modalities",
65
+ mod_labels,
66
+ default=mod_labels,
67
+ help="Keep cells whose modality combination matches your selection (RNA/ATAC measured where present; flux inferred).",
68
+ )
69
+ only_correct = st.selectbox(
70
+ "Prediction outcome",
71
+ ["All", "Correct only", "Wrong only"],
72
+ help="Restrict to cells where the model was correct, incorrect, or show all.",
73
+ )
74
+ folds = sorted(df["fold"].unique())
75
+ fold_pick = st.multiselect(
76
+ "CV folds",
77
+ folds,
78
+ default=folds,
79
+ help="Validation cross-validation folds to include (each fold’s held-out cells).",
80
+ )
81
+ pct_rng = st.slider(
82
+ "Dominant fate % range",
83
+ 0.0,
84
+ 100.0,
85
+ (0.0, 100.0),
86
+ 1.0,
87
+ help="Keep cells whose dominant lineage probability (percent) falls in this range.",
88
+ )
89
+
90
+ plot_df = df[df["fold"].isin(fold_pick) & df["modality_label"].isin(mod_pick)].copy()
91
+ plot_df = plot_df[(plot_df["pct"] >= pct_rng[0]) & (plot_df["pct"] <= pct_rng[1])]
92
+ if only_correct == "Correct only":
93
+ plot_df = plot_df[plot_df["correct"]]
94
+ elif only_correct == "Wrong only":
95
+ plot_df = plot_df[~plot_df["correct"]]
96
+
97
+ if plot_df.empty:
98
+ st.warning("No points after filters. Relax the filters and try again.")
99
+ st.stop()
100
+
101
+ with right:
102
+ fig = plots.latent_scatter(
103
+ plot_df,
104
+ color_opt,
105
+ title="Validation latent space (UMAP)",
106
+ width=900,
107
+ height=560,
108
+ marker_size=5.8,
109
+ marker_opacity=0.74,
110
+ )
111
+ st.plotly_chart(fig, width="stretch", on_select="rerun", key="latent_pick")
112
+
113
+ st.subheader("Selected points")
114
+ state = st.session_state.get("latent_pick")
115
+ points = []
116
+ if isinstance(state, dict):
117
+ sel = state.get("selection") or {}
118
+ if isinstance(sel, dict):
119
+ points = sel.get("points") or []
120
+ if points:
121
+ idxs = [int(p["point_index"]) for p in points if "point_index" in p]
122
+ idxs = [i for i in idxs if 0 <= i < len(plot_df)]
123
+ if idxs:
124
+ sub = plot_df.iloc[idxs]
125
+ disp = formatters.prepare_latent_display_dataframe(sub)
126
+ st.dataframe(
127
+ disp,
128
+ width="stretch",
129
+ hide_index=True,
130
+ )
131
+ else:
132
+ st.warning(
133
+ "A selection was reported but no valid points matched the current filtered view. "
134
+ "Try selecting again after changing filters, or pick a row via **Inspect by dataset index**."
135
+ )
136
+ else:
137
+ st.info(
138
+ "This table fills in when you **select points on the UMAP**. "
139
+ "In the chart’s top-right toolbar, choose **Box select** or **Lasso select**, "
140
+ "then drag over the dots; the page reruns and rows for those cells appear here. "
141
+ "To inspect one cell without using the lasso, scroll down to **Inspect by dataset index**."
142
+ )
143
+
144
+ st.subheader("Inspect by dataset index")
145
+ pick = st.number_input(
146
+ "Dataset index",
147
+ min_value=int(df["dataset_idx"].min()),
148
+ max_value=int(df["dataset_idx"].max()),
149
+ value=int(df["dataset_idx"].iloc[0]),
150
+ help="Index `ind` in your sample table; aligns one validation cell to this row.",
151
+ )
152
+ row = df[df["dataset_idx"] == pick]
153
+ if not row.empty:
154
+ st.dataframe(
155
+ formatters.latent_inspector_key_value(row.iloc[0]),
156
+ width="stretch",
157
+ hide_index=True,
158
+ )
streamlit_hf/pages/2_Feature_insights.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multimodal feature importance: ranks, attention by prediction, tables."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+
11
+ _REPO = Path(__file__).resolve().parents[2]
12
+ if str(_REPO) not in sys.path:
13
+ sys.path.insert(0, str(_REPO))
14
+
15
+ from streamlit_hf.lib import io
16
+ from streamlit_hf.lib import plots
17
+ from streamlit_hf.lib import ui
18
+
19
+ ui.inject_app_styles()
20
+
21
+ st.title("Feature Insights")
22
+ st.caption("Latent-shift probes, attention rollout, and combined rankings across RNA, ATAC, and Flux.")
23
+
24
+ df = io.load_df_features()
25
+ att = io.load_attention_summary()
26
+
27
+ if df is None:
28
+ st.error(
29
+ "Feature data are not loaded. Ask the maintainer to publish results for this app, then reload."
30
+ )
31
+ st.stop()
32
+
33
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(
34
+ [
35
+ "Global overview",
36
+ "Modality spotlight",
37
+ "Shift vs attention",
38
+ "Attention vs prediction",
39
+ "Full table",
40
+ ]
41
+ )
42
+
43
+ # ----- Tab 1 -----
44
+ with tab1:
45
+ c1, c2 = st.columns(2)
46
+ with c1:
47
+ top_n_bars = st.slider(
48
+ "Top N (shift & attention bars)",
49
+ 10,
50
+ 45,
51
+ 20,
52
+ key="t1_topn_bars",
53
+ )
54
+ with c2:
55
+ top_n_pie = st.slider(
56
+ "Pool size (mean-rank pie)",
57
+ 50,
58
+ 250,
59
+ 100,
60
+ key="t1_topn_pie",
61
+ )
62
+ st.plotly_chart(
63
+ plots.global_rank_triple_panel(df, top_n=top_n_bars, top_n_pie=top_n_pie),
64
+ width="stretch",
65
+ )
66
+ st.caption(
67
+ "Bars: **global** top features by shift impact and by mean attention (min-max scaled); "
68
+ "colour = modality. Pie: RNA / ATAC / Flux mix among the lowest mean-rank features in that pool."
69
+ )
70
+
71
+ # ----- Tab 2: RNA / ATAC / Flux columns -----
72
+ with tab2:
73
+ st.caption(
74
+ "**Modality spotlight:** three columns (**RNA**, **ATAC**, **Flux**). Each column only shows features "
75
+ "from that modality so you can compare shift impact, attention, and joint ranking **within** RNA, ATAC, or flux."
76
+ )
77
+ top_n_rank = st.slider("Top N per chart", 10, 55, 20, key="t2_topn")
78
+ st.subheader("Joint top markers (by mean rank)")
79
+ st.caption(
80
+ "The **strongest combined** markers by mean rank (lower mean rank = higher joint shift + attention priority). "
81
+ "Shift and attention bars are **min-max scaled within this top-N list** (0 to 1) so you can compare them on one axis. "
82
+ "Hover a bar for the full feature name."
83
+ )
84
+ r1a, r1b, r1c = st.columns(3)
85
+ for col, mod in zip((r1a, r1b, r1c), ("RNA", "ATAC", "Flux")):
86
+ sm = df[df["modality"] == mod]
87
+ if sm.empty:
88
+ continue
89
+ with col:
90
+ st.plotly_chart(
91
+ plots.joint_shift_attention_top_features(sm, mod, top_n_rank),
92
+ width="stretch",
93
+ )
94
+ st.subheader("Shift importance")
95
+ r2a, r2b, r2c = st.columns(3)
96
+ for col, mod in zip((r2a, r2b, r2c), ("RNA", "ATAC", "Flux")):
97
+ sm = df[df["modality"] == mod]
98
+ if sm.empty:
99
+ continue
100
+ colc = plots.MODALITY_COLOR.get(mod, plots.PALETTE[0])
101
+ sub = sm.nlargest(top_n_rank, "importance_shift").sort_values("importance_shift", ascending=True)
102
+ with col:
103
+ st.plotly_chart(
104
+ plots.rank_bar(
105
+ sub,
106
+ "importance_shift",
107
+ "feature",
108
+ f"{mod}: shift · top {top_n_rank}",
109
+ colc,
110
+ xaxis_title="Latent shift importance",
111
+ ),
112
+ width="stretch",
113
+ )
114
+ st.subheader("Attention importance")
115
+ r3a, r3b, r3c = st.columns(3)
116
+ for col, mod in zip((r3a, r3b, r3c), ("RNA", "ATAC", "Flux")):
117
+ sm = df[df["modality"] == mod]
118
+ if sm.empty:
119
+ continue
120
+ colc = plots.MODALITY_COLOR.get(mod, plots.PALETTE[0])
121
+ sub = sm.nlargest(top_n_rank, "importance_att").sort_values("importance_att", ascending=True)
122
+ with col:
123
+ st.plotly_chart(
124
+ plots.rank_bar(
125
+ sub,
126
+ "importance_att",
127
+ "feature",
128
+ f"{mod}: attention · top {top_n_rank}",
129
+ colc,
130
+ xaxis_title="Attention importance",
131
+ ),
132
+ width="stretch",
133
+ )
134
+
135
+ # ----- Tab 3 -----
136
+ with tab3:
137
+ st.caption(
138
+ "Each point is **one feature** within its modality. **Attention rank** is on the horizontal axis and **shift rank** "
139
+ "on the vertical axis (1 = strongest in that modality for that metric). Features near the diagonal rank similarly "
140
+ "for both; the **red dashed line** is a straight-line trend (least-squares fit) through the cloud."
141
+ )
142
+ corr_rows = []
143
+ for mod in ("RNA", "ATAC", "Flux"):
144
+ sm = df[df["modality"] == mod]
145
+ if sm.empty:
146
+ continue
147
+ cor = plots.modality_shift_attention_rank_stats(sm)
148
+ if cor.get("n", 0) >= 3:
149
+ corr_rows.append(
150
+ {
151
+ "Modality": mod,
152
+ "# features": cor["n"],
153
+ "Pearson r": f"{cor['pearson_r']:.3f}",
154
+ "Pearson p": f"{cor['pearson_p']:.2e}",
155
+ "Spearman ρ": f"{cor['spearman_r']:.3f}",
156
+ "Spearman p": f"{cor['spearman_p']:.2e}",
157
+ }
158
+ )
159
+ if corr_rows:
160
+ st.dataframe(pd.DataFrame(corr_rows), hide_index=True, width="stretch")
161
+ rc1, rc2, rc3 = st.columns(3)
162
+ for col, mod in zip((rc1, rc2, rc3), ("RNA", "ATAC", "Flux")):
163
+ with col:
164
+ sub_m = df[df["modality"] == mod]
165
+ st.plotly_chart(
166
+ plots.rank_scatter_shift_vs_attention(sub_m, mod),
167
+ width="stretch",
168
+ )
169
+
170
+ # ----- Tab 4 -----
171
+ with tab4:
172
+ with st.expander("What is this?", expanded=False):
173
+ st.markdown(
174
+ "Bars show **mean attention weights** (from rollout) averaged over validation cells, split by **what the "
175
+ "model predicted** for each cell: all validation cells together, only cells called **dead-end**, or only "
176
+ "cells called **reprogramming**. This reflects **model behaviour**, not the true fate label."
177
+ )
178
+ cohort_mode = st.selectbox(
179
+ "Cohort view",
180
+ [
181
+ "compare",
182
+ "all",
183
+ "dead_end",
184
+ "reprogramming",
185
+ ],
186
+ format_func=lambda x: {
187
+ "compare": "Compare cohorts (grouped bars)",
188
+ "all": "All validation samples (mean attention)",
189
+ "dead_end": "Mean attention when prediction = dead-end",
190
+ "reprogramming": "Mean attention when prediction = reprogramming",
191
+ }[x],
192
+ key="t4_cohort",
193
+ help=(
194
+ "Choose which validation cells contribute to the average. **All validation samples** uses every validation "
195
+ "cell. The prediction-specific options use only cells where the model output was dead-end or reprogramming, "
196
+ "so you can see which features receive more weight when the model leans each way."
197
+ ),
198
+ )
199
+ top_n_att = st.slider("Top N", 6, 28, 15, key="t4_topn")
200
+ if not att or "fi_att" not in att:
201
+ st.warning(
202
+ "Attention summaries are not available in this session. That view needs a full publish from the maintainer."
203
+ )
204
+ else:
205
+ ac1, ac2, ac3 = st.columns(3)
206
+ for col, mod in zip((ac1, ac2, ac3), ("RNA", "ATAC", "Flux")):
207
+ with col:
208
+ st.plotly_chart(
209
+ plots.attention_cohort_view(att["fi_att"], mod, top_n=top_n_att, mode=cohort_mode),
210
+ width="stretch",
211
+ )
212
+ if "rollout_mean" in att and "slices" in att:
213
+ st.subheader("Mean rollout weight")
214
+ if cohort_mode == "compare":
215
+ roll_cohort = st.selectbox(
216
+ "Rollout table: average over",
217
+ ["all", "dead_end", "reprogramming"],
218
+ format_func=lambda x: {
219
+ "all": "All validation samples",
220
+ "dead_end": "Cells predicted dead-end",
221
+ "reprogramming": "Cells predicted reprogramming",
222
+ }[x],
223
+ key="t4_roll",
224
+ help="Pick which validation subset is used for the mean rollout vector in the tables below.",
225
+ )
226
+ else:
227
+ roll_cohort = cohort_mode
228
+ st.caption(
229
+ "Rollout tables use the **same cohort** as the bar charts above (batch-embedding tokens are omitted)."
230
+ )
231
+ rc1, rc2, rc3 = st.columns(3)
232
+ for col, mod in zip((rc1, rc2, rc3), ("RNA", "ATAC", "Flux")):
233
+ with col:
234
+ rm = att["rollout_mean"]
235
+ vec_all = rm.get(roll_cohort)
236
+ if vec_all is None:
237
+ vec_all = rm["all"]
238
+ sl = att["slices"][mod]
239
+ vec = vec_all[sl["start"] : sl["stop"]]
240
+ names = att["feature_names"][sl["start"] : sl["stop"]]
241
+ mini = plots.rollout_top_features_table(names, vec, top_n_att)
242
+ st.caption(mod)
243
+ st.dataframe(mini, hide_index=True, width="stretch")
244
+
245
+ # ----- Tab 5 -----
246
+ with tab5:
247
+ scope = st.radio(
248
+ "Table scope",
249
+ ["All modalities", "Single modality"],
250
+ horizontal=True,
251
+ key="t5_scope",
252
+ )
253
+ mod_tbl = "all"
254
+ if scope == "Single modality":
255
+ mod_tbl = st.selectbox("Modality", ["RNA", "ATAC", "Flux"], key="t5_mod")
256
+ tbl = df[df["modality"] == mod_tbl].copy()
257
+ else:
258
+ tbl = df.copy()
259
+ show_cols = [
260
+ c
261
+ for c in [
262
+ "mean_rank",
263
+ "feature",
264
+ "modality",
265
+ "rank_shift_in_modal",
266
+ "rank_att_in_modal",
267
+ "combined_order_mod",
268
+ "rank_shift",
269
+ "rank_att",
270
+ "importance_shift",
271
+ "importance_att",
272
+ "top_10_pct",
273
+ "group",
274
+ "log_fc",
275
+ "pval_adj",
276
+ "pathway",
277
+ "module",
278
+ ]
279
+ if c in tbl.columns
280
+ ]
281
+ st.caption(
282
+ "All rows for the chosen scope, sorted by **mean rank** (lower = stronger joint shift + attention priority). "
283
+ "Use the dataframe search / sort in the table toolbar to narrow down."
284
+ )
285
+ full_view = tbl[show_cols].sort_values("mean_rank")
286
+ st.dataframe(full_view, width="stretch", hide_index=True)
287
+ suffix = mod_tbl if scope == "Single modality" else "all"
288
+ st.download_button(
289
+ "Download table (CSV)",
290
+ full_view.to_csv(index=False).encode("utf-8"),
291
+ file_name=f"fateformer_features_{suffix}.csv",
292
+ mime="text/csv",
293
+ key="t5_dl",
294
+ )
streamlit_hf/pages/3_Flux_analysis.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metabolic flux: pathway map, differential views, reaction ranking table, metabolic model metadata."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import streamlit as st
9
+
10
+ _REPO = Path(__file__).resolve().parents[2]
11
+ if str(_REPO) not in sys.path:
12
+ sys.path.insert(0, str(_REPO))
13
+
14
+ from streamlit_hf.lib import io
15
+ from streamlit_hf.lib import plots
16
+ from streamlit_hf.lib import ui
17
+
18
+ ui.inject_app_styles()
19
+
20
+ st.title("Flux Analysis")
21
+ st.caption(
22
+ "Reaction-level flux: how pathways, statistics, and model rankings line up. "
23
+ "For global rank bars and shift vs. attention scatter, open **Feature insights**."
24
+ )
25
+
26
+ df = io.load_df_features()
27
+ if df is None:
28
+ st.error(
29
+ "Flux and feature data are not loaded in this session. Reload the app after the maintainer has published "
30
+ "fresh results, or ask them to check the deployment."
31
+ )
32
+ st.stop()
33
+
34
+ flux = df[df["modality"] == "Flux"].copy()
35
+ if flux.empty:
36
+ st.warning("There are no flux reactions in the current results.")
37
+ st.stop()
38
+
39
+ meta = io.load_metabolic_model_metadata()
40
+
41
+ tab_map, tab_bio, tab_rank, tab_meta = st.tabs(
42
+ [
43
+ "Pathway map",
44
+ "Differential & fate",
45
+ "Reaction ranking",
46
+ "Metabolic model metadata",
47
+ ]
48
+ )
49
+
50
+ with tab_map:
51
+ st.caption(
52
+ "**Left:** sunburst of the strongest reactions by mean rank, grouped by pathway. **Right:** heatmaps for the "
53
+ "same reactions: pathway, differential Log₂FC, and statistical significance, aligned row by row. "
54
+ "Ranked reaction table: **Reaction Ranking**. Curated model edges: **Metabolic model metadata**."
55
+ )
56
+ try:
57
+ c1, c2 = st.columns([1.05, 0.95], gap="medium", vertical_alignment="top")
58
+ except TypeError:
59
+ c1, c2 = st.columns([1.05, 0.95], gap="medium")
60
+ with c1:
61
+ n_sb = st.slider("Reactions in sunburst", 25, 90, 52, key="flux_sb_n")
62
+ st.plotly_chart(plots.flux_pathway_sunburst(flux, max_features=n_sb), width="stretch")
63
+ with c2:
64
+ top_n_nb = st.slider("Reactions in annotation + profile", 12, 40, 26, key="flux_nb_n")
65
+ st.plotly_chart(
66
+ plots.flux_reaction_annotation_panel(flux, top_n=top_n_nb, metric="mean_rank"),
67
+ width="stretch",
68
+ )
69
+ st.plotly_chart(
70
+ plots.flux_model_metric_profile(flux, top_n=min(top_n_nb, 24), metric="mean_rank"),
71
+ width="stretch",
72
+ )
73
+
74
+ with tab_bio:
75
+ st.caption(
76
+ "**Volcano:** differential Log₂FC versus significance (−log₁₀ adjusted p); colour shows overall mean rank. "
77
+ "Points with essentially no fold change and a zero adjusted p-value are removed as unreliable. "
78
+ "**Scatter:** average measured flux in dead-end versus reprogramming cells; point size reflects combined shift "
79
+ "and attention strength; colours mark pathway (largest groups shown, others grouped as *Other*)."
80
+ )
81
+ b1, b2 = st.columns(2)
82
+ with b1:
83
+ st.plotly_chart(plots.flux_volcano(flux), width="stretch")
84
+ with b2:
85
+ st.plotly_chart(plots.flux_dead_end_vs_reprogram_scatter(flux), width="stretch")
86
+
87
+ with tab_rank:
88
+ st.caption("Filter by reaction name or pathway, then inspect or download the ranked flux table.")
89
+ q = st.text_input("Substring filter (reaction name)", "", key="flux_q")
90
+ pw_f = st.multiselect(
91
+ "Pathway",
92
+ sorted(flux["pathway"].dropna().unique().astype(str)),
93
+ default=[],
94
+ key="flux_pw_f",
95
+ )
96
+ show = flux
97
+ if q.strip():
98
+ show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
99
+ if pw_f:
100
+ show = show[show["pathway"].astype(str).isin(pw_f)]
101
+ cols = [
102
+ c
103
+ for c in [
104
+ "mean_rank",
105
+ "feature",
106
+ "rank_shift_in_modal",
107
+ "rank_att_in_modal",
108
+ "combined_order_mod",
109
+ "rank_shift",
110
+ "rank_att",
111
+ "importance_shift",
112
+ "importance_att",
113
+ "top_10_pct",
114
+ "mean_de",
115
+ "mean_re",
116
+ "group",
117
+ "log_fc",
118
+ "pval_adj",
119
+ "pathway",
120
+ "module",
121
+ ]
122
+ if c in show.columns
123
+ ]
124
+ st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
125
+ st.download_button(
126
+ "Download Flux table (CSV)",
127
+ show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
128
+ file_name="fateformer_flux_filtered.csv",
129
+ mime="text/csv",
130
+ key="flux_dl",
131
+ )
132
+
133
+ with tab_meta:
134
+ st.caption(
135
+ "Directed substrate-to-product steps from the reference model, merged with this flux table where reaction names match."
136
+ )
137
+ if meta is None or meta.empty:
138
+ st.warning("Metabolic model metadata is not available in this build.")
139
+ else:
140
+ sm_ids = sorted(meta["Supermodule_id"].dropna().unique().astype(int).tolist())
141
+ graph_labels = ["All modules"]
142
+ for sid in sm_ids:
143
+ cls = str(meta.loc[meta["Supermodule_id"] == sid, "Super.Module.class"].iloc[0])
144
+ graph_labels.append(f"{sid}: {cls}")
145
+ tix = st.selectbox(
146
+ "Model scope",
147
+ range(len(graph_labels)),
148
+ format_func=lambda i: graph_labels[i],
149
+ key="flux_model_scope",
150
+ help="Show every step in the model, or restrict to one functional module.",
151
+ )
152
+ supermodule_id = None if tix == 0 else sm_ids[tix - 1]
153
+ tbl = io.build_metabolic_model_table(meta, flux, supermodule_id=supermodule_id)
154
+ st.dataframe(tbl, width="stretch", hide_index=True)
155
+ st.download_button(
156
+ "Download metabolic model metadata (CSV)",
157
+ tbl.to_csv(index=False).encode("utf-8"),
158
+ file_name="fateformer_metabolic_model_edges.csv",
159
+ mime="text/csv",
160
+ key="flux_model_dl",
161
+ )
streamlit_hf/pages/4_Gene_expression_analysis.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gene expression and TF motif activity: pathway enrichment, chromVAR-style motifs, and tables."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+
11
+ _REPO = Path(__file__).resolve().parents[2]
12
+ if str(_REPO) not in sys.path:
13
+ sys.path.insert(0, str(_REPO))
14
+
15
+ from streamlit_hf.lib import io
16
+ from streamlit_hf.lib import pathways as pathway_data
17
+ from streamlit_hf.lib import plots
18
+ from streamlit_hf.lib import ui
19
+
20
+ ui.inject_app_styles()
21
+
22
+ st.title("Gene Expression & TF Activity")
23
+
24
+ df = io.load_df_features()
25
+ if df is None:
26
+ st.error("Feature data could not be loaded. Reload after results are published, or contact the maintainer.")
27
+ st.stop()
28
+
29
+ rna = df[df["modality"] == "RNA"].copy()
30
+ atac = df[df["modality"] == "ATAC"].copy()
31
+ if rna.empty and atac.empty:
32
+ st.warning("No RNA gene or ATAC motif features are available in the current results.")
33
+ st.stop()
34
+
35
+ st.caption(
36
+ "Pathway enrichment (Reactome / KEGG) and a pathway–gene map; chromVAR-style motif deviations and activity by "
37
+ "fate; sortable gene and motif tables. Use **Feature Insights** for global shift and attention rankings across modalities."
38
+ )
39
+
40
+ TABLE_COLS = [
41
+ "mean_rank",
42
+ "feature",
43
+ "rank_shift_in_modal",
44
+ "rank_att_in_modal",
45
+ "combined_order_mod",
46
+ "rank_shift",
47
+ "rank_att",
48
+ "importance_shift",
49
+ "importance_att",
50
+ "top_10_pct",
51
+ "mean_de",
52
+ "mean_re",
53
+ "group",
54
+ "log_fc",
55
+ "pval_adj",
56
+ "mean_diff",
57
+ "pval_adj_log",
58
+ ]
59
+
60
+
61
+ def _table_cols(show: pd.DataFrame) -> list[str]:
62
+ return [c for c in TABLE_COLS if c in show.columns]
63
+
64
+
65
+ tab_path, tab_motif, tab_gene_tbl, tab_motif_tbl = st.tabs(
66
+ ["Gene Pathway Enrichment", "Motif Activity", "Gene Table", "Motif Table"]
67
+ )
68
+
69
+ with tab_path:
70
+ st.caption(
71
+ "Over-representation of Reactome and KEGG pathways (Benjamini–Hochberg *q* < 0.05). "
72
+ "The lower panel maps leading genes to pathways; empty grid positions are left clear."
73
+ )
74
+ raw = pathway_data.load_de_re_tsv()
75
+ if raw is None:
76
+ st.info("Pathway enrichment views are not available in this deployment.")
77
+ else:
78
+ de_all, re_all = raw
79
+ mde, mre = pathway_data.merged_reactome_kegg_bubble_frames(de_all, re_all)
80
+ bubble_h = max(
81
+ plots.pathway_bubble_suggested_height(len(mde)),
82
+ plots.pathway_bubble_suggested_height(len(mre)),
83
+ )
84
+ c1, c2 = st.columns(2, gap="medium")
85
+ with c1:
86
+ st.plotly_chart(
87
+ plots.pathway_enrichment_bubble_panel(
88
+ mde,
89
+ "Pathway enrichment — dead-end",
90
+ show_colorbar=True,
91
+ layout_height=bubble_h,
92
+ ),
93
+ width="stretch",
94
+ )
95
+ with c2:
96
+ st.plotly_chart(
97
+ plots.pathway_enrichment_bubble_panel(
98
+ mre,
99
+ "Pathway enrichment — reprogramming",
100
+ show_colorbar=True,
101
+ layout_height=bubble_h,
102
+ ),
103
+ width="stretch",
104
+ )
105
+ hm = pathway_data.build_merged_pathway_membership(de_all, re_all)
106
+ if hm is None:
107
+ st.info("No pathway–gene matrix could be built from the current enrichment results.")
108
+ else:
109
+ z, ylabs, xlabs = hm
110
+ st.plotly_chart(plots.pathway_gene_membership_heatmap(z, ylabs, xlabs), width="stretch")
111
+
112
+ with tab_motif:
113
+ if atac.empty:
114
+ st.warning("No motif-level ATAC features are available in the current results.")
115
+ else:
116
+ st.caption(
117
+ "Left: mean motif score difference (reprogramming − dead-end) versus significance. "
118
+ "Right: mean activity in each fate; colour and size follow the same encoding as in **Feature Insights**."
119
+ )
120
+ a1, a2 = st.columns(2, gap="medium")
121
+ with a1:
122
+ st.plotly_chart(plots.motif_chromvar_volcano(atac), width="stretch")
123
+ with a2:
124
+ st.plotly_chart(
125
+ plots.notebook_style_activity_scatter(
126
+ atac,
127
+ title="TF activity (z-score) by fate",
128
+ x_title="Dead-end (TF activity)",
129
+ y_title="Reprogramming (TF activity)",
130
+ ),
131
+ width="stretch",
132
+ )
133
+
134
+ with tab_gene_tbl:
135
+ if rna.empty:
136
+ st.warning("No RNA gene features are available in the current results.")
137
+ else:
138
+ q = st.text_input("Filter by gene name", "", key="ge_tbl_q")
139
+ show = rna
140
+ if q.strip():
141
+ show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
142
+ cols = _table_cols(show)
143
+ st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
144
+ st.download_button(
145
+ "Download table (CSV)",
146
+ show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
147
+ file_name="gene_expression_table.csv",
148
+ mime="text/csv",
149
+ key="ge_tbl_dl",
150
+ )
151
+
152
+ with tab_motif_tbl:
153
+ if atac.empty:
154
+ st.warning("No motif-level ATAC features are available in the current results.")
155
+ else:
156
+ q = st.text_input("Filter by motif or TF", "", key="tf_tbl_q")
157
+ show = atac
158
+ if q.strip():
159
+ show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
160
+ cols = _table_cols(show)
161
+ st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
162
+ st.download_button(
163
+ "Download table (CSV)",
164
+ show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
165
+ file_name="tf_motif_table.csv",
166
+ mime="text/csv",
167
+ key="tf_tbl_dl",
168
+ )
streamlit_hf/requirements-docker.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Hugging Face / production image: precomputed cache only (no torch)
2
+ streamlit>=1.40.0
3
+ plotly>=5.22.0
4
+ pandas>=2.0.0
5
+ numpy>=1.24.0
6
+ pyarrow>=14.0.0
streamlit_hf/static/app_icon.svg ADDED