File size: 6,540 Bytes
c451d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Protein Sequence-Level Prediction with Multiple Token Aggregation Methods

Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction.

## Aggregation Methods

| # | Method | Class | Output Dim | Description |
|---|--------|-------|-----------|-------------|
| 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings |
| 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings |
| 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) |
| 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β†’ GAT GNN β†’ attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) |
| 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β†’ GAT GNN β†’ attention readout |
| 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) |

Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2).

## Supported ESM2 Backbones

The backbone is changeable β€” just pass a different model name:

| Model | Params | Hidden Dim |
|-------|--------|-----------|
| `facebook/esm2_t6_8M_UR50D` | 8M | 320 |
| `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) |
| `facebook/esm2_t30_150M_UR50D` | 150M | 640 |
| `facebook/esm2_t33_650M_UR50D` | 650M | 1280 |
| `facebook/esm2_t36_3B_UR50D` | 3B | 2560 |

## Quick Start

```python
from protein_aggregator import ProteinSequenceClassifier

# Build model: frozen ESM2 + GLOT aggregation + 10-class head
model = ProteinSequenceClassifier(
    esm2_model_name="facebook/esm2_t12_35M_UR50D",
    aggregation="glot",          # "mean", "max", "cls", "glot", "glot_residue", "covariance"
    num_classes=10,
    aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
    classifier_hidden=256,
    dropout=0.1,
).cuda()

# Get sequence-level embeddings
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
print(embeddings.shape)  # [2, 384]  (p*(K+1) = 128*3)

# Or full forward pass with loss
inputs = model.tokenizer(
    ["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
    padding=True, truncation=True, return_tensors="pt"
).to("cuda")

outputs = model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=torch.tensor([0, 3]).cuda(),
)
loss = outputs["loss"]
logits = outputs["logits"]
```

## Using GLOT-Residue with PDB Files

When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ… threshold) using [graphein](https://graphein.ai/):

```python
model = ProteinSequenceClassifier(
    aggregation="glot_residue",
    num_classes=10,
    aggregator_kwargs={
        "contact_threshold": 8.0,  # CΞ±-CΞ± distance in Γ…
        "seq_neighbor_k": 5,       # fallback: Β±k sequence neighbors if no PDB
    },
)

# With PDB files
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])

# Without PDB files (falls back to sequence-distance graph)
outputs = model(input_ids=ids, attention_mask=mask)
```

## Using Covariance Pooling

Captures second-order statistics (feature co-activations) across residue positions:

```python
model = ProteinSequenceClassifier(
    aggregation="covariance",
    num_classes=10,
    aggregator_kwargs={"d_proj": 64},  # output dim = 64*65/2 = 2080
)
```

The `d_proj` parameter controls the output dimensionality:
- `d_proj=32` β†’ 528 dims
- `d_proj=64` β†’ 2080 dims (default)
- `d_proj=128` β†’ 8256 dims

## Architecture Overview

```
Protein Sequence
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  ESM2 (frozen)   β”‚  Extracts per-residue embeddings [B, L, d]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Aggregator     β”‚  Compresses token-level β†’ sequence-level [B, agg_dim]
β”‚  (one of 6)      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Classifier Head β”‚  Linear (+ optional hidden layer) β†’ [B, num_classes]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```

Only the aggregator and classifier are trained. ESM2 is always frozen.

## GLOT Details

The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389):

1. **Token Graph Construction** β€” For standard GLOT: pairwise cosine similarity between residue embeddings β†’ threshold at Ο„ (default 0.6) β†’ binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
2. **Token-GNN** β€” K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
3. **Attention Readout** β€” Learned per-token importance scores β†’ softmax β†’ weighted sum to produce the sequence vector.

Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`.

## Dependencies

```bash
pip install torch torch-geometric transformers
# For GLOT-Residue with PDB files:
pip install graphein biopython
```

## File Structure

```
protein_aggregator/
β”œβ”€β”€ __init__.py          # Package exports
β”œβ”€β”€ aggregators.py       # All 6 aggregation method implementations
└── model.py             # ProteinSequenceClassifier (ESM2 + aggregator + head)
example_localization.py  # Usage example for subcellular localization
```

## References

- **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT)
- **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
- **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling)
- **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)