import argparse import os import sys import time from dataclasses import dataclass from typing import Dict, List, Optional import numpy as np ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if ROOT_DIR not in sys.path: sys.path.insert(0, ROOT_DIR) import torch from hydra import compose, initialize_config_dir from hydra.core.global_hydra import GlobalHydra from diffusion import Diffusion from scoring.scoring_functions import ScoringFunctions from scoring.functions.binding import MultiTargetBindingAffinity from td3b.direction_oracle import DirectionalOracle, resolve_device from td3b.data_utils import peptide_seq_to_smiles, smiles_token_length from baselines.baselines import ( RewardInputs, RewardWrapper, classifier_guidance, peptune_mctg_sampling, sequential_monte_carlo, twisted_diffusion_sampler, unguided_sampling, ) AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY" @dataclass class ProteinTokenizer: aa_to_id: Dict[str, int] pad_id: int = 0 @classmethod def default(cls) -> "ProteinTokenizer": aa_to_id = {aa: idx + 1 for idx, aa in enumerate(AMINO_ACIDS)} return cls(aa_to_id=aa_to_id, pad_id=0) def encode(self, seq: str) -> torch.Tensor: ids = [self.aa_to_id.get(aa, self.pad_id) for aa in seq] return torch.tensor([ids], dtype=torch.long) def load_base_model( ckpt_path: str, device: str, config_name: str = "peptune_config.yaml", ) -> Diffusion: GlobalHydra.instance().clear() config_dir = os.path.join(os.path.dirname(__file__), "..", "configs") initialize_config_dir(config_dir=config_dir, job_name="load_model") cfg = compose(config_name=config_name) try: model = Diffusion.load_from_checkpoint( ckpt_path, config=cfg, mode="eval", device=device, map_location=device, ) model.eval() return model except Exception as exc: print(f"[load_base_model] Lightning load failed, falling back to raw state_dict: {exc}") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) if isinstance(checkpoint, dict): if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint else: raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}") model = Diffusion( config=cfg, mode="eval", device=device, ) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f"[load_base_model] Missing keys: {len(missing)}") if unexpected: print(f"[load_base_model] Unexpected keys: {len(unexpected)}") model.eval() model.to(device) return model def load_reward_models( prot_seq: Optional[str], device: str, base_model: Optional[Diffusion] = None, base_path: Optional[str] = None, multi_target: bool = False, score_func_names: Optional[List[str]] = None, ): if multi_target: if base_model is None or base_path is None: raise ValueError("base_model and base_path are required for multi-target affinity.") return MultiTargetBindingAffinity( tokenizer=base_model.tokenizer, base_path=base_path, device=device, emb_model=base_model.backbone, ) if score_func_names is None: score_func_names = [ "binding_affinity1", "solubility", "hemolysis", "nonfouling", "permeability", ] if prot_seq is None: raise ValueError("prot_seq is required for single-target scoring.") return ScoringFunctions(score_func_names, prot_seqs=[prot_seq], device=device) def load_direction_oracle(args, device: str) -> DirectionalOracle: oracle = DirectionalOracle( model_ckpt=args.direction_oracle_ckpt, tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint, tokenizer_vocab=args.direction_oracle_tokenizer_vocab, tokenizer_splits=args.direction_oracle_tokenizer_splits, esm_name=args.direction_oracle_esm_name, d_model=args.direction_oracle_d_model, n_heads=args.direction_oracle_n_heads, n_self_attn_layers=args.direction_oracle_n_self_attn_layers, n_bmca_layers=args.direction_oracle_n_bmca_layers, dropout=args.direction_oracle_dropout, max_ligand_length=args.direction_oracle_max_ligand_length, max_protein_length=args.direction_oracle_max_protein_length, device=device, esm_cache_dir=args.direction_oracle_esm_cache_dir, esm_local_files_only=args.direction_oracle_esm_local_files_only, ) oracle.eval() return oracle def run_baseline( baseline: str, base_model: Diffusion, reward_fn: RewardWrapper, batch_size: int, seq_length: int, num_steps: int, guidance_scale: float, alpha: float, guidance_steps: Optional[int], mcts_iterations: int, num_children: int, sample_prob_weight: float, invalid_penalty: float, pareto_max_size: Optional[int], ) -> Dict[str, torch.Tensor]: baseline = baseline.lower() if baseline == "cg": return classifier_guidance( base_model, reward_fn, batch_size=batch_size, seq_length=seq_length, num_steps=num_steps, guidance_scale=guidance_scale, guidance_steps=guidance_steps, ) if baseline == "unguided": return unguided_sampling( base_model, batch_size=batch_size, seq_length=seq_length, num_steps=num_steps, ) if baseline == "smc": return sequential_monte_carlo( base_model, reward_fn, batch_size=batch_size, seq_length=seq_length, num_steps=num_steps, alpha=alpha, ) if baseline == "tds": return twisted_diffusion_sampler( base_model, reward_fn, batch_size=batch_size, seq_length=seq_length, num_steps=num_steps, guidance_scale=guidance_scale, alpha=alpha, guidance_steps=guidance_steps, ) if baseline == "peptune": return peptune_mctg_sampling( base_model, reward_fn, batch_size=batch_size, seq_length=seq_length, num_steps=num_steps, mcts_iterations=mcts_iterations, num_children=num_children, alpha=alpha, sample_prob_weight=sample_prob_weight, invalid_penalty=invalid_penalty, pareto_max_size=pareto_max_size, ) raise ValueError(f"Unknown baseline: {baseline}") def main(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--baseline", type=str, default="cg", choices=["cg", "smc", "tds", "unguided", "peptune"]) parser.add_argument("--prot_seq", type=str, default=None) parser.add_argument("--targets_csv", type=str, default=None) parser.add_argument("--d_star", type=float, default=1.0) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--seq_length", type=int, default=200) parser.add_argument("--binder_seq", type=str, default=None) parser.add_argument("--num_steps", type=int, default=128) parser.add_argument("--guidance_scale", type=float, default=1.0) parser.add_argument("--alpha", type=float, default=0.1) parser.add_argument("--reward_alpha", type=float, default=None) parser.add_argument("--mcts_iterations", type=int, default=20) parser.add_argument("--num_children", type=int, default=24) parser.add_argument("--sample_prob_weight", type=float, default=0.1) parser.add_argument("--invalid_penalty", type=float, default=1.0) parser.add_argument("--pareto_max_size", type=int, default=None) parser.add_argument("--guidance_steps", type=int, default=None) parser.add_argument("--fast_direction", action="store_true", default=False) parser.add_argument("--num_batches", type=int, default=1) parser.add_argument("--output_dir", type=str, default=None) parser.add_argument("--shard_id", type=int, default=None) parser.add_argument("--num_shards", type=int, default=None) parser.add_argument("--direction_oracle_ckpt", type=str, default=None) parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None) parser.add_argument("--direction_oracle_tokenizer_vocab", type=str, default=None) parser.add_argument("--direction_oracle_tokenizer_splits", type=str, default=None) parser.add_argument("--direction_oracle_esm_name", type=str, default="facebook/esm2_t33_650M_UR50D") parser.add_argument("--direction_oracle_esm_cache_dir", type=str, default=None) parser.add_argument("--direction_oracle_esm_local_files_only", action="store_true", default=False) parser.add_argument("--direction_oracle_max_ligand_length", type=int, default=768) parser.add_argument("--direction_oracle_max_protein_length", type=int, default=1024) parser.add_argument("--direction_oracle_d_model", type=int, default=256) parser.add_argument("--direction_oracle_n_heads", type=int, default=4) parser.add_argument("--direction_oracle_n_self_attn_layers", type=int, default=1) parser.add_argument("--direction_oracle_n_bmca_layers", type=int, default=2) parser.add_argument("--direction_oracle_dropout", type=float, default=0.3) args = parser.parse_args() rank_env = os.environ.get("LOCAL_RANK") world_env = os.environ.get("WORLD_SIZE") if rank_env is not None or world_env is not None: rank = int(rank_env or 0) world_size = int(world_env or 1) else: rank = int(args.shard_id) if args.shard_id is not None else 0 world_size = int(args.num_shards) if args.num_shards is not None else 1 if world_size < 1: world_size = 1 if world_size > 1 and str(args.device).lower() in {"cuda", "cuda:0", "auto"}: args.device = f"cuda:{rank}" resolved_device = resolve_device(args.device) args.device = str(resolved_device) tr2d2_root = ROOT_DIR if args.direction_oracle_ckpt is None: args.direction_oracle_ckpt = os.path.join( tr2d2_root, "direction_oracle.pt" ) if args.direction_oracle_tr2d2_checkpoint is None: args.direction_oracle_tr2d2_checkpoint = os.path.join( tr2d2_root, "pretrained", "peptune-pretrained.ckpt" ) if args.direction_oracle_tokenizer_vocab is None: args.direction_oracle_tokenizer_vocab = os.path.join( tr2d2_root, "tokenizer", "new_vocab.txt" ) if args.direction_oracle_tokenizer_splits is None: args.direction_oracle_tokenizer_splits = os.path.join( tr2d2_root, "tokenizer", "new_splits.txt" ) if args.targets_csv is None and args.prot_seq is None: raise ValueError("--prot_seq is required when --targets_csv is not provided.") base_model = load_base_model(args.ckpt_path, args.device) base_path = os.path.abspath(os.path.join(ROOT_DIR, "..")) multi_target = args.targets_csv is not None scoring_fn = load_reward_models( args.prot_seq if not multi_target else None, args.device, base_model=base_model, base_path=base_path, multi_target=multi_target, ) direction_oracle = load_direction_oracle(args, args.device) reward_alpha = args.reward_alpha if args.reward_alpha is not None else args.alpha if args.targets_csv: import pandas as pd df = pd.read_csv(args.targets_csv) if "Target_Sequence" not in df.columns: raise ValueError("targets_csv must contain a 'Target_Sequence' column.") if "Ligand_Sequence" not in df.columns: raise ValueError("targets_csv must contain a 'Ligand_Sequence' column.") targets = [] for row_idx, row in df.iterrows(): target_seq = str(row["Target_Sequence"]) if pd.notna(row["Target_Sequence"]) else None if not target_seq: continue binder_seq = row["Ligand_Sequence"] if pd.isna(binder_seq): binder_seq = None else: binder_seq = str(binder_seq) if binder_seq.strip() == "": binder_seq = None targets.append( { "target_seq": target_seq, "binder_seq": binder_seq, "row_index": int(row_idx), } ) else: targets = [{"target_seq": args.prot_seq, "binder_seq": args.binder_seq, "row_index": 0}] if world_size > 1: targets = [item for idx, item in enumerate(targets) if idx % world_size == rank] print(f"[shard] rank {rank}/{world_size}: {len(targets)} targets") output_dir = args.output_dir if output_dir is None: output_dir = os.path.join(os.path.dirname(__file__), "outputs") os.makedirs(output_dir, exist_ok=True) from utils.app import PeptideAnalyzer analyzer = PeptideAnalyzer() all_rows = [] batch_rows = [] metrics_rows = [] def resolve_seq_length(binder_seq: Optional[str]) -> int: if not binder_seq: return args.seq_length try: smiles = peptide_seq_to_smiles(binder_seq) if not smiles: return args.seq_length if base_model.tokenizer is None: return len(smiles) return smiles_token_length(smiles, base_model.tokenizer) except Exception as exc: print(f"Warning: failed to derive seq_length from binder_seq; using {args.seq_length}. Error: {exc}") return args.seq_length for target_idx, target_info in enumerate(targets): target_seq = target_info["target_seq"] binder_seq = target_info.get("binder_seq") row_index = target_info.get("row_index", target_idx) seq_length = resolve_seq_length(binder_seq) protein_tokens = direction_oracle.encode_protein(target_seq) for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]: reward_inputs = RewardInputs( protein_tokens=protein_tokens, d_star=d_star, protein_seq=target_seq, ) reward_fn = RewardWrapper( scoring_fn=scoring_fn, direction_oracle=direction_oracle, base_model=base_model, tokenizer=base_model.tokenizer, reward_inputs=reward_inputs, device=torch.device(args.device), fast_direction=args.fast_direction, reward_alpha=reward_alpha, ) num_batches = 1 if multi_target else args.num_batches for batch_idx in range(num_batches): start = time.perf_counter() result = run_baseline( args.baseline, base_model, reward_fn, batch_size=args.batch_size, seq_length=seq_length, num_steps=args.num_steps, guidance_scale=args.guidance_scale, alpha=args.alpha, guidance_steps=args.guidance_steps, mcts_iterations=args.mcts_iterations, num_children=args.num_children, sample_prob_weight=args.sample_prob_weight, invalid_penalty=args.invalid_penalty, pareto_max_size=args.pareto_max_size, ) elapsed = time.perf_counter() - start scores = reward_fn.evaluate_tokens( result["tokens"], torch.ones_like(result["tokens"], device=result["tokens"].device), ) sequences = scores["sequences"] affinity = scores["affinity"].detach().cpu().numpy() direction = scores["direction"].detach().cpu().numpy() gated_reward = scores["gated_reward"].detach().cpu().numpy() valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences], dtype=np.float32) valid_fraction = float(valid_mask.mean()) if len(valid_mask) else 0.0 consistency = d_star * (direction - 0.5) if d_star > 0: direction_correct = (direction >= 0.5).astype(np.float32) else: direction_correct = (direction < 0.5).astype(np.float32) success = direction_correct * valid_mask direction_mean = float(np.mean(direction)) direction_std = float(np.std(direction)) affinity_mean = float(np.mean(affinity)) affinity_std = float(np.std(affinity)) consistency_mean = float(np.mean(consistency)) consistency_std = float(np.std(consistency)) gated_reward_mean = float(np.mean(gated_reward)) gated_reward_std = float(np.std(gated_reward)) direction_acc_mean = float(np.mean(direction_correct)) direction_acc_std = float(np.std(direction_correct)) success_rate_mean = float(np.mean(success)) success_rate_std = float(np.std(success)) batch_metrics = { "direction_mean": direction_mean, "direction_std": direction_std, "affinity_mean": affinity_mean, "affinity_std": affinity_std, "consistency_mean": consistency_mean, "consistency_std": consistency_std, "gated_reward_mean": gated_reward_mean, "gated_reward_std": gated_reward_std, "direction_accuracy_mean": direction_acc_mean, "direction_accuracy_std": direction_acc_std, "valid_fraction": valid_fraction, "success_rate_mean": success_rate_mean, "success_rate_std": success_rate_std, } for i, seq in enumerate(sequences): all_rows.append( { "rank": rank, "sequence": seq, "affinity": float(affinity[i]), "direction": float(direction[i]), "d_star": float(d_star), "direction_name": direction_name, "target_seq": target_seq, "target_index": target_idx, "row_index": row_index, "binder_seq": binder_seq, "seq_length": seq_length, "gated_reward": float(gated_reward[i]), "consistency_reward": float(consistency[i]), "direction_accuracy": float(direction_correct[i]), "valid": float(valid_mask[i]), "success": float(success[i]), "batch_index": batch_idx, "batch_time_sec": elapsed, **batch_metrics, } ) batch_rows.append( { "rank": rank, "batch_index": batch_idx, "batch_time_sec": elapsed, "target_index": target_idx, "row_index": row_index, "binder_seq": binder_seq, "seq_length": seq_length, "direction_name": direction_name, } ) metrics_rows.append( { "rank": rank, "target_index": target_idx, "target_seq": target_seq, "row_index": row_index, "binder_seq": binder_seq, "seq_length": seq_length, "direction_name": direction_name, "d_star": float(d_star), "batch_index": batch_idx, "num_samples": len(sequences), **batch_metrics, } ) print( f"Target {target_idx} dir {direction_name}: " f"generated {len(sequences)} sequences in {elapsed:.3f}s" ) import pandas as pd if world_size > 1: output_csv = os.path.join(output_dir, f"{args.baseline}_samples_rank{rank}.csv") batch_csv = os.path.join(output_dir, f"batch_times_rank{rank}.csv") metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics_rank{rank}.csv") else: output_csv = os.path.join(output_dir, f"{args.baseline}_samples.csv") batch_csv = os.path.join(output_dir, "batch_times.csv") metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics.csv") pd.DataFrame(all_rows).to_csv(output_csv, index=False) pd.DataFrame(batch_rows).to_csv(batch_csv, index=False) pd.DataFrame(metrics_rows).to_csv(metrics_csv, index=False) print(f"Saved samples to {output_csv}") if __name__ == "__main__": main()