AliSaadatV commited on
Commit
0eb73db
·
verified ·
1 Parent(s): 28bbe23

Add aggregators module with 6 methods

Browse files
Files changed (1) hide show
  1. protein_aggregator/aggregators.py +478 -0
protein_aggregator/aggregators.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Six token aggregation methods for protein sequence-level representation.
3
+
4
+ All aggregators follow the same interface:
5
+ Input: token_embeddings [B, L, d], attention_mask [B, L]
6
+ Output: sequence_embedding [B, out_dim]
7
+
8
+ Optional extra inputs (e.g., PDB paths for GLOTResidueGraphPooling) are passed
9
+ via keyword arguments.
10
+ """
11
+
12
+ import math
13
+ from typing import List, Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch_geometric.data import Batch, Data
19
+ from torch_geometric.nn import GATConv, JumpingKnowledge
20
+ from torch_geometric.utils import softmax as pyg_softmax
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # 1. Mean Pooling
25
+ # ---------------------------------------------------------------------------
26
+ class MeanPooling(nn.Module):
27
+ """Average over non-padded token embeddings."""
28
+
29
+ def __init__(self, d_in: int, **kwargs):
30
+ super().__init__()
31
+ self.out_dim = d_in
32
+
33
+ def forward(
34
+ self,
35
+ token_embeddings: torch.Tensor,
36
+ attention_mask: torch.Tensor,
37
+ **kwargs,
38
+ ) -> torch.Tensor:
39
+ mask = attention_mask.unsqueeze(-1).float() # [B, L, 1]
40
+ summed = (token_embeddings * mask).sum(dim=1) # [B, d]
41
+ counts = mask.sum(dim=1).clamp(min=1) # [B, 1]
42
+ return summed / counts
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # 2. Max Pooling
47
+ # ---------------------------------------------------------------------------
48
+ class MaxPooling(nn.Module):
49
+ """Element-wise max over non-padded token embeddings."""
50
+
51
+ def __init__(self, d_in: int, **kwargs):
52
+ super().__init__()
53
+ self.out_dim = d_in
54
+
55
+ def forward(
56
+ self,
57
+ token_embeddings: torch.Tensor,
58
+ attention_mask: torch.Tensor,
59
+ **kwargs,
60
+ ) -> torch.Tensor:
61
+ # Set padded positions to -inf so they don't affect max
62
+ mask = attention_mask.unsqueeze(-1).bool() # [B, L, 1]
63
+ filled = token_embeddings.masked_fill(~mask, float("-inf"))
64
+ return filled.max(dim=1).values # [B, d]
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # 3. CLS Token Pooling
69
+ # ---------------------------------------------------------------------------
70
+ class CLSPooling(nn.Module):
71
+ """Use the [CLS] token (position 0) representation.
72
+
73
+ For ESM2, position 0 is the <cls> token added by the tokenizer.
74
+ NOTE: This operates on the FULL hidden states (before stripping special
75
+ tokens), so the caller should pass the raw last_hidden_state with CLS
76
+ still at position 0.
77
+ """
78
+
79
+ def __init__(self, d_in: int, **kwargs):
80
+ super().__init__()
81
+ self.out_dim = d_in
82
+
83
+ def forward(
84
+ self,
85
+ token_embeddings: torch.Tensor,
86
+ attention_mask: torch.Tensor,
87
+ **kwargs,
88
+ ) -> torch.Tensor:
89
+ return token_embeddings[:, 0, :] # [B, d]
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # 4. GLOT Pooling (cosine-similarity token graph)
94
+ # ---------------------------------------------------------------------------
95
+ class GLOTPooling(nn.Module):
96
+ """Graph-Learning Over Tokens (GLOT) pooling.
97
+
98
+ Constructs a token graph based on pairwise cosine similarity of the
99
+ frozen LLM hidden states. A lightweight GAT-based GNN refines the
100
+ representations, followed by an attention readout.
101
+
102
+ Reference: arXiv 2603.03389 — Mantri et al., 2025.
103
+
104
+ Args:
105
+ d_in: Dimensionality of input token embeddings (ESM2 hidden size).
106
+ p: GNN hidden dimension (default: 128).
107
+ K: Number of GATConv layers (default: 2).
108
+ tau: Cosine-similarity threshold for edge creation (default: 0.6).
109
+ n_heads: Number of GAT attention heads (default: 4).
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ d_in: int,
115
+ p: int = 128,
116
+ K: int = 2,
117
+ tau: float = 0.6,
118
+ n_heads: int = 4,
119
+ **kwargs,
120
+ ):
121
+ super().__init__()
122
+ self.tau = tau
123
+ self.K = K
124
+ self.p = p
125
+
126
+ # Input projection: d_in -> p
127
+ self.W_in = nn.Linear(d_in, p)
128
+
129
+ # K layers of GATConv
130
+ self.gat_layers = nn.ModuleList(
131
+ [
132
+ GATConv(p, p // n_heads, heads=n_heads, concat=True)
133
+ for _ in range(K)
134
+ ]
135
+ )
136
+
137
+ # Jumping Knowledge: concatenate ALL layer outputs (input proj + K GNN layers)
138
+ self.jk = JumpingKnowledge(mode="cat")
139
+ jk_out_dim = p * (K + 1)
140
+
141
+ # Attention readout (Eq. 3 in the paper)
142
+ self.W_m = nn.Linear(jk_out_dim, p)
143
+ self.v = nn.Linear(p, 1, bias=False)
144
+
145
+ self.out_dim = jk_out_dim
146
+
147
+ def _build_graph_batch(
148
+ self,
149
+ token_embeddings: torch.Tensor,
150
+ attention_mask: torch.Tensor,
151
+ ) -> Batch:
152
+ """Build a PyG Batch of cosine-similarity token graphs."""
153
+ graphs = []
154
+ device = token_embeddings.device
155
+
156
+ for i in range(token_embeddings.size(0)):
157
+ valid = attention_mask[i].bool()
158
+ h_i = token_embeddings[i][valid] # [L_i, d_in]
159
+
160
+ # Pairwise cosine similarity
161
+ h_norm = F.normalize(h_i, p=2, dim=-1)
162
+ S = h_norm @ h_norm.T # [L_i, L_i]
163
+
164
+ # Threshold -> binary adjacency (self-loops included since cos(x,x)=1)
165
+ A = (S > self.tau)
166
+ edge_index = A.nonzero(as_tuple=False).T.contiguous().long() # [2, E]
167
+
168
+ graphs.append(Data(x=h_i, edge_index=edge_index))
169
+
170
+ return Batch.from_data_list(graphs)
171
+
172
+ def forward(
173
+ self,
174
+ token_embeddings: torch.Tensor,
175
+ attention_mask: torch.Tensor,
176
+ **kwargs,
177
+ ) -> torch.Tensor:
178
+ # Stage 1: Build token graph
179
+ batch = self._build_graph_batch(token_embeddings, attention_mask)
180
+ x = batch.x.to(token_embeddings.device)
181
+ edge_index = batch.edge_index.to(token_embeddings.device)
182
+ batch_idx = batch.batch.to(token_embeddings.device)
183
+
184
+ # Stage 2: Token-GNN with Jumping Knowledge
185
+ h = self.W_in(x) # [N_total, p]
186
+ layer_outputs = [h]
187
+ for gat in self.gat_layers:
188
+ h = F.relu(gat(h, edge_index))
189
+ layer_outputs.append(h)
190
+
191
+ U_fused = self.jk(layer_outputs) # [N_total, p*(K+1)]
192
+
193
+ # Stage 3: Attention readout (Eq. 3)
194
+ m = self.v(torch.tanh(self.W_m(U_fused))).squeeze(-1) # [N_total]
195
+ pi = pyg_softmax(m, batch_idx) # per-graph softmax
196
+ Z = torch.zeros(
197
+ token_embeddings.size(0),
198
+ U_fused.size(-1),
199
+ device=U_fused.device,
200
+ )
201
+ Z.scatter_add_(0, batch_idx.unsqueeze(-1).expand_as(U_fused), pi.unsqueeze(-1) * U_fused)
202
+
203
+ return Z # [B, p*(K+1)]
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # 5. GLOT with Protein Residue Graph (via graphein)
208
+ # ---------------------------------------------------------------------------
209
+ class GLOTResidueGraphPooling(nn.Module):
210
+ """GLOT pooling where the token graph is a protein residue contact graph
211
+ constructed from the 3D structure (PDB file) using graphein.
212
+
213
+ Uses Cα-Cα distance threshold (default 8 Å) plus peptide backbone bonds.
214
+ If no PDB path is provided, falls back to a sequence-distance graph
215
+ (edges between residues within ±k positions in the primary sequence).
216
+
217
+ The GNN and readout are identical to standard GLOT.
218
+
219
+ Args:
220
+ d_in: ESM2 hidden size.
221
+ p: GNN hidden dimension (default: 128).
222
+ K: Number of GATConv layers (default: 2).
223
+ contact_threshold: Cα-Cα distance threshold in Å (default: 8.0).
224
+ seq_neighbor_k: Fallback: sequence-distance window (default: 5).
225
+ n_heads: GAT attention heads (default: 4).
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ d_in: int,
231
+ p: int = 128,
232
+ K: int = 2,
233
+ contact_threshold: float = 8.0,
234
+ seq_neighbor_k: int = 5,
235
+ n_heads: int = 4,
236
+ **kwargs,
237
+ ):
238
+ super().__init__()
239
+ self.contact_threshold = contact_threshold
240
+ self.seq_neighbor_k = seq_neighbor_k
241
+ self.K = K
242
+ self.p = p
243
+
244
+ # Input projection
245
+ self.W_in = nn.Linear(d_in, p)
246
+
247
+ # GATConv layers
248
+ self.gat_layers = nn.ModuleList(
249
+ [
250
+ GATConv(p, p // n_heads, heads=n_heads, concat=True)
251
+ for _ in range(K)
252
+ ]
253
+ )
254
+
255
+ # Jumping Knowledge
256
+ self.jk = JumpingKnowledge(mode="cat")
257
+ jk_out_dim = p * (K + 1)
258
+
259
+ # Readout
260
+ self.W_m = nn.Linear(jk_out_dim, p)
261
+ self.v = nn.Linear(p, 1, bias=False)
262
+
263
+ self.out_dim = jk_out_dim
264
+
265
+ @staticmethod
266
+ def _build_residue_graph_from_pdb(
267
+ pdb_path: str,
268
+ contact_threshold: float,
269
+ ) -> torch.Tensor:
270
+ """Build edge_index from a PDB file using graphein.
271
+
272
+ Returns edge_index [2, E] with 0-indexed residue indices.
273
+ """
274
+ from functools import partial
275
+
276
+ from graphein.protein.config import ProteinGraphConfig
277
+ from graphein.protein.edges.distance import (
278
+ add_distance_threshold,
279
+ add_peptide_bonds,
280
+ )
281
+ from graphein.protein.graphs import construct_graph
282
+
283
+ config = ProteinGraphConfig(
284
+ graph_construction_functions=[
285
+ partial(
286
+ add_distance_threshold,
287
+ long_interaction_threshold=0,
288
+ threshold=contact_threshold,
289
+ ),
290
+ add_peptide_bonds,
291
+ ],
292
+ )
293
+
294
+ nx_graph = construct_graph(config=config, pdb_path=pdb_path)
295
+
296
+ # Map node names to sequential 0-based indices
297
+ node_list = sorted(nx_graph.nodes())
298
+ node_to_idx = {n: i for i, n in enumerate(node_list)}
299
+
300
+ edges_src, edges_dst = [], []
301
+ for u, v in nx_graph.edges():
302
+ edges_src.append(node_to_idx[u])
303
+ edges_dst.append(node_to_idx[v])
304
+ # Undirected: add reverse edge
305
+ edges_src.append(node_to_idx[v])
306
+ edges_dst.append(node_to_idx[u])
307
+
308
+ # Add self-loops
309
+ n_nodes = len(node_list)
310
+ for i in range(n_nodes):
311
+ edges_src.append(i)
312
+ edges_dst.append(i)
313
+
314
+ edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
315
+ return edge_index, n_nodes
316
+
317
+ @staticmethod
318
+ def _build_sequence_distance_graph(
319
+ seq_len: int, k: int
320
+ ) -> torch.Tensor:
321
+ """Fallback: build edges between residues within ±k positions."""
322
+ edges_src, edges_dst = [], []
323
+ for i in range(seq_len):
324
+ for j in range(max(0, i - k), min(seq_len, i + k + 1)):
325
+ edges_src.append(i)
326
+ edges_dst.append(j)
327
+ edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
328
+ return edge_index
329
+
330
+ def _build_graph_batch(
331
+ self,
332
+ token_embeddings: torch.Tensor,
333
+ attention_mask: torch.Tensor,
334
+ pdb_paths: Optional[List[Optional[str]]] = None,
335
+ ) -> Batch:
336
+ """Build PyG Batch using residue graphs (from PDB or sequence fallback)."""
337
+ graphs = []
338
+ B = token_embeddings.size(0)
339
+
340
+ for i in range(B):
341
+ valid = attention_mask[i].bool()
342
+ h_i = token_embeddings[i][valid] # [L_i, d_in]
343
+ L_i = h_i.size(0)
344
+
345
+ if pdb_paths is not None and pdb_paths[i] is not None:
346
+ edge_index, n_nodes = self._build_residue_graph_from_pdb(
347
+ pdb_paths[i], self.contact_threshold
348
+ )
349
+ # Align: graphein graph may have different number of residues
350
+ # than ESM2 tokens. We use min(n_nodes, L_i) and truncate.
351
+ n = min(n_nodes, L_i)
352
+ # Filter edges to only include nodes < n
353
+ mask_edges = (edge_index[0] < n) & (edge_index[1] < n)
354
+ edge_index = edge_index[:, mask_edges]
355
+ h_i = h_i[:n]
356
+ else:
357
+ # Sequence-distance fallback
358
+ edge_index = self._build_sequence_distance_graph(
359
+ L_i, self.seq_neighbor_k
360
+ )
361
+
362
+ graphs.append(Data(x=h_i, edge_index=edge_index))
363
+
364
+ return Batch.from_data_list(graphs)
365
+
366
+ def forward(
367
+ self,
368
+ token_embeddings: torch.Tensor,
369
+ attention_mask: torch.Tensor,
370
+ pdb_paths: Optional[List[Optional[str]]] = None,
371
+ **kwargs,
372
+ ) -> torch.Tensor:
373
+ """
374
+ Args:
375
+ token_embeddings: [B, L, d_in] — ESM2 residue embeddings.
376
+ attention_mask: [B, L] — 1 for valid residues, 0 for padding.
377
+ pdb_paths: Optional list of PDB file paths (one per sequence).
378
+ If None or a path is None, uses sequence-distance fallback.
379
+ """
380
+ batch = self._build_graph_batch(token_embeddings, attention_mask, pdb_paths)
381
+ x = batch.x.to(token_embeddings.device)
382
+ edge_index = batch.edge_index.to(token_embeddings.device)
383
+ batch_idx = batch.batch.to(token_embeddings.device)
384
+
385
+ # GNN
386
+ h = self.W_in(x)
387
+ layer_outputs = [h]
388
+ for gat in self.gat_layers:
389
+ h = F.relu(gat(h, edge_index))
390
+ layer_outputs.append(h)
391
+
392
+ U_fused = self.jk(layer_outputs)
393
+
394
+ # Readout
395
+ m = self.v(torch.tanh(self.W_m(U_fused))).squeeze(-1)
396
+ pi = pyg_softmax(m, batch_idx)
397
+
398
+ # Determine number of graphs from batch_idx
399
+ num_graphs = batch_idx.max().item() + 1 if batch_idx.numel() > 0 else token_embeddings.size(0)
400
+ Z = torch.zeros(
401
+ num_graphs,
402
+ U_fused.size(-1),
403
+ device=U_fused.device,
404
+ )
405
+ Z.scatter_add_(0, batch_idx.unsqueeze(-1).expand_as(U_fused), pi.unsqueeze(-1) * U_fused)
406
+
407
+ return Z
408
+
409
+
410
+ # ---------------------------------------------------------------------------
411
+ # 6. Covariance Pooling
412
+ # ---------------------------------------------------------------------------
413
+ class CovariancePooling(nn.Module):
414
+ """Second-order covariance pooling for sequence-level representation.
415
+
416
+ Captures pairwise feature co-activation patterns across token positions,
417
+ providing a richer representation than first-order (mean) statistics.
418
+
419
+ The method:
420
+ 1. Projects tokens to a lower dimension d_proj to control output size.
421
+ 2. Mean-centers the projected tokens.
422
+ 3. Computes the covariance matrix C = X_centered^T @ X_centered / L.
423
+ 4. Applies power normalization (signed sqrt) for training stability.
424
+ 5. Extracts the upper triangle as a flat vector.
425
+
426
+ Output dimension = d_proj * (d_proj + 1) / 2.
427
+
428
+ Reference: https://www.goodfire.ai/research/covariance-pooling
429
+
430
+ Args:
431
+ d_in: Input embedding dimension (ESM2 hidden size).
432
+ d_proj: Projection dimension before covariance (default: 64).
433
+ Controls output size: 64 -> 2080, 32 -> 528, 128 -> 8256.
434
+ """
435
+
436
+ def __init__(self, d_in: int, d_proj: int = 64, **kwargs):
437
+ super().__init__()
438
+ self.d_proj = d_proj
439
+ self.proj = nn.Linear(d_in, d_proj)
440
+ self.out_dim = d_proj * (d_proj + 1) // 2
441
+
442
+ # Pre-compute upper-triangle indices (registered as buffer for device handling)
443
+ triu_i, triu_j = torch.triu_indices(d_proj, d_proj, offset=0)
444
+ self.register_buffer("triu_i", triu_i)
445
+ self.register_buffer("triu_j", triu_j)
446
+
447
+ def forward(
448
+ self,
449
+ token_embeddings: torch.Tensor,
450
+ attention_mask: torch.Tensor,
451
+ **kwargs,
452
+ ) -> torch.Tensor:
453
+ # Project to lower dimension
454
+ x = self.proj(token_embeddings) # [B, L, d_proj]
455
+
456
+ # Mask padding
457
+ mask = attention_mask.unsqueeze(-1).float() # [B, L, 1]
458
+ x = x * mask
459
+
460
+ # Per-sequence token count (avoid division by zero)
461
+ L_eff = mask.sum(dim=1, keepdim=True).clamp(min=1) # [B, 1, 1]
462
+
463
+ # Mean-center
464
+ mu = x.sum(dim=1, keepdim=True) / L_eff # [B, 1, d_proj]
465
+ x_centered = (x - mu) * mask # re-apply mask after centering
466
+
467
+ # Covariance matrix: C = X^T X / (L-1)
468
+ # Use L_eff - 1 for unbiased estimate, but clamp to avoid div-by-zero
469
+ denom = (L_eff.squeeze(-1) - 1).clamp(min=1).unsqueeze(-1) # [B, 1, 1]
470
+ C = torch.bmm(x_centered.transpose(1, 2), x_centered) / denom # [B, d_proj, d_proj]
471
+
472
+ # Power normalization (signed square root) for training stability
473
+ C = torch.sign(C) * (torch.abs(C) + 1e-7).sqrt()
474
+
475
+ # Extract upper triangle -> flat vector
476
+ out = C[:, self.triu_i, self.triu_j] # [B, d_proj*(d_proj+1)/2]
477
+
478
+ return out