{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation\n", "\n", "This notebook demonstrates **TD3B inference** — generating peptide binders with specified agonist or antagonist behavior for GPCR targets.\n", "\n", "**What TD3B does:**\n", "- Takes a target protein sequence + desired direction (agonist / antagonist)\n", "- Generates peptide binder sequences using a finetuned discrete diffusion model\n", "- Scores them with a Direction Oracle and binding affinity predictor\n", "- Returns the best candidates via weighted resampling (Algorithm 2)\n", "\n", "**Requirements:** GPU runtime (T4 or better). Click **Runtime → Change runtime type → GPU**." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 1. Setup" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121\n", "!pip install -q transformers fair-esm SmilesPE rdkit-pypi scipy pandas numpy xgboost pytorch-lightning lightning hydra-core loguru timm huggingface_hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Clone TD3B repository and download checkpoints from HuggingFace\n", "!git clone https://github.com/chq1155/TD3B_ICML.git TD3B\n", "%cd TD3B\n", "\n", "from huggingface_hub import hf_hub_download\n", "import os\n", "\n", "REPO_ID = \"ChatterjeeLab/TD3B\"\n", "os.makedirs(\"checkpoints\", exist_ok=True)\n", "os.makedirs(\"data\", exist_ok=True)\n", "\n", "# Download checkpoints (this may take a few minutes)\n", "for fname in [\"checkpoints/td3b.ckpt\", \"checkpoints/pretrained.ckpt\",\n", " \"checkpoints/direction_oracle.pt\",\n", " \"scoring/functions/classifiers/binding-affinity.pt\",\n", " \"data/test.csv\", \"data/train.csv\"]:\n", " print(f\"Downloading {fname}...\")\n", " hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir=\".\")\n", "\n", "print(\"\\nAll files downloaded!\")\n", "!ls -lh checkpoints/" ] }, { "cell_type": "markdown", "source": [ "## 2. Load Model and Oracle" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, \".\")\n", "\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from diffusion import Diffusion\n", "from configs.finetune_config import (\n", " DiffusionConfig, RoFormerConfig, NoiseConfig,\n", " TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,\n", ")\n", "from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer\n", "from td3b.direction_oracle import DirectionalOracle\n", "from td3b.td3b_scoring import TD3BRewardFunction, create_td3b_reward_function\n", "from scoring.functions.binding import BindingAffinity\n", "from utils.app import PeptideAnalyzer\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load tokenizer\n", "tokenizer = SMILES_SPE_Tokenizer(\"tokenizer/new_vocab.txt\", \"tokenizer/new_splits.txt\")\n", "print(f\"Tokenizer vocab size: {len(tokenizer)}\")\n", "\n", "# Load diffusion model\n", "print(\"\\nLoading TD3B model...\")\n", "cfg = DiffusionConfig(\n", " roformer=RoFormerConfig(hidden_size=768, n_layers=8, n_heads=8),\n", " noise=NoiseConfig(),\n", " training=TrainingConfig(sampling_eps=1e-3),\n", " sampling=SamplingConfig(steps=128, sampling_eps=1e-3),\n", " eval_cfg=EvalConfig(), optim=OptimConfig(lr=3e-4), mcts=MCTSConfig(),\n", ")\n", "model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)\n", "\n", "ckpt = torch.load(\"checkpoints/td3b.ckpt\", map_location=device, weights_only=False)\n", "state_dict = ckpt.get(\"model_state_dict\") or ckpt.get(\"state_dict\") or ckpt\n", "model.load_state_dict(state_dict, strict=False)\n", "model.eval()\n", "model.tokenizer = tokenizer\n", "print(\"TD3B model loaded!\")\n", "\n", "# Load Direction Oracle\n", "print(\"\\nLoading Direction Oracle...\")\n", "oracle = DirectionalOracle(\n", " model_ckpt=\"checkpoints/direction_oracle.pt\",\n", " tr2d2_checkpoint=\"checkpoints/pretrained.ckpt\",\n", " tokenizer_vocab=\"tokenizer/new_vocab.txt\",\n", " tokenizer_splits=\"tokenizer/new_splits.txt\",\n", " device=device,\n", ")\n", "oracle.eval()\n", "print(\"Direction Oracle loaded!\")\n", "\n", "# Load Affinity Predictor\n", "print(\"\\nLoading Affinity Predictor...\")\n", "analyzer = PeptideAnalyzer()\n", "print(\"\\nAll models loaded!\")" ] }, { "cell_type": "markdown", "source": [ "## 3. Define Helper Functions" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sample_sequences(model, batch_size, seq_length, num_steps=128, eps=1e-5):\n", " \"\"\"Sample sequences from the diffusion model.\"\"\"\n", " x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)\n", " timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)\n", " dt = torch.tensor((1 - eps) / num_steps, device=model.device)\n", "\n", " for i in range(num_steps):\n", " t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)\n", " _, x = model.single_reverse_step(x, t=t, dt=dt)\n", " x = x.to(model.device)\n", "\n", " mask_pos = (x == model.mask_index)\n", " if mask_pos.any():\n", " t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)\n", " _, x = model.single_noise_removal(x, t=t, dt=dt)\n", " return x\n", "\n", "\n", "def generate_binders(target_seq, direction=\"agonist\", num_pool=32,\n", " num_keep=8, alpha=0.1, seq_length=200):\n", " \"\"\"\n", " Generate directional binders for a target protein.\n", " \n", " Args:\n", " target_seq: Target protein amino acid sequence\n", " direction: 'agonist' or 'antagonist'\n", " num_pool: Number of candidates to generate\n", " num_keep: Number of final samples after resampling\n", " alpha: Temperature for weighted resampling\n", " seq_length: Binder sequence length (in SMILES tokens)\n", " \n", " Returns:\n", " DataFrame with generated binders and scores\n", " \"\"\"\n", " d_star = 1.0 if direction == \"agonist\" else -1.0\n", " \n", " # Build reward function\n", " affinity_pred = BindingAffinity(\n", " prot_seq=target_seq, tokenizer=tokenizer,\n", " base_path=\".\", device=device, emb_model=model.backbone\n", " )\n", " reward_fn = create_td3b_reward_function(\n", " affinity_predictor=affinity_pred,\n", " directional_oracle=oracle,\n", " target_protein_seq=target_seq,\n", " target_direction=direction,\n", " peptide_tokenizer=tokenizer,\n", " device=device,\n", " )\n", " \n", " # Generate candidates\n", " with torch.no_grad():\n", " x_pool = sample_sequences(model, num_pool, seq_length)\n", " sequences = tokenizer.batch_decode(x_pool)\n", " \n", " # Score all\n", " rewards, info = reward_fn(sequences)\n", " affinities = info[\"affinities\"]\n", " directions = info[\"directions\"]\n", " \n", " # Weighted resampling (Algorithm 2)\n", " rewards_t = torch.as_tensor(rewards, device=device)\n", " weights = torch.softmax(rewards_t / max(alpha, 1e-6), dim=0)\n", " idx = torch.multinomial(weights, num_samples=num_keep, replacement=True)\n", " chosen = idx.cpu().numpy()\n", " \n", " # Filter to valid peptides only\n", " results = []\n", " for i in chosen:\n", " is_valid = analyzer.is_peptide(sequences[i])\n", " da = float(directions[i] > 0.5) if d_star > 0 else float(directions[i] < 0.5)\n", " results.append({\n", " \"sequence\": sequences[i],\n", " \"direction\": direction,\n", " \"is_valid\": is_valid,\n", " \"affinity\": float(affinities[i]),\n", " \"gated_reward\": float(rewards[i]),\n", " \"p_agonist\": float(directions[i]),\n", " \"direction_accuracy\": da,\n", " })\n", " \n", " df = pd.DataFrame(results)\n", " return df" ] }, { "cell_type": "markdown", "source": [ "## 4. Generate Binders\n", "\n", "Let's generate **agonist** and **antagonist** binders for a test target and compare the Direction Oracle predictions." ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load test targets\n", "test_df = pd.read_csv(\"data/test.csv\")\n", "print(f\"Test set: {len(test_df)} target-binder pairs\")\n", "\n", "# Pick first target for demo\n", "target_row = test_df.iloc[0]\n", "TARGET_SEQ = target_row[\"Target_Sequence\"]\n", "TARGET_UID = target_row[\"Target_UniProt_ID\"]\n", "print(f\"\\nTarget: {TARGET_UID}\")\n", "print(f\"Sequence length: {len(TARGET_SEQ)} aa\")\n", "print(f\"Sequence: {TARGET_SEQ[:60]}...\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# Generate AGONIST binders\n", "print(\"Generating agonist binders (d*=+1)...\")\n", "torch.manual_seed(42)\n", "np.random.seed(42)\n", "df_agonist = generate_binders(TARGET_SEQ, direction=\"agonist\", num_pool=32, num_keep=8)\n", "\n", "print(f\"\\nGenerated {len(df_agonist)} samples ({df_agonist['is_valid'].sum()} valid)\")\n", "print(f\"Mean p(agonist): {df_agonist['p_agonist'].mean():.3f}\")\n", "print(f\"Mean affinity: {df_agonist['affinity'].mean():.2f}\")\n", "print(f\"Mean gated reward: {df_agonist['gated_reward'].mean():.2f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# Generate ANTAGONIST binders\n", "print(\"Generating antagonist binders (d*=-1)...\")\n", "torch.manual_seed(42)\n", "np.random.seed(42)\n", "df_antagonist = generate_binders(TARGET_SEQ, direction=\"antagonist\", num_pool=32, num_keep=8)\n", "\n", "print(f\"\\nGenerated {len(df_antagonist)} samples ({df_antagonist['is_valid'].sum()} valid)\")\n", "print(f\"Mean p(agonist): {df_antagonist['p_agonist'].mean():.3f}\")\n", "print(f\"Mean affinity: {df_antagonist['affinity'].mean():.2f}\")\n", "print(f\"Mean gated reward: {df_antagonist['gated_reward'].mean():.2f}\")" ] }, { "cell_type": "markdown", "source": [ "## 5. Compare Directional Control" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n", "\n", "# Plot 1: Direction Oracle p(agonist)\n", "axes[0].hist(df_agonist[\"p_agonist\"], bins=20, alpha=0.7, label=\"d*=+1 (agonist)\", color=\"#e74c3c\")\n", "axes[0].hist(df_antagonist[\"p_agonist\"], bins=20, alpha=0.7, label=\"d*=-1 (antagonist)\", color=\"#3498db\")\n", "axes[0].axvline(0.5, color=\"gray\", linestyle=\"--\", label=\"threshold\")\n", "axes[0].set_xlabel(\"p(agonist)\")\n", "axes[0].set_ylabel(\"Count\")\n", "axes[0].set_title(\"Direction Oracle Predictions\")\n", "axes[0].legend()\n", "\n", "# Plot 2: Binding Affinity\n", "axes[1].hist(df_agonist[\"affinity\"], bins=20, alpha=0.7, label=\"Agonist\", color=\"#e74c3c\")\n", "axes[1].hist(df_antagonist[\"affinity\"], bins=20, alpha=0.7, label=\"Antagonist\", color=\"#3498db\")\n", "axes[1].set_xlabel(\"Predicted Binding Affinity\")\n", "axes[1].set_ylabel(\"Count\")\n", "axes[1].set_title(\"Binding Affinity Distribution\")\n", "axes[1].legend()\n", "\n", "# Plot 3: Gated Reward\n", "axes[2].hist(df_agonist[\"gated_reward\"], bins=20, alpha=0.7, label=\"Agonist\", color=\"#e74c3c\")\n", "axes[2].hist(df_antagonist[\"gated_reward\"], bins=20, alpha=0.7, label=\"Antagonist\", color=\"#3498db\")\n", "axes[2].set_xlabel(\"Gated Reward\")\n", "axes[2].set_ylabel(\"Count\")\n", "axes[2].set_title(\"Gated Reward Distribution\")\n", "axes[2].legend()\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"td3b_results.png\", dpi=150, bbox_inches=\"tight\")\n", "plt.show()\n", "\n", "print(\"\\nSummary:\")\n", "print(f\" Agonist mode: p(agonist)={df_agonist['p_agonist'].mean():.3f} Affinity={df_agonist['affinity'].mean():.2f} Gated={df_agonist['gated_reward'].mean():.2f}\")\n", "print(f\" Antagonist mode: p(agonist)={df_antagonist['p_agonist'].mean():.3f} Affinity={df_antagonist['affinity'].mean():.2f} Gated={df_antagonist['gated_reward'].mean():.2f}\")\n", "print(f\" Directional gap: Δp = {df_agonist['p_agonist'].mean() - df_antagonist['p_agonist'].mean():.3f}\")" ] }, { "cell_type": "markdown", "source": [ "## 6. Run on Multiple Targets\n", "\n", "Generate binders for the first 5 test targets and compute aggregate metrics." ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "N_TARGETS = 5 # Number of targets to evaluate (increase for full benchmark)\n", "\n", "all_results = []\n", "targets = test_df.drop_duplicates(\"Target_UniProt_ID\").head(N_TARGETS)\n", "\n", "for i, (_, row) in enumerate(targets.iterrows()):\n", " uid = row[\"Target_UniProt_ID\"]\n", " seq = row[\"Target_Sequence\"]\n", " print(f\"[{i+1}/{N_TARGETS}] {uid} (len={len(seq)})\")\n", " \n", " for direction in [\"agonist\", \"antagonist\"]:\n", " torch.manual_seed(42)\n", " np.random.seed(42)\n", " df = generate_binders(seq, direction=direction, num_pool=32, num_keep=8)\n", " df[\"target_uid\"] = uid\n", " all_results.append(df)\n", " \n", " d_star = 1.0 if direction == \"agonist\" else -1.0\n", " da = df[\"direction_accuracy\"].mean()\n", " print(f\" {direction:>10s}: DA={da:.2f} Aff={df['affinity'].mean():.2f} Gated={df['gated_reward'].mean():.2f} valid={df['is_valid'].sum()}/{len(df)}\")\n", "\n", "combined = pd.concat(all_results, ignore_index=True)\n", "\n", "print(f\"\\n{'='*60}\")\n", "print(f\"AGGREGATE METRICS ({N_TARGETS} targets)\")\n", "print(f\"{'='*60}\")\n", "for d_name, d_val in [(\"Agonist (d*=+1)\", \"agonist\"), (\"Antagonist (d*=-1)\", \"antagonist\")]:\n", " sub = combined[combined[\"direction\"] == d_val]\n", " valid = sub[sub[\"is_valid\"] == True]\n", " print(f\" {d_name}:\")\n", " print(f\" Affinity: {sub['affinity'].mean():.2f}\")\n", " print(f\" Direction Accuracy: {sub['direction_accuracy'].mean():.3f}\")\n", " print(f\" Gated Reward (all): {sub['gated_reward'].mean():.2f}\")\n", " if len(valid) > 0:\n", " print(f\" Gated Reward (valid): {valid['gated_reward'].mean():.2f}\")\n", " print(f\" Valid: {sub['is_valid'].sum()}/{len(sub)}\")\n", "\n", "# Save\n", "combined.to_csv(\"td3b_demo_results.csv\", index=False)\n", "print(f\"\\nResults saved to td3b_demo_results.csv\")" ] }, { "cell_type": "markdown", "source": [ "## Citation\n", "\n", "```bibtex\n", "@article{caotd3b,\n", " title={TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation},\n", " author={Cao, Hanqun and Pal, Aastha and Tang, Sophia and Zhang, Yinuo and Zhang, Jingjie and Heng, Pheng-Ann and Chatterjee, Pranam}\n", "}\n", "```" ], "metadata": {} } ] }