FateFormerExplorer / scripts /precompute_streamlit_cache.py
kaveh's picture
Updated. First version.
4d886f4
#!/usr/bin/env python3
"""
One-off cache builder for the Streamlit explorer.
Run from the repository root:
python scripts/precompute_streamlit_cache.py
python scripts/precompute_streamlit_cache.py --skip-attention # faster: reuse objects/fi_shift_*.pkl only for df_features if attention_summary exists
"""
from __future__ import annotations
import argparse
import os
import pickle
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import umap
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
os.chdir(ROOT)
from data import create_dataset # noqa: E402
from interpretation import attentions as att # noqa: E402
from interpretation import latentspace as ls # noqa: E402
from interpretation import predictions as prds # noqa: E402
CACHE = ROOT / "streamlit_hf" / "cache"
CACHE.mkdir(parents=True, exist_ok=True)
def replace_fold_results_path(fold_results, ckp_root: str = "ckp"):
"""Point checkpoints at flat `ckp/multi_seed0_fold{k}.pth` layout in this repo."""
for fold in fold_results:
ckpt_name = os.path.basename(fold["best_model_path"])
fold_token = next((part for part in ckpt_name.split("_") if part.startswith("fold")), "")
fold_idx = "".join(ch for ch in fold_token if ch.isdigit())
if fold_idx:
clean_ckpt_name = f"multi_seed0_fold{fold_idx}.pth"
else:
clean_ckpt_name = ckpt_name
fold["best_model_path"] = os.path.join(ckp_root, clean_ckpt_name)
return fold_results
def load_training_context():
with open(ROOT / "objects" / "mutlimodal_dataset.pkl", "rb") as f:
md = pickle.load(f)
X, y_label = md["X"], md["y_label"]
b, df_indices, pcts = md["b"], md["df_indices"], md["pcts"]
y_number = torch.tensor(
[{"reprogramming": 1, "dead-end": 0}[i] for i in list(y_label)],
dtype=torch.float32,
)
multimodal_dataset = create_dataset.MultiModalDataset(
X, b, y_number, df_indices, pcts, y_label
)
with open(ROOT / "objects" / "fold_results_multi.pkl", "rb") as f:
fold_results = pickle.load(f)
fold_results = replace_fold_results_path(fold_results)
share_config = {
"d_model": 128,
"d_ff": 16,
"n_heads": 8,
"n_encoder_layers": 2,
"n_batches": 3,
"dropout_rate": 0.0,
}
model_config_rna = {"vocab_size": 5914, "seq_len": X[0].shape[1]}
model_config_atac = {"vocab_size": 1, "seq_len": X[1].shape[1]}
model_config_flux = {"vocab_size": 1, "seq_len": X[2].shape[1]}
model_config_multi = {"d_model": 128, "n_heads_cls": 8, "d_ff_cls": 16}
model_config = {
"Share": share_config,
"RNA": model_config_rna,
"ATAC": model_config_atac,
"Flux": model_config_flux,
"Multi": model_config_multi,
}
feature_names = (
list(X[0].columns)
+ ["batch_rna"]
+ list(X[1].columns)
+ ["batch_atac"]
+ list(X[2].columns)
+ ["batch_flux"]
)
adata_RNA_labelled = None
rna_pkl = ROOT / "data" / "datasets" / "rna_labelled.pkl"
try:
with open(rna_pkl, "rb") as f:
adata_RNA_labelled = pickle.load(f)
except Exception as e:
print(
f"Warning: could not load {rna_pkl} ({e}). "
"Sample table will omit AnnData-derived metadata (e.g. clone_id)."
)
return (
multimodal_dataset,
fold_results,
model_config,
feature_names,
adata_RNA_labelled,
)
def build_latent_umap(multimodal_dataset, fold_results, model_config, common_samples: bool = False):
device = "cuda" if torch.cuda.is_available() else "cpu"
ls_v, labels, preds = ls.get_latent_space(
"Multi",
fold_results,
multimodal_dataset,
model_config,
device,
common_samples=common_samples,
)
reducer = umap.UMAP(n_components=2, random_state=0, n_neighbors=30, min_dist=1.0)
xy = reducer.fit_transform(ls_v)
ordered_indices: list[int] = []
fold_ids: list[int] = []
from interpretation.attentions import filter_idx # noqa: PLC0415
from torch.utils.data import Subset # noqa: PLC0415
for fold_idx, fold in enumerate(fold_results):
val_idx = fold["val_idx"]
if common_samples:
val_idx = filter_idx(multimodal_dataset, val_idx)
ordered_indices.extend(val_idx)
fold_ids.extend([fold_idx + 1] * len(val_idx))
labels = np.asarray(labels).ravel()
preds = np.asarray(preds).ravel().astype(int)
label_name = np.where(labels > 0.5, "reprogramming", "dead-end")
pred_name = np.where(preds > 0.5, "reprogramming", "dead-end")
correct = (preds == labels.astype(int)).astype(np.int8)
ds = multimodal_dataset
batch_no = np.array([int(ds.batch_no[i].item()) for i in ordered_indices], dtype=np.int32)
pcts = np.array([float(ds.pcts[i]) for i in ordered_indices], dtype=np.float64)
modalities = []
for i in ordered_indices:
has_r = (ds.rna_data[i] != 0).any().item()
has_a = (ds.atac_data[i] != 0).any().item()
has_f = (ds.flux_data[i] != 0).any().item()
s = "".join(c for c, h in (("R", has_r), ("A", has_a), ("F", has_f)) if h)
modalities.append(s or "None")
return {
"umap_x": xy[:, 0].astype(np.float32),
"umap_y": xy[:, 1].astype(np.float32),
"label_name": label_name,
"pred_name": pred_name,
"correct": correct,
"fold": np.array(fold_ids, dtype=np.int32),
"batch_no": batch_no,
"pct": pcts,
"modality": modalities,
"dataset_idx": np.array(ordered_indices, dtype=np.int32),
"common_samples": common_samples,
}
def create_combined_feature_dataframe(
fi_shift_rna,
fi_shift_atac,
fi_shift_flux,
fi_att_rna,
fi_att_atac,
fi_att_flux,
df_rna_degs=None,
df_atac_degs=None,
df_flux_degs=None,
remove_batch=True,
):
def process_modality(shift_list, att_list, degs_df, modality_name):
shift_df = pd.DataFrame(shift_list, columns=["feature", "importance_shift"]).reset_index()
shift_df.rename(columns={"index": "rank_shift_in_modal"}, inplace=True)
shift_df["rank_shift_in_modal"] += 1
att_df = pd.DataFrame(att_list, columns=["feature", "importance_att"]).reset_index()
att_df.rename(columns={"index": "rank_att_in_modal"}, inplace=True)
att_df["rank_att_in_modal"] += 1
combined_df = pd.merge(shift_df, att_df, on="feature", how="outer")
if degs_df is not None:
combined_df = pd.merge(combined_df, degs_df, on="feature", how="left")
combined_df["modality"] = modality_name
return combined_df
rna_df = process_modality(fi_shift_rna, fi_att_rna, df_rna_degs, "RNA")
atac_df = process_modality(fi_shift_atac, fi_att_atac, df_atac_degs, "ATAC")
flux_df = process_modality(fi_shift_flux, fi_att_flux, df_flux_degs, "Flux")
all_features_df = pd.concat([rna_df, atac_df, flux_df], ignore_index=True)
if remove_batch:
all_features_df = all_features_df[~all_features_df["feature"].str.contains("batch", na=False)]
max_rank_modal = max(
all_features_df["rank_att_in_modal"].max(), all_features_df["rank_shift_in_modal"].max()
)
all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
["rank_att_in_modal", "rank_shift_in_modal"]
].fillna(max_rank_modal + 1)
all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[
["rank_att_in_modal", "rank_shift_in_modal"]
].astype("int32")
all_features_df[["importance_att", "importance_shift"]] = (
all_features_df[["importance_att", "importance_shift"]].fillna(0).astype("float64")
)
all_features_df["rank_shift"] = (
all_features_df["importance_shift"].rank(ascending=False, method="first").astype("int32")
)
all_features_df["rank_att"] = (
all_features_df["importance_att"].rank(ascending=False, method="first").astype("int32")
)
all_features_df["mean_rank"] = all_features_df[["rank_att", "rank_shift"]].mean(axis=1)
top_th = int(all_features_df.shape[0] * 0.1) + 1
all_features_df["top_10_pct"] = all_features_df.apply(
lambda row: "both"
if row["rank_shift"] <= top_th and row["rank_att"] <= top_th
else (
"shift"
if row["rank_shift"] <= top_th
else ("att" if row["rank_att"] <= top_th else "None")
),
axis=1,
)
float_cols = [
col for col in all_features_df.columns if col.startswith(("log_fc", "mean_", "std_", "pval_"))
]
if float_cols:
all_features_df[float_cols] = all_features_df[float_cols].round(6)
all_features_df["importance_att"] = all_features_df["importance_att"].round(6)
all_features_df["importance_shift"] = all_features_df["importance_shift"].round(6)
all_features_df = all_features_df.sort_values(by="mean_rank", ascending=True)
cols = [
"mean_rank",
"feature",
"rank_shift",
"rank_att",
"rank_shift_in_modal",
"rank_att_in_modal",
"modality",
"importance_shift",
"importance_att",
"top_10_pct",
"mean_de",
"mean_re",
"std_de",
"std_re",
"pval",
"pval_adj",
"log_fc",
"group",
"pval_adj_log",
"mean_diff",
"pathway",
"module",
]
for c in cols:
if c not in all_features_df.columns:
all_features_df[c] = np.nan
return all_features_df[cols]
def run_attention_and_fi(
multimodal_dataset,
fold_results,
model_config,
feature_names,
device: str,
adata_rna,
):
df_samples = prds.get_sample_predictions_dataframe(
model_type="Multi",
multimodal_dataset=multimodal_dataset,
fold_results=fold_results,
model_config=model_config,
device=device,
batch_size=32,
threshold=0.5,
adata_rna=adata_rna,
)
all_indices = df_samples["ind"].tolist()
de_preds_indices = df_samples[df_samples["predicted_class"] == "dead-end"]["ind"].tolist()
re_preds_indices = df_samples[df_samples["predicted_class"] == "reprogramming"]["ind"].tolist()
print("Running flow attention (all validation)…")
all_layers_all = att.analyze_cls_attention(
"Multi",
fold_results,
multimodal_dataset,
model_config,
device=device,
indices=all_indices,
average_heads=False,
return_flow_attention=True,
)
print("Running flow attention (predicted dead-end)…")
all_layers_de = att.analyze_cls_attention(
"Multi",
fold_results,
multimodal_dataset,
model_config,
device=device,
indices=de_preds_indices,
average_heads=False,
return_flow_attention=True,
)
print("Running flow attention (predicted reprogramming)…")
all_layers_re = att.analyze_cls_attention(
"Multi",
fold_results,
multimodal_dataset,
model_config,
device=device,
indices=re_preds_indices,
average_heads=False,
return_flow_attention=True,
)
rollout_all = att.multimodal_attention_rollout(all_layers_all)
rollout_de = att.multimodal_attention_rollout(all_layers_de)
rollout_re = att.multimodal_attention_rollout(all_layers_re)
rollout_all = rollout_all / rollout_all.sum(dim=-1, keepdim=True)
rollout_de = rollout_de / rollout_de.sum(dim=-1, keepdim=True)
rollout_re = rollout_re / rollout_re.sum(dim=-1, keepdim=True)
# Explicit splits (notebook): RNA [:945], ATAC [945:945+884], flux rest
i0, i1, i2 = 0, 945, 945 + 884
def mean_vec(t):
return t.mean(dim=0).detach().cpu().numpy()
rollout_mean = {
"all": mean_vec(rollout_all),
"dead_end": mean_vec(rollout_de),
"reprogramming": mean_vec(rollout_re),
}
top_n_get = None
fi = {"all": {}, "dead_end": {}, "reprogramming": {}}
for name, tensor in (
("all", rollout_all),
("dead_end", rollout_de),
("reprogramming", rollout_re),
):
fi[name]["rna"] = att.get_top_features(
tensor[:, i0:i1], feature_names[i0:i1], modality="RNA", top_n=top_n_get
)
fi[name]["atac"] = att.get_top_features(
tensor[:, i1:i2], feature_names[i1:i2], modality="ATAC", top_n=top_n_get
)
fi[name]["flux"] = att.get_top_features(
tensor[:, i2:], feature_names[i2:], modality="Flux", top_n=top_n_get
)
summary = {
"feature_names": feature_names,
"slices": {
"RNA": {"start": i0, "stop": i1},
"ATAC": {"start": i1, "stop": i2},
"Flux": {"start": i2, "stop": len(feature_names)},
},
"rollout_mean": rollout_mean,
"fi_att": fi,
}
return summary, df_samples
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--skip-attention", action="store_true", help="Skip attention if summary exists")
ap.add_argument(
"--common-samples",
action="store_true",
help="Use common-samples filter for latent UMAP (default: False, notebook-style)",
)
args = ap.parse_args()
common_samples = args.common_samples
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
(
multimodal_dataset,
fold_results,
model_config,
feature_names,
adata_RNA_labelled,
) = load_training_context()
print("Building latent UMAP bundle…")
latent = build_latent_umap(
multimodal_dataset, fold_results, model_config, common_samples=common_samples
)
with open(CACHE / "latent_umap.pkl", "wb") as f:
pickle.dump(latent, f)
att_path = CACHE / "attention_summary.pkl"
df_samples_path = CACHE / "samples.parquet"
if args.skip_attention and att_path.is_file():
print("Skipping attention (--skip-attention, file exists).")
with open(att_path, "rb") as f:
summary = pickle.load(f)
else:
print("Computing attention + rollout (slow)…")
summary, df_samples = run_attention_and_fi(
multimodal_dataset,
fold_results,
model_config,
feature_names,
device,
adata_RNA_labelled,
)
with open(att_path, "wb") as f:
pickle.dump(summary, f)
with open(CACHE / "attention_feature_ranks.pkl", "wb") as f:
pickle.dump(summary["fi_att"], f)
df_samples.to_parquet(df_samples_path, index=False)
if args.skip_attention and att_path.is_file() and not df_samples_path.is_file():
df_samples = prds.get_sample_predictions_dataframe(
model_type="Multi",
multimodal_dataset=multimodal_dataset,
fold_results=fold_results,
model_config=model_config,
device=device,
batch_size=32,
threshold=0.5,
adata_rna=adata_RNA_labelled,
)
df_samples.to_parquet(df_samples_path, index=False)
for name in ["fi_shift_rna.pkl", "fi_shift_atac.pkl", "fi_shift_flux.pkl"]:
src = ROOT / "objects" / name
if not src.is_file():
print(f"Warning: missing {src}")
with open(ROOT / "objects" / "fi_shift_rna.pkl", "rb") as f:
fi_shift_rna = pickle.load(f)
with open(ROOT / "objects" / "fi_shift_atac.pkl", "rb") as f:
fi_shift_atac = pickle.load(f)
with open(ROOT / "objects" / "fi_shift_flux.pkl", "rb") as f:
fi_shift_flux = pickle.load(f)
with open(ROOT / "objects" / "degs.pkl", "rb") as f:
degs = pickle.load(f)
df_rna_degs, df_atac_degs, df_flux_degs = degs[0], degs[1], degs[2]
fi = summary["fi_att"]
df_features = create_combined_feature_dataframe(
fi_shift_rna,
fi_shift_atac,
fi_shift_flux,
fi["all"]["rna"],
fi["all"]["atac"],
fi["all"]["flux"],
df_rna_degs,
df_atac_degs,
df_flux_degs,
)
df_features.to_parquet(CACHE / "df_features.parquet", index=False)
df_features.to_csv(ROOT / "analysis" / "df_features.csv", index=False)
print(f"Wrote {CACHE / 'df_features.parquet'} and analysis/df_features.csv")
print("Done.")
if __name__ == "__main__":
main()