#!/usr/bin/env python3 """ TD3B Inference Script Generate directional binders for target proteins using a finetuned TD3B model. Usage: python inference.py \ --ckpt_path checkpoints/td3b.ckpt \ --val_csv data/test.csv \ --save_path results/ \ --seed 42 """ import argparse import os import sys import logging from typing import Dict, List, Tuple import numpy as np import pandas as pd import torch ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) if ROOT_DIR not in sys.path: sys.path.insert(0, ROOT_DIR) from diffusion import Diffusion from configs.finetune_config import ( DiffusionConfig, RoFormerConfig, NoiseConfig, TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig, ) from finetune_utils import load_tokenizer, create_reward_function from td3b.direction_oracle import DirectionalOracle from td3b.td3b_scoring import create_td3b_reward_function from utils.app import PeptideAnalyzer logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") # ─── Defaults ───────────────────────────────────────────────────────────────── DEFAULTS = dict( seq_length=200, sampling_eps=1e-3, total_num_steps=128, hidden_dim=768, num_layers=8, num_heads=8, alpha=0.1, min_affinity_threshold=0.0, sigmoid_temperature=0.1, num_pool=32, val_samples_per_target=8, ) def load_model(ckpt_path: str, device: torch.device): """Load finetuned TD3B model from checkpoint.""" ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt config = ckpt.get("config") or {} tokenizer = load_tokenizer(ROOT_DIR) cfg = DiffusionConfig( roformer=RoFormerConfig( hidden_size=config.get("hidden_dim", 768), n_layers=config.get("num_layers", 8), n_heads=config.get("num_heads", 8), ), noise=NoiseConfig(), training=TrainingConfig(sampling_eps=1e-3), sampling=SamplingConfig(steps=128, sampling_eps=1e-3), eval_cfg=EvalConfig(), optim=OptimConfig(lr=3e-4), mcts=MCTSConfig(), ) model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device) model.load_state_dict(state_dict, strict=False) model.eval() model.tokenizer = tokenizer return model, tokenizer def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5): """Sample sequences from the diffusion model.""" x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long) timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device) dt = torch.tensor((1 - eps) / num_steps, device=model.device) for i in range(num_steps): t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device) _, x = model.single_reverse_step(x, t=t, dt=dt) x = x.to(model.device) # Remove remaining masks mask_pos = (x == model.mask_index) if mask_pos.any(): t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device) _, x = model.single_noise_removal(x, t=t, dt=dt) x = x.to(model.device) return x def score_sequences(reward_model, sequences: List[str]): """Score sequences with the TD3B reward function.""" result = reward_model(sequences) if isinstance(result, tuple): rewards, info = result return ( np.asarray(rewards), np.asarray(info.get("affinities", rewards)), np.asarray(info.get("directions", np.zeros_like(rewards))), np.asarray(info.get("confidences", np.ones_like(rewards))), ) rewards = np.asarray(result) return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards) def main(): parser = argparse.ArgumentParser(description="TD3B Inference") parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint") parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns") parser.add_argument("--save_path", type=str, default="results", help="Output directory") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation") parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction") parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling") parser.add_argument("--direction_oracle_ckpt", type=str, default=None) parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None) args = parser.parse_args() # Setup device = torch.device(args.device if torch.cuda.is_available() else "cpu") torch.manual_seed(args.seed) np.random.seed(args.seed) os.makedirs(args.save_path, exist_ok=True) analyzer = PeptideAnalyzer() # Load model logger.info(f"Loading model from {args.ckpt_path}") model, tokenizer = load_model(args.ckpt_path, device) # Load targets logger.info(f"Loading targets from {args.val_csv}") df = pd.read_csv(args.val_csv) targets = [] for _, row in df.iterrows(): targets.append({ "target_seq": row["Target_Sequence"], "target_uid": row.get("Target_UniProt_ID", ""), "binder_seq": row.get("Ligand_Sequence", ""), "label": row.get("label", ""), "seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200), }) # Build reward function for each target logger.info("Building reward functions...") oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt") oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt") records = [] for tidx, target in enumerate(targets): for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]: logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}") # Create reward function try: reward_model = create_reward_function( base_path=ROOT_DIR, tokenizer=tokenizer, target_protein_seq=target["target_seq"], target_direction="agonist" if d_star > 0 else "antagonist", device=device, emb_model=model.backbone, directional_oracle_checkpoint=oracle_ckpt, direction_oracle_tr2d2_checkpoint=oracle_tr2d2, ) except Exception as e: logger.warning(f"Failed to create reward for {target['target_uid']}: {e}") continue # Generate pool of candidates target_length = target.get("seq_length", 200) x_pool = sample_sequences(model, args.num_pool, target_length, 128) sequences = tokenizer.batch_decode(x_pool) # Check validity valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences]) # Score all gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences) direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0 else (directions < 0.5).astype(float)) # Weighted resampling (Algorithm 2) finite = np.isfinite(gated_rewards) if finite.any(): rewards_t = torch.as_tensor(gated_rewards[finite], device=device) alpha = max(args.resample_alpha, 1e-6) weights = torch.softmax(rewards_t / alpha, dim=0) idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True) valid_idx = np.where(finite)[0] chosen = valid_idx[idx.cpu().numpy()] else: chosen = np.arange(min(args.val_samples_per_target, len(sequences))) # Save only VALID resampled samples for i in chosen: is_valid = bool(valid_mask[i]) if valid_mask.size else False if not is_valid: continue # Skip invalid samples records.append({ "target": target["target_seq"][:20], "target_uid": target["target_uid"], "sequence": sequences[i], "target_direction": d_star, "direction_name": d_name, "is_valid": True, "affinity": float(affinities[i]), "gated_reward": float(gated_rewards[i]), "direction_oracle": float(directions[i]), "direction_accuracy": float(direction_accuracy[i]), }) # Save results out_df = pd.DataFrame(records) out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv") out_df.to_csv(out_path, index=False) # Print summary if len(out_df) > 0: dp = out_df[out_df["target_direction"] == 1.0] dm = out_df[out_df["target_direction"] == -1.0] logger.info(f"\n{'='*60}") logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)") logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples") logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples") logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "") logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "") logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}") logger.info(f"{'='*60}") else: logger.warning("No valid samples generated.") if __name__ == "__main__": main()