| """ |
| ESM2 backbone + pluggable aggregation head + classification head. |
| |
| The ESM2 backbone is always frozen. Only the aggregation module and the |
| classifier head are trained. |
| |
| ESM2 model variants (all from facebook): |
| esm2_t6_8M_UR50D -> d=320, 8M params |
| esm2_t12_35M_UR50D -> d=480, 35M params (default) |
| esm2_t30_150M_UR50D -> d=640, 150M params |
| esm2_t33_650M_UR50D -> d=1280, 650M params |
| esm2_t36_3B_UR50D -> d=2560, 3B params |
| """ |
|
|
| from typing import Dict, List, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, EsmModel |
|
|
| from .aggregators import ( |
| CLSPooling, |
| CovariancePooling, |
| GLOTPooling, |
| GLOTResidueGraphPooling, |
| MaxPooling, |
| MeanPooling, |
| ) |
|
|
| |
| AGGREGATOR_REGISTRY = { |
| "mean": MeanPooling, |
| "max": MaxPooling, |
| "cls": CLSPooling, |
| "glot": GLOTPooling, |
| "glot_residue": GLOTResidueGraphPooling, |
| "covariance": CovariancePooling, |
| } |
|
|
| |
| ESM2_HIDDEN_DIMS = { |
| "facebook/esm2_t6_8M_UR50D": 320, |
| "facebook/esm2_t12_35M_UR50D": 480, |
| "facebook/esm2_t30_150M_UR50D": 640, |
| "facebook/esm2_t33_650M_UR50D": 1280, |
| "facebook/esm2_t36_3B_UR50D": 2560, |
| } |
|
|
|
|
| class ProteinSequenceClassifier(nn.Module): |
| """End-to-end model: frozen ESM2 -> aggregation -> classification. |
| |
| Args: |
| esm2_model_name: HuggingFace model ID for ESM2. |
| aggregation: Name of aggregation method (see AGGREGATOR_REGISTRY). |
| num_classes: Number of output classes. |
| aggregator_kwargs: Extra arguments passed to the aggregator constructor. |
| classifier_hidden: If >0, adds a hidden layer in the classifier head. |
| dropout: Dropout rate before the classifier. |
| strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance), |
| strips the <cls> and <eos> tokens from the ESM2 output |
| before aggregation. CLS pooling operates on the raw output. |
| """ |
|
|
| def __init__( |
| self, |
| esm2_model_name: str = "facebook/esm2_t12_35M_UR50D", |
| aggregation: str = "mean", |
| num_classes: int = 10, |
| aggregator_kwargs: Optional[Dict] = None, |
| classifier_hidden: int = 0, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.esm2_model_name = esm2_model_name |
| self.aggregation_name = aggregation |
|
|
| |
| self.esm2 = EsmModel.from_pretrained(esm2_model_name) |
| for param in self.esm2.parameters(): |
| param.requires_grad = False |
| self.esm2.eval() |
|
|
| |
| self.d_esm2 = ESM2_HIDDEN_DIMS.get( |
| esm2_model_name, self.esm2.config.hidden_size |
| ) |
|
|
| |
| if aggregation not in AGGREGATOR_REGISTRY: |
| raise ValueError( |
| f"Unknown aggregation '{aggregation}'. " |
| f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}" |
| ) |
|
|
| agg_cls = AGGREGATOR_REGISTRY[aggregation] |
| agg_kwargs = aggregator_kwargs or {} |
| self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs) |
|
|
| |
| self.strip_special = aggregation != "cls" |
|
|
| |
| agg_dim = self.aggregator.out_dim |
| if classifier_hidden > 0: |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(agg_dim, classifier_hidden), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(classifier_hidden, num_classes), |
| ) |
| else: |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(agg_dim, num_classes), |
| ) |
|
|
| @property |
| def tokenizer(self): |
| """Lazy-load tokenizer.""" |
| if not hasattr(self, "_tokenizer"): |
| self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name) |
| return self._tokenizer |
|
|
| def get_residue_embeddings( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> tuple: |
| """Extract per-residue embeddings from frozen ESM2. |
| |
| Returns: |
| token_embeddings: [B, L, d] (optionally with special tokens stripped) |
| mask: [B, L] |
| """ |
| with torch.no_grad(): |
| outputs = self.esm2( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
|
|
| if self.strip_special: |
| |
| |
| token_embeddings = hidden_states[:, 1:, :] |
| mask = attention_mask[:, 1:].clone() |
|
|
| |
| |
| B, L = mask.shape |
| |
| lengths = mask.sum(dim=1).long() |
| for i in range(B): |
| if lengths[i] > 0: |
| mask[i, lengths[i] - 1] = 0 |
| else: |
| token_embeddings = hidden_states |
| mask = attention_mask |
|
|
| return token_embeddings, mask |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| pdb_paths: Optional[List[Optional[str]]] = None, |
| **kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| input_ids: [B, L] tokenized protein sequences. |
| attention_mask: [B, L] attention mask. |
| labels: [B] class labels (optional, for loss computation). |
| pdb_paths: List of PDB file paths (only for glot_residue aggregation). |
| |
| Returns: |
| Dict with keys: 'logits', optionally 'loss', 'embeddings'. |
| """ |
| |
| token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask) |
|
|
| |
| extra_kwargs = {} |
| if pdb_paths is not None: |
| extra_kwargs["pdb_paths"] = pdb_paths |
|
|
| sequence_embedding = self.aggregator( |
| token_embeddings, mask, **extra_kwargs |
| ) |
|
|
| |
| logits = self.classifier(sequence_embedding) |
|
|
| result = {"logits": logits, "embeddings": sequence_embedding} |
|
|
| if labels is not None: |
| loss_fn = nn.CrossEntropyLoss() |
| result["loss"] = loss_fn(logits, labels) |
|
|
| return result |
|
|
| def encode( |
| self, |
| sequences: Union[str, List[str]], |
| pdb_paths: Optional[List[Optional[str]]] = None, |
| max_length: int = 1024, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """Convenience method: tokenize + forward to get sequence embeddings. |
| |
| Args: |
| sequences: Single protein sequence or list of sequences. |
| pdb_paths: Optional PDB paths for glot_residue aggregation. |
| max_length: Maximum sequence length (ESM2 supports up to 1026). |
| device: Device to run on. |
| |
| Returns: |
| Sequence-level embeddings [B, agg_dim]. |
| """ |
| if isinstance(sequences, str): |
| sequences = [sequences] |
|
|
| if device is None: |
| device = next(self.parameters()).device |
|
|
| inputs = self.tokenizer( |
| sequences, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ).to(device) |
|
|
| self.eval() |
| with torch.no_grad(): |
| outputs = self.forward( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| pdb_paths=pdb_paths, |
| ) |
|
|
| return outputs["embeddings"] |
|
|