AliSaadatV's picture
Add README
c451d12 verified
# Protein Sequence-Level Prediction with Multiple Token Aggregation Methods
Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction.
## Aggregation Methods
| # | Method | Class | Output Dim | Description |
|---|--------|-------|-----------|-------------|
| 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings |
| 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings |
| 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) |
| 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β†’ GAT GNN β†’ attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) |
| 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β†’ GAT GNN β†’ attention readout |
| 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) |
Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2).
## Supported ESM2 Backbones
The backbone is changeable β€” just pass a different model name:
| Model | Params | Hidden Dim |
|-------|--------|-----------|
| `facebook/esm2_t6_8M_UR50D` | 8M | 320 |
| `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) |
| `facebook/esm2_t30_150M_UR50D` | 150M | 640 |
| `facebook/esm2_t33_650M_UR50D` | 650M | 1280 |
| `facebook/esm2_t36_3B_UR50D` | 3B | 2560 |
## Quick Start
```python
from protein_aggregator import ProteinSequenceClassifier
# Build model: frozen ESM2 + GLOT aggregation + 10-class head
model = ProteinSequenceClassifier(
esm2_model_name="facebook/esm2_t12_35M_UR50D",
aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance"
num_classes=10,
aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
classifier_hidden=256,
dropout=0.1,
).cuda()
# Get sequence-level embeddings
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3)
# Or full forward pass with loss
inputs = model.tokenizer(
["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
padding=True, truncation=True, return_tensors="pt"
).to("cuda")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=torch.tensor([0, 3]).cuda(),
)
loss = outputs["loss"]
logits = outputs["logits"]
```
## Using GLOT-Residue with PDB Files
When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ… threshold) using [graphein](https://graphein.ai/):
```python
model = ProteinSequenceClassifier(
aggregation="glot_residue",
num_classes=10,
aggregator_kwargs={
"contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ…
"seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB
},
)
# With PDB files
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])
# Without PDB files (falls back to sequence-distance graph)
outputs = model(input_ids=ids, attention_mask=mask)
```
## Using Covariance Pooling
Captures second-order statistics (feature co-activations) across residue positions:
```python
model = ProteinSequenceClassifier(
aggregation="covariance",
num_classes=10,
aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080
)
```
The `d_proj` parameter controls the output dimensionality:
- `d_proj=32` β†’ 528 dims
- `d_proj=64` β†’ 2080 dims (default)
- `d_proj=128` β†’ 8256 dims
## Architecture Overview
```
Protein Sequence
↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ ESM2 (frozen) β”‚ Extracts per-residue embeddings [B, L, d]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Aggregator β”‚ Compresses token-level β†’ sequence-level [B, agg_dim]
β”‚ (one of 6) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Classifier Head β”‚ Linear (+ optional hidden layer) β†’ [B, num_classes]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Only the aggregator and classifier are trained. ESM2 is always frozen.
## GLOT Details
The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389):
1. **Token Graph Construction** β€” For standard GLOT: pairwise cosine similarity between residue embeddings β†’ threshold at Ο„ (default 0.6) β†’ binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
2. **Token-GNN** β€” K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
3. **Attention Readout** β€” Learned per-token importance scores β†’ softmax β†’ weighted sum to produce the sequence vector.
Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`.
## Dependencies
```bash
pip install torch torch-geometric transformers
# For GLOT-Residue with PDB files:
pip install graphein biopython
```
## File Structure
```
protein_aggregator/
β”œβ”€β”€ __init__.py # Package exports
β”œβ”€β”€ aggregators.py # All 6 aggregation method implementations
└── model.py # ProteinSequenceClassifier (ESM2 + aggregator + head)
example_localization.py # Usage example for subcellular localization
```
## References
- **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT)
- **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
- **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling)
- **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)