Add README
Browse files
README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Protein Sequence-Level Prediction with Multiple Token Aggregation Methods
|
| 2 |
+
|
| 3 |
+
Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction.
|
| 4 |
+
|
| 5 |
+
## Aggregation Methods
|
| 6 |
+
|
| 7 |
+
| # | Method | Class | Output Dim | Description |
|
| 8 |
+
|---|--------|-------|-----------|-------------|
|
| 9 |
+
| 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings |
|
| 10 |
+
| 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings |
|
| 11 |
+
| 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) |
|
| 12 |
+
| 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β GAT GNN β attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) |
|
| 13 |
+
| 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β GAT GNN β attention readout |
|
| 14 |
+
| 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) |
|
| 15 |
+
|
| 16 |
+
Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2).
|
| 17 |
+
|
| 18 |
+
## Supported ESM2 Backbones
|
| 19 |
+
|
| 20 |
+
The backbone is changeable β just pass a different model name:
|
| 21 |
+
|
| 22 |
+
| Model | Params | Hidden Dim |
|
| 23 |
+
|-------|--------|-----------|
|
| 24 |
+
| `facebook/esm2_t6_8M_UR50D` | 8M | 320 |
|
| 25 |
+
| `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) |
|
| 26 |
+
| `facebook/esm2_t30_150M_UR50D` | 150M | 640 |
|
| 27 |
+
| `facebook/esm2_t33_650M_UR50D` | 650M | 1280 |
|
| 28 |
+
| `facebook/esm2_t36_3B_UR50D` | 3B | 2560 |
|
| 29 |
+
|
| 30 |
+
## Quick Start
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from protein_aggregator import ProteinSequenceClassifier
|
| 34 |
+
|
| 35 |
+
# Build model: frozen ESM2 + GLOT aggregation + 10-class head
|
| 36 |
+
model = ProteinSequenceClassifier(
|
| 37 |
+
esm2_model_name="facebook/esm2_t12_35M_UR50D",
|
| 38 |
+
aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance"
|
| 39 |
+
num_classes=10,
|
| 40 |
+
aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
|
| 41 |
+
classifier_hidden=256,
|
| 42 |
+
dropout=0.1,
|
| 43 |
+
).cuda()
|
| 44 |
+
|
| 45 |
+
# Get sequence-level embeddings
|
| 46 |
+
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
|
| 47 |
+
print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3)
|
| 48 |
+
|
| 49 |
+
# Or full forward pass with loss
|
| 50 |
+
inputs = model.tokenizer(
|
| 51 |
+
["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
|
| 52 |
+
padding=True, truncation=True, return_tensors="pt"
|
| 53 |
+
).to("cuda")
|
| 54 |
+
|
| 55 |
+
outputs = model(
|
| 56 |
+
input_ids=inputs["input_ids"],
|
| 57 |
+
attention_mask=inputs["attention_mask"],
|
| 58 |
+
labels=torch.tensor([0, 3]).cuda(),
|
| 59 |
+
)
|
| 60 |
+
loss = outputs["loss"]
|
| 61 |
+
logits = outputs["logits"]
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Using GLOT-Residue with PDB Files
|
| 65 |
+
|
| 66 |
+
When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ
threshold) using [graphein](https://graphein.ai/):
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
model = ProteinSequenceClassifier(
|
| 70 |
+
aggregation="glot_residue",
|
| 71 |
+
num_classes=10,
|
| 72 |
+
aggregator_kwargs={
|
| 73 |
+
"contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ
|
| 74 |
+
"seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB
|
| 75 |
+
},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# With PDB files
|
| 79 |
+
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])
|
| 80 |
+
|
| 81 |
+
# Without PDB files (falls back to sequence-distance graph)
|
| 82 |
+
outputs = model(input_ids=ids, attention_mask=mask)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Using Covariance Pooling
|
| 86 |
+
|
| 87 |
+
Captures second-order statistics (feature co-activations) across residue positions:
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
model = ProteinSequenceClassifier(
|
| 91 |
+
aggregation="covariance",
|
| 92 |
+
num_classes=10,
|
| 93 |
+
aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080
|
| 94 |
+
)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
The `d_proj` parameter controls the output dimensionality:
|
| 98 |
+
- `d_proj=32` β 528 dims
|
| 99 |
+
- `d_proj=64` β 2080 dims (default)
|
| 100 |
+
- `d_proj=128` β 8256 dims
|
| 101 |
+
|
| 102 |
+
## Architecture Overview
|
| 103 |
+
|
| 104 |
+
```
|
| 105 |
+
Protein Sequence
|
| 106 |
+
β
|
| 107 |
+
ββββββββββββββββββββ
|
| 108 |
+
β ESM2 (frozen) β Extracts per-residue embeddings [B, L, d]
|
| 109 |
+
ββββββββββββββββββββ
|
| 110 |
+
β
|
| 111 |
+
ββββββββββββββββββββ
|
| 112 |
+
β Aggregator β Compresses token-level β sequence-level [B, agg_dim]
|
| 113 |
+
β (one of 6) β
|
| 114 |
+
ββββββββββββββββββββ
|
| 115 |
+
β
|
| 116 |
+
ββββββββββββββββββββ
|
| 117 |
+
β Classifier Head β Linear (+ optional hidden layer) β [B, num_classes]
|
| 118 |
+
ββββββββββββββββββββ
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Only the aggregator and classifier are trained. ESM2 is always frozen.
|
| 122 |
+
|
| 123 |
+
## GLOT Details
|
| 124 |
+
|
| 125 |
+
The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389):
|
| 126 |
+
|
| 127 |
+
1. **Token Graph Construction** β For standard GLOT: pairwise cosine similarity between residue embeddings β threshold at Ο (default 0.6) β binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
|
| 128 |
+
2. **Token-GNN** β K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
|
| 129 |
+
3. **Attention Readout** β Learned per-token importance scores β softmax β weighted sum to produce the sequence vector.
|
| 130 |
+
|
| 131 |
+
Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`.
|
| 132 |
+
|
| 133 |
+
## Dependencies
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
pip install torch torch-geometric transformers
|
| 137 |
+
# For GLOT-Residue with PDB files:
|
| 138 |
+
pip install graphein biopython
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
## File Structure
|
| 142 |
+
|
| 143 |
+
```
|
| 144 |
+
protein_aggregator/
|
| 145 |
+
βββ __init__.py # Package exports
|
| 146 |
+
βββ aggregators.py # All 6 aggregation method implementations
|
| 147 |
+
βββ model.py # ProteinSequenceClassifier (ESM2 + aggregator + head)
|
| 148 |
+
example_localization.py # Usage example for subcellular localization
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
## References
|
| 152 |
+
|
| 153 |
+
- **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT)
|
| 154 |
+
- **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
|
| 155 |
+
- **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling)
|
| 156 |
+
- **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)
|