TD3B / inference.py
chq1155's picture
Upload TD3B code (inference, training, baselines)
ee6da62 verified
#!/usr/bin/env python3
"""
TD3B Inference Script
Generate directional binders for target proteins using a finetuned TD3B model.
Usage:
python inference.py \
--ckpt_path checkpoints/td3b.ckpt \
--val_csv data/test.csv \
--save_path results/ \
--seed 42
"""
import argparse
import os
import sys
import logging
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from diffusion import Diffusion
from configs.finetune_config import (
DiffusionConfig, RoFormerConfig, NoiseConfig,
TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,
)
from finetune_utils import load_tokenizer, create_reward_function
from td3b.direction_oracle import DirectionalOracle
from td3b.td3b_scoring import create_td3b_reward_function
from utils.app import PeptideAnalyzer
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# ─── Defaults ─────────────────────────────────────────────────────────────────
DEFAULTS = dict(
seq_length=200,
sampling_eps=1e-3,
total_num_steps=128,
hidden_dim=768,
num_layers=8,
num_heads=8,
alpha=0.1,
min_affinity_threshold=0.0,
sigmoid_temperature=0.1,
num_pool=32,
val_samples_per_target=8,
)
def load_model(ckpt_path: str, device: torch.device):
"""Load finetuned TD3B model from checkpoint."""
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
config = ckpt.get("config") or {}
tokenizer = load_tokenizer(ROOT_DIR)
cfg = DiffusionConfig(
roformer=RoFormerConfig(
hidden_size=config.get("hidden_dim", 768),
n_layers=config.get("num_layers", 8),
n_heads=config.get("num_heads", 8),
),
noise=NoiseConfig(),
training=TrainingConfig(sampling_eps=1e-3),
sampling=SamplingConfig(steps=128, sampling_eps=1e-3),
eval_cfg=EvalConfig(),
optim=OptimConfig(lr=3e-4),
mcts=MCTSConfig(),
)
model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)
model.load_state_dict(state_dict, strict=False)
model.eval()
model.tokenizer = tokenizer
return model, tokenizer
def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5):
"""Sample sequences from the diffusion model."""
x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)
dt = torch.tensor((1 - eps) / num_steps, device=model.device)
for i in range(num_steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)
_, x = model.single_reverse_step(x, t=t, dt=dt)
x = x.to(model.device)
# Remove remaining masks
mask_pos = (x == model.mask_index)
if mask_pos.any():
t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)
_, x = model.single_noise_removal(x, t=t, dt=dt)
x = x.to(model.device)
return x
def score_sequences(reward_model, sequences: List[str]):
"""Score sequences with the TD3B reward function."""
result = reward_model(sequences)
if isinstance(result, tuple):
rewards, info = result
return (
np.asarray(rewards),
np.asarray(info.get("affinities", rewards)),
np.asarray(info.get("directions", np.zeros_like(rewards))),
np.asarray(info.get("confidences", np.ones_like(rewards))),
)
rewards = np.asarray(result)
return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards)
def main():
parser = argparse.ArgumentParser(description="TD3B Inference")
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint")
parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns")
parser.add_argument("--save_path", type=str, default="results", help="Output directory")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation")
parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction")
parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling")
parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
args = parser.parse_args()
# Setup
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
np.random.seed(args.seed)
os.makedirs(args.save_path, exist_ok=True)
analyzer = PeptideAnalyzer()
# Load model
logger.info(f"Loading model from {args.ckpt_path}")
model, tokenizer = load_model(args.ckpt_path, device)
# Load targets
logger.info(f"Loading targets from {args.val_csv}")
df = pd.read_csv(args.val_csv)
targets = []
for _, row in df.iterrows():
targets.append({
"target_seq": row["Target_Sequence"],
"target_uid": row.get("Target_UniProt_ID", ""),
"binder_seq": row.get("Ligand_Sequence", ""),
"label": row.get("label", ""),
"seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200),
})
# Build reward function for each target
logger.info("Building reward functions...")
oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt")
oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt")
records = []
for tidx, target in enumerate(targets):
for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]:
logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}")
# Create reward function
try:
reward_model = create_reward_function(
base_path=ROOT_DIR,
tokenizer=tokenizer,
target_protein_seq=target["target_seq"],
target_direction="agonist" if d_star > 0 else "antagonist",
device=device,
emb_model=model.backbone,
directional_oracle_checkpoint=oracle_ckpt,
direction_oracle_tr2d2_checkpoint=oracle_tr2d2,
)
except Exception as e:
logger.warning(f"Failed to create reward for {target['target_uid']}: {e}")
continue
# Generate pool of candidates
target_length = target.get("seq_length", 200)
x_pool = sample_sequences(model, args.num_pool, target_length, 128)
sequences = tokenizer.batch_decode(x_pool)
# Check validity
valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences])
# Score all
gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences)
direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0
else (directions < 0.5).astype(float))
# Weighted resampling (Algorithm 2)
finite = np.isfinite(gated_rewards)
if finite.any():
rewards_t = torch.as_tensor(gated_rewards[finite], device=device)
alpha = max(args.resample_alpha, 1e-6)
weights = torch.softmax(rewards_t / alpha, dim=0)
idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True)
valid_idx = np.where(finite)[0]
chosen = valid_idx[idx.cpu().numpy()]
else:
chosen = np.arange(min(args.val_samples_per_target, len(sequences)))
# Save only VALID resampled samples
for i in chosen:
is_valid = bool(valid_mask[i]) if valid_mask.size else False
if not is_valid:
continue # Skip invalid samples
records.append({
"target": target["target_seq"][:20],
"target_uid": target["target_uid"],
"sequence": sequences[i],
"target_direction": d_star,
"direction_name": d_name,
"is_valid": True,
"affinity": float(affinities[i]),
"gated_reward": float(gated_rewards[i]),
"direction_oracle": float(directions[i]),
"direction_accuracy": float(direction_accuracy[i]),
})
# Save results
out_df = pd.DataFrame(records)
out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv")
out_df.to_csv(out_path, index=False)
# Print summary
if len(out_df) > 0:
dp = out_df[out_df["target_direction"] == 1.0]
dm = out_df[out_df["target_direction"] == -1.0]
logger.info(f"\n{'='*60}")
logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)")
logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples")
logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples")
logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "")
logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "")
logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}")
logger.info(f"{'='*60}")
else:
logger.warning("No valid samples generated.")
if __name__ == "__main__":
main()