""" ESM2 backbone + pluggable aggregation head + classification head. The ESM2 backbone is always frozen. Only the aggregation module and the classifier head are trained. ESM2 model variants (all from facebook): esm2_t6_8M_UR50D -> d=320, 8M params esm2_t12_35M_UR50D -> d=480, 35M params (default) esm2_t30_150M_UR50D -> d=640, 150M params esm2_t33_650M_UR50D -> d=1280, 650M params esm2_t36_3B_UR50D -> d=2560, 3B params """ from typing import Dict, List, Optional, Union import torch import torch.nn as nn from transformers import AutoTokenizer, EsmModel from .aggregators import ( CLSPooling, CovariancePooling, GLOTPooling, GLOTResidueGraphPooling, MaxPooling, MeanPooling, ) # Map of aggregation method names to classes AGGREGATOR_REGISTRY = { "mean": MeanPooling, "max": MaxPooling, "cls": CLSPooling, "glot": GLOTPooling, "glot_residue": GLOTResidueGraphPooling, "covariance": CovariancePooling, } # ESM2 hidden dimensions by model name ESM2_HIDDEN_DIMS = { "facebook/esm2_t6_8M_UR50D": 320, "facebook/esm2_t12_35M_UR50D": 480, "facebook/esm2_t30_150M_UR50D": 640, "facebook/esm2_t33_650M_UR50D": 1280, "facebook/esm2_t36_3B_UR50D": 2560, } class ProteinSequenceClassifier(nn.Module): """End-to-end model: frozen ESM2 -> aggregation -> classification. Args: esm2_model_name: HuggingFace model ID for ESM2. aggregation: Name of aggregation method (see AGGREGATOR_REGISTRY). num_classes: Number of output classes. aggregator_kwargs: Extra arguments passed to the aggregator constructor. classifier_hidden: If >0, adds a hidden layer in the classifier head. dropout: Dropout rate before the classifier. strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance), strips the and tokens from the ESM2 output before aggregation. CLS pooling operates on the raw output. """ def __init__( self, esm2_model_name: str = "facebook/esm2_t12_35M_UR50D", aggregation: str = "mean", num_classes: int = 10, aggregator_kwargs: Optional[Dict] = None, classifier_hidden: int = 0, dropout: float = 0.1, ): super().__init__() self.esm2_model_name = esm2_model_name self.aggregation_name = aggregation # ---- ESM2 backbone (frozen) ---- self.esm2 = EsmModel.from_pretrained(esm2_model_name) for param in self.esm2.parameters(): param.requires_grad = False self.esm2.eval() # ---- Determine hidden size ---- self.d_esm2 = ESM2_HIDDEN_DIMS.get( esm2_model_name, self.esm2.config.hidden_size ) # ---- Aggregation head ---- if aggregation not in AGGREGATOR_REGISTRY: raise ValueError( f"Unknown aggregation '{aggregation}'. " f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}" ) agg_cls = AGGREGATOR_REGISTRY[aggregation] agg_kwargs = aggregator_kwargs or {} self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs) # Whether to strip / before aggregation self.strip_special = aggregation != "cls" # ---- Classification head ---- agg_dim = self.aggregator.out_dim if classifier_hidden > 0: self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(agg_dim, classifier_hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(classifier_hidden, num_classes), ) else: self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(agg_dim, num_classes), ) @property def tokenizer(self): """Lazy-load tokenizer.""" if not hasattr(self, "_tokenizer"): self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name) return self._tokenizer def get_residue_embeddings( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> tuple: """Extract per-residue embeddings from frozen ESM2. Returns: token_embeddings: [B, L, d] (optionally with special tokens stripped) mask: [B, L] """ with torch.no_grad(): outputs = self.esm2( input_ids=input_ids, attention_mask=attention_mask, ) hidden_states = outputs.last_hidden_state # [B, L_full, d] if self.strip_special: # Strip (pos 0) and (last valid position) # For ESM2: input is [, AA1, AA2, ..., AAN, , , ...] token_embeddings = hidden_states[:, 1:, :] # remove mask = attention_mask[:, 1:].clone() # adjust mask # Now remove the token for each sequence # The is the last 1 in the mask (before padding) B, L = mask.shape # Find the position of the last 1 in each row lengths = mask.sum(dim=1).long() # number of valid tokens after removing for i in range(B): if lengths[i] > 0: mask[i, lengths[i] - 1] = 0 # zero out position else: token_embeddings = hidden_states mask = attention_mask return token_embeddings, mask def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, pdb_paths: Optional[List[Optional[str]]] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """ Args: input_ids: [B, L] tokenized protein sequences. attention_mask: [B, L] attention mask. labels: [B] class labels (optional, for loss computation). pdb_paths: List of PDB file paths (only for glot_residue aggregation). Returns: Dict with keys: 'logits', optionally 'loss', 'embeddings'. """ # Extract residue embeddings from frozen ESM2 token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask) # Aggregate to sequence-level extra_kwargs = {} if pdb_paths is not None: extra_kwargs["pdb_paths"] = pdb_paths sequence_embedding = self.aggregator( token_embeddings, mask, **extra_kwargs ) # [B, agg_dim] # Classify logits = self.classifier(sequence_embedding) # [B, num_classes] result = {"logits": logits, "embeddings": sequence_embedding} if labels is not None: loss_fn = nn.CrossEntropyLoss() result["loss"] = loss_fn(logits, labels) return result def encode( self, sequences: Union[str, List[str]], pdb_paths: Optional[List[Optional[str]]] = None, max_length: int = 1024, device: Optional[torch.device] = None, ) -> torch.Tensor: """Convenience method: tokenize + forward to get sequence embeddings. Args: sequences: Single protein sequence or list of sequences. pdb_paths: Optional PDB paths for glot_residue aggregation. max_length: Maximum sequence length (ESM2 supports up to 1026). device: Device to run on. Returns: Sequence-level embeddings [B, agg_dim]. """ if isinstance(sequences, str): sequences = [sequences] if device is None: device = next(self.parameters()).device inputs = self.tokenizer( sequences, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ).to(device) self.eval() with torch.no_grad(): outputs = self.forward( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], pdb_paths=pdb_paths, ) return outputs["embeddings"]