| |
| """ |
| GPCR Agonist Classifier - TR2-D2 Inference Script |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from types import SimpleNamespace |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import EsmModel, EsmTokenizer |
|
|
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, PROJECT_ROOT) |
|
|
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from roformer import Roformer |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def resolve_device(requested: Optional[str]) -> torch.device: |
| if requested is None or str(requested).lower() == "auto": |
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: |
| return torch.device("cuda:0") |
| return torch.device("cpu") |
|
|
| try: |
| device = torch.device(requested) |
| except Exception as exc: |
| logger.warning("Invalid device '%s': %s. Falling back to CPU.", requested, exc) |
| return torch.device("cpu") |
| if device.type != "cuda": |
| return device |
|
|
| if not torch.cuda.is_available() or torch.cuda.device_count() == 0: |
| logger.warning("CUDA requested but not available; falling back to CPU") |
| return torch.device("cpu") |
|
|
| index = device.index if device.index is not None else 0 |
| count = torch.cuda.device_count() |
| if index is None or index < 0 or index >= count: |
| logger.warning( |
| "CUDA device %s requested but only %d visible; using cuda:0", |
| index, |
| count |
| ) |
| return torch.device("cuda:0") |
|
|
| return torch.device(f"cuda:{index}") |
|
|
| |
| |
| |
| def peptide_to_smiles(seq: str) -> str: |
| from rdkit import Chem |
| seq = seq.strip().upper() |
| mol = Chem.MolFromSequence(seq) |
| if mol is None: |
| raise ValueError(f"RDKit failed to convert peptide '{seq}' to SMILES") |
| return Chem.MolToSmiles(mol) |
|
|
| |
| |
| |
| class SelfAttentionBlock(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.1): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True) |
| self.norm = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x, key_padding_mask=None): |
| attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask) |
| x = self.norm(x + self.dropout(attn_out)) |
| return x |
|
|
| |
| |
| |
| class BiMultiHeadCrossAttention(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.1): |
| super().__init__() |
| self.prot_to_lig = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True) |
| self.lig_to_prot = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True) |
| self.prot_ln = nn.LayerNorm(d_model) |
| self.lig_ln = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, prot_h, lig_h, prot_kpm=None, lig_kpm=None): |
| prot_ctx, _ = self.prot_to_lig(prot_h, lig_h, lig_h, key_padding_mask=lig_kpm) |
| prot_h_out = self.prot_ln(prot_h + self.dropout(prot_ctx)) |
| |
| lig_ctx, _ = self.lig_to_prot(lig_h, prot_h, prot_h, key_padding_mask=prot_kpm) |
| lig_h_out = self.lig_ln(lig_h + self.dropout(lig_ctx)) |
| |
| return prot_h_out, lig_h_out |
|
|
| |
| |
| |
| class TR2D2RoFormerEncoder(nn.Module): |
| def __init__(self, config, tokenizer, checkpoint_path=None, device="cpu"): |
| super().__init__() |
| self.device = device |
| self.encoder = Roformer(config, tokenizer, device=device) |
|
|
| if checkpoint_path: |
| print(f" Loading TR2-D2 checkpoint...") |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| state_dict = ckpt.get("state_dict", ckpt) |
| roformer_state = { |
| k.replace("model.", "").replace("backbone.", ""): v |
| for k, v in state_dict.items() |
| if "roformer" in k or "encoder" in k or "backbone" in k |
| } |
| self.encoder.model.load_state_dict(roformer_state, strict=False) |
| print(" TR2-D2 checkpoint loaded") |
|
|
| for p in self.encoder.parameters(): |
| p.requires_grad = False |
| self.encoder.eval() |
|
|
| def forward(self, input_ids, attention_mask, inputs_embeds=None): |
| if attention_mask is None: |
| raise ValueError("attention_mask is required for ligand encoding.") |
| attention_mask = attention_mask.to(self.device) |
| if inputs_embeds is not None: |
| inputs_embeds = inputs_embeds.to(self.device) |
| out = self.encoder.model.roformer( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask |
| ) |
| else: |
| input_ids = input_ids.to(self.device) |
| with torch.no_grad(): |
| out = self.encoder.model.roformer( |
| input_ids=input_ids, |
| attention_mask=attention_mask |
| ) |
| return out.last_hidden_state |
|
|
| |
| |
| |
| class ESM_TR2D2_GPCRClassifier(nn.Module): |
| """ |
| GPCR Agonist Classifier with TR2-D2 |
| |
| Architecture: |
| 1. ESM2 (protein) + TR2-D2 RoFormer (ligand) |
| 2. Projections to common dimension |
| 3. Self-Attention (1 layer each) |
| 4. BiDirectional Cross-Attention (2 stacked layers) |
| 5. Masked Average Pooling |
| 6. MLP Classifier |
| """ |
| def __init__( |
| self, |
| esm_name, |
| tr2d2_config, |
| lig_tokenizer, |
| tr2d2_checkpoint=None, |
| d_model=256, |
| n_heads=4, |
| n_self_attn_layers=1, |
| n_bmca_layers=2, |
| dropout=0.3, |
| device="cuda", |
| esm_cache_dir=None, |
| esm_local_files_only=False |
| ): |
| super().__init__() |
| self.device = device |
|
|
| |
| print("Loading ESM2 protein encoder...") |
| self.esm = EsmModel.from_pretrained( |
| esm_name, |
| cache_dir=esm_cache_dir, |
| local_files_only=esm_local_files_only |
| ) |
| for p in self.esm.parameters(): |
| p.requires_grad = False |
| self.esm.eval() |
|
|
| print("Loading TR2-D2 ligand encoder...") |
| self.ligand_encoder = TR2D2RoFormerEncoder( |
| tr2d2_config, lig_tokenizer, tr2d2_checkpoint, device |
| ) |
|
|
| esm_dim = self.esm.config.hidden_size |
| lig_dim = tr2d2_config.roformer.hidden_size |
|
|
| self.prot_proj = nn.Linear(esm_dim, d_model) |
| self.lig_proj = nn.Linear(lig_dim, d_model) |
|
|
| |
| self.prot_self_attn_layers = nn.ModuleList([ |
| SelfAttentionBlock(d_model, n_heads, dropout) |
| for _ in range(n_self_attn_layers) |
| ]) |
| self.lig_self_attn_layers = nn.ModuleList([ |
| SelfAttentionBlock(d_model, n_heads, dropout) |
| for _ in range(n_self_attn_layers) |
| ]) |
|
|
| |
| self.bmca_layers = nn.ModuleList([ |
| BiMultiHeadCrossAttention(d_model, n_heads, dropout) |
| for _ in range(n_bmca_layers) |
| ]) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(2 * d_model, d_model), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_model, 2) |
| ) |
|
|
| def forward(self, prot_tokens, lig_tokens, lig_inputs_embeds=None): |
| prot_kpm = prot_tokens["attention_mask"].eq(0) |
| lig_kpm = lig_tokens["attention_mask"].eq(0) |
|
|
| with torch.no_grad(): |
| prot_out = self.esm(**prot_tokens).last_hidden_state |
|
|
| lig_out = self.ligand_encoder( |
| lig_tokens["input_ids"], |
| lig_tokens["attention_mask"], |
| inputs_embeds=lig_inputs_embeds |
| ) |
|
|
| prot_h = self.prot_proj(prot_out) |
| lig_h = self.lig_proj(lig_out) |
|
|
| |
| for self_attn in self.prot_self_attn_layers: |
| prot_h = self_attn(prot_h, key_padding_mask=prot_kpm) |
| for self_attn in self.lig_self_attn_layers: |
| lig_h = self_attn(lig_h, key_padding_mask=lig_kpm) |
|
|
| |
| for bmca in self.bmca_layers: |
| prot_h, lig_h = bmca(prot_h, lig_h, prot_kpm=prot_kpm, lig_kpm=lig_kpm) |
|
|
| |
| prot_mask = prot_tokens["attention_mask"].unsqueeze(-1) |
| lig_mask = lig_tokens["attention_mask"].unsqueeze(-1) |
|
|
| prot_repr = (prot_h * prot_mask).sum(dim=1) / prot_mask.sum(dim=1).clamp(min=1) |
| lig_repr = (lig_h * lig_mask).sum(dim=1) / lig_mask.sum(dim=1).clamp(min=1) |
|
|
| return self.classifier(torch.cat([prot_repr, lig_repr], dim=-1)) |
|
|
| |
| |
| |
| def create_tr2d2_config(vocab_size): |
| return SimpleNamespace( |
| roformer=SimpleNamespace( |
| vocab_size=vocab_size, |
| hidden_size=768, |
| n_layers=8, |
| n_heads=8, |
| max_position_embeddings=1035 |
| ) |
| ) |
|
|
|
|
| def _load_state_dict_flexible(model: nn.Module, state_dict: Dict, strict: bool = True) -> None: |
| try: |
| model.load_state_dict(state_dict, strict=strict) |
| return |
| except RuntimeError as exc: |
| model_keys = set(model.state_dict().keys()) |
| filtered = {k: v for k, v in state_dict.items() if k in model_keys} |
| logger.warning("Strict load failed: %s", exc) |
| logger.warning( |
| "Retrying with filtered keys (%d/%d) and strict=False", |
| len(filtered), |
| len(state_dict) |
| ) |
| incompatible = model.load_state_dict(filtered, strict=False) |
| if incompatible.missing_keys: |
| logger.warning("Missing keys (first 10): %s", incompatible.missing_keys[:10]) |
| if incompatible.unexpected_keys: |
| logger.warning("Unexpected keys (first 10): %s", incompatible.unexpected_keys[:10]) |
|
|
| def tokenize_protein(seq, tokenizer, device): |
| out = tokenizer( |
| seq, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=1024, |
| add_special_tokens=True |
| ) |
| return {k: v.to(device) for k, v in out.items()} |
|
|
| def tokenize_ligand(smiles, tokenizer, max_len, device): |
| enc = tokenizer( |
| smiles, |
| return_tensors="pt", |
| truncation=True, |
| max_length=max_len, |
| add_special_tokens=True |
| ) |
| ids = enc["input_ids"].squeeze(0) |
| att = enc["attention_mask"].squeeze(0) |
| |
| pad = max_len - ids.numel() |
| if pad > 0: |
| ids = torch.cat([ids, torch.full((pad,), tokenizer.pad_token_id)]) |
| att = torch.cat([att, torch.zeros(pad)]) |
| |
| return { |
| "input_ids": ids.unsqueeze(0).to(device), |
| "attention_mask": att.unsqueeze(0).to(device) |
| } |
|
|
| |
| |
| |
| class DirectionalOracle(nn.Module): |
| """ |
| Batch-capable oracle wrapper with TD3B-compatible predict_with_confidence(). |
| |
| This class is intended for training integration where peptide/protein tokens |
| are provided directly (batched) and the oracle runs in inference-only mode. |
| """ |
| def __init__( |
| self, |
| model_ckpt: str, |
| tr2d2_checkpoint: str, |
| tokenizer_vocab: str, |
| tokenizer_splits: str, |
| esm_name: str = "facebook/esm2_t33_650M_UR50D", |
| d_model: int = 256, |
| n_heads: int = 4, |
| n_self_attn_layers: int = 1, |
| n_bmca_layers: int = 2, |
| dropout: float = 0.3, |
| max_ligand_length: int = 768, |
| max_protein_length: int = 1024, |
| device: Optional[str] = None, |
| esm_cache_dir: Optional[str] = None, |
| esm_local_files_only: bool = False |
| ): |
| super().__init__() |
|
|
| if isinstance(device, torch.device): |
| device = str(device) |
| self.device = resolve_device(device) |
|
|
| self.max_ligand_length = max_ligand_length |
| self.max_protein_length = max_protein_length |
| self._warned_ligand_truncation = False |
| self._warned_protein_truncation = False |
|
|
| self.lig_tokenizer = SMILES_SPE_Tokenizer(tokenizer_vocab, tokenizer_splits) |
| self.prot_tokenizer = EsmTokenizer.from_pretrained( |
| esm_name, |
| cache_dir=esm_cache_dir, |
| local_files_only=esm_local_files_only |
| ) |
|
|
| tr2d2_cfg = create_tr2d2_config(self.lig_tokenizer.vocab_size) |
| self.model = ESM_TR2D2_GPCRClassifier( |
| esm_name=esm_name, |
| tr2d2_config=tr2d2_cfg, |
| lig_tokenizer=self.lig_tokenizer, |
| tr2d2_checkpoint=tr2d2_checkpoint, |
| d_model=d_model, |
| n_heads=n_heads, |
| n_self_attn_layers=n_self_attn_layers, |
| n_bmca_layers=n_bmca_layers, |
| dropout=dropout, |
| device=self.device, |
| esm_cache_dir=esm_cache_dir, |
| esm_local_files_only=esm_local_files_only |
| ) |
|
|
| state_dict = torch.load(model_ckpt, map_location=self.device, weights_only=False) |
| if isinstance(state_dict, dict) and "model_state_dict" in state_dict: |
| state_dict = state_dict["model_state_dict"] |
| _load_state_dict_flexible(self.model, state_dict, strict=True) |
| self.model.to(self.device).eval() |
|
|
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| self._lig_pad_token_id = self.lig_tokenizer.pad_token_id |
| if self._lig_pad_token_id is None: |
| self._lig_pad_token_id = 0 |
| self._prot_pad_token_id = self.prot_tokenizer.pad_token_id |
| if self._prot_pad_token_id is None: |
| self._prot_pad_token_id = 0 |
|
|
| def encode_protein(self, protein_seq: str) -> torch.Tensor: |
| tokens = self.prot_tokenizer( |
| protein_seq, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.max_protein_length, |
| add_special_tokens=True |
| ) |
| return tokens["input_ids"].to(self.device) |
|
|
| def _normalize_token_dict( |
| self, |
| tokens: torch.Tensor, |
| pad_token_id: int, |
| max_length: int, |
| warned_attr: str |
| ) -> Dict[str, torch.Tensor]: |
| if isinstance(tokens, dict): |
| input_ids = tokens.get("input_ids") |
| if input_ids is None: |
| raise ValueError("Token dict must include input_ids.") |
| attention_mask = tokens.get("attention_mask") |
| input_ids = input_ids.to(self.device) |
| if attention_mask is None: |
| attention_mask = (input_ids != pad_token_id).long() |
| else: |
| attention_mask = attention_mask.to(self.device) |
| else: |
| input_ids = tokens |
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| input_ids = input_ids.to(self.device) |
| attention_mask = (input_ids != pad_token_id).long() |
|
|
| if max_length is not None and input_ids.size(1) > max_length: |
| if not getattr(self, warned_attr): |
| logger.warning( |
| "Truncating input from length %d to max_length=%d", |
| input_ids.size(1), |
| max_length |
| ) |
| setattr(self, warned_attr, True) |
| input_ids = input_ids[:, :max_length] |
| attention_mask = attention_mask[:, :max_length] |
|
|
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| def _normalize_prob_inputs( |
| self, |
| probs: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| max_length: int, |
| warned_attr: str, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if probs.dim() == 2: |
| probs = probs.unsqueeze(0) |
| probs = probs.to(self.device) |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| probs.size(0), probs.size(1), device=self.device, dtype=torch.long |
| ) |
| else: |
| if attention_mask.dim() == 1: |
| attention_mask = attention_mask.unsqueeze(0) |
| attention_mask = attention_mask.to(self.device).long() |
|
|
| if max_length is not None and probs.size(1) > max_length: |
| if not getattr(self, warned_attr): |
| logger.warning( |
| "Truncating input from length %d to max_length=%d", |
| probs.size(1), |
| max_length |
| ) |
| setattr(self, warned_attr, True) |
| probs = probs[:, :max_length] |
| attention_mask = attention_mask[:, :max_length] |
|
|
| return probs, attention_mask |
|
|
| @torch.no_grad() |
| def predict_with_confidence( |
| self, |
| peptide_tokens: torch.Tensor, |
| protein_tokens: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| lig_tokens = self._normalize_token_dict( |
| peptide_tokens, |
| self._lig_pad_token_id, |
| self.max_ligand_length, |
| "_warned_ligand_truncation" |
| ) |
| prot_tokens = self._normalize_token_dict( |
| protein_tokens, |
| self._prot_pad_token_id, |
| self.max_protein_length, |
| "_warned_protein_truncation" |
| ) |
|
|
| lig_batch = lig_tokens["input_ids"].size(0) |
| prot_batch = prot_tokens["input_ids"].size(0) |
| if prot_batch == 1 and lig_batch > 1: |
| prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()} |
| elif prot_batch != lig_batch: |
| raise ValueError( |
| f"Batch size mismatch: peptide_tokens={lig_batch}, protein_tokens={prot_batch}" |
| ) |
|
|
| logits = self.model(prot_tokens, lig_tokens) |
| probs = F.softmax(logits, dim=-1) |
| p_agonist = probs[:, 1] |
| confidence = torch.max(probs, dim=-1).values |
| return p_agonist, confidence |
|
|
| def predict_from_probs( |
| self, |
| ligand_probs: torch.Tensor, |
| protein_tokens: torch.Tensor, |
| ligand_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| lig_probs, lig_attention = self._normalize_prob_inputs( |
| ligand_probs, |
| ligand_attention_mask, |
| self.max_ligand_length, |
| "_warned_ligand_truncation", |
| ) |
| prot_tokens = self._normalize_token_dict( |
| protein_tokens, |
| self._prot_pad_token_id, |
| self.max_protein_length, |
| "_warned_protein_truncation" |
| ) |
|
|
| lig_batch = lig_probs.size(0) |
| prot_batch = prot_tokens["input_ids"].size(0) |
| if prot_batch == 1 and lig_batch > 1: |
| prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()} |
| elif prot_batch != lig_batch: |
| raise ValueError( |
| f"Batch size mismatch: ligand_probs={lig_batch}, protein_tokens={prot_batch}" |
| ) |
|
|
| emb_weight = self.model.ligand_encoder.encoder.model.roformer.embeddings.word_embeddings.weight |
| if lig_probs.size(-1) != emb_weight.size(0): |
| raise ValueError( |
| f"Ligand vocab mismatch: probs={lig_probs.size(-1)} vs oracle={emb_weight.size(0)}" |
| ) |
| lig_inputs_embeds = lig_probs @ emb_weight |
| lig_input_ids = torch.zeros( |
| lig_probs.size(0), lig_probs.size(1), device=lig_probs.device, dtype=torch.long |
| ) |
| lig_tokens = {"input_ids": lig_input_ids, "attention_mask": lig_attention} |
| logits = self.model(prot_tokens, lig_tokens, lig_inputs_embeds=lig_inputs_embeds) |
| probs = F.softmax(logits, dim=-1) |
| return probs[:, 1] |
|
|
| |
| |
| |
| @torch.no_grad() |
| def predict(model, prot_tok, lig_tok, protein_seq, peptide_seq, device, threshold=0.5): |
| """ |
| Predict agonist activity |
| |
| Returns: |
| dict with keys: smiles, non_agonist_prob, agonist_prob, prediction, confidence |
| """ |
| |
| smiles = peptide_to_smiles(peptide_seq) |
| |
| |
| prot_tokens = tokenize_protein(protein_seq, prot_tok, device) |
| lig_tokens = tokenize_ligand(smiles, lig_tok, 768, device) |
| |
| |
| logits = model(prot_tokens, lig_tokens) |
| probs = F.softmax(logits, dim=-1).squeeze(0) |
| |
| p_non_agonist = probs[0].item() |
| p_agonist = probs[1].item() |
| prediction = "agonist" if p_agonist >= threshold else "non-agonist" |
| |
| return { |
| "smiles": smiles, |
| "non_agonist_prob": p_non_agonist, |
| "agonist_prob": p_agonist, |
| "prediction": prediction, |
| "confidence": max(p_non_agonist, p_agonist) |
| } |
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description="GPCR Agonist Classifier - TR2-D2 Inference" |
| ) |
| parser.add_argument("--model_ckpt", required=True, |
| help="Path to trained model checkpoint") |
| parser.add_argument("--tr2d2_checkpoint", required=True, |
| help="Path to TR2-D2 pretrained checkpoint") |
| parser.add_argument("--tokenizer_vocab", required=True, |
| help="Path to tokenizer vocabulary") |
| parser.add_argument("--tokenizer_splits", required=True, |
| help="Path to tokenizer splits") |
| parser.add_argument("--protein_seq", required=True, |
| help="GPCR protein sequence") |
| parser.add_argument("--ligand_peptide", required=True, |
| help="Ligand peptide sequence") |
| parser.add_argument("--threshold", type=float, default=0.5, |
| help="Classification threshold (default: 0.5)") |
| parser.add_argument("--d_model", type=int, default=256, |
| help="Hidden dimension (must match training)") |
| parser.add_argument("--n_heads", type=int, default=4, |
| help="Number of attention heads (must match training)") |
| parser.add_argument("--n_self_attn_layers", type=int, default=1, |
| help="Number of self-attention layers (must match training)") |
| parser.add_argument("--n_bmca_layers", type=int, default=2, |
| help="Number of cross-attention layers (must match training)") |
| parser.add_argument("--dropout", type=float, default=0.3, |
| help="Dropout rate (must match training)") |
| parser.add_argument("--device", default=None, |
| help="Device (cuda/cpu, default: auto)") |
| parser.add_argument("--esm_name", default="facebook/esm2_t33_650M_UR50D", |
| help="ESM model name or local path") |
| parser.add_argument("--esm_cache_dir", default=None, |
| help="Optional cache directory for ESM model") |
| parser.add_argument("--esm_local_files_only", action="store_true", |
| help="Load ESM from local cache only (no network)") |
| |
| args = parser.parse_args() |
| |
| |
| device = resolve_device(args.device) |
| |
| print(f"Device: {device}") |
| print("") |
| |
| |
| print("Loading tokenizers...") |
| prot_tok = EsmTokenizer.from_pretrained( |
| args.esm_name, |
| cache_dir=args.esm_cache_dir, |
| local_files_only=args.esm_local_files_only |
| ) |
| lig_tok = SMILES_SPE_Tokenizer(args.tokenizer_vocab, args.tokenizer_splits) |
| print(f" Vocab size: {lig_tok.vocab_size}") |
| print("") |
| |
| |
| tr2d2_cfg = create_tr2d2_config(lig_tok.vocab_size) |
| |
| |
| print("Loading model...") |
| model = ESM_TR2D2_GPCRClassifier( |
| esm_name=args.esm_name, |
| tr2d2_config=tr2d2_cfg, |
| lig_tokenizer=lig_tok, |
| tr2d2_checkpoint=args.tr2d2_checkpoint, |
| d_model=args.d_model, |
| n_heads=args.n_heads, |
| n_self_attn_layers=args.n_self_attn_layers, |
| n_bmca_layers=args.n_bmca_layers, |
| dropout=args.dropout, |
| device=device, |
| esm_cache_dir=args.esm_cache_dir, |
| esm_local_files_only=args.esm_local_files_only |
| ) |
| |
| |
| print(" Loading trained weights...") |
| state_dict = torch.load(args.model_ckpt, map_location=device) |
| _load_state_dict_flexible(model, state_dict, strict=True) |
| model.to(device).eval() |
| print(" Model ready.") |
| print("") |
| |
| |
| print("Running inference...") |
| result = predict( |
| model, prot_tok, lig_tok, |
| args.protein_seq, args.ligand_peptide, |
| device, args.threshold |
| ) |
| |
| |
| print("") |
| print("=" * 70) |
| print("RESULTS") |
| print("=" * 70) |
| print(f"Protein: {args.protein_seq[:50]}{'...' if len(args.protein_seq) > 50 else ''}") |
| print(f"Ligand: {args.ligand_peptide}") |
| print(f"SMILES: {result['smiles']}") |
| print("") |
| print(f"Non-agonist probability: {result['non_agonist_prob']:.4f}") |
| print(f"Agonist probability: {result['agonist_prob']:.4f}") |
| print("") |
| print(f"Prediction (threshold={args.threshold}): {result['prediction'].upper()}") |
| print(f"Confidence: {result['confidence']:.4f}") |
| print("=" * 70) |
|
|
| if __name__ == "__main__": |
| main() |
|
|