File size: 8,762 Bytes
36bbb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# ModernProteinLM β€” Private GPU Cluster Instructions

## Overview

ModernProteinLM is a next-generation protein encoder (<200M params) that combines:
1. **ModernBERT architecture** (RoPE, Pre-LN, GeGLU, deep & narrow)
2. **ELECTRA discriminative pre-training** (replaced token detection)
3. **Span masking curriculum** (30% β†’ 5% over training)

This is the **first protein encoder** to combine all three proven techniques, targeting predictive downstream tasks (fluorescence, stability, solubility, structure, etc.).

---

## Quick Start

```bash
# 1. Clone / copy the codebase to your cluster
# 2. Install dependencies
pip install -r requirements.txt

# 3. (Optional) Install FlashAttention for speedup
pip install flash-attn --no-build-isolation

# 4. Run pre-training
bash run_pretrain.sh

# 5. Run downstream fine-tuning + evaluation
bash run_finetune.sh
```

---

## Architecture Summary

| Component | Value | Why |
|-----------|-------|-----|
| **Params** | ~150M | Competitive with ESM-2 150M |
| **Layers** | 28 | Deep & narrow (NeoBERT/ModernBERT best practice) |
| **Hidden** | 576 | Head dim = 64 (tensor core optimal) |
| **Heads** | 9 | 576/9 = 64 |
| **FFN** | 2304 | GeGLU (4Γ— hidden) |
| **Pos Emb** | RoPE (ΞΈ=10k) | Extrapolates to longer proteins |
| **Norm** | Pre-LN | Stable at 28 layers |
| **Dropout** | 0.0 | Following ESM-2 (data is noise enough) |
| **Vocab** | 33 | ESM-2 compatible |
| **Generator** | 320 hidden, 8L | 25% of discriminator (ELECTRA recipe) |

**Discriminator params: ~150M | Generator params: ~25M**

---

## Stage 1: Pre-Training (ELECTRA)

### Single GPU

```bash
CUDA_VISIBLE_DEVICES=0 bash run_pretrain.sh
```

### Multi-GPU (DDP)

```bash
# 4 GPUs
torchrun --standalone --nnodes=1 --nproc_per_node=4 run_pretrain.sh
```

### SLURM

```bash
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=16
#SBATCH --mem=128G

module load cuda/12.1
source ~/venv/bin/activate

export NUM_GPUS=4
export BATCH_SIZE=32        # Per-device
export MAX_STEPS=500000
export USE_AMP=1
export USE_FLASH_ATTN=1

bash run_pretrain.sh
```

### Key Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `NUM_GPUS` | 1 | Number of GPUs |
| `BATCH_SIZE` | 64 | Per-device batch size |
| `MAX_STEPS` | 100000 | Total training steps |
| `LR` | 5e-4 | Peak learning rate |
| `MASK_START` | 0.30 | Initial mask ratio |
| `MASK_END` | 0.05 | Final mask ratio |
| `USE_AMP` | 1 | bf16 mixed precision |
| `USE_FLASH_ATTN` | 1 | FlashAttention (requires install) |
| `GRADIENT_CHECKPOINTING` | 0 | Trade compute for memory |
| `USE_TRACKIO` | 0 | Enable experiment tracking |

### Data Sources

Pre-training pulls from HuggingFace datasets by default:
- `lamm-mit/protein_secondary_structure_from_PDB` (~126k sequences)
- `adamstogsdill/pdb_protein_dataset_100_4000_1024`

**For full pre-training**, set `USE_STREAMING=1` and add UniRef50/UniRef90:

```bash
export USE_STREAMING=1
# Or provide local UniRef FASTA:
export UNIREF_PATH=/path/to/uniref50.fasta
```

To add UniRef support, modify `load_sequences()` in `train_pretrain.py`:

```python
from Bio import SeqIO

def load_uniref_fasta(path, max_seqs=5000000):
    sequences = []
    for record in SeqIO.parse(path, "fasta"):
        seq = str(record.seq)
        if len(seq) >= 20 and len(seq) <= 1024:
            sequences.append(seq)
        if len(sequences) >= max_seqs:
            break
    return sequences
```

### Expected Pre-Training Time

| Hardware | Batch Size | Steps/Day | 100K Steps | 500K Steps |
|----------|-----------|-----------|------------|------------|
| 1Γ— A100 80GB | 128 | ~50K | 2 days | 10 days |
| 4Γ— A100 80GB | 128Γ—4 | ~200K | 12 hours | 2.5 days |
| 8Γ— A100 80GB | 128Γ—8 | ~400K | 6 hours | ~30 hours |

*With bf16 AMP and FlashAttention*

---

## Stage 2: Downstream Fine-Tuning

After pre-training completes, fine-tune on specific tasks:

```bash
# Fine-tune on all available tasks
bash run_finetune.sh

# Or specific tasks
PRETRAIN_DIR=./outputs/pretrain/final bash run_finetune.sh
```

### Supported Benchmark Tasks

| Task | Type | Metric | Baseline (ESM-2 150M) | Target |
|------|------|--------|----------------------|--------|
| **Fluorescence** | Regression | Spearman ρ | 0.68 | β‰₯ 0.75 |
| **Stability** | Regression | Spearman ρ | 0.79 | β‰₯ 0.85 |
| **Solubility** | Classification | Accuracy | ~74% | β‰₯ 80% |
| **Remote Homology** | Classification | Accuracy | ~20% | β‰₯ 25% |

### Fine-Tuning Strategy

The script uses **layer-wise learning rate decay**:
- Task head: `lr`
- Last 4 transformer layers: `lr Γ— 0.5`
- Earlier layers + embeddings: `lr Γ— 0.1`

This is critical for small downstream datasets (fluorescence has ~21k samples).

For even smaller datasets, add LoRA:

```bash
# Install PEFT
pip install peft

# In train_finetune.py, replace full fine-tuning with:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8, lora_alpha=16,
    target_modules=["qkv_proj", "out_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
)
model = get_peft_model(model, lora_config)
```

---

## Stage 3: Pushing to HuggingFace Hub

After fine-tuning, push the pretrained encoder for community use:

```python
from modeling_modern_protein import ModernProteinLM
from transformers import PreTrainedTokenizerFast

# Load your trained model
model = ModernProteinLM.from_pretrained("./outputs/pretrain/final")

# Push to Hub
model.push_to_hub("your-username/ModernProteinLM-150M")

# With a task-specific head
from modeling_modern_protein import ModernProteinLMForSequenceClassification
cls_model = ModernProteinLMForSequenceClassification.from_pretrained(
    "./outputs/finetune/fluorescence/best"
)
cls_model.push_to_hub("your-username/ModernProteinLM-fluorescence")
```

---

## Expected Improvements Over ESM-2 150M

| Technique | Source | Expected Gain |
|-----------|--------|--------------|
| ELECTRA vs MLM | ELECTRA paper | +3-5% on discriminative tasks |
| GeGLU vs GELU | ModernBERT | +1-2% |
| Deep & narrow (28L) | NeoBERT | +1-3% on embeddings |
| Span masking | SpanBERT analogy | +1-2% on structure tasks |
| Curriculum 30%β†’5% | mmBERT | Faster convergence |
| **Combined (conservative)** | β€” | **+7-14% on predictive benchmarks** |

---

## Troubleshooting

### OOM during pre-training

```bash
# Reduce per-device batch size
export BATCH_SIZE=32

# Enable gradient checkpointing
export GRADIENT_CHECKPOINTING=1

# Reduce sequence length
export MAX_SEQ_LENGTH=512
```

### FlashAttention install fails

```bash
# Skip FlashAttention (slower but works)
export USE_FLASH_ATTN=0

# Or install from prebuilt wheel
pip install flash-attn --find-links https://github.com/Dao-AILab/flash-attention/releases
```

### Slow data loading

```bash
# Increase workers
export NUM_WORKERS=16

# Pre-tokenize and cache
python -c "
from train_pretrain import load_sequences, ProteinTokenizer
import pickle
tokenizer = ProteinTokenizer()
seqs = load_sequences(None)
tokenized = [tokenizer.encode(s) for s in seqs]
pickle.dump(tokenized, open('tokenized_cache.pkl', 'wb'))
"
```

---

## File Reference

```
modern_protein_lm/
β”œβ”€β”€ modeling_modern_protein.py    # Core architecture (ModernBERT-style + ELECTRA)
β”œβ”€β”€ train_pretrain.py             # ELECTRA pre-training (supports DDP, AMP)
β”œβ”€β”€ train_finetune.py             # Downstream fine-tuning (layer-wise LR)
β”œβ”€β”€ run_pretrain.sh               # Launch script for pre-training
β”œβ”€β”€ run_finetune.sh               # Launch script for fine-tuning
β”œβ”€β”€ requirements.txt              # Dependencies
β”œβ”€β”€ README.md                     # Architecture docs
└── CLUSTER_INSTRUCTIONS.md       # This file
```

---

## Citation

If you use this architecture or achieve SOTA results, please cite:

```bibtex
@article{lin2023evolutionary,
  title={Language models of protein sequences at the scale of evolution enable accurate structure prediction},
  author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yaniv and others},
  journal={Science},
  year={2023}
}

@article{warner2024modernbert,
  title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient and Long Context Finetuning and Inference},
  author={Warner, Benjamin and Chalkidis, Ilias and Dadic, Jon Ander and others},
  journal={arXiv preprint arXiv:2412.13663},
  year={2024}
}

@inproceedings{clark2020electra,
  title={ELECTRA: Pre-training text encoders as discriminators rather than generators},
  author={Clark, Kevin and Luong, Minh-Thang and Le, Quoc V and Manning, Christopher D},
  booktitle={ICLR},
  year={2020}
}
```