{ "cells": [ { "cell_type": "markdown", "id": "da3151cf", "metadata": {}, "source": [ "# Conditional Flow Matching Method\n", "\n", "This notebook implements a thesis-specific method for single-gene perturbation prediction on single-cell RNA-seq data.\n", "\n", "Core idea:\n", "\n", "1. Encode control and perturbed cells into a latent space.\n", "2. Use perturbation gene identity plus batch background as the condition signal.\n", "3. Build pseudo pairs with mini-batch optimal transport.\n", "4. Train a FiLM-conditioned flow matching model that transports control-cell states toward the perturbation distribution.\n", "5. Decode generated latent states back to expression space and evaluate with distribution-aware metrics.\n", "6. **CellOT-style coupling**: besides matching the flow endpoint to the Sinkhorn latent barycenter, the loss also aligns **decode(z1_bar)** with the perturbed batch (MMD + pseudo-bulk) and applies a light **displacement prior** on `pred_z1 − z0` (optional **velocity L2**). Mini-batches resample controls if a spurious index collision occurs.\n", "\n", "By default, the notebook uses a `seen_cell_split`, which matches the algorithm design based on learned perturbation-gene embeddings. If you switch to the original unseen-gene split, unseen perturbation genes are mapped to an `__unk__` token and should be interpreted as an ablation rather than the main setting.\n" ] }, { "cell_type": "markdown", "id": "9e3b1b7d", "metadata": {}, "source": [ "## Method Summary\n", "\n", "**条件(与代码一致):「基因 + batch」** — `ConditionEncoder` 仅包含扰动基因嵌入与技术 batch 嵌入(经 `MLP` 得到条件向量 `h`)。不包含单独的扰动类型(如 KO/activation)或细胞类型嵌入。\n", "\n", "Let `x` be the preprocessed gene-expression vector of a control cell, and let `p` be a single-gene perturbation. The model learns:\n", "\n", "`x -> z = E(x) -> z_hat(1) via conditional flow -> x_hat = D(z_hat(1))`\n", "\n", "The vector field is conditioned on **gene + batch** (`h`), plus **continuous time embedding** `phi(t)` (sinusoidal features + small MLP).\n", "\n", "### Training objective(总损失,与 `compute_loss_terms` 一致)\n", "\n", "```\n", "L = RECON_WEIGHT * L_recon\n", " + FLOW_WEIGHT * L_flow\n", " + ENDPOINT_WEIGHT * L_endpoint\n", " + MMD_WEIGHT * L_mmd\n", " + MEAN_WEIGHT * L_mean\n", " + OT_DECODE_MMD_WEIGHT * L_ot_decode_mmd\n", " + OT_DECODE_MEAN_WEIGHT * L_ot_decode_mean\n", " + DISPLACEMENT_REG_WEIGHT * L_disp_reg\n", " + VELOCITY_L2_WEIGHT * L_vel_l2\n", " + DELTA_MSE_WEIGHT * L_delta_mse\n", " + DELTA_COS_WEIGHT * L_delta_cos\n", " + DELTA_MMD_WEIGHT * L_delta_mmd\n", " + DELTA_BULK_WEIGHT * L_delta_bulk\n", "```\n", "\n", "- **L_recon(重构损失)**: latent autoencoder 对 **source 与 target** 表达的重构 MSE,即 `MSE(D(E(x)), x)` 在 control 与 perturbed 样本上求和(见 `source_recon` / `target_recon`)。默认权重下调,避免重构主导、挤压对扰动效应的学习。\n", "- **L_flow(流匹配损失)**: 预测速度场与 OT 重心目标速度 `z1_bar - z0` 的 MSE。\n", "- **L_endpoint(潜空间 endpoint 损失)**: 由 `z_t` 与预测速度一步推到 `t=1` 的潜向量 `pred_z1`,与 Sinkhorn 重心 `z1_bar` 的 MSE。\n", "- **L_mmd**: 解码后表达空间上,预测 batch 与真实 perturbed 细胞的 RBF-MMD(与扰动后分布对齐)。\n", "- **L_mean**: 解码后按细胞维平均的 pseudo-bulk 向量与真实 perturbed 均值的 MSE。\n", "- **L_ot_decode_mmd / L_ot_decode_mean**: 将 **Sinkhorn 潜空间重心** `z1_bar` 经解码器映射到表达空间,与当前 batch 的真实 perturbed 细胞做 **MMD** 与 **pseudo-bulk MSE**;约束「OT 传输几何」经解码器落到观测空间(与仅监督 `pred_z1` 路径互补)。\n", "- **L_disp_reg**: `mean(||pred_z1 - z0||^2)`,抑制过大的 latent 位移(CellOT-style 运输幅度先验)。\n", "- **L_vel_l2**: `mean(||v||^2)`,可选,惩罚过强的速度场。\n", "- **L_delta_mse / L_delta_cos**: 以配对 control 为基线,`pred_x1 - source_x` 与 `target_x - source_x` 的逐细胞 MSE 与 `1 - cos`(强化扰动方向与幅度,而非仅拟合绝对表达)。\n", "- **L_delta_mmd**: 预测 delta 与真实 delta 在基因维上的 RBF-MMD(与扰动一致的分布形状)。\n", "- **L_delta_bulk**: 基因维平均 delta(pseudo-bulk 扰动向量)的 MSE,与评估中的 delta 指标更直接相关。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1c11f06f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Project root: /root/final\n", "Input h5ad: /root/final/data/processed/scPerturb/prediction_ready/replogle_k562_essential_single_gene_prediction_ready_full.h5ad\n", "Output dir: /root/final/baseline/outputs/conditional_flow_matching_method\n", "Device: cuda\n" ] } ], "source": [ "! pip install -q anndata numpy pandas scipy torch matplotlib\n", "from pathlib import Path\n", "from copy import deepcopy\n", "import json\n", "import math\n", "import random\n", "\n", "import anndata as ad\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import scipy.sparse as sp\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "def find_project_root(start: Path | None = None) -> Path:\n", " start = (start or Path.cwd()).resolve()\n", " for path in [start, *start.parents]:\n", " if (path / \"src\" / \"data\").exists() and (path / \"baseline\").exists():\n", " return path\n", " return start\n", "\n", "\n", "PROJECT_ROOT = find_project_root()\n", "DATASET_NAME = \"ReplogleWeissman2022_K562_essential\"\n", "METHOD_NAME = \"conditional_flow_matching_method\"\n", "RANDOM_SEED = 42\n", "\n", "SPLIT_STRATEGY = \"seen_cell_split\" # choices: \"seen_cell_split\", \"use_existing\"\n", "TRAIN_FRAC = 0.80\n", "VAL_FRAC = 0.10\n", "EVAL_SPLIT = \"test\"\n", "CONTROL_SPLIT = \"train\"\n", "TRAIN_SPLIT = \"train\"\n", "VAL_SPLIT = \"val\"\n", "\n", "USE_FEATURE_GENES = True\n", "FEATURE_GENE_COLUMNS = [\"is_feature_gene\", \"highly_variable\"]\n", "\n", "# Optional runtime caps. Set to None for the full run.\n", "MAX_TRAIN_CONDITIONS = None\n", "MAX_EVAL_GENES = None\n", "MAX_EVAL_CELLS_PER_GENE = None\n", "\n", "LATENT_DIM = 128\n", "HIDDEN_DIM = 512\n", "CONDITION_DIM = 256\n", "TIME_DIM = 128\n", "DROPOUT = 0.08\n", "NUM_FLOW_BLOCKS = 6\n", "\n", "BATCH_SIZE = 128\n", "EPOCHS = 20\n", "STEPS_PER_EPOCH = 200\n", "VALIDATION_STEPS = 48\n", "LEARNING_RATE = 1e-3\n", "WEIGHT_DECAY = 1e-5\n", "GRAD_CLIP_NORM = 1.0\n", "\n", "RECON_WEIGHT = 1.0\n", "FLOW_WEIGHT = 1.5\n", "ENDPOINT_WEIGHT = 0.5\n", "MMD_WEIGHT = 0.1\n", "MEAN_WEIGHT = 1.0\n", "DELTA_MSE_WEIGHT = 0.0\n", "DELTA_COS_WEIGHT = 0.0\n", "DELTA_MMD_WEIGHT = 0.0\n", "DELTA_BULK_WEIGHT = 0.0\n", "\n", "# --- CellOT-style OT coupling (same mini-batch Sinkhorn z1_bar; extra decode + transport priors) ---\n", "OT_DECODE_MMD_WEIGHT = 0.05 # MMD(decode(z1_bar)), target_x) ties decoder to OT geometry\n", "OT_DECODE_MEAN_WEIGHT = 0.25 # pseudo-bulk for decode(z1_bar) vs target\n", "DISPLACEMENT_REG_WEIGHT = 0.01 # mean ||pred_z1 - z0||^2 — discourages extreme latent jumps\n", "VELOCITY_L2_WEIGHT = 0.0 # mean ||v||^2; set e.g. 1e-4 to penalize large speeds\n", "\n", "NOISE_STD = 0.05\n", "SINKHORN_EPSILON = 0.5\n", "SINKHORN_ITERS = 30\n", "INFERENCE_STEPS = 24\n", "PRED_BATCH_SIZE = 256\n", "\n", "# --- Optimization & validation (tuning) ---\n", "LR_SCHEDULER = \"plateau\" # \"none\", \"cosine\", \"plateau\", \"onecycle\"\n", "PLATEAU_FACTOR = 0.5\n", "PLATEAU_PATIENCE = 3\n", "PLATEAU_MIN_LR = 1e-6\n", "COSINE_MIN_LR = 1e-6\n", "EARLY_STOP_PATIENCE = 5 # 0 = disabled\n", "FIXED_VAL_SUBSET = True\n", "FIXED_VAL_BATCH_SEED = True\n", "VAL_BATCH_SEED = RANDOM_SEED + 777\n", "FLOW_LR_MULT = 1.5\n", "UNCERTAINTY_SAMPLES = 1 # >1: stochastic control resampling; metrics get pred_stochastic_std_mean\n", "\n", "INPUT_CANDIDATES = [\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready_full.h5ad\",\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready.h5ad\",\n", " PROJECT_ROOT / \"data\" / \"processed\" / \"scPerturb\" / \"prediction_ready\" / \"replogle_k562_essential_single_gene_prediction_ready_fast_debug.h5ad\",\n", "]\n", "INPUT_PATH = next((path for path in INPUT_CANDIDATES if path.exists()), None)\n", "if INPUT_PATH is None:\n", " raise FileNotFoundError(\"Could not find a prediction-ready Replogle h5ad. Checked:\\n\" + \"\\n\".join(map(str, INPUT_CANDIDATES)))\n", "\n", "OUTPUT_DIR = PROJECT_ROOT / \"baseline\" / \"outputs\" / METHOD_NAME\n", "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", "MODEL_PATH = OUTPUT_DIR / \"conditional_flow_matching_model.pt\"\n", "\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "np.random.seed(RANDOM_SEED)\n", "random.seed(RANDOM_SEED)\n", "torch.manual_seed(RANDOM_SEED)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(RANDOM_SEED)\n", "if torch.backends.cudnn.is_available():\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", "\n", "print(f\"Project root: {PROJECT_ROOT}\")\n", "print(f\"Input h5ad: {INPUT_PATH}\")\n", "print(f\"Output dir: {OUTPUT_DIR}\")\n", "print(f\"Device: {DEVICE}\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "e4871faf", "metadata": {}, "outputs": [], "source": [ "SCIENCE_COLORS = {\n", " \"blue\": \"#0072B2\",\n", " \"orange\": \"#E69F00\",\n", " \"green\": \"#009E73\",\n", " \"magenta\": \"#CC79A7\",\n", " \"dark\": \"#333333\",\n", " \"gray\": \"#999999\",\n", "}\n", "\n", "\n", "def setup_science_style() -> None:\n", " mpl.rcParams.update(\n", " {\n", " \"figure.dpi\": 110,\n", " \"savefig.dpi\": 300,\n", " \"font.family\": \"sans-serif\",\n", " \"font.sans-serif\": [\"DejaVu Sans\", \"Helvetica\", \"Arial\"],\n", " \"font.size\": 9,\n", " \"axes.titlesize\": 10,\n", " \"axes.labelsize\": 9,\n", " \"axes.linewidth\": 0.6,\n", " \"axes.edgecolor\": SCIENCE_COLORS[\"dark\"],\n", " \"axes.labelcolor\": SCIENCE_COLORS[\"dark\"],\n", " \"xtick.labelsize\": 8,\n", " \"ytick.labelsize\": 8,\n", " \"xtick.major.width\": 0.6,\n", " \"ytick.major.width\": 0.6,\n", " \"legend.frameon\": False,\n", " \"legend.fontsize\": 8,\n", " \"axes.spines.top\": False,\n", " \"axes.spines.right\": False,\n", " \"figure.facecolor\": \"white\",\n", " \"axes.facecolor\": \"white\",\n", " \"grid.alpha\": 0.3,\n", " \"grid.linewidth\": 0.4,\n", " \"lines.linewidth\": 1.25,\n", " \"pdf.fonttype\": 42,\n", " \"ps.fonttype\": 42,\n", " }\n", " )\n", "\n", "\n", "FIGURE_DIR = OUTPUT_DIR / \"figures\"\n", "FIGURE_DIR.mkdir(parents=True, exist_ok=True)\n", "setup_science_style()\n", "\n", "\n", "def plot_training_history(history: pd.DataFrame, tag: str) -> None:\n", " setup_science_style()\n", " h = history.copy()\n", " if h.empty:\n", " return\n", " fig, axes = plt.subplots(1, 2, figsize=(7.0, 2.45), constrained_layout=True)\n", " e = h[\"epoch\"].to_numpy()\n", " ax = axes[0]\n", " if \"train_total\" in h.columns:\n", " ax.plot(\n", " e,\n", " h[\"train_total\"],\n", " \"o-\",\n", " ms=3,\n", " lw=1.1,\n", " color=SCIENCE_COLORS[\"blue\"],\n", " label=\"Train (total)\",\n", " )\n", " if \"val_total\" in h.columns and h[\"val_total\"].notna().any():\n", " ax.plot(\n", " e,\n", " h[\"val_total\"],\n", " \"s-\",\n", " ms=3,\n", " lw=1.1,\n", " color=SCIENCE_COLORS[\"orange\"],\n", " label=\"Validation (total)\",\n", " )\n", " best_i = int(h[\"val_total\"].idxmin())\n", " best_ep = float(h.loc[best_i, \"epoch\"])\n", " best_v = float(h.loc[best_i, \"val_total\"])\n", " ax.scatter(\n", " [best_ep],\n", " [best_v],\n", " s=55,\n", " zorder=5,\n", " facecolors=\"none\",\n", " edgecolors=SCIENCE_COLORS[\"dark\"],\n", " linewidths=1.1,\n", " label=\"Best val\",\n", " )\n", " ax.set_xlabel(\"Epoch\")\n", " ax.set_ylabel(\"Loss\")\n", " ax.legend(loc=\"upper right\", handlelength=2.0)\n", "\n", " ax2 = axes[1]\n", " for col, lab, c in (\n", " (\"train_flow\", \"Train flow\", SCIENCE_COLORS[\"green\"]),\n", " (\"train_recon\", \"Train reconstruction\", SCIENCE_COLORS[\"magenta\"]),\n", " (\"train_ot_decode_mmd\", \"Train OT-decode MMD\", \"#0072B2\"),\n", " (\"train_disp_reg\", \"Train displacement reg\", \"#009E73\"),\n", " (\"train_vel_l2\", \"Train velocity L2\", \"#CC79A7\"),\n", " (\"train_delta_mse\", \"Train delta MSE\", \"#D55E00\"),\n", " (\"train_delta_mmd\", \"Train delta MMD\", \"#6A3D9A\"),\n", " ):\n", " if col in h.columns:\n", " ax2.plot(e, h[col], \"o-\", ms=2.5, lw=1.0, color=c, label=lab)\n", " ax2.set_xlabel(\"Epoch\")\n", " ax2.set_ylabel(\"Component loss\")\n", " ax2.set_yscale(\"symlog\", linthresh=1e-4)\n", " ax2.legend(loc=\"upper right\", handlelength=2.0)\n", "\n", " fig.suptitle(\n", " \"Training dynamics — conditional flow matching\",\n", " fontsize=10.5,\n", " color=SCIENCE_COLORS[\"dark\"],\n", " y=1.02,\n", " )\n", " for ext in (\"png\", \"pdf\"):\n", " fig.savefig(FIGURE_DIR / f\"training_history_{tag}.{ext}\", bbox_inches=\"tight\", facecolor=\"white\")\n", " plt.show()\n", "\n", "\n", "def plot_evaluation_metrics(\n", " metrics: pd.DataFrame,\n", " tag: str,\n", " train_genes: list[str] | None = None,\n", ") -> None:\n", " setup_science_style()\n", " if metrics.empty:\n", " return\n", " m = metrics.copy()\n", " if train_genes is not None:\n", " seen = set(train_genes)\n", " m[\"cohort\"] = np.where(m[\"perturbation_gene\"].isin(seen), \"Seen at train\", \"Unseen at train\")\n", " else:\n", " m[\"cohort\"] = \"All\"\n", "\n", " fig = plt.figure(figsize=(7.2, 5.9), constrained_layout=True)\n", " gs = fig.add_gridspec(2, 2, height_ratios=[1.05, 1.0], hspace=0.28, wspace=0.28)\n", "\n", " ax1 = fig.add_subplot(gs[0, 0])\n", " if m[\"cohort\"].nunique() > 1:\n", " palette = {\"Seen at train\": SCIENCE_COLORS[\"blue\"], \"Unseen at train\": SCIENCE_COLORS[\"orange\"]}\n", " for lab, g in m.groupby(\"cohort\"):\n", " s = 14 + np.clip(g[\"n_cells\"].to_numpy(dtype=float), 0, 120) * 0.12\n", " ax1.scatter(\n", " g[\"mse_mean\"],\n", " g[\"pearson_mean\"],\n", " s=s,\n", " alpha=0.78,\n", " edgecolors=\"none\",\n", " c=palette.get(lab, SCIENCE_COLORS[\"gray\"]),\n", " label=lab,\n", " )\n", " ax1.legend(loc=\"lower left\")\n", " else:\n", " ax1.scatter(\n", " m[\"mse_mean\"],\n", " m[\"pearson_mean\"],\n", " s=22,\n", " alpha=0.75,\n", " edgecolors=\"none\",\n", " c=SCIENCE_COLORS[\"blue\"],\n", " )\n", " ax1.set_xlabel(\"MSE (predicted vs true, gene means)\")\n", " ax1.set_ylabel(\"Pearson r (gene means)\")\n", " ax1.set_xscale(\"log\")\n", " ax1.text(0.02, 0.98, \"a\", transform=ax1.transAxes, fontsize=11, fontweight=\"bold\", va=\"top\")\n", "\n", " ax2 = fig.add_subplot(gs[0, 1])\n", " m_sorted = m.sort_values(\"pearson_mean\", ascending=True)\n", " max_rows = 36\n", " if len(m_sorted) > max_rows:\n", " head_n = max_rows // 2\n", " plot_df = pd.concat([m_sorted.head(head_n), m_sorted.tail(head_n)], axis=0)\n", " else:\n", " plot_df = m_sorted\n", " y = np.arange(len(plot_df))\n", " ax2.barh(y, plot_df[\"pearson_mean\"], height=0.68, color=SCIENCE_COLORS[\"blue\"], alpha=0.88)\n", " ax2.set_yticks(y)\n", " ax2.set_yticklabels(plot_df[\"perturbation_gene\"], fontsize=5.8)\n", " ax2.set_xlabel(\"Pearson r (gene means)\")\n", " ax2.set_xlim(0, 1.02)\n", " ax2.text(0.02, 0.98, \"b\", transform=ax2.transAxes, fontsize=11, fontweight=\"bold\", va=\"top\")\n", "\n", " ax3 = fig.add_subplot(gs[1, 0])\n", " d = m[\"delta_l2\"].dropna().to_numpy()\n", " if d.size:\n", " nbin = int(np.clip(len(d) // 2, 10, 32))\n", " ax3.hist(\n", " d,\n", " bins=nbin,\n", " color=SCIENCE_COLORS[\"green\"],\n", " alpha=0.88,\n", " edgecolor=\"white\",\n", " linewidth=0.35,\n", " )\n", " ax3.set_xlabel(\"L2 distance (pred delta vs true delta)\")\n", " ax3.set_ylabel(\"Perturbations\")\n", " ax3.text(0.02, 0.98, \"c\", transform=ax3.transAxes, fontsize=11, fontweight=\"bold\", va=\"top\")\n", "\n", " ax4 = fig.add_subplot(gs[1, 1])\n", " ax4.scatter(\n", " m[\"pearson_mean\"],\n", " m[\"pearson_delta\"],\n", " s=20,\n", " c=SCIENCE_COLORS[\"orange\"],\n", " alpha=0.72,\n", " edgecolors=\"none\",\n", " )\n", " pm = m[\"pearson_mean\"].to_numpy(dtype=float)\n", " pd_ = m[\"pearson_delta\"].to_numpy(dtype=float)\n", " lo = float(np.nanmin([np.nanmin(pm), np.nanmin(pd_), 0.0]))\n", " hi = float(np.nanmax([np.nanmax(pm), np.nanmax(pd_), 1.0]))\n", " pad = 0.03 * (hi - lo + 1e-6)\n", " ax4.plot([lo, hi], [lo, hi], ls=\"--\", color=SCIENCE_COLORS[\"gray\"], lw=0.9, label=\"Identity\")\n", " ax4.set_xlim(lo - pad, hi + pad)\n", " ax4.set_ylim(lo - pad, hi + pad)\n", " ax4.set_xlabel(\"Pearson r (means)\")\n", " ax4.set_ylabel(\"Pearson r (delta vs control)\")\n", " ax4.legend(loc=\"lower right\")\n", " ax4.text(0.02, 0.98, \"d\", transform=ax4.transAxes, fontsize=11, fontweight=\"bold\", va=\"top\")\n", "\n", " fig.suptitle(\n", " \"Held-out evaluation — per-perturbation metrics\",\n", " fontsize=10.5,\n", " color=SCIENCE_COLORS[\"dark\"],\n", " )\n", " for ext in (\"png\", \"pdf\"):\n", " fig.savefig(FIGURE_DIR / f\"evaluation_metrics_{tag}.{ext}\", bbox_inches=\"tight\", facecolor=\"white\")\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3c3b561", "metadata": {}, "outputs": [], "source": [ "def normalize_bool(series: pd.Series) -> pd.Series:\n", " if pd.api.types.is_bool_dtype(series):\n", " return series.fillna(False).astype(bool)\n", " text = series.astype(str).str.lower().str.strip()\n", " return text.isin([\"true\", \"1\", \"yes\", \"y\"])\n", "\n", "\n", "def require_columns(obs: pd.DataFrame, columns: list[str]) -> None:\n", " missing = [col for col in columns if col not in obs.columns]\n", " if missing:\n", " raise KeyError(f\"Missing required obs columns: {missing}\")\n", "\n", "\n", "def take_dense_rows(X, idx: np.ndarray) -> np.ndarray:\n", " out = X[idx]\n", " if sp.issparse(out):\n", " out = out.toarray()\n", " return np.asarray(out, dtype=np.float32)\n", "\n", "\n", "def mean_vector(X) -> np.ndarray:\n", " if sp.issparse(X):\n", " return np.asarray(X.mean(axis=0)).ravel().astype(np.float32)\n", " return np.asarray(X, dtype=np.float32).mean(axis=0)\n", "\n", "\n", "def safe_pearson(x: np.ndarray, y: np.ndarray) -> float:\n", " x = np.asarray(x).ravel()\n", " y = np.asarray(y).ravel()\n", " if x.size < 2 or np.std(x) == 0 or np.std(y) == 0:\n", " return np.nan\n", " return float(np.corrcoef(x, y)[0, 1])\n", "\n", "\n", "def split_array_indices(indices: np.ndarray, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n", " indices = np.asarray(indices)\n", " if len(indices) == 0:\n", " return indices, indices, indices\n", " shuffled = indices.copy()\n", " rng.shuffle(shuffled)\n", " if len(shuffled) < 3:\n", " return np.sort(shuffled), np.array([], dtype=indices.dtype), np.array([], dtype=indices.dtype)\n", " n_train = max(1, int(round(len(shuffled) * TRAIN_FRAC)))\n", " n_val = int(round(len(shuffled) * VAL_FRAC))\n", " n_train = min(n_train, len(shuffled) - 2)\n", " n_val = max(1, min(n_val, len(shuffled) - n_train - 1))\n", " train = np.sort(shuffled[:n_train])\n", " val = np.sort(shuffled[n_train : n_train + n_val])\n", " test = np.sort(shuffled[n_train + n_val :])\n", " if len(test) == 0:\n", " test = val[-1:].copy()\n", " val = val[:-1]\n", " if len(val) == 0:\n", " val = train[-1:].copy()\n", " train = train[:-1]\n", " return train, val, test\n", "\n", "\n", "def assign_seen_cell_split(obs: pd.DataFrame, seed: int) -> pd.Categorical:\n", " rng = np.random.default_rng(seed)\n", " split = np.full(len(obs), \"train\", dtype=object)\n", " is_control = normalize_bool(obs[\"is_control\"]).to_numpy(dtype=bool)\n", " genes = obs[\"perturbation_gene\"].astype(str).to_numpy()\n", "\n", " control_idx = np.where(is_control)[0]\n", " train_idx, val_idx, test_idx = split_array_indices(control_idx, rng)\n", " split[train_idx] = \"train\"\n", " split[val_idx] = \"val\"\n", " split[test_idx] = \"test\"\n", "\n", " for gene in sorted(set(genes[~is_control])):\n", " idx = np.where((~is_control) & (genes == gene))[0]\n", " train_idx, val_idx, test_idx = split_array_indices(idx, rng)\n", " split[train_idx] = \"train\"\n", " split[val_idx] = \"val\"\n", " split[test_idx] = \"test\"\n", "\n", " return pd.Categorical(split, categories=[\"train\", \"val\", \"test\"], ordered=True)\n", "\n", "\n", "def choose_feature_mask(var: pd.DataFrame) -> np.ndarray:\n", " if not USE_FEATURE_GENES:\n", " return np.ones(var.shape[0], dtype=bool)\n", " for column in FEATURE_GENE_COLUMNS:\n", " if column in var.columns:\n", " mask = var[column].astype(bool).to_numpy()\n", " if mask.any():\n", " print(f\"Using feature mask from var['{column}']: {int(mask.sum())} genes\")\n", " return mask\n", " print(\"Feature-gene columns not found or empty; using all genes.\")\n", " return np.ones(var.shape[0], dtype=bool)\n", "\n", "\n", "def median_heuristic_sigma(x: torch.Tensor, y: torch.Tensor, max_points: int = 256) -> float:\n", " z = torch.cat([x[:max_points], y[:max_points]], dim=0)\n", " if z.shape[0] < 2:\n", " return 1.0\n", " d2 = torch.cdist(z, z, p=2).pow(2)\n", " mask = ~torch.eye(d2.shape[0], dtype=torch.bool, device=d2.device)\n", " valid = d2[mask]\n", " valid = valid[valid > 0]\n", " if valid.numel() == 0:\n", " return 1.0\n", " return float(torch.sqrt(valid.median()).item())\n", "\n", "\n", "def rbf_mmd_loss(x: torch.Tensor, y: torch.Tensor, sigma: float | None = None) -> torch.Tensor:\n", " if x.shape[0] < 2 or y.shape[0] < 2:\n", " return x.new_tensor(0.0)\n", " sigma = sigma or median_heuristic_sigma(x, y)\n", " gamma = 1.0 / (2.0 * sigma * sigma + 1e-8)\n", " d_xx = torch.cdist(x, x, p=2).pow(2)\n", " d_yy = torch.cdist(y, y, p=2).pow(2)\n", " d_xy = torch.cdist(x, y, p=2).pow(2)\n", " k_xx = torch.exp(-gamma * d_xx)\n", " k_yy = torch.exp(-gamma * d_yy)\n", " k_xy = torch.exp(-gamma * d_xy)\n", " return k_xx.mean() + k_yy.mean() - 2.0 * k_xy.mean()\n", "\n", "\n", "def sinkhorn_barycentric_targets(x0: torch.Tensor, x1: torch.Tensor, epsilon: float, n_iters: int) -> torch.Tensor:\n", " cost = torch.cdist(x0, x1, p=2).pow(2)\n", " kernel = torch.exp(-cost / max(epsilon, 1e-4)).clamp_min(1e-8)\n", "\n", " a = torch.full((x0.shape[0],), 1.0 / x0.shape[0], device=x0.device, dtype=x0.dtype)\n", " b = torch.full((x1.shape[0],), 1.0 / x1.shape[0], device=x1.device, dtype=x1.dtype)\n", " u = torch.ones_like(a)\n", " v = torch.ones_like(b)\n", "\n", " for _ in range(n_iters):\n", " u = a / (kernel @ v + 1e-8)\n", " v = b / (kernel.t() @ u + 1e-8)\n", "\n", " plan = u[:, None] * kernel * v[None, :]\n", " row_mass = plan.sum(dim=1, keepdim=True).clamp_min(1e-8)\n", " return (plan / row_mass) @ x1\n", "\n", "\n", "def build_aggregate_frame(metrics: pd.DataFrame) -> pd.DataFrame:\n", " numeric = metrics.select_dtypes(include=[np.number])\n", " if numeric.empty:\n", " return pd.DataFrame()\n", " return numeric.agg([\"mean\", \"median\"]).T\n", "\n", "\n", "@torch.no_grad()\n", "def rk2_integrate_latent(model, z0, gene_idx, batch_idx, steps: int) -> torch.Tensor:\n", " z = z0\n", " dt = 1.0 / steps\n", " for step in range(steps):\n", " t0 = torch.full((z.shape[0],), step / steps, device=z.device, dtype=z.dtype)\n", " k1 = model.velocity(z, t0, gene_idx, batch_idx)\n", " z_mid = z + 0.5 * dt * k1\n", " t_mid = torch.full((z.shape[0],), (step + 0.5) / steps, device=z.device, dtype=z.dtype)\n", " k2 = model.velocity(z_mid, t_mid, gene_idx, batch_idx)\n", " z = z + dt * k2\n", " return z\n" ] }, { "cell_type": "markdown", "id": "1296b782", "metadata": {}, "source": [ "## Load And Prepare Data\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "b1ead7e2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using feature mask from var['is_feature_gene']: 4334 genes\n", "adata: AnnData object with n_obs × n_vars = 310385 × 4334\n", " obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_label', 'perturbation_gene', 'is_control', 'is_single_perturbation', 'split', 'condition', 'batch_model', 'source_split'\n", " var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'is_target_gene', 'is_feature_gene'\n", " uns: 'hvg', 'log1p', 'prediction_ready'\n", " layers: 'counts'\n", "split counts:\n", "split\n", "train 248311\n", "test 31046\n", "val 31028\n", "Name: count, dtype: int64\n", "train conditions: 2056\n", "val conditions: 2056\n", "eval conditions: 2056\n", "eval genes missing from training: 0\n" ] } ], "source": [ "adata = ad.read_h5ad(INPUT_PATH)\n", "require_columns(adata.obs, [\"perturbation_gene\", \"is_control\", \"split\"])\n", "\n", "adata.obs[\"is_control\"] = normalize_bool(adata.obs[\"is_control\"])\n", "adata.obs[\"perturbation_gene\"] = adata.obs[\"perturbation_gene\"].astype(str).str.strip()\n", "adata.obs[\"perturbation_gene\"] = adata.obs[\"perturbation_gene\"].mask(adata.obs[\"is_control\"], \"__ctrl__\")\n", "adata.obs[\"condition\"] = adata.obs[\"perturbation_gene\"].astype(str)\n", "adata.obs[\"batch_model\"] = adata.obs[\"batch\"].astype(str) if \"batch\" in adata.obs.columns else \"global\"\n", "\n", "feature_mask = choose_feature_mask(adata.var)\n", "adata = adata[:, feature_mask].copy()\n", "\n", "if SPLIT_STRATEGY == \"seen_cell_split\":\n", " adata.obs[\"source_split\"] = adata.obs[\"split\"].astype(str)\n", " adata.obs[\"split\"] = assign_seen_cell_split(adata.obs, RANDOM_SEED)\n", "elif SPLIT_STRATEGY == \"use_existing\":\n", " adata.obs[\"split\"] = pd.Categorical(\n", " adata.obs[\"split\"].astype(str),\n", " categories=[\"train\", \"val\", \"test\"],\n", " ordered=True,\n", " )\n", "else:\n", " raise ValueError(f\"Unknown SPLIT_STRATEGY: {SPLIT_STRATEGY}\")\n", "\n", "is_control = adata.obs[\"is_control\"].to_numpy(dtype=bool)\n", "split = adata.obs[\"split\"].astype(str).to_numpy()\n", "genes = adata.obs[\"perturbation_gene\"].astype(str).to_numpy()\n", "batches = adata.obs[\"batch_model\"].astype(str).to_numpy()\n", "\n", "control_train_idx = np.where(is_control & (split == CONTROL_SPLIT))[0]\n", "if len(control_train_idx) == 0:\n", " print(f\"No controls found in CONTROL_SPLIT={CONTROL_SPLIT}; falling back to all controls.\")\n", " control_train_idx = np.where(is_control)[0]\n", "if len(control_train_idx) == 0:\n", " raise ValueError(\"No control cells found.\")\n", "\n", "train_conditions = sorted(set(genes[(~is_control) & (split == TRAIN_SPLIT)]))\n", "val_conditions = sorted(set(genes[(~is_control) & (split == VAL_SPLIT)]))\n", "eval_conditions = sorted(set(genes[(~is_control) & (split == EVAL_SPLIT)]))\n", "\n", "if MAX_TRAIN_CONDITIONS is not None:\n", " train_conditions = train_conditions[:MAX_TRAIN_CONDITIONS]\n", " train_condition_set = set(train_conditions)\n", " keep_mask = is_control | np.isin(genes, list(train_condition_set)) | np.isin(genes, eval_conditions) | np.isin(genes, val_conditions)\n", " adata = adata[keep_mask].copy()\n", " is_control = adata.obs[\"is_control\"].to_numpy(dtype=bool)\n", " split = adata.obs[\"split\"].astype(str).to_numpy()\n", " genes = adata.obs[\"perturbation_gene\"].astype(str).to_numpy()\n", " batches = adata.obs[\"batch_model\"].astype(str).to_numpy()\n", " control_train_idx = np.where(is_control & (split == CONTROL_SPLIT))[0]\n", " train_conditions = sorted(set(genes[(~is_control) & (split == TRAIN_SPLIT)]))\n", " val_conditions = sorted(set(genes[(~is_control) & (split == VAL_SPLIT)]))\n", " eval_conditions = sorted(set(genes[(~is_control) & (split == EVAL_SPLIT)]))\n", "\n", "if MAX_EVAL_GENES is not None:\n", " eval_conditions = eval_conditions[:MAX_EVAL_GENES]\n", "\n", "indices_by_split_gene = {}\n", "for split_name in [TRAIN_SPLIT, VAL_SPLIT, EVAL_SPLIT]:\n", " indices_by_split_gene[split_name] = {}\n", " split_mask = (~is_control) & (split == split_name)\n", " for gene in sorted(set(genes[split_mask])):\n", " indices = np.where(split_mask & (genes == gene))[0]\n", " if len(indices) > 0:\n", " indices_by_split_gene[split_name][gene] = indices\n", "\n", "control_idx_by_batch = {}\n", "for batch_name in sorted(set(batches[control_train_idx])):\n", " batch_idx = control_train_idx[batches[control_train_idx] == batch_name]\n", " if len(batch_idx) > 0:\n", " control_idx_by_batch[batch_name] = batch_idx\n", "\n", "gene_vocab = [\"__ctrl__\", \"__unk__\", *train_conditions]\n", "gene_to_idx = {gene: idx for idx, gene in enumerate(gene_vocab)}\n", "batch_vocab = [\"__unk__\", *sorted(set(batches))]\n", "batch_to_idx = {batch: idx for idx, batch in enumerate(batch_vocab)}\n", "\n", "gene_index_per_obs = np.array([gene_to_idx.get(gene, gene_to_idx[\"__unk__\"]) for gene in genes], dtype=np.int64)\n", "batch_index_per_obs = np.array([batch_to_idx.get(batch, batch_to_idx[\"__unk__\"]) for batch in batches], dtype=np.int64)\n", "\n", "control_mean = mean_vector(adata.X[control_train_idx])\n", "unseen_eval_genes = sorted(set(eval_conditions) - set(train_conditions))\n", "\n", "print(\"adata:\", adata)\n", "print(\"split counts:\")\n", "print(adata.obs[\"split\"].astype(str).value_counts())\n", "print(\"train conditions:\", len(train_conditions))\n", "print(\"val conditions:\", len(val_conditions))\n", "print(\"eval conditions:\", len(eval_conditions))\n", "print(\"eval genes missing from training:\", len(unseen_eval_genes))\n", "if unseen_eval_genes[:10]:\n", " print(\"first unseen eval genes:\", unseen_eval_genes[:10])\n" ] }, { "cell_type": "markdown", "id": "3fa220a2", "metadata": {}, "source": [ "## Model Definition\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "ec5b92e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConditionalFlowMatchingModel(\n", " (encoder): MLPEncoder(\n", " (net): Sequential(\n", " (0): Linear(in_features=4334, out_features=1024, bias=True)\n", " (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.08, inplace=False)\n", " (4): Linear(in_features=1024, out_features=1024, bias=True)\n", " (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (6): GELU(approximate='none')\n", " (7): Dropout(p=0.08, inplace=False)\n", " (8): Linear(in_features=1024, out_features=512, bias=True)\n", " (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (10): GELU(approximate='none')\n", " (11): Dropout(p=0.08, inplace=False)\n", " (12): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " (decoder): MLPDecoder(\n", " (net): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.08, inplace=False)\n", " (4): Linear(in_features=512, out_features=1024, bias=True)\n", " (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (6): GELU(approximate='none')\n", " (7): Dropout(p=0.08, inplace=False)\n", " (8): Linear(in_features=1024, out_features=1024, bias=True)\n", " (9): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (10): GELU(approximate='none')\n", " (11): Dropout(p=0.08, inplace=False)\n", " (12): Linear(in_features=1024, out_features=4334, bias=True)\n", " )\n", " )\n", " (condition_encoder): ConditionEncoder(\n", " (gene_embedding): Embedding(2058, 128)\n", " (batch_embedding): Embedding(49, 64)\n", " (mlp): Sequential(\n", " (0): Linear(in_features=192, out_features=256, bias=True)\n", " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.08, inplace=False)\n", " (4): Linear(in_features=256, out_features=256, bias=True)\n", " )\n", " )\n", " (time_encoder): TimeEmbedding(\n", " (proj): Sequential(\n", " (0): Linear(in_features=128, out_features=128, bias=True)\n", " (1): SiLU()\n", " (2): Linear(in_features=128, out_features=128, bias=True)\n", " )\n", " )\n", " (vector_field): ConditionalVectorField(\n", " (input_proj): Linear(in_features=256, out_features=512, bias=True)\n", " (blocks): ModuleList(\n", " (0-5): 6 x FiLMResidualBlock(\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (fc1): Linear(in_features=512, out_features=512, bias=True)\n", " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", " (film1): Linear(in_features=256, out_features=1024, bias=True)\n", " (film2): Linear(in_features=256, out_features=1024, bias=True)\n", " (dropout): Dropout(p=0.08, inplace=False)\n", " )\n", " )\n", " (out): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", ")\n", "LR scheduler: ReduceLROnPlateau | VAL_EVAL_SUBSET: 48 | flow_lr_mult: 1.5\n" ] } ], "source": [ "class MLPEncoder(nn.Module):\n", " def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int, dropout: float):\n", " super().__init__()\n", " h2 = hidden_dim * 2\n", " self.net = nn.Sequential(\n", " nn.Linear(input_dim, h2),\n", " nn.LayerNorm(h2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(h2, h2),\n", " nn.LayerNorm(h2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(h2, hidden_dim),\n", " nn.LayerNorm(hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, latent_dim),\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.net(x)\n", "\n", "\n", "class MLPDecoder(nn.Module):\n", " def __init__(self, latent_dim: int, output_dim: int, hidden_dim: int, dropout: float):\n", " super().__init__()\n", " h2 = hidden_dim * 2\n", " self.net = nn.Sequential(\n", " nn.Linear(latent_dim, hidden_dim),\n", " nn.LayerNorm(hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, h2),\n", " nn.LayerNorm(h2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(h2, h2),\n", " nn.LayerNorm(h2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(h2, output_dim),\n", " )\n", "\n", " def forward(self, z: torch.Tensor) -> torch.Tensor:\n", " return F.softplus(self.net(z))\n", "\n", "\n", "# Condition h = MLP([gene_emb; batch_emb]) — \"基因 + batch\" only (no perturbation-type / cell-type).\n", "class ConditionEncoder(nn.Module):\n", " def __init__(self, n_genes: int, n_batches: int, condition_dim: int, dropout: float):\n", " super().__init__()\n", " gene_dim = max(32, condition_dim // 2)\n", " batch_dim = max(16, condition_dim // 4)\n", " self.gene_embedding = nn.Embedding(n_genes, gene_dim)\n", " self.batch_embedding = nn.Embedding(n_batches, batch_dim)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(gene_dim + batch_dim, condition_dim),\n", " nn.LayerNorm(condition_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(condition_dim, condition_dim),\n", " )\n", "\n", " def forward(self, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " gene_emb = self.gene_embedding(gene_idx)\n", " batch_emb = self.batch_embedding(batch_idx)\n", " return self.mlp(torch.cat([gene_emb, batch_emb], dim=-1))\n", "\n", "\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, time_dim: int):\n", " super().__init__()\n", " self.time_dim = time_dim\n", " self.proj = nn.Sequential(\n", " nn.Linear(time_dim, time_dim),\n", " nn.SiLU(),\n", " nn.Linear(time_dim, time_dim),\n", " )\n", "\n", " def forward(self, t: torch.Tensor) -> torch.Tensor:\n", " half = self.time_dim // 2\n", " freqs = torch.exp(\n", " -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half, 1)\n", " )\n", " angles = t[:, None] * freqs[None, :]\n", " emb = torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)\n", " if emb.shape[1] < self.time_dim:\n", " emb = F.pad(emb, (0, self.time_dim - emb.shape[1]))\n", " return self.proj(emb)\n", "\n", "\n", "class FiLMResidualBlock(nn.Module):\n", " def __init__(self, hidden_dim: int, condition_dim: int, dropout: float):\n", " super().__init__()\n", " self.norm1 = nn.LayerNorm(hidden_dim)\n", " self.norm2 = nn.LayerNorm(hidden_dim)\n", " self.fc1 = nn.Linear(hidden_dim, hidden_dim)\n", " self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n", " self.film1 = nn.Linear(condition_dim, hidden_dim * 2)\n", " self.film2 = nn.Linear(condition_dim, hidden_dim * 2)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:\n", " gamma1, beta1 = self.film1(condition).chunk(2, dim=-1)\n", " h = self.norm1(x)\n", " h = h * (1 + gamma1) + beta1\n", " h = F.gelu(self.fc1(h))\n", " h = self.dropout(h)\n", "\n", " gamma2, beta2 = self.film2(condition).chunk(2, dim=-1)\n", " h = self.norm2(h)\n", " h = h * (1 + gamma2) + beta2\n", " h = self.fc2(h)\n", " h = self.dropout(h)\n", " return x + h\n", "\n", "\n", "class ConditionalVectorField(nn.Module):\n", " def __init__(self, latent_dim: int, time_dim: int, condition_dim: int, hidden_dim: int, num_blocks: int, dropout: float):\n", " super().__init__()\n", " self.input_proj = nn.Linear(latent_dim + time_dim, hidden_dim)\n", " self.blocks = nn.ModuleList(\n", " [FiLMResidualBlock(hidden_dim, condition_dim, dropout) for _ in range(num_blocks)]\n", " )\n", " self.out = nn.Linear(hidden_dim, latent_dim)\n", "\n", " def forward(self, z_t: torch.Tensor, t_emb: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:\n", " h = self.input_proj(torch.cat([z_t, t_emb], dim=-1))\n", " for block in self.blocks:\n", " h = block(h, condition)\n", " return self.out(h)\n", "\n", "\n", "class ConditionalFlowMatchingModel(nn.Module):\n", " def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int, condition_dim: int, time_dim: int, n_genes: int, n_batches: int, num_blocks: int, dropout: float):\n", " super().__init__()\n", " self.encoder = MLPEncoder(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim, dropout=dropout)\n", " self.decoder = MLPDecoder(latent_dim=latent_dim, output_dim=input_dim, hidden_dim=hidden_dim, dropout=dropout)\n", " self.condition_encoder = ConditionEncoder(n_genes=n_genes, n_batches=n_batches, condition_dim=condition_dim, dropout=dropout)\n", " self.time_encoder = TimeEmbedding(time_dim=time_dim)\n", " self.vector_field = ConditionalVectorField(\n", " latent_dim=latent_dim,\n", " time_dim=time_dim,\n", " condition_dim=condition_dim,\n", " hidden_dim=hidden_dim,\n", " num_blocks=num_blocks,\n", " dropout=dropout,\n", " )\n", "\n", " def encode_x(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.encoder(x)\n", "\n", " def decode_z(self, z: torch.Tensor) -> torch.Tensor:\n", " return self.decoder(z)\n", "\n", " def reconstruct(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.decode_z(self.encode_x(x))\n", "\n", " def condition(self, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " return self.condition_encoder(gene_idx, batch_idx)\n", "\n", " def velocity(self, z_t: torch.Tensor, t: torch.Tensor, gene_idx: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor:\n", " condition = self.condition(gene_idx, batch_idx)\n", " t_emb = self.time_encoder(t)\n", " return self.vector_field(z_t, t_emb, condition)\n", "\n", "\n", "model = ConditionalFlowMatchingModel(\n", " input_dim=adata.n_vars,\n", " latent_dim=LATENT_DIM,\n", " hidden_dim=HIDDEN_DIM,\n", " condition_dim=CONDITION_DIM,\n", " time_dim=TIME_DIM,\n", " n_genes=len(gene_vocab),\n", " n_batches=len(batch_vocab),\n", " num_blocks=NUM_FLOW_BLOCKS,\n", " dropout=DROPOUT,\n", ").to(DEVICE)\n", "\n", "ae_params = list(model.encoder.parameters()) + list(model.decoder.parameters())\n", "flow_params = (\n", " list(model.vector_field.parameters())\n", " + list(model.time_encoder.parameters())\n", " + list(model.condition_encoder.parameters())\n", ")\n", "optimizer = torch.optim.AdamW(\n", " [\n", " {\"params\": ae_params, \"lr\": LEARNING_RATE},\n", " {\"params\": flow_params, \"lr\": LEARNING_RATE * FLOW_LR_MULT},\n", " ],\n", " weight_decay=WEIGHT_DECAY,\n", ")\n", "\n", "if val_conditions:\n", " _n = min(len(val_conditions), VALIDATION_STEPS)\n", " VAL_EVAL_SUBSET = sorted(val_conditions)[:_n]\n", "else:\n", " VAL_EVAL_SUBSET = []\n", "\n", "lr_scheduler = None\n", "if LR_SCHEDULER == \"plateau\":\n", " lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", " optimizer, mode=\"min\", factor=PLATEAU_FACTOR, patience=PLATEAU_PATIENCE, min_lr=PLATEAU_MIN_LR\n", " )\n", "elif LR_SCHEDULER == \"cosine\":\n", " lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=COSINE_MIN_LR)\n", "elif LR_SCHEDULER == \"onecycle\":\n", " total_steps = EPOCHS * STEPS_PER_EPOCH\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer,\n", " max_lr=[LEARNING_RATE, LEARNING_RATE * FLOW_LR_MULT],\n", " total_steps=total_steps,\n", " pct_start=0.1,\n", " )\n", "\n", "print(model)\n", "print(\n", " \"LR scheduler:\",\n", " type(lr_scheduler).__name__ if lr_scheduler is not None else \"none\",\n", " \"| VAL_EVAL_SUBSET:\",\n", " len(VAL_EVAL_SUBSET),\n", " \"| flow_lr_mult:\",\n", " FLOW_LR_MULT,\n", ")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "91749254", "metadata": {}, "outputs": [], "source": [ "def sample_controls_like_targets(target_idx: np.ndarray, rng: np.random.Generator) -> np.ndarray:\n", " sampled = []\n", " for idx in target_idx:\n", " batch_name = batches[idx]\n", " pool = control_idx_by_batch.get(batch_name, control_train_idx)\n", " sampled.append(rng.choice(pool))\n", " return np.asarray(sampled, dtype=np.int64)\n", "\n", "\n", "def resample_unpaired_controls(\n", " source_idx: np.ndarray, target_idx: np.ndarray, rng: np.random.Generator\n", ") -> np.ndarray:\n", " \"\"\"Ensure source (control) indices never equal target (perturbed) indices.\"\"\"\n", " source_idx = source_idx.copy()\n", " for _ in range(64):\n", " clash = source_idx == target_idx\n", " if not np.any(clash):\n", " break\n", " for i in np.flatnonzero(clash):\n", " batch_name = batches[target_idx[i]]\n", " pool = control_idx_by_batch.get(batch_name, control_train_idx)\n", " source_idx[i] = rng.choice(pool)\n", " return source_idx\n", "\n", "\n", "def make_condition_batch(condition: str, split_name: str, rng: np.random.Generator) -> dict:\n", " target_pool = indices_by_split_gene[split_name][condition]\n", " size = min(BATCH_SIZE, len(target_pool))\n", " replace = len(target_pool) < size\n", " target_idx = rng.choice(target_pool, size=size, replace=replace)\n", " source_idx = sample_controls_like_targets(target_idx, rng)\n", " source_idx = resample_unpaired_controls(source_idx, target_idx, rng)\n", "\n", " if MAX_EVAL_CELLS_PER_GENE is not None and len(target_idx) > MAX_EVAL_CELLS_PER_GENE:\n", " target_idx = target_idx[:MAX_EVAL_CELLS_PER_GENE]\n", " source_idx = source_idx[:MAX_EVAL_CELLS_PER_GENE]\n", "\n", " source_x = torch.from_numpy(take_dense_rows(adata.X, source_idx)).to(DEVICE)\n", " target_x = torch.from_numpy(take_dense_rows(adata.X, target_idx)).to(DEVICE)\n", " gene_idx = torch.full(\n", " (len(target_idx),),\n", " fill_value=gene_to_idx.get(condition, gene_to_idx[\"__unk__\"]),\n", " device=DEVICE,\n", " dtype=torch.long,\n", " )\n", " batch_idx = torch.as_tensor(batch_index_per_obs[source_idx], device=DEVICE, dtype=torch.long)\n", "\n", " return {\n", " \"condition\": condition,\n", " \"source_idx\": source_idx,\n", " \"target_idx\": target_idx,\n", " \"source_x\": source_x,\n", " \"target_x\": target_x,\n", " \"gene_idx\": gene_idx,\n", " \"batch_idx\": batch_idx,\n", " }\n", "\n", "\n", "def compute_loss_terms(batch: dict) -> dict[str, torch.Tensor]:\n", " source_x = batch[\"source_x\"]\n", " target_x = batch[\"target_x\"]\n", " gene_idx = batch[\"gene_idx\"]\n", " batch_idx = batch[\"batch_idx\"]\n", "\n", " z0 = model.encode_x(source_x)\n", " z1 = model.encode_x(target_x)\n", " z1_bar = sinkhorn_barycentric_targets(\n", " x0=z0,\n", " x1=z1,\n", " epsilon=SINKHORN_EPSILON,\n", " n_iters=SINKHORN_ITERS,\n", " )\n", "\n", " t = torch.rand(z0.shape[0], device=DEVICE)\n", " sigma = NOISE_STD * torch.sqrt((t * (1.0 - t)).clamp_min(1e-6)).unsqueeze(-1)\n", " z_t = (1.0 - t).unsqueeze(-1) * z0 + t.unsqueeze(-1) * z1_bar + sigma * torch.randn_like(z0)\n", "\n", " target_velocity = z1_bar - z0\n", " pred_velocity = model.velocity(z_t, t, gene_idx, batch_idx)\n", " pred_z1 = z_t + (1.0 - t).unsqueeze(-1) * pred_velocity\n", " pred_x1 = model.decode_z(pred_z1)\n", " x_ot = model.decode_z(z1_bar)\n", "\n", " source_recon = model.reconstruct(source_x)\n", " target_recon = model.reconstruct(target_x)\n", "\n", " recon_loss = F.mse_loss(source_recon, source_x) + F.mse_loss(target_recon, target_x)\n", " flow_loss = F.mse_loss(pred_velocity, target_velocity)\n", " endpoint_loss = F.mse_loss(pred_z1, z1_bar)\n", " mmd_loss = rbf_mmd_loss(pred_x1, target_x)\n", " mean_loss = F.mse_loss(pred_x1.mean(dim=0), target_x.mean(dim=0))\n", " ot_decode_mmd = rbf_mmd_loss(x_ot, target_x)\n", " ot_decode_mean = F.mse_loss(x_ot.mean(dim=0), target_x.mean(dim=0))\n", " disp_reg = torch.mean((pred_z1 - z0) ** 2)\n", " vel_l2 = torch.mean(pred_velocity ** 2)\n", "\n", " total_loss = (\n", " RECON_WEIGHT * recon_loss\n", " + FLOW_WEIGHT * flow_loss\n", " + ENDPOINT_WEIGHT * endpoint_loss\n", " + MMD_WEIGHT * mmd_loss\n", " + MEAN_WEIGHT * mean_loss\n", " + OT_DECODE_MMD_WEIGHT * ot_decode_mmd\n", " + OT_DECODE_MEAN_WEIGHT * ot_decode_mean\n", " + DISPLACEMENT_REG_WEIGHT * disp_reg\n", " + VELOCITY_L2_WEIGHT * vel_l2\n", " )\n", "\n", " return {\n", " \"total\": total_loss,\n", " \"recon\": recon_loss.detach(),\n", " \"flow\": flow_loss.detach(),\n", " \"endpoint\": endpoint_loss.detach(),\n", " \"mmd\": mmd_loss.detach(),\n", " \"mean\": mean_loss.detach(),\n", " \"ot_decode_mmd\": ot_decode_mmd.detach(),\n", " \"ot_decode_mean\": ot_decode_mean.detach(),\n", " \"disp_reg\": disp_reg.detach(),\n", " \"vel_l2\": vel_l2.detach(),\n", " }\n", "\n", "\n", "@torch.no_grad()\n", "def evaluate_split_loss(\n", " split_name: str,\n", " conditions: list[str],\n", " seed: int,\n", " steps: int,\n", " fixed_condition_subset: list[str] | None = None,\n", " batch_rng_seed: int | None = None,\n", ") -> dict[str, float]:\n", " if fixed_condition_subset is not None:\n", " if len(fixed_condition_subset) == 0:\n", " return {}\n", " elif not conditions:\n", " return {}\n", "\n", " model.eval()\n", " if fixed_condition_subset is not None:\n", " subset = fixed_condition_subset[: min(len(fixed_condition_subset), steps)]\n", " else:\n", " rng = np.random.default_rng(seed)\n", " subset = conditions.copy()\n", " rng.shuffle(subset)\n", " subset = subset[: min(len(subset), steps)]\n", "\n", " brng = batch_rng_seed if batch_rng_seed is not None else seed\n", " rng = np.random.default_rng(brng)\n", "\n", " rows = []\n", " for condition in subset:\n", " batch = make_condition_batch(condition, split_name=split_name, rng=rng)\n", " losses = compute_loss_terms(batch)\n", " rows.append({key: float(value.item()) for key, value in losses.items()})\n", "\n", " return pd.DataFrame(rows).mean().to_dict()\n", "\n", "\n", "@torch.no_grad()\n", "def predict_gene_expression(\n", " condition: str,\n", " split_name: str,\n", " rng: np.random.Generator,\n", " n_stochastic: int = 1,\n", ") -> tuple[np.ndarray, np.ndarray, pd.DataFrame, np.ndarray]:\n", " target_idx_all = indices_by_split_gene[split_name][condition]\n", " if MAX_EVAL_CELLS_PER_GENE is not None and len(target_idx_all) > MAX_EVAL_CELLS_PER_GENE:\n", " target_idx_all = target_idx_all[:MAX_EVAL_CELLS_PER_GENE]\n", "\n", " pred_blocks = []\n", " pred_std_blocks = []\n", " real_blocks = []\n", " obs_blocks = []\n", " gene_id = gene_to_idx.get(condition, gene_to_idx[\"__unk__\"])\n", " condition_source = \"trained_gene_embedding\" if condition in gene_to_idx else \"unk_gene_embedding\"\n", "\n", " for start in range(0, len(target_idx_all), PRED_BATCH_SIZE):\n", " target_idx = target_idx_all[start : start + PRED_BATCH_SIZE]\n", " if n_stochastic <= 1:\n", " source_idx = sample_controls_like_targets(target_idx, rng)\n", " source_idx = resample_unpaired_controls(source_idx, target_idx, rng)\n", " source_x = torch.from_numpy(take_dense_rows(adata.X, source_idx)).to(DEVICE)\n", " batch_idx = torch.as_tensor(batch_index_per_obs[source_idx], device=DEVICE, dtype=torch.long)\n", " gene_idx = torch.full((len(target_idx),), gene_id, device=DEVICE, dtype=torch.long)\n", " z0 = model.encode_x(source_x)\n", " z1_hat = rk2_integrate_latent(model, z0, gene_idx, batch_idx, steps=INFERENCE_STEPS)\n", " pred_x = model.decode_z(z1_hat).cpu().numpy().astype(np.float32)\n", " pred_std_blocks.append(np.zeros_like(pred_x, dtype=np.float32))\n", " else:\n", " runs = []\n", " for s in range(n_stochastic):\n", " sub_rng = np.random.default_rng(int(rng.integers(0, 2**31 - 1)) + s * 1000003)\n", " source_idx = sample_controls_like_targets(target_idx, sub_rng)\n", " source_idx = resample_unpaired_controls(source_idx, target_idx, sub_rng)\n", " source_x = torch.from_numpy(take_dense_rows(adata.X, source_idx)).to(DEVICE)\n", " batch_idx = torch.as_tensor(batch_index_per_obs[source_idx], device=DEVICE, dtype=torch.long)\n", " gene_idx = torch.full((len(target_idx),), gene_id, device=DEVICE, dtype=torch.long)\n", " z0 = model.encode_x(source_x)\n", " z1_hat = rk2_integrate_latent(model, z0, gene_idx, batch_idx, steps=INFERENCE_STEPS)\n", " runs.append(model.decode_z(z1_hat).cpu().numpy().astype(np.float32))\n", " stacked = np.stack(runs, axis=0)\n", " pred_x = stacked.mean(axis=0)\n", " pred_std_blocks.append(stacked.std(axis=0))\n", "\n", " real_x = take_dense_rows(adata.X, target_idx)\n", " pred_blocks.append(pred_x)\n", " real_blocks.append(real_x)\n", "\n", " obs_chunk = adata.obs.iloc[target_idx][[\"perturbation_gene\", \"condition\", \"split\"]].copy()\n", " obs_chunk[\"target_cell_id\"] = adata.obs_names[target_idx].astype(str)\n", " if n_stochastic <= 1:\n", " obs_chunk[\"sampled_control_cell_id\"] = adata.obs_names[source_idx].astype(str)\n", " else:\n", " obs_chunk[\"sampled_control_cell_id\"] = \"stochastic_mean\"\n", " obs_chunk[\"condition_source\"] = condition_source\n", " obs_blocks.append(obs_chunk)\n", "\n", " pred_gene = np.vstack(pred_blocks).astype(np.float32)\n", " pred_std_gene = np.vstack(pred_std_blocks).astype(np.float32)\n", " real_gene = np.vstack(real_blocks).astype(np.float32)\n", " obs_gene = pd.concat(obs_blocks, axis=0)\n", " return pred_gene, real_gene, obs_gene, pred_std_gene\n", "\n", "\n", "@torch.no_grad()\n", "def evaluate_gene_metrics(split_name: str, conditions: list[str], seed: int) -> tuple[pd.DataFrame, ad.AnnData, ad.AnnData]:\n", " rng = np.random.default_rng(seed)\n", "\n", " pred_blocks = []\n", " real_blocks = []\n", " obs_blocks = []\n", " metrics_rows = []\n", "\n", " for condition in conditions:\n", " pred_gene, real_gene, obs_gene, pred_std_gene = predict_gene_expression(\n", " condition=condition, split_name=split_name, rng=rng, n_stochastic=UNCERTAINTY_SAMPLES\n", " )\n", "\n", " pred_blocks.append(pred_gene)\n", " real_blocks.append(real_gene)\n", " obs_blocks.append(obs_gene)\n", "\n", " pred_mean = pred_gene.mean(axis=0)\n", " real_mean = real_gene.mean(axis=0)\n", " pred_delta = pred_mean - control_mean\n", " real_delta = real_mean - control_mean\n", "\n", " row = {\n", " \"perturbation_gene\": condition,\n", " \"n_cells\": int(real_gene.shape[0]),\n", " \"mse_mean\": float(np.mean((pred_mean - real_mean) ** 2)),\n", " \"mae_mean\": float(np.mean(np.abs(pred_mean - real_mean))),\n", " \"pearson_mean\": safe_pearson(pred_mean, real_mean),\n", " \"pearson_delta\": safe_pearson(pred_delta, real_delta),\n", " \"delta_l2\": float(np.linalg.norm(pred_delta - real_delta)),\n", " }\n", " if UNCERTAINTY_SAMPLES > 1:\n", " row[\"pred_stochastic_std_mean\"] = float(np.mean(pred_std_gene))\n", " metrics_rows.append(row)\n", "\n", " pred = ad.AnnData(\n", " X=np.vstack(pred_blocks).astype(np.float32),\n", " obs=pd.concat(obs_blocks, axis=0),\n", " var=adata.var.copy(),\n", " )\n", " real = ad.AnnData(\n", " X=np.vstack(real_blocks).astype(np.float32),\n", " obs=pd.concat(obs_blocks, axis=0),\n", " var=adata.var.copy(),\n", " )\n", " metrics = pd.DataFrame(metrics_rows).sort_values(\"perturbation_gene\").reset_index(drop=True)\n", " return metrics, pred, real\n" ] }, { "cell_type": "markdown", "id": "e2bf6516", "metadata": {}, "source": [ "## Train The Conditional Flow Matching Model\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "27f82aeb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'epoch': 1, 'train_total': 0.7261, 'train_flow': 0.2356, 'train_recon': 0.2876, 'val_total': 0.3572, 'lr': 0.001}\n", "{'epoch': 2, 'train_total': 0.322, 'train_flow': 0.0056, 'train_recon': 0.2768, 'val_total': 0.3224, 'lr': 0.001}\n", "{'epoch': 3, 'train_total': 0.3096, 'train_flow': 0.0009, 'train_recon': 0.2731, 'val_total': 0.3219, 'lr': 0.001}\n", "{'epoch': 4, 'train_total': 0.3087, 'train_flow': 0.0006, 'train_recon': 0.2731, 'val_total': 0.3197, 'lr': 0.001}\n", "{'epoch': 5, 'train_total': 0.3108, 'train_flow': 0.0004, 'train_recon': 0.2751, 'val_total': 0.3196, 'lr': 0.001}\n", "{'epoch': 6, 'train_total': 0.3125, 'train_flow': 0.0003, 'train_recon': 0.2764, 'val_total': 0.3186, 'lr': 0.001}\n", "{'epoch': 7, 'train_total': 0.3083, 'train_flow': 0.0002, 'train_recon': 0.2729, 'val_total': 0.3191, 'lr': 0.001}\n", "{'epoch': 8, 'train_total': 0.3069, 'train_flow': 0.0002, 'train_recon': 0.272, 'val_total': 0.3208, 'lr': 0.001}\n", "{'epoch': 9, 'train_total': 0.3083, 'train_flow': 0.0001, 'train_recon': 0.2734, 'val_total': 0.3183, 'lr': 0.001}\n", "{'epoch': 10, 'train_total': 0.312, 'train_flow': 0.0001, 'train_recon': 0.2762, 'val_total': 0.3204, 'lr': 0.001}\n", "{'epoch': 11, 'train_total': 0.3084, 'train_flow': 0.0001, 'train_recon': 0.2737, 'val_total': 0.3186, 'lr': 0.001}\n", "{'epoch': 12, 'train_total': 0.3104, 'train_flow': 0.0001, 'train_recon': 0.2757, 'val_total': 0.3184, 'lr': 0.001}\n", "{'epoch': 13, 'train_total': 0.3096, 'train_flow': 0.0001, 'train_recon': 0.2749, 'val_total': 0.319, 'lr': 0.0005}\n", "{'epoch': 14, 'train_total': 0.3072, 'train_flow': 0.0, 'train_recon': 0.2732, 'val_total': 0.3175, 'lr': 0.0005}\n", "{'epoch': 15, 'train_total': 0.3046, 'train_flow': 0.0, 'train_recon': 0.2713, 'val_total': 0.3193, 'lr': 0.0005}\n", "{'epoch': 16, 'train_total': 0.3084, 'train_flow': 0.0, 'train_recon': 0.2745, 'val_total': 0.3177, 'lr': 0.0005}\n", "{'epoch': 17, 'train_total': 0.3057, 'train_flow': 0.0, 'train_recon': 0.2722, 'val_total': 0.3182, 'lr': 0.0005}\n", "{'epoch': 18, 'train_total': 0.3075, 'train_flow': 0.0, 'train_recon': 0.2735, 'val_total': 0.3174, 'lr': 0.0005}\n", "{'epoch': 19, 'train_total': 0.3051, 'train_flow': 0.0, 'train_recon': 0.2718, 'val_total': 0.318, 'lr': 0.0005}\n", "{'epoch': 20, 'train_total': 0.3029, 'train_flow': 0.0, 'train_recon': 0.2694, 'val_total': 0.3179, 'lr': 0.0005}\n" ] }, { "data": { "text/html": [ "
| \n", " | epoch | \n", "train_total | \n", "train_recon | \n", "train_flow | \n", "train_endpoint | \n", "train_mmd | \n", "train_mean | \n", "train_ot_decode_mmd | \n", "train_ot_decode_mean | \n", "train_disp_reg | \n", "... | \n", "val_total | \n", "val_recon | \n", "val_flow | \n", "val_endpoint | \n", "val_mmd | \n", "val_mean | \n", "val_ot_decode_mmd | \n", "val_ot_decode_mean | \n", "val_disp_reg | \n", "val_vel_l2 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1 | \n", "0.726086 | \n", "0.287597 | \n", "0.235618 | \n", "0.078830 | \n", "0.190431 | \n", "0.013061 | \n", "0.190711 | \n", "0.012874 | \n", "0.078863 | \n", "... | \n", "0.357233 | \n", "0.270025 | \n", "0.019285 | \n", "0.006376 | \n", "0.232606 | \n", "0.016578 | \n", "0.220882 | \n", "0.016581 | \n", "0.006376 | \n", "0.019285 | \n", "
| 1 | \n", "2 | \n", "0.321951 | \n", "0.276832 | \n", "0.005564 | \n", "0.002245 | \n", "0.182267 | \n", "0.006630 | \n", "0.182260 | \n", "0.006632 | \n", "0.002248 | \n", "... | \n", "0.322358 | \n", "0.268141 | \n", "0.000624 | \n", "0.000624 | \n", "0.222793 | \n", "0.015618 | \n", "0.223224 | \n", "0.015618 | \n", "0.000624 | \n", "0.000624 | \n", "
| 2 | \n", "3 | \n", "0.309613 | \n", "0.273109 | \n", "0.000928 | \n", "0.000722 | \n", "0.181945 | \n", "0.005960 | \n", "0.181987 | \n", "0.005957 | \n", "0.000724 | \n", "... | \n", "0.321867 | \n", "0.268512 | \n", "0.000247 | \n", "0.000500 | \n", "0.221805 | \n", "0.015897 | \n", "0.213580 | \n", "0.015897 | \n", "0.000500 | \n", "0.000247 | \n", "
| 3 | \n", "4 | \n", "0.308655 | \n", "0.273063 | \n", "0.000554 | \n", "0.000601 | \n", "0.181111 | \n", "0.005832 | \n", "0.181094 | \n", "0.005827 | \n", "0.000600 | \n", "... | \n", "0.319691 | \n", "0.267235 | \n", "0.000109 | \n", "0.000452 | \n", "0.219989 | \n", "0.015224 | \n", "0.220651 | \n", "0.015224 | \n", "0.000452 | \n", "0.000109 | \n", "
| 4 | \n", "5 | \n", "0.310805 | \n", "0.275149 | \n", "0.000363 | \n", "0.000534 | \n", "0.181921 | \n", "0.006041 | \n", "0.181920 | \n", "0.006042 | \n", "0.000534 | \n", "... | \n", "0.319558 | \n", "0.267634 | \n", "0.000108 | \n", "0.000458 | \n", "0.218250 | \n", "0.015324 | \n", "0.210994 | \n", "0.015324 | \n", "0.000458 | \n", "0.000108 | \n", "
| 5 | \n", "6 | \n", "0.312528 | \n", "0.276359 | \n", "0.000281 | \n", "0.000505 | \n", "0.182827 | \n", "0.006453 | \n", "0.182823 | \n", "0.006456 | \n", "0.000505 | \n", "... | \n", "0.318567 | \n", "0.266999 | \n", "0.000080 | \n", "0.000446 | \n", "0.217384 | \n", "0.015067 | \n", "0.212976 | \n", "0.015067 | \n", "0.000446 | \n", "0.000080 | \n", "
| 6 | \n", "7 | \n", "0.308339 | \n", "0.272867 | \n", "0.000211 | \n", "0.000483 | \n", "0.183340 | \n", "0.005928 | \n", "0.183333 | \n", "0.005926 | \n", "0.000483 | \n", "... | \n", "0.319122 | \n", "0.267365 | \n", "0.000057 | \n", "0.000427 | \n", "0.218600 | \n", "0.015263 | \n", "0.210317 | \n", "0.015263 | \n", "0.000427 | \n", "0.000057 | \n", "
| 7 | \n", "8 | \n", "0.306890 | \n", "0.272027 | \n", "0.000159 | \n", "0.000465 | \n", "0.181842 | \n", "0.005690 | \n", "0.181823 | \n", "0.005690 | \n", "0.000465 | \n", "... | \n", "0.320815 | \n", "0.268177 | \n", "0.000036 | \n", "0.000421 | \n", "0.218680 | \n", "0.015580 | \n", "0.220539 | \n", "0.015580 | \n", "0.000421 | \n", "0.000036 | \n", "
| 8 | \n", "9 | \n", "0.308317 | \n", "0.273363 | \n", "0.000132 | \n", "0.000460 | \n", "0.182320 | \n", "0.005737 | \n", "0.182330 | \n", "0.005741 | \n", "0.000460 | \n", "... | \n", "0.318322 | \n", "0.266749 | \n", "0.000034 | \n", "0.000415 | \n", "0.217338 | \n", "0.014970 | \n", "0.217291 | \n", "0.014970 | \n", "0.000415 | \n", "0.000034 | \n", "
| 9 | \n", "10 | \n", "0.311956 | \n", "0.276195 | \n", "0.000119 | \n", "0.000449 | \n", "0.182801 | \n", "0.006346 | \n", "0.182800 | \n", "0.006349 | \n", "0.000449 | \n", "... | \n", "0.320359 | \n", "0.268012 | \n", "0.000032 | \n", "0.000431 | \n", "0.218044 | \n", "0.015470 | \n", "0.218747 | \n", "0.015470 | \n", "0.000431 | \n", "0.000032 | \n", "
| 10 | \n", "11 | \n", "0.308387 | \n", "0.273684 | \n", "0.000094 | \n", "0.000442 | \n", "0.180344 | \n", "0.005828 | \n", "0.180341 | \n", "0.005828 | \n", "0.000442 | \n", "... | \n", "0.318582 | \n", "0.266860 | \n", "0.000024 | \n", "0.000408 | \n", "0.216553 | \n", "0.015041 | \n", "0.220423 | \n", "0.015041 | \n", "0.000408 | \n", "0.000024 | \n", "
| 11 | \n", "12 | \n", "0.310445 | \n", "0.275663 | \n", "0.000077 | \n", "0.000438 | \n", "0.180843 | \n", "0.005852 | \n", "0.180887 | \n", "0.005853 | \n", "0.000438 | \n", "... | \n", "0.318373 | \n", "0.266954 | \n", "0.000014 | \n", "0.000422 | \n", "0.216261 | \n", "0.015009 | \n", "0.215913 | \n", "0.015009 | \n", "0.000422 | \n", "0.000014 | \n", "
| 12 | \n", "13 | \n", "0.309572 | \n", "0.274860 | \n", "0.000064 | \n", "0.000431 | \n", "0.180777 | \n", "0.005823 | \n", "0.180775 | \n", "0.005824 | \n", "0.000431 | \n", "... | \n", "0.319046 | \n", "0.266958 | \n", "0.000013 | \n", "0.000408 | \n", "0.216601 | \n", "0.015054 | \n", "0.227646 | \n", "0.015054 | \n", "0.000408 | \n", "0.000013 | \n", "
| 13 | \n", "14 | \n", "0.307190 | \n", "0.273207 | \n", "0.000048 | \n", "0.000427 | \n", "0.180640 | \n", "0.005275 | \n", "0.180690 | \n", "0.005275 | \n", "0.000427 | \n", "... | \n", "0.317530 | \n", "0.266552 | \n", "0.000006 | \n", "0.000409 | \n", "0.215501 | \n", "0.014852 | \n", "0.212916 | \n", "0.014852 | \n", "0.000409 | \n", "0.000006 | \n", "
| 14 | \n", "15 | \n", "0.304604 | \n", "0.271330 | \n", "0.000043 | \n", "0.000424 | \n", "0.179007 | \n", "0.004914 | \n", "0.179017 | \n", "0.004912 | \n", "0.000424 | \n", "... | \n", "0.319289 | \n", "0.267379 | \n", "0.000005 | \n", "0.000415 | \n", "0.217476 | \n", "0.015219 | \n", "0.218394 | \n", "0.015219 | \n", "0.000415 | \n", "0.000005 | \n", "
| 15 | \n", "16 | \n", "0.308400 | \n", "0.274548 | \n", "0.000040 | \n", "0.000424 | \n", "0.179647 | \n", "0.005304 | \n", "0.179647 | \n", "0.005305 | \n", "0.000424 | \n", "... | \n", "0.317750 | \n", "0.266600 | \n", "0.000005 | \n", "0.000409 | \n", "0.216036 | \n", "0.014851 | \n", "0.215316 | \n", "0.014851 | \n", "0.000409 | \n", "0.000005 | \n", "
| 16 | \n", "17 | \n", "0.305699 | \n", "0.272213 | \n", "0.000037 | \n", "0.000423 | \n", "0.178551 | \n", "0.005146 | \n", "0.178549 | \n", "0.005146 | \n", "0.000423 | \n", "... | \n", "0.318228 | \n", "0.266857 | \n", "0.000005 | \n", "0.000408 | \n", "0.216236 | \n", "0.014949 | \n", "0.216927 | \n", "0.014949 | \n", "0.000408 | \n", "0.000005 | \n", "
| 17 | \n", "18 | \n", "0.307527 | \n", "0.273461 | \n", "0.000034 | \n", "0.000419 | \n", "0.179587 | \n", "0.005491 | \n", "0.179581 | \n", "0.005489 | \n", "0.000419 | \n", "... | \n", "0.317383 | \n", "0.266766 | \n", "0.000005 | \n", "0.000411 | \n", "0.215717 | \n", "0.014897 | \n", "0.204130 | \n", "0.014897 | \n", "0.000411 | \n", "0.000005 | \n", "
| 18 | \n", "19 | \n", "0.305071 | \n", "0.271768 | \n", "0.000032 | \n", "0.000421 | \n", "0.178401 | \n", "0.005025 | \n", "0.178407 | \n", "0.005023 | \n", "0.000421 | \n", "... | \n", "0.317969 | \n", "0.267077 | \n", "0.000005 | \n", "0.000407 | \n", "0.216272 | \n", "0.015069 | \n", "0.204290 | \n", "0.015069 | \n", "0.000407 | \n", "0.000005 | \n", "
| 19 | \n", "20 | \n", "0.302873 | \n", "0.269402 | \n", "0.000030 | \n", "0.000421 | \n", "0.179407 | \n", "0.005041 | \n", "0.179403 | \n", "0.005039 | \n", "0.000421 | \n", "... | \n", "0.317881 | \n", "0.266701 | \n", "0.000007 | \n", "0.000412 | \n", "0.215959 | \n", "0.014903 | \n", "0.214677 | \n", "0.014903 | \n", "0.000412 | \n", "0.000007 | \n", "
20 rows × 21 columns
\n", "| \n", " | perturbation_gene | \n", "n_cells | \n", "mse_mean | \n", "mae_mean | \n", "pearson_mean | \n", "pearson_delta | \n", "delta_l2 | \n", "
|---|---|---|---|---|---|---|---|
| 0 | \n", "AAAS | \n", "21 | \n", "0.007128 | \n", "0.064098 | \n", "0.990898 | \n", "-0.084128 | \n", "5.558319 | \n", "
| 1 | \n", "AAMP | \n", "10 | \n", "0.018754 | \n", "0.103177 | \n", "0.974846 | \n", "0.319923 | \n", "9.015438 | \n", "
| 2 | \n", "AARS | \n", "4 | \n", "0.043423 | \n", "0.158615 | \n", "0.945173 | \n", "0.384644 | \n", "13.718369 | \n", "
| 3 | \n", "AARS2 | \n", "18 | \n", "0.007762 | \n", "0.067804 | \n", "0.989657 | \n", "-0.031809 | \n", "5.800062 | \n", "
| 4 | \n", "AASDHPPT | \n", "5 | \n", "0.029958 | \n", "0.133616 | \n", "0.964445 | \n", "0.052611 | \n", "11.394714 | \n", "