| """ |
| Production ELECTRA pre-training script for ModernProteinLM. |
| Supports: single GPU, multi-GPU DDP, FSDP (optional), bf16 AMP, gradient checkpointing. |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import math |
| import random |
| import time |
| import json |
| from typing import List, Dict, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler |
| from torch.cuda.amp import autocast, GradScaler |
| from transformers import get_cosine_schedule_with_warmup |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig |
|
|
|
|
| def setup_distributed(): |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) |
| torch.cuda.set_device(local_rank) |
| return rank, world_size, local_rank |
| return 0, 1, 0 |
|
|
|
|
| def cleanup_distributed(): |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def log_rank0(msg): |
| if not dist.is_initialized() or dist.get_rank() == 0: |
| print(msg) |
|
|
|
|
| |
| |
| |
|
|
| class ProteinTokenizer: |
| """ESM-2 compatible protein tokenizer.""" |
| |
| def __init__(self): |
| self.vocab = { |
| "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3, |
| "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, |
| "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, |
| "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, |
| "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29, |
| "<sep>": 30, |
| } |
| while len(self.vocab) < 33: |
| self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab) |
| self.id_to_token = {v: k for k, v in self.vocab.items()} |
| self.mask_token_id = 29 |
| self.pad_token_id = 1 |
| self.cls_token_id = 0 |
| self.eos_token_id = 2 |
| |
| def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True): |
| tokens = [] |
| if add_special_tokens: |
| tokens.append(self.cls_token_id) |
| for aa in sequence.upper(): |
| tokens.append(self.vocab.get(aa, self.vocab["<unk>"])) |
| if add_special_tokens: |
| tokens.append(self.eos_token_id) |
| |
| if len(tokens) > max_length: |
| tokens = tokens[:max_length] |
| |
| attention_mask = [1] * len(tokens) |
| while len(tokens) < max_length: |
| tokens.append(self.pad_token_id) |
| attention_mask.append(0) |
| |
| return { |
| "input_ids": tokens, |
| "attention_mask": attention_mask, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def create_span_mask(length: int, mask_ratio: float, mean_span_length: int = 3): |
| num_to_mask = max(1, int(length * mask_ratio)) |
| mask = [False] * length |
| |
| masked = 0 |
| attempts = 0 |
| while masked < num_to_mask and attempts < num_to_mask * 10: |
| span_len = max(1, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked)) |
| start = random.randint(0, max(0, length - span_len)) |
| if any(mask[start:start+span_len]): |
| attempts += 1 |
| continue |
| for i in range(start, min(start + span_len, length)): |
| mask[i] = True |
| masked += 1 |
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class PretrainDataset(Dataset): |
| def __init__(self, sequences: List[str], tokenizer, args, current_step: int = 0): |
| self.sequences = sequences |
| self.tokenizer = tokenizer |
| self.args = args |
| self.current_step = current_step |
| |
| def get_mask_ratio(self): |
| progress = min(1.0, self.current_step / self.args.max_steps) |
| return self.args.mask_start + (self.args.mask_end - self.args.mask_start) * progress |
| |
| def __len__(self): |
| return len(self.sequences) |
| |
| def __getitem__(self, idx): |
| seq = self.sequences[idx] |
| encoded = self.tokenizer.encode(seq, max_length=self.args.max_seq_length) |
| input_ids = encoded["input_ids"] |
| attention_mask = encoded["attention_mask"] |
| |
| seq_len = sum(attention_mask) |
| effective_len = max(1, seq_len - 2) |
| |
| span_mask = create_span_mask(effective_len, self.get_mask_ratio(), self.args.span_length) |
| |
| masked_input = input_ids.copy() |
| labels = [-100] * len(input_ids) |
| replaced = [False] * len(input_ids) |
| |
| for i in range(1, 1 + effective_len): |
| if span_mask[i - 1]: |
| labels[i] = input_ids[i] |
| replaced[i] = True |
| r = random.random() |
| if r < 0.8: |
| masked_input[i] = self.tokenizer.mask_token_id |
| elif r < 0.9: |
| masked_input[i] = random.randint(4, 28) |
| |
| return { |
| "input_ids": torch.tensor(masked_input, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "mlm_labels": torch.tensor(labels, dtype=torch.long), |
| "replaced": torch.tensor(replaced, dtype=torch.bool), |
| "original_ids": torch.tensor(input_ids, dtype=torch.long), |
| } |
|
|
|
|
| def load_sequences(args): |
| all_sequences = [] |
| |
| |
| sources = [ |
| ("lamm-mit/protein_secondary_structure_from_PDB", "train", "input"), |
| ("adamstogsdill/pdb_protein_dataset_100_4000_1024", "train", "sequence"), |
| ] |
| |
| for dataset_name, split, seq_key in sources: |
| try: |
| if args.use_streaming: |
| ds = load_dataset(dataset_name, split=split, streaming=True) |
| count = 0 |
| for ex in ds: |
| seq = ex.get(seq_key, "") |
| if isinstance(seq, str) and len(seq) >= 20: |
| all_sequences.append(seq) |
| count += 1 |
| if count >= args.max_sequences: |
| break |
| else: |
| ds = load_dataset(dataset_name, split=split) |
| for ex in ds: |
| seq = ex.get(seq_key, "") |
| if isinstance(seq, str) and len(seq) >= 20: |
| all_sequences.append(seq) |
| log_rank0(f"Loaded {len(all_sequences)} from {dataset_name}") |
| except Exception as e: |
| log_rank0(f"Failed {dataset_name}: {e}") |
| |
| |
| if len(all_sequences) < 1000: |
| log_rank0("Using synthetic sequences for testing") |
| amino_acids = "ACDEFGHIKLMNPQRSTVWY" |
| all_sequences = [ |
| "".join(random.choices(amino_acids, k=random.randint(50, 500))) |
| for _ in range(min(args.max_sequences, 50000)) |
| ] |
| |
| |
| if len(all_sequences) > args.max_sequences: |
| random.shuffle(all_sequences) |
| all_sequences = all_sequences[:args.max_sequences] |
| |
| return all_sequences |
|
|
|
|
| |
| |
| |
|
|
| class Generator(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| config = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=args.gen_hidden_size, |
| num_hidden_layers=args.gen_num_layers, |
| num_attention_heads=args.gen_num_heads, |
| intermediate_size=args.gen_intermediate_size, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| max_position_embeddings=args.max_seq_length + 2, |
| ) |
| self.model = ModernProteinLM(config) |
| |
| def forward(self, input_ids, attention_mask, labels): |
| return self.model(input_ids, attention_mask, labels=labels) |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| config = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=args.hidden_size, |
| num_hidden_layers=args.num_layers, |
| num_attention_heads=args.num_heads, |
| intermediate_size=args.intermediate_size, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| max_position_embeddings=args.max_seq_length + 2, |
| ) |
| self.model = ModernProteinLM(config) |
| self.discriminator_head = nn.Linear(args.hidden_size, 1) |
| |
| params = sum(p.numel() for p in self.model.parameters()) |
| log_rank0(f"Discriminator: {params/1e6:.1f}M params") |
| |
| def forward(self, input_ids, attention_mask, disc_labels=None): |
| outputs = self.model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) |
| hidden = outputs.hidden_states[-1] |
| logits = self.discriminator_head(hidden).squeeze(-1) |
| |
| loss = None |
| if disc_labels is not None: |
| loss_fct = nn.BCEWithLogitsLoss() |
| active = disc_labels != -100 |
| if active.any(): |
| loss = loss_fct(logits[active], disc_labels[active].float()) |
| |
| return {"loss": loss, "logits": logits, "hidden_states": hidden} |
|
|
|
|
| |
| |
| |
|
|
| class Trainer: |
| def __init__(self, args, generator, discriminator, tokenizer, device, rank, world_size): |
| self.args = args |
| self.generator = generator.to(device) |
| self.discriminator = discriminator.to(device) |
| self.tokenizer = tokenizer |
| self.device = device |
| self.rank = rank |
| self.world_size = world_size |
| self.global_step = 0 |
| |
| if world_size > 1: |
| self.generator = DDP(self.generator, device_ids=[rank], find_unused_parameters=False) |
| self.discriminator = DDP(self.discriminator, device_ids=[rank], find_unused_parameters=False) |
| |
| self.gen_opt = torch.optim.AdamW( |
| generator.parameters(), lr=args.lr, |
| betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay |
| ) |
| self.disc_opt = torch.optim.AdamW( |
| discriminator.parameters(), lr=args.lr, |
| betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay |
| ) |
| |
| self.gen_sched = get_cosine_schedule_with_warmup( |
| self.gen_opt, args.warmup_steps, args.max_steps |
| ) |
| self.disc_sched = get_cosine_schedule_with_warmup( |
| self.disc_opt, args.warmup_steps, args.max_steps |
| ) |
| |
| self.scaler = GradScaler() if args.use_amp else None |
| |
| if args.gradient_checkpointing: |
| self.generator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.generator.model.gradient_checkpointing_enable() |
| self.discriminator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.discriminator.model.gradient_checkpointing_enable() |
| |
| |
| self.trackio = None |
| if args.use_trackio: |
| try: |
| import trackio |
| trackio.init(project=args.trackio_project, space_id=args.trackio_space_id or None) |
| self.trackio = trackio |
| log_rank0("Trackio initialized") |
| except ImportError: |
| log_rank0("Trackio not available") |
| |
| def train_step(self, batch): |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| mlm_labels = batch["mlm_labels"].to(self.device) |
| replaced = batch["replaced"].to(self.device) |
| original_ids = batch["original_ids"].to(self.device) |
| |
| with autocast(enabled=self.args.use_amp): |
| |
| gen_out = self.generator(input_ids, attention_mask, mlm_labels) |
| gen_loss = gen_out.loss |
| |
| |
| with torch.no_grad(): |
| gen_logits = gen_out.logits |
| gen_probs = F.softmax(gen_logits, dim=-1) |
| sampled = torch.multinomial( |
| gen_probs.view(-1, gen_probs.size(-1)), 1 |
| ).view(gen_probs.shape[:-1]) |
| |
| corrupted = original_ids.clone() |
| mask_pos = mlm_labels != -100 |
| corrupted[mask_pos] = sampled[mask_pos] |
| |
| |
| disc_labels = torch.ones_like(original_ids, dtype=torch.float) |
| disc_labels[replaced] = 0.0 |
| disc_labels[attention_mask == 0] = -100 |
| |
| disc_out = self.discriminator(corrupted, attention_mask, disc_labels) |
| disc_loss = disc_out["loss"] |
| |
| total_loss = self.args.gen_weight * gen_loss + self.args.disc_weight * disc_loss |
| |
| |
| if self.scaler: |
| self.scaler.scale(total_loss).backward() |
| self.scaler.unscale_(self.gen_opt) |
| self.scaler.unscale_(self.disc_opt) |
| torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip) |
| torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip) |
| self.scaler.step(self.gen_opt) |
| self.scaler.step(self.disc_opt) |
| self.scaler.update() |
| else: |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip) |
| torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip) |
| self.gen_opt.step() |
| self.disc_opt.step() |
| |
| self.gen_sched.step() |
| self.disc_sched.step() |
| self.gen_opt.zero_grad() |
| self.disc_opt.zero_grad() |
| |
| self.global_step += 1 |
| |
| return { |
| "gen_loss": gen_loss.item(), |
| "disc_loss": disc_loss.item() if disc_loss else 0.0, |
| "total_loss": total_loss.item(), |
| "lr": self.gen_sched.get_last_lr()[0], |
| } |
| |
| def evaluate(self, eval_loader): |
| self.generator.eval() |
| self.discriminator.eval() |
| |
| total_gen = 0.0 |
| total_disc = 0.0 |
| n = 0 |
| |
| with torch.no_grad(): |
| for batch in eval_loader: |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| mlm_labels = batch["mlm_labels"].to(self.device) |
| replaced = batch["replaced"].to(self.device) |
| original_ids = batch["original_ids"].to(self.device) |
| |
| gen_out = self.generator(input_ids, attention_mask, mlm_labels) |
| total_gen += gen_out.loss.item() |
| |
| disc_labels = torch.ones_like(original_ids, dtype=torch.float) |
| disc_labels[replaced] = 0.0 |
| disc_labels[attention_mask == 0] = -100 |
| |
| disc_out = self.discriminator(input_ids, attention_mask, disc_labels) |
| if disc_out["loss"]: |
| total_disc += disc_out["loss"].item() |
| n += 1 |
| |
| self.generator.train() |
| self.discriminator.train() |
| |
| return {"gen_loss": total_gen / max(n, 1), "disc_loss": total_disc / max(n, 1)} |
| |
| def save(self, path, name): |
| save_dir = os.path.join(path, name) |
| os.makedirs(save_dir, exist_ok=True) |
| |
| gen_state = self.generator.module.state_dict() if self.world_size > 1 else self.generator.state_dict() |
| disc_state = self.discriminator.module.state_dict() if self.world_size > 1 else self.discriminator.state_dict() |
| |
| torch.save({ |
| "generator": gen_state, |
| "discriminator": disc_state, |
| "step": self.global_step, |
| }, os.path.join(save_dir, "checkpoint.pt")) |
| |
| log_rank0(f"Saved checkpoint to {save_dir}") |
| |
| def train(self, train_loader, eval_loader=None): |
| log_rank0(f"\n{'='*60}") |
| log_rank0(f"ELECTRA Pre-training: {self.args.max_steps} steps") |
| log_rank0(f"{'='*60}\n") |
| |
| self.generator.train() |
| self.discriminator.train() |
| |
| epoch = 0 |
| while self.global_step < self.args.max_steps: |
| epoch += 1 |
| if isinstance(train_loader.sampler, DistributedSampler): |
| train_loader.sampler.set_epoch(epoch) |
| |
| for batch in train_loader: |
| if self.global_step >= self.args.max_steps: |
| break |
| |
| metrics = self.train_step(batch) |
| |
| if self.global_step % self.args.log_interval == 0 and self.rank == 0: |
| log_rank0( |
| f"Step {self.global_step:6d} | " |
| f"gen_loss={metrics['gen_loss']:.4f} | " |
| f"disc_loss={metrics['disc_loss']:.4f} | " |
| f"total={metrics['total_loss']:.4f} | " |
| f"lr={metrics['lr']:.2e}" |
| ) |
| |
| if self.trackio: |
| self.trackio.log(metrics, step=self.global_step) |
| |
| if eval_loader and self.global_step % self.args.eval_interval == 0: |
| eval_metrics = self.evaluate(eval_loader) |
| if self.rank == 0: |
| log_rank0(f"Eval @ {self.global_step}: gen={eval_metrics['gen_loss']:.4f}, disc={eval_metrics['disc_loss']:.4f}") |
| if self.trackio: |
| self.trackio.log({f"eval_{k}": v for k, v in eval_metrics.items()}, step=self.global_step) |
| |
| if self.global_step % self.args.save_interval == 0: |
| self.save(self.args.output_dir, f"step_{self.global_step}") |
| |
| |
| self.save(self.args.output_dir, "final") |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| |
| |
| parser.add_argument("--hidden_size", type=int, default=576) |
| parser.add_argument("--num_layers", type=int, default=28) |
| parser.add_argument("--num_heads", type=int, default=9) |
| parser.add_argument("--intermediate_size", type=int, default=2304) |
| parser.add_argument("--gen_hidden_size", type=int, default=320) |
| parser.add_argument("--gen_num_layers", type=int, default=8) |
| parser.add_argument("--gen_num_heads", type=int, default=8) |
| parser.add_argument("--gen_intermediate_size", type=int, default=1280) |
| parser.add_argument("--max_seq_length", type=int, default=1024) |
| |
| |
| parser.add_argument("--batch_size", type=int, default=64) |
| parser.add_argument("--max_steps", type=int, default=100000) |
| parser.add_argument("--warmup_steps", type=int, default=10000) |
| parser.add_argument("--lr", type=float, default=5e-4) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--grad_clip", type=float, default=1.0) |
| parser.add_argument("--gen_weight", type=float, default=1.0) |
| parser.add_argument("--disc_weight", type=float, default=50.0) |
| |
| |
| parser.add_argument("--mask_start", type=float, default=0.30) |
| parser.add_argument("--mask_end", type=float, default=0.05) |
| parser.add_argument("--span_length", type=int, default=3) |
| |
| |
| parser.add_argument("--max_sequences", type=int, default=1000000) |
| parser.add_argument("--use_streaming", action="store_true") |
| |
| |
| parser.add_argument("--output_dir", default="./outputs/pretrain") |
| parser.add_argument("--num_workers", type=int, default=8) |
| parser.add_argument("--log_interval", type=int, default=100) |
| parser.add_argument("--eval_interval", type=int, default=5000) |
| parser.add_argument("--save_interval", type=int, default=5000) |
| parser.add_argument("--use_amp", action="store_true") |
| parser.add_argument("--use_flash_attn", action="store_true") |
| parser.add_argument("--resume_from", default="") |
| parser.add_argument("--gradient_checkpointing", action="store_true") |
| parser.add_argument("--seed", type=int, default=42) |
| |
| |
| parser.add_argument("--use_trackio", action="store_true") |
| parser.add_argument("--trackio_project", default="modern-protein-lm") |
| parser.add_argument("--trackio_space_id", default="") |
| |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| rank, world_size, local_rank = setup_distributed() |
| |
| |
| random.seed(args.seed + rank) |
| np.random.seed(args.seed + rank) |
| torch.manual_seed(args.seed + rank) |
| |
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
| |
| |
| tokenizer = ProteinTokenizer() |
| sequences = load_sequences(args) |
| |
| if world_size > 1: |
| dist.barrier() |
| |
| |
| n_train = int(0.95 * len(sequences)) |
| train_seqs = sequences[:n_train] |
| eval_seqs = sequences[n_train:] |
| |
| train_dataset = PretrainDataset(train_seqs, tokenizer, args) |
| eval_dataset = PretrainDataset(eval_seqs, tokenizer, args) |
| |
| if world_size > 1: |
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) |
| eval_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
| else: |
| train_sampler = None |
| eval_sampler = None |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=args.batch_size, sampler=train_sampler, |
| num_workers=args.num_workers, pin_memory=True, drop_last=True, |
| ) |
| eval_loader = DataLoader( |
| eval_dataset, batch_size=args.batch_size, sampler=eval_sampler, |
| num_workers=args.num_workers, pin_memory=True, drop_last=False, |
| ) |
| |
| |
| generator = Generator(args) |
| discriminator = Discriminator(args) |
| |
| gen_params = sum(p.numel() for p in generator.parameters()) |
| log_rank0(f"Generator: {gen_params/1e6:.1f}M params") |
| |
| |
| if args.resume_from: |
| checkpoint = torch.load(args.resume_from, map_location="cpu") |
| generator.load_state_dict(checkpoint["generator"]) |
| discriminator.load_state_dict(checkpoint["discriminator"]) |
| log_rank0(f"Resumed from {args.resume_from}") |
| |
| trainer = Trainer(args, generator, discriminator, tokenizer, device, rank, world_size) |
| trainer.train(train_loader, eval_loader) |
| |
| cleanup_distributed() |
| log_rank0("Training complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|