AliSaadatV commited on
Commit
411e478
·
verified ·
1 Parent(s): 0eb73db

Add model module and example script

Browse files
Files changed (1) hide show
  1. protein_aggregator/model.py +246 -0
protein_aggregator/model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESM2 backbone + pluggable aggregation head + classification head.
3
+
4
+ The ESM2 backbone is always frozen. Only the aggregation module and the
5
+ classifier head are trained.
6
+
7
+ ESM2 model variants (all from facebook):
8
+ esm2_t6_8M_UR50D -> d=320, 8M params
9
+ esm2_t12_35M_UR50D -> d=480, 35M params (default)
10
+ esm2_t30_150M_UR50D -> d=640, 150M params
11
+ esm2_t33_650M_UR50D -> d=1280, 650M params
12
+ esm2_t36_3B_UR50D -> d=2560, 3B params
13
+ """
14
+
15
+ from typing import Dict, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from transformers import AutoTokenizer, EsmModel
20
+
21
+ from .aggregators import (
22
+ CLSPooling,
23
+ CovariancePooling,
24
+ GLOTPooling,
25
+ GLOTResidueGraphPooling,
26
+ MaxPooling,
27
+ MeanPooling,
28
+ )
29
+
30
+ # Map of aggregation method names to classes
31
+ AGGREGATOR_REGISTRY = {
32
+ "mean": MeanPooling,
33
+ "max": MaxPooling,
34
+ "cls": CLSPooling,
35
+ "glot": GLOTPooling,
36
+ "glot_residue": GLOTResidueGraphPooling,
37
+ "covariance": CovariancePooling,
38
+ }
39
+
40
+ # ESM2 hidden dimensions by model name
41
+ ESM2_HIDDEN_DIMS = {
42
+ "facebook/esm2_t6_8M_UR50D": 320,
43
+ "facebook/esm2_t12_35M_UR50D": 480,
44
+ "facebook/esm2_t30_150M_UR50D": 640,
45
+ "facebook/esm2_t33_650M_UR50D": 1280,
46
+ "facebook/esm2_t36_3B_UR50D": 2560,
47
+ }
48
+
49
+
50
+ class ProteinSequenceClassifier(nn.Module):
51
+ """End-to-end model: frozen ESM2 -> aggregation -> classification.
52
+
53
+ Args:
54
+ esm2_model_name: HuggingFace model ID for ESM2.
55
+ aggregation: Name of aggregation method (see AGGREGATOR_REGISTRY).
56
+ num_classes: Number of output classes.
57
+ aggregator_kwargs: Extra arguments passed to the aggregator constructor.
58
+ classifier_hidden: If >0, adds a hidden layer in the classifier head.
59
+ dropout: Dropout rate before the classifier.
60
+ strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance),
61
+ strips the <cls> and <eos> tokens from the ESM2 output
62
+ before aggregation. CLS pooling operates on the raw output.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ esm2_model_name: str = "facebook/esm2_t12_35M_UR50D",
68
+ aggregation: str = "mean",
69
+ num_classes: int = 10,
70
+ aggregator_kwargs: Optional[Dict] = None,
71
+ classifier_hidden: int = 0,
72
+ dropout: float = 0.1,
73
+ ):
74
+ super().__init__()
75
+ self.esm2_model_name = esm2_model_name
76
+ self.aggregation_name = aggregation
77
+
78
+ # ---- ESM2 backbone (frozen) ----
79
+ self.esm2 = EsmModel.from_pretrained(esm2_model_name)
80
+ for param in self.esm2.parameters():
81
+ param.requires_grad = False
82
+ self.esm2.eval()
83
+
84
+ # ---- Determine hidden size ----
85
+ self.d_esm2 = ESM2_HIDDEN_DIMS.get(
86
+ esm2_model_name, self.esm2.config.hidden_size
87
+ )
88
+
89
+ # ---- Aggregation head ----
90
+ if aggregation not in AGGREGATOR_REGISTRY:
91
+ raise ValueError(
92
+ f"Unknown aggregation '{aggregation}'. "
93
+ f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}"
94
+ )
95
+
96
+ agg_cls = AGGREGATOR_REGISTRY[aggregation]
97
+ agg_kwargs = aggregator_kwargs or {}
98
+ self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs)
99
+
100
+ # Whether to strip <cls>/<eos> before aggregation
101
+ self.strip_special = aggregation != "cls"
102
+
103
+ # ---- Classification head ----
104
+ agg_dim = self.aggregator.out_dim
105
+ if classifier_hidden > 0:
106
+ self.classifier = nn.Sequential(
107
+ nn.Dropout(dropout),
108
+ nn.Linear(agg_dim, classifier_hidden),
109
+ nn.ReLU(),
110
+ nn.Dropout(dropout),
111
+ nn.Linear(classifier_hidden, num_classes),
112
+ )
113
+ else:
114
+ self.classifier = nn.Sequential(
115
+ nn.Dropout(dropout),
116
+ nn.Linear(agg_dim, num_classes),
117
+ )
118
+
119
+ @property
120
+ def tokenizer(self):
121
+ """Lazy-load tokenizer."""
122
+ if not hasattr(self, "_tokenizer"):
123
+ self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name)
124
+ return self._tokenizer
125
+
126
+ def get_residue_embeddings(
127
+ self,
128
+ input_ids: torch.Tensor,
129
+ attention_mask: torch.Tensor,
130
+ ) -> tuple:
131
+ """Extract per-residue embeddings from frozen ESM2.
132
+
133
+ Returns:
134
+ token_embeddings: [B, L, d] (optionally with special tokens stripped)
135
+ mask: [B, L]
136
+ """
137
+ with torch.no_grad():
138
+ outputs = self.esm2(
139
+ input_ids=input_ids,
140
+ attention_mask=attention_mask,
141
+ )
142
+
143
+ hidden_states = outputs.last_hidden_state # [B, L_full, d]
144
+
145
+ if self.strip_special:
146
+ # Strip <cls> (pos 0) and <eos> (last valid position)
147
+ # For ESM2: input is [<cls>, AA1, AA2, ..., AAN, <eos>, <pad>, ...]
148
+ token_embeddings = hidden_states[:, 1:, :] # remove <cls>
149
+ mask = attention_mask[:, 1:].clone() # adjust mask
150
+
151
+ # Now remove the <eos> token for each sequence
152
+ # The <eos> is the last 1 in the mask (before padding)
153
+ B, L = mask.shape
154
+ # Find the position of the last 1 in each row
155
+ lengths = mask.sum(dim=1).long() # number of valid tokens after removing <cls>
156
+ for i in range(B):
157
+ if lengths[i] > 0:
158
+ mask[i, lengths[i] - 1] = 0 # zero out <eos> position
159
+ else:
160
+ token_embeddings = hidden_states
161
+ mask = attention_mask
162
+
163
+ return token_embeddings, mask
164
+
165
+ def forward(
166
+ self,
167
+ input_ids: torch.Tensor,
168
+ attention_mask: torch.Tensor,
169
+ labels: Optional[torch.Tensor] = None,
170
+ pdb_paths: Optional[List[Optional[str]]] = None,
171
+ **kwargs,
172
+ ) -> Dict[str, torch.Tensor]:
173
+ """
174
+ Args:
175
+ input_ids: [B, L] tokenized protein sequences.
176
+ attention_mask: [B, L] attention mask.
177
+ labels: [B] class labels (optional, for loss computation).
178
+ pdb_paths: List of PDB file paths (only for glot_residue aggregation).
179
+
180
+ Returns:
181
+ Dict with keys: 'logits', optionally 'loss', 'embeddings'.
182
+ """
183
+ # Extract residue embeddings from frozen ESM2
184
+ token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask)
185
+
186
+ # Aggregate to sequence-level
187
+ extra_kwargs = {}
188
+ if pdb_paths is not None:
189
+ extra_kwargs["pdb_paths"] = pdb_paths
190
+
191
+ sequence_embedding = self.aggregator(
192
+ token_embeddings, mask, **extra_kwargs
193
+ ) # [B, agg_dim]
194
+
195
+ # Classify
196
+ logits = self.classifier(sequence_embedding) # [B, num_classes]
197
+
198
+ result = {"logits": logits, "embeddings": sequence_embedding}
199
+
200
+ if labels is not None:
201
+ loss_fn = nn.CrossEntropyLoss()
202
+ result["loss"] = loss_fn(logits, labels)
203
+
204
+ return result
205
+
206
+ def encode(
207
+ self,
208
+ sequences: Union[str, List[str]],
209
+ pdb_paths: Optional[List[Optional[str]]] = None,
210
+ max_length: int = 1024,
211
+ device: Optional[torch.device] = None,
212
+ ) -> torch.Tensor:
213
+ """Convenience method: tokenize + forward to get sequence embeddings.
214
+
215
+ Args:
216
+ sequences: Single protein sequence or list of sequences.
217
+ pdb_paths: Optional PDB paths for glot_residue aggregation.
218
+ max_length: Maximum sequence length (ESM2 supports up to 1026).
219
+ device: Device to run on.
220
+
221
+ Returns:
222
+ Sequence-level embeddings [B, agg_dim].
223
+ """
224
+ if isinstance(sequences, str):
225
+ sequences = [sequences]
226
+
227
+ if device is None:
228
+ device = next(self.parameters()).device
229
+
230
+ inputs = self.tokenizer(
231
+ sequences,
232
+ padding=True,
233
+ truncation=True,
234
+ max_length=max_length,
235
+ return_tensors="pt",
236
+ ).to(device)
237
+
238
+ self.eval()
239
+ with torch.no_grad():
240
+ outputs = self.forward(
241
+ input_ids=inputs["input_ids"],
242
+ attention_mask=inputs["attention_mask"],
243
+ pdb_paths=pdb_paths,
244
+ )
245
+
246
+ return outputs["embeddings"]