""" Six token aggregation methods for protein sequence-level representation. All aggregators follow the same interface: Input: token_embeddings [B, L, d], attention_mask [B, L] Output: sequence_embedding [B, out_dim] Optional extra inputs (e.g., PDB paths for GLOTResidueGraphPooling) are passed via keyword arguments. """ import math from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Batch, Data from torch_geometric.nn import GATConv, JumpingKnowledge from torch_geometric.utils import softmax as pyg_softmax # --------------------------------------------------------------------------- # 1. Mean Pooling # --------------------------------------------------------------------------- class MeanPooling(nn.Module): """Average over non-padded token embeddings.""" def __init__(self, d_in: int, **kwargs): super().__init__() self.out_dim = d_in def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).float() # [B, L, 1] summed = (token_embeddings * mask).sum(dim=1) # [B, d] counts = mask.sum(dim=1).clamp(min=1) # [B, 1] return summed / counts # --------------------------------------------------------------------------- # 2. Max Pooling # --------------------------------------------------------------------------- class MaxPooling(nn.Module): """Element-wise max over non-padded token embeddings.""" def __init__(self, d_in: int, **kwargs): super().__init__() self.out_dim = d_in def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: # Set padded positions to -inf so they don't affect max mask = attention_mask.unsqueeze(-1).bool() # [B, L, 1] filled = token_embeddings.masked_fill(~mask, float("-inf")) return filled.max(dim=1).values # [B, d] # --------------------------------------------------------------------------- # 3. CLS Token Pooling # --------------------------------------------------------------------------- class CLSPooling(nn.Module): """Use the [CLS] token (position 0) representation. For ESM2, position 0 is the token added by the tokenizer. NOTE: This operates on the FULL hidden states (before stripping special tokens), so the caller should pass the raw last_hidden_state with CLS still at position 0. """ def __init__(self, d_in: int, **kwargs): super().__init__() self.out_dim = d_in def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: return token_embeddings[:, 0, :] # [B, d] # --------------------------------------------------------------------------- # 4. GLOT Pooling (cosine-similarity token graph) # --------------------------------------------------------------------------- class GLOTPooling(nn.Module): """Graph-Learning Over Tokens (GLOT) pooling. Constructs a token graph based on pairwise cosine similarity of the frozen LLM hidden states. A lightweight GAT-based GNN refines the representations, followed by an attention readout. Reference: arXiv 2603.03389 — Mantri et al., 2025. Args: d_in: Dimensionality of input token embeddings (ESM2 hidden size). p: GNN hidden dimension (default: 128). K: Number of GATConv layers (default: 2). tau: Cosine-similarity threshold for edge creation (default: 0.6). n_heads: Number of GAT attention heads (default: 4). """ def __init__( self, d_in: int, p: int = 128, K: int = 2, tau: float = 0.6, n_heads: int = 4, **kwargs, ): super().__init__() self.tau = tau self.K = K self.p = p # Input projection: d_in -> p self.W_in = nn.Linear(d_in, p) # K layers of GATConv self.gat_layers = nn.ModuleList( [ GATConv(p, p // n_heads, heads=n_heads, concat=True) for _ in range(K) ] ) # Jumping Knowledge: concatenate ALL layer outputs (input proj + K GNN layers) self.jk = JumpingKnowledge(mode="cat") jk_out_dim = p * (K + 1) # Attention readout (Eq. 3 in the paper) self.W_m = nn.Linear(jk_out_dim, p) self.v = nn.Linear(p, 1, bias=False) self.out_dim = jk_out_dim def _build_graph_batch( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, ) -> Batch: """Build a PyG Batch of cosine-similarity token graphs.""" graphs = [] device = token_embeddings.device for i in range(token_embeddings.size(0)): valid = attention_mask[i].bool() h_i = token_embeddings[i][valid] # [L_i, d_in] # Pairwise cosine similarity h_norm = F.normalize(h_i, p=2, dim=-1) S = h_norm @ h_norm.T # [L_i, L_i] # Threshold -> binary adjacency (self-loops included since cos(x,x)=1) A = (S > self.tau) edge_index = A.nonzero(as_tuple=False).T.contiguous().long() # [2, E] graphs.append(Data(x=h_i, edge_index=edge_index)) return Batch.from_data_list(graphs) def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: # Stage 1: Build token graph batch = self._build_graph_batch(token_embeddings, attention_mask) x = batch.x.to(token_embeddings.device) edge_index = batch.edge_index.to(token_embeddings.device) batch_idx = batch.batch.to(token_embeddings.device) # Stage 2: Token-GNN with Jumping Knowledge h = self.W_in(x) # [N_total, p] layer_outputs = [h] for gat in self.gat_layers: h = F.relu(gat(h, edge_index)) layer_outputs.append(h) U_fused = self.jk(layer_outputs) # [N_total, p*(K+1)] # Stage 3: Attention readout (Eq. 3) m = self.v(torch.tanh(self.W_m(U_fused))).squeeze(-1) # [N_total] pi = pyg_softmax(m, batch_idx) # per-graph softmax Z = torch.zeros( token_embeddings.size(0), U_fused.size(-1), device=U_fused.device, ) Z.scatter_add_(0, batch_idx.unsqueeze(-1).expand_as(U_fused), pi.unsqueeze(-1) * U_fused) return Z # [B, p*(K+1)] # --------------------------------------------------------------------------- # 5. GLOT with Protein Residue Graph (via graphein) # --------------------------------------------------------------------------- class GLOTResidueGraphPooling(nn.Module): """GLOT pooling where the token graph is a protein residue contact graph constructed from the 3D structure (PDB file) using graphein. Uses Cα-Cα distance threshold (default 8 Å) plus peptide backbone bonds. If no PDB path is provided, falls back to a sequence-distance graph (edges between residues within ±k positions in the primary sequence). The GNN and readout are identical to standard GLOT. Args: d_in: ESM2 hidden size. p: GNN hidden dimension (default: 128). K: Number of GATConv layers (default: 2). contact_threshold: Cα-Cα distance threshold in Å (default: 8.0). seq_neighbor_k: Fallback: sequence-distance window (default: 5). n_heads: GAT attention heads (default: 4). """ def __init__( self, d_in: int, p: int = 128, K: int = 2, contact_threshold: float = 8.0, seq_neighbor_k: int = 5, n_heads: int = 4, **kwargs, ): super().__init__() self.contact_threshold = contact_threshold self.seq_neighbor_k = seq_neighbor_k self.K = K self.p = p # Input projection self.W_in = nn.Linear(d_in, p) # GATConv layers self.gat_layers = nn.ModuleList( [ GATConv(p, p // n_heads, heads=n_heads, concat=True) for _ in range(K) ] ) # Jumping Knowledge self.jk = JumpingKnowledge(mode="cat") jk_out_dim = p * (K + 1) # Readout self.W_m = nn.Linear(jk_out_dim, p) self.v = nn.Linear(p, 1, bias=False) self.out_dim = jk_out_dim @staticmethod def _build_residue_graph_from_pdb( pdb_path: str, contact_threshold: float, ) -> torch.Tensor: """Build edge_index from a PDB file using graphein. Returns edge_index [2, E] with 0-indexed residue indices. """ from functools import partial from graphein.protein.config import ProteinGraphConfig from graphein.protein.edges.distance import ( add_distance_threshold, add_peptide_bonds, ) from graphein.protein.graphs import construct_graph config = ProteinGraphConfig( graph_construction_functions=[ partial( add_distance_threshold, long_interaction_threshold=0, threshold=contact_threshold, ), add_peptide_bonds, ], ) nx_graph = construct_graph(config=config, pdb_path=pdb_path) # Map node names to sequential 0-based indices node_list = sorted(nx_graph.nodes()) node_to_idx = {n: i for i, n in enumerate(node_list)} edges_src, edges_dst = [], [] for u, v in nx_graph.edges(): edges_src.append(node_to_idx[u]) edges_dst.append(node_to_idx[v]) # Undirected: add reverse edge edges_src.append(node_to_idx[v]) edges_dst.append(node_to_idx[u]) # Add self-loops n_nodes = len(node_list) for i in range(n_nodes): edges_src.append(i) edges_dst.append(i) edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long) return edge_index, n_nodes @staticmethod def _build_sequence_distance_graph( seq_len: int, k: int ) -> torch.Tensor: """Fallback: build edges between residues within ±k positions.""" edges_src, edges_dst = [], [] for i in range(seq_len): for j in range(max(0, i - k), min(seq_len, i + k + 1)): edges_src.append(i) edges_dst.append(j) edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long) return edge_index def _build_graph_batch( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, pdb_paths: Optional[List[Optional[str]]] = None, ) -> Batch: """Build PyG Batch using residue graphs (from PDB or sequence fallback).""" graphs = [] B = token_embeddings.size(0) for i in range(B): valid = attention_mask[i].bool() h_i = token_embeddings[i][valid] # [L_i, d_in] L_i = h_i.size(0) if pdb_paths is not None and pdb_paths[i] is not None: edge_index, n_nodes = self._build_residue_graph_from_pdb( pdb_paths[i], self.contact_threshold ) # Align: graphein graph may have different number of residues # than ESM2 tokens. We use min(n_nodes, L_i) and truncate. n = min(n_nodes, L_i) # Filter edges to only include nodes < n mask_edges = (edge_index[0] < n) & (edge_index[1] < n) edge_index = edge_index[:, mask_edges] h_i = h_i[:n] else: # Sequence-distance fallback edge_index = self._build_sequence_distance_graph( L_i, self.seq_neighbor_k ) graphs.append(Data(x=h_i, edge_index=edge_index)) return Batch.from_data_list(graphs) def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, pdb_paths: Optional[List[Optional[str]]] = None, **kwargs, ) -> torch.Tensor: """ Args: token_embeddings: [B, L, d_in] — ESM2 residue embeddings. attention_mask: [B, L] — 1 for valid residues, 0 for padding. pdb_paths: Optional list of PDB file paths (one per sequence). If None or a path is None, uses sequence-distance fallback. """ batch = self._build_graph_batch(token_embeddings, attention_mask, pdb_paths) x = batch.x.to(token_embeddings.device) edge_index = batch.edge_index.to(token_embeddings.device) batch_idx = batch.batch.to(token_embeddings.device) # GNN h = self.W_in(x) layer_outputs = [h] for gat in self.gat_layers: h = F.relu(gat(h, edge_index)) layer_outputs.append(h) U_fused = self.jk(layer_outputs) # Readout m = self.v(torch.tanh(self.W_m(U_fused))).squeeze(-1) pi = pyg_softmax(m, batch_idx) # Determine number of graphs from batch_idx num_graphs = batch_idx.max().item() + 1 if batch_idx.numel() > 0 else token_embeddings.size(0) Z = torch.zeros( num_graphs, U_fused.size(-1), device=U_fused.device, ) Z.scatter_add_(0, batch_idx.unsqueeze(-1).expand_as(U_fused), pi.unsqueeze(-1) * U_fused) return Z # --------------------------------------------------------------------------- # 6. Covariance Pooling # --------------------------------------------------------------------------- class CovariancePooling(nn.Module): """Second-order covariance pooling for sequence-level representation. Captures pairwise feature co-activation patterns across token positions, providing a richer representation than first-order (mean) statistics. The method: 1. Projects tokens to a lower dimension d_proj to control output size. 2. Mean-centers the projected tokens. 3. Computes the covariance matrix C = X_centered^T @ X_centered / L. 4. Applies power normalization (signed sqrt) for training stability. 5. Extracts the upper triangle as a flat vector. Output dimension = d_proj * (d_proj + 1) / 2. Reference: https://www.goodfire.ai/research/covariance-pooling Args: d_in: Input embedding dimension (ESM2 hidden size). d_proj: Projection dimension before covariance (default: 64). Controls output size: 64 -> 2080, 32 -> 528, 128 -> 8256. """ def __init__(self, d_in: int, d_proj: int = 64, **kwargs): super().__init__() self.d_proj = d_proj self.proj = nn.Linear(d_in, d_proj) self.out_dim = d_proj * (d_proj + 1) // 2 # Pre-compute upper-triangle indices (registered as buffer for device handling) triu_i, triu_j = torch.triu_indices(d_proj, d_proj, offset=0) self.register_buffer("triu_i", triu_i) self.register_buffer("triu_j", triu_j) def forward( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ) -> torch.Tensor: # Project to lower dimension x = self.proj(token_embeddings) # [B, L, d_proj] # Mask padding mask = attention_mask.unsqueeze(-1).float() # [B, L, 1] x = x * mask # Per-sequence token count (avoid division by zero) L_eff = mask.sum(dim=1, keepdim=True).clamp(min=1) # [B, 1, 1] # Mean-center mu = x.sum(dim=1, keepdim=True) / L_eff # [B, 1, d_proj] x_centered = (x - mu) * mask # re-apply mask after centering # Covariance matrix: C = X^T X / (L-1) # Use L_eff - 1 for unbiased estimate, but clamp to avoid div-by-zero denom = (L_eff.squeeze(-1) - 1).clamp(min=1).unsqueeze(-1) # [B, 1, 1] C = torch.bmm(x_centered.transpose(1, 2), x_centered) / denom # [B, d_proj, d_proj] # Power normalization (signed square root) for training stability C = torch.sign(C) * (torch.abs(C) + 1e-7).sqrt() # Extract upper triangle -> flat vector out = C[:, self.triu_i, self.triu_j] # [B, d_proj*(d_proj+1)/2] return out