# Protein Sequence-Level Prediction with Multiple Token Aggregation Methods Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction. ## Aggregation Methods | # | Method | Class | Output Dim | Description | |---|--------|-------|-----------|-------------| | 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings | | 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings | | 3 | **CLS** | `CLSPooling` | `d` | ESM2's `` token representation (position 0) | | 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph → GAT GNN → attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) | | 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) → GAT GNN → attention readout | | 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) | Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2). ## Supported ESM2 Backbones The backbone is changeable — just pass a different model name: | Model | Params | Hidden Dim | |-------|--------|-----------| | `facebook/esm2_t6_8M_UR50D` | 8M | 320 | | `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) | | `facebook/esm2_t30_150M_UR50D` | 150M | 640 | | `facebook/esm2_t33_650M_UR50D` | 650M | 1280 | | `facebook/esm2_t36_3B_UR50D` | 3B | 2560 | ## Quick Start ```python from protein_aggregator import ProteinSequenceClassifier # Build model: frozen ESM2 + GLOT aggregation + 10-class head model = ProteinSequenceClassifier( esm2_model_name="facebook/esm2_t12_35M_UR50D", aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance" num_classes=10, aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6}, classifier_hidden=256, dropout=0.1, ).cuda() # Get sequence-level embeddings embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"]) print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3) # Or full forward pass with loss inputs = model.tokenizer( ["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"], padding=True, truncation=True, return_tensors="pt" ).to("cuda") outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=torch.tensor([0, 3]).cuda(), ) loss = outputs["loss"] logits = outputs["logits"] ``` ## Using GLOT-Residue with PDB Files When 3D structure is available, `glot_residue` builds the token graph from the protein's Cα-Cα contact map (8Å threshold) using [graphein](https://graphein.ai/): ```python model = ProteinSequenceClassifier( aggregation="glot_residue", num_classes=10, aggregator_kwargs={ "contact_threshold": 8.0, # Cα-Cα distance in Å "seq_neighbor_k": 5, # fallback: ±k sequence neighbors if no PDB }, ) # With PDB files outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"]) # Without PDB files (falls back to sequence-distance graph) outputs = model(input_ids=ids, attention_mask=mask) ``` ## Using Covariance Pooling Captures second-order statistics (feature co-activations) across residue positions: ```python model = ProteinSequenceClassifier( aggregation="covariance", num_classes=10, aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080 ) ``` The `d_proj` parameter controls the output dimensionality: - `d_proj=32` → 528 dims - `d_proj=64` → 2080 dims (default) - `d_proj=128` → 8256 dims ## Architecture Overview ``` Protein Sequence ↓ ┌──────────────────┐ │ ESM2 (frozen) │ Extracts per-residue embeddings [B, L, d] └──────────────────┘ ↓ ┌──────────────────┐ │ Aggregator │ Compresses token-level → sequence-level [B, agg_dim] │ (one of 6) │ └──────────────────┘ ↓ ┌──────────────────┐ │ Classifier Head │ Linear (+ optional hidden layer) → [B, num_classes] └──────────────────┘ ``` Only the aggregator and classifier are trained. ESM2 is always frozen. ## GLOT Details The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389): 1. **Token Graph Construction** — For standard GLOT: pairwise cosine similarity between residue embeddings → threshold at τ (default 0.6) → binary adjacency. For GLOT-Residue: Cα-Cα distance contact map from 3D structure. 2. **Token-GNN** — K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs. 3. **Attention Readout** — Learned per-token importance scores → softmax → weighted sum to produce the sequence vector. Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`. ## Dependencies ```bash pip install torch torch-geometric transformers # For GLOT-Residue with PDB files: pip install graphein biopython ``` ## File Structure ``` protein_aggregator/ ├── __init__.py # Package exports ├── aggregators.py # All 6 aggregation method implementations └── model.py # ProteinSequenceClassifier (ESM2 + aggregator + head) example_localization.py # Usage example for subcellular localization ``` ## References - **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT) - **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D) - **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling) - **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)