| |
| """ |
| TD3B Inference Script |
| Generate directional binders for target proteins using a finetuned TD3B model. |
| |
| Usage: |
| python inference.py \ |
| --ckpt_path checkpoints/td3b.ckpt \ |
| --val_csv data/test.csv \ |
| --save_path results/ \ |
| --seed 42 |
| """ |
| import argparse |
| import os |
| import sys |
| import logging |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| if ROOT_DIR not in sys.path: |
| sys.path.insert(0, ROOT_DIR) |
|
|
| from diffusion import Diffusion |
| from configs.finetune_config import ( |
| DiffusionConfig, RoFormerConfig, NoiseConfig, |
| TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig, |
| ) |
| from finetune_utils import load_tokenizer, create_reward_function |
| from td3b.direction_oracle import DirectionalOracle |
| from td3b.td3b_scoring import create_td3b_reward_function |
| from utils.app import PeptideAnalyzer |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
| |
| DEFAULTS = dict( |
| seq_length=200, |
| sampling_eps=1e-3, |
| total_num_steps=128, |
| hidden_dim=768, |
| num_layers=8, |
| num_heads=8, |
| alpha=0.1, |
| min_affinity_threshold=0.0, |
| sigmoid_temperature=0.1, |
| num_pool=32, |
| val_samples_per_target=8, |
| ) |
|
|
|
|
| def load_model(ckpt_path: str, device: torch.device): |
| """Load finetuned TD3B model from checkpoint.""" |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt |
| config = ckpt.get("config") or {} |
|
|
| tokenizer = load_tokenizer(ROOT_DIR) |
|
|
| cfg = DiffusionConfig( |
| roformer=RoFormerConfig( |
| hidden_size=config.get("hidden_dim", 768), |
| n_layers=config.get("num_layers", 8), |
| n_heads=config.get("num_heads", 8), |
| ), |
| noise=NoiseConfig(), |
| training=TrainingConfig(sampling_eps=1e-3), |
| sampling=SamplingConfig(steps=128, sampling_eps=1e-3), |
| eval_cfg=EvalConfig(), |
| optim=OptimConfig(lr=3e-4), |
| mcts=MCTSConfig(), |
| ) |
|
|
| model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| model.tokenizer = tokenizer |
| return model, tokenizer |
|
|
|
|
| def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5): |
| """Sample sequences from the diffusion model.""" |
| x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long) |
| timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device) |
| dt = torch.tensor((1 - eps) / num_steps, device=model.device) |
|
|
| for i in range(num_steps): |
| t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device) |
| _, x = model.single_reverse_step(x, t=t, dt=dt) |
| x = x.to(model.device) |
|
|
| |
| mask_pos = (x == model.mask_index) |
| if mask_pos.any(): |
| t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device) |
| _, x = model.single_noise_removal(x, t=t, dt=dt) |
| x = x.to(model.device) |
|
|
| return x |
|
|
|
|
| def score_sequences(reward_model, sequences: List[str]): |
| """Score sequences with the TD3B reward function.""" |
| result = reward_model(sequences) |
| if isinstance(result, tuple): |
| rewards, info = result |
| return ( |
| np.asarray(rewards), |
| np.asarray(info.get("affinities", rewards)), |
| np.asarray(info.get("directions", np.zeros_like(rewards))), |
| np.asarray(info.get("confidences", np.ones_like(rewards))), |
| ) |
| rewards = np.asarray(result) |
| return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="TD3B Inference") |
| parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint") |
| parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns") |
| parser.add_argument("--save_path", type=str, default="results", help="Output directory") |
| parser.add_argument("--device", type=str, default="cuda:0") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation") |
| parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction") |
| parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling") |
| parser.add_argument("--direction_oracle_ckpt", type=str, default=None) |
| parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None) |
| args = parser.parse_args() |
|
|
| |
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| os.makedirs(args.save_path, exist_ok=True) |
|
|
| analyzer = PeptideAnalyzer() |
|
|
| |
| logger.info(f"Loading model from {args.ckpt_path}") |
| model, tokenizer = load_model(args.ckpt_path, device) |
|
|
| |
| logger.info(f"Loading targets from {args.val_csv}") |
| df = pd.read_csv(args.val_csv) |
| targets = [] |
| for _, row in df.iterrows(): |
| targets.append({ |
| "target_seq": row["Target_Sequence"], |
| "target_uid": row.get("Target_UniProt_ID", ""), |
| "binder_seq": row.get("Ligand_Sequence", ""), |
| "label": row.get("label", ""), |
| "seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200), |
| }) |
|
|
| |
| logger.info("Building reward functions...") |
| oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt") |
| oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt") |
|
|
| records = [] |
|
|
| for tidx, target in enumerate(targets): |
| for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]: |
| logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}") |
|
|
| |
| try: |
| reward_model = create_reward_function( |
| base_path=ROOT_DIR, |
| tokenizer=tokenizer, |
| target_protein_seq=target["target_seq"], |
| target_direction="agonist" if d_star > 0 else "antagonist", |
| device=device, |
| emb_model=model.backbone, |
| directional_oracle_checkpoint=oracle_ckpt, |
| direction_oracle_tr2d2_checkpoint=oracle_tr2d2, |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to create reward for {target['target_uid']}: {e}") |
| continue |
|
|
| |
| target_length = target.get("seq_length", 200) |
| x_pool = sample_sequences(model, args.num_pool, target_length, 128) |
| sequences = tokenizer.batch_decode(x_pool) |
|
|
| |
| valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences]) |
|
|
| |
| gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences) |
| direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0 |
| else (directions < 0.5).astype(float)) |
|
|
| |
| finite = np.isfinite(gated_rewards) |
| if finite.any(): |
| rewards_t = torch.as_tensor(gated_rewards[finite], device=device) |
| alpha = max(args.resample_alpha, 1e-6) |
| weights = torch.softmax(rewards_t / alpha, dim=0) |
| idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True) |
| valid_idx = np.where(finite)[0] |
| chosen = valid_idx[idx.cpu().numpy()] |
| else: |
| chosen = np.arange(min(args.val_samples_per_target, len(sequences))) |
|
|
| |
| for i in chosen: |
| is_valid = bool(valid_mask[i]) if valid_mask.size else False |
| if not is_valid: |
| continue |
|
|
| records.append({ |
| "target": target["target_seq"][:20], |
| "target_uid": target["target_uid"], |
| "sequence": sequences[i], |
| "target_direction": d_star, |
| "direction_name": d_name, |
| "is_valid": True, |
| "affinity": float(affinities[i]), |
| "gated_reward": float(gated_rewards[i]), |
| "direction_oracle": float(directions[i]), |
| "direction_accuracy": float(direction_accuracy[i]), |
| }) |
|
|
| |
| out_df = pd.DataFrame(records) |
| out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv") |
| out_df.to_csv(out_path, index=False) |
|
|
| |
| if len(out_df) > 0: |
| dp = out_df[out_df["target_direction"] == 1.0] |
| dm = out_df[out_df["target_direction"] == -1.0] |
| logger.info(f"\n{'='*60}") |
| logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)") |
| logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples") |
| logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples") |
| logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "") |
| logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "") |
| logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}") |
| logger.info(f"{'='*60}") |
| else: |
| logger.warning("No valid samples generated.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|