File size: 10,388 Bytes
ee6da62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | #!/usr/bin/env python3
"""
TD3B Inference Script
Generate directional binders for target proteins using a finetuned TD3B model.
Usage:
python inference.py \
--ckpt_path checkpoints/td3b.ckpt \
--val_csv data/test.csv \
--save_path results/ \
--seed 42
"""
import argparse
import os
import sys
import logging
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from diffusion import Diffusion
from configs.finetune_config import (
DiffusionConfig, RoFormerConfig, NoiseConfig,
TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,
)
from finetune_utils import load_tokenizer, create_reward_function
from td3b.direction_oracle import DirectionalOracle
from td3b.td3b_scoring import create_td3b_reward_function
from utils.app import PeptideAnalyzer
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# βββ Defaults βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DEFAULTS = dict(
seq_length=200,
sampling_eps=1e-3,
total_num_steps=128,
hidden_dim=768,
num_layers=8,
num_heads=8,
alpha=0.1,
min_affinity_threshold=0.0,
sigmoid_temperature=0.1,
num_pool=32,
val_samples_per_target=8,
)
def load_model(ckpt_path: str, device: torch.device):
"""Load finetuned TD3B model from checkpoint."""
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
config = ckpt.get("config") or {}
tokenizer = load_tokenizer(ROOT_DIR)
cfg = DiffusionConfig(
roformer=RoFormerConfig(
hidden_size=config.get("hidden_dim", 768),
n_layers=config.get("num_layers", 8),
n_heads=config.get("num_heads", 8),
),
noise=NoiseConfig(),
training=TrainingConfig(sampling_eps=1e-3),
sampling=SamplingConfig(steps=128, sampling_eps=1e-3),
eval_cfg=EvalConfig(),
optim=OptimConfig(lr=3e-4),
mcts=MCTSConfig(),
)
model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)
model.load_state_dict(state_dict, strict=False)
model.eval()
model.tokenizer = tokenizer
return model, tokenizer
def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5):
"""Sample sequences from the diffusion model."""
x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)
dt = torch.tensor((1 - eps) / num_steps, device=model.device)
for i in range(num_steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)
_, x = model.single_reverse_step(x, t=t, dt=dt)
x = x.to(model.device)
# Remove remaining masks
mask_pos = (x == model.mask_index)
if mask_pos.any():
t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)
_, x = model.single_noise_removal(x, t=t, dt=dt)
x = x.to(model.device)
return x
def score_sequences(reward_model, sequences: List[str]):
"""Score sequences with the TD3B reward function."""
result = reward_model(sequences)
if isinstance(result, tuple):
rewards, info = result
return (
np.asarray(rewards),
np.asarray(info.get("affinities", rewards)),
np.asarray(info.get("directions", np.zeros_like(rewards))),
np.asarray(info.get("confidences", np.ones_like(rewards))),
)
rewards = np.asarray(result)
return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards)
def main():
parser = argparse.ArgumentParser(description="TD3B Inference")
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint")
parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns")
parser.add_argument("--save_path", type=str, default="results", help="Output directory")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation")
parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction")
parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling")
parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
args = parser.parse_args()
# Setup
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
np.random.seed(args.seed)
os.makedirs(args.save_path, exist_ok=True)
analyzer = PeptideAnalyzer()
# Load model
logger.info(f"Loading model from {args.ckpt_path}")
model, tokenizer = load_model(args.ckpt_path, device)
# Load targets
logger.info(f"Loading targets from {args.val_csv}")
df = pd.read_csv(args.val_csv)
targets = []
for _, row in df.iterrows():
targets.append({
"target_seq": row["Target_Sequence"],
"target_uid": row.get("Target_UniProt_ID", ""),
"binder_seq": row.get("Ligand_Sequence", ""),
"label": row.get("label", ""),
"seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200),
})
# Build reward function for each target
logger.info("Building reward functions...")
oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt")
oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt")
records = []
for tidx, target in enumerate(targets):
for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]:
logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}")
# Create reward function
try:
reward_model = create_reward_function(
base_path=ROOT_DIR,
tokenizer=tokenizer,
target_protein_seq=target["target_seq"],
target_direction="agonist" if d_star > 0 else "antagonist",
device=device,
emb_model=model.backbone,
directional_oracle_checkpoint=oracle_ckpt,
direction_oracle_tr2d2_checkpoint=oracle_tr2d2,
)
except Exception as e:
logger.warning(f"Failed to create reward for {target['target_uid']}: {e}")
continue
# Generate pool of candidates
target_length = target.get("seq_length", 200)
x_pool = sample_sequences(model, args.num_pool, target_length, 128)
sequences = tokenizer.batch_decode(x_pool)
# Check validity
valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences])
# Score all
gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences)
direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0
else (directions < 0.5).astype(float))
# Weighted resampling (Algorithm 2)
finite = np.isfinite(gated_rewards)
if finite.any():
rewards_t = torch.as_tensor(gated_rewards[finite], device=device)
alpha = max(args.resample_alpha, 1e-6)
weights = torch.softmax(rewards_t / alpha, dim=0)
idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True)
valid_idx = np.where(finite)[0]
chosen = valid_idx[idx.cpu().numpy()]
else:
chosen = np.arange(min(args.val_samples_per_target, len(sequences)))
# Save only VALID resampled samples
for i in chosen:
is_valid = bool(valid_mask[i]) if valid_mask.size else False
if not is_valid:
continue # Skip invalid samples
records.append({
"target": target["target_seq"][:20],
"target_uid": target["target_uid"],
"sequence": sequences[i],
"target_direction": d_star,
"direction_name": d_name,
"is_valid": True,
"affinity": float(affinities[i]),
"gated_reward": float(gated_rewards[i]),
"direction_oracle": float(directions[i]),
"direction_accuracy": float(direction_accuracy[i]),
})
# Save results
out_df = pd.DataFrame(records)
out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv")
out_df.to_csv(out_path, index=False)
# Print summary
if len(out_df) > 0:
dp = out_df[out_df["target_direction"] == 1.0]
dm = out_df[out_df["target_direction"] == -1.0]
logger.info(f"\n{'='*60}")
logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)")
logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples")
logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples")
logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "")
logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "")
logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}")
logger.info(f"{'='*60}")
else:
logger.warning("No valid samples generated.")
if __name__ == "__main__":
main()
|