{ "cells": [ { "cell_type": "markdown", "id": "a2c7b1c0", "metadata": {}, "source": [ "# Foundation-Encoder Conditional Flow Matching\n", "\n", "This notebook implements a **foundation-aware perturbation prediction model** for the project dataset.\n", "\n", "Design:\n", "\n", "1. Load a **frozen foundation embedding** for each cell, such as Geneformer or scFoundation.\n", "2. Optionally add a lightweight **HVG branch** from the project expression matrix.\n", "3. Project these features into a trainable latent space.\n", "4. Learn a **perturbation-conditioned flow matching model** in that latent space.\n", "5. Decode the transported latent state back to expression space.\n", "\n", "This follows the thesis logic: the foundation model encodes **what the current cell state is**, while the flow module learns **how a perturbation moves that state**.\n" ] }, { "cell_type": "markdown", "id": "66dc2c85", "metadata": {}, "source": [ "## Supported Foundation Embedding Inputs\n", "\n", "The notebook is designed to work with project-style data and to **load foundation embeddings that were exported beforehand**. This is usually the most stable setup for Geneformer or scFoundation.\n", "\n", "Supported modes:\n", "\n", "- `same_adata_obsm`: embeddings already stored in the same `adata.obsm[...]`\n", "- `paired_h5ad_obsm`: embeddings stored in another `.h5ad`, aligned by `obs_names`\n", "- `npy_with_obs_csv`: embeddings in `.npy` plus a parallel `.csv` of cell ids\n", "- `pt_with_obs_csv`: embeddings in `.pt` plus a parallel `.csv` of cell ids\n", "- `hvg_only`: no foundation encoder, only the raw HVG branch\n", "\n", "Recommended use:\n", "\n", "- export **Geneformer**, **scFoundation**, or **scGPT** cell embeddings with each model’s official code (same `obs_names` as this `adata`)\n", "- store them in a file or paired `.h5ad`\n", "- point the config below to that file\n" ] }, { "cell_type": "markdown", "id": "b11134dd", "metadata": {}, "source": [ "## Foundation Model Notes\n", "\n", "This notebook **does not reimplement Geneformer or scFoundation tokenization/inference internals**. Instead, it assumes you load embeddings produced by their official code or model APIs, which is usually the most reproducible workflow for downstream perturbation experiments.\n", "\n", "Relevant primary references:\n", "\n", "- Geneformer documentation and model card describe extraction of cell embeddings from pretrained models.\n", "- scFoundation README states that pretrained weights and embedding workflows are provided through their model code and platform.\n", "- **scGPT**: run the official embedding / fine-tuning pipeline on your expression matrix, then write vectors to `adata.obsm['X_scGPT']` (or a paired `.h5ad` / `.npy` + cell-id CSV). Set `FOUNDATION_MODEL_NAME = \"scgpt\"` (metadata only), `FOUNDATION_EMBED_SOURCE`, and `FOUNDATION_OBSM_KEY` accordingly.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "60ebcf23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Project root: /root/final\n", "Input h5ad: /root/final/data/processed/scPerturb/prediction_ready/replogle_k562_essential_single_gene_prediction_ready_full.h5ad\n", "Output dir: /root/final/baseline/outputs/foundation_encoder_flow_matching\n", "Device: cuda\n" ] } ], "source": [ "from __future__ import annotations\n", "\n", "from pathlib import Path\n", "from copy import deepcopy\n", "import json\n", "import math\n", "import random\n", "\n", "import anndata as ad\n", "import numpy as np\n", "import pandas as pd\n", "import scipy.sparse as sp\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "def find_project_root(start: Path | None = None) -> Path:\n", " start = (start or Path.cwd()).resolve()\n", " for path in [start, *start.parents]:\n", " if (path / \"src\" / \"data\").exists() and (path / \"baseline\").exists():\n", " return path\n", " return start\n", "\n", "\n", "PROJECT_ROOT = find_project_root()\n", "DATASET_NAME = \"ReplogleWeissman2022_K562_essential\"\n", "METHOD_NAME = \"foundation_encoder_flow_matching\"\n", "RANDOM_SEED = 42\n", "\n", "SPLIT_STRATEGY = \"seen_cell_split\" # choices: \"seen_cell_split\", \"use_existing\"\n", "TRAIN_FRAC = 0.80\n", "VAL_FRAC = 0.10\n", "EVAL_SPLIT = \"test\"\n", "CONTROL_SPLIT = \"train\"\n", "TRAIN_SPLIT = \"train\"\n", "VAL_SPLIT = \"val\"\n", "\n", "USE_FEATURE_GENES = True\n", "FEATURE_GENE_COLUMNS = [\"is_feature_gene\", \"highly_variable\"]\n", "\n", "# Foundation embedding configuration.\n", "FOUNDATION_MODEL_NAME = \"geneformer\" # choices: \"geneformer\", \"scfoundation\", \"custom\", \"none\"\n", "FOUNDATION_EMBED_SOURCE = \"hvg_only\"\n", "# choices:\n", "# - \"same_adata_obsm\"\n", "# - \"paired_h5ad_obsm\"\n", "# - \"npy_with_obs_csv\"\n", "# - \"pt_with_obs_csv\"\n", "# - \"hvg_only\"\n", "\n", "FOUNDATION_OBSM_KEY = \"X_geneformer\" # e.g. \"X_scGPT\" if you store scGPT embeddings in obsm\n", "FOUNDATION_EMBED_H5AD_PATH = None\n", "FOUNDATION_EMBED_FILE = None\n", "FOUNDATION_OBS_CSV = None\n", "\n", "# Whether to concatenate a trainable HVG branch with the frozen foundation branch.\n", "USE_DUAL_CHANNEL = True\n", "FOUNDATION_FREEZE = True\n", "\n", "FOUNDATION_PROJ_DIM = 128\n", "HVG_PROJ_DIM = 64\n", "LATENT_DIM = 96\n", "HIDDEN_DIM = 256\n", "CONDITION_DIM = 128\n", "TIME_DIM = 64\n", "DROPOUT = 0.10\n", "NUM_FLOW_BLOCKS = 4\n", "\n", "BATCH_SIZE = 128\n", "EPOCHS = 20\n", "STEPS_PER_EPOCH = 200\n", "VALIDATION_STEPS = 32\n", "LEARNING_RATE = 1e-3\n", "WEIGHT_DECAY = 1e-5\n", "GRAD_CLIP_NORM = 1.0\n", "\n", "RECON_WEIGHT = 1.0\n", "FLOW_WEIGHT = 1.0\n", "ENDPOINT_WEIGHT = 0.5\n", "MMD_WEIGHT = 0.1\n", "MEAN_WEIGHT = 1.0\n", "\n", "NOISE_STD = 0.05\n", "SINKHORN_EPSILON = 0.5\n", "SINKHORN_ITERS = 30\n", "INFERENCE_STEPS = 24\n", "PRED_BATCH_SIZE = 256\n", "\n", "MAX_TRAIN_CONDITIONS = None\n", "MAX_EVAL_GENES = None\n", "MAX_EVAL_CELLS_PER_GENE = None\n", "\n", "INPUT_CANDIDATES = [\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready_full.h5ad\",\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready.h5ad\",\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready_fast_debug.h5ad\",\n", "]\n", "INPUT_PATH = next((path for path in INPUT_CANDIDATES if path.exists()), None)\n", "if INPUT_PATH is None:\n", " raise FileNotFoundError(\"Could not find a prediction-ready Replogle h5ad. Checked:\\n\" + \"\\n\".join(map(str, INPUT_CANDIDATES)))\n", "\n", "OUTPUT_DIR = PROJECT_ROOT / \"baseline\" / \"outputs\" / METHOD_NAME\n", "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", "MODEL_PATH = OUTPUT_DIR / \"foundation_encoder_flow_matching_model.pt\"\n", "\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "np.random.seed(RANDOM_SEED)\n", "random.seed(RANDOM_SEED)\n", "torch.manual_seed(RANDOM_SEED)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(RANDOM_SEED)\n", "if torch.backends.cudnn.is_available():\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", "\n", "print(f\"Project root: {PROJECT_ROOT}\")\n", "print(f\"Input h5ad: {INPUT_PATH}\")\n", "print(f\"Output dir: {OUTPUT_DIR}\")\n", "print(f\"Device: {DEVICE}\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "0c7e6c73", "metadata": {}, "outputs": [], "source": [ "def normalize_bool(series: pd.Series) -> pd.Series:\n", " if pd.api.types.is_bool_dtype(series):\n", " return series.fillna(False).astype(bool)\n", " text = series.astype(str).str.lower().str.strip()\n", " return text.isin([\"true\", \"1\", \"yes\", \"y\"])\n", "\n", "\n", "def require_columns(obs: pd.DataFrame, columns: list[str]) -> None:\n", " missing = [col for col in columns if col not in obs.columns]\n", " if missing:\n", " raise KeyError(f\"Missing required obs columns: {missing}\")\n", "\n", "\n", "def take_dense_rows(X, idx: np.ndarray) -> np.ndarray:\n", " out = X[idx]\n", " if sp.issparse(out):\n", " out = out.toarray()\n", " return np.asarray(out, dtype=np.float32)\n", "\n", "\n", "def mean_vector(X) -> np.ndarray:\n", " if sp.issparse(X):\n", " return np.asarray(X.mean(axis=0)).ravel().astype(np.float32)\n", " return np.asarray(X, dtype=np.float32).mean(axis=0)\n", "\n", "\n", "def safe_pearson(x: np.ndarray, y: np.ndarray) -> float:\n", " x = np.asarray(x).ravel()\n", " y = np.asarray(y).ravel()\n", " if x.size < 2 or np.std(x) == 0 or np.std(y) == 0:\n", " return np.nan\n", " return float(np.corrcoef(x, y)[0, 1])\n", "\n", "\n", "def split_array_indices(indices: np.ndarray, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n", " indices = np.asarray(indices)\n", " if len(indices) == 0:\n", " return indices, indices, indices\n", " shuffled = indices.copy()\n", " rng.shuffle(shuffled)\n", " if len(shuffled) < 3:\n", " return np.sort(shuffled), np.array([], dtype=indices.dtype), np.array([], dtype=indices.dtype)\n", " n_train = max(1, int(round(len(shuffled) * TRAIN_FRAC)))\n", " n_val = int(round(len(shuffled) * VAL_FRAC))\n", " n_train = min(n_train, len(shuffled) - 2)\n", " n_val = max(1, min(n_val, len(shuffled) - n_train - 1))\n", " train = np.sort(shuffled[:n_train])\n", " val = np.sort(shuffled[n_train : n_train + n_val])\n", " test = np.sort(shuffled[n_train + n_val :])\n", " if len(test) == 0:\n", " test = val[-1:].copy()\n", " val = val[:-1]\n", " if len(val) == 0:\n", " val = train[-1:].copy()\n", " train = train[:-1]\n", " return train, val, test\n", "\n", "\n", "def assign_seen_cell_split(obs: pd.DataFrame, seed: int) -> pd.Categorical:\n", " rng = np.random.default_rng(seed)\n", " split = np.full(len(obs), \"train\", dtype=object)\n", " is_control = normalize_bool(obs[\"is_control\"]).to_numpy(dtype=bool)\n", " genes = obs[\"perturbation_gene\"].astype(str).to_numpy()\n", "\n", " control_idx = np.where(is_control)[0]\n", " train_idx, val_idx, test_idx = split_array_indices(control_idx, rng)\n", " split[train_idx] = \"train\"\n", " split[val_idx] = \"val\"\n", " split[test_idx] = \"test\"\n", "\n", " for gene in sorted(set(genes[~is_control])):\n", " idx = np.where((~is_control) & (genes == gene))[0]\n", " train_idx, val_idx, test_idx = split_array_indices(idx, rng)\n", " split[train_idx] = \"train\"\n", " split[val_idx] = \"val\"\n", " split[test_idx] = \"test\"\n", "\n", " return pd.Categorical(split, categories=[\"train\", \"val\", \"test\"], ordered=True)\n", "\n", "\n", "def choose_feature_mask(var: pd.DataFrame) -> np.ndarray:\n", " if not USE_FEATURE_GENES:\n", " return np.ones(var.shape[0], dtype=bool)\n", " for column in FEATURE_GENE_COLUMNS:\n", " if column in var.columns:\n", " mask = var[column].astype(bool).to_numpy()\n", " if mask.any():\n", " print(f\"Using feature mask from var['{column}']: {int(mask.sum())} genes\")\n", " return mask\n", " print(\"Feature-gene columns not found or empty; using all genes.\")\n", " return np.ones(var.shape[0], dtype=bool)\n", "\n", "\n", "def median_heuristic_sigma(x: torch.Tensor, y: torch.Tensor, max_points: int = 256) -> float:\n", " z = torch.cat([x[:max_points], y[:max_points]], dim=0)\n", " if z.shape[0] < 2:\n", " return 1.0\n", " d2 = torch.cdist(z, z, p=2).pow(2)\n", " mask = ~torch.eye(d2.shape[0], dtype=torch.bool, device=d2.device)\n", " valid = d2[mask]\n", " valid = valid[valid > 0]\n", " if valid.numel() == 0:\n", " return 1.0\n", " return float(torch.sqrt(valid.median()).item())\n", "\n", "\n", "def rbf_mmd_loss(x: torch.Tensor, y: torch.Tensor, sigma: float | None = None) -> torch.Tensor:\n", " if x.shape[0] < 2 or y.shape[0] < 2:\n", " return x.new_tensor(0.0)\n", " sigma = sigma or median_heuristic_sigma(x, y)\n", " gamma = 1.0 / (2.0 * sigma * sigma + 1e-8)\n", " d_xx = torch.cdist(x, x, p=2).pow(2)\n", " d_yy = torch.cdist(y, y, p=2).pow(2)\n", " d_xy = torch.cdist(x, y, p=2).pow(2)\n", " k_xx = torch.exp(-gamma * d_xx)\n", " k_yy = torch.exp(-gamma * d_yy)\n", " k_xy = torch.exp(-gamma * d_xy)\n", " return k_xx.mean() + k_yy.mean() - 2.0 * k_xy.mean()\n", "\n", "\n", "def sinkhorn_barycentric_targets(x0: torch.Tensor, x1: torch.Tensor, epsilon: float, n_iters: int) -> torch.Tensor:\n", " cost = torch.cdist(x0, x1, p=2).pow(2)\n", " kernel = torch.exp(-cost / max(epsilon, 1e-4)).clamp_min(1e-8)\n", "\n", " a = torch.full((x0.shape[0],), 1.0 / x0.shape[0], device=x0.device, dtype=x0.dtype)\n", " b = torch.full((x1.shape[0],), 1.0 / x1.shape[0], device=x1.device, dtype=x1.dtype)\n", " u = torch.ones_like(a)\n", " v = torch.ones_like(b)\n", "\n", " for _ in range(n_iters):\n", " u = a / (kernel @ v + 1e-8)\n", " v = b / (kernel.t() @ u + 1e-8)\n", "\n", " plan = u[:, None] * kernel * v[None, :]\n", " row_mass = plan.sum(dim=1, keepdim=True).clamp_min(1e-8)\n", " return (plan / row_mass) @ x1\n", "\n", "\n", "def build_aggregate_frame(metrics: pd.DataFrame) -> pd.DataFrame:\n", " numeric = metrics.select_dtypes(include=[np.number])\n", " if numeric.empty:\n", " return pd.DataFrame()\n", " return numeric.agg([\"mean\", \"median\"]).T\n", "\n", "\n", "@torch.no_grad()\n", "def rk2_integrate_latent(model, batch: dict, steps: int) -> torch.Tensor:\n", " z = model.encode_state(batch[\"source_x\"], batch[\"source_cell_idx\"])\n", " dt = 1.0 / steps\n", " for step in range(steps):\n", " t0 = torch.full((z.shape[0],), step / steps, device=z.device, dtype=z.dtype)\n", " k1 = model.velocity(z, t0, batch[\"gene_idx\"], batch[\"batch_idx\"])\n", " z_mid = z + 0.5 * dt * k1\n", " t_mid = torch.full((z.shape[0],), (step + 0.5) / steps, device=z.device, dtype=z.dtype)\n", " k2 = model.velocity(z_mid, t_mid, batch[\"gene_idx\"], batch[\"batch_idx\"])\n", " z = z + dt * k2\n", " return z\n", "\n", "\n", "def load_embedding_matrix_from_file(path: Path) -> np.ndarray:\n", " suffix = path.suffix.lower()\n", " if suffix == \".npy\":\n", " arr = np.load(path)\n", " elif suffix == \".pt\":\n", " arr = torch.load(path, map_location=\"cpu\")\n", " if isinstance(arr, torch.Tensor):\n", " arr = arr.cpu().numpy()\n", " else:\n", " raise TypeError(\"Expected a tensor in the .pt embedding file.\")\n", " elif suffix in {\".csv\", \".tsv\"}:\n", " sep = \",\" if suffix == \".csv\" else \"\\t\"\n", " arr = pd.read_csv(path, sep=sep).to_numpy()\n", " else:\n", " raise ValueError(f\"Unsupported embedding file type: {path}\")\n", " return np.asarray(arr, dtype=np.float32)\n", "\n", "\n", "def align_embedding_matrix(obs_names: pd.Index, embed_obs_names: pd.Index, embed_matrix: np.ndarray) -> np.ndarray:\n", " embed_df = pd.DataFrame(embed_matrix, index=pd.Index(embed_obs_names.astype(str), name=\"cell_id\"))\n", " aligned = embed_df.reindex(pd.Index(obs_names.astype(str), name=\"cell_id\"))\n", " missing = aligned.isna().all(axis=1)\n", " if missing.any():\n", " missing_ids = aligned.index[missing].tolist()[:10]\n", " raise KeyError(f\"Some cells are missing foundation embeddings. First missing ids: {missing_ids}\")\n", " return aligned.to_numpy(dtype=np.float32)\n", "\n", "\n", "def load_foundation_embeddings(adata: ad.AnnData) -> tuple[np.ndarray | None, dict]:\n", " info = {\n", " \"foundation_model_name\": FOUNDATION_MODEL_NAME,\n", " \"foundation_embed_source\": FOUNDATION_EMBED_SOURCE,\n", " \"foundation_obsm_key\": FOUNDATION_OBSM_KEY,\n", " }\n", "\n", " if FOUNDATION_EMBED_SOURCE == \"hvg_only\" or FOUNDATION_MODEL_NAME == \"none\":\n", " print(\"Foundation embeddings disabled; using HVG branch only.\")\n", " return None, info\n", "\n", " if FOUNDATION_EMBED_SOURCE == \"same_adata_obsm\":\n", " if FOUNDATION_OBSM_KEY not in adata.obsm:\n", " raise KeyError(f\"FOUNDATION_OBSM_KEY={FOUNDATION_OBSM_KEY!r} not found in adata.obsm.\")\n", " matrix = np.asarray(adata.obsm[FOUNDATION_OBSM_KEY], dtype=np.float32)\n", " info[\"foundation_embed_dim\"] = int(matrix.shape[1])\n", " return matrix, info\n", "\n", " if FOUNDATION_EMBED_SOURCE == \"paired_h5ad_obsm\":\n", " if FOUNDATION_EMBED_H5AD_PATH is None:\n", " raise ValueError(\"FOUNDATION_EMBED_H5AD_PATH must be set for paired_h5ad_obsm mode.\")\n", " emb_adata = ad.read_h5ad(Path(FOUNDATION_EMBED_H5AD_PATH))\n", " if FOUNDATION_OBSM_KEY not in emb_adata.obsm:\n", " raise KeyError(f\"FOUNDATION_OBSM_KEY={FOUNDATION_OBSM_KEY!r} not found in paired embedding h5ad.\")\n", " matrix = align_embedding_matrix(\n", " obs_names=adata.obs_names,\n", " embed_obs_names=emb_adata.obs_names,\n", " embed_matrix=np.asarray(emb_adata.obsm[FOUNDATION_OBSM_KEY], dtype=np.float32),\n", " )\n", " info[\"foundation_embed_dim\"] = int(matrix.shape[1])\n", " return matrix, info\n", "\n", " if FOUNDATION_EMBED_SOURCE in {\"npy_with_obs_csv\", \"pt_with_obs_csv\"}:\n", " if FOUNDATION_EMBED_FILE is None or FOUNDATION_OBS_CSV is None:\n", " raise ValueError(\"FOUNDATION_EMBED_FILE and FOUNDATION_OBS_CSV must both be set.\")\n", " embed_matrix = load_embedding_matrix_from_file(Path(FOUNDATION_EMBED_FILE))\n", " obs_csv = pd.read_csv(Path(FOUNDATION_OBS_CSV))\n", " if \"cell_id\" in obs_csv.columns:\n", " embed_obs_names = obs_csv[\"cell_id\"].astype(str)\n", " elif \"obs_name\" in obs_csv.columns:\n", " embed_obs_names = obs_csv[\"obs_name\"].astype(str)\n", " else:\n", " embed_obs_names = obs_csv.iloc[:, 0].astype(str)\n", " matrix = align_embedding_matrix(\n", " obs_names=adata.obs_names,\n", " embed_obs_names=pd.Index(embed_obs_names),\n", " embed_matrix=embed_matrix,\n", " )\n", " info[\"foundation_embed_dim\"] = int(matrix.shape[1])\n", " return matrix, info\n", "\n", " raise ValueError(f\"Unknown FOUNDATION_EMBED_SOURCE: {FOUNDATION_EMBED_SOURCE}\")\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "a9eed842", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using feature mask from var['is_feature_gene']: 4334 genes\n", "Foundation embeddings disabled; using HVG branch only.\n", "adata: AnnData object with n_obs × n_vars = 310385 × 4334\n", " obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_label', 'perturbation_gene', 'is_control', 'is_single_perturbation', 'split', 'condition', 'batch_model', 'source_split'\n", " var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'is_target_gene', 'is_feature_gene'\n", " uns: 'hvg', 'log1p', 'prediction_ready'\n", " layers: 'counts'\n", "split counts:\n", "split\n", "train 248311\n", "test 31046\n", "val 31028\n", "Name: count, dtype: int64\n", "train conditions: 2056\n", "val conditions: 2056\n", "eval conditions: 2056\n", "eval genes missing from training: 0\n", "foundation info: {'foundation_model_name': 'geneformer', 'foundation_embed_source': 'hvg_only', 'foundation_obsm_key': 'X_geneformer'}\n" ] } ], "source": [ "adata = ad.read_h5ad(INPUT_PATH)\n", "require_columns(adata.obs, [\"perturbation_gene\", \"is_control\", \"split\"])\n", "\n", "adata.obs[\"is_control\"] = normalize_bool(adata.obs[\"is_control\"])\n", "adata.obs[\"perturbation_gene\"] = adata.obs[\"perturbation_gene\"].astype(str).str.strip()\n", "adata.obs[\"perturbation_gene\"] = adata.obs[\"perturbation_gene\"].mask(adata.obs[\"is_control\"], \"__ctrl__\")\n", "adata.obs[\"condition\"] = adata.obs[\"perturbation_gene\"].astype(str)\n", "adata.obs[\"batch_model\"] = adata.obs[\"batch\"].astype(str) if \"batch\" in adata.obs.columns else \"global\"\n", "\n", "feature_mask = choose_feature_mask(adata.var)\n", "adata = adata[:, feature_mask].copy()\n", "\n", "if SPLIT_STRATEGY == \"seen_cell_split\":\n", " adata.obs[\"source_split\"] = adata.obs[\"split\"].astype(str)\n", " adata.obs[\"split\"] = assign_seen_cell_split(adata.obs, RANDOM_SEED)\n", "elif SPLIT_STRATEGY == \"use_existing\":\n", " adata.obs[\"split\"] = pd.Categorical(\n", " adata.obs[\"split\"].astype(str),\n", " categories=[\"train\", \"val\", \"test\"],\n", " ordered=True,\n", " )\n", "else:\n", " raise ValueError(f\"Unknown SPLIT_STRATEGY: {SPLIT_STRATEGY}\")\n", "\n", "foundation_matrix_np, foundation_info = load_foundation_embeddings(adata)\n", "\n", "is_control = adata.obs[\"is_control\"].to_numpy(dtype=bool)\n", "split = adata.obs[\"split\"].astype(str).to_numpy()\n", "genes = adata.obs[\"perturbation_gene\"].astype(str).to_numpy()\n", "batches = adata.obs[\"batch_model\"].astype(str).to_numpy()\n", "cell_indices = np.arange(adata.n_obs, dtype=np.int64)\n", "\n", "control_train_idx = np.where(is_control & (split == CONTROL_SPLIT))[0]\n", "if len(control_train_idx) == 0:\n", " print(f\"No controls found in CONTROL_SPLIT={CONTROL_SPLIT}; falling back to all controls.\")\n", " control_train_idx = np.where(is_control)[0]\n", "if len(control_train_idx) == 0:\n", " raise ValueError(\"No control cells found.\")\n", "\n", "train_conditions = sorted(set(genes[(~is_control) & (split == TRAIN_SPLIT)]))\n", "val_conditions = sorted(set(genes[(~is_control) & (split == VAL_SPLIT)]))\n", "eval_conditions = sorted(set(genes[(~is_control) & (split == EVAL_SPLIT)]))\n", "\n", "if MAX_TRAIN_CONDITIONS is not None:\n", " train_conditions = train_conditions[:MAX_TRAIN_CONDITIONS]\n", "if MAX_EVAL_GENES is not None:\n", " eval_conditions = eval_conditions[:MAX_EVAL_GENES]\n", "\n", "indices_by_split_gene = {split_name: {} for split_name in [TRAIN_SPLIT, VAL_SPLIT, EVAL_SPLIT]}\n", "for split_name in [TRAIN_SPLIT, VAL_SPLIT, EVAL_SPLIT]:\n", " split_mask = (~is_control) & (split == split_name)\n", " split_gene_list = train_conditions if split_name == TRAIN_SPLIT else val_conditions if split_name == VAL_SPLIT else eval_conditions\n", " for gene in split_gene_list:\n", " idx = np.where(split_mask & (genes == gene))[0]\n", " if len(idx) > 0:\n", " indices_by_split_gene[split_name][gene] = idx\n", "\n", "control_idx_by_batch = {}\n", "for batch_name in sorted(set(batches[control_train_idx])):\n", " idx = control_train_idx[batches[control_train_idx] == batch_name]\n", " if len(idx) > 0:\n", " control_idx_by_batch[batch_name] = idx\n", "\n", "gene_vocab = [\"__ctrl__\", \"__unk__\", *train_conditions]\n", "gene_to_idx = {gene: idx for idx, gene in enumerate(gene_vocab)}\n", "batch_vocab = [\"__unk__\", *sorted(set(batches))]\n", "batch_to_idx = {batch: idx for idx, batch in enumerate(batch_vocab)}\n", "batch_index_per_obs = np.array([batch_to_idx.get(batch, batch_to_idx[\"__unk__\"]) for batch in batches], dtype=np.int64)\n", "\n", "control_mean = mean_vector(adata.X[control_train_idx])\n", "unseen_eval_genes = sorted(set(eval_conditions) - set(train_conditions))\n", "\n", "print(\"adata:\", adata)\n", "print(\"split counts:\")\n", "print(adata.obs[\"split\"].astype(str).value_counts())\n", "print(\"train conditions:\", len(train_conditions))\n", "print(\"val conditions:\", len(val_conditions))\n", "print(\"eval conditions:\", len(eval_conditions))\n", "print(\"eval genes missing from training:\", len(unseen_eval_genes))\n", "print(\"foundation info:\", foundation_info)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "2a2de0a0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FoundationAwareFlowModel(\n", " (hvg_encoder): HVGEncoder(\n", " (net): Sequential(\n", " (0): Linear(in_features=4334, out_features=512, bias=True)\n", " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=512, out_features=256, bias=True)\n", " (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (6): GELU(approximate='none')\n", " (7): Dropout(p=0.1, inplace=False)\n", " (8): Linear(in_features=256, out_features=64, bias=True)\n", " )\n", " )\n", " (state_projector): Projector(\n", " (net): Sequential(\n", " (0): Linear(in_features=64, out_features=256, bias=True)\n", " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=256, out_features=96, bias=True)\n", " )\n", " )\n", " (decoder): Decoder(\n", " (net): Sequential(\n", " (0): Linear(in_features=96, out_features=256, bias=True)\n", " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=256, out_features=512, bias=True)\n", " (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (6): GELU(approximate='none')\n", " (7): Dropout(p=0.1, inplace=False)\n", " (8): Linear(in_features=512, out_features=4334, bias=True)\n", " )\n", " )\n", " (condition_encoder): ConditionEncoder(\n", " (gene_embedding): Embedding(2058, 64)\n", " (batch_embedding): Embedding(49, 32)\n", " (mlp): Sequential(\n", " (0): Linear(in_features=96, out_features=128, bias=True)\n", " (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=128, out_features=128, bias=True)\n", " )\n", " )\n", " (time_encoder): TimeEmbedding(\n", " (proj): Sequential(\n", " (0): Linear(in_features=64, out_features=64, bias=True)\n", " (1): SiLU()\n", " (2): Linear(in_features=64, out_features=64, bias=True)\n", " )\n", " )\n", " (vector_field): ConditionalVectorField(\n", " (input_proj): Linear(in_features=160, out_features=256, bias=True)\n", " (blocks): ModuleList(\n", " (0-3): 4 x FiLMResidualBlock(\n", " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (fc1): Linear(in_features=256, out_features=256, bias=True)\n", " (fc2): Linear(in_features=256, out_features=256, bias=True)\n", " (film1): Linear(in_features=128, out_features=512, bias=True)\n", " (film2): Linear(in_features=128, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (out): Linear(in_features=256, out_features=96, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "class FrozenLookupEncoder(nn.Module):\n", " def __init__(self, matrix: np.ndarray):\n", " super().__init__()\n", " weight = torch.as_tensor(matrix, dtype=torch.float32)\n", " self.embedding = nn.Embedding.from_pretrained(weight, freeze=True)\n", "\n", " @property\n", " def output_dim(self) -> int:\n", " return int(self.embedding.embedding_dim)\n", "\n", " def forward(self, cell_idx: torch.Tensor) -> torch.Tensor:\n", " return self.embedding(cell_idx)\n", "\n", "\n", "class HVGEncoder(nn.Module):\n", " def __init__(self, input_dim: int, hidden_dim: int, out_dim: int, dropout: float):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(input_dim, hidden_dim * 2),\n", " nn.LayerNorm(hidden_dim * 2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim * 2, hidden_dim),\n", " nn.LayerNorm(hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, out_dim),\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.net(x)\n", "\n", "\n", "class Projector(nn.Module):\n", " def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int, dropout: float):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(input_dim, hidden_dim),\n", " nn.LayerNorm(hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, latent_dim),\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.net(x)\n", "\n", "\n", "class Decoder(nn.Module):\n", " def __init__(self, latent_dim: int, output_dim: int, hidden_dim: int, dropout: float):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(latent_dim, hidden_dim),\n", " nn.LayerNorm(hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, hidden_dim * 2),\n", " nn.LayerNorm(hidden_dim * 2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim * 2, output_dim),\n", " )\n", "\n", " def forward(self, z: torch.Tensor) -> torch.Tensor:\n", " return F.softplus(self.net(z))\n", "\n", "\n", "class ConditionEncoder(nn.Module):\n", " def __init__(self, n_genes: int, n_batches: int, condition_dim: int, dropout: float):\n", " super().__init__()\n", " gene_dim = max(32, condition_dim // 2)\n", " batch_dim = max(16, condition_dim // 4)\n", " self.gene_embedding = nn.Embedding(n_genes, gene_dim)\n", " self.batch_embedding = nn.Embedding(n_batches, batch_dim)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(gene_dim + batch_dim, condition_dim),\n", " nn.LayerNorm(condition_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(condition_dim, condition_dim),\n", " )\n", "\n", " def forward(self, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " gene_emb = self.gene_embedding(gene_idx)\n", " batch_emb = self.batch_embedding(batch_idx)\n", " return self.mlp(torch.cat([gene_emb, batch_emb], dim=-1))\n", "\n", "\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, time_dim: int):\n", " super().__init__()\n", " self.time_dim = time_dim\n", " self.proj = nn.Sequential(\n", " nn.Linear(time_dim, time_dim),\n", " nn.SiLU(),\n", " nn.Linear(time_dim, time_dim),\n", " )\n", "\n", " def forward(self, t: torch.Tensor) -> torch.Tensor:\n", " half = self.time_dim // 2\n", " freqs = torch.exp(\n", " -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half, 1)\n", " )\n", " angles = t[:, None] * freqs[None, :]\n", " emb = torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)\n", " if emb.shape[1] < self.time_dim:\n", " emb = F.pad(emb, (0, self.time_dim - emb.shape[1]))\n", " return self.proj(emb)\n", "\n", "\n", "class FiLMResidualBlock(nn.Module):\n", " def __init__(self, hidden_dim: int, condition_dim: int, dropout: float):\n", " super().__init__()\n", " self.norm1 = nn.LayerNorm(hidden_dim)\n", " self.norm2 = nn.LayerNorm(hidden_dim)\n", " self.fc1 = nn.Linear(hidden_dim, hidden_dim)\n", " self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n", " self.film1 = nn.Linear(condition_dim, hidden_dim * 2)\n", " self.film2 = nn.Linear(condition_dim, hidden_dim * 2)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:\n", " gamma1, beta1 = self.film1(condition).chunk(2, dim=-1)\n", " h = self.norm1(x)\n", " h = h * (1 + gamma1) + beta1\n", " h = F.gelu(self.fc1(h))\n", " h = self.dropout(h)\n", "\n", " gamma2, beta2 = self.film2(condition).chunk(2, dim=-1)\n", " h = self.norm2(h)\n", " h = h * (1 + gamma2) + beta2\n", " h = self.fc2(h)\n", " h = self.dropout(h)\n", " return x + h\n", "\n", "\n", "class ConditionalVectorField(nn.Module):\n", " def __init__(self, latent_dim: int, time_dim: int, condition_dim: int, hidden_dim: int, num_blocks: int, dropout: float):\n", " super().__init__()\n", " self.input_proj = nn.Linear(latent_dim + time_dim, hidden_dim)\n", " self.blocks = nn.ModuleList([FiLMResidualBlock(hidden_dim, condition_dim, dropout) for _ in range(num_blocks)])\n", " self.out = nn.Linear(hidden_dim, latent_dim)\n", "\n", " def forward(self, z_t: torch.Tensor, t_emb: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:\n", " h = self.input_proj(torch.cat([z_t, t_emb], dim=-1))\n", " for block in self.blocks:\n", " h = block(h, condition)\n", " return self.out(h)\n", "\n", "\n", "class FoundationAwareFlowModel(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim: int,\n", " latent_dim: int,\n", " hidden_dim: int,\n", " foundation_matrix: np.ndarray | None,\n", " foundation_proj_dim: int,\n", " hvg_proj_dim: int,\n", " use_dual_channel: bool,\n", " condition_dim: int,\n", " time_dim: int,\n", " n_genes: int,\n", " n_batches: int,\n", " num_blocks: int,\n", " dropout: float,\n", " ):\n", " super().__init__()\n", " self.use_dual_channel = use_dual_channel\n", " self.has_foundation = foundation_matrix is not None\n", "\n", " if self.has_foundation:\n", " self.foundation_encoder = FrozenLookupEncoder(foundation_matrix)\n", " foundation_in_dim = self.foundation_encoder.output_dim\n", " self.foundation_projector = Projector(\n", " input_dim=foundation_in_dim,\n", " latent_dim=foundation_proj_dim,\n", " hidden_dim=hidden_dim,\n", " dropout=dropout,\n", " )\n", " else:\n", " self.foundation_encoder = None\n", " self.foundation_projector = None\n", " foundation_proj_dim = 0\n", "\n", " if self.use_dual_channel or not self.has_foundation:\n", " self.hvg_encoder = HVGEncoder(\n", " input_dim=input_dim,\n", " hidden_dim=hidden_dim,\n", " out_dim=hvg_proj_dim,\n", " dropout=dropout,\n", " )\n", " else:\n", " self.hvg_encoder = None\n", " hvg_proj_dim = 0\n", "\n", " state_input_dim = foundation_proj_dim + hvg_proj_dim\n", " if state_input_dim <= 0:\n", " raise ValueError(\"At least one encoder branch must be active.\")\n", "\n", " self.state_projector = Projector(\n", " input_dim=state_input_dim,\n", " latent_dim=latent_dim,\n", " hidden_dim=hidden_dim,\n", " dropout=dropout,\n", " )\n", " self.decoder = Decoder(latent_dim=latent_dim, output_dim=input_dim, hidden_dim=hidden_dim, dropout=dropout)\n", " self.condition_encoder = ConditionEncoder(n_genes=n_genes, n_batches=n_batches, condition_dim=condition_dim, dropout=dropout)\n", " self.time_encoder = TimeEmbedding(time_dim=time_dim)\n", " self.vector_field = ConditionalVectorField(\n", " latent_dim=latent_dim,\n", " time_dim=time_dim,\n", " condition_dim=condition_dim,\n", " hidden_dim=hidden_dim,\n", " num_blocks=num_blocks,\n", " dropout=dropout,\n", " )\n", "\n", " def encode_state(self, x: torch.Tensor, cell_idx: torch.Tensor) -> torch.Tensor:\n", " parts = []\n", " if self.foundation_encoder is not None:\n", " foundation_emb = self.foundation_encoder(cell_idx)\n", " foundation_proj = self.foundation_projector(foundation_emb)\n", " parts.append(foundation_proj)\n", " if self.hvg_encoder is not None:\n", " hvg_proj = self.hvg_encoder(x)\n", " parts.append(hvg_proj)\n", " return self.state_projector(torch.cat(parts, dim=-1))\n", "\n", " def decode_z(self, z: torch.Tensor) -> torch.Tensor:\n", " return self.decoder(z)\n", "\n", " def reconstruct(self, x: torch.Tensor, cell_idx: torch.Tensor) -> torch.Tensor:\n", " return self.decode_z(self.encode_state(x, cell_idx))\n", "\n", " def condition(self, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " return self.condition_encoder(gene_idx, batch_idx)\n", "\n", " def velocity(self, z_t: torch.Tensor, t: torch.Tensor, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " condition = self.condition(gene_idx, batch_idx)\n", " t_emb = self.time_encoder(t)\n", " return self.vector_field(z_t, t_emb, condition)\n", "\n", "\n", "model = FoundationAwareFlowModel(\n", " input_dim=adata.n_vars,\n", " latent_dim=LATENT_DIM,\n", " hidden_dim=HIDDEN_DIM,\n", " foundation_matrix=foundation_matrix_np,\n", " foundation_proj_dim=FOUNDATION_PROJ_DIM,\n", " hvg_proj_dim=HVG_PROJ_DIM,\n", " use_dual_channel=USE_DUAL_CHANNEL,\n", " condition_dim=CONDITION_DIM,\n", " time_dim=TIME_DIM,\n", " n_genes=len(gene_vocab),\n", " n_batches=len(batch_vocab),\n", " num_blocks=NUM_FLOW_BLOCKS,\n", " dropout=DROPOUT,\n", ").to(DEVICE)\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n", "print(model)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "3c23da6a", "metadata": {}, "outputs": [], "source": [ "def sample_controls_like_targets(target_idx: np.ndarray, rng: np.random.Generator) -> np.ndarray:\n", " sampled = []\n", " for idx in target_idx:\n", " batch_name = batches[idx]\n", " pool = control_idx_by_batch.get(batch_name, control_train_idx)\n", " sampled.append(rng.choice(pool))\n", " return np.asarray(sampled, dtype=np.int64)\n", "\n", "\n", "def make_condition_batch(condition: str, split_name: str, rng: np.random.Generator) -> dict:\n", " target_pool = indices_by_split_gene[split_name][condition]\n", " size = min(BATCH_SIZE, len(target_pool))\n", " replace = len(target_pool) < size\n", " target_idx = rng.choice(target_pool, size=size, replace=replace)\n", " source_idx = sample_controls_like_targets(target_idx, rng)\n", "\n", " if MAX_EVAL_CELLS_PER_GENE is not None and len(target_idx) > MAX_EVAL_CELLS_PER_GENE:\n", " target_idx = target_idx[:MAX_EVAL_CELLS_PER_GENE]\n", " source_idx = source_idx[:MAX_EVAL_CELLS_PER_GENE]\n", "\n", " source_x = torch.from_numpy(take_dense_rows(adata.X, source_idx)).to(DEVICE)\n", " target_x = torch.from_numpy(take_dense_rows(adata.X, target_idx)).to(DEVICE)\n", " source_cell_idx = torch.as_tensor(cell_indices[source_idx], device=DEVICE, dtype=torch.long)\n", " target_cell_idx = torch.as_tensor(cell_indices[target_idx], device=DEVICE, dtype=torch.long)\n", " gene_idx = torch.full(\n", " (len(target_idx),),\n", " fill_value=gene_to_idx.get(condition, gene_to_idx[\"__unk__\"]),\n", " device=DEVICE,\n", " dtype=torch.long,\n", " )\n", " batch_idx = torch.as_tensor(batch_index_per_obs[source_idx], device=DEVICE, dtype=torch.long)\n", "\n", " return {\n", " \"condition\": condition,\n", " \"source_idx\": source_idx,\n", " \"target_idx\": target_idx,\n", " \"source_x\": source_x,\n", " \"target_x\": target_x,\n", " \"source_cell_idx\": source_cell_idx,\n", " \"target_cell_idx\": target_cell_idx,\n", " \"gene_idx\": gene_idx,\n", " \"batch_idx\": batch_idx,\n", " }\n", "\n", "\n", "def compute_loss_terms(batch: dict) -> dict[str, torch.Tensor]:\n", " z0 = model.encode_state(batch[\"source_x\"], batch[\"source_cell_idx\"])\n", " z1 = model.encode_state(batch[\"target_x\"], batch[\"target_cell_idx\"])\n", "\n", " z1_bar = sinkhorn_barycentric_targets(\n", " x0=z0,\n", " x1=z1,\n", " epsilon=SINKHORN_EPSILON,\n", " n_iters=SINKHORN_ITERS,\n", " )\n", "\n", " t = torch.rand(z0.shape[0], device=DEVICE)\n", " sigma = NOISE_STD * torch.sqrt((t * (1.0 - t)).clamp_min(1e-6)).unsqueeze(-1)\n", " z_t = (1.0 - t).unsqueeze(-1) * z0 + t.unsqueeze(-1) * z1_bar + sigma * torch.randn_like(z0)\n", "\n", " target_velocity = z1_bar - z0\n", " pred_velocity = model.velocity(z_t, t, batch[\"gene_idx\"], batch[\"batch_idx\"])\n", " pred_z1 = z_t + (1.0 - t).unsqueeze(-1) * pred_velocity\n", " pred_x1 = model.decode_z(pred_z1)\n", "\n", " source_recon = model.reconstruct(batch[\"source_x\"], batch[\"source_cell_idx\"])\n", " target_recon = model.reconstruct(batch[\"target_x\"], batch[\"target_cell_idx\"])\n", "\n", " recon_loss = F.mse_loss(source_recon, batch[\"source_x\"]) + F.mse_loss(target_recon, batch[\"target_x\"])\n", " flow_loss = F.mse_loss(pred_velocity, target_velocity)\n", " endpoint_loss = F.mse_loss(pred_z1, z1_bar)\n", " mmd_loss = rbf_mmd_loss(pred_x1, batch[\"target_x\"])\n", " mean_loss = F.mse_loss(pred_x1.mean(dim=0), batch[\"target_x\"].mean(dim=0))\n", "\n", " total_loss = (\n", " RECON_WEIGHT * recon_loss\n", " + FLOW_WEIGHT * flow_loss\n", " + ENDPOINT_WEIGHT * endpoint_loss\n", " + MMD_WEIGHT * mmd_loss\n", " + MEAN_WEIGHT * mean_loss\n", " )\n", "\n", " return {\n", " \"total\": total_loss,\n", " \"recon\": recon_loss.detach(),\n", " \"flow\": flow_loss.detach(),\n", " \"endpoint\": endpoint_loss.detach(),\n", " \"mmd\": mmd_loss.detach(),\n", " \"mean\": mean_loss.detach(),\n", " }\n", "\n", "\n", "@torch.no_grad()\n", "def evaluate_split_loss(split_name: str, conditions: list[str], seed: int, steps: int) -> dict[str, float]:\n", " if not conditions:\n", " return {}\n", "\n", " model.eval()\n", " rng = np.random.default_rng(seed)\n", " subset = conditions.copy()\n", " rng.shuffle(subset)\n", " subset = subset[: min(len(subset), steps)]\n", "\n", " rows = []\n", " for condition in subset:\n", " batch = make_condition_batch(condition, split_name=split_name, rng=rng)\n", " losses = compute_loss_terms(batch)\n", " rows.append({key: float(value.item()) for key, value in losses.items()})\n", "\n", " return pd.DataFrame(rows).mean().to_dict()\n", "\n", "\n", "@torch.no_grad()\n", "def predict_gene_expression(condition: str, split_name: str, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]:\n", " target_idx_all = indices_by_split_gene[split_name][condition]\n", " if MAX_EVAL_CELLS_PER_GENE is not None and len(target_idx_all) > MAX_EVAL_CELLS_PER_GENE:\n", " target_idx_all = target_idx_all[:MAX_EVAL_CELLS_PER_GENE]\n", "\n", " pred_blocks = []\n", " real_blocks = []\n", " obs_blocks = []\n", "\n", " for start in range(0, len(target_idx_all), PRED_BATCH_SIZE):\n", " target_idx = target_idx_all[start : start + PRED_BATCH_SIZE]\n", " source_idx = sample_controls_like_targets(target_idx, rng)\n", "\n", " batch = {\n", " \"source_x\": torch.from_numpy(take_dense_rows(adata.X, source_idx)).to(DEVICE),\n", " \"source_cell_idx\": torch.as_tensor(cell_indices[source_idx], device=DEVICE, dtype=torch.long),\n", " \"gene_idx\": torch.full(\n", " (len(target_idx),),\n", " gene_to_idx.get(condition, gene_to_idx[\"__unk__\"]),\n", " device=DEVICE,\n", " dtype=torch.long,\n", " ),\n", " \"batch_idx\": torch.as_tensor(batch_index_per_obs[source_idx], device=DEVICE, dtype=torch.long),\n", " }\n", "\n", " z1_hat = rk2_integrate_latent(model, batch, steps=INFERENCE_STEPS)\n", " pred_x = model.decode_z(z1_hat).cpu().numpy().astype(np.float32)\n", " real_x = take_dense_rows(adata.X, target_idx)\n", "\n", " pred_blocks.append(pred_x)\n", " real_blocks.append(real_x)\n", "\n", " obs_chunk = adata.obs.iloc[target_idx][[\"perturbation_gene\", \"condition\", \"split\"]].copy()\n", " obs_chunk[\"target_cell_id\"] = adata.obs_names[target_idx].astype(str)\n", " obs_chunk[\"sampled_control_cell_id\"] = adata.obs_names[source_idx].astype(str)\n", " obs_chunk[\"foundation_model_name\"] = FOUNDATION_MODEL_NAME\n", " obs_blocks.append(obs_chunk)\n", "\n", " pred_gene = np.vstack(pred_blocks).astype(np.float32)\n", " real_gene = np.vstack(real_blocks).astype(np.float32)\n", " obs_gene = pd.concat(obs_blocks, axis=0)\n", " return pred_gene, real_gene, obs_gene\n", "\n", "\n", "@torch.no_grad()\n", "def evaluate_gene_metrics(split_name: str, conditions: list[str], seed: int) -> tuple[pd.DataFrame, ad.AnnData, ad.AnnData]:\n", " rng = np.random.default_rng(seed)\n", "\n", " pred_blocks = []\n", " real_blocks = []\n", " obs_blocks = []\n", " metrics_rows = []\n", "\n", " for condition in conditions:\n", " pred_gene, real_gene, obs_gene = predict_gene_expression(condition, split_name, rng)\n", "\n", " pred_blocks.append(pred_gene)\n", " real_blocks.append(real_gene)\n", " obs_blocks.append(obs_gene)\n", "\n", " pred_mean = pred_gene.mean(axis=0)\n", " real_mean = real_gene.mean(axis=0)\n", " pred_delta = pred_mean - control_mean\n", " real_delta = real_mean - control_mean\n", "\n", " metrics_rows.append(\n", " {\n", " \"perturbation_gene\": condition,\n", " \"n_cells\": int(real_gene.shape[0]),\n", " \"mse_mean\": float(np.mean((pred_mean - real_mean) ** 2)),\n", " \"mae_mean\": float(np.mean(np.abs(pred_mean - real_mean))),\n", " \"pearson_mean\": safe_pearson(pred_mean, real_mean),\n", " \"pearson_delta\": safe_pearson(pred_delta, real_delta),\n", " \"delta_l2\": float(np.linalg.norm(pred_delta - real_delta)),\n", " }\n", " )\n", "\n", " pred = ad.AnnData(\n", " X=np.vstack(pred_blocks).astype(np.float32),\n", " obs=pd.concat(obs_blocks, axis=0),\n", " var=adata.var.copy(),\n", " )\n", " real = ad.AnnData(\n", " X=np.vstack(real_blocks).astype(np.float32),\n", " obs=pd.concat(obs_blocks, axis=0),\n", " var=adata.var.copy(),\n", " )\n", " metrics = pd.DataFrame(metrics_rows).sort_values(\"perturbation_gene\").reset_index(drop=True)\n", " return metrics, pred, real\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "ace518dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'epoch': 1, 'train_total': 0.3869, 'train_flow': 0.049, 'train_recon': 0.2952, 'val_total': 0.3275}\n", "{'epoch': 2, 'train_total': 0.3046, 'train_flow': 0.0023, 'train_recon': 0.2775, 'val_total': 0.3213}\n", "{'epoch': 3, 'train_total': 0.2992, 'train_flow': 0.0012, 'train_recon': 0.2738, 'val_total': 0.325}\n", "{'epoch': 4, 'train_total': 0.2984, 'train_flow': 0.0008, 'train_recon': 0.2736, 'val_total': 0.3282}\n", "{'epoch': 5, 'train_total': 0.3005, 'train_flow': 0.0005, 'train_recon': 0.2757, 'val_total': 0.3171}\n", "{'epoch': 6, 'train_total': 0.3021, 'train_flow': 0.0005, 'train_recon': 0.2768, 'val_total': 0.3102}\n", "{'epoch': 7, 'train_total': 0.298, 'train_flow': 0.0004, 'train_recon': 0.2732, 'val_total': 0.3117}\n", "{'epoch': 8, 'train_total': 0.2966, 'train_flow': 0.0003, 'train_recon': 0.2723, 'val_total': 0.3245}\n", "{'epoch': 9, 'train_total': 0.298, 'train_flow': 0.0002, 'train_recon': 0.2736, 'val_total': 0.312}\n", "{'epoch': 10, 'train_total': 0.3014, 'train_flow': 0.0003, 'train_recon': 0.2765, 'val_total': 0.3295}\n", "{'epoch': 11, 'train_total': 0.2981, 'train_flow': 0.0002, 'train_recon': 0.2739, 'val_total': 0.3117}\n", "{'epoch': 12, 'train_total': 0.3001, 'train_flow': 0.0002, 'train_recon': 0.2758, 'val_total': 0.3073}\n", "{'epoch': 13, 'train_total': 0.2992, 'train_flow': 0.0002, 'train_recon': 0.275, 'val_total': 0.317}\n", "{'epoch': 14, 'train_total': 0.2976, 'train_flow': 0.0002, 'train_recon': 0.2738, 'val_total': 0.3246}\n", "{'epoch': 15, 'train_total': 0.2951, 'train_flow': 0.0001, 'train_recon': 0.2718, 'val_total': 0.3031}\n", "{'epoch': 16, 'train_total': 0.2987, 'train_flow': 0.0001, 'train_recon': 0.275, 'val_total': 0.3139}\n", "{'epoch': 17, 'train_total': 0.2961, 'train_flow': 0.0001, 'train_recon': 0.2726, 'val_total': 0.3283}\n", "{'epoch': 18, 'train_total': 0.2978, 'train_flow': 0.0001, 'train_recon': 0.2739, 'val_total': 0.3148}\n", "{'epoch': 19, 'train_total': 0.2954, 'train_flow': 0.0001, 'train_recon': 0.2722, 'val_total': 0.3195}\n", "{'epoch': 20, 'train_total': 0.2931, 'train_flow': 0.0001, 'train_recon': 0.2697, 'val_total': 0.3245}\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_totaltrain_recontrain_flowtrain_endpointtrain_mmdtrain_meanval_totalval_reconval_flowval_endpointval_mmdval_mean
010.3869180.2951750.0490480.0164690.1880880.0156520.3274610.2758260.0010440.0007790.2483190.025369
120.3045520.2775280.0023270.0011930.1779840.0063030.3212790.2729230.0004430.0005440.2625860.021382
230.2991560.2738040.0012440.0008200.1789000.0058080.3249600.2797550.0003340.0004980.2473300.019889
340.2983640.2736200.0007740.0006680.1789140.0057450.3281900.2743260.0001260.0004490.2576750.027746
450.3005320.2756920.0005310.0005860.1801510.0060020.3170660.2711360.0000820.0004420.2580870.019818
560.3020610.2767780.0004560.0005600.1812490.0064230.3101930.2692570.0000680.0004480.2148000.019164
670.2979860.2732200.0004050.0005390.1819480.0058970.3116890.2714400.0000600.0004110.2355930.016424
780.2966080.2723300.0003100.0005020.1806050.0056560.3245160.2692130.0000690.0004350.2219870.032819
890.2979560.2736370.0002450.0004860.1812640.0057050.3120060.2671910.0001590.0004340.2274880.021690
9100.3014480.2764500.0002560.0004830.1817920.0063210.3294980.2781170.0000770.0004270.2738310.023708
10110.2980620.2738780.0002040.0004650.1795040.0057970.3116610.2688260.0000410.0004300.2438890.018189
11120.3000770.2758460.0001660.0004570.1800870.0058280.3072870.2658510.0000380.0004210.2369160.017495
12130.2992090.2750260.0001540.0004520.1800390.0057990.3169970.2681630.0000580.0004300.2489000.023671
13140.2976290.2737560.0001600.0004490.1805050.0054380.3246110.2780860.0000360.0004060.2216370.024122
14150.2950550.2718230.0001210.0004400.1786710.0050230.3031180.2651200.0000310.0004080.2235780.015406
15160.2986790.2749720.0001070.0004330.1794700.0054370.3138970.2739640.0000310.0004080.2245540.017243
16170.2960670.2726460.0000940.0004300.1783720.0052750.3283410.2750190.0000190.0004100.2599040.027108
17180.2977590.2738890.0000940.0004270.1794080.0056220.3148340.2713820.0000830.0004310.2412520.019028
18190.2954280.2721500.0000970.0004250.1782860.0051400.3195000.2781220.0000710.0004410.2129980.019787
19200.2930980.2697340.0000810.0004220.1792690.0051450.3244620.2819560.0000200.0004040.2318680.019096
\n", "
" ], "text/plain": [ " epoch train_total train_recon train_flow train_endpoint train_mmd \\\n", "0 1 0.386918 0.295175 0.049048 0.016469 0.188088 \n", "1 2 0.304552 0.277528 0.002327 0.001193 0.177984 \n", "2 3 0.299156 0.273804 0.001244 0.000820 0.178900 \n", "3 4 0.298364 0.273620 0.000774 0.000668 0.178914 \n", "4 5 0.300532 0.275692 0.000531 0.000586 0.180151 \n", "5 6 0.302061 0.276778 0.000456 0.000560 0.181249 \n", "6 7 0.297986 0.273220 0.000405 0.000539 0.181948 \n", "7 8 0.296608 0.272330 0.000310 0.000502 0.180605 \n", "8 9 0.297956 0.273637 0.000245 0.000486 0.181264 \n", "9 10 0.301448 0.276450 0.000256 0.000483 0.181792 \n", "10 11 0.298062 0.273878 0.000204 0.000465 0.179504 \n", "11 12 0.300077 0.275846 0.000166 0.000457 0.180087 \n", "12 13 0.299209 0.275026 0.000154 0.000452 0.180039 \n", "13 14 0.297629 0.273756 0.000160 0.000449 0.180505 \n", "14 15 0.295055 0.271823 0.000121 0.000440 0.178671 \n", "15 16 0.298679 0.274972 0.000107 0.000433 0.179470 \n", "16 17 0.296067 0.272646 0.000094 0.000430 0.178372 \n", "17 18 0.297759 0.273889 0.000094 0.000427 0.179408 \n", "18 19 0.295428 0.272150 0.000097 0.000425 0.178286 \n", "19 20 0.293098 0.269734 0.000081 0.000422 0.179269 \n", "\n", " train_mean val_total val_recon val_flow val_endpoint val_mmd \\\n", "0 0.015652 0.327461 0.275826 0.001044 0.000779 0.248319 \n", "1 0.006303 0.321279 0.272923 0.000443 0.000544 0.262586 \n", "2 0.005808 0.324960 0.279755 0.000334 0.000498 0.247330 \n", "3 0.005745 0.328190 0.274326 0.000126 0.000449 0.257675 \n", "4 0.006002 0.317066 0.271136 0.000082 0.000442 0.258087 \n", "5 0.006423 0.310193 0.269257 0.000068 0.000448 0.214800 \n", "6 0.005897 0.311689 0.271440 0.000060 0.000411 0.235593 \n", "7 0.005656 0.324516 0.269213 0.000069 0.000435 0.221987 \n", "8 0.005705 0.312006 0.267191 0.000159 0.000434 0.227488 \n", "9 0.006321 0.329498 0.278117 0.000077 0.000427 0.273831 \n", "10 0.005797 0.311661 0.268826 0.000041 0.000430 0.243889 \n", "11 0.005828 0.307287 0.265851 0.000038 0.000421 0.236916 \n", "12 0.005799 0.316997 0.268163 0.000058 0.000430 0.248900 \n", "13 0.005438 0.324611 0.278086 0.000036 0.000406 0.221637 \n", "14 0.005023 0.303118 0.265120 0.000031 0.000408 0.223578 \n", "15 0.005437 0.313897 0.273964 0.000031 0.000408 0.224554 \n", "16 0.005275 0.328341 0.275019 0.000019 0.000410 0.259904 \n", "17 0.005622 0.314834 0.271382 0.000083 0.000431 0.241252 \n", "18 0.005140 0.319500 0.278122 0.000071 0.000441 0.212998 \n", "19 0.005145 0.324462 0.281956 0.000020 0.000404 0.231868 \n", "\n", " val_mean \n", "0 0.025369 \n", "1 0.021382 \n", "2 0.019889 \n", "3 0.027746 \n", "4 0.019818 \n", "5 0.019164 \n", "6 0.016424 \n", "7 0.032819 \n", "8 0.021690 \n", "9 0.023708 \n", "10 0.018189 \n", "11 0.017495 \n", "12 0.023671 \n", "13 0.024122 \n", "14 0.015406 \n", "15 0.017243 \n", "16 0.027108 \n", "17 0.019028 \n", "18 0.019787 \n", "19 0.019096 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_rows = []\n", "best_state = None\n", "best_val_total = float(\"inf\")\n", "rng = np.random.default_rng(RANDOM_SEED)\n", "\n", "train_sampling_conditions = [g for g in train_conditions if g in indices_by_split_gene[TRAIN_SPLIT]]\n", "if not train_sampling_conditions:\n", " raise ValueError(\"No train perturbation conditions are available.\")\n", "\n", "for epoch in range(1, EPOCHS + 1):\n", " model.train()\n", " epoch_rows = []\n", "\n", " for _ in range(STEPS_PER_EPOCH):\n", " condition = rng.choice(train_sampling_conditions)\n", " batch = make_condition_batch(condition=condition, split_name=TRAIN_SPLIT, rng=rng)\n", "\n", " optimizer.zero_grad(set_to_none=True)\n", " losses = compute_loss_terms(batch)\n", " losses[\"total\"].backward()\n", " nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP_NORM)\n", " optimizer.step()\n", "\n", " epoch_rows.append({key: float(value.item()) for key, value in losses.items()})\n", "\n", " train_summary = pd.DataFrame(epoch_rows).mean().to_dict()\n", " val_summary = evaluate_split_loss(\n", " split_name=VAL_SPLIT,\n", " conditions=val_conditions,\n", " seed=RANDOM_SEED + epoch,\n", " steps=VALIDATION_STEPS,\n", " ) if val_conditions else {}\n", "\n", " row = {\"epoch\": epoch, **{f\"train_{k}\": v for k, v in train_summary.items()}}\n", " row.update({f\"val_{k}\": v for k, v in val_summary.items()})\n", " history_rows.append(row)\n", "\n", " current_val_total = val_summary.get(\"total\", train_summary[\"total\"])\n", " if current_val_total < best_val_total:\n", " best_val_total = current_val_total\n", " best_state = deepcopy(model.state_dict())\n", "\n", " print(\n", " {\n", " \"epoch\": epoch,\n", " \"train_total\": round(train_summary[\"total\"], 4),\n", " \"train_flow\": round(train_summary[\"flow\"], 4),\n", " \"train_recon\": round(train_summary[\"recon\"], 4),\n", " \"val_total\": round(current_val_total, 4),\n", " }\n", " )\n", "\n", "if best_state is not None:\n", " model.load_state_dict(best_state)\n", "\n", "history = pd.DataFrame(history_rows)\n", "history\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "52911c87", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved:\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/foundation_encoder_flow_matching_model.pt\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/pred_foundation_encoder_flow_matching_seen_cell_split_test.h5ad\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/real_foundation_encoder_flow_matching_seen_cell_split_test.h5ad\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/metrics_by_gene_foundation_encoder_flow_matching_seen_cell_split_test.csv\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/aggregate_foundation_encoder_flow_matching_seen_cell_split_test.csv\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/training_history_foundation_encoder_flow_matching_seen_cell_split_test.csv\n", "- /root/final/baseline/outputs/foundation_encoder_flow_matching/config_foundation_encoder_flow_matching_seen_cell_split_test.json\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
perturbation_genen_cellsmse_meanmae_meanpearson_meanpearson_deltadelta_l2
0AAAS210.0074500.0652000.990404-0.0573415.682287
1AAMP100.0188800.1033200.9750360.2601109.045727
2AARS40.0428390.1567310.9459990.36599513.625822
3AARS2180.0083750.0701680.989019-0.0785556.024816
4AASDHPPT50.0296570.1316670.9642180.07997711.337287
\n", "
" ], "text/plain": [ " perturbation_gene n_cells mse_mean mae_mean pearson_mean pearson_delta \\\n", "0 AAAS 21 0.007450 0.065200 0.990404 -0.057341 \n", "1 AAMP 10 0.018880 0.103320 0.975036 0.260110 \n", "2 AARS 4 0.042839 0.156731 0.945999 0.365995 \n", "3 AARS2 18 0.008375 0.070168 0.989019 -0.078555 \n", "4 AASDHPPT 5 0.029657 0.131667 0.964218 0.079977 \n", "\n", " delta_l2 \n", "0 5.682287 \n", "1 9.045727 \n", "2 13.625822 \n", "3 6.024816 \n", "4 11.337287 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_conditions_to_run = [condition for condition in eval_conditions if condition in indices_by_split_gene[EVAL_SPLIT]]\n", "if not eval_conditions_to_run:\n", " raise ValueError(f\"No conditions found for EVAL_SPLIT={EVAL_SPLIT}.\")\n", "\n", "metrics, pred, real = evaluate_gene_metrics(\n", " split_name=EVAL_SPLIT,\n", " conditions=eval_conditions_to_run,\n", " seed=RANDOM_SEED + 999,\n", ")\n", "aggregate = build_aggregate_frame(metrics)\n", "\n", "tag = f\"{SPLIT_STRATEGY}_{EVAL_SPLIT}\"\n", "pred_path = OUTPUT_DIR / f\"pred_{METHOD_NAME}_{tag}.h5ad\"\n", "real_path = OUTPUT_DIR / f\"real_{METHOD_NAME}_{tag}.h5ad\"\n", "metrics_path = OUTPUT_DIR / f\"metrics_by_gene_{METHOD_NAME}_{tag}.csv\"\n", "aggregate_path = OUTPUT_DIR / f\"aggregate_{METHOD_NAME}_{tag}.csv\"\n", "history_path = OUTPUT_DIR / f\"training_history_{METHOD_NAME}_{tag}.csv\"\n", "config_path = OUTPUT_DIR / f\"config_{METHOD_NAME}_{tag}.json\"\n", "\n", "config = {\n", " \"name\": METHOD_NAME,\n", " \"input_path\": str(INPUT_PATH),\n", " \"split_strategy\": SPLIT_STRATEGY,\n", " \"train_frac\": float(TRAIN_FRAC),\n", " \"val_frac\": float(VAL_FRAC),\n", " \"eval_split\": EVAL_SPLIT,\n", " \"random_seed\": int(RANDOM_SEED),\n", " \"foundation_model_name\": FOUNDATION_MODEL_NAME,\n", " \"foundation_embed_source\": FOUNDATION_EMBED_SOURCE,\n", " \"foundation_obsm_key\": FOUNDATION_OBSM_KEY,\n", " \"foundation_embed_h5ad_path\": None if FOUNDATION_EMBED_H5AD_PATH is None else str(FOUNDATION_EMBED_H5AD_PATH),\n", " \"foundation_embed_file\": None if FOUNDATION_EMBED_FILE is None else str(FOUNDATION_EMBED_FILE),\n", " \"foundation_obs_csv\": None if FOUNDATION_OBS_CSV is None else str(FOUNDATION_OBS_CSV),\n", " \"use_dual_channel\": bool(USE_DUAL_CHANNEL),\n", " \"foundation_freeze\": bool(FOUNDATION_FREEZE),\n", " \"foundation_proj_dim\": int(FOUNDATION_PROJ_DIM),\n", " \"hvg_proj_dim\": int(HVG_PROJ_DIM),\n", " \"latent_dim\": int(LATENT_DIM),\n", " \"hidden_dim\": int(HIDDEN_DIM),\n", " \"condition_dim\": int(CONDITION_DIM),\n", " \"time_dim\": int(TIME_DIM),\n", " \"dropout\": float(DROPOUT),\n", " \"num_flow_blocks\": int(NUM_FLOW_BLOCKS),\n", " \"batch_size\": int(BATCH_SIZE),\n", " \"epochs\": int(EPOCHS),\n", " \"steps_per_epoch\": int(STEPS_PER_EPOCH),\n", " \"validation_steps\": int(VALIDATION_STEPS),\n", " \"learning_rate\": float(LEARNING_RATE),\n", " \"weight_decay\": float(WEIGHT_DECAY),\n", " \"recon_weight\": float(RECON_WEIGHT),\n", " \"flow_weight\": float(FLOW_WEIGHT),\n", " \"endpoint_weight\": float(ENDPOINT_WEIGHT),\n", " \"mmd_weight\": float(MMD_WEIGHT),\n", " \"mean_weight\": float(MEAN_WEIGHT),\n", " \"noise_std\": float(NOISE_STD),\n", " \"sinkhorn_epsilon\": float(SINKHORN_EPSILON),\n", " \"sinkhorn_iters\": int(SINKHORN_ITERS),\n", " \"inference_steps\": int(INFERENCE_STEPS),\n", " \"n_train_conditions\": int(len(train_conditions)),\n", " \"n_eval_conditions\": int(len(eval_conditions_to_run)),\n", " \"n_unseen_eval_conditions\": int(len(set(eval_conditions_to_run) - set(train_conditions))),\n", " \"feature_gene_columns\": FEATURE_GENE_COLUMNS,\n", " \"model_path\": str(MODEL_PATH),\n", "}\n", "\n", "pred.uns[\"baseline\"] = config\n", "real.uns[\"baseline\"] = config\n", "\n", "torch.save(model.state_dict(), MODEL_PATH)\n", "pred.write_h5ad(pred_path, compression=\"gzip\")\n", "real.write_h5ad(real_path, compression=\"gzip\")\n", "metrics.to_csv(metrics_path, index=False)\n", "aggregate.to_csv(aggregate_path)\n", "history.to_csv(history_path, index=False)\n", "with config_path.open(\"w\") as f:\n", " json.dump(config, f, indent=2)\n", "\n", "print(\"Saved:\")\n", "for path in [MODEL_PATH, pred_path, real_path, metrics_path, aggregate_path, history_path, config_path]:\n", " print(\"-\", path)\n", "\n", "metrics.head()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b5cdfb4d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }