#!/usr/bin/env python3 """ One-off cache builder for the Streamlit explorer. Run from the repository root: python scripts/precompute_streamlit_cache.py python scripts/precompute_streamlit_cache.py --skip-attention # faster: reuse objects/fi_shift_*.pkl only for df_features if attention_summary exists """ from __future__ import annotations import argparse import os import pickle import sys from pathlib import Path import numpy as np import pandas as pd import torch import umap ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) os.chdir(ROOT) from data import create_dataset # noqa: E402 from interpretation import attentions as att # noqa: E402 from interpretation import latentspace as ls # noqa: E402 from interpretation import predictions as prds # noqa: E402 CACHE = ROOT / "streamlit_hf" / "cache" CACHE.mkdir(parents=True, exist_ok=True) def replace_fold_results_path(fold_results, ckp_root: str = "ckp"): """Point checkpoints at flat `ckp/multi_seed0_fold{k}.pth` layout in this repo.""" for fold in fold_results: ckpt_name = os.path.basename(fold["best_model_path"]) fold_token = next((part for part in ckpt_name.split("_") if part.startswith("fold")), "") fold_idx = "".join(ch for ch in fold_token if ch.isdigit()) if fold_idx: clean_ckpt_name = f"multi_seed0_fold{fold_idx}.pth" else: clean_ckpt_name = ckpt_name fold["best_model_path"] = os.path.join(ckp_root, clean_ckpt_name) return fold_results def load_training_context(): with open(ROOT / "objects" / "mutlimodal_dataset.pkl", "rb") as f: md = pickle.load(f) X, y_label = md["X"], md["y_label"] b, df_indices, pcts = md["b"], md["df_indices"], md["pcts"] y_number = torch.tensor( [{"reprogramming": 1, "dead-end": 0}[i] for i in list(y_label)], dtype=torch.float32, ) multimodal_dataset = create_dataset.MultiModalDataset( X, b, y_number, df_indices, pcts, y_label ) with open(ROOT / "objects" / "fold_results_multi.pkl", "rb") as f: fold_results = pickle.load(f) fold_results = replace_fold_results_path(fold_results) share_config = { "d_model": 128, "d_ff": 16, "n_heads": 8, "n_encoder_layers": 2, "n_batches": 3, "dropout_rate": 0.0, } model_config_rna = {"vocab_size": 5914, "seq_len": X[0].shape[1]} model_config_atac = {"vocab_size": 1, "seq_len": X[1].shape[1]} model_config_flux = {"vocab_size": 1, "seq_len": X[2].shape[1]} model_config_multi = {"d_model": 128, "n_heads_cls": 8, "d_ff_cls": 16} model_config = { "Share": share_config, "RNA": model_config_rna, "ATAC": model_config_atac, "Flux": model_config_flux, "Multi": model_config_multi, } feature_names = ( list(X[0].columns) + ["batch_rna"] + list(X[1].columns) + ["batch_atac"] + list(X[2].columns) + ["batch_flux"] ) adata_RNA_labelled = None rna_pkl = ROOT / "data" / "datasets" / "rna_labelled.pkl" try: with open(rna_pkl, "rb") as f: adata_RNA_labelled = pickle.load(f) except Exception as e: print( f"Warning: could not load {rna_pkl} ({e}). " "Sample table will omit AnnData-derived metadata (e.g. clone_id)." ) return ( multimodal_dataset, fold_results, model_config, feature_names, adata_RNA_labelled, ) def build_latent_umap(multimodal_dataset, fold_results, model_config, common_samples: bool = False): device = "cuda" if torch.cuda.is_available() else "cpu" ls_v, labels, preds = ls.get_latent_space( "Multi", fold_results, multimodal_dataset, model_config, device, common_samples=common_samples, ) reducer = umap.UMAP(n_components=2, random_state=0, n_neighbors=30, min_dist=1.0) xy = reducer.fit_transform(ls_v) ordered_indices: list[int] = [] fold_ids: list[int] = [] from interpretation.attentions import filter_idx # noqa: PLC0415 from torch.utils.data import Subset # noqa: PLC0415 for fold_idx, fold in enumerate(fold_results): val_idx = fold["val_idx"] if common_samples: val_idx = filter_idx(multimodal_dataset, val_idx) ordered_indices.extend(val_idx) fold_ids.extend([fold_idx + 1] * len(val_idx)) labels = np.asarray(labels).ravel() preds = np.asarray(preds).ravel().astype(int) label_name = np.where(labels > 0.5, "reprogramming", "dead-end") pred_name = np.where(preds > 0.5, "reprogramming", "dead-end") correct = (preds == labels.astype(int)).astype(np.int8) ds = multimodal_dataset batch_no = np.array([int(ds.batch_no[i].item()) for i in ordered_indices], dtype=np.int32) pcts = np.array([float(ds.pcts[i]) for i in ordered_indices], dtype=np.float64) modalities = [] for i in ordered_indices: has_r = (ds.rna_data[i] != 0).any().item() has_a = (ds.atac_data[i] != 0).any().item() has_f = (ds.flux_data[i] != 0).any().item() s = "".join(c for c, h in (("R", has_r), ("A", has_a), ("F", has_f)) if h) modalities.append(s or "None") return { "umap_x": xy[:, 0].astype(np.float32), "umap_y": xy[:, 1].astype(np.float32), "label_name": label_name, "pred_name": pred_name, "correct": correct, "fold": np.array(fold_ids, dtype=np.int32), "batch_no": batch_no, "pct": pcts, "modality": modalities, "dataset_idx": np.array(ordered_indices, dtype=np.int32), "common_samples": common_samples, } def create_combined_feature_dataframe( fi_shift_rna, fi_shift_atac, fi_shift_flux, fi_att_rna, fi_att_atac, fi_att_flux, df_rna_degs=None, df_atac_degs=None, df_flux_degs=None, remove_batch=True, ): def process_modality(shift_list, att_list, degs_df, modality_name): shift_df = pd.DataFrame(shift_list, columns=["feature", "importance_shift"]).reset_index() shift_df.rename(columns={"index": "rank_shift_in_modal"}, inplace=True) shift_df["rank_shift_in_modal"] += 1 att_df = pd.DataFrame(att_list, columns=["feature", "importance_att"]).reset_index() att_df.rename(columns={"index": "rank_att_in_modal"}, inplace=True) att_df["rank_att_in_modal"] += 1 combined_df = pd.merge(shift_df, att_df, on="feature", how="outer") if degs_df is not None: combined_df = pd.merge(combined_df, degs_df, on="feature", how="left") combined_df["modality"] = modality_name return combined_df rna_df = process_modality(fi_shift_rna, fi_att_rna, df_rna_degs, "RNA") atac_df = process_modality(fi_shift_atac, fi_att_atac, df_atac_degs, "ATAC") flux_df = process_modality(fi_shift_flux, fi_att_flux, df_flux_degs, "Flux") all_features_df = pd.concat([rna_df, atac_df, flux_df], ignore_index=True) if remove_batch: all_features_df = all_features_df[~all_features_df["feature"].str.contains("batch", na=False)] max_rank_modal = max( all_features_df["rank_att_in_modal"].max(), all_features_df["rank_shift_in_modal"].max() ) all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[ ["rank_att_in_modal", "rank_shift_in_modal"] ].fillna(max_rank_modal + 1) all_features_df[["rank_att_in_modal", "rank_shift_in_modal"]] = all_features_df[ ["rank_att_in_modal", "rank_shift_in_modal"] ].astype("int32") all_features_df[["importance_att", "importance_shift"]] = ( all_features_df[["importance_att", "importance_shift"]].fillna(0).astype("float64") ) all_features_df["rank_shift"] = ( all_features_df["importance_shift"].rank(ascending=False, method="first").astype("int32") ) all_features_df["rank_att"] = ( all_features_df["importance_att"].rank(ascending=False, method="first").astype("int32") ) all_features_df["mean_rank"] = all_features_df[["rank_att", "rank_shift"]].mean(axis=1) top_th = int(all_features_df.shape[0] * 0.1) + 1 all_features_df["top_10_pct"] = all_features_df.apply( lambda row: "both" if row["rank_shift"] <= top_th and row["rank_att"] <= top_th else ( "shift" if row["rank_shift"] <= top_th else ("att" if row["rank_att"] <= top_th else "None") ), axis=1, ) float_cols = [ col for col in all_features_df.columns if col.startswith(("log_fc", "mean_", "std_", "pval_")) ] if float_cols: all_features_df[float_cols] = all_features_df[float_cols].round(6) all_features_df["importance_att"] = all_features_df["importance_att"].round(6) all_features_df["importance_shift"] = all_features_df["importance_shift"].round(6) all_features_df = all_features_df.sort_values(by="mean_rank", ascending=True) cols = [ "mean_rank", "feature", "rank_shift", "rank_att", "rank_shift_in_modal", "rank_att_in_modal", "modality", "importance_shift", "importance_att", "top_10_pct", "mean_de", "mean_re", "std_de", "std_re", "pval", "pval_adj", "log_fc", "group", "pval_adj_log", "mean_diff", "pathway", "module", ] for c in cols: if c not in all_features_df.columns: all_features_df[c] = np.nan return all_features_df[cols] def run_attention_and_fi( multimodal_dataset, fold_results, model_config, feature_names, device: str, adata_rna, ): df_samples = prds.get_sample_predictions_dataframe( model_type="Multi", multimodal_dataset=multimodal_dataset, fold_results=fold_results, model_config=model_config, device=device, batch_size=32, threshold=0.5, adata_rna=adata_rna, ) all_indices = df_samples["ind"].tolist() de_preds_indices = df_samples[df_samples["predicted_class"] == "dead-end"]["ind"].tolist() re_preds_indices = df_samples[df_samples["predicted_class"] == "reprogramming"]["ind"].tolist() print("Running flow attention (all validation)…") all_layers_all = att.analyze_cls_attention( "Multi", fold_results, multimodal_dataset, model_config, device=device, indices=all_indices, average_heads=False, return_flow_attention=True, ) print("Running flow attention (predicted dead-end)…") all_layers_de = att.analyze_cls_attention( "Multi", fold_results, multimodal_dataset, model_config, device=device, indices=de_preds_indices, average_heads=False, return_flow_attention=True, ) print("Running flow attention (predicted reprogramming)…") all_layers_re = att.analyze_cls_attention( "Multi", fold_results, multimodal_dataset, model_config, device=device, indices=re_preds_indices, average_heads=False, return_flow_attention=True, ) rollout_all = att.multimodal_attention_rollout(all_layers_all) rollout_de = att.multimodal_attention_rollout(all_layers_de) rollout_re = att.multimodal_attention_rollout(all_layers_re) rollout_all = rollout_all / rollout_all.sum(dim=-1, keepdim=True) rollout_de = rollout_de / rollout_de.sum(dim=-1, keepdim=True) rollout_re = rollout_re / rollout_re.sum(dim=-1, keepdim=True) # Explicit splits (notebook): RNA [:945], ATAC [945:945+884], flux rest i0, i1, i2 = 0, 945, 945 + 884 def mean_vec(t): return t.mean(dim=0).detach().cpu().numpy() rollout_mean = { "all": mean_vec(rollout_all), "dead_end": mean_vec(rollout_de), "reprogramming": mean_vec(rollout_re), } top_n_get = None fi = {"all": {}, "dead_end": {}, "reprogramming": {}} for name, tensor in ( ("all", rollout_all), ("dead_end", rollout_de), ("reprogramming", rollout_re), ): fi[name]["rna"] = att.get_top_features( tensor[:, i0:i1], feature_names[i0:i1], modality="RNA", top_n=top_n_get ) fi[name]["atac"] = att.get_top_features( tensor[:, i1:i2], feature_names[i1:i2], modality="ATAC", top_n=top_n_get ) fi[name]["flux"] = att.get_top_features( tensor[:, i2:], feature_names[i2:], modality="Flux", top_n=top_n_get ) summary = { "feature_names": feature_names, "slices": { "RNA": {"start": i0, "stop": i1}, "ATAC": {"start": i1, "stop": i2}, "Flux": {"start": i2, "stop": len(feature_names)}, }, "rollout_mean": rollout_mean, "fi_att": fi, } return summary, df_samples def main(): ap = argparse.ArgumentParser() ap.add_argument("--skip-attention", action="store_true", help="Skip attention if summary exists") ap.add_argument( "--common-samples", action="store_true", help="Use common-samples filter for latent UMAP (default: False, notebook-style)", ) args = ap.parse_args() common_samples = args.common_samples device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") ( multimodal_dataset, fold_results, model_config, feature_names, adata_RNA_labelled, ) = load_training_context() print("Building latent UMAP bundle…") latent = build_latent_umap( multimodal_dataset, fold_results, model_config, common_samples=common_samples ) with open(CACHE / "latent_umap.pkl", "wb") as f: pickle.dump(latent, f) att_path = CACHE / "attention_summary.pkl" df_samples_path = CACHE / "samples.parquet" if args.skip_attention and att_path.is_file(): print("Skipping attention (--skip-attention, file exists).") with open(att_path, "rb") as f: summary = pickle.load(f) else: print("Computing attention + rollout (slow)…") summary, df_samples = run_attention_and_fi( multimodal_dataset, fold_results, model_config, feature_names, device, adata_RNA_labelled, ) with open(att_path, "wb") as f: pickle.dump(summary, f) with open(CACHE / "attention_feature_ranks.pkl", "wb") as f: pickle.dump(summary["fi_att"], f) df_samples.to_parquet(df_samples_path, index=False) if args.skip_attention and att_path.is_file() and not df_samples_path.is_file(): df_samples = prds.get_sample_predictions_dataframe( model_type="Multi", multimodal_dataset=multimodal_dataset, fold_results=fold_results, model_config=model_config, device=device, batch_size=32, threshold=0.5, adata_rna=adata_RNA_labelled, ) df_samples.to_parquet(df_samples_path, index=False) for name in ["fi_shift_rna.pkl", "fi_shift_atac.pkl", "fi_shift_flux.pkl"]: src = ROOT / "objects" / name if not src.is_file(): print(f"Warning: missing {src}") with open(ROOT / "objects" / "fi_shift_rna.pkl", "rb") as f: fi_shift_rna = pickle.load(f) with open(ROOT / "objects" / "fi_shift_atac.pkl", "rb") as f: fi_shift_atac = pickle.load(f) with open(ROOT / "objects" / "fi_shift_flux.pkl", "rb") as f: fi_shift_flux = pickle.load(f) with open(ROOT / "objects" / "degs.pkl", "rb") as f: degs = pickle.load(f) df_rna_degs, df_atac_degs, df_flux_degs = degs[0], degs[1], degs[2] fi = summary["fi_att"] df_features = create_combined_feature_dataframe( fi_shift_rna, fi_shift_atac, fi_shift_flux, fi["all"]["rna"], fi["all"]["atac"], fi["all"]["flux"], df_rna_degs, df_atac_degs, df_flux_degs, ) df_features.to_parquet(CACHE / "df_features.parquet", index=False) df_features.to_csv(ROOT / "analysis" / "df_features.csv", index=False) print(f"Wrote {CACHE / 'df_features.parquet'} and analysis/df_features.csv") print("Done.") if __name__ == "__main__": main()