GrimSqueaker commited on
Commit
10db53f
·
verified ·
1 Parent(s): fa85dd6

Upload train_pretrain.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pretrain.py +610 -1
train_pretrain.py CHANGED
@@ -1 +1,610 @@
1
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production ELECTRA pre-training script for ModernProteinLM.
3
+ Supports: single GPU, multi-GPU DDP, FSDP (optional), bf16 AMP, gradient checkpointing.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import math
10
+ import random
11
+ import time
12
+ import json
13
+ from typing import List, Dict, Optional
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torch.distributed as dist
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
22
+ from torch.cuda.amp import autocast, GradScaler
23
+ from transformers import get_cosine_schedule_with_warmup
24
+ from datasets import load_dataset
25
+ from tqdm import tqdm
26
+
27
+ from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
28
+
29
+
30
+ def setup_distributed():
31
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
32
+ rank = int(os.environ["RANK"])
33
+ world_size = int(os.environ["WORLD_SIZE"])
34
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
35
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
36
+ torch.cuda.set_device(local_rank)
37
+ return rank, world_size, local_rank
38
+ return 0, 1, 0
39
+
40
+
41
+ def cleanup_distributed():
42
+ if dist.is_initialized():
43
+ dist.destroy_process_group()
44
+
45
+
46
+ def log_rank0(msg):
47
+ if not dist.is_initialized() or dist.get_rank() == 0:
48
+ print(msg)
49
+
50
+
51
+ # =============================================================================
52
+ # TOKENIZER
53
+ # =============================================================================
54
+
55
+ class ProteinTokenizer:
56
+ """ESM-2 compatible protein tokenizer."""
57
+
58
+ def __init__(self):
59
+ self.vocab = {
60
+ "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
61
+ "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
62
+ "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
63
+ "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
64
+ "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29,
65
+ "<sep>": 30,
66
+ }
67
+ while len(self.vocab) < 33:
68
+ self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
69
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
70
+ self.mask_token_id = 29
71
+ self.pad_token_id = 1
72
+ self.cls_token_id = 0
73
+ self.eos_token_id = 2
74
+
75
+ def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True):
76
+ tokens = []
77
+ if add_special_tokens:
78
+ tokens.append(self.cls_token_id)
79
+ for aa in sequence.upper():
80
+ tokens.append(self.vocab.get(aa, self.vocab["<unk>"]))
81
+ if add_special_tokens:
82
+ tokens.append(self.eos_token_id)
83
+
84
+ if len(tokens) > max_length:
85
+ tokens = tokens[:max_length]
86
+
87
+ attention_mask = [1] * len(tokens)
88
+ while len(tokens) < max_length:
89
+ tokens.append(self.pad_token_id)
90
+ attention_mask.append(0)
91
+
92
+ return {
93
+ "input_ids": tokens,
94
+ "attention_mask": attention_mask,
95
+ }
96
+
97
+
98
+ # =============================================================================
99
+ # MASKING
100
+ # =============================================================================
101
+
102
+ def create_span_mask(length: int, mask_ratio: float, mean_span_length: int = 3):
103
+ num_to_mask = max(1, int(length * mask_ratio))
104
+ mask = [False] * length
105
+
106
+ masked = 0
107
+ attempts = 0
108
+ while masked < num_to_mask and attempts < num_to_mask * 10:
109
+ span_len = max(1, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked))
110
+ start = random.randint(0, max(0, length - span_len))
111
+ if any(mask[start:start+span_len]):
112
+ attempts += 1
113
+ continue
114
+ for i in range(start, min(start + span_len, length)):
115
+ mask[i] = True
116
+ masked += 1
117
+ return mask
118
+
119
+
120
+ # =============================================================================
121
+ # DATASET
122
+ # =============================================================================
123
+
124
+ class PretrainDataset(Dataset):
125
+ def __init__(self, sequences: List[str], tokenizer, args, current_step: int = 0):
126
+ self.sequences = sequences
127
+ self.tokenizer = tokenizer
128
+ self.args = args
129
+ self.current_step = current_step
130
+
131
+ def get_mask_ratio(self):
132
+ progress = min(1.0, self.current_step / self.args.max_steps)
133
+ return self.args.mask_start + (self.args.mask_end - self.args.mask_start) * progress
134
+
135
+ def __len__(self):
136
+ return len(self.sequences)
137
+
138
+ def __getitem__(self, idx):
139
+ seq = self.sequences[idx]
140
+ encoded = self.tokenizer.encode(seq, max_length=self.args.max_seq_length)
141
+ input_ids = encoded["input_ids"]
142
+ attention_mask = encoded["attention_mask"]
143
+
144
+ seq_len = sum(attention_mask)
145
+ effective_len = max(1, seq_len - 2)
146
+
147
+ span_mask = create_span_mask(effective_len, self.get_mask_ratio(), self.args.span_length)
148
+
149
+ masked_input = input_ids.copy()
150
+ labels = [-100] * len(input_ids)
151
+ replaced = [False] * len(input_ids)
152
+
153
+ for i in range(1, 1 + effective_len):
154
+ if span_mask[i - 1]:
155
+ labels[i] = input_ids[i]
156
+ replaced[i] = True
157
+ r = random.random()
158
+ if r < 0.8:
159
+ masked_input[i] = self.tokenizer.mask_token_id
160
+ elif r < 0.9:
161
+ masked_input[i] = random.randint(4, 28)
162
+
163
+ return {
164
+ "input_ids": torch.tensor(masked_input, dtype=torch.long),
165
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
166
+ "mlm_labels": torch.tensor(labels, dtype=torch.long),
167
+ "replaced": torch.tensor(replaced, dtype=torch.bool),
168
+ "original_ids": torch.tensor(input_ids, dtype=torch.long),
169
+ }
170
+
171
+
172
+ def load_sequences(args):
173
+ all_sequences = []
174
+
175
+ # Try HF datasets first
176
+ sources = [
177
+ ("lamm-mit/protein_secondary_structure_from_PDB", "train", "input"),
178
+ ("adamstogsdill/pdb_protein_dataset_100_4000_1024", "train", "sequence"),
179
+ ]
180
+
181
+ for dataset_name, split, seq_key in sources:
182
+ try:
183
+ if args.use_streaming:
184
+ ds = load_dataset(dataset_name, split=split, streaming=True)
185
+ count = 0
186
+ for ex in ds:
187
+ seq = ex.get(seq_key, "")
188
+ if isinstance(seq, str) and len(seq) >= 20:
189
+ all_sequences.append(seq)
190
+ count += 1
191
+ if count >= args.max_sequences:
192
+ break
193
+ else:
194
+ ds = load_dataset(dataset_name, split=split)
195
+ for ex in ds:
196
+ seq = ex.get(seq_key, "")
197
+ if isinstance(seq, str) and len(seq) >= 20:
198
+ all_sequences.append(seq)
199
+ log_rank0(f"Loaded {len(all_sequences)} from {dataset_name}")
200
+ except Exception as e:
201
+ log_rank0(f"Failed {dataset_name}: {e}")
202
+
203
+ # Fallback to synthetic
204
+ if len(all_sequences) < 1000:
205
+ log_rank0("Using synthetic sequences for testing")
206
+ amino_acids = "ACDEFGHIKLMNPQRSTVWY"
207
+ all_sequences = [
208
+ "".join(random.choices(amino_acids, k=random.randint(50, 500)))
209
+ for _ in range(min(args.max_sequences, 50000))
210
+ ]
211
+
212
+ # Limit total
213
+ if len(all_sequences) > args.max_sequences:
214
+ random.shuffle(all_sequences)
215
+ all_sequences = all_sequences[:args.max_sequences]
216
+
217
+ return all_sequences
218
+
219
+
220
+ # =============================================================================
221
+ # MODELS
222
+ # =============================================================================
223
+
224
+ class Generator(nn.Module):
225
+ def __init__(self, args):
226
+ super().__init__()
227
+ config = ModernProteinLMConfig(
228
+ vocab_size=33,
229
+ hidden_size=args.gen_hidden_size,
230
+ num_hidden_layers=args.gen_num_layers,
231
+ num_attention_heads=args.gen_num_heads,
232
+ intermediate_size=args.gen_intermediate_size,
233
+ use_geglu=True,
234
+ tie_word_embeddings=True,
235
+ max_position_embeddings=args.max_seq_length + 2,
236
+ )
237
+ self.model = ModernProteinLM(config)
238
+
239
+ def forward(self, input_ids, attention_mask, labels):
240
+ return self.model(input_ids, attention_mask, labels=labels)
241
+
242
+
243
+ class Discriminator(nn.Module):
244
+ def __init__(self, args):
245
+ super().__init__()
246
+ config = ModernProteinLMConfig(
247
+ vocab_size=33,
248
+ hidden_size=args.hidden_size,
249
+ num_hidden_layers=args.num_layers,
250
+ num_attention_heads=args.num_heads,
251
+ intermediate_size=args.intermediate_size,
252
+ use_geglu=True,
253
+ tie_word_embeddings=True,
254
+ max_position_embeddings=args.max_seq_length + 2,
255
+ )
256
+ self.model = ModernProteinLM(config)
257
+ self.discriminator_head = nn.Linear(args.hidden_size, 1)
258
+
259
+ params = sum(p.numel() for p in self.model.parameters())
260
+ log_rank0(f"Discriminator: {params/1e6:.1f}M params")
261
+
262
+ def forward(self, input_ids, attention_mask, disc_labels=None):
263
+ outputs = self.model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
264
+ hidden = outputs.hidden_states[-1]
265
+ logits = self.discriminator_head(hidden).squeeze(-1)
266
+
267
+ loss = None
268
+ if disc_labels is not None:
269
+ loss_fct = nn.BCEWithLogitsLoss()
270
+ active = disc_labels != -100
271
+ if active.any():
272
+ loss = loss_fct(logits[active], disc_labels[active].float())
273
+
274
+ return {"loss": loss, "logits": logits, "hidden_states": hidden}
275
+
276
+
277
+ # =============================================================================
278
+ # TRAINING
279
+ # =============================================================================
280
+
281
+ class Trainer:
282
+ def __init__(self, args, generator, discriminator, tokenizer, device, rank, world_size):
283
+ self.args = args
284
+ self.generator = generator.to(device)
285
+ self.discriminator = discriminator.to(device)
286
+ self.tokenizer = tokenizer
287
+ self.device = device
288
+ self.rank = rank
289
+ self.world_size = world_size
290
+ self.global_step = 0
291
+
292
+ if world_size > 1:
293
+ self.generator = DDP(self.generator, device_ids=[rank], find_unused_parameters=False)
294
+ self.discriminator = DDP(self.discriminator, device_ids=[rank], find_unused_parameters=False)
295
+
296
+ self.gen_opt = torch.optim.AdamW(
297
+ generator.parameters(), lr=args.lr,
298
+ betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay
299
+ )
300
+ self.disc_opt = torch.optim.AdamW(
301
+ discriminator.parameters(), lr=args.lr,
302
+ betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay
303
+ )
304
+
305
+ self.gen_sched = get_cosine_schedule_with_warmup(
306
+ self.gen_opt, args.warmup_steps, args.max_steps
307
+ )
308
+ self.disc_sched = get_cosine_schedule_with_warmup(
309
+ self.disc_opt, args.warmup_steps, args.max_steps
310
+ )
311
+
312
+ self.scaler = GradScaler() if args.use_amp else None
313
+
314
+ if args.gradient_checkpointing:
315
+ self.generator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.generator.model.gradient_checkpointing_enable()
316
+ self.discriminator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.discriminator.model.gradient_checkpointing_enable()
317
+
318
+ # Trackio
319
+ self.trackio = None
320
+ if args.use_trackio:
321
+ try:
322
+ import trackio
323
+ trackio.init(project=args.trackio_project, space_id=args.trackio_space_id or None)
324
+ self.trackio = trackio
325
+ log_rank0("Trackio initialized")
326
+ except ImportError:
327
+ log_rank0("Trackio not available")
328
+
329
+ def train_step(self, batch):
330
+ input_ids = batch["input_ids"].to(self.device)
331
+ attention_mask = batch["attention_mask"].to(self.device)
332
+ mlm_labels = batch["mlm_labels"].to(self.device)
333
+ replaced = batch["replaced"].to(self.device)
334
+ original_ids = batch["original_ids"].to(self.device)
335
+
336
+ with autocast(enabled=self.args.use_amp):
337
+ # Generator
338
+ gen_out = self.generator(input_ids, attention_mask, mlm_labels)
339
+ gen_loss = gen_out.loss
340
+
341
+ # Sample corrupted input
342
+ with torch.no_grad():
343
+ gen_logits = gen_out.logits
344
+ gen_probs = F.softmax(gen_logits, dim=-1)
345
+ sampled = torch.multinomial(
346
+ gen_probs.view(-1, gen_probs.size(-1)), 1
347
+ ).view(gen_probs.shape[:-1])
348
+
349
+ corrupted = original_ids.clone()
350
+ mask_pos = mlm_labels != -100
351
+ corrupted[mask_pos] = sampled[mask_pos]
352
+
353
+ # Discriminator
354
+ disc_labels = torch.ones_like(original_ids, dtype=torch.float)
355
+ disc_labels[replaced] = 0.0
356
+ disc_labels[attention_mask == 0] = -100
357
+
358
+ disc_out = self.discriminator(corrupted, attention_mask, disc_labels)
359
+ disc_loss = disc_out["loss"]
360
+
361
+ total_loss = self.args.gen_weight * gen_loss + self.args.disc_weight * disc_loss
362
+
363
+ # Backward
364
+ if self.scaler:
365
+ self.scaler.scale(total_loss).backward()
366
+ self.scaler.unscale_(self.gen_opt)
367
+ self.scaler.unscale_(self.disc_opt)
368
+ torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip)
369
+ torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip)
370
+ self.scaler.step(self.gen_opt)
371
+ self.scaler.step(self.disc_opt)
372
+ self.scaler.update()
373
+ else:
374
+ total_loss.backward()
375
+ torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip)
376
+ torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip)
377
+ self.gen_opt.step()
378
+ self.disc_opt.step()
379
+
380
+ self.gen_sched.step()
381
+ self.disc_sched.step()
382
+ self.gen_opt.zero_grad()
383
+ self.disc_opt.zero_grad()
384
+
385
+ self.global_step += 1
386
+
387
+ return {
388
+ "gen_loss": gen_loss.item(),
389
+ "disc_loss": disc_loss.item() if disc_loss else 0.0,
390
+ "total_loss": total_loss.item(),
391
+ "lr": self.gen_sched.get_last_lr()[0],
392
+ }
393
+
394
+ def evaluate(self, eval_loader):
395
+ self.generator.eval()
396
+ self.discriminator.eval()
397
+
398
+ total_gen = 0.0
399
+ total_disc = 0.0
400
+ n = 0
401
+
402
+ with torch.no_grad():
403
+ for batch in eval_loader:
404
+ input_ids = batch["input_ids"].to(self.device)
405
+ attention_mask = batch["attention_mask"].to(self.device)
406
+ mlm_labels = batch["mlm_labels"].to(self.device)
407
+ replaced = batch["replaced"].to(self.device)
408
+ original_ids = batch["original_ids"].to(self.device)
409
+
410
+ gen_out = self.generator(input_ids, attention_mask, mlm_labels)
411
+ total_gen += gen_out.loss.item()
412
+
413
+ disc_labels = torch.ones_like(original_ids, dtype=torch.float)
414
+ disc_labels[replaced] = 0.0
415
+ disc_labels[attention_mask == 0] = -100
416
+
417
+ disc_out = self.discriminator(input_ids, attention_mask, disc_labels)
418
+ if disc_out["loss"]:
419
+ total_disc += disc_out["loss"].item()
420
+ n += 1
421
+
422
+ self.generator.train()
423
+ self.discriminator.train()
424
+
425
+ return {"gen_loss": total_gen / max(n, 1), "disc_loss": total_disc / max(n, 1)}
426
+
427
+ def save(self, path, name):
428
+ save_dir = os.path.join(path, name)
429
+ os.makedirs(save_dir, exist_ok=True)
430
+
431
+ gen_state = self.generator.module.state_dict() if self.world_size > 1 else self.generator.state_dict()
432
+ disc_state = self.discriminator.module.state_dict() if self.world_size > 1 else self.discriminator.state_dict()
433
+
434
+ torch.save({
435
+ "generator": gen_state,
436
+ "discriminator": disc_state,
437
+ "step": self.global_step,
438
+ }, os.path.join(save_dir, "checkpoint.pt"))
439
+
440
+ log_rank0(f"Saved checkpoint to {save_dir}")
441
+
442
+ def train(self, train_loader, eval_loader=None):
443
+ log_rank0(f"\n{'='*60}")
444
+ log_rank0(f"ELECTRA Pre-training: {self.args.max_steps} steps")
445
+ log_rank0(f"{'='*60}\n")
446
+
447
+ self.generator.train()
448
+ self.discriminator.train()
449
+
450
+ epoch = 0
451
+ while self.global_step < self.args.max_steps:
452
+ epoch += 1
453
+ if isinstance(train_loader.sampler, DistributedSampler):
454
+ train_loader.sampler.set_epoch(epoch)
455
+
456
+ for batch in train_loader:
457
+ if self.global_step >= self.args.max_steps:
458
+ break
459
+
460
+ metrics = self.train_step(batch)
461
+
462
+ if self.global_step % self.args.log_interval == 0 and self.rank == 0:
463
+ log_rank0(
464
+ f"Step {self.global_step:6d} | "
465
+ f"gen_loss={metrics['gen_loss']:.4f} | "
466
+ f"disc_loss={metrics['disc_loss']:.4f} | "
467
+ f"total={metrics['total_loss']:.4f} | "
468
+ f"lr={metrics['lr']:.2e}"
469
+ )
470
+
471
+ if self.trackio:
472
+ self.trackio.log(metrics, step=self.global_step)
473
+
474
+ if eval_loader and self.global_step % self.args.eval_interval == 0:
475
+ eval_metrics = self.evaluate(eval_loader)
476
+ if self.rank == 0:
477
+ log_rank0(f"Eval @ {self.global_step}: gen={eval_metrics['gen_loss']:.4f}, disc={eval_metrics['disc_loss']:.4f}")
478
+ if self.trackio:
479
+ self.trackio.log({f"eval_{k}": v for k, v in eval_metrics.items()}, step=self.global_step)
480
+
481
+ if self.global_step % self.args.save_interval == 0:
482
+ self.save(self.args.output_dir, f"step_{self.global_step}")
483
+
484
+ # Final save
485
+ self.save(self.args.output_dir, "final")
486
+
487
+
488
+ # =============================================================================
489
+ # MAIN
490
+ # =============================================================================
491
+
492
+ def parse_args():
493
+ parser = argparse.ArgumentParser()
494
+
495
+ # Model
496
+ parser.add_argument("--hidden_size", type=int, default=576)
497
+ parser.add_argument("--num_layers", type=int, default=28)
498
+ parser.add_argument("--num_heads", type=int, default=9)
499
+ parser.add_argument("--intermediate_size", type=int, default=2304)
500
+ parser.add_argument("--gen_hidden_size", type=int, default=320)
501
+ parser.add_argument("--gen_num_layers", type=int, default=8)
502
+ parser.add_argument("--gen_num_heads", type=int, default=8)
503
+ parser.add_argument("--gen_intermediate_size", type=int, default=1280)
504
+ parser.add_argument("--max_seq_length", type=int, default=1024)
505
+
506
+ # Training
507
+ parser.add_argument("--batch_size", type=int, default=64)
508
+ parser.add_argument("--max_steps", type=int, default=100000)
509
+ parser.add_argument("--warmup_steps", type=int, default=10000)
510
+ parser.add_argument("--lr", type=float, default=5e-4)
511
+ parser.add_argument("--weight_decay", type=float, default=0.01)
512
+ parser.add_argument("--grad_clip", type=float, default=1.0)
513
+ parser.add_argument("--gen_weight", type=float, default=1.0)
514
+ parser.add_argument("--disc_weight", type=float, default=50.0)
515
+
516
+ # Masking
517
+ parser.add_argument("--mask_start", type=float, default=0.30)
518
+ parser.add_argument("--mask_end", type=float, default=0.05)
519
+ parser.add_argument("--span_length", type=int, default=3)
520
+
521
+ # Data
522
+ parser.add_argument("--max_sequences", type=int, default=1000000)
523
+ parser.add_argument("--use_streaming", action="store_true")
524
+
525
+ # System
526
+ parser.add_argument("--output_dir", default="./outputs/pretrain")
527
+ parser.add_argument("--num_workers", type=int, default=8)
528
+ parser.add_argument("--log_interval", type=int, default=100)
529
+ parser.add_argument("--eval_interval", type=int, default=5000)
530
+ parser.add_argument("--save_interval", type=int, default=5000)
531
+ parser.add_argument("--use_amp", action="store_true")
532
+ parser.add_argument("--use_flash_attn", action="store_true")
533
+ parser.add_argument("--resume_from", default="")
534
+ parser.add_argument("--gradient_checkpointing", action="store_true")
535
+ parser.add_argument("--seed", type=int, default=42)
536
+
537
+ # Tracking
538
+ parser.add_argument("--use_trackio", action="store_true")
539
+ parser.add_argument("--trackio_project", default="modern-protein-lm")
540
+ parser.add_argument("--trackio_space_id", default="")
541
+
542
+ return parser.parse_args()
543
+
544
+
545
+ def main():
546
+ args = parse_args()
547
+
548
+ rank, world_size, local_rank = setup_distributed()
549
+
550
+ # Set seed
551
+ random.seed(args.seed + rank)
552
+ np.random.seed(args.seed + rank)
553
+ torch.manual_seed(args.seed + rank)
554
+
555
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
556
+
557
+ # Load data
558
+ tokenizer = ProteinTokenizer()
559
+ sequences = load_sequences(args)
560
+
561
+ if world_size > 1:
562
+ dist.barrier()
563
+
564
+ # Split
565
+ n_train = int(0.95 * len(sequences))
566
+ train_seqs = sequences[:n_train]
567
+ eval_seqs = sequences[n_train:]
568
+
569
+ train_dataset = PretrainDataset(train_seqs, tokenizer, args)
570
+ eval_dataset = PretrainDataset(eval_seqs, tokenizer, args)
571
+
572
+ if world_size > 1:
573
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
574
+ eval_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False)
575
+ else:
576
+ train_sampler = None
577
+ eval_sampler = None
578
+
579
+ train_loader = DataLoader(
580
+ train_dataset, batch_size=args.batch_size, sampler=train_sampler,
581
+ num_workers=args.num_workers, pin_memory=True, drop_last=True,
582
+ )
583
+ eval_loader = DataLoader(
584
+ eval_dataset, batch_size=args.batch_size, sampler=eval_sampler,
585
+ num_workers=args.num_workers, pin_memory=True, drop_last=False,
586
+ )
587
+
588
+ # Models
589
+ generator = Generator(args)
590
+ discriminator = Discriminator(args)
591
+
592
+ gen_params = sum(p.numel() for p in generator.parameters())
593
+ log_rank0(f"Generator: {gen_params/1e6:.1f}M params")
594
+
595
+ # Resume
596
+ if args.resume_from:
597
+ checkpoint = torch.load(args.resume_from, map_location="cpu")
598
+ generator.load_state_dict(checkpoint["generator"])
599
+ discriminator.load_state_dict(checkpoint["discriminator"])
600
+ log_rank0(f"Resumed from {args.resume_from}")
601
+
602
+ trainer = Trainer(args, generator, discriminator, tokenizer, device, rank, world_size)
603
+ trainer.train(train_loader, eval_loader)
604
+
605
+ cleanup_distributed()
606
+ log_rank0("Training complete!")
607
+
608
+
609
+ if __name__ == "__main__":
610
+ main()