File size: 6,540 Bytes
c451d12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # Protein Sequence-Level Prediction with Multiple Token Aggregation Methods
Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction.
## Aggregation Methods
| # | Method | Class | Output Dim | Description |
|---|--------|-------|-----------|-------------|
| 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings |
| 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings |
| 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) |
| 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β GAT GNN β attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) |
| 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β GAT GNN β attention readout |
| 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) |
Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2).
## Supported ESM2 Backbones
The backbone is changeable β just pass a different model name:
| Model | Params | Hidden Dim |
|-------|--------|-----------|
| `facebook/esm2_t6_8M_UR50D` | 8M | 320 |
| `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) |
| `facebook/esm2_t30_150M_UR50D` | 150M | 640 |
| `facebook/esm2_t33_650M_UR50D` | 650M | 1280 |
| `facebook/esm2_t36_3B_UR50D` | 3B | 2560 |
## Quick Start
```python
from protein_aggregator import ProteinSequenceClassifier
# Build model: frozen ESM2 + GLOT aggregation + 10-class head
model = ProteinSequenceClassifier(
esm2_model_name="facebook/esm2_t12_35M_UR50D",
aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance"
num_classes=10,
aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
classifier_hidden=256,
dropout=0.1,
).cuda()
# Get sequence-level embeddings
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3)
# Or full forward pass with loss
inputs = model.tokenizer(
["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
padding=True, truncation=True, return_tensors="pt"
).to("cuda")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=torch.tensor([0, 3]).cuda(),
)
loss = outputs["loss"]
logits = outputs["logits"]
```
## Using GLOT-Residue with PDB Files
When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ
threshold) using [graphein](https://graphein.ai/):
```python
model = ProteinSequenceClassifier(
aggregation="glot_residue",
num_classes=10,
aggregator_kwargs={
"contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ
"seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB
},
)
# With PDB files
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])
# Without PDB files (falls back to sequence-distance graph)
outputs = model(input_ids=ids, attention_mask=mask)
```
## Using Covariance Pooling
Captures second-order statistics (feature co-activations) across residue positions:
```python
model = ProteinSequenceClassifier(
aggregation="covariance",
num_classes=10,
aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080
)
```
The `d_proj` parameter controls the output dimensionality:
- `d_proj=32` β 528 dims
- `d_proj=64` β 2080 dims (default)
- `d_proj=128` β 8256 dims
## Architecture Overview
```
Protein Sequence
β
ββββββββββββββββββββ
β ESM2 (frozen) β Extracts per-residue embeddings [B, L, d]
ββββββββββββββββββββ
β
ββββββββββββββββββββ
β Aggregator β Compresses token-level β sequence-level [B, agg_dim]
β (one of 6) β
ββββββββββββββββββββ
β
ββββββββββββββββββββ
β Classifier Head β Linear (+ optional hidden layer) β [B, num_classes]
ββββββββββββββββββββ
```
Only the aggregator and classifier are trained. ESM2 is always frozen.
## GLOT Details
The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389):
1. **Token Graph Construction** β For standard GLOT: pairwise cosine similarity between residue embeddings β threshold at Ο (default 0.6) β binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
2. **Token-GNN** β K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
3. **Attention Readout** β Learned per-token importance scores β softmax β weighted sum to produce the sequence vector.
Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`.
## Dependencies
```bash
pip install torch torch-geometric transformers
# For GLOT-Residue with PDB files:
pip install graphein biopython
```
## File Structure
```
protein_aggregator/
βββ __init__.py # Package exports
βββ aggregators.py # All 6 aggregation method implementations
βββ model.py # ProteinSequenceClassifier (ESM2 + aggregator + head)
example_localization.py # Usage example for subcellular localization
```
## References
- **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT)
- **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
- **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling)
- **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)
|