AliSaadatV's picture
Add model module and example script
411e478 verified
"""
ESM2 backbone + pluggable aggregation head + classification head.
The ESM2 backbone is always frozen. Only the aggregation module and the
classifier head are trained.
ESM2 model variants (all from facebook):
esm2_t6_8M_UR50D -> d=320, 8M params
esm2_t12_35M_UR50D -> d=480, 35M params (default)
esm2_t30_150M_UR50D -> d=640, 150M params
esm2_t33_650M_UR50D -> d=1280, 650M params
esm2_t36_3B_UR50D -> d=2560, 3B params
"""
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from transformers import AutoTokenizer, EsmModel
from .aggregators import (
CLSPooling,
CovariancePooling,
GLOTPooling,
GLOTResidueGraphPooling,
MaxPooling,
MeanPooling,
)
# Map of aggregation method names to classes
AGGREGATOR_REGISTRY = {
"mean": MeanPooling,
"max": MaxPooling,
"cls": CLSPooling,
"glot": GLOTPooling,
"glot_residue": GLOTResidueGraphPooling,
"covariance": CovariancePooling,
}
# ESM2 hidden dimensions by model name
ESM2_HIDDEN_DIMS = {
"facebook/esm2_t6_8M_UR50D": 320,
"facebook/esm2_t12_35M_UR50D": 480,
"facebook/esm2_t30_150M_UR50D": 640,
"facebook/esm2_t33_650M_UR50D": 1280,
"facebook/esm2_t36_3B_UR50D": 2560,
}
class ProteinSequenceClassifier(nn.Module):
"""End-to-end model: frozen ESM2 -> aggregation -> classification.
Args:
esm2_model_name: HuggingFace model ID for ESM2.
aggregation: Name of aggregation method (see AGGREGATOR_REGISTRY).
num_classes: Number of output classes.
aggregator_kwargs: Extra arguments passed to the aggregator constructor.
classifier_hidden: If >0, adds a hidden layer in the classifier head.
dropout: Dropout rate before the classifier.
strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance),
strips the <cls> and <eos> tokens from the ESM2 output
before aggregation. CLS pooling operates on the raw output.
"""
def __init__(
self,
esm2_model_name: str = "facebook/esm2_t12_35M_UR50D",
aggregation: str = "mean",
num_classes: int = 10,
aggregator_kwargs: Optional[Dict] = None,
classifier_hidden: int = 0,
dropout: float = 0.1,
):
super().__init__()
self.esm2_model_name = esm2_model_name
self.aggregation_name = aggregation
# ---- ESM2 backbone (frozen) ----
self.esm2 = EsmModel.from_pretrained(esm2_model_name)
for param in self.esm2.parameters():
param.requires_grad = False
self.esm2.eval()
# ---- Determine hidden size ----
self.d_esm2 = ESM2_HIDDEN_DIMS.get(
esm2_model_name, self.esm2.config.hidden_size
)
# ---- Aggregation head ----
if aggregation not in AGGREGATOR_REGISTRY:
raise ValueError(
f"Unknown aggregation '{aggregation}'. "
f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}"
)
agg_cls = AGGREGATOR_REGISTRY[aggregation]
agg_kwargs = aggregator_kwargs or {}
self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs)
# Whether to strip <cls>/<eos> before aggregation
self.strip_special = aggregation != "cls"
# ---- Classification head ----
agg_dim = self.aggregator.out_dim
if classifier_hidden > 0:
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(agg_dim, classifier_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(classifier_hidden, num_classes),
)
else:
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(agg_dim, num_classes),
)
@property
def tokenizer(self):
"""Lazy-load tokenizer."""
if not hasattr(self, "_tokenizer"):
self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name)
return self._tokenizer
def get_residue_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple:
"""Extract per-residue embeddings from frozen ESM2.
Returns:
token_embeddings: [B, L, d] (optionally with special tokens stripped)
mask: [B, L]
"""
with torch.no_grad():
outputs = self.esm2(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden_states = outputs.last_hidden_state # [B, L_full, d]
if self.strip_special:
# Strip <cls> (pos 0) and <eos> (last valid position)
# For ESM2: input is [<cls>, AA1, AA2, ..., AAN, <eos>, <pad>, ...]
token_embeddings = hidden_states[:, 1:, :] # remove <cls>
mask = attention_mask[:, 1:].clone() # adjust mask
# Now remove the <eos> token for each sequence
# The <eos> is the last 1 in the mask (before padding)
B, L = mask.shape
# Find the position of the last 1 in each row
lengths = mask.sum(dim=1).long() # number of valid tokens after removing <cls>
for i in range(B):
if lengths[i] > 0:
mask[i, lengths[i] - 1] = 0 # zero out <eos> position
else:
token_embeddings = hidden_states
mask = attention_mask
return token_embeddings, mask
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
pdb_paths: Optional[List[Optional[str]]] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Args:
input_ids: [B, L] tokenized protein sequences.
attention_mask: [B, L] attention mask.
labels: [B] class labels (optional, for loss computation).
pdb_paths: List of PDB file paths (only for glot_residue aggregation).
Returns:
Dict with keys: 'logits', optionally 'loss', 'embeddings'.
"""
# Extract residue embeddings from frozen ESM2
token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask)
# Aggregate to sequence-level
extra_kwargs = {}
if pdb_paths is not None:
extra_kwargs["pdb_paths"] = pdb_paths
sequence_embedding = self.aggregator(
token_embeddings, mask, **extra_kwargs
) # [B, agg_dim]
# Classify
logits = self.classifier(sequence_embedding) # [B, num_classes]
result = {"logits": logits, "embeddings": sequence_embedding}
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
result["loss"] = loss_fn(logits, labels)
return result
def encode(
self,
sequences: Union[str, List[str]],
pdb_paths: Optional[List[Optional[str]]] = None,
max_length: int = 1024,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Convenience method: tokenize + forward to get sequence embeddings.
Args:
sequences: Single protein sequence or list of sequences.
pdb_paths: Optional PDB paths for glot_residue aggregation.
max_length: Maximum sequence length (ESM2 supports up to 1026).
device: Device to run on.
Returns:
Sequence-level embeddings [B, agg_dim].
"""
if isinstance(sequences, str):
sequences = [sequences]
if device is None:
device = next(self.parameters()).device
inputs = self.tokenizer(
sequences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
self.eval()
with torch.no_grad():
outputs = self.forward(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pdb_paths=pdb_paths,
)
return outputs["embeddings"]