Spaces:
Building
Building
Updated. First version.
Browse files- .dockerignore +12 -0
- .gitignore +8 -1
- Dockerfile +17 -0
- README.md +59 -0
- analysis/df_features.csv +0 -0
- notebooks/analysis_plots.ipynb +0 -0
- requirements.txt +15 -0
- scripts/precompute_streamlit_cache.py +486 -0
- streamlit_hf/.streamlit/config.toml +13 -0
- streamlit_hf/README.md +36 -0
- streamlit_hf/__init__.py +1 -0
- streamlit_hf/app.py +35 -0
- streamlit_hf/cache/.gitkeep +0 -0
- streamlit_hf/home.py +58 -0
- streamlit_hf/lib/__init__.py +0 -0
- streamlit_hf/lib/formatters.py +118 -0
- streamlit_hf/lib/io.py +171 -0
- streamlit_hf/lib/pathways.py +133 -0
- streamlit_hf/lib/plots.py +1421 -0
- streamlit_hf/lib/reactions.py +12 -0
- streamlit_hf/lib/ui.py +24 -0
- streamlit_hf/pages/1_Single_Cell_Explorer.py +158 -0
- streamlit_hf/pages/2_Feature_insights.py +294 -0
- streamlit_hf/pages/3_Flux_analysis.py +161 -0
- streamlit_hf/pages/4_Gene_expression_analysis.py +168 -0
- streamlit_hf/requirements-docker.txt +6 -0
- streamlit_hf/static/app_icon.svg +14 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
venv
|
| 4 |
+
__pycache__
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*.egg-info
|
| 7 |
+
.pytest_cache
|
| 8 |
+
.mypy_cache
|
| 9 |
+
.ruff_cache
|
| 10 |
+
.cursor
|
| 11 |
+
*.ipynb_checkpoints
|
| 12 |
+
notebooks
|
.gitignore
CHANGED
|
@@ -1,3 +1,10 @@
|
|
| 1 |
__pycache__/
|
| 2 |
|
| 3 |
-
.DS_Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
|
| 3 |
+
.DS_Store
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
|
| 7 |
+
# Precomputed explorer artifacts (regenerate with scripts/precompute_streamlit_cache.py)
|
| 8 |
+
streamlit_hf/cache/*.pkl
|
| 9 |
+
streamlit_hf/cache/*.parquet
|
| 10 |
+
streamlit_hf/cache/*.csv
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces (Streamlit + Docker). Port 7860.
|
| 2 |
+
# Build context: repository root. Upload `streamlit_hf/cache/*` (pickles + parquet) via Git LFS or CI.
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim-bookworm
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
COPY streamlit_hf/requirements-docker.txt /app/requirements-docker.txt
|
| 9 |
+
RUN pip install --no-cache-dir -r /app/requirements-docker.txt
|
| 10 |
+
|
| 11 |
+
COPY . /app
|
| 12 |
+
|
| 13 |
+
ENV PYTHONPATH=/app
|
| 14 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
CMD ["streamlit", "run", "streamlit_hf/app.py", "--server.port", "7860", "--server.address", "0.0.0.0", "--browser.gatherUsageStats", "false"]
|
README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FateFormer Explorer
|
| 3 |
+
short_description: Streamlit app to explore multimodal single-cell fate modeling (RNA, ATAC, metabolic flux, attention, and rankings).
|
| 4 |
+
emoji: 🧬
|
| 5 |
+
colorFrom: violet
|
| 6 |
+
colorTo: indigo
|
| 7 |
+
tags:
|
| 8 |
+
- streamlit
|
| 9 |
+
- single-cell
|
| 10 |
+
- multi-omics
|
| 11 |
+
- genomics
|
| 12 |
+
- atac-seq
|
| 13 |
+
- rna-seq
|
| 14 |
+
- metabolic-modeling
|
| 15 |
+
- deep-learning
|
| 16 |
+
- biology
|
| 17 |
+
license: mit
|
| 18 |
+
sdk: docker
|
| 19 |
+
app_port: 7860
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# FateFormerApp
|
| 23 |
+
|
| 24 |
+
## Interactive explorer (Streamlit)
|
| 25 |
+
|
| 26 |
+
From the repo root, with the project virtualenv activated:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
PYTHONPATH=. streamlit run streamlit_hf/app.py
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
The default local port is **8501**. The **Dockerfile** (and Hugging Face Space card above) use **7860** to match Spaces.
|
| 33 |
+
|
| 34 |
+
### Updating results after new experiments (no code changes)
|
| 35 |
+
|
| 36 |
+
The app reads **fixed paths**. Replace files under `streamlit_hf/cache/` using the **same filenames**; then **restart Streamlit** (or do a hard refresh) so the new data loads.
|
| 37 |
+
|
| 38 |
+
| File | What it drives |
|
| 39 |
+
|------|----------------|
|
| 40 |
+
| `streamlit_hf/cache/latent_umap.pkl` | Single-Cell Explorer (UMAP) |
|
| 41 |
+
| `streamlit_hf/cache/df_features.parquet` | Feature insights + Flux analysis |
|
| 42 |
+
| `streamlit_hf/cache/attention_summary.pkl` | “Attention vs prediction” in Feature insights |
|
| 43 |
+
| `streamlit_hf/cache/attention_feature_ranks.pkl` | Optional; attention lists also live inside `attention_summary.pkl` |
|
| 44 |
+
|
| 45 |
+
You can also keep `analysis/df_features.csv` in sync for your own workflows; the UI **prefers** `streamlit_hf/cache/df_features.parquet` when present.
|
| 46 |
+
|
| 47 |
+
### Regenerating caches from this repo
|
| 48 |
+
|
| 49 |
+
If you updated checkpoints, fold splits, shift pickles, or deg tables **inside this project**, run:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
python scripts/precompute_streamlit_cache.py
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
That script expects (among others) `ckp/*.pth`, `objects/fold_results_multi.pkl`, `objects/mutlimodal_dataset.pkl`, `objects/fi_shift_*.pkl`, and `objects/degs.pkl`. Point those inputs at your new experiment outputs **before** running the script, or copy your new pickles/CSVs into `streamlit_hf/cache/` manually as in the table above.
|
| 56 |
+
|
| 57 |
+
### Docker / Hugging Face
|
| 58 |
+
|
| 59 |
+
See `streamlit_hf/HUGGINGFACE.md` and the root `Dockerfile`.
|
analysis/df_features.csv
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/analysis_plots.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FateFormerApp — training, precompute, and local Streamlit dev
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
pandas>=2.0.0
|
| 5 |
+
scipy>=1.11.0
|
| 6 |
+
scikit-learn>=1.3.0
|
| 7 |
+
umap-learn>=0.5.5
|
| 8 |
+
tqdm>=4.66.0
|
| 9 |
+
anndata>=0.10.0
|
| 10 |
+
scanpy>=1.9.0
|
| 11 |
+
statsmodels>=0.14.0
|
| 12 |
+
tabulate>=0.9.0
|
| 13 |
+
streamlit>=1.40.0
|
| 14 |
+
plotly>=5.22.0
|
| 15 |
+
pyarrow>=14.0.0
|
scripts/precompute_streamlit_cache.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
One-off cache builder for the Streamlit explorer.
|
| 4 |
+
Run from the repository root:
|
| 5 |
+
python scripts/precompute_streamlit_cache.py
|
| 6 |
+
python scripts/precompute_streamlit_cache.py --skip-attention # faster: reuse objects/fi_shift_*.pkl only for df_features if attention_summary exists
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import pickle
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import torch
|
| 20 |
+
import umap
|
| 21 |
+
|
| 22 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 23 |
+
sys.path.insert(0, str(ROOT))
|
| 24 |
+
os.chdir(ROOT)
|
| 25 |
+
|
| 26 |
+
from data import create_dataset # noqa: E402
|
| 27 |
+
from interpretation import attentions as att # noqa: E402
|
| 28 |
+
from interpretation import latentspace as ls # noqa: E402
|
| 29 |
+
from interpretation import predictions as prds # noqa: E402
|
| 30 |
+
CACHE = ROOT / "streamlit_hf" / "cache"
|
| 31 |
+
CACHE.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def replace_fold_results_path(fold_results, ckp_root: str = "ckp"):
|
| 35 |
+
"""Point checkpoints at flat `ckp/multi_seed0_fold{k}.pth` layout in this repo."""
|
| 36 |
+
for fold in fold_results:
|
| 37 |
+
ckpt_name = os.path.basename(fold["best_model_path"])
|
| 38 |
+
fold_token = next((part for part in ckpt_name.split("_") if part.startswith("fold")), "")
|
| 39 |
+
fold_idx = "".join(ch for ch in fold_token if ch.isdigit())
|
| 40 |
+
if fold_idx:
|
| 41 |
+
clean_ckpt_name = f"multi_seed0_fold{fold_idx}.pth"
|
| 42 |
+
else:
|
| 43 |
+
clean_ckpt_name = ckpt_name
|
| 44 |
+
fold["best_model_path"] = os.path.join(ckp_root, clean_ckpt_name)
|
| 45 |
+
return fold_results
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_training_context():
|
| 49 |
+
with open(ROOT / "objects" / "mutlimodal_dataset.pkl", "rb") as f:
|
| 50 |
+
md = pickle.load(f)
|
| 51 |
+
X, y_label = md["X"], md["y_label"]
|
| 52 |
+
b, df_indices, pcts = md["b"], md["df_indices"], md["pcts"]
|
| 53 |
+
|
| 54 |
+
y_number = torch.tensor(
|
| 55 |
+
[{"reprogramming": 1, "dead-end": 0}[i] for i in list(y_label)],
|
| 56 |
+
dtype=torch.float32,
|
| 57 |
+
)
|
| 58 |
+
multimodal_dataset = create_dataset.MultiModalDataset(
|
| 59 |
+
X, b, y_number, df_indices, pcts, y_label
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with open(ROOT / "objects" / "fold_results_multi.pkl", "rb") as f:
|
| 63 |
+
fold_results = pickle.load(f)
|
| 64 |
+
fold_results = replace_fold_results_path(fold_results)
|
| 65 |
+
|
| 66 |
+
share_config = {
|
| 67 |
+
"d_model": 128,
|
| 68 |
+
"d_ff": 16,
|
| 69 |
+
"n_heads": 8,
|
| 70 |
+
"n_encoder_layers": 2,
|
| 71 |
+
"n_batches": 3,
|
| 72 |
+
"dropout_rate": 0.0,
|
| 73 |
+
}
|
| 74 |
+
model_config_rna = {"vocab_size": 5914, "seq_len": X[0].shape[1]}
|
| 75 |
+
model_config_atac = {"vocab_size": 1, "seq_len": X[1].shape[1]}
|
| 76 |
+
model_config_flux = {"vocab_size": 1, "seq_len": X[2].shape[1]}
|
| 77 |
+
model_config_multi = {"d_model": 128, "n_heads_cls": 8, "d_ff_cls": 16}
|
| 78 |
+
model_config = {
|
| 79 |
+
"Share": share_config,
|
| 80 |
+
"RNA": model_config_rna,
|
| 81 |
+
"ATAC": model_config_atac,
|
| 82 |
+
"Flux": model_config_flux,
|
| 83 |
+
"Multi": model_config_multi,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
feature_names = (
|
| 87 |
+
list(X[0].columns)
|
| 88 |
+
+ ["batch_rna"]
|
| 89 |
+
+ list(X[1].columns)
|
| 90 |
+
+ ["batch_atac"]
|
| 91 |
+
+ list(X[2].columns)
|
| 92 |
+
+ ["batch_flux"]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
adata_RNA_labelled = None
|
| 96 |
+
rna_pkl = ROOT / "data" / "datasets" / "rna_labelled.pkl"
|
| 97 |
+
try:
|
| 98 |
+
with open(rna_pkl, "rb") as f:
|
| 99 |
+
adata_RNA_labelled = pickle.load(f)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(
|
| 102 |
+
f"Warning: could not load {rna_pkl} ({e}). "
|
| 103 |
+
"Sample table will omit AnnData-derived metadata (e.g. clone_id)."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return (
|
| 107 |
+
multimodal_dataset,
|
| 108 |
+
fold_results,
|
| 109 |
+
model_config,
|
| 110 |
+
feature_names,
|
| 111 |
+
adata_RNA_labelled,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def build_latent_umap(multimodal_dataset, fold_results, model_config, common_samples: bool = False):
|
| 116 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 117 |
+
ls_v, labels, preds = ls.get_latent_space(
|
| 118 |
+
"Multi",
|
| 119 |
+
fold_results,
|
| 120 |
+
multimodal_dataset,
|
| 121 |
+
model_config,
|
| 122 |
+
device,
|
| 123 |
+
common_samples=common_samples,
|
| 124 |
+
)
|
| 125 |
+
reducer = umap.UMAP(n_components=2, random_state=0, n_neighbors=30, min_dist=1.0)
|
| 126 |
+
xy = reducer.fit_transform(ls_v)
|
| 127 |
+
|
| 128 |
+
ordered_indices: list[int] = []
|
| 129 |
+
fold_ids: list[int] = []
|
| 130 |
+
from interpretation.attentions import filter_idx # noqa: PLC0415
|
| 131 |
+
from torch.utils.data import Subset # noqa: PLC0415
|
| 132 |
+
|
| 133 |
+
for fold_idx, fold in enumerate(fold_results):
|
| 134 |
+
val_idx = fold["val_idx"]
|
| 135 |
+
if common_samples:
|
| 136 |
+
val_idx = filter_idx(multimodal_dataset, val_idx)
|
| 137 |
+
ordered_indices.extend(val_idx)
|
| 138 |
+
fold_ids.extend([fold_idx + 1] * len(val_idx))
|
| 139 |
+
|
| 140 |
+
labels = np.asarray(labels).ravel()
|
| 141 |
+
preds = np.asarray(preds).ravel().astype(int)
|
| 142 |
+
label_name = np.where(labels > 0.5, "reprogramming", "dead-end")
|
| 143 |
+
pred_name = np.where(preds > 0.5, "reprogramming", "dead-end")
|
| 144 |
+
correct = (preds == labels.astype(int)).astype(np.int8)
|
| 145 |
+
|
| 146 |
+
ds = multimodal_dataset
|
| 147 |
+
batch_no = np.array([int(ds.batch_no[i].item()) for i in ordered_indices], dtype=np.int32)
|
| 148 |
+
pcts = np.array([float(ds.pcts[i]) for i in ordered_indices], dtype=np.float64)
|
| 149 |
+
|
| 150 |
+
modalities = []
|
| 151 |
+
for i in ordered_indices:
|
| 152 |
+
has_r = (ds.rna_data[i] != 0).any().item()
|
| 153 |
+
has_a = (ds.atac_data[i] != 0).any().item()
|
| 154 |
+
has_f = (ds.flux_data[i] != 0).any().item()
|
| 155 |
+
s = "".join(c for c, h in (("R", has_r), ("A", has_a), ("F", has_f)) if h)
|
| 156 |
+
modalities.append(s or "None")
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"umap_x": xy[:, 0].astype(np.float32),
|
| 160 |
+
"umap_y": xy[:, 1].astype(np.float32),
|
| 161 |
+
"label_name": label_name,
|
| 162 |
+
"pred_name": pred_name,
|
| 163 |
+
"correct": correct,
|
| 164 |
+
"fold": np.array(fold_ids, dtype=np.int32),
|
| 165 |
+
"batch_no": batch_no,
|
| 166 |
+
"pct": pcts,
|
| 167 |
+
"modality": modalities,
|
| 168 |
+
"dataset_idx": np.array(ordered_indices, dtype=np.int32),
|
| 169 |
+
"common_samples": common_samples,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_combined_feature_dataframe(
|
| 174 |
+
fi_shift_rna,
|
| 175 |
+
fi_shift_atac,
|
| 176 |
+
fi_shift_flux,
|
| 177 |
+
fi_att_rna,
|
| 178 |
+
fi_att_atac,
|
| 179 |
+
fi_att_flux,
|
| 180 |
+
df_rna_degs=None,
|
| 181 |
+
df_atac_degs=None,
|
| 182 |
+
df_flux_degs=None,
|
| 183 |
+
remove_batch=True,
|
| 184 |
+
):
|
| 185 |
+
def process_modality(shift_list, att_list, degs_df, modality_name):
|
| 186 |
+
shift_df = pd.DataFrame(shift_list, columns=["feature", "importance_shift"]).reset_index()
|
| 187 |
+
shift_df.rename(columns={"index": "rank_shift_in_modal"}, inplace=True)
|
| 188 |
+
shift_df["rank_shift_in_modal"] += 1
|
| 189 |
+
|
| 190 |
+
att_df = pd.DataFrame(att_list, columns=["feature", "importance_att"]).reset_index()
|
| 191 |
+
att_df.rename(columns={"index": "rank_att_in_modal"}, inplace=True)
|
| 192 |
+
att_df["rank_att_in_modal"] += 1
|
| 193 |
+
|
| 194 |
+
combined_df = pd.merge(shift_df, att_df, on="feature", how="outer")
|
| 195 |
+
if degs_df is not None:
|
| 196 |
+
combined_df = pd.merge(combined_df, degs_df, on="feature", how="left")
|
| 197 |
+
combined_df["modality"] = modality_name
|
| 198 |
+
return combined_df
|
| 199 |
+
|
| 200 |
+
rna_df = process_modality(fi_shift_rna, fi_att_rna, df_rna_degs, "RNA")
|
| 201 |
+
atac_df = process_modality(fi_shift_atac, fi_att_atac, df_atac_degs, "ATAC")
|
| 202 |
+
flux_df = process_modality(fi_shift_flux, fi_att_flux, df_flux_degs, "Flux")
|
| 203 |
+
all_features_df = pd.concat([rna_df, atac_df, flux_df], ignore_index=True)
|
| 204 |
+
|
| 205 |
+
if remove_batch:
|
| 206 |
+
all_features_df = all_features_df[~all_features_df["feature"].str.contains("batch", na=False)]
|
| 207 |
+
|
| 208 |
+
max_rank_modal = max(
|
| 209 |
+
all_features_df["rank_att_in_modal"].max(), all_features_df["rank_shift_in_modal"].max()
|
| 210 |
+
)
|
| 211 |
+
all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
|
| 212 |
+
["rank_att_in_modal", "rank_shift_in_modal"]
|
| 213 |
+
].fillna(max_rank_modal + 1)
|
| 214 |
+
all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
|
| 215 |
+
["rank_att_in_modal", "rank_shift_in_modal"]
|
| 216 |
+
].astype("int32")
|
| 217 |
+
|
| 218 |
+
all_features_df[["importance_att", "importance_shift"]] = (
|
| 219 |
+
all_features_df[["importance_att", "importance_shift"]].fillna(0).astype("float64")
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
all_features_df["rank_shift"] = (
|
| 223 |
+
all_features_df["importance_shift"].rank(ascending=False, method="first").astype("int32")
|
| 224 |
+
)
|
| 225 |
+
all_features_df["rank_att"] = (
|
| 226 |
+
all_features_df["importance_att"].rank(ascending=False, method="first").astype("int32")
|
| 227 |
+
)
|
| 228 |
+
all_features_df["mean_rank"] = all_features_df[["rank_att", "rank_shift"]].mean(axis=1)
|
| 229 |
+
|
| 230 |
+
top_th = int(all_features_df.shape[0] * 0.1) + 1
|
| 231 |
+
all_features_df["top_10_pct"] = all_features_df.apply(
|
| 232 |
+
lambda row: "both"
|
| 233 |
+
if row["rank_shift"] <= top_th and row["rank_att"] <= top_th
|
| 234 |
+
else (
|
| 235 |
+
"shift"
|
| 236 |
+
if row["rank_shift"] <= top_th
|
| 237 |
+
else ("att" if row["rank_att"] <= top_th else "None")
|
| 238 |
+
),
|
| 239 |
+
axis=1,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
float_cols = [
|
| 243 |
+
col for col in all_features_df.columns if col.startswith(("log_fc", "mean_", "std_", "pval_"))
|
| 244 |
+
]
|
| 245 |
+
if float_cols:
|
| 246 |
+
all_features_df[float_cols] = all_features_df[float_cols].round(6)
|
| 247 |
+
all_features_df["importance_att"] = all_features_df["importance_att"].round(6)
|
| 248 |
+
all_features_df["importance_shift"] = all_features_df["importance_shift"].round(6)
|
| 249 |
+
all_features_df = all_features_df.sort_values(by="mean_rank", ascending=True)
|
| 250 |
+
|
| 251 |
+
cols = [
|
| 252 |
+
"mean_rank",
|
| 253 |
+
"feature",
|
| 254 |
+
"rank_shift",
|
| 255 |
+
"rank_att",
|
| 256 |
+
"rank_shift_in_modal",
|
| 257 |
+
"rank_att_in_modal",
|
| 258 |
+
"modality",
|
| 259 |
+
"importance_shift",
|
| 260 |
+
"importance_att",
|
| 261 |
+
"top_10_pct",
|
| 262 |
+
"mean_de",
|
| 263 |
+
"mean_re",
|
| 264 |
+
"std_de",
|
| 265 |
+
"std_re",
|
| 266 |
+
"pval",
|
| 267 |
+
"pval_adj",
|
| 268 |
+
"log_fc",
|
| 269 |
+
"group",
|
| 270 |
+
"pval_adj_log",
|
| 271 |
+
"mean_diff",
|
| 272 |
+
"pathway",
|
| 273 |
+
"module",
|
| 274 |
+
]
|
| 275 |
+
for c in cols:
|
| 276 |
+
if c not in all_features_df.columns:
|
| 277 |
+
all_features_df[c] = np.nan
|
| 278 |
+
return all_features_df[cols]
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def run_attention_and_fi(
|
| 282 |
+
multimodal_dataset,
|
| 283 |
+
fold_results,
|
| 284 |
+
model_config,
|
| 285 |
+
feature_names,
|
| 286 |
+
device: str,
|
| 287 |
+
adata_rna,
|
| 288 |
+
):
|
| 289 |
+
df_samples = prds.get_sample_predictions_dataframe(
|
| 290 |
+
model_type="Multi",
|
| 291 |
+
multimodal_dataset=multimodal_dataset,
|
| 292 |
+
fold_results=fold_results,
|
| 293 |
+
model_config=model_config,
|
| 294 |
+
device=device,
|
| 295 |
+
batch_size=32,
|
| 296 |
+
threshold=0.5,
|
| 297 |
+
adata_rna=adata_rna,
|
| 298 |
+
)
|
| 299 |
+
all_indices = df_samples["ind"].tolist()
|
| 300 |
+
de_preds_indices = df_samples[df_samples["predicted_class"] == "dead-end"]["ind"].tolist()
|
| 301 |
+
re_preds_indices = df_samples[df_samples["predicted_class"] == "reprogramming"]["ind"].tolist()
|
| 302 |
+
|
| 303 |
+
print("Running flow attention (all validation)…")
|
| 304 |
+
all_layers_all = att.analyze_cls_attention(
|
| 305 |
+
"Multi",
|
| 306 |
+
fold_results,
|
| 307 |
+
multimodal_dataset,
|
| 308 |
+
model_config,
|
| 309 |
+
device=device,
|
| 310 |
+
indices=all_indices,
|
| 311 |
+
average_heads=False,
|
| 312 |
+
return_flow_attention=True,
|
| 313 |
+
)
|
| 314 |
+
print("Running flow attention (predicted dead-end)…")
|
| 315 |
+
all_layers_de = att.analyze_cls_attention(
|
| 316 |
+
"Multi",
|
| 317 |
+
fold_results,
|
| 318 |
+
multimodal_dataset,
|
| 319 |
+
model_config,
|
| 320 |
+
device=device,
|
| 321 |
+
indices=de_preds_indices,
|
| 322 |
+
average_heads=False,
|
| 323 |
+
return_flow_attention=True,
|
| 324 |
+
)
|
| 325 |
+
print("Running flow attention (predicted reprogramming)…")
|
| 326 |
+
all_layers_re = att.analyze_cls_attention(
|
| 327 |
+
"Multi",
|
| 328 |
+
fold_results,
|
| 329 |
+
multimodal_dataset,
|
| 330 |
+
model_config,
|
| 331 |
+
device=device,
|
| 332 |
+
indices=re_preds_indices,
|
| 333 |
+
average_heads=False,
|
| 334 |
+
return_flow_attention=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
rollout_all = att.multimodal_attention_rollout(all_layers_all)
|
| 338 |
+
rollout_de = att.multimodal_attention_rollout(all_layers_de)
|
| 339 |
+
rollout_re = att.multimodal_attention_rollout(all_layers_re)
|
| 340 |
+
rollout_all = rollout_all / rollout_all.sum(dim=-1, keepdim=True)
|
| 341 |
+
rollout_de = rollout_de / rollout_de.sum(dim=-1, keepdim=True)
|
| 342 |
+
rollout_re = rollout_re / rollout_re.sum(dim=-1, keepdim=True)
|
| 343 |
+
|
| 344 |
+
# Explicit splits (notebook): RNA [:945], ATAC [945:945+884], flux rest
|
| 345 |
+
i0, i1, i2 = 0, 945, 945 + 884
|
| 346 |
+
|
| 347 |
+
def mean_vec(t):
|
| 348 |
+
return t.mean(dim=0).detach().cpu().numpy()
|
| 349 |
+
|
| 350 |
+
rollout_mean = {
|
| 351 |
+
"all": mean_vec(rollout_all),
|
| 352 |
+
"dead_end": mean_vec(rollout_de),
|
| 353 |
+
"reprogramming": mean_vec(rollout_re),
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
top_n_get = None
|
| 357 |
+
fi = {"all": {}, "dead_end": {}, "reprogramming": {}}
|
| 358 |
+
for name, tensor in (
|
| 359 |
+
("all", rollout_all),
|
| 360 |
+
("dead_end", rollout_de),
|
| 361 |
+
("reprogramming", rollout_re),
|
| 362 |
+
):
|
| 363 |
+
fi[name]["rna"] = att.get_top_features(
|
| 364 |
+
tensor[:, i0:i1], feature_names[i0:i1], modality="RNA", top_n=top_n_get
|
| 365 |
+
)
|
| 366 |
+
fi[name]["atac"] = att.get_top_features(
|
| 367 |
+
tensor[:, i1:i2], feature_names[i1:i2], modality="ATAC", top_n=top_n_get
|
| 368 |
+
)
|
| 369 |
+
fi[name]["flux"] = att.get_top_features(
|
| 370 |
+
tensor[:, i2:], feature_names[i2:], modality="Flux", top_n=top_n_get
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
summary = {
|
| 374 |
+
"feature_names": feature_names,
|
| 375 |
+
"slices": {
|
| 376 |
+
"RNA": {"start": i0, "stop": i1},
|
| 377 |
+
"ATAC": {"start": i1, "stop": i2},
|
| 378 |
+
"Flux": {"start": i2, "stop": len(feature_names)},
|
| 379 |
+
},
|
| 380 |
+
"rollout_mean": rollout_mean,
|
| 381 |
+
"fi_att": fi,
|
| 382 |
+
}
|
| 383 |
+
return summary, df_samples
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def main():
|
| 387 |
+
ap = argparse.ArgumentParser()
|
| 388 |
+
ap.add_argument("--skip-attention", action="store_true", help="Skip attention if summary exists")
|
| 389 |
+
ap.add_argument(
|
| 390 |
+
"--common-samples",
|
| 391 |
+
action="store_true",
|
| 392 |
+
help="Use common-samples filter for latent UMAP (default: False, notebook-style)",
|
| 393 |
+
)
|
| 394 |
+
args = ap.parse_args()
|
| 395 |
+
common_samples = args.common_samples
|
| 396 |
+
|
| 397 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 398 |
+
print(f"Device: {device}")
|
| 399 |
+
|
| 400 |
+
(
|
| 401 |
+
multimodal_dataset,
|
| 402 |
+
fold_results,
|
| 403 |
+
model_config,
|
| 404 |
+
feature_names,
|
| 405 |
+
adata_RNA_labelled,
|
| 406 |
+
) = load_training_context()
|
| 407 |
+
|
| 408 |
+
print("Building latent UMAP bundle…")
|
| 409 |
+
latent = build_latent_umap(
|
| 410 |
+
multimodal_dataset, fold_results, model_config, common_samples=common_samples
|
| 411 |
+
)
|
| 412 |
+
with open(CACHE / "latent_umap.pkl", "wb") as f:
|
| 413 |
+
pickle.dump(latent, f)
|
| 414 |
+
|
| 415 |
+
att_path = CACHE / "attention_summary.pkl"
|
| 416 |
+
df_samples_path = CACHE / "samples.parquet"
|
| 417 |
+
|
| 418 |
+
if args.skip_attention and att_path.is_file():
|
| 419 |
+
print("Skipping attention (--skip-attention, file exists).")
|
| 420 |
+
with open(att_path, "rb") as f:
|
| 421 |
+
summary = pickle.load(f)
|
| 422 |
+
else:
|
| 423 |
+
print("Computing attention + rollout (slow)…")
|
| 424 |
+
summary, df_samples = run_attention_and_fi(
|
| 425 |
+
multimodal_dataset,
|
| 426 |
+
fold_results,
|
| 427 |
+
model_config,
|
| 428 |
+
feature_names,
|
| 429 |
+
device,
|
| 430 |
+
adata_RNA_labelled,
|
| 431 |
+
)
|
| 432 |
+
with open(att_path, "wb") as f:
|
| 433 |
+
pickle.dump(summary, f)
|
| 434 |
+
with open(CACHE / "attention_feature_ranks.pkl", "wb") as f:
|
| 435 |
+
pickle.dump(summary["fi_att"], f)
|
| 436 |
+
df_samples.to_parquet(df_samples_path, index=False)
|
| 437 |
+
|
| 438 |
+
if args.skip_attention and att_path.is_file() and not df_samples_path.is_file():
|
| 439 |
+
df_samples = prds.get_sample_predictions_dataframe(
|
| 440 |
+
model_type="Multi",
|
| 441 |
+
multimodal_dataset=multimodal_dataset,
|
| 442 |
+
fold_results=fold_results,
|
| 443 |
+
model_config=model_config,
|
| 444 |
+
device=device,
|
| 445 |
+
batch_size=32,
|
| 446 |
+
threshold=0.5,
|
| 447 |
+
adata_rna=adata_RNA_labelled,
|
| 448 |
+
)
|
| 449 |
+
df_samples.to_parquet(df_samples_path, index=False)
|
| 450 |
+
|
| 451 |
+
for name in ["fi_shift_rna.pkl", "fi_shift_atac.pkl", "fi_shift_flux.pkl"]:
|
| 452 |
+
src = ROOT / "objects" / name
|
| 453 |
+
if not src.is_file():
|
| 454 |
+
print(f"Warning: missing {src}")
|
| 455 |
+
|
| 456 |
+
with open(ROOT / "objects" / "fi_shift_rna.pkl", "rb") as f:
|
| 457 |
+
fi_shift_rna = pickle.load(f)
|
| 458 |
+
with open(ROOT / "objects" / "fi_shift_atac.pkl", "rb") as f:
|
| 459 |
+
fi_shift_atac = pickle.load(f)
|
| 460 |
+
with open(ROOT / "objects" / "fi_shift_flux.pkl", "rb") as f:
|
| 461 |
+
fi_shift_flux = pickle.load(f)
|
| 462 |
+
|
| 463 |
+
with open(ROOT / "objects" / "degs.pkl", "rb") as f:
|
| 464 |
+
degs = pickle.load(f)
|
| 465 |
+
df_rna_degs, df_atac_degs, df_flux_degs = degs[0], degs[1], degs[2]
|
| 466 |
+
|
| 467 |
+
fi = summary["fi_att"]
|
| 468 |
+
df_features = create_combined_feature_dataframe(
|
| 469 |
+
fi_shift_rna,
|
| 470 |
+
fi_shift_atac,
|
| 471 |
+
fi_shift_flux,
|
| 472 |
+
fi["all"]["rna"],
|
| 473 |
+
fi["all"]["atac"],
|
| 474 |
+
fi["all"]["flux"],
|
| 475 |
+
df_rna_degs,
|
| 476 |
+
df_atac_degs,
|
| 477 |
+
df_flux_degs,
|
| 478 |
+
)
|
| 479 |
+
df_features.to_parquet(CACHE / "df_features.parquet", index=False)
|
| 480 |
+
df_features.to_csv(ROOT / "analysis" / "df_features.csv", index=False)
|
| 481 |
+
print(f"Wrote {CACHE / 'df_features.parquet'} and analysis/df_features.csv")
|
| 482 |
+
print("Done.")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
if __name__ == "__main__":
|
| 486 |
+
main()
|
streamlit_hf/.streamlit/config.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor = "#2563eb"
|
| 3 |
+
backgroundColor = "#f8fafc"
|
| 4 |
+
secondaryBackgroundColor = "#ffffff"
|
| 5 |
+
textColor = "#0f172a"
|
| 6 |
+
font = "sans-serif"
|
| 7 |
+
|
| 8 |
+
[server]
|
| 9 |
+
headless = true
|
| 10 |
+
# Default CORS + XSRF settings avoid the "enableCORS=false vs XSRF" conflict on localhost.
|
| 11 |
+
|
| 12 |
+
[browser]
|
| 13 |
+
gatherUsageStats = false
|
streamlit_hf/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space (Docker + Streamlit)
|
| 2 |
+
|
| 3 |
+
The **root `README.md`** starts with the YAML card Hugging Face reads for the Space (title, tags, colours, `sdk: docker`, `app_port: 7860`). Copy that block if you maintain a separate Space README.
|
| 4 |
+
|
| 5 |
+
```yaml
|
| 6 |
+
---
|
| 7 |
+
title: FateFormer Explorer
|
| 8 |
+
short_description: Streamlit app to explore multimodal single-cell fate modeling (RNA, ATAC, metabolic flux, attention, and rankings).
|
| 9 |
+
emoji: 🧬
|
| 10 |
+
colorFrom: violet
|
| 11 |
+
colorTo: indigo
|
| 12 |
+
tags:
|
| 13 |
+
- streamlit
|
| 14 |
+
- single-cell
|
| 15 |
+
- multi-omics
|
| 16 |
+
- genomics
|
| 17 |
+
- atac-seq
|
| 18 |
+
- rna-seq
|
| 19 |
+
- metabolic-modeling
|
| 20 |
+
- deep-learning
|
| 21 |
+
- biology
|
| 22 |
+
license: mit
|
| 23 |
+
sdk: docker
|
| 24 |
+
app_port: 7860
|
| 25 |
+
---
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
`app_port` **7860** matches the root **`Dockerfile`** (`streamlit ... --server.port 7860`). Local runs use Streamlit’s default **8501** unless you pass `--server.port`.
|
| 29 |
+
|
| 30 |
+
## Before first deploy
|
| 31 |
+
|
| 32 |
+
1. Run locally: `python scripts/precompute_streamlit_cache.py` (requires GPU/CPU time for attention).
|
| 33 |
+
2. Commit **`streamlit_hf/cache/`** contents (`latent_umap.pkl`, `attention_summary.pkl`, `attention_feature_ranks.pkl`, `df_features.parquet`, and optionally `samples.parquet` if you use it elsewhere) or attach via **Git LFS** if files are large. These paths are listed in `.gitignore`; use `git add -f streamlit_hf/cache/*` when you want them in the remote.
|
| 34 |
+
3. Keep **`ckp/`** model weights available only if you run precompute in CI; the slim Docker image does **not** include PyTorch and expects precomputed caches.
|
| 35 |
+
|
| 36 |
+
The repository **`Dockerfile`** at the root builds the Space.
|
streamlit_hf/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Streamlit explorer package (run with PYTHONPATH=repo root).
|
streamlit_hf/app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FateFormer: interactive analysis explorer.
|
| 3 |
+
Run from repository root: PYTHONPATH=. streamlit run streamlit_hf/app.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
_APP_DIR = Path(__file__).resolve().parent
|
| 11 |
+
_ICON_PATH = _APP_DIR / "static" / "app_icon.svg"
|
| 12 |
+
_page_icon_kw = {"page_icon": str(_ICON_PATH)} if _ICON_PATH.is_file() else {}
|
| 13 |
+
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="FateFormer Explorer",
|
| 16 |
+
layout="wide",
|
| 17 |
+
initial_sidebar_state="expanded",
|
| 18 |
+
**_page_icon_kw,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
_home = str(_APP_DIR / "home.py")
|
| 22 |
+
_p1 = str(_APP_DIR / "pages" / "1_Single_Cell_Explorer.py")
|
| 23 |
+
_p2 = str(_APP_DIR / "pages" / "2_Feature_insights.py")
|
| 24 |
+
_p3 = str(_APP_DIR / "pages" / "3_Flux_analysis.py")
|
| 25 |
+
_p4 = str(_APP_DIR / "pages" / "4_Gene_expression_analysis.py")
|
| 26 |
+
|
| 27 |
+
pages = [
|
| 28 |
+
st.Page(_home, title="Home", icon=":material/home:", default=True),
|
| 29 |
+
st.Page(_p1, title="Single-Cell Explorer", icon=":material/scatter_plot:"),
|
| 30 |
+
st.Page(_p2, title="Feature Insights", icon=":material/analytics:"),
|
| 31 |
+
st.Page(_p3, title="Flux Analysis", icon=":material/account_tree:"),
|
| 32 |
+
st.Page(_p4, title="Gene Expression & TF Activity", icon=":material/genetics:"),
|
| 33 |
+
]
|
| 34 |
+
nav = st.navigation(pages)
|
| 35 |
+
nav.run()
|
streamlit_hf/cache/.gitkeep
ADDED
|
File without changes
|
streamlit_hf/home.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Landing content for the FateFormer Streamlit hub."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
_REPO = Path(__file__).resolve().parents[1]
|
| 11 |
+
if str(_REPO) not in sys.path:
|
| 12 |
+
sys.path.insert(0, str(_REPO))
|
| 13 |
+
|
| 14 |
+
from streamlit_hf.lib import ui
|
| 15 |
+
|
| 16 |
+
_CACHE = Path(__file__).resolve().parent / "cache"
|
| 17 |
+
_HAS_CACHE = (_CACHE / "latent_umap.pkl").is_file() and (_CACHE / "df_features.parquet").is_file()
|
| 18 |
+
|
| 19 |
+
ui.inject_app_styles()
|
| 20 |
+
|
| 21 |
+
st.title("FateFormer: interactive analysis")
|
| 22 |
+
st.caption("Choose a workspace below or use the sidebar. All views use the same precomputed validation results.")
|
| 23 |
+
|
| 24 |
+
if not _HAS_CACHE:
|
| 25 |
+
st.warning(
|
| 26 |
+
"This deployment does not have precomputed results yet. Ask the maintainer to publish data, then reload."
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
st.success("Precomputed results are available. After a server-side update, refresh the browser to load new plots.")
|
| 30 |
+
|
| 31 |
+
st.subheader("Open a page")
|
| 32 |
+
r1a, r1b, r1c = st.columns(3)
|
| 33 |
+
with r1a:
|
| 34 |
+
with st.container(border=True):
|
| 35 |
+
st.page_link("pages/1_Single_Cell_Explorer.py", label="Single-Cell Explorer", icon=":material/scatter_plot:")
|
| 36 |
+
st.caption("Latent UMAP: colour by fate, prediction, fold, batch, modalities, or dominant fate emphasis.")
|
| 37 |
+
with r1b:
|
| 38 |
+
with st.container(border=True):
|
| 39 |
+
st.page_link("pages/2_Feature_insights.py", label="Feature Insights", icon=":material/analytics:")
|
| 40 |
+
st.caption("Shift and attention rankings, cohort comparisons, and full feature tables.")
|
| 41 |
+
with r1c:
|
| 42 |
+
with st.container(border=True):
|
| 43 |
+
st.page_link("pages/3_Flux_analysis.py", label="Flux Analysis", icon=":material/account_tree:")
|
| 44 |
+
st.caption("Reaction pathways, differential flux, rankings, and model metadata.")
|
| 45 |
+
r2a, _, _ = st.columns(3)
|
| 46 |
+
with r2a:
|
| 47 |
+
with st.container(border=True):
|
| 48 |
+
st.page_link(
|
| 49 |
+
"pages/4_Gene_expression_analysis.py",
|
| 50 |
+
label="Gene Expression & TF Activity",
|
| 51 |
+
icon=":material/genetics:",
|
| 52 |
+
)
|
| 53 |
+
st.caption("Pathway enrichment, motif activity, and gene / motif tables.")
|
| 54 |
+
|
| 55 |
+
st.markdown("---")
|
| 56 |
+
st.markdown(
|
| 57 |
+
"**Tips:** use chart toolbars for pan/zoom and lasso selection where offered. Tables support search and column sort from the header row."
|
| 58 |
+
)
|
streamlit_hf/lib/__init__.py
ADDED
|
File without changes
|
streamlit_hf/lib/formatters.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Human-readable labels for compact codes used in cached tables."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
# Matches interpretation.predictions._get_modality_info letter codes (R/A/F order).
|
| 9 |
+
# Short table-friendly labels (no long parentheticals).
|
| 10 |
+
_MODALITY_LONG: dict[str, str] = {
|
| 11 |
+
"RAF": "RNA + ATAC + Flux",
|
| 12 |
+
"RA": "RNA + ATAC",
|
| 13 |
+
"RF": "RNA + Flux",
|
| 14 |
+
"AF": "ATAC + Flux",
|
| 15 |
+
"R": "RNA only",
|
| 16 |
+
"A": "ATAC only",
|
| 17 |
+
"F": "Flux only",
|
| 18 |
+
"None": "No modality data",
|
| 19 |
+
"none": "No modality data",
|
| 20 |
+
"nan": "No modality data",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Rename row fields in inspector tables for display.
|
| 24 |
+
_FIELD_DISPLAY: dict[str, str] = {
|
| 25 |
+
"label": "CellTag-Multi label",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# Latent explorer: table headers and key–value inspector (exclude non-meaningful / internal cols).
|
| 29 |
+
LATENT_TABLE_RENAME: dict[str, str] = {
|
| 30 |
+
"label": "CellTag-Multi label",
|
| 31 |
+
"predicted_class": "Predicted fate",
|
| 32 |
+
"predicted_value": "Prediction score",
|
| 33 |
+
"correct": "Prediction correct",
|
| 34 |
+
"pct": "Dominant fate (%)",
|
| 35 |
+
"modality_label": "Available modalities",
|
| 36 |
+
"dataset_idx": "Dataset index",
|
| 37 |
+
"batch_no": "Batch",
|
| 38 |
+
"fold": "CV fold",
|
| 39 |
+
"clone_id": "Clone ID",
|
| 40 |
+
"clone_size": "Clone size",
|
| 41 |
+
"cell_type": "Cell type",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
LATENT_DROP_FROM_TABLES: frozenset[str] = frozenset({"umap_x", "umap_y", "modality", "pct_decile"})
|
| 45 |
+
|
| 46 |
+
_NAME_MAP = {**_FIELD_DISPLAY, **LATENT_TABLE_RENAME}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _format_scalar(v) -> str:
|
| 50 |
+
if v is None:
|
| 51 |
+
return ""
|
| 52 |
+
if isinstance(v, bool):
|
| 53 |
+
return "Yes" if v else "No"
|
| 54 |
+
try:
|
| 55 |
+
if pd.isna(v):
|
| 56 |
+
return ""
|
| 57 |
+
except (ValueError, TypeError):
|
| 58 |
+
pass
|
| 59 |
+
if isinstance(v, (float, np.floating)) and np.isnan(v):
|
| 60 |
+
return ""
|
| 61 |
+
return str(v)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _field_label(name: str, *, fallback_field_display: bool) -> str:
|
| 65 |
+
k = str(name)
|
| 66 |
+
if fallback_field_display:
|
| 67 |
+
return _NAME_MAP.get(k, _FIELD_DISPLAY.get(k, k))
|
| 68 |
+
return _NAME_MAP.get(k, k)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def expand_modality(code) -> str:
|
| 72 |
+
"""Map R/A/F codes (e.g. RAF, RA) to full names."""
|
| 73 |
+
if code is None:
|
| 74 |
+
return _MODALITY_LONG["None"]
|
| 75 |
+
try:
|
| 76 |
+
if pd.isna(code):
|
| 77 |
+
return _MODALITY_LONG["None"]
|
| 78 |
+
except (ValueError, TypeError):
|
| 79 |
+
pass
|
| 80 |
+
if isinstance(code, (float, np.floating)) and np.isnan(code):
|
| 81 |
+
return _MODALITY_LONG["None"]
|
| 82 |
+
key = str(code).strip()
|
| 83 |
+
if not key or key.lower() == "nan":
|
| 84 |
+
return _MODALITY_LONG["None"]
|
| 85 |
+
return _MODALITY_LONG.get(key, key)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def annotate_modality_column(df, code_col: str = "modality", label_col: str = "modality_label"):
|
| 89 |
+
"""Add human-readable modality column; returns a copy."""
|
| 90 |
+
out = df.copy()
|
| 91 |
+
out[label_col] = out[code_col].map(expand_modality)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def prepare_latent_display_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
| 96 |
+
"""Drop UMAP / internal columns and rename headers for Selected-points style tables."""
|
| 97 |
+
drop = [c for c in df.columns if c in LATENT_DROP_FROM_TABLES or str(c).startswith("umap_")]
|
| 98 |
+
out = df.drop(columns=drop, errors="ignore")
|
| 99 |
+
return out.rename(columns=LATENT_TABLE_RENAME)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def latent_inspector_key_value(series: pd.Series) -> pd.DataFrame:
|
| 103 |
+
"""Key–value inspector row: human names, no UMAP coordinates."""
|
| 104 |
+
s = series.drop(
|
| 105 |
+
labels=[c for c in series.index if c in LATENT_DROP_FROM_TABLES or str(c).startswith("umap_")],
|
| 106 |
+
errors="ignore",
|
| 107 |
+
)
|
| 108 |
+
idx = [_field_label(i, fallback_field_display=False) for i in s.index]
|
| 109 |
+
vals = [_format_scalar(v) for v in s.values]
|
| 110 |
+
return pd.DataFrame({"Field": idx, "Value": vals})
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dataframe_to_arrow_safe_kv(series: pd.Series) -> pd.DataFrame:
|
| 114 |
+
"""Two string columns for Streamlit/PyArrow (avoids mixed-type single column)."""
|
| 115 |
+
s = series.copy()
|
| 116 |
+
idx = [_field_label(i, fallback_field_display=True) for i in s.index]
|
| 117 |
+
vals = [_format_scalar(v) for v in s.values]
|
| 118 |
+
return pd.DataFrame({"field": idx, "value": vals})
|
streamlit_hf/lib/io.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load precomputed explorer artifacts (no torch required at runtime)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pickle
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from streamlit_hf.lib.formatters import annotate_modality_column
|
| 12 |
+
from streamlit_hf.lib.reactions import normalize_reaction_key
|
| 13 |
+
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 15 |
+
CACHE_DIR = REPO_ROOT / "streamlit_hf" / "cache"
|
| 16 |
+
METABOLIC_MODEL_METADATA = REPO_ROOT / "data" / "datasets" / "metabolic_model_metadata.csv"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _is_valid_features_csv(path: Path) -> bool:
|
| 20 |
+
if not path.is_file():
|
| 21 |
+
return False
|
| 22 |
+
try:
|
| 23 |
+
head = pd.read_csv(path, nrows=2)
|
| 24 |
+
except Exception:
|
| 25 |
+
return False
|
| 26 |
+
return "feature" in head.columns and "importance_shift" in head.columns
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_latent_bundle():
|
| 30 |
+
path = CACHE_DIR / "latent_umap.pkl"
|
| 31 |
+
if not path.is_file():
|
| 32 |
+
return None
|
| 33 |
+
with open(path, "rb") as f:
|
| 34 |
+
return pickle.load(f)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_attention_summary():
|
| 38 |
+
path = CACHE_DIR / "attention_summary.pkl"
|
| 39 |
+
if not path.is_file():
|
| 40 |
+
return None
|
| 41 |
+
with open(path, "rb") as f:
|
| 42 |
+
return pickle.load(f)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_samples_df() -> pd.DataFrame | None:
|
| 46 |
+
pq = CACHE_DIR / "samples.parquet"
|
| 47 |
+
if pq.is_file():
|
| 48 |
+
df = pd.read_parquet(pq)
|
| 49 |
+
return annotate_modality_column(df) if "modality" in df.columns else df
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _add_within_modality_orders(df: pd.DataFrame) -> pd.DataFrame:
|
| 54 |
+
"""
|
| 55 |
+
Align scatter / table columns with the notebook.
|
| 56 |
+
|
| 57 |
+
Parquet from precompute already has rank_shift_in_modal / rank_att_in_modal from the same
|
| 58 |
+
merge-of-sorted-lists logic as the notebook; do not overwrite those with pandas ranks on
|
| 59 |
+
rounded importances (tie order can differ and changes the RNA cloud).
|
| 60 |
+
"""
|
| 61 |
+
out = df.copy()
|
| 62 |
+
if "modality" not in out.columns:
|
| 63 |
+
return out
|
| 64 |
+
if "rank_shift_in_modal" in out.columns and "rank_att_in_modal" in out.columns:
|
| 65 |
+
out["shift_order_mod"] = out["rank_shift_in_modal"].astype(int)
|
| 66 |
+
out["attention_order_mod"] = out["rank_att_in_modal"].astype(int)
|
| 67 |
+
else:
|
| 68 |
+
g = out.groupby("modality", observed=True)
|
| 69 |
+
out["shift_order_mod"] = g["importance_shift"].rank(ascending=False, method="first").astype(int)
|
| 70 |
+
out["attention_order_mod"] = g["importance_att"].rank(ascending=False, method="first").astype(int)
|
| 71 |
+
out["rank_shift_in_modal"] = out["shift_order_mod"]
|
| 72 |
+
out["rank_att_in_modal"] = out["attention_order_mod"]
|
| 73 |
+
if "combined_order_mod" not in out.columns:
|
| 74 |
+
g = out.groupby("modality", observed=True)
|
| 75 |
+
out["combined_order_mod"] = g["mean_rank"].rank(ascending=True, method="first").astype(int)
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_metabolic_model_metadata() -> pd.DataFrame | None:
|
| 80 |
+
"""Directed reaction edges: substrate → product, grouped by supermodule (see CSV headers)."""
|
| 81 |
+
if not METABOLIC_MODEL_METADATA.is_file():
|
| 82 |
+
return None
|
| 83 |
+
return pd.read_csv(METABOLIC_MODEL_METADATA)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def build_metabolic_model_table(
|
| 87 |
+
meta: pd.DataFrame,
|
| 88 |
+
flux_df: pd.DataFrame,
|
| 89 |
+
supermodule_id: int | None = None,
|
| 90 |
+
) -> pd.DataFrame:
|
| 91 |
+
"""
|
| 92 |
+
Static edge list: substrate → product, reaction label, module class, plus DE / model columns when the
|
| 93 |
+
reaction string matches a row in the flux feature table.
|
| 94 |
+
"""
|
| 95 |
+
need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Supermodule_id", "Super.Module.class"}
|
| 96 |
+
if not need.issubset(set(meta.columns)):
|
| 97 |
+
return pd.DataFrame()
|
| 98 |
+
m = meta.copy()
|
| 99 |
+
if supermodule_id is not None:
|
| 100 |
+
m = m[m["Supermodule_id"] == int(supermodule_id)]
|
| 101 |
+
if m.empty:
|
| 102 |
+
return pd.DataFrame()
|
| 103 |
+
|
| 104 |
+
fd = flux_df.copy()
|
| 105 |
+
fd["_rk"] = fd["feature"].map(normalize_reaction_key)
|
| 106 |
+
fd = fd.drop_duplicates("_rk", keep="first").set_index("_rk", drop=False)
|
| 107 |
+
|
| 108 |
+
rows: list[dict] = []
|
| 109 |
+
for _, r in m.iterrows():
|
| 110 |
+
k = normalize_reaction_key(str(r["rxnName"]))
|
| 111 |
+
base = {
|
| 112 |
+
"Supermodule": r.get("Super.Module.class"),
|
| 113 |
+
"Module_id": r.get("Module_id"),
|
| 114 |
+
"Substrate": r["Compound_IN_name"],
|
| 115 |
+
"Product": r["Compound_OUT_name"],
|
| 116 |
+
"Reaction": r["rxnName"],
|
| 117 |
+
}
|
| 118 |
+
if k in fd.index:
|
| 119 |
+
row = fd.loc[k]
|
| 120 |
+
if isinstance(row, pd.DataFrame):
|
| 121 |
+
row = row.iloc[0]
|
| 122 |
+
base["log_fc"] = row["log_fc"] if "log_fc" in row.index else None
|
| 123 |
+
base["pval_adj"] = row["pval_adj"] if "pval_adj" in row.index else None
|
| 124 |
+
base["mean_rank"] = row["mean_rank"] if "mean_rank" in row.index else None
|
| 125 |
+
base["pathway"] = row["pathway"] if "pathway" in row.index else None
|
| 126 |
+
else:
|
| 127 |
+
base["log_fc"] = None
|
| 128 |
+
base["pval_adj"] = None
|
| 129 |
+
base["mean_rank"] = None
|
| 130 |
+
base["pathway"] = None
|
| 131 |
+
rows.append(base)
|
| 132 |
+
return pd.DataFrame(rows)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_df_features() -> pd.DataFrame | None:
|
| 136 |
+
pq = CACHE_DIR / "df_features.parquet"
|
| 137 |
+
if pq.is_file():
|
| 138 |
+
return _add_within_modality_orders(pd.read_parquet(pq))
|
| 139 |
+
csv_cache = CACHE_DIR / "df_features.csv"
|
| 140 |
+
if csv_cache.is_file():
|
| 141 |
+
return _add_within_modality_orders(pd.read_csv(csv_cache))
|
| 142 |
+
analysis_csv = REPO_ROOT / "analysis" / "df_features.csv"
|
| 143 |
+
if _is_valid_features_csv(analysis_csv):
|
| 144 |
+
return _add_within_modality_orders(pd.read_csv(analysis_csv))
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def latent_join_samples(bundle: dict, samples: pd.DataFrame | None) -> pd.DataFrame:
|
| 149 |
+
"""One row per UMAP point, aligned with bundle arrays."""
|
| 150 |
+
n = len(bundle["umap_x"])
|
| 151 |
+
df = pd.DataFrame(
|
| 152 |
+
{
|
| 153 |
+
"umap_x": bundle["umap_x"],
|
| 154 |
+
"umap_y": bundle["umap_y"],
|
| 155 |
+
"label": bundle["label_name"],
|
| 156 |
+
"predicted_class": bundle["pred_name"],
|
| 157 |
+
"correct": bundle["correct"].astype(bool),
|
| 158 |
+
"fold": bundle["fold"].astype(int),
|
| 159 |
+
"batch_no": bundle["batch_no"].astype(int),
|
| 160 |
+
"pct": bundle["pct"],
|
| 161 |
+
"modality": bundle["modality"],
|
| 162 |
+
"dataset_idx": bundle["dataset_idx"].astype(int),
|
| 163 |
+
}
|
| 164 |
+
)
|
| 165 |
+
if samples is not None and not samples.empty:
|
| 166 |
+
s = samples.drop_duplicates(subset=["ind"], keep="first").set_index("ind")
|
| 167 |
+
extra = s.reindex(df["dataset_idx"].values)
|
| 168 |
+
for col in ["predicted_value", "clone_id", "clone_size", "cell_type"]:
|
| 169 |
+
if col in extra.columns:
|
| 170 |
+
df[col] = extra[col].values
|
| 171 |
+
return annotate_modality_column(df)
|
streamlit_hf/lib/pathways.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pathway enrichment tables (DAVID-style exports) for Reactome and KEGG panels."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 11 |
+
DE_TSV = REPO_ROOT / "analysis" / "de_all_48.tsv"
|
| 12 |
+
RE_TSV = REPO_ROOT / "analysis" / "re_all_48.tsv"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_de_re_tsv() -> tuple[pd.DataFrame, pd.DataFrame] | None:
|
| 16 |
+
if not DE_TSV.is_file() or not RE_TSV.is_file():
|
| 17 |
+
return None
|
| 18 |
+
return pd.read_csv(DE_TSV, sep="\t"), pd.read_csv(RE_TSV, sep="\t")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def preprocess_pathway_file(df: pd.DataFrame, splitter: str) -> pd.DataFrame:
|
| 22 |
+
out = df.copy()
|
| 23 |
+
out["Term"] = out["Term"].astype(str).str.split(splitter).str[-1]
|
| 24 |
+
if splitter == "-":
|
| 25 |
+
out["Term"] = out["Term"].astype(str).str.split("~").str[-1]
|
| 26 |
+
out = out[out["Benjamini"] < 0.05].copy()
|
| 27 |
+
out["Gene Ratio"] = out["Count"] / out["List Total"]
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def merged_reactome_kegg_bubble_frames(
|
| 32 |
+
de_all: pd.DataFrame, re_all: pd.DataFrame
|
| 33 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 34 |
+
"""Rows for bubble plot (Gene Ratio, Count, Benjamini, Library, Term) per notebook cell 31."""
|
| 35 |
+
reactome_de = de_all[de_all["Category"] == "REACTOME_PATHWAY"]
|
| 36 |
+
reactome_re = re_all[re_all["Category"] == "REACTOME_PATHWAY"]
|
| 37 |
+
kegg_de = de_all[de_all["Category"] == "KEGG_PATHWAY"]
|
| 38 |
+
kegg_re = re_all[re_all["Category"] == "KEGG_PATHWAY"]
|
| 39 |
+
|
| 40 |
+
rde = preprocess_pathway_file(reactome_de, "~")
|
| 41 |
+
rde["Library"] = "Reactome"
|
| 42 |
+
rre = preprocess_pathway_file(reactome_re, "~")
|
| 43 |
+
rre["Library"] = "Reactome"
|
| 44 |
+
kde = preprocess_pathway_file(kegg_de, ":")
|
| 45 |
+
kde["Library"] = "KEGG"
|
| 46 |
+
kre = preprocess_pathway_file(kegg_re, ":")
|
| 47 |
+
kre["Library"] = "KEGG"
|
| 48 |
+
|
| 49 |
+
merged_dead = pd.concat([rde, kde], ignore_index=True)
|
| 50 |
+
merged_re = pd.concat([rre, kre], ignore_index=True)
|
| 51 |
+
return merged_dead, merged_re
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _preprocess_exploded(df: pd.DataFrame, pval_threshold: float, splitter: str, label: str) -> pd.DataFrame:
|
| 55 |
+
d = df.copy()
|
| 56 |
+
d["Term"] = d["Term"].astype(str).str.split(splitter).str[-1]
|
| 57 |
+
if splitter == "-":
|
| 58 |
+
d["Term"] = d["Term"].astype(str).str.split("~").str[-1]
|
| 59 |
+
|
| 60 |
+
def _trunc(x: str) -> str:
|
| 61 |
+
return x[:60] + "..." if len(x) > 60 else x
|
| 62 |
+
|
| 63 |
+
d["Term"] = d["Term"].map(_trunc)
|
| 64 |
+
d = d[d["Benjamini"] < pval_threshold]
|
| 65 |
+
sub = d[["Term", "Genes", "Benjamini"]].copy()
|
| 66 |
+
sub["Label"] = label
|
| 67 |
+
exploded = (
|
| 68 |
+
sub.set_index(["Term", "Benjamini", "Label"])["Genes"].str.split(", ").explode().reset_index()
|
| 69 |
+
)
|
| 70 |
+
return exploded
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _binary_matrix(data: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
|
| 74 |
+
binary = pd.crosstab(data["Term"], data["Genes"])
|
| 75 |
+
labels = data.groupby("Term")["Label"].first()
|
| 76 |
+
pvals = data.groupby("Term")["Benjamini"].first()
|
| 77 |
+
return binary, labels, pvals
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _sort_matrix(matrix: pd.DataFrame) -> pd.DataFrame:
|
| 81 |
+
sp = matrix.sum(axis=1).sort_values(ascending=False).index
|
| 82 |
+
sg = matrix.sum(axis=0).sort_values(ascending=False).index
|
| 83 |
+
return matrix.loc[sp, sg]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def build_merged_pathway_membership(
|
| 87 |
+
de_all: pd.DataFrame, re_all: pd.DataFrame, pval_threshold: float = 0.05
|
| 88 |
+
) -> tuple[np.ndarray, list[str], list[str]] | None:
|
| 89 |
+
"""
|
| 90 |
+
Numeric grid for heatmap: values 0=white, 1=dead-end gene, 2=reprogramming gene,
|
| 91 |
+
3=Reactome library stripe, 4=KEGG library stripe (notebook cell 29).
|
| 92 |
+
"""
|
| 93 |
+
reactome_de = de_all[de_all["Category"] == "REACTOME_PATHWAY"]
|
| 94 |
+
reactome_re = re_all[re_all["Category"] == "REACTOME_PATHWAY"]
|
| 95 |
+
kegg_de = de_all[de_all["Category"] == "KEGG_PATHWAY"]
|
| 96 |
+
kegg_re = re_all[re_all["Category"] == "KEGG_PATHWAY"]
|
| 97 |
+
|
| 98 |
+
rde = _preprocess_exploded(reactome_de, pval_threshold, "~", "Dead-end")
|
| 99 |
+
rre = _preprocess_exploded(reactome_re, pval_threshold, "~", "Reprogramming")
|
| 100 |
+
rcomb = pd.concat([rde, rre], ignore_index=True)
|
| 101 |
+
kde = _preprocess_exploded(kegg_de, pval_threshold, ":", "Dead-end")
|
| 102 |
+
kre = _preprocess_exploded(kegg_re, pval_threshold, ":", "Reprogramming")
|
| 103 |
+
kcomb = pd.concat([kde, kre], ignore_index=True)
|
| 104 |
+
|
| 105 |
+
rm, rlab, _ = _binary_matrix(rcomb)
|
| 106 |
+
km, klab, _ = _binary_matrix(kcomb)
|
| 107 |
+
rm = _sort_matrix(rm)
|
| 108 |
+
km = _sort_matrix(km)
|
| 109 |
+
|
| 110 |
+
reactome_lib = pd.Series("Reactome", index=rm.index)
|
| 111 |
+
kegg_lib = pd.Series("KEGG", index=km.index)
|
| 112 |
+
merged = pd.concat([rm, km], axis=0, sort=False).fillna(0)
|
| 113 |
+
if merged.empty or merged.shape[1] == 0:
|
| 114 |
+
return None
|
| 115 |
+
merged_labels = pd.concat([rlab, klab])
|
| 116 |
+
merged_library = pd.concat([reactome_lib, kegg_lib])
|
| 117 |
+
|
| 118 |
+
label_code = {"Dead-end": 1, "Reprogramming": 2}
|
| 119 |
+
lib_code = {"Reactome": 3, "KEGG": 4}
|
| 120 |
+
|
| 121 |
+
gene_cols = list(merged.columns)
|
| 122 |
+
z = np.zeros((len(merged), len(gene_cols) + 1), dtype=float)
|
| 123 |
+
for i, term in enumerate(merged.index):
|
| 124 |
+
lc = label_code.get(str(merged_labels.loc[term]), 0)
|
| 125 |
+
for j, g in enumerate(gene_cols):
|
| 126 |
+
v = float(merged.loc[term, g])
|
| 127 |
+
if v > 0 and lc:
|
| 128 |
+
z[i, j] = v * lc
|
| 129 |
+
z[i, -1] = lib_code.get(str(merged_library.loc[term]), 0)
|
| 130 |
+
|
| 131 |
+
row_labels = [str(t) for t in merged.index]
|
| 132 |
+
col_labels = gene_cols + ["Library"]
|
| 133 |
+
return z, row_labels, col_labels
|
streamlit_hf/lib/plots.py
ADDED
|
@@ -0,0 +1,1421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plotly helpers for the explorer UI."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import plotly.express as px
|
| 10 |
+
import plotly.graph_objects as go
|
| 11 |
+
from plotly.subplots import make_subplots
|
| 12 |
+
|
| 13 |
+
from streamlit_hf.lib.reactions import normalize_reaction_key
|
| 14 |
+
|
| 15 |
+
# Matches Streamlit theme primary + slate text; used across Plotly layouts.
|
| 16 |
+
PLOT_FONT = dict(family="Inter, system-ui, sans-serif", size=12)
|
| 17 |
+
|
| 18 |
+
PALETTE = (
|
| 19 |
+
"#2563eb",
|
| 20 |
+
"#dc2626",
|
| 21 |
+
"#059669",
|
| 22 |
+
"#d97706",
|
| 23 |
+
"#7c3aed",
|
| 24 |
+
"#db2777",
|
| 25 |
+
"#0d9488",
|
| 26 |
+
"#4f46e5",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
MODALITY_COLOR = {"RNA": "#E64B35", "ATAC": "#4DBBD5", "Flux": "#00A087"}
|
| 30 |
+
# Global modality pie only: edit here to try other hues (bars/scatter use MODALITY_COLOR).
|
| 31 |
+
MODALITY_PIE_COLOR = dict(MODALITY_COLOR)
|
| 32 |
+
# Log₂FC heatmaps/sunburst: colours like ggplot2 scale_colour_gradient2 (mid grey at 0).
|
| 33 |
+
LOG_FC_COLOR_MIN = -0.5
|
| 34 |
+
LOG_FC_COLOR_MAX = 0.5
|
| 35 |
+
LOG_FC_DIVERGING_SCALE: list[list] = [
|
| 36 |
+
[0.0, "#1C86EE"],
|
| 37 |
+
[0.5, "#FAFAFA"],
|
| 38 |
+
[1.0, "#FF0000"],
|
| 39 |
+
]
|
| 40 |
+
# Unicode minus (U+2212) and subscript ₁₀ / ₂ for axes/colorbars.
|
| 41 |
+
LABEL_NEG_LOG10_ADJ_P = "\u2212log\u2081\u2080 adj. p"
|
| 42 |
+
LABEL_LOG2FC = "Log\u2082FC"
|
| 43 |
+
# Cached attention dict uses lowercase modality keys.
|
| 44 |
+
FI_ATT_MOD_KEY = {"RNA": "rna", "ATAC": "atac", "Flux": "flux"}
|
| 45 |
+
# Model appends one batch-embedding token per modality; hide from attention rankings in the UI.
|
| 46 |
+
BATCH_EMBEDDING_FEATURE_NAMES = frozenset({"batch_rna", "batch_atac", "batch_flux"})
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _attention_pairs_skip_batch(pairs: list) -> list:
|
| 50 |
+
return [(n, s) for n, s in pairs if str(n) not in BATCH_EMBEDDING_FEATURE_NAMES]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rollout_top_features_table(feature_names, vec, top_n: int) -> pd.DataFrame:
|
| 54 |
+
"""Top `top_n` rollout weights per modality slice, excluding batch-embedding tokens."""
|
| 55 |
+
names = [str(x) for x in feature_names]
|
| 56 |
+
v = np.asarray(vec, dtype=float)
|
| 57 |
+
rows = [
|
| 58 |
+
(names[i], float(v[i]))
|
| 59 |
+
for i in range(len(names))
|
| 60 |
+
if names[i] not in BATCH_EMBEDDING_FEATURE_NAMES
|
| 61 |
+
]
|
| 62 |
+
rows.sort(key=lambda x: -x[1])
|
| 63 |
+
rows = rows[:top_n]
|
| 64 |
+
if not rows:
|
| 65 |
+
return pd.DataFrame(columns=["feature", "mean_attention"])
|
| 66 |
+
feat, val = zip(*rows)
|
| 67 |
+
return pd.DataFrame({"feature": list(feat), "mean_attention": list(val)})
|
| 68 |
+
|
| 69 |
+
# Themed continuous scale for dominant-fate % on UMAP (low → high emphasis).
|
| 70 |
+
UMAP_PCT_COLORSCALE: list[list] = [
|
| 71 |
+
[0.0, "#eff6ff"],
|
| 72 |
+
[0.25, "#bfdbfe"],
|
| 73 |
+
[0.55, "#3b82f6"],
|
| 74 |
+
[0.82, "#2563eb"],
|
| 75 |
+
[1.0, "#1e3a8a"],
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Okabe–Ito–style distinct colours (colourblind-friendly) for categorical UMAP hues.
|
| 79 |
+
LATENT_DISCRETE_PALETTE = (
|
| 80 |
+
"#0072B2",
|
| 81 |
+
"#E69F00",
|
| 82 |
+
"#009E73",
|
| 83 |
+
"#CC79A7",
|
| 84 |
+
"#56B4E9",
|
| 85 |
+
"#D55E00",
|
| 86 |
+
"#F0E442",
|
| 87 |
+
"#000000",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def latent_scatter(
|
| 92 |
+
df,
|
| 93 |
+
color_col: str,
|
| 94 |
+
title: str,
|
| 95 |
+
width: int = 720,
|
| 96 |
+
height: int = 520,
|
| 97 |
+
marker_size: float = 5.0,
|
| 98 |
+
marker_opacity: float = 0.78,
|
| 99 |
+
):
|
| 100 |
+
d = df.copy()
|
| 101 |
+
hover_spec = {
|
| 102 |
+
"umap_x": ":.3f",
|
| 103 |
+
"umap_y": ":.3f",
|
| 104 |
+
"dataset_idx": True,
|
| 105 |
+
"fold": True,
|
| 106 |
+
"batch_no": True,
|
| 107 |
+
"predicted_class": True,
|
| 108 |
+
"label": True,
|
| 109 |
+
"correct": True,
|
| 110 |
+
"pct": ":.2f",
|
| 111 |
+
"modality_label": True,
|
| 112 |
+
"modality": True,
|
| 113 |
+
"predicted_value": ":.3f",
|
| 114 |
+
"clone_id": True,
|
| 115 |
+
"clone_size": True,
|
| 116 |
+
"cell_type": True,
|
| 117 |
+
}
|
| 118 |
+
if "modality_label" in d.columns:
|
| 119 |
+
hover_spec.pop("modality", None)
|
| 120 |
+
hover_data = {k: v for k, v in hover_spec.items() if k in d.columns}
|
| 121 |
+
_disp = {
|
| 122 |
+
"label": "CellTag-Multi label",
|
| 123 |
+
"predicted_class": "Predicted fate",
|
| 124 |
+
"pct": "Dominant fate (%)",
|
| 125 |
+
"modality_label": "Available modalities",
|
| 126 |
+
"dataset_idx": "Dataset index",
|
| 127 |
+
"batch_no": "Batch",
|
| 128 |
+
"fold": "CV fold",
|
| 129 |
+
}
|
| 130 |
+
labels_map = {c: _disp[c] for c in _disp if c in d.columns}
|
| 131 |
+
|
| 132 |
+
continuous = color_col == "pct"
|
| 133 |
+
if color_col == "fold":
|
| 134 |
+
d["_color"] = d["fold"].astype(str)
|
| 135 |
+
color_arg = "_color"
|
| 136 |
+
labels_map["_color"] = "Fold"
|
| 137 |
+
continuous = False
|
| 138 |
+
elif color_col == "batch_no":
|
| 139 |
+
d["_color"] = d["batch_no"].astype(str)
|
| 140 |
+
color_arg = "_color"
|
| 141 |
+
labels_map["_color"] = "Batch"
|
| 142 |
+
continuous = False
|
| 143 |
+
elif color_col == "correct":
|
| 144 |
+
d["_color"] = d["correct"].map({True: "Correct", False: "Wrong"})
|
| 145 |
+
color_arg = "_color"
|
| 146 |
+
labels_map["_color"] = "Prediction"
|
| 147 |
+
continuous = False
|
| 148 |
+
else:
|
| 149 |
+
color_arg = color_col
|
| 150 |
+
|
| 151 |
+
common = dict(
|
| 152 |
+
x="umap_x",
|
| 153 |
+
y="umap_y",
|
| 154 |
+
hover_data=hover_data,
|
| 155 |
+
labels=labels_map,
|
| 156 |
+
title=title,
|
| 157 |
+
width=width,
|
| 158 |
+
height=height,
|
| 159 |
+
)
|
| 160 |
+
if continuous:
|
| 161 |
+
fig = px.scatter(
|
| 162 |
+
d,
|
| 163 |
+
color=color_arg,
|
| 164 |
+
color_continuous_scale=UMAP_PCT_COLORSCALE,
|
| 165 |
+
**common,
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
fig = px.scatter(
|
| 169 |
+
d,
|
| 170 |
+
color=color_arg,
|
| 171 |
+
color_discrete_sequence=list(LATENT_DISCRETE_PALETTE),
|
| 172 |
+
**common,
|
| 173 |
+
)
|
| 174 |
+
fig.update_traces(
|
| 175 |
+
marker=dict(size=marker_size, opacity=marker_opacity, line=dict(width=0.25, color="rgba(255,255,255,0.4)"))
|
| 176 |
+
)
|
| 177 |
+
fig.update_layout(
|
| 178 |
+
template="plotly_white",
|
| 179 |
+
font=PLOT_FONT,
|
| 180 |
+
title_font_size=16,
|
| 181 |
+
margin=dict(l=28, r=20, t=56, b=28),
|
| 182 |
+
legend_title_text="",
|
| 183 |
+
xaxis_title="",
|
| 184 |
+
yaxis_title="",
|
| 185 |
+
)
|
| 186 |
+
fig.update_xaxes(showticklabels=False, showgrid=True, gridcolor="rgba(0,0,0,0.06)", zeroline=False)
|
| 187 |
+
fig.update_yaxes(showticklabels=False, showgrid=True, gridcolor="rgba(0,0,0,0.06)", zeroline=False)
|
| 188 |
+
return fig
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def rank_scatter_shift_vs_attention(df_mod, modality: str, width: int = 420, height: int = 440):
|
| 192 |
+
"""Attention rank on x, shift rank on y, least-squares trend line, discrete point colours."""
|
| 193 |
+
need = ("shift_order_mod", "attention_order_mod")
|
| 194 |
+
if not all(c in df_mod.columns for c in need):
|
| 195 |
+
return go.Figure()
|
| 196 |
+
sub = df_mod.dropna(subset=list(need)).copy()
|
| 197 |
+
if sub.empty:
|
| 198 |
+
return go.Figure()
|
| 199 |
+
x = sub["attention_order_mod"].astype(float).to_numpy()
|
| 200 |
+
y = sub["shift_order_mod"].astype(float).to_numpy()
|
| 201 |
+
fig = px.scatter(
|
| 202 |
+
sub,
|
| 203 |
+
x="attention_order_mod",
|
| 204 |
+
y="shift_order_mod",
|
| 205 |
+
color="top_10_pct",
|
| 206 |
+
hover_name="feature",
|
| 207 |
+
hover_data={
|
| 208 |
+
"mean_rank": True,
|
| 209 |
+
"importance_shift": ":.4f",
|
| 210 |
+
"importance_att": ":.4f",
|
| 211 |
+
},
|
| 212 |
+
labels={
|
| 213 |
+
"attention_order_mod": "Attention rank",
|
| 214 |
+
"shift_order_mod": "Shift rank",
|
| 215 |
+
},
|
| 216 |
+
width=width,
|
| 217 |
+
height=height,
|
| 218 |
+
color_discrete_map={
|
| 219 |
+
"both": PALETTE[0],
|
| 220 |
+
"shift": PALETTE[1],
|
| 221 |
+
"att": PALETTE[2],
|
| 222 |
+
"None": "#94a3b8",
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
+
fig.update_traces(marker=dict(size=7, opacity=0.62, line=dict(width=0.5, color="rgba(15,23,42,0.28)")))
|
| 226 |
+
if len(x) >= 2 and float(np.ptp(x)) > 0:
|
| 227 |
+
coef = np.polyfit(x, y, 1)
|
| 228 |
+
poly = np.poly1d(coef)
|
| 229 |
+
xs = np.linspace(float(np.min(x)), float(np.max(x)), 100)
|
| 230 |
+
fig.add_trace(
|
| 231 |
+
go.Scatter(
|
| 232 |
+
x=xs,
|
| 233 |
+
y=poly(xs),
|
| 234 |
+
mode="lines",
|
| 235 |
+
name=f"y = {coef[0]:.2f}x + {coef[1]:.2f}",
|
| 236 |
+
line=dict(color="#2563eb", width=2, dash="dash"),
|
| 237 |
+
showlegend=True,
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
fig.update_layout(
|
| 241 |
+
template="plotly_white",
|
| 242 |
+
font=PLOT_FONT,
|
| 243 |
+
title=dict(
|
| 244 |
+
text=f"{modality}: shift vs attention (ranks)",
|
| 245 |
+
x=0.5,
|
| 246 |
+
xanchor="center",
|
| 247 |
+
y=0.98,
|
| 248 |
+
yanchor="top",
|
| 249 |
+
font=dict(size=14, family=PLOT_FONT["family"]),
|
| 250 |
+
),
|
| 251 |
+
margin=dict(l=48, r=20, t=52, b=72),
|
| 252 |
+
legend=dict(orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5),
|
| 253 |
+
)
|
| 254 |
+
return fig
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _truncate_label(s: str, max_len: int = 36) -> str:
|
| 258 |
+
s = str(s)
|
| 259 |
+
return s if len(s) <= max_len else s[: max_len - 1] + "…"
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def joint_shift_attention_top_features(df_mod, modality: str, top_n: int):
|
| 263 |
+
"""
|
| 264 |
+
Top features by mean_rank (lowest = strongest joint shift+attention ranking).
|
| 265 |
+
Shift and attention importances are min–max scaled within this top-N slice for side-by-side comparison.
|
| 266 |
+
"""
|
| 267 |
+
need = ("mean_rank", "importance_shift", "importance_att", "feature")
|
| 268 |
+
if not all(c in df_mod.columns for c in need):
|
| 269 |
+
return go.Figure()
|
| 270 |
+
sub = df_mod.nsmallest(top_n, "mean_rank").copy()
|
| 271 |
+
if sub.empty:
|
| 272 |
+
return go.Figure()
|
| 273 |
+
|
| 274 |
+
def _mm(s: pd.Series) -> pd.Series:
|
| 275 |
+
lo, hi = float(s.min()), float(s.max())
|
| 276 |
+
if hi <= lo:
|
| 277 |
+
return pd.Series(0.5, index=s.index)
|
| 278 |
+
return (s.astype(float) - lo) / (hi - lo)
|
| 279 |
+
|
| 280 |
+
sub["_zs"] = _mm(sub["importance_shift"])
|
| 281 |
+
sub["_za"] = _mm(sub["importance_att"])
|
| 282 |
+
# Best (lowest mean_rank) at top of chart; matches shift/attention rows below.
|
| 283 |
+
sub = sub.sort_values("mean_rank", ascending=True)
|
| 284 |
+
feats_full = sub["feature"].astype(str)
|
| 285 |
+
y_disp = feats_full.map(lambda s: _truncate_label(s, 40))
|
| 286 |
+
base = MODALITY_COLOR.get(modality, PALETTE[0])
|
| 287 |
+
att_c = "#475569" if base != "#475569" else "#64748b"
|
| 288 |
+
|
| 289 |
+
margin_l = int(min(380, 64 + 5.8 * max((len(t) for t in y_disp), default=10)))
|
| 290 |
+
h = min(720, 52 + 22 * len(sub))
|
| 291 |
+
|
| 292 |
+
fig = go.Figure()
|
| 293 |
+
fig.add_trace(
|
| 294 |
+
go.Bar(
|
| 295 |
+
name="Shift (scaled)",
|
| 296 |
+
y=y_disp,
|
| 297 |
+
x=sub["_zs"],
|
| 298 |
+
orientation="h",
|
| 299 |
+
marker_color=base,
|
| 300 |
+
customdata=feats_full,
|
| 301 |
+
hovertemplate="<b>%{customdata}</b><br>Shift (scaled): %{x:.3f}<extra></extra>",
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
fig.add_trace(
|
| 305 |
+
go.Bar(
|
| 306 |
+
name="Attention (scaled)",
|
| 307 |
+
y=y_disp,
|
| 308 |
+
x=sub["_za"],
|
| 309 |
+
orientation="h",
|
| 310 |
+
marker_color=att_c,
|
| 311 |
+
customdata=feats_full,
|
| 312 |
+
hovertemplate="<b>%{customdata}</b><br>Attention (scaled): %{x:.3f}<extra></extra>",
|
| 313 |
+
)
|
| 314 |
+
)
|
| 315 |
+
fig.update_layout(
|
| 316 |
+
template="plotly_white",
|
| 317 |
+
font=PLOT_FONT,
|
| 318 |
+
title=dict(
|
| 319 |
+
text=f"{modality} · top {top_n}",
|
| 320 |
+
x=0.5,
|
| 321 |
+
xanchor="center",
|
| 322 |
+
y=0.98,
|
| 323 |
+
yanchor="top",
|
| 324 |
+
font=dict(size=14, family=PLOT_FONT["family"]),
|
| 325 |
+
),
|
| 326 |
+
barmode="group",
|
| 327 |
+
bargap=0.15,
|
| 328 |
+
bargroupgap=0.05,
|
| 329 |
+
width=680,
|
| 330 |
+
height=h,
|
| 331 |
+
margin=dict(l=margin_l, r=12, t=44, b=72),
|
| 332 |
+
xaxis_title="Scaled 0-1 within selection",
|
| 333 |
+
yaxis_title="",
|
| 334 |
+
legend=dict(orientation="h", yanchor="top", y=-0.14, xanchor="center", x=0.5),
|
| 335 |
+
)
|
| 336 |
+
fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
|
| 337 |
+
return fig
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def modality_shift_attention_rank_stats(df_mod) -> dict[str, Any]:
|
| 341 |
+
"""Pearson / Spearman between per-modality shift and attention ordinal ranks."""
|
| 342 |
+
from scipy.stats import pearsonr, spearmanr
|
| 343 |
+
|
| 344 |
+
need = ("shift_order_mod", "attention_order_mod")
|
| 345 |
+
if not all(c in df_mod.columns for c in need):
|
| 346 |
+
return {"n": 0}
|
| 347 |
+
sub = df_mod.dropna(subset=list(need))
|
| 348 |
+
n = len(sub)
|
| 349 |
+
if n < 3:
|
| 350 |
+
return {"n": n}
|
| 351 |
+
xs = sub["attention_order_mod"].astype(float)
|
| 352 |
+
ys = sub["shift_order_mod"].astype(float)
|
| 353 |
+
pr, pp = pearsonr(xs, ys)
|
| 354 |
+
sr, sp = spearmanr(xs, ys)
|
| 355 |
+
return {
|
| 356 |
+
"n": n,
|
| 357 |
+
"pearson_r": float(pr),
|
| 358 |
+
"pearson_p": float(pp),
|
| 359 |
+
"spearman_r": float(sr),
|
| 360 |
+
"spearman_p": float(sp),
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def rank_bar(
|
| 365 |
+
df_top,
|
| 366 |
+
xcol: str,
|
| 367 |
+
ycol: str,
|
| 368 |
+
title: str,
|
| 369 |
+
color: str = PALETTE[0],
|
| 370 |
+
xaxis_title: str | None = None,
|
| 371 |
+
):
|
| 372 |
+
d = df_top.sort_values(xcol, ascending=True)
|
| 373 |
+
y_raw = d[ycol].astype(str)
|
| 374 |
+
y_show = y_raw.map(lambda s: _truncate_label(s, 42))
|
| 375 |
+
margin_l = int(min(420, 80 + 5.8 * max((len(s) for s in y_show), default=12)))
|
| 376 |
+
fig = go.Figure(
|
| 377 |
+
go.Bar(
|
| 378 |
+
y=y_show,
|
| 379 |
+
x=d[xcol],
|
| 380 |
+
orientation="h",
|
| 381 |
+
marker_color=color,
|
| 382 |
+
customdata=y_raw,
|
| 383 |
+
hovertemplate="<b>%{customdata}</b><br>%{x:.4g}<extra></extra>",
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
xt = xaxis_title if xaxis_title is not None else xcol.replace("_", " ")
|
| 387 |
+
fig.update_layout(
|
| 388 |
+
template="plotly_white",
|
| 389 |
+
font=PLOT_FONT,
|
| 390 |
+
title=title,
|
| 391 |
+
width=680,
|
| 392 |
+
height=min(620, 38 + 20 * len(d)),
|
| 393 |
+
margin=dict(l=margin_l, r=24, t=48, b=40),
|
| 394 |
+
xaxis_title=xt,
|
| 395 |
+
yaxis_title="",
|
| 396 |
+
)
|
| 397 |
+
fig.update_yaxes(tickfont=dict(size=10))
|
| 398 |
+
return fig
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def attention_top_comparison(fi_lists: dict, modality: str, top_n: int = 18):
|
| 402 |
+
"""fi_lists: cohort -> {rna|atac|flux: [(name, score), ...]}."""
|
| 403 |
+
mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
|
| 404 |
+
traces = []
|
| 405 |
+
for key, name, color in (
|
| 406 |
+
("all", "All validation samples", PALETTE[0]),
|
| 407 |
+
("dead_end", "Predicted dead-end", PALETTE[1]),
|
| 408 |
+
("reprogramming", "Predicted reprogramming", PALETTE[2]),
|
| 409 |
+
):
|
| 410 |
+
cohort = fi_lists.get(key) or {}
|
| 411 |
+
items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
|
| 412 |
+
if not items:
|
| 413 |
+
continue
|
| 414 |
+
feats, scores = zip(*items)
|
| 415 |
+
traces.append(
|
| 416 |
+
go.Bar(
|
| 417 |
+
name=name,
|
| 418 |
+
x=list(scores),
|
| 419 |
+
y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
|
| 420 |
+
orientation="h",
|
| 421 |
+
marker_color=color,
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
fig = go.Figure(traces)
|
| 425 |
+
bar_h = max(320, 36 + min(top_n, 20) * 22 * max(1, len(traces)))
|
| 426 |
+
fig.update_layout(
|
| 427 |
+
barmode="group",
|
| 428 |
+
template="plotly_white",
|
| 429 |
+
font=PLOT_FONT,
|
| 430 |
+
title=f"Top attention (rollout): {modality}",
|
| 431 |
+
width=520,
|
| 432 |
+
height=bar_h,
|
| 433 |
+
margin=dict(l=220, r=24, t=56, b=40),
|
| 434 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 435 |
+
)
|
| 436 |
+
if not traces:
|
| 437 |
+
fig.update_layout(
|
| 438 |
+
annotations=[
|
| 439 |
+
dict(
|
| 440 |
+
text="No attention list for this modality (re-run precompute).",
|
| 441 |
+
xref="paper",
|
| 442 |
+
yref="paper",
|
| 443 |
+
x=0.5,
|
| 444 |
+
y=0.5,
|
| 445 |
+
showarrow=False,
|
| 446 |
+
)
|
| 447 |
+
]
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
fig.update_yaxes(autorange="reversed")
|
| 451 |
+
return fig
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def attention_cohort_view(
|
| 455 |
+
fi_lists: dict,
|
| 456 |
+
modality: str,
|
| 457 |
+
top_n: int,
|
| 458 |
+
mode: str,
|
| 459 |
+
):
|
| 460 |
+
"""
|
| 461 |
+
mode: 'compare': grouped bars for all three cohorts;
|
| 462 |
+
'all' | 'dead_end' | 'reprogramming': single cohort only.
|
| 463 |
+
"""
|
| 464 |
+
if mode == "compare":
|
| 465 |
+
return attention_top_comparison(fi_lists, modality, top_n)
|
| 466 |
+
mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
|
| 467 |
+
cohort = fi_lists.get(mode) or {}
|
| 468 |
+
items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
|
| 469 |
+
label = {
|
| 470 |
+
"all": "All validation samples",
|
| 471 |
+
"dead_end": "Predicted dead-end",
|
| 472 |
+
"reprogramming": "Predicted reprogramming",
|
| 473 |
+
}.get(mode, mode)
|
| 474 |
+
if not items:
|
| 475 |
+
fig = go.Figure()
|
| 476 |
+
fig.update_layout(
|
| 477 |
+
template="plotly_white",
|
| 478 |
+
font=PLOT_FONT,
|
| 479 |
+
title=f"{modality} · {label}",
|
| 480 |
+
annotations=[
|
| 481 |
+
dict(
|
| 482 |
+
text="No items for this cohort.",
|
| 483 |
+
xref="paper",
|
| 484 |
+
yref="paper",
|
| 485 |
+
x=0.5,
|
| 486 |
+
y=0.5,
|
| 487 |
+
showarrow=False,
|
| 488 |
+
)
|
| 489 |
+
],
|
| 490 |
+
)
|
| 491 |
+
return fig
|
| 492 |
+
feats, scores = zip(*items)
|
| 493 |
+
fig = go.Figure(
|
| 494 |
+
go.Bar(
|
| 495 |
+
x=list(scores),
|
| 496 |
+
y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
|
| 497 |
+
orientation="h",
|
| 498 |
+
marker_color=PALETTE[0],
|
| 499 |
+
)
|
| 500 |
+
)
|
| 501 |
+
h = max(280, 40 + min(top_n, 25) * 20)
|
| 502 |
+
fig.update_layout(
|
| 503 |
+
template="plotly_white",
|
| 504 |
+
font=PLOT_FONT,
|
| 505 |
+
title=f"{modality} · {label}",
|
| 506 |
+
width=520,
|
| 507 |
+
height=h,
|
| 508 |
+
margin=dict(l=220, r=24, t=56, b=40),
|
| 509 |
+
xaxis_title="Attention weight",
|
| 510 |
+
)
|
| 511 |
+
fig.update_yaxes(autorange="reversed")
|
| 512 |
+
return fig
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def global_rank_triple_panel(df_features, top_n: int = 20, top_n_pie: int = 100):
|
| 516 |
+
"""
|
| 517 |
+
Global top-N by latent-shift and by attention (min–max scaled), plus pie of modality mix
|
| 518 |
+
among the top `top_n_pie` features by mean rank.
|
| 519 |
+
"""
|
| 520 |
+
d = df_features.copy()
|
| 521 |
+
for col in ("importance_shift", "importance_att"):
|
| 522 |
+
min_v, max_v = d[col].min(), d[col].max()
|
| 523 |
+
if max_v > min_v:
|
| 524 |
+
d[col + "_norm"] = (d[col] - min_v) / (max_v - min_v)
|
| 525 |
+
else:
|
| 526 |
+
d[col + "_norm"] = 0.0
|
| 527 |
+
|
| 528 |
+
shift_top = d.nlargest(top_n, "importance_shift")
|
| 529 |
+
att_top = d.nlargest(top_n, "importance_att")
|
| 530 |
+
pie_pool = d.nsmallest(top_n_pie, "mean_rank")
|
| 531 |
+
|
| 532 |
+
fig = make_subplots(
|
| 533 |
+
rows=1,
|
| 534 |
+
cols=3,
|
| 535 |
+
column_widths=[0.36, 0.36, 0.28],
|
| 536 |
+
specs=[[{}, {}, {"type": "domain"}]],
|
| 537 |
+
subplot_titles=(
|
| 538 |
+
f"Top {top_n} by latent shift (ranked)",
|
| 539 |
+
f"Top {top_n} by attention (ranked)",
|
| 540 |
+
f"Top {top_n_pie} by mean rank (modality mix)",
|
| 541 |
+
),
|
| 542 |
+
horizontal_spacing=0.06,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
fig.add_trace(
|
| 546 |
+
go.Bar(
|
| 547 |
+
x=shift_top["importance_shift_norm"],
|
| 548 |
+
y=shift_top["feature"],
|
| 549 |
+
orientation="h",
|
| 550 |
+
marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in shift_top["modality"]],
|
| 551 |
+
marker_line=dict(color="rgba(15,23,42,0.12)", width=1),
|
| 552 |
+
showlegend=False,
|
| 553 |
+
hovertemplate="%{y}<br>scaled shift: %{x:.3f}<extra></extra>",
|
| 554 |
+
),
|
| 555 |
+
row=1,
|
| 556 |
+
col=1,
|
| 557 |
+
)
|
| 558 |
+
fig.add_trace(
|
| 559 |
+
go.Bar(
|
| 560 |
+
x=att_top["importance_att_norm"],
|
| 561 |
+
y=att_top["feature"],
|
| 562 |
+
orientation="h",
|
| 563 |
+
marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in att_top["modality"]],
|
| 564 |
+
marker_line=dict(color="rgba(15,23,42,0.12)", width=1),
|
| 565 |
+
showlegend=False,
|
| 566 |
+
hovertemplate="%{y}<br>scaled attention: %{x:.3f}<extra></extra>",
|
| 567 |
+
),
|
| 568 |
+
row=1,
|
| 569 |
+
col=2,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
pie_labels = ["RNA", "ATAC", "Flux"]
|
| 573 |
+
counts = pie_pool["modality"].value_counts()
|
| 574 |
+
pie_vals = [int(counts.get(lab, 0)) for lab in pie_labels]
|
| 575 |
+
if sum(pie_vals) == 0:
|
| 576 |
+
pie_vals = [1, 1, 1]
|
| 577 |
+
|
| 578 |
+
fig.add_trace(
|
| 579 |
+
go.Pie(
|
| 580 |
+
labels=pie_labels,
|
| 581 |
+
values=pie_vals,
|
| 582 |
+
marker=dict(
|
| 583 |
+
colors=[MODALITY_PIE_COLOR.get(l, "#64748b") for l in pie_labels],
|
| 584 |
+
line=dict(color="#1e293b", width=1.2),
|
| 585 |
+
),
|
| 586 |
+
textinfo="label+percent",
|
| 587 |
+
textfont_size=12,
|
| 588 |
+
hole=0.0,
|
| 589 |
+
showlegend=False,
|
| 590 |
+
),
|
| 591 |
+
row=1,
|
| 592 |
+
col=3,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
fig.update_xaxes(title_text="Min-max scaled shift", row=1, col=1)
|
| 596 |
+
fig.update_xaxes(title_text="Min-max scaled attention", row=1, col=2)
|
| 597 |
+
fig.update_yaxes(autorange="reversed", row=1, col=1)
|
| 598 |
+
fig.update_yaxes(autorange="reversed", row=1, col=2)
|
| 599 |
+
|
| 600 |
+
h = max(480, 40 + top_n * 18)
|
| 601 |
+
fig.update_layout(
|
| 602 |
+
template="plotly_white",
|
| 603 |
+
font=PLOT_FONT,
|
| 604 |
+
height=h,
|
| 605 |
+
width=min(1280, 400 + top_n * 14),
|
| 606 |
+
margin=dict(l=40, r=40, t=80, b=40),
|
| 607 |
+
title_text="Global feature ranking (all modalities)",
|
| 608 |
+
title_x=0.5,
|
| 609 |
+
)
|
| 610 |
+
return fig
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def _flux_prepare_top_ranked(flux_df: pd.DataFrame, top_n: int, metric: str = "mean_rank") -> pd.DataFrame:
|
| 614 |
+
sub = flux_df[~flux_df["feature"].astype(str).str.contains("batch", case=False, na=False)].copy()
|
| 615 |
+
if metric not in sub.columns:
|
| 616 |
+
metric = "mean_rank"
|
| 617 |
+
sub = sub.sort_values(metric, ascending=True).head(int(top_n)).copy()
|
| 618 |
+
if "pathway" in sub.columns:
|
| 619 |
+
pc = sub["pathway"].value_counts()
|
| 620 |
+
sub["_pw_n"] = sub["pathway"].map(pc)
|
| 621 |
+
sub.sort_values(["_pw_n", "pathway"], ascending=[False, True], inplace=True)
|
| 622 |
+
return sub
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def flux_pathway_sunburst(flux_df: pd.DataFrame, max_features: int = 55) -> go.Figure:
|
| 626 |
+
sub = flux_df.dropna(subset=["pathway"]).copy()
|
| 627 |
+
if sub.empty:
|
| 628 |
+
return go.Figure()
|
| 629 |
+
sub = sub.nsmallest(int(max_features), "mean_rank")
|
| 630 |
+
sub["pathway"] = sub["pathway"].astype(str)
|
| 631 |
+
sub["_uid"] = np.arange(len(sub))
|
| 632 |
+
sub["rxn"] = sub.apply(
|
| 633 |
+
lambda r: f"{_truncate_label(str(r['feature']), 36)} ·{int(r['_uid'])}",
|
| 634 |
+
axis=1,
|
| 635 |
+
)
|
| 636 |
+
mr = sub["mean_rank"].astype(float)
|
| 637 |
+
sub["w"] = (mr.max() - mr + 1.0).clip(lower=0.5)
|
| 638 |
+
color_col = "log_fc" if "log_fc" in sub.columns and sub["log_fc"].notna().any() else "mean_rank"
|
| 639 |
+
sb_kw: dict[str, Any] = {
|
| 640 |
+
"path": ["pathway", "rxn"],
|
| 641 |
+
"values": "w",
|
| 642 |
+
"color": color_col,
|
| 643 |
+
"hover_data": {"mean_rank": ":.2f", "pval_adj": ":.2e", "feature": True, "w": False, "_uid": False},
|
| 644 |
+
}
|
| 645 |
+
if color_col == "log_fc":
|
| 646 |
+
sb_kw["color_continuous_scale"] = LOG_FC_DIVERGING_SCALE
|
| 647 |
+
sb_kw["range_color"] = [LOG_FC_COLOR_MIN, LOG_FC_COLOR_MAX]
|
| 648 |
+
else:
|
| 649 |
+
sb_kw["color_continuous_scale"] = "Viridis_r"
|
| 650 |
+
fig = px.sunburst(sub, **sb_kw)
|
| 651 |
+
fig.update_layout(
|
| 652 |
+
template="plotly_white",
|
| 653 |
+
font=PLOT_FONT,
|
| 654 |
+
margin=dict(l=8, r=8, t=100, b=16),
|
| 655 |
+
height=min(820, 520 + int(max_features) * 5),
|
| 656 |
+
title=dict(
|
| 657 |
+
text="Top flux reactions by model rank, nested under pathway",
|
| 658 |
+
x=0,
|
| 659 |
+
xanchor="left",
|
| 660 |
+
y=0.99,
|
| 661 |
+
yanchor="top",
|
| 662 |
+
font=dict(size=13, family=PLOT_FONT["family"]),
|
| 663 |
+
pad=dict(b=16, l=4),
|
| 664 |
+
),
|
| 665 |
+
)
|
| 666 |
+
if color_col == "log_fc":
|
| 667 |
+
fig.update_layout(
|
| 668 |
+
coloraxis=dict(
|
| 669 |
+
cmin=LOG_FC_COLOR_MIN,
|
| 670 |
+
cmax=LOG_FC_COLOR_MAX,
|
| 671 |
+
colorbar=dict(
|
| 672 |
+
title=dict(text=LABEL_LOG2FC, side="right"),
|
| 673 |
+
tickformat=".2f",
|
| 674 |
+
len=0.38,
|
| 675 |
+
thickness=12,
|
| 676 |
+
y=0.52,
|
| 677 |
+
yanchor="middle",
|
| 678 |
+
),
|
| 679 |
+
)
|
| 680 |
+
)
|
| 681 |
+
return fig
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def flux_volcano(flux_df: pd.DataFrame) -> go.Figure:
|
| 685 |
+
if "log_fc" not in flux_df.columns:
|
| 686 |
+
return go.Figure()
|
| 687 |
+
d = flux_df.dropna(subset=["log_fc"]).copy()
|
| 688 |
+
if d.empty:
|
| 689 |
+
return go.Figure()
|
| 690 |
+
# Drop degenerate rows: ~zero fold-change with exactly-zero adjusted p (numeric artifact / noise).
|
| 691 |
+
lf = d["log_fc"].astype(float)
|
| 692 |
+
if "pval_adj" in d.columns:
|
| 693 |
+
pa = d["pval_adj"].astype(float)
|
| 694 |
+
bad = np.isfinite(lf) & np.isfinite(pa) & (np.abs(lf) < 1e-10) & (pa <= 0.0)
|
| 695 |
+
d = d[~bad]
|
| 696 |
+
if d.empty:
|
| 697 |
+
return go.Figure()
|
| 698 |
+
if "pval_adj_log" in d.columns:
|
| 699 |
+
y = d["pval_adj_log"].astype(float)
|
| 700 |
+
else:
|
| 701 |
+
p = d["pval_adj"].astype(float).clip(lower=1e-300)
|
| 702 |
+
y = -np.log10(p.to_numpy())
|
| 703 |
+
d = d.assign(_neglogp=y)
|
| 704 |
+
fig = px.scatter(
|
| 705 |
+
d,
|
| 706 |
+
x="log_fc",
|
| 707 |
+
y="_neglogp",
|
| 708 |
+
color="mean_rank",
|
| 709 |
+
color_continuous_scale="Viridis_r",
|
| 710 |
+
hover_name="feature",
|
| 711 |
+
hover_data=["pathway", "pval_adj", "group"],
|
| 712 |
+
labels={
|
| 713 |
+
"log_fc": LABEL_LOG2FC,
|
| 714 |
+
"_neglogp": LABEL_NEG_LOG10_ADJ_P,
|
| 715 |
+
"mean_rank": "Mean rank",
|
| 716 |
+
},
|
| 717 |
+
)
|
| 718 |
+
fig.update_layout(
|
| 719 |
+
template="plotly_white",
|
| 720 |
+
font=PLOT_FONT,
|
| 721 |
+
title="Differential flux vs statistical significance",
|
| 722 |
+
height=520,
|
| 723 |
+
margin=dict(l=52, r=24, t=52, b=48),
|
| 724 |
+
coloraxis_colorbar=dict(
|
| 725 |
+
title=dict(text="Mean rank", side="right"),
|
| 726 |
+
thickness=12,
|
| 727 |
+
len=0.55,
|
| 728 |
+
),
|
| 729 |
+
)
|
| 730 |
+
return fig
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def motif_tf_mean_rank_bars(atac_df: pd.DataFrame, top_n: int = 22) -> go.Figure:
|
| 734 |
+
"""Aggregate motif features by TF name (prefix before ``_<motif_id>``); show lowest mean joint rank."""
|
| 735 |
+
if atac_df.empty or "feature" not in atac_df.columns:
|
| 736 |
+
return go.Figure()
|
| 737 |
+
|
| 738 |
+
def _tf_prefix(feat: str) -> str:
|
| 739 |
+
s = str(feat)
|
| 740 |
+
if "_" in s:
|
| 741 |
+
head, tail = s.rsplit("_", 1)
|
| 742 |
+
if tail.isdigit():
|
| 743 |
+
return head
|
| 744 |
+
return s
|
| 745 |
+
|
| 746 |
+
d = atac_df.copy()
|
| 747 |
+
d["_tf"] = d["feature"].map(_tf_prefix)
|
| 748 |
+
agg = d.groupby("_tf", as_index=False)["mean_rank"].mean()
|
| 749 |
+
agg = agg.nsmallest(int(top_n), "mean_rank").sort_values("mean_rank", ascending=True)
|
| 750 |
+
if agg.empty:
|
| 751 |
+
return go.Figure()
|
| 752 |
+
y_show = agg["_tf"].astype(str).map(lambda s: _truncate_label(s, 36))
|
| 753 |
+
fig = go.Figure(
|
| 754 |
+
go.Bar(
|
| 755 |
+
y=y_show,
|
| 756 |
+
x=agg["mean_rank"],
|
| 757 |
+
orientation="h",
|
| 758 |
+
marker_color=MODALITY_COLOR.get("ATAC", PALETTE[0]),
|
| 759 |
+
customdata=agg["_tf"],
|
| 760 |
+
hovertemplate="<b>%{customdata}</b><br>Mean mean_rank (across motifs): %{x:.2f}<extra></extra>",
|
| 761 |
+
)
|
| 762 |
+
)
|
| 763 |
+
fig.update_layout(
|
| 764 |
+
template="plotly_white",
|
| 765 |
+
font=PLOT_FONT,
|
| 766 |
+
title=f"TFs by average motif rank (top {top_n} by lowest mean rank)",
|
| 767 |
+
height=min(640, 48 + 22 * len(agg)),
|
| 768 |
+
margin=dict(l=160, r=24, t=52, b=40),
|
| 769 |
+
xaxis_title="Mean of mean_rank over motif instances (lower = stronger)",
|
| 770 |
+
yaxis_title="",
|
| 771 |
+
)
|
| 772 |
+
fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
|
| 773 |
+
return fig
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def motif_chromvar_volcano(atac_df: pd.DataFrame) -> go.Figure:
|
| 777 |
+
"""Motif differential view: mean activity difference (reprogramming − dead-end) vs significance."""
|
| 778 |
+
need = ("mean_diff", "pval_adj")
|
| 779 |
+
if not all(c in atac_df.columns for c in need):
|
| 780 |
+
return go.Figure()
|
| 781 |
+
d = atac_df.dropna(subset=["mean_diff", "pval_adj"]).copy()
|
| 782 |
+
if d.empty:
|
| 783 |
+
return go.Figure()
|
| 784 |
+
md = d["mean_diff"].astype(float)
|
| 785 |
+
pa = d["pval_adj"].astype(float)
|
| 786 |
+
bad = np.isfinite(md) & np.isfinite(pa) & (np.abs(md) < 1e-12) & (pa <= 0.0)
|
| 787 |
+
d = d[~bad]
|
| 788 |
+
if d.empty:
|
| 789 |
+
return go.Figure()
|
| 790 |
+
if "pval_adj_log" in d.columns:
|
| 791 |
+
y = d["pval_adj_log"].astype(float)
|
| 792 |
+
else:
|
| 793 |
+
p = d["pval_adj"].astype(float).clip(lower=1e-300)
|
| 794 |
+
y = -np.log10(p.to_numpy())
|
| 795 |
+
d = d.assign(_y=y)
|
| 796 |
+
hover_cols = [c for c in ("group", "pval_adj", "mean_rank", "mean_de", "mean_re") if c in d.columns]
|
| 797 |
+
fig = px.scatter(
|
| 798 |
+
d,
|
| 799 |
+
x="mean_diff",
|
| 800 |
+
y="_y",
|
| 801 |
+
color="mean_rank",
|
| 802 |
+
color_continuous_scale="Viridis_r",
|
| 803 |
+
hover_name="feature",
|
| 804 |
+
hover_data=hover_cols if hover_cols else None,
|
| 805 |
+
labels={
|
| 806 |
+
"mean_diff": "Mean difference (reprogramming − dead-end)",
|
| 807 |
+
"_y": LABEL_NEG_LOG10_ADJ_P,
|
| 808 |
+
"mean_rank": "Mean rank",
|
| 809 |
+
},
|
| 810 |
+
)
|
| 811 |
+
fig.update_layout(
|
| 812 |
+
template="plotly_white",
|
| 813 |
+
font=PLOT_FONT,
|
| 814 |
+
title="TF motif differential activity (mean difference vs significance)",
|
| 815 |
+
height=520,
|
| 816 |
+
margin=dict(l=52, r=24, t=52, b=48),
|
| 817 |
+
coloraxis_colorbar=dict(title=dict(text="Mean rank", side="right"), thickness=12, len=0.55),
|
| 818 |
+
)
|
| 819 |
+
return fig
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
def notebook_style_activity_scatter(
|
| 823 |
+
df: pd.DataFrame,
|
| 824 |
+
title: str,
|
| 825 |
+
x_title: str,
|
| 826 |
+
y_title: str,
|
| 827 |
+
) -> go.Figure:
|
| 828 |
+
"""mean_de vs mean_re, colour = pval_adj_log (Reds), marker size ∝ inverse mean_rank."""
|
| 829 |
+
need = ("mean_de", "mean_re", "mean_rank", "pval_adj_log", "feature", "group")
|
| 830 |
+
if not all(c in df.columns for c in need):
|
| 831 |
+
return go.Figure()
|
| 832 |
+
d = df.dropna(subset=["mean_de", "mean_re", "mean_rank", "pval_adj_log"]).copy()
|
| 833 |
+
if d.empty:
|
| 834 |
+
return go.Figure()
|
| 835 |
+
mx = float(d["mean_rank"].max())
|
| 836 |
+
d = d.assign(_inv=(mx - d["mean_rank"].astype(float)).clip(lower=0))
|
| 837 |
+
inv = d["_inv"].astype(float)
|
| 838 |
+
lo, hi = float(inv.min()), float(inv.max())
|
| 839 |
+
if hi <= lo:
|
| 840 |
+
d["_sz"] = 6.0
|
| 841 |
+
else:
|
| 842 |
+
d["_sz"] = 3.5 + (inv - lo) / (hi - lo) * 9.0
|
| 843 |
+
|
| 844 |
+
fig = px.scatter(
|
| 845 |
+
d,
|
| 846 |
+
x="mean_de",
|
| 847 |
+
y="mean_re",
|
| 848 |
+
color="pval_adj_log",
|
| 849 |
+
color_continuous_scale="Reds",
|
| 850 |
+
size="_sz",
|
| 851 |
+
size_max=14,
|
| 852 |
+
hover_name="feature",
|
| 853 |
+
hover_data={
|
| 854 |
+
"mean_rank": ":.2f",
|
| 855 |
+
"group": True,
|
| 856 |
+
"pval_adj_log": ":.2f",
|
| 857 |
+
"_inv": False,
|
| 858 |
+
"_sz": False,
|
| 859 |
+
},
|
| 860 |
+
labels={
|
| 861 |
+
"mean_de": x_title,
|
| 862 |
+
"mean_re": y_title,
|
| 863 |
+
"pval_adj_log": "Adj. p-value (log)",
|
| 864 |
+
},
|
| 865 |
+
)
|
| 866 |
+
fig.update_traces(
|
| 867 |
+
marker=dict(line=dict(width=0.45, color="rgba(255,255,255,0.75)"), opacity=0.9),
|
| 868 |
+
selector=dict(mode="markers"),
|
| 869 |
+
)
|
| 870 |
+
fig.update_layout(
|
| 871 |
+
template="plotly_white",
|
| 872 |
+
font=PLOT_FONT,
|
| 873 |
+
title=title,
|
| 874 |
+
height=520,
|
| 875 |
+
margin=dict(l=52, r=24, t=52, b=48),
|
| 876 |
+
coloraxis_colorbar=dict(title=dict(text="Adj. p (log)", side="right"), thickness=12, len=0.55),
|
| 877 |
+
)
|
| 878 |
+
return fig
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
def pathway_bubble_suggested_height(n_paths: int) -> int:
|
| 882 |
+
"""Total figure height for pathway bubble panels (use the max of both cohorts so legends line up)."""
|
| 883 |
+
n = max(int(n_paths), 1)
|
| 884 |
+
return max(520, min(1100, 22 * n + 200))
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def pathway_enrichment_bubble_panel(
|
| 888 |
+
df: pd.DataFrame,
|
| 889 |
+
title: str,
|
| 890 |
+
*,
|
| 891 |
+
show_colorbar: bool = True,
|
| 892 |
+
layout_height: int | None = None,
|
| 893 |
+
) -> go.Figure:
|
| 894 |
+
"""Single cohort: Reactome (circle) vs KEGG (square), colour = −log₁₀ Benjamini (scale per panel)."""
|
| 895 |
+
fig = go.Figure()
|
| 896 |
+
if df.empty:
|
| 897 |
+
fig.update_layout(
|
| 898 |
+
template="plotly_white",
|
| 899 |
+
font=PLOT_FONT,
|
| 900 |
+
title=dict(text=title, x=0.5, xanchor="center"),
|
| 901 |
+
annotations=[
|
| 902 |
+
dict(
|
| 903 |
+
text="No significant pathways (Benjamini–Hochberg q < 0.05)",
|
| 904 |
+
xref="paper",
|
| 905 |
+
yref="paper",
|
| 906 |
+
x=0.5,
|
| 907 |
+
y=0.5,
|
| 908 |
+
showarrow=False,
|
| 909 |
+
font=dict(size=13, color="#64748b"),
|
| 910 |
+
)
|
| 911 |
+
],
|
| 912 |
+
height=320,
|
| 913 |
+
margin=dict(l=40, r=40, t=56, b=40),
|
| 914 |
+
)
|
| 915 |
+
return fig
|
| 916 |
+
|
| 917 |
+
# More genes in the overlap first, then stronger gene ratio (matches enrichment table emphasis).
|
| 918 |
+
d = df.sort_values(by=["Count", "Gene Ratio"], ascending=[False, False]).reset_index(drop=True)
|
| 919 |
+
d = d.assign(
|
| 920 |
+
_neglog=-np.log10(d["Benjamini"].astype(float).clip(lower=1e-300)),
|
| 921 |
+
_y=np.arange(len(d), dtype=float),
|
| 922 |
+
)
|
| 923 |
+
nl = d["_neglog"].astype(float)
|
| 924 |
+
cmin = float(nl.min())
|
| 925 |
+
cmax = float(nl.max())
|
| 926 |
+
if cmax <= cmin:
|
| 927 |
+
cmax = cmin + 1e-6
|
| 928 |
+
|
| 929 |
+
# Single trace: per-panel cmin/cmax so Viridis uses the cohort’s range (shared global max clusters at one hue).
|
| 930 |
+
sym_map = {"Reactome": "circle", "KEGG": "square"}
|
| 931 |
+
symbols = [sym_map.get(str(x), "circle") for x in d["Library"].tolist()]
|
| 932 |
+
sz = np.sqrt(d["Count"].astype(float).clip(lower=1)) * 4.8
|
| 933 |
+
customdata = np.stack(
|
| 934 |
+
[d["Count"].to_numpy(), d["_neglog"].to_numpy(), d["Library"].astype(str).to_numpy()],
|
| 935 |
+
axis=1,
|
| 936 |
+
)
|
| 937 |
+
fig.add_trace(
|
| 938 |
+
go.Scatter(
|
| 939 |
+
x=d["Gene Ratio"],
|
| 940 |
+
y=d["_y"],
|
| 941 |
+
mode="markers",
|
| 942 |
+
name="Pathways",
|
| 943 |
+
showlegend=False,
|
| 944 |
+
marker=dict(
|
| 945 |
+
size=sz,
|
| 946 |
+
sizemode="diameter",
|
| 947 |
+
sizemin=4,
|
| 948 |
+
symbol=symbols,
|
| 949 |
+
color=d["_neglog"],
|
| 950 |
+
cmin=cmin,
|
| 951 |
+
cmax=cmax,
|
| 952 |
+
colorscale="Viridis",
|
| 953 |
+
showscale=bool(show_colorbar),
|
| 954 |
+
colorbar=dict(
|
| 955 |
+
title=dict(
|
| 956 |
+
text="\u2212log\u2081\u2080 q",
|
| 957 |
+
side="right",
|
| 958 |
+
),
|
| 959 |
+
len=0.72,
|
| 960 |
+
thickness=12,
|
| 961 |
+
y=0.45,
|
| 962 |
+
yanchor="middle",
|
| 963 |
+
outlinewidth=0,
|
| 964 |
+
)
|
| 965 |
+
if show_colorbar
|
| 966 |
+
else None,
|
| 967 |
+
line=dict(width=0.75, color="rgba(0,0,0,0.5)"),
|
| 968 |
+
opacity=0.92,
|
| 969 |
+
),
|
| 970 |
+
text=d["Term"],
|
| 971 |
+
customdata=customdata,
|
| 972 |
+
hovertemplate=(
|
| 973 |
+
"<b>%{text}</b><br>%{customdata[2]}<br>Gene ratio: %{x:.3f}<br>Count: %{customdata[0]}"
|
| 974 |
+
"<br>\u2212log\u2081\u2080 Benjamini: %{customdata[1]:.2f}<extra></extra>"
|
| 975 |
+
),
|
| 976 |
+
)
|
| 977 |
+
)
|
| 978 |
+
for lib, sym in (("Reactome", "circle"), ("KEGG", "square")):
|
| 979 |
+
if lib not in set(d["Library"].astype(str)):
|
| 980 |
+
continue
|
| 981 |
+
fig.add_trace(
|
| 982 |
+
go.Scatter(
|
| 983 |
+
x=[None],
|
| 984 |
+
y=[None],
|
| 985 |
+
mode="markers",
|
| 986 |
+
name=lib,
|
| 987 |
+
marker=dict(
|
| 988 |
+
symbol=sym,
|
| 989 |
+
size=11,
|
| 990 |
+
color="#475569",
|
| 991 |
+
line=dict(width=1, color="rgba(0,0,0,0.45)"),
|
| 992 |
+
),
|
| 993 |
+
showlegend=True,
|
| 994 |
+
)
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
ticktext = [_truncate_label(str(t), 52) for t in d["Term"]]
|
| 998 |
+
h = int(layout_height) if layout_height is not None else pathway_bubble_suggested_height(len(d))
|
| 999 |
+
fig.update_yaxes(
|
| 1000 |
+
tickmode="array",
|
| 1001 |
+
tickvals=d["_y"].tolist(),
|
| 1002 |
+
ticktext=ticktext,
|
| 1003 |
+
autorange="reversed",
|
| 1004 |
+
title="",
|
| 1005 |
+
)
|
| 1006 |
+
fig.update_xaxes(title_text="Gene ratio (count ÷ list total)")
|
| 1007 |
+
fig.update_layout(
|
| 1008 |
+
template="plotly_white",
|
| 1009 |
+
font=PLOT_FONT,
|
| 1010 |
+
title=dict(
|
| 1011 |
+
text=title,
|
| 1012 |
+
x=0.5,
|
| 1013 |
+
xanchor="center",
|
| 1014 |
+
yanchor="top",
|
| 1015 |
+
y=0.985,
|
| 1016 |
+
pad=dict(b=0),
|
| 1017 |
+
),
|
| 1018 |
+
height=h,
|
| 1019 |
+
margin=dict(l=215, r=132, t=48, b=108),
|
| 1020 |
+
legend=dict(
|
| 1021 |
+
orientation="h",
|
| 1022 |
+
yanchor="top",
|
| 1023 |
+
y=-0.11,
|
| 1024 |
+
xanchor="center",
|
| 1025 |
+
x=0.5,
|
| 1026 |
+
bgcolor="rgba(255,255,255,0.92)",
|
| 1027 |
+
bordercolor="rgba(0,0,0,0.08)",
|
| 1028 |
+
borderwidth=1,
|
| 1029 |
+
),
|
| 1030 |
+
showlegend=True,
|
| 1031 |
+
)
|
| 1032 |
+
return fig
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
+
def pathway_gene_membership_heatmap(
|
| 1036 |
+
z: np.ndarray, row_labels: list[str], col_labels: list[str]
|
| 1037 |
+
) -> go.Figure:
|
| 1038 |
+
"""Pathway × gene grid; empty cells transparent; light gaps; legend for category colours."""
|
| 1039 |
+
if z.size == 0:
|
| 1040 |
+
return go.Figure()
|
| 1041 |
+
# Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots.
|
| 1042 |
+
_z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84}
|
| 1043 |
+
zn = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z).astype(float)
|
| 1044 |
+
transparent = "rgba(0,0,0,0)"
|
| 1045 |
+
colorscale = [
|
| 1046 |
+
[0.0, transparent],
|
| 1047 |
+
[0.14, transparent],
|
| 1048 |
+
[0.15, "#e69138"],
|
| 1049 |
+
[0.33, "#e69138"],
|
| 1050 |
+
[0.34, "#7eb6d9"],
|
| 1051 |
+
[0.53, "#7eb6d9"],
|
| 1052 |
+
[0.54, "#9ccc65"],
|
| 1053 |
+
[0.73, "#9ccc65"],
|
| 1054 |
+
[0.74, "#283593"],
|
| 1055 |
+
[1.0, "#283593"],
|
| 1056 |
+
]
|
| 1057 |
+
|
| 1058 |
+
def _cell_hint(v: float) -> str:
|
| 1059 |
+
k = int(round(float(v)))
|
| 1060 |
+
return {
|
| 1061 |
+
0: "",
|
| 1062 |
+
1: "Gene enriched in dead-end contrast",
|
| 1063 |
+
2: "Gene enriched in reprogramming contrast",
|
| 1064 |
+
3: "Reactome pathway set",
|
| 1065 |
+
4: "KEGG pathway set",
|
| 1066 |
+
}.get(k, "")
|
| 1067 |
+
|
| 1068 |
+
z_int = z.astype(int)
|
| 1069 |
+
text_grid = [[_cell_hint(z_int[i, j]) for j in range(z.shape[1])] for i in range(z.shape[0])]
|
| 1070 |
+
|
| 1071 |
+
heat = go.Heatmap(
|
| 1072 |
+
z=zn,
|
| 1073 |
+
x=col_labels,
|
| 1074 |
+
y=row_labels,
|
| 1075 |
+
text=text_grid,
|
| 1076 |
+
colorscale=colorscale,
|
| 1077 |
+
zmin=0,
|
| 1078 |
+
zmax=1,
|
| 1079 |
+
showscale=False,
|
| 1080 |
+
xgap=1,
|
| 1081 |
+
ygap=1,
|
| 1082 |
+
hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
fig = go.Figure(data=[heat])
|
| 1086 |
+
|
| 1087 |
+
n_rows, n_cols = z.shape
|
| 1088 |
+
cell_w = 10
|
| 1089 |
+
cell_h = 20
|
| 1090 |
+
w = int(min(1000, max(460, n_cols * cell_w + 272)))
|
| 1091 |
+
h = int(min(960, max(460, n_rows * cell_h + 128)))
|
| 1092 |
+
fig.update_layout(
|
| 1093 |
+
template="plotly_white",
|
| 1094 |
+
font=PLOT_FONT,
|
| 1095 |
+
title=dict(text="Pathway–gene membership", x=0.5, xanchor="center"),
|
| 1096 |
+
height=h,
|
| 1097 |
+
width=w,
|
| 1098 |
+
margin=dict(l=4, r=168, t=52, b=108),
|
| 1099 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 1100 |
+
plot_bgcolor="#f4f6f9",
|
| 1101 |
+
xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False),
|
| 1102 |
+
yaxis=dict(
|
| 1103 |
+
tickfont=dict(size=9),
|
| 1104 |
+
showgrid=False,
|
| 1105 |
+
zeroline=False,
|
| 1106 |
+
autorange="reversed",
|
| 1107 |
+
),
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
legend_markers = [
|
| 1111 |
+
("Empty cell", "#f1f5f9", "square"),
|
| 1112 |
+
("Dead-end–linked gene", "#e69138", "square"),
|
| 1113 |
+
("Reprogramming–linked gene", "#7eb6d9", "square"),
|
| 1114 |
+
("Reactome (column tag)", "#9ccc65", "square"),
|
| 1115 |
+
("KEGG (column tag)", "#283593", "square"),
|
| 1116 |
+
]
|
| 1117 |
+
for name, color, sym in legend_markers:
|
| 1118 |
+
fig.add_trace(
|
| 1119 |
+
go.Scatter(
|
| 1120 |
+
x=[None],
|
| 1121 |
+
y=[None],
|
| 1122 |
+
mode="markers",
|
| 1123 |
+
name=name,
|
| 1124 |
+
marker=dict(size=11, color=color, symbol=sym, line=dict(width=1, color="rgba(0,0,0,0.25)")),
|
| 1125 |
+
showlegend=True,
|
| 1126 |
+
)
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
fig.update_layout(
|
| 1130 |
+
legend=dict(
|
| 1131 |
+
orientation="v",
|
| 1132 |
+
yanchor="top",
|
| 1133 |
+
y=0.98,
|
| 1134 |
+
xanchor="left",
|
| 1135 |
+
x=1.02,
|
| 1136 |
+
bgcolor="rgba(255,255,255,0.92)",
|
| 1137 |
+
bordercolor="rgba(0,0,0,0.08)",
|
| 1138 |
+
borderwidth=1,
|
| 1139 |
+
font=dict(size=11),
|
| 1140 |
+
)
|
| 1141 |
+
)
|
| 1142 |
+
return fig
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def flux_dead_end_vs_reprogram_scatter(flux_df: pd.DataFrame, max_pathway_colors: int = 12) -> go.Figure:
|
| 1146 |
+
need = ("mean_de", "mean_re")
|
| 1147 |
+
if not all(c in flux_df.columns for c in need):
|
| 1148 |
+
return go.Figure()
|
| 1149 |
+
d = flux_df.dropna(subset=list(need)).copy()
|
| 1150 |
+
if d.empty:
|
| 1151 |
+
return go.Figure()
|
| 1152 |
+
imp = (
|
| 1153 |
+
d["importance_shift"].astype(float).clip(lower=0) * d["importance_att"].astype(float).clip(lower=0)
|
| 1154 |
+
) ** 0.5
|
| 1155 |
+
q = float(imp.quantile(0.95)) if len(imp) else 1.0
|
| 1156 |
+
d = d.assign(_s=(imp / (q or 1.0)).clip(upper=1) * 20 + 5)
|
| 1157 |
+
pw = d["pathway"].fillna("Unknown").astype(str) if "pathway" in d.columns else pd.Series(
|
| 1158 |
+
["Unknown"] * len(d), index=d.index
|
| 1159 |
+
)
|
| 1160 |
+
top_pw = pw.value_counts().head(int(max_pathway_colors)).index
|
| 1161 |
+
d = d.assign(_pw_col=pw.where(pw.isin(top_pw), "Other"))
|
| 1162 |
+
uniq = sorted(d["_pw_col"].astype(str).unique(), key=lambda x: (x == "Other", x))
|
| 1163 |
+
pal = list(LATENT_DISCRETE_PALETTE)
|
| 1164 |
+
pw_cmap: dict[str, str] = {}
|
| 1165 |
+
j = 0
|
| 1166 |
+
for name in uniq:
|
| 1167 |
+
if name == "Other":
|
| 1168 |
+
pw_cmap[name] = "#94a3b8"
|
| 1169 |
+
else:
|
| 1170 |
+
pw_cmap[name] = pal[j % len(pal)]
|
| 1171 |
+
j += 1
|
| 1172 |
+
fig = px.scatter(
|
| 1173 |
+
d,
|
| 1174 |
+
x="mean_de",
|
| 1175 |
+
y="mean_re",
|
| 1176 |
+
color="_pw_col",
|
| 1177 |
+
color_discrete_map=pw_cmap,
|
| 1178 |
+
size="_s",
|
| 1179 |
+
hover_name="feature",
|
| 1180 |
+
hover_data=["mean_rank", "log_fc", "pathway"],
|
| 1181 |
+
labels={
|
| 1182 |
+
"mean_de": "Mean flux · dead-end",
|
| 1183 |
+
"mean_re": "Mean flux · reprogramming",
|
| 1184 |
+
"_pw_col": "Pathway",
|
| 1185 |
+
},
|
| 1186 |
+
)
|
| 1187 |
+
fig.update_layout(
|
| 1188 |
+
template="plotly_white",
|
| 1189 |
+
font=PLOT_FONT,
|
| 1190 |
+
height=540,
|
| 1191 |
+
margin=dict(l=52, r=20, t=52, b=40),
|
| 1192 |
+
title="Average measured flux by fate label (each point is one reaction)",
|
| 1193 |
+
legend=dict(orientation="h", yanchor="top", y=-0.28, xanchor="center", x=0.5),
|
| 1194 |
+
)
|
| 1195 |
+
fig.update_traces(marker=dict(opacity=0.75, line=dict(width=0.35, color="rgba(0,0,0,0.3)")))
|
| 1196 |
+
return fig
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
def flux_pathway_mean_rank_violin(flux_df: pd.DataFrame, top_pathways: int = 12) -> go.Figure:
|
| 1200 |
+
sub = flux_df.dropna(subset=["pathway"]).copy()
|
| 1201 |
+
if sub.empty:
|
| 1202 |
+
return go.Figure()
|
| 1203 |
+
top_p = sub["pathway"].astype(str).value_counts().head(int(top_pathways)).index
|
| 1204 |
+
sub = sub[sub["pathway"].astype(str).isin(top_p)]
|
| 1205 |
+
top_list = list(top_p)
|
| 1206 |
+
v_cmap = {p: LATENT_DISCRETE_PALETTE[i % len(LATENT_DISCRETE_PALETTE)] for i, p in enumerate(top_list)}
|
| 1207 |
+
fig = px.violin(
|
| 1208 |
+
sub,
|
| 1209 |
+
x="pathway",
|
| 1210 |
+
y="mean_rank",
|
| 1211 |
+
box=True,
|
| 1212 |
+
points=False,
|
| 1213 |
+
color="pathway",
|
| 1214 |
+
color_discrete_map=v_cmap,
|
| 1215 |
+
labels={"mean_rank": "Mean rank (lower = stronger model focus)", "pathway": "Pathway"},
|
| 1216 |
+
)
|
| 1217 |
+
fig.update_layout(
|
| 1218 |
+
template="plotly_white",
|
| 1219 |
+
font=PLOT_FONT,
|
| 1220 |
+
showlegend=False,
|
| 1221 |
+
height=420,
|
| 1222 |
+
xaxis_tickangle=-32,
|
| 1223 |
+
margin=dict(l=48, r=24, t=48, b=140),
|
| 1224 |
+
title="How joint model rank spreads within high-coverage pathways",
|
| 1225 |
+
)
|
| 1226 |
+
return fig
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
def flux_reaction_annotation_panel(flux_df: pd.DataFrame, top_n: int = 26, metric: str = "mean_rank") -> go.Figure:
|
| 1230 |
+
"""Three heatmap columns: pathway (categorical), DE Log₂FC, −log₁₀ adjusted p."""
|
| 1231 |
+
top = _flux_prepare_top_ranked(flux_df, top_n, metric)
|
| 1232 |
+
if top.empty:
|
| 1233 |
+
return go.Figure()
|
| 1234 |
+
n = len(top)
|
| 1235 |
+
pathways = top["pathway"].fillna("Unknown").astype(str).tolist() if "pathway" in top.columns else ["Unknown"] * n
|
| 1236 |
+
uniq = list(dict.fromkeys(pathways))
|
| 1237 |
+
code_map = {u: i for i, u in enumerate(uniq)}
|
| 1238 |
+
codes = np.array([code_map[p] for p in pathways], dtype=float)
|
| 1239 |
+
k = max(len(uniq), 1)
|
| 1240 |
+
qual = list(px.colors.qualitative.Safe) + list(px.colors.qualitative.Dark24) + list(px.colors.qualitative.Light24)
|
| 1241 |
+
if k <= 1:
|
| 1242 |
+
disc_scale = [[0, qual[0]], [1, qual[0]]]
|
| 1243 |
+
else:
|
| 1244 |
+
disc_scale = [[j / (k - 1), qual[j % len(qual)]] for j in range(k)]
|
| 1245 |
+
log_fc = top["log_fc"].fillna(0).astype(float).to_numpy() if "log_fc" in top.columns else np.zeros(n)
|
| 1246 |
+
if "pval_adj_log" in top.columns:
|
| 1247 |
+
pv = top["pval_adj_log"].fillna(0).astype(float).to_numpy()
|
| 1248 |
+
else:
|
| 1249 |
+
pv = -np.log10(top["pval_adj"].astype(float).clip(lower=1e-300).to_numpy())
|
| 1250 |
+
full_features = top["feature"].astype(str).tolist()
|
| 1251 |
+
y_labels = [_truncate_label(str(f), 44) for f in full_features]
|
| 1252 |
+
z_path = codes.reshape(-1, 1)
|
| 1253 |
+
# hovertext (not customdata): subplot heatmaps often render %{customdata[0]} as "-" in the browser.
|
| 1254 |
+
hover_path = [[f"<b>{fn}</b><br>pathway: {pw}"] for fn, pw in zip(full_features, pathways)]
|
| 1255 |
+
hover_lfc = [
|
| 1256 |
+
[f"<b>{fn}</b><br>{LABEL_LOG2FC}: {float(log_fc[i]):.4f}"]
|
| 1257 |
+
for i, fn in enumerate(full_features)
|
| 1258 |
+
]
|
| 1259 |
+
hover_pv = [
|
| 1260 |
+
[f"<b>{fn}</b><br>{LABEL_NEG_LOG10_ADJ_P}: {float(pv[i]):.2f}"]
|
| 1261 |
+
for i, fn in enumerate(full_features)
|
| 1262 |
+
]
|
| 1263 |
+
fig = make_subplots(
|
| 1264 |
+
rows=1,
|
| 1265 |
+
cols=3,
|
| 1266 |
+
shared_yaxes=True,
|
| 1267 |
+
horizontal_spacing=0.06,
|
| 1268 |
+
column_widths=[0.24, 0.24, 0.24],
|
| 1269 |
+
)
|
| 1270 |
+
fig.add_trace(
|
| 1271 |
+
go.Heatmap(
|
| 1272 |
+
z=z_path,
|
| 1273 |
+
x=[""],
|
| 1274 |
+
y=y_labels,
|
| 1275 |
+
colorscale=disc_scale,
|
| 1276 |
+
zmin=0,
|
| 1277 |
+
zmax=max(k - 1, 0),
|
| 1278 |
+
showscale=False,
|
| 1279 |
+
hovertext=hover_path,
|
| 1280 |
+
hovertemplate="%{hovertext}<extra></extra>",
|
| 1281 |
+
),
|
| 1282 |
+
row=1,
|
| 1283 |
+
col=1,
|
| 1284 |
+
)
|
| 1285 |
+
fig.add_trace(
|
| 1286 |
+
go.Heatmap(
|
| 1287 |
+
z=log_fc.reshape(-1, 1),
|
| 1288 |
+
x=[""],
|
| 1289 |
+
y=y_labels,
|
| 1290 |
+
colorscale=LOG_FC_DIVERGING_SCALE,
|
| 1291 |
+
zmin=LOG_FC_COLOR_MIN,
|
| 1292 |
+
zmax=LOG_FC_COLOR_MAX,
|
| 1293 |
+
showscale=True,
|
| 1294 |
+
colorbar=dict(
|
| 1295 |
+
title=dict(text=LABEL_LOG2FC, side="right"),
|
| 1296 |
+
tickformat=".2f",
|
| 1297 |
+
len=0.22,
|
| 1298 |
+
y=0.71,
|
| 1299 |
+
yanchor="middle",
|
| 1300 |
+
x=1.0,
|
| 1301 |
+
xanchor="left",
|
| 1302 |
+
xref="paper",
|
| 1303 |
+
yref="paper",
|
| 1304 |
+
thickness=12,
|
| 1305 |
+
),
|
| 1306 |
+
hovertext=hover_lfc,
|
| 1307 |
+
hovertemplate="%{hovertext}<extra></extra>",
|
| 1308 |
+
),
|
| 1309 |
+
row=1,
|
| 1310 |
+
col=2,
|
| 1311 |
+
)
|
| 1312 |
+
fig.add_trace(
|
| 1313 |
+
go.Heatmap(
|
| 1314 |
+
z=pv.reshape(-1, 1),
|
| 1315 |
+
x=[""],
|
| 1316 |
+
y=y_labels,
|
| 1317 |
+
colorscale="Viridis",
|
| 1318 |
+
showscale=True,
|
| 1319 |
+
colorbar=dict(
|
| 1320 |
+
title=dict(text=LABEL_NEG_LOG10_ADJ_P, side="right"),
|
| 1321 |
+
len=0.22,
|
| 1322 |
+
y=0.29,
|
| 1323 |
+
yanchor="middle",
|
| 1324 |
+
x=1.0,
|
| 1325 |
+
xanchor="left",
|
| 1326 |
+
xref="paper",
|
| 1327 |
+
yref="paper",
|
| 1328 |
+
thickness=12,
|
| 1329 |
+
),
|
| 1330 |
+
hovertext=hover_pv,
|
| 1331 |
+
hovertemplate="%{hovertext}<extra></extra>",
|
| 1332 |
+
),
|
| 1333 |
+
row=1,
|
| 1334 |
+
col=3,
|
| 1335 |
+
)
|
| 1336 |
+
fig.update_layout(
|
| 1337 |
+
template="plotly_white",
|
| 1338 |
+
font=PLOT_FONT,
|
| 1339 |
+
height=min(820, 120 + n * 22),
|
| 1340 |
+
width=900,
|
| 1341 |
+
margin=dict(l=8, r=108, t=56, b=72),
|
| 1342 |
+
title=dict(
|
| 1343 |
+
text=f"Pathway, {LABEL_LOG2FC}, and significance",
|
| 1344 |
+
x=0,
|
| 1345 |
+
xanchor="left",
|
| 1346 |
+
y=0.995,
|
| 1347 |
+
yanchor="top",
|
| 1348 |
+
font=dict(size=13, family=PLOT_FONT["family"]),
|
| 1349 |
+
pad=dict(b=8, l=4),
|
| 1350 |
+
),
|
| 1351 |
+
)
|
| 1352 |
+
fig.update_xaxes(side="bottom", title_standoff=8)
|
| 1353 |
+
fig.update_xaxes(title_text="Pathway", row=1, col=1)
|
| 1354 |
+
fig.update_xaxes(title_text=LABEL_LOG2FC, row=1, col=2)
|
| 1355 |
+
fig.update_xaxes(title_text=LABEL_NEG_LOG10_ADJ_P, row=1, col=3)
|
| 1356 |
+
fig.update_yaxes(autorange="reversed")
|
| 1357 |
+
return fig
|
| 1358 |
+
|
| 1359 |
+
|
| 1360 |
+
def flux_model_metric_profile(flux_df: pd.DataFrame, top_n: int = 22, metric: str = "mean_rank") -> go.Figure:
|
| 1361 |
+
"""Matrix view: scaled shift, attention, model priority, and fate flux contrast."""
|
| 1362 |
+
top = _flux_prepare_top_ranked(flux_df, top_n, metric)
|
| 1363 |
+
if top.empty:
|
| 1364 |
+
return go.Figure()
|
| 1365 |
+
|
| 1366 |
+
def mm(s: pd.Series) -> np.ndarray:
|
| 1367 |
+
v = s.astype(float).to_numpy()
|
| 1368 |
+
lo, hi = float(np.nanmin(v)), float(np.nanmax(v))
|
| 1369 |
+
if hi <= lo or not np.isfinite(lo):
|
| 1370 |
+
return np.zeros_like(v, dtype=float)
|
| 1371 |
+
return (v - lo) / (hi - lo)
|
| 1372 |
+
|
| 1373 |
+
cols: list[np.ndarray] = []
|
| 1374 |
+
labels: list[str] = []
|
| 1375 |
+
for c, lab in (("importance_shift", "Latent shift impact"), ("importance_att", "Attention (rollout)")):
|
| 1376 |
+
if c in top.columns:
|
| 1377 |
+
cols.append(mm(top[c]))
|
| 1378 |
+
labels.append(lab)
|
| 1379 |
+
cols.append(1.0 - mm(top["mean_rank"]))
|
| 1380 |
+
labels.append("Joint priority (1 - scaled mean rank)")
|
| 1381 |
+
if "mean_de" in top.columns and "mean_re" in top.columns:
|
| 1382 |
+
de = top["mean_de"].astype(float).replace(0, np.nan)
|
| 1383 |
+
ratio = (top["mean_re"].astype(float) / (de + 1e-12)).fillna(0)
|
| 1384 |
+
cols.append(mm(ratio))
|
| 1385 |
+
labels.append("RE / DE mean flux (scaled)")
|
| 1386 |
+
z = np.column_stack(cols)
|
| 1387 |
+
full_rxn = top["feature"].astype(str).tolist()
|
| 1388 |
+
x_labels = [_truncate_label(str(f), 34) for f in full_rxn]
|
| 1389 |
+
fig = px.imshow(
|
| 1390 |
+
z.T,
|
| 1391 |
+
x=x_labels,
|
| 1392 |
+
y=labels,
|
| 1393 |
+
aspect="auto",
|
| 1394 |
+
color_continuous_scale="Tealrose",
|
| 1395 |
+
labels=dict(x="Reaction", y="Metric", color="Scaled 0-1 per metric"),
|
| 1396 |
+
)
|
| 1397 |
+
n_met, n_rxn = z.T.shape
|
| 1398 |
+
hover_cd = np.broadcast_to(np.array(full_rxn, dtype=object), (n_met, n_rxn))
|
| 1399 |
+
fig.update_traces(
|
| 1400 |
+
customdata=hover_cd,
|
| 1401 |
+
hovertemplate="<b>%{customdata}</b><br>%{y}<br>scaled: %{z:.3f}<extra></extra>",
|
| 1402 |
+
)
|
| 1403 |
+
fig.update_xaxes(tickangle=-50, side="bottom", title_standoff=12)
|
| 1404 |
+
fig.update_layout(
|
| 1405 |
+
template="plotly_white",
|
| 1406 |
+
font=PLOT_FONT,
|
| 1407 |
+
height=min(380, 140 + len(labels) * 36),
|
| 1408 |
+
margin=dict(l=200, r=28, t=64, b=200),
|
| 1409 |
+
title=dict(
|
| 1410 |
+
text="Reaction profile",
|
| 1411 |
+
x=0,
|
| 1412 |
+
xanchor="left",
|
| 1413 |
+
y=0.98,
|
| 1414 |
+
yanchor="top",
|
| 1415 |
+
font=dict(size=13, family=PLOT_FONT["family"]),
|
| 1416 |
+
pad=dict(b=10, l=4),
|
| 1417 |
+
),
|
| 1418 |
+
)
|
| 1419 |
+
return fig
|
| 1420 |
+
|
| 1421 |
+
|
streamlit_hf/lib/reactions.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared reaction-string normalisation (flux features vs metabolic metadata)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def normalize_reaction_key(name: str) -> str:
|
| 9 |
+
"""Map `A→B` style names to the same key as metadata `A -> B` (case-insensitive)."""
|
| 10 |
+
t = str(name).strip().replace("→", " -> ")
|
| 11 |
+
t = re.sub(r"\s+", " ", t)
|
| 12 |
+
return t.lower()
|
streamlit_hf/lib/ui.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Light shared styles (no heavy themes; keeps default Streamlit + plotly_white)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def inject_app_styles() -> None:
|
| 9 |
+
"""Panel labels and home cards; safe to call on every rerun (small CSS block)."""
|
| 10 |
+
st.markdown(
|
| 11 |
+
"""
|
| 12 |
+
<style>
|
| 13 |
+
.latent-panel-title {
|
| 14 |
+
font-size: 0.82rem;
|
| 15 |
+
font-weight: 600;
|
| 16 |
+
color: #475569;
|
| 17 |
+
margin: 0 0 0.35rem 0;
|
| 18 |
+
letter-spacing: 0.02em;
|
| 19 |
+
}
|
| 20 |
+
.latent-panel-title-gap { margin-top: 0.85rem; }
|
| 21 |
+
</style>
|
| 22 |
+
""",
|
| 23 |
+
unsafe_allow_html=True,
|
| 24 |
+
)
|
streamlit_hf/pages/1_Single_Cell_Explorer.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interactive UMAP of multimodal latent space (validation folds)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
|
| 11 |
+
_REPO = Path(__file__).resolve().parents[2]
|
| 12 |
+
if str(_REPO) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(_REPO))
|
| 14 |
+
|
| 15 |
+
from streamlit_hf.lib import formatters
|
| 16 |
+
from streamlit_hf.lib import io
|
| 17 |
+
from streamlit_hf.lib import plots
|
| 18 |
+
from streamlit_hf.lib import ui
|
| 19 |
+
|
| 20 |
+
ui.inject_app_styles()
|
| 21 |
+
|
| 22 |
+
st.title("Single-Cell Explorer")
|
| 23 |
+
st.caption("Explore validation cells in 2-D UMAP space: colour and filter to compare fates, predictions, and modalities.")
|
| 24 |
+
|
| 25 |
+
bundle = io.load_latent_bundle()
|
| 26 |
+
if bundle is None:
|
| 27 |
+
st.error("Latent maps are not available in this session. Ask the maintainer to publish results, then reload.")
|
| 28 |
+
st.stop()
|
| 29 |
+
|
| 30 |
+
samples = io.load_samples_df()
|
| 31 |
+
df = io.latent_join_samples(bundle, samples)
|
| 32 |
+
|
| 33 |
+
left, right = st.columns([0.36, 0.64], gap="large")
|
| 34 |
+
|
| 35 |
+
with left:
|
| 36 |
+
st.markdown('<p class="latent-panel-title">Colour by</p>', unsafe_allow_html=True)
|
| 37 |
+
color_opt = st.selectbox(
|
| 38 |
+
"Hue",
|
| 39 |
+
[
|
| 40 |
+
"label",
|
| 41 |
+
"predicted_class",
|
| 42 |
+
"correct",
|
| 43 |
+
"fold",
|
| 44 |
+
"batch_no",
|
| 45 |
+
"modality_label",
|
| 46 |
+
"pct",
|
| 47 |
+
],
|
| 48 |
+
format_func=lambda x: {
|
| 49 |
+
"label": "CellTag-Multi label",
|
| 50 |
+
"predicted_class": "Predicted fate",
|
| 51 |
+
"correct": "Prediction correct",
|
| 52 |
+
"fold": "CV fold",
|
| 53 |
+
"batch_no": "Batch",
|
| 54 |
+
"modality_label": "Available modalities",
|
| 55 |
+
"pct": "Dominant fate %",
|
| 56 |
+
}[x],
|
| 57 |
+
label_visibility="collapsed",
|
| 58 |
+
help="Which variable sets the colour of each point on the UMAP.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
st.markdown('<p class="latent-panel-title latent-panel-title-gap">Filters</p>', unsafe_allow_html=True)
|
| 62 |
+
mod_labels = sorted(df["modality_label"].astype(str).unique())
|
| 63 |
+
mod_pick = st.multiselect(
|
| 64 |
+
"Available modalities",
|
| 65 |
+
mod_labels,
|
| 66 |
+
default=mod_labels,
|
| 67 |
+
help="Keep cells whose modality combination matches your selection (RNA/ATAC measured where present; flux inferred).",
|
| 68 |
+
)
|
| 69 |
+
only_correct = st.selectbox(
|
| 70 |
+
"Prediction outcome",
|
| 71 |
+
["All", "Correct only", "Wrong only"],
|
| 72 |
+
help="Restrict to cells where the model was correct, incorrect, or show all.",
|
| 73 |
+
)
|
| 74 |
+
folds = sorted(df["fold"].unique())
|
| 75 |
+
fold_pick = st.multiselect(
|
| 76 |
+
"CV folds",
|
| 77 |
+
folds,
|
| 78 |
+
default=folds,
|
| 79 |
+
help="Validation cross-validation folds to include (each fold’s held-out cells).",
|
| 80 |
+
)
|
| 81 |
+
pct_rng = st.slider(
|
| 82 |
+
"Dominant fate % range",
|
| 83 |
+
0.0,
|
| 84 |
+
100.0,
|
| 85 |
+
(0.0, 100.0),
|
| 86 |
+
1.0,
|
| 87 |
+
help="Keep cells whose dominant lineage probability (percent) falls in this range.",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
plot_df = df[df["fold"].isin(fold_pick) & df["modality_label"].isin(mod_pick)].copy()
|
| 91 |
+
plot_df = plot_df[(plot_df["pct"] >= pct_rng[0]) & (plot_df["pct"] <= pct_rng[1])]
|
| 92 |
+
if only_correct == "Correct only":
|
| 93 |
+
plot_df = plot_df[plot_df["correct"]]
|
| 94 |
+
elif only_correct == "Wrong only":
|
| 95 |
+
plot_df = plot_df[~plot_df["correct"]]
|
| 96 |
+
|
| 97 |
+
if plot_df.empty:
|
| 98 |
+
st.warning("No points after filters. Relax the filters and try again.")
|
| 99 |
+
st.stop()
|
| 100 |
+
|
| 101 |
+
with right:
|
| 102 |
+
fig = plots.latent_scatter(
|
| 103 |
+
plot_df,
|
| 104 |
+
color_opt,
|
| 105 |
+
title="Validation latent space (UMAP)",
|
| 106 |
+
width=900,
|
| 107 |
+
height=560,
|
| 108 |
+
marker_size=5.8,
|
| 109 |
+
marker_opacity=0.74,
|
| 110 |
+
)
|
| 111 |
+
st.plotly_chart(fig, width="stretch", on_select="rerun", key="latent_pick")
|
| 112 |
+
|
| 113 |
+
st.subheader("Selected points")
|
| 114 |
+
state = st.session_state.get("latent_pick")
|
| 115 |
+
points = []
|
| 116 |
+
if isinstance(state, dict):
|
| 117 |
+
sel = state.get("selection") or {}
|
| 118 |
+
if isinstance(sel, dict):
|
| 119 |
+
points = sel.get("points") or []
|
| 120 |
+
if points:
|
| 121 |
+
idxs = [int(p["point_index"]) for p in points if "point_index" in p]
|
| 122 |
+
idxs = [i for i in idxs if 0 <= i < len(plot_df)]
|
| 123 |
+
if idxs:
|
| 124 |
+
sub = plot_df.iloc[idxs]
|
| 125 |
+
disp = formatters.prepare_latent_display_dataframe(sub)
|
| 126 |
+
st.dataframe(
|
| 127 |
+
disp,
|
| 128 |
+
width="stretch",
|
| 129 |
+
hide_index=True,
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
st.warning(
|
| 133 |
+
"A selection was reported but no valid points matched the current filtered view. "
|
| 134 |
+
"Try selecting again after changing filters, or pick a row via **Inspect by dataset index**."
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
st.info(
|
| 138 |
+
"This table fills in when you **select points on the UMAP**. "
|
| 139 |
+
"In the chart’s top-right toolbar, choose **Box select** or **Lasso select**, "
|
| 140 |
+
"then drag over the dots; the page reruns and rows for those cells appear here. "
|
| 141 |
+
"To inspect one cell without using the lasso, scroll down to **Inspect by dataset index**."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
st.subheader("Inspect by dataset index")
|
| 145 |
+
pick = st.number_input(
|
| 146 |
+
"Dataset index",
|
| 147 |
+
min_value=int(df["dataset_idx"].min()),
|
| 148 |
+
max_value=int(df["dataset_idx"].max()),
|
| 149 |
+
value=int(df["dataset_idx"].iloc[0]),
|
| 150 |
+
help="Index `ind` in your sample table; aligns one validation cell to this row.",
|
| 151 |
+
)
|
| 152 |
+
row = df[df["dataset_idx"] == pick]
|
| 153 |
+
if not row.empty:
|
| 154 |
+
st.dataframe(
|
| 155 |
+
formatters.latent_inspector_key_value(row.iloc[0]),
|
| 156 |
+
width="stretch",
|
| 157 |
+
hide_index=True,
|
| 158 |
+
)
|
streamlit_hf/pages/2_Feature_insights.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multimodal feature importance: ranks, attention by prediction, tables."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
|
| 11 |
+
_REPO = Path(__file__).resolve().parents[2]
|
| 12 |
+
if str(_REPO) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(_REPO))
|
| 14 |
+
|
| 15 |
+
from streamlit_hf.lib import io
|
| 16 |
+
from streamlit_hf.lib import plots
|
| 17 |
+
from streamlit_hf.lib import ui
|
| 18 |
+
|
| 19 |
+
ui.inject_app_styles()
|
| 20 |
+
|
| 21 |
+
st.title("Feature Insights")
|
| 22 |
+
st.caption("Latent-shift probes, attention rollout, and combined rankings across RNA, ATAC, and Flux.")
|
| 23 |
+
|
| 24 |
+
df = io.load_df_features()
|
| 25 |
+
att = io.load_attention_summary()
|
| 26 |
+
|
| 27 |
+
if df is None:
|
| 28 |
+
st.error(
|
| 29 |
+
"Feature data are not loaded. Ask the maintainer to publish results for this app, then reload."
|
| 30 |
+
)
|
| 31 |
+
st.stop()
|
| 32 |
+
|
| 33 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs(
|
| 34 |
+
[
|
| 35 |
+
"Global overview",
|
| 36 |
+
"Modality spotlight",
|
| 37 |
+
"Shift vs attention",
|
| 38 |
+
"Attention vs prediction",
|
| 39 |
+
"Full table",
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# ----- Tab 1 -----
|
| 44 |
+
with tab1:
|
| 45 |
+
c1, c2 = st.columns(2)
|
| 46 |
+
with c1:
|
| 47 |
+
top_n_bars = st.slider(
|
| 48 |
+
"Top N (shift & attention bars)",
|
| 49 |
+
10,
|
| 50 |
+
45,
|
| 51 |
+
20,
|
| 52 |
+
key="t1_topn_bars",
|
| 53 |
+
)
|
| 54 |
+
with c2:
|
| 55 |
+
top_n_pie = st.slider(
|
| 56 |
+
"Pool size (mean-rank pie)",
|
| 57 |
+
50,
|
| 58 |
+
250,
|
| 59 |
+
100,
|
| 60 |
+
key="t1_topn_pie",
|
| 61 |
+
)
|
| 62 |
+
st.plotly_chart(
|
| 63 |
+
plots.global_rank_triple_panel(df, top_n=top_n_bars, top_n_pie=top_n_pie),
|
| 64 |
+
width="stretch",
|
| 65 |
+
)
|
| 66 |
+
st.caption(
|
| 67 |
+
"Bars: **global** top features by shift impact and by mean attention (min-max scaled); "
|
| 68 |
+
"colour = modality. Pie: RNA / ATAC / Flux mix among the lowest mean-rank features in that pool."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# ----- Tab 2: RNA / ATAC / Flux columns -----
|
| 72 |
+
with tab2:
|
| 73 |
+
st.caption(
|
| 74 |
+
"**Modality spotlight:** three columns (**RNA**, **ATAC**, **Flux**). Each column only shows features "
|
| 75 |
+
"from that modality so you can compare shift impact, attention, and joint ranking **within** RNA, ATAC, or flux."
|
| 76 |
+
)
|
| 77 |
+
top_n_rank = st.slider("Top N per chart", 10, 55, 20, key="t2_topn")
|
| 78 |
+
st.subheader("Joint top markers (by mean rank)")
|
| 79 |
+
st.caption(
|
| 80 |
+
"The **strongest combined** markers by mean rank (lower mean rank = higher joint shift + attention priority). "
|
| 81 |
+
"Shift and attention bars are **min-max scaled within this top-N list** (0 to 1) so you can compare them on one axis. "
|
| 82 |
+
"Hover a bar for the full feature name."
|
| 83 |
+
)
|
| 84 |
+
r1a, r1b, r1c = st.columns(3)
|
| 85 |
+
for col, mod in zip((r1a, r1b, r1c), ("RNA", "ATAC", "Flux")):
|
| 86 |
+
sm = df[df["modality"] == mod]
|
| 87 |
+
if sm.empty:
|
| 88 |
+
continue
|
| 89 |
+
with col:
|
| 90 |
+
st.plotly_chart(
|
| 91 |
+
plots.joint_shift_attention_top_features(sm, mod, top_n_rank),
|
| 92 |
+
width="stretch",
|
| 93 |
+
)
|
| 94 |
+
st.subheader("Shift importance")
|
| 95 |
+
r2a, r2b, r2c = st.columns(3)
|
| 96 |
+
for col, mod in zip((r2a, r2b, r2c), ("RNA", "ATAC", "Flux")):
|
| 97 |
+
sm = df[df["modality"] == mod]
|
| 98 |
+
if sm.empty:
|
| 99 |
+
continue
|
| 100 |
+
colc = plots.MODALITY_COLOR.get(mod, plots.PALETTE[0])
|
| 101 |
+
sub = sm.nlargest(top_n_rank, "importance_shift").sort_values("importance_shift", ascending=True)
|
| 102 |
+
with col:
|
| 103 |
+
st.plotly_chart(
|
| 104 |
+
plots.rank_bar(
|
| 105 |
+
sub,
|
| 106 |
+
"importance_shift",
|
| 107 |
+
"feature",
|
| 108 |
+
f"{mod}: shift · top {top_n_rank}",
|
| 109 |
+
colc,
|
| 110 |
+
xaxis_title="Latent shift importance",
|
| 111 |
+
),
|
| 112 |
+
width="stretch",
|
| 113 |
+
)
|
| 114 |
+
st.subheader("Attention importance")
|
| 115 |
+
r3a, r3b, r3c = st.columns(3)
|
| 116 |
+
for col, mod in zip((r3a, r3b, r3c), ("RNA", "ATAC", "Flux")):
|
| 117 |
+
sm = df[df["modality"] == mod]
|
| 118 |
+
if sm.empty:
|
| 119 |
+
continue
|
| 120 |
+
colc = plots.MODALITY_COLOR.get(mod, plots.PALETTE[0])
|
| 121 |
+
sub = sm.nlargest(top_n_rank, "importance_att").sort_values("importance_att", ascending=True)
|
| 122 |
+
with col:
|
| 123 |
+
st.plotly_chart(
|
| 124 |
+
plots.rank_bar(
|
| 125 |
+
sub,
|
| 126 |
+
"importance_att",
|
| 127 |
+
"feature",
|
| 128 |
+
f"{mod}: attention · top {top_n_rank}",
|
| 129 |
+
colc,
|
| 130 |
+
xaxis_title="Attention importance",
|
| 131 |
+
),
|
| 132 |
+
width="stretch",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# ----- Tab 3 -----
|
| 136 |
+
with tab3:
|
| 137 |
+
st.caption(
|
| 138 |
+
"Each point is **one feature** within its modality. **Attention rank** is on the horizontal axis and **shift rank** "
|
| 139 |
+
"on the vertical axis (1 = strongest in that modality for that metric). Features near the diagonal rank similarly "
|
| 140 |
+
"for both; the **red dashed line** is a straight-line trend (least-squares fit) through the cloud."
|
| 141 |
+
)
|
| 142 |
+
corr_rows = []
|
| 143 |
+
for mod in ("RNA", "ATAC", "Flux"):
|
| 144 |
+
sm = df[df["modality"] == mod]
|
| 145 |
+
if sm.empty:
|
| 146 |
+
continue
|
| 147 |
+
cor = plots.modality_shift_attention_rank_stats(sm)
|
| 148 |
+
if cor.get("n", 0) >= 3:
|
| 149 |
+
corr_rows.append(
|
| 150 |
+
{
|
| 151 |
+
"Modality": mod,
|
| 152 |
+
"# features": cor["n"],
|
| 153 |
+
"Pearson r": f"{cor['pearson_r']:.3f}",
|
| 154 |
+
"Pearson p": f"{cor['pearson_p']:.2e}",
|
| 155 |
+
"Spearman ρ": f"{cor['spearman_r']:.3f}",
|
| 156 |
+
"Spearman p": f"{cor['spearman_p']:.2e}",
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
if corr_rows:
|
| 160 |
+
st.dataframe(pd.DataFrame(corr_rows), hide_index=True, width="stretch")
|
| 161 |
+
rc1, rc2, rc3 = st.columns(3)
|
| 162 |
+
for col, mod in zip((rc1, rc2, rc3), ("RNA", "ATAC", "Flux")):
|
| 163 |
+
with col:
|
| 164 |
+
sub_m = df[df["modality"] == mod]
|
| 165 |
+
st.plotly_chart(
|
| 166 |
+
plots.rank_scatter_shift_vs_attention(sub_m, mod),
|
| 167 |
+
width="stretch",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# ----- Tab 4 -----
|
| 171 |
+
with tab4:
|
| 172 |
+
with st.expander("What is this?", expanded=False):
|
| 173 |
+
st.markdown(
|
| 174 |
+
"Bars show **mean attention weights** (from rollout) averaged over validation cells, split by **what the "
|
| 175 |
+
"model predicted** for each cell: all validation cells together, only cells called **dead-end**, or only "
|
| 176 |
+
"cells called **reprogramming**. This reflects **model behaviour**, not the true fate label."
|
| 177 |
+
)
|
| 178 |
+
cohort_mode = st.selectbox(
|
| 179 |
+
"Cohort view",
|
| 180 |
+
[
|
| 181 |
+
"compare",
|
| 182 |
+
"all",
|
| 183 |
+
"dead_end",
|
| 184 |
+
"reprogramming",
|
| 185 |
+
],
|
| 186 |
+
format_func=lambda x: {
|
| 187 |
+
"compare": "Compare cohorts (grouped bars)",
|
| 188 |
+
"all": "All validation samples (mean attention)",
|
| 189 |
+
"dead_end": "Mean attention when prediction = dead-end",
|
| 190 |
+
"reprogramming": "Mean attention when prediction = reprogramming",
|
| 191 |
+
}[x],
|
| 192 |
+
key="t4_cohort",
|
| 193 |
+
help=(
|
| 194 |
+
"Choose which validation cells contribute to the average. **All validation samples** uses every validation "
|
| 195 |
+
"cell. The prediction-specific options use only cells where the model output was dead-end or reprogramming, "
|
| 196 |
+
"so you can see which features receive more weight when the model leans each way."
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
top_n_att = st.slider("Top N", 6, 28, 15, key="t4_topn")
|
| 200 |
+
if not att or "fi_att" not in att:
|
| 201 |
+
st.warning(
|
| 202 |
+
"Attention summaries are not available in this session. That view needs a full publish from the maintainer."
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
ac1, ac2, ac3 = st.columns(3)
|
| 206 |
+
for col, mod in zip((ac1, ac2, ac3), ("RNA", "ATAC", "Flux")):
|
| 207 |
+
with col:
|
| 208 |
+
st.plotly_chart(
|
| 209 |
+
plots.attention_cohort_view(att["fi_att"], mod, top_n=top_n_att, mode=cohort_mode),
|
| 210 |
+
width="stretch",
|
| 211 |
+
)
|
| 212 |
+
if "rollout_mean" in att and "slices" in att:
|
| 213 |
+
st.subheader("Mean rollout weight")
|
| 214 |
+
if cohort_mode == "compare":
|
| 215 |
+
roll_cohort = st.selectbox(
|
| 216 |
+
"Rollout table: average over",
|
| 217 |
+
["all", "dead_end", "reprogramming"],
|
| 218 |
+
format_func=lambda x: {
|
| 219 |
+
"all": "All validation samples",
|
| 220 |
+
"dead_end": "Cells predicted dead-end",
|
| 221 |
+
"reprogramming": "Cells predicted reprogramming",
|
| 222 |
+
}[x],
|
| 223 |
+
key="t4_roll",
|
| 224 |
+
help="Pick which validation subset is used for the mean rollout vector in the tables below.",
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
roll_cohort = cohort_mode
|
| 228 |
+
st.caption(
|
| 229 |
+
"Rollout tables use the **same cohort** as the bar charts above (batch-embedding tokens are omitted)."
|
| 230 |
+
)
|
| 231 |
+
rc1, rc2, rc3 = st.columns(3)
|
| 232 |
+
for col, mod in zip((rc1, rc2, rc3), ("RNA", "ATAC", "Flux")):
|
| 233 |
+
with col:
|
| 234 |
+
rm = att["rollout_mean"]
|
| 235 |
+
vec_all = rm.get(roll_cohort)
|
| 236 |
+
if vec_all is None:
|
| 237 |
+
vec_all = rm["all"]
|
| 238 |
+
sl = att["slices"][mod]
|
| 239 |
+
vec = vec_all[sl["start"] : sl["stop"]]
|
| 240 |
+
names = att["feature_names"][sl["start"] : sl["stop"]]
|
| 241 |
+
mini = plots.rollout_top_features_table(names, vec, top_n_att)
|
| 242 |
+
st.caption(mod)
|
| 243 |
+
st.dataframe(mini, hide_index=True, width="stretch")
|
| 244 |
+
|
| 245 |
+
# ----- Tab 5 -----
|
| 246 |
+
with tab5:
|
| 247 |
+
scope = st.radio(
|
| 248 |
+
"Table scope",
|
| 249 |
+
["All modalities", "Single modality"],
|
| 250 |
+
horizontal=True,
|
| 251 |
+
key="t5_scope",
|
| 252 |
+
)
|
| 253 |
+
mod_tbl = "all"
|
| 254 |
+
if scope == "Single modality":
|
| 255 |
+
mod_tbl = st.selectbox("Modality", ["RNA", "ATAC", "Flux"], key="t5_mod")
|
| 256 |
+
tbl = df[df["modality"] == mod_tbl].copy()
|
| 257 |
+
else:
|
| 258 |
+
tbl = df.copy()
|
| 259 |
+
show_cols = [
|
| 260 |
+
c
|
| 261 |
+
for c in [
|
| 262 |
+
"mean_rank",
|
| 263 |
+
"feature",
|
| 264 |
+
"modality",
|
| 265 |
+
"rank_shift_in_modal",
|
| 266 |
+
"rank_att_in_modal",
|
| 267 |
+
"combined_order_mod",
|
| 268 |
+
"rank_shift",
|
| 269 |
+
"rank_att",
|
| 270 |
+
"importance_shift",
|
| 271 |
+
"importance_att",
|
| 272 |
+
"top_10_pct",
|
| 273 |
+
"group",
|
| 274 |
+
"log_fc",
|
| 275 |
+
"pval_adj",
|
| 276 |
+
"pathway",
|
| 277 |
+
"module",
|
| 278 |
+
]
|
| 279 |
+
if c in tbl.columns
|
| 280 |
+
]
|
| 281 |
+
st.caption(
|
| 282 |
+
"All rows for the chosen scope, sorted by **mean rank** (lower = stronger joint shift + attention priority). "
|
| 283 |
+
"Use the dataframe search / sort in the table toolbar to narrow down."
|
| 284 |
+
)
|
| 285 |
+
full_view = tbl[show_cols].sort_values("mean_rank")
|
| 286 |
+
st.dataframe(full_view, width="stretch", hide_index=True)
|
| 287 |
+
suffix = mod_tbl if scope == "Single modality" else "all"
|
| 288 |
+
st.download_button(
|
| 289 |
+
"Download table (CSV)",
|
| 290 |
+
full_view.to_csv(index=False).encode("utf-8"),
|
| 291 |
+
file_name=f"fateformer_features_{suffix}.csv",
|
| 292 |
+
mime="text/csv",
|
| 293 |
+
key="t5_dl",
|
| 294 |
+
)
|
streamlit_hf/pages/3_Flux_analysis.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metabolic flux: pathway map, differential views, reaction ranking table, metabolic model metadata."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
_REPO = Path(__file__).resolve().parents[2]
|
| 11 |
+
if str(_REPO) not in sys.path:
|
| 12 |
+
sys.path.insert(0, str(_REPO))
|
| 13 |
+
|
| 14 |
+
from streamlit_hf.lib import io
|
| 15 |
+
from streamlit_hf.lib import plots
|
| 16 |
+
from streamlit_hf.lib import ui
|
| 17 |
+
|
| 18 |
+
ui.inject_app_styles()
|
| 19 |
+
|
| 20 |
+
st.title("Flux Analysis")
|
| 21 |
+
st.caption(
|
| 22 |
+
"Reaction-level flux: how pathways, statistics, and model rankings line up. "
|
| 23 |
+
"For global rank bars and shift vs. attention scatter, open **Feature insights**."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
df = io.load_df_features()
|
| 27 |
+
if df is None:
|
| 28 |
+
st.error(
|
| 29 |
+
"Flux and feature data are not loaded in this session. Reload the app after the maintainer has published "
|
| 30 |
+
"fresh results, or ask them to check the deployment."
|
| 31 |
+
)
|
| 32 |
+
st.stop()
|
| 33 |
+
|
| 34 |
+
flux = df[df["modality"] == "Flux"].copy()
|
| 35 |
+
if flux.empty:
|
| 36 |
+
st.warning("There are no flux reactions in the current results.")
|
| 37 |
+
st.stop()
|
| 38 |
+
|
| 39 |
+
meta = io.load_metabolic_model_metadata()
|
| 40 |
+
|
| 41 |
+
tab_map, tab_bio, tab_rank, tab_meta = st.tabs(
|
| 42 |
+
[
|
| 43 |
+
"Pathway map",
|
| 44 |
+
"Differential & fate",
|
| 45 |
+
"Reaction ranking",
|
| 46 |
+
"Metabolic model metadata",
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
with tab_map:
|
| 51 |
+
st.caption(
|
| 52 |
+
"**Left:** sunburst of the strongest reactions by mean rank, grouped by pathway. **Right:** heatmaps for the "
|
| 53 |
+
"same reactions: pathway, differential Log₂FC, and statistical significance, aligned row by row. "
|
| 54 |
+
"Ranked reaction table: **Reaction Ranking**. Curated model edges: **Metabolic model metadata**."
|
| 55 |
+
)
|
| 56 |
+
try:
|
| 57 |
+
c1, c2 = st.columns([1.05, 0.95], gap="medium", vertical_alignment="top")
|
| 58 |
+
except TypeError:
|
| 59 |
+
c1, c2 = st.columns([1.05, 0.95], gap="medium")
|
| 60 |
+
with c1:
|
| 61 |
+
n_sb = st.slider("Reactions in sunburst", 25, 90, 52, key="flux_sb_n")
|
| 62 |
+
st.plotly_chart(plots.flux_pathway_sunburst(flux, max_features=n_sb), width="stretch")
|
| 63 |
+
with c2:
|
| 64 |
+
top_n_nb = st.slider("Reactions in annotation + profile", 12, 40, 26, key="flux_nb_n")
|
| 65 |
+
st.plotly_chart(
|
| 66 |
+
plots.flux_reaction_annotation_panel(flux, top_n=top_n_nb, metric="mean_rank"),
|
| 67 |
+
width="stretch",
|
| 68 |
+
)
|
| 69 |
+
st.plotly_chart(
|
| 70 |
+
plots.flux_model_metric_profile(flux, top_n=min(top_n_nb, 24), metric="mean_rank"),
|
| 71 |
+
width="stretch",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
with tab_bio:
|
| 75 |
+
st.caption(
|
| 76 |
+
"**Volcano:** differential Log₂FC versus significance (−log₁₀ adjusted p); colour shows overall mean rank. "
|
| 77 |
+
"Points with essentially no fold change and a zero adjusted p-value are removed as unreliable. "
|
| 78 |
+
"**Scatter:** average measured flux in dead-end versus reprogramming cells; point size reflects combined shift "
|
| 79 |
+
"and attention strength; colours mark pathway (largest groups shown, others grouped as *Other*)."
|
| 80 |
+
)
|
| 81 |
+
b1, b2 = st.columns(2)
|
| 82 |
+
with b1:
|
| 83 |
+
st.plotly_chart(plots.flux_volcano(flux), width="stretch")
|
| 84 |
+
with b2:
|
| 85 |
+
st.plotly_chart(plots.flux_dead_end_vs_reprogram_scatter(flux), width="stretch")
|
| 86 |
+
|
| 87 |
+
with tab_rank:
|
| 88 |
+
st.caption("Filter by reaction name or pathway, then inspect or download the ranked flux table.")
|
| 89 |
+
q = st.text_input("Substring filter (reaction name)", "", key="flux_q")
|
| 90 |
+
pw_f = st.multiselect(
|
| 91 |
+
"Pathway",
|
| 92 |
+
sorted(flux["pathway"].dropna().unique().astype(str)),
|
| 93 |
+
default=[],
|
| 94 |
+
key="flux_pw_f",
|
| 95 |
+
)
|
| 96 |
+
show = flux
|
| 97 |
+
if q.strip():
|
| 98 |
+
show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
|
| 99 |
+
if pw_f:
|
| 100 |
+
show = show[show["pathway"].astype(str).isin(pw_f)]
|
| 101 |
+
cols = [
|
| 102 |
+
c
|
| 103 |
+
for c in [
|
| 104 |
+
"mean_rank",
|
| 105 |
+
"feature",
|
| 106 |
+
"rank_shift_in_modal",
|
| 107 |
+
"rank_att_in_modal",
|
| 108 |
+
"combined_order_mod",
|
| 109 |
+
"rank_shift",
|
| 110 |
+
"rank_att",
|
| 111 |
+
"importance_shift",
|
| 112 |
+
"importance_att",
|
| 113 |
+
"top_10_pct",
|
| 114 |
+
"mean_de",
|
| 115 |
+
"mean_re",
|
| 116 |
+
"group",
|
| 117 |
+
"log_fc",
|
| 118 |
+
"pval_adj",
|
| 119 |
+
"pathway",
|
| 120 |
+
"module",
|
| 121 |
+
]
|
| 122 |
+
if c in show.columns
|
| 123 |
+
]
|
| 124 |
+
st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
|
| 125 |
+
st.download_button(
|
| 126 |
+
"Download Flux table (CSV)",
|
| 127 |
+
show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
|
| 128 |
+
file_name="fateformer_flux_filtered.csv",
|
| 129 |
+
mime="text/csv",
|
| 130 |
+
key="flux_dl",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
with tab_meta:
|
| 134 |
+
st.caption(
|
| 135 |
+
"Directed substrate-to-product steps from the reference model, merged with this flux table where reaction names match."
|
| 136 |
+
)
|
| 137 |
+
if meta is None or meta.empty:
|
| 138 |
+
st.warning("Metabolic model metadata is not available in this build.")
|
| 139 |
+
else:
|
| 140 |
+
sm_ids = sorted(meta["Supermodule_id"].dropna().unique().astype(int).tolist())
|
| 141 |
+
graph_labels = ["All modules"]
|
| 142 |
+
for sid in sm_ids:
|
| 143 |
+
cls = str(meta.loc[meta["Supermodule_id"] == sid, "Super.Module.class"].iloc[0])
|
| 144 |
+
graph_labels.append(f"{sid}: {cls}")
|
| 145 |
+
tix = st.selectbox(
|
| 146 |
+
"Model scope",
|
| 147 |
+
range(len(graph_labels)),
|
| 148 |
+
format_func=lambda i: graph_labels[i],
|
| 149 |
+
key="flux_model_scope",
|
| 150 |
+
help="Show every step in the model, or restrict to one functional module.",
|
| 151 |
+
)
|
| 152 |
+
supermodule_id = None if tix == 0 else sm_ids[tix - 1]
|
| 153 |
+
tbl = io.build_metabolic_model_table(meta, flux, supermodule_id=supermodule_id)
|
| 154 |
+
st.dataframe(tbl, width="stretch", hide_index=True)
|
| 155 |
+
st.download_button(
|
| 156 |
+
"Download metabolic model metadata (CSV)",
|
| 157 |
+
tbl.to_csv(index=False).encode("utf-8"),
|
| 158 |
+
file_name="fateformer_metabolic_model_edges.csv",
|
| 159 |
+
mime="text/csv",
|
| 160 |
+
key="flux_model_dl",
|
| 161 |
+
)
|
streamlit_hf/pages/4_Gene_expression_analysis.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gene expression and TF motif activity: pathway enrichment, chromVAR-style motifs, and tables."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
|
| 11 |
+
_REPO = Path(__file__).resolve().parents[2]
|
| 12 |
+
if str(_REPO) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(_REPO))
|
| 14 |
+
|
| 15 |
+
from streamlit_hf.lib import io
|
| 16 |
+
from streamlit_hf.lib import pathways as pathway_data
|
| 17 |
+
from streamlit_hf.lib import plots
|
| 18 |
+
from streamlit_hf.lib import ui
|
| 19 |
+
|
| 20 |
+
ui.inject_app_styles()
|
| 21 |
+
|
| 22 |
+
st.title("Gene Expression & TF Activity")
|
| 23 |
+
|
| 24 |
+
df = io.load_df_features()
|
| 25 |
+
if df is None:
|
| 26 |
+
st.error("Feature data could not be loaded. Reload after results are published, or contact the maintainer.")
|
| 27 |
+
st.stop()
|
| 28 |
+
|
| 29 |
+
rna = df[df["modality"] == "RNA"].copy()
|
| 30 |
+
atac = df[df["modality"] == "ATAC"].copy()
|
| 31 |
+
if rna.empty and atac.empty:
|
| 32 |
+
st.warning("No RNA gene or ATAC motif features are available in the current results.")
|
| 33 |
+
st.stop()
|
| 34 |
+
|
| 35 |
+
st.caption(
|
| 36 |
+
"Pathway enrichment (Reactome / KEGG) and a pathway–gene map; chromVAR-style motif deviations and activity by "
|
| 37 |
+
"fate; sortable gene and motif tables. Use **Feature Insights** for global shift and attention rankings across modalities."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
TABLE_COLS = [
|
| 41 |
+
"mean_rank",
|
| 42 |
+
"feature",
|
| 43 |
+
"rank_shift_in_modal",
|
| 44 |
+
"rank_att_in_modal",
|
| 45 |
+
"combined_order_mod",
|
| 46 |
+
"rank_shift",
|
| 47 |
+
"rank_att",
|
| 48 |
+
"importance_shift",
|
| 49 |
+
"importance_att",
|
| 50 |
+
"top_10_pct",
|
| 51 |
+
"mean_de",
|
| 52 |
+
"mean_re",
|
| 53 |
+
"group",
|
| 54 |
+
"log_fc",
|
| 55 |
+
"pval_adj",
|
| 56 |
+
"mean_diff",
|
| 57 |
+
"pval_adj_log",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _table_cols(show: pd.DataFrame) -> list[str]:
|
| 62 |
+
return [c for c in TABLE_COLS if c in show.columns]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
tab_path, tab_motif, tab_gene_tbl, tab_motif_tbl = st.tabs(
|
| 66 |
+
["Gene Pathway Enrichment", "Motif Activity", "Gene Table", "Motif Table"]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with tab_path:
|
| 70 |
+
st.caption(
|
| 71 |
+
"Over-representation of Reactome and KEGG pathways (Benjamini–Hochberg *q* < 0.05). "
|
| 72 |
+
"The lower panel maps leading genes to pathways; empty grid positions are left clear."
|
| 73 |
+
)
|
| 74 |
+
raw = pathway_data.load_de_re_tsv()
|
| 75 |
+
if raw is None:
|
| 76 |
+
st.info("Pathway enrichment views are not available in this deployment.")
|
| 77 |
+
else:
|
| 78 |
+
de_all, re_all = raw
|
| 79 |
+
mde, mre = pathway_data.merged_reactome_kegg_bubble_frames(de_all, re_all)
|
| 80 |
+
bubble_h = max(
|
| 81 |
+
plots.pathway_bubble_suggested_height(len(mde)),
|
| 82 |
+
plots.pathway_bubble_suggested_height(len(mre)),
|
| 83 |
+
)
|
| 84 |
+
c1, c2 = st.columns(2, gap="medium")
|
| 85 |
+
with c1:
|
| 86 |
+
st.plotly_chart(
|
| 87 |
+
plots.pathway_enrichment_bubble_panel(
|
| 88 |
+
mde,
|
| 89 |
+
"Pathway enrichment — dead-end",
|
| 90 |
+
show_colorbar=True,
|
| 91 |
+
layout_height=bubble_h,
|
| 92 |
+
),
|
| 93 |
+
width="stretch",
|
| 94 |
+
)
|
| 95 |
+
with c2:
|
| 96 |
+
st.plotly_chart(
|
| 97 |
+
plots.pathway_enrichment_bubble_panel(
|
| 98 |
+
mre,
|
| 99 |
+
"Pathway enrichment — reprogramming",
|
| 100 |
+
show_colorbar=True,
|
| 101 |
+
layout_height=bubble_h,
|
| 102 |
+
),
|
| 103 |
+
width="stretch",
|
| 104 |
+
)
|
| 105 |
+
hm = pathway_data.build_merged_pathway_membership(de_all, re_all)
|
| 106 |
+
if hm is None:
|
| 107 |
+
st.info("No pathway–gene matrix could be built from the current enrichment results.")
|
| 108 |
+
else:
|
| 109 |
+
z, ylabs, xlabs = hm
|
| 110 |
+
st.plotly_chart(plots.pathway_gene_membership_heatmap(z, ylabs, xlabs), width="stretch")
|
| 111 |
+
|
| 112 |
+
with tab_motif:
|
| 113 |
+
if atac.empty:
|
| 114 |
+
st.warning("No motif-level ATAC features are available in the current results.")
|
| 115 |
+
else:
|
| 116 |
+
st.caption(
|
| 117 |
+
"Left: mean motif score difference (reprogramming − dead-end) versus significance. "
|
| 118 |
+
"Right: mean activity in each fate; colour and size follow the same encoding as in **Feature Insights**."
|
| 119 |
+
)
|
| 120 |
+
a1, a2 = st.columns(2, gap="medium")
|
| 121 |
+
with a1:
|
| 122 |
+
st.plotly_chart(plots.motif_chromvar_volcano(atac), width="stretch")
|
| 123 |
+
with a2:
|
| 124 |
+
st.plotly_chart(
|
| 125 |
+
plots.notebook_style_activity_scatter(
|
| 126 |
+
atac,
|
| 127 |
+
title="TF activity (z-score) by fate",
|
| 128 |
+
x_title="Dead-end (TF activity)",
|
| 129 |
+
y_title="Reprogramming (TF activity)",
|
| 130 |
+
),
|
| 131 |
+
width="stretch",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
with tab_gene_tbl:
|
| 135 |
+
if rna.empty:
|
| 136 |
+
st.warning("No RNA gene features are available in the current results.")
|
| 137 |
+
else:
|
| 138 |
+
q = st.text_input("Filter by gene name", "", key="ge_tbl_q")
|
| 139 |
+
show = rna
|
| 140 |
+
if q.strip():
|
| 141 |
+
show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
|
| 142 |
+
cols = _table_cols(show)
|
| 143 |
+
st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
|
| 144 |
+
st.download_button(
|
| 145 |
+
"Download table (CSV)",
|
| 146 |
+
show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
|
| 147 |
+
file_name="gene_expression_table.csv",
|
| 148 |
+
mime="text/csv",
|
| 149 |
+
key="ge_tbl_dl",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
with tab_motif_tbl:
|
| 153 |
+
if atac.empty:
|
| 154 |
+
st.warning("No motif-level ATAC features are available in the current results.")
|
| 155 |
+
else:
|
| 156 |
+
q = st.text_input("Filter by motif or TF", "", key="tf_tbl_q")
|
| 157 |
+
show = atac
|
| 158 |
+
if q.strip():
|
| 159 |
+
show = show[show["feature"].astype(str).str.contains(q, case=False, na=False)]
|
| 160 |
+
cols = _table_cols(show)
|
| 161 |
+
st.dataframe(show[cols].sort_values("mean_rank"), width="stretch", hide_index=True)
|
| 162 |
+
st.download_button(
|
| 163 |
+
"Download table (CSV)",
|
| 164 |
+
show[cols].sort_values("mean_rank").to_csv(index=False).encode("utf-8"),
|
| 165 |
+
file_name="tf_motif_table.csv",
|
| 166 |
+
mime="text/csv",
|
| 167 |
+
key="tf_tbl_dl",
|
| 168 |
+
)
|
streamlit_hf/requirements-docker.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face / production image: precomputed cache only (no torch)
|
| 2 |
+
streamlit>=1.40.0
|
| 3 |
+
plotly>=5.22.0
|
| 4 |
+
pandas>=2.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
pyarrow>=14.0.0
|
streamlit_hf/static/app_icon.svg
ADDED
|
|