File size: 8,333 Bytes
411e478 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | """
ESM2 backbone + pluggable aggregation head + classification head.
The ESM2 backbone is always frozen. Only the aggregation module and the
classifier head are trained.
ESM2 model variants (all from facebook):
esm2_t6_8M_UR50D -> d=320, 8M params
esm2_t12_35M_UR50D -> d=480, 35M params (default)
esm2_t30_150M_UR50D -> d=640, 150M params
esm2_t33_650M_UR50D -> d=1280, 650M params
esm2_t36_3B_UR50D -> d=2560, 3B params
"""
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from transformers import AutoTokenizer, EsmModel
from .aggregators import (
CLSPooling,
CovariancePooling,
GLOTPooling,
GLOTResidueGraphPooling,
MaxPooling,
MeanPooling,
)
# Map of aggregation method names to classes
AGGREGATOR_REGISTRY = {
"mean": MeanPooling,
"max": MaxPooling,
"cls": CLSPooling,
"glot": GLOTPooling,
"glot_residue": GLOTResidueGraphPooling,
"covariance": CovariancePooling,
}
# ESM2 hidden dimensions by model name
ESM2_HIDDEN_DIMS = {
"facebook/esm2_t6_8M_UR50D": 320,
"facebook/esm2_t12_35M_UR50D": 480,
"facebook/esm2_t30_150M_UR50D": 640,
"facebook/esm2_t33_650M_UR50D": 1280,
"facebook/esm2_t36_3B_UR50D": 2560,
}
class ProteinSequenceClassifier(nn.Module):
"""End-to-end model: frozen ESM2 -> aggregation -> classification.
Args:
esm2_model_name: HuggingFace model ID for ESM2.
aggregation: Name of aggregation method (see AGGREGATOR_REGISTRY).
num_classes: Number of output classes.
aggregator_kwargs: Extra arguments passed to the aggregator constructor.
classifier_hidden: If >0, adds a hidden layer in the classifier head.
dropout: Dropout rate before the classifier.
strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance),
strips the <cls> and <eos> tokens from the ESM2 output
before aggregation. CLS pooling operates on the raw output.
"""
def __init__(
self,
esm2_model_name: str = "facebook/esm2_t12_35M_UR50D",
aggregation: str = "mean",
num_classes: int = 10,
aggregator_kwargs: Optional[Dict] = None,
classifier_hidden: int = 0,
dropout: float = 0.1,
):
super().__init__()
self.esm2_model_name = esm2_model_name
self.aggregation_name = aggregation
# ---- ESM2 backbone (frozen) ----
self.esm2 = EsmModel.from_pretrained(esm2_model_name)
for param in self.esm2.parameters():
param.requires_grad = False
self.esm2.eval()
# ---- Determine hidden size ----
self.d_esm2 = ESM2_HIDDEN_DIMS.get(
esm2_model_name, self.esm2.config.hidden_size
)
# ---- Aggregation head ----
if aggregation not in AGGREGATOR_REGISTRY:
raise ValueError(
f"Unknown aggregation '{aggregation}'. "
f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}"
)
agg_cls = AGGREGATOR_REGISTRY[aggregation]
agg_kwargs = aggregator_kwargs or {}
self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs)
# Whether to strip <cls>/<eos> before aggregation
self.strip_special = aggregation != "cls"
# ---- Classification head ----
agg_dim = self.aggregator.out_dim
if classifier_hidden > 0:
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(agg_dim, classifier_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(classifier_hidden, num_classes),
)
else:
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(agg_dim, num_classes),
)
@property
def tokenizer(self):
"""Lazy-load tokenizer."""
if not hasattr(self, "_tokenizer"):
self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name)
return self._tokenizer
def get_residue_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple:
"""Extract per-residue embeddings from frozen ESM2.
Returns:
token_embeddings: [B, L, d] (optionally with special tokens stripped)
mask: [B, L]
"""
with torch.no_grad():
outputs = self.esm2(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden_states = outputs.last_hidden_state # [B, L_full, d]
if self.strip_special:
# Strip <cls> (pos 0) and <eos> (last valid position)
# For ESM2: input is [<cls>, AA1, AA2, ..., AAN, <eos>, <pad>, ...]
token_embeddings = hidden_states[:, 1:, :] # remove <cls>
mask = attention_mask[:, 1:].clone() # adjust mask
# Now remove the <eos> token for each sequence
# The <eos> is the last 1 in the mask (before padding)
B, L = mask.shape
# Find the position of the last 1 in each row
lengths = mask.sum(dim=1).long() # number of valid tokens after removing <cls>
for i in range(B):
if lengths[i] > 0:
mask[i, lengths[i] - 1] = 0 # zero out <eos> position
else:
token_embeddings = hidden_states
mask = attention_mask
return token_embeddings, mask
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
pdb_paths: Optional[List[Optional[str]]] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Args:
input_ids: [B, L] tokenized protein sequences.
attention_mask: [B, L] attention mask.
labels: [B] class labels (optional, for loss computation).
pdb_paths: List of PDB file paths (only for glot_residue aggregation).
Returns:
Dict with keys: 'logits', optionally 'loss', 'embeddings'.
"""
# Extract residue embeddings from frozen ESM2
token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask)
# Aggregate to sequence-level
extra_kwargs = {}
if pdb_paths is not None:
extra_kwargs["pdb_paths"] = pdb_paths
sequence_embedding = self.aggregator(
token_embeddings, mask, **extra_kwargs
) # [B, agg_dim]
# Classify
logits = self.classifier(sequence_embedding) # [B, num_classes]
result = {"logits": logits, "embeddings": sequence_embedding}
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
result["loss"] = loss_fn(logits, labels)
return result
def encode(
self,
sequences: Union[str, List[str]],
pdb_paths: Optional[List[Optional[str]]] = None,
max_length: int = 1024,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Convenience method: tokenize + forward to get sequence embeddings.
Args:
sequences: Single protein sequence or list of sequences.
pdb_paths: Optional PDB paths for glot_residue aggregation.
max_length: Maximum sequence length (ESM2 supports up to 1026).
device: Device to run on.
Returns:
Sequence-level embeddings [B, agg_dim].
"""
if isinstance(sequences, str):
sequences = [sequences]
if device is None:
device = next(self.parameters()).device
inputs = self.tokenizer(
sequences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
self.eval()
with torch.no_grad():
outputs = self.forward(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pdb_paths=pdb_paths,
)
return outputs["embeddings"]
|