| # Protein Sequence-Level Prediction with Multiple Token Aggregation Methods |
|
|
| Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction. |
|
|
| ## Aggregation Methods |
|
|
| | # | Method | Class | Output Dim | Description | |
| |---|--------|-------|-----------|-------------| |
| | 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings | |
| | 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings | |
| | 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) | |
| | 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β GAT GNN β attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) | |
| | 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β GAT GNN β attention readout | |
| | 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) | |
|
|
| Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2). |
|
|
| ## Supported ESM2 Backbones |
|
|
| The backbone is changeable β just pass a different model name: |
|
|
| | Model | Params | Hidden Dim | |
| |-------|--------|-----------| |
| | `facebook/esm2_t6_8M_UR50D` | 8M | 320 | |
| | `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) | |
| | `facebook/esm2_t30_150M_UR50D` | 150M | 640 | |
| | `facebook/esm2_t33_650M_UR50D` | 650M | 1280 | |
| | `facebook/esm2_t36_3B_UR50D` | 3B | 2560 | |
|
|
| ## Quick Start |
|
|
| ```python |
| from protein_aggregator import ProteinSequenceClassifier |
| |
| # Build model: frozen ESM2 + GLOT aggregation + 10-class head |
| model = ProteinSequenceClassifier( |
| esm2_model_name="facebook/esm2_t12_35M_UR50D", |
| aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance" |
| num_classes=10, |
| aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6}, |
| classifier_hidden=256, |
| dropout=0.1, |
| ).cuda() |
| |
| # Get sequence-level embeddings |
| embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"]) |
| print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3) |
| |
| # Or full forward pass with loss |
| inputs = model.tokenizer( |
| ["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"], |
| padding=True, truncation=True, return_tensors="pt" |
| ).to("cuda") |
| |
| outputs = model( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| labels=torch.tensor([0, 3]).cuda(), |
| ) |
| loss = outputs["loss"] |
| logits = outputs["logits"] |
| ``` |
|
|
| ## Using GLOT-Residue with PDB Files |
|
|
| When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ
threshold) using [graphein](https://graphein.ai/): |
|
|
| ```python |
| model = ProteinSequenceClassifier( |
| aggregation="glot_residue", |
| num_classes=10, |
| aggregator_kwargs={ |
| "contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ
|
| "seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB |
| }, |
| ) |
| |
| # With PDB files |
| outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"]) |
| |
| # Without PDB files (falls back to sequence-distance graph) |
| outputs = model(input_ids=ids, attention_mask=mask) |
| ``` |
|
|
| ## Using Covariance Pooling |
|
|
| Captures second-order statistics (feature co-activations) across residue positions: |
|
|
| ```python |
| model = ProteinSequenceClassifier( |
| aggregation="covariance", |
| num_classes=10, |
| aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080 |
| ) |
| ``` |
|
|
| The `d_proj` parameter controls the output dimensionality: |
| - `d_proj=32` β 528 dims |
| - `d_proj=64` β 2080 dims (default) |
| - `d_proj=128` β 8256 dims |
|
|
| ## Architecture Overview |
|
|
| ``` |
| Protein Sequence |
| β |
| ββββββββββββββββββββ |
| β ESM2 (frozen) β Extracts per-residue embeddings [B, L, d] |
| ββββββββββββββββββββ |
| β |
| ββββββββββββββββββββ |
| β Aggregator β Compresses token-level β sequence-level [B, agg_dim] |
| β (one of 6) β |
| ββββββββββββββββββββ |
| β |
| ββββββββββββββββββββ |
| β Classifier Head β Linear (+ optional hidden layer) β [B, num_classes] |
| ββββββββββββββββββββ |
| ``` |
|
|
| Only the aggregator and classifier are trained. ESM2 is always frozen. |
|
|
| ## GLOT Details |
|
|
| The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389): |
|
|
| 1. **Token Graph Construction** β For standard GLOT: pairwise cosine similarity between residue embeddings β threshold at Ο (default 0.6) β binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure. |
| 2. **Token-GNN** β K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs. |
| 3. **Attention Readout** β Learned per-token importance scores β softmax β weighted sum to produce the sequence vector. |
|
|
| Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`. |
|
|
| ## Dependencies |
|
|
| ```bash |
| pip install torch torch-geometric transformers |
| # For GLOT-Residue with PDB files: |
| pip install graphein biopython |
| ``` |
|
|
| ## File Structure |
|
|
| ``` |
| protein_aggregator/ |
| βββ __init__.py # Package exports |
| βββ aggregators.py # All 6 aggregation method implementations |
| βββ model.py # ProteinSequenceClassifier (ESM2 + aggregator + head) |
| example_localization.py # Usage example for subcellular localization |
| ``` |
|
|
| ## References |
|
|
| - **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT) |
| - **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D) |
| - **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling) |
| - **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/) |
|
|