AliSaadatV commited on
Commit
c451d12
Β·
verified Β·
1 Parent(s): c2ef34a

Add README

Browse files
Files changed (1) hide show
  1. README.md +156 -0
README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protein Sequence-Level Prediction with Multiple Token Aggregation Methods
2
+
3
+ Extract residue embeddings from a **frozen ESM2** backbone and aggregate them into sequence-level representations using **6 different strategies** for downstream tasks like subcellular localization prediction.
4
+
5
+ ## Aggregation Methods
6
+
7
+ | # | Method | Class | Output Dim | Description |
8
+ |---|--------|-------|-----------|-------------|
9
+ | 1 | **Mean** | `MeanPooling` | `d` | Average over non-padded residue embeddings |
10
+ | 2 | **Max** | `MaxPooling` | `d` | Element-wise max over non-padded residue embeddings |
11
+ | 3 | **CLS** | `CLSPooling` | `d` | ESM2's `<cls>` token representation (position 0) |
12
+ | 4 | **GLOT** | `GLOTPooling` | `p*(K+1)` | Cosine-similarity token graph β†’ GAT GNN β†’ attention readout ([arXiv:2603.03389](https://arxiv.org/abs/2603.03389)) |
13
+ | 5 | **GLOT-Residue** | `GLOTResidueGraphPooling` | `p*(K+1)` | Protein 3D residue contact graph (via [graphein](https://graphein.ai/)) β†’ GAT GNN β†’ attention readout |
14
+ | 6 | **Covariance** | `CovariancePooling` | `d_proj*(d_proj+1)/2` | Second-order covariance pooling with power normalization ([ref](https://www.goodfire.ai/research/covariance-pooling)) |
15
+
16
+ Where `d` = ESM2 hidden dimension (e.g. 480 for 35M model), `p` = GNN hidden dim (default 128), `K` = GNN layers (default 2).
17
+
18
+ ## Supported ESM2 Backbones
19
+
20
+ The backbone is changeable β€” just pass a different model name:
21
+
22
+ | Model | Params | Hidden Dim |
23
+ |-------|--------|-----------|
24
+ | `facebook/esm2_t6_8M_UR50D` | 8M | 320 |
25
+ | `facebook/esm2_t12_35M_UR50D` | 35M | 480 (default) |
26
+ | `facebook/esm2_t30_150M_UR50D` | 150M | 640 |
27
+ | `facebook/esm2_t33_650M_UR50D` | 650M | 1280 |
28
+ | `facebook/esm2_t36_3B_UR50D` | 3B | 2560 |
29
+
30
+ ## Quick Start
31
+
32
+ ```python
33
+ from protein_aggregator import ProteinSequenceClassifier
34
+
35
+ # Build model: frozen ESM2 + GLOT aggregation + 10-class head
36
+ model = ProteinSequenceClassifier(
37
+ esm2_model_name="facebook/esm2_t12_35M_UR50D",
38
+ aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance"
39
+ num_classes=10,
40
+ aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
41
+ classifier_hidden=256,
42
+ dropout=0.1,
43
+ ).cuda()
44
+
45
+ # Get sequence-level embeddings
46
+ embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
47
+ print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3)
48
+
49
+ # Or full forward pass with loss
50
+ inputs = model.tokenizer(
51
+ ["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
52
+ padding=True, truncation=True, return_tensors="pt"
53
+ ).to("cuda")
54
+
55
+ outputs = model(
56
+ input_ids=inputs["input_ids"],
57
+ attention_mask=inputs["attention_mask"],
58
+ labels=torch.tensor([0, 3]).cuda(),
59
+ )
60
+ loss = outputs["loss"]
61
+ logits = outputs["logits"]
62
+ ```
63
+
64
+ ## Using GLOT-Residue with PDB Files
65
+
66
+ When 3D structure is available, `glot_residue` builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ… threshold) using [graphein](https://graphein.ai/):
67
+
68
+ ```python
69
+ model = ProteinSequenceClassifier(
70
+ aggregation="glot_residue",
71
+ num_classes=10,
72
+ aggregator_kwargs={
73
+ "contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ…
74
+ "seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB
75
+ },
76
+ )
77
+
78
+ # With PDB files
79
+ outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])
80
+
81
+ # Without PDB files (falls back to sequence-distance graph)
82
+ outputs = model(input_ids=ids, attention_mask=mask)
83
+ ```
84
+
85
+ ## Using Covariance Pooling
86
+
87
+ Captures second-order statistics (feature co-activations) across residue positions:
88
+
89
+ ```python
90
+ model = ProteinSequenceClassifier(
91
+ aggregation="covariance",
92
+ num_classes=10,
93
+ aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080
94
+ )
95
+ ```
96
+
97
+ The `d_proj` parameter controls the output dimensionality:
98
+ - `d_proj=32` β†’ 528 dims
99
+ - `d_proj=64` β†’ 2080 dims (default)
100
+ - `d_proj=128` β†’ 8256 dims
101
+
102
+ ## Architecture Overview
103
+
104
+ ```
105
+ Protein Sequence
106
+ ↓
107
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
108
+ β”‚ ESM2 (frozen) β”‚ Extracts per-residue embeddings [B, L, d]
109
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
110
+ ↓
111
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
112
+ β”‚ Aggregator β”‚ Compresses token-level β†’ sequence-level [B, agg_dim]
113
+ β”‚ (one of 6) β”‚
114
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
115
+ ↓
116
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
117
+ β”‚ Classifier Head β”‚ Linear (+ optional hidden layer) β†’ [B, num_classes]
118
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
119
+ ```
120
+
121
+ Only the aggregator and classifier are trained. ESM2 is always frozen.
122
+
123
+ ## GLOT Details
124
+
125
+ The GLOT aggregation (methods 4 & 5) follows [arXiv:2603.03389](https://arxiv.org/abs/2603.03389):
126
+
127
+ 1. **Token Graph Construction** β€” For standard GLOT: pairwise cosine similarity between residue embeddings β†’ threshold at Ο„ (default 0.6) β†’ binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
128
+ 2. **Token-GNN** β€” K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
129
+ 3. **Attention Readout** β€” Learned per-token importance scores β†’ softmax β†’ weighted sum to produce the sequence vector.
130
+
131
+ Default hyperparameters (from the paper): `p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer`.
132
+
133
+ ## Dependencies
134
+
135
+ ```bash
136
+ pip install torch torch-geometric transformers
137
+ # For GLOT-Residue with PDB files:
138
+ pip install graphein biopython
139
+ ```
140
+
141
+ ## File Structure
142
+
143
+ ```
144
+ protein_aggregator/
145
+ β”œβ”€β”€ __init__.py # Package exports
146
+ β”œβ”€β”€ aggregators.py # All 6 aggregation method implementations
147
+ └── model.py # ProteinSequenceClassifier (ESM2 + aggregator + head)
148
+ example_localization.py # Usage example for subcellular localization
149
+ ```
150
+
151
+ ## References
152
+
153
+ - **GLOT**: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). [Paper](https://arxiv.org/abs/2603.03389) | [Code](https://github.com/ipsitmantri/GLOT)
154
+ - **ESM2**: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. [Models](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
155
+ - **Covariance Pooling**: [Goodfire Research](https://www.goodfire.ai/research/covariance-pooling)
156
+ - **Graphein**: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. [Docs](https://graphein.ai/)