""" Production ELECTRA pre-training script for ModernProteinLM. Supports: single GPU, multi-GPU DDP, FSDP (optional), bf16 AMP, gradient checkpointing. """ import os import sys import argparse import math import random import time import json from typing import List, Dict, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset, DistributedSampler from torch.cuda.amp import autocast, GradScaler from transformers import get_cosine_schedule_with_warmup from datasets import load_dataset from tqdm import tqdm from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig def setup_distributed(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(local_rank) return rank, world_size, local_rank return 0, 1, 0 def cleanup_distributed(): if dist.is_initialized(): dist.destroy_process_group() def log_rank0(msg): if not dist.is_initialized() or dist.get_rank() == 0: print(msg) # ============================================================================= # TOKENIZER # ============================================================================= class ProteinTokenizer: """ESM-2 compatible protein tokenizer.""" def __init__(self): self.vocab = { "": 0, "": 1, "": 2, "": 3, "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, "B": 25, "U": 26, "Z": 27, "O": 28, "": 29, "": 30, } while len(self.vocab) < 33: self.vocab[f""] = len(self.vocab) self.id_to_token = {v: k for k, v in self.vocab.items()} self.mask_token_id = 29 self.pad_token_id = 1 self.cls_token_id = 0 self.eos_token_id = 2 def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True): tokens = [] if add_special_tokens: tokens.append(self.cls_token_id) for aa in sequence.upper(): tokens.append(self.vocab.get(aa, self.vocab[""])) if add_special_tokens: tokens.append(self.eos_token_id) if len(tokens) > max_length: tokens = tokens[:max_length] attention_mask = [1] * len(tokens) while len(tokens) < max_length: tokens.append(self.pad_token_id) attention_mask.append(0) return { "input_ids": tokens, "attention_mask": attention_mask, } # ============================================================================= # MASKING # ============================================================================= def create_span_mask(length: int, mask_ratio: float, mean_span_length: int = 3): num_to_mask = max(1, int(length * mask_ratio)) mask = [False] * length masked = 0 attempts = 0 while masked < num_to_mask and attempts < num_to_mask * 10: span_len = max(1, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked)) start = random.randint(0, max(0, length - span_len)) if any(mask[start:start+span_len]): attempts += 1 continue for i in range(start, min(start + span_len, length)): mask[i] = True masked += 1 return mask # ============================================================================= # DATASET # ============================================================================= class PretrainDataset(Dataset): def __init__(self, sequences: List[str], tokenizer, args, current_step: int = 0): self.sequences = sequences self.tokenizer = tokenizer self.args = args self.current_step = current_step def get_mask_ratio(self): progress = min(1.0, self.current_step / self.args.max_steps) return self.args.mask_start + (self.args.mask_end - self.args.mask_start) * progress def __len__(self): return len(self.sequences) def __getitem__(self, idx): seq = self.sequences[idx] encoded = self.tokenizer.encode(seq, max_length=self.args.max_seq_length) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] seq_len = sum(attention_mask) effective_len = max(1, seq_len - 2) span_mask = create_span_mask(effective_len, self.get_mask_ratio(), self.args.span_length) masked_input = input_ids.copy() labels = [-100] * len(input_ids) replaced = [False] * len(input_ids) for i in range(1, 1 + effective_len): if span_mask[i - 1]: labels[i] = input_ids[i] replaced[i] = True r = random.random() if r < 0.8: masked_input[i] = self.tokenizer.mask_token_id elif r < 0.9: masked_input[i] = random.randint(4, 28) return { "input_ids": torch.tensor(masked_input, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "mlm_labels": torch.tensor(labels, dtype=torch.long), "replaced": torch.tensor(replaced, dtype=torch.bool), "original_ids": torch.tensor(input_ids, dtype=torch.long), } def load_sequences(args): all_sequences = [] # Try HF datasets first sources = [ ("lamm-mit/protein_secondary_structure_from_PDB", "train", "input"), ("adamstogsdill/pdb_protein_dataset_100_4000_1024", "train", "sequence"), ] for dataset_name, split, seq_key in sources: try: if args.use_streaming: ds = load_dataset(dataset_name, split=split, streaming=True) count = 0 for ex in ds: seq = ex.get(seq_key, "") if isinstance(seq, str) and len(seq) >= 20: all_sequences.append(seq) count += 1 if count >= args.max_sequences: break else: ds = load_dataset(dataset_name, split=split) for ex in ds: seq = ex.get(seq_key, "") if isinstance(seq, str) and len(seq) >= 20: all_sequences.append(seq) log_rank0(f"Loaded {len(all_sequences)} from {dataset_name}") except Exception as e: log_rank0(f"Failed {dataset_name}: {e}") # Fallback to synthetic if len(all_sequences) < 1000: log_rank0("Using synthetic sequences for testing") amino_acids = "ACDEFGHIKLMNPQRSTVWY" all_sequences = [ "".join(random.choices(amino_acids, k=random.randint(50, 500))) for _ in range(min(args.max_sequences, 50000)) ] # Limit total if len(all_sequences) > args.max_sequences: random.shuffle(all_sequences) all_sequences = all_sequences[:args.max_sequences] return all_sequences # ============================================================================= # MODELS # ============================================================================= class Generator(nn.Module): def __init__(self, args): super().__init__() config = ModernProteinLMConfig( vocab_size=33, hidden_size=args.gen_hidden_size, num_hidden_layers=args.gen_num_layers, num_attention_heads=args.gen_num_heads, intermediate_size=args.gen_intermediate_size, use_geglu=True, tie_word_embeddings=True, max_position_embeddings=args.max_seq_length + 2, ) self.model = ModernProteinLM(config) def forward(self, input_ids, attention_mask, labels): return self.model(input_ids, attention_mask, labels=labels) class Discriminator(nn.Module): def __init__(self, args): super().__init__() config = ModernProteinLMConfig( vocab_size=33, hidden_size=args.hidden_size, num_hidden_layers=args.num_layers, num_attention_heads=args.num_heads, intermediate_size=args.intermediate_size, use_geglu=True, tie_word_embeddings=True, max_position_embeddings=args.max_seq_length + 2, ) self.model = ModernProteinLM(config) self.discriminator_head = nn.Linear(args.hidden_size, 1) params = sum(p.numel() for p in self.model.parameters()) log_rank0(f"Discriminator: {params/1e6:.1f}M params") def forward(self, input_ids, attention_mask, disc_labels=None): outputs = self.model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) hidden = outputs.hidden_states[-1] logits = self.discriminator_head(hidden).squeeze(-1) loss = None if disc_labels is not None: loss_fct = nn.BCEWithLogitsLoss() active = disc_labels != -100 if active.any(): loss = loss_fct(logits[active], disc_labels[active].float()) return {"loss": loss, "logits": logits, "hidden_states": hidden} # ============================================================================= # TRAINING # ============================================================================= class Trainer: def __init__(self, args, generator, discriminator, tokenizer, device, rank, world_size): self.args = args self.generator = generator.to(device) self.discriminator = discriminator.to(device) self.tokenizer = tokenizer self.device = device self.rank = rank self.world_size = world_size self.global_step = 0 if world_size > 1: self.generator = DDP(self.generator, device_ids=[rank], find_unused_parameters=False) self.discriminator = DDP(self.discriminator, device_ids=[rank], find_unused_parameters=False) self.gen_opt = torch.optim.AdamW( generator.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay ) self.disc_opt = torch.optim.AdamW( discriminator.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay ) self.gen_sched = get_cosine_schedule_with_warmup( self.gen_opt, args.warmup_steps, args.max_steps ) self.disc_sched = get_cosine_schedule_with_warmup( self.disc_opt, args.warmup_steps, args.max_steps ) self.scaler = GradScaler() if args.use_amp else None if args.gradient_checkpointing: self.generator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.generator.model.gradient_checkpointing_enable() self.discriminator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.discriminator.model.gradient_checkpointing_enable() # Trackio self.trackio = None if args.use_trackio: try: import trackio trackio.init(project=args.trackio_project, space_id=args.trackio_space_id or None) self.trackio = trackio log_rank0("Trackio initialized") except ImportError: log_rank0("Trackio not available") def train_step(self, batch): input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) mlm_labels = batch["mlm_labels"].to(self.device) replaced = batch["replaced"].to(self.device) original_ids = batch["original_ids"].to(self.device) with autocast(enabled=self.args.use_amp): # Generator gen_out = self.generator(input_ids, attention_mask, mlm_labels) gen_loss = gen_out.loss # Sample corrupted input with torch.no_grad(): gen_logits = gen_out.logits gen_probs = F.softmax(gen_logits, dim=-1) sampled = torch.multinomial( gen_probs.view(-1, gen_probs.size(-1)), 1 ).view(gen_probs.shape[:-1]) corrupted = original_ids.clone() mask_pos = mlm_labels != -100 corrupted[mask_pos] = sampled[mask_pos] # Discriminator disc_labels = torch.ones_like(original_ids, dtype=torch.float) disc_labels[replaced] = 0.0 disc_labels[attention_mask == 0] = -100 disc_out = self.discriminator(corrupted, attention_mask, disc_labels) disc_loss = disc_out["loss"] total_loss = self.args.gen_weight * gen_loss + self.args.disc_weight * disc_loss # Backward if self.scaler: self.scaler.scale(total_loss).backward() self.scaler.unscale_(self.gen_opt) self.scaler.unscale_(self.disc_opt) torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip) torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip) self.scaler.step(self.gen_opt) self.scaler.step(self.disc_opt) self.scaler.update() else: total_loss.backward() torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip) torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip) self.gen_opt.step() self.disc_opt.step() self.gen_sched.step() self.disc_sched.step() self.gen_opt.zero_grad() self.disc_opt.zero_grad() self.global_step += 1 return { "gen_loss": gen_loss.item(), "disc_loss": disc_loss.item() if disc_loss else 0.0, "total_loss": total_loss.item(), "lr": self.gen_sched.get_last_lr()[0], } def evaluate(self, eval_loader): self.generator.eval() self.discriminator.eval() total_gen = 0.0 total_disc = 0.0 n = 0 with torch.no_grad(): for batch in eval_loader: input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) mlm_labels = batch["mlm_labels"].to(self.device) replaced = batch["replaced"].to(self.device) original_ids = batch["original_ids"].to(self.device) gen_out = self.generator(input_ids, attention_mask, mlm_labels) total_gen += gen_out.loss.item() disc_labels = torch.ones_like(original_ids, dtype=torch.float) disc_labels[replaced] = 0.0 disc_labels[attention_mask == 0] = -100 disc_out = self.discriminator(input_ids, attention_mask, disc_labels) if disc_out["loss"]: total_disc += disc_out["loss"].item() n += 1 self.generator.train() self.discriminator.train() return {"gen_loss": total_gen / max(n, 1), "disc_loss": total_disc / max(n, 1)} def save(self, path, name): save_dir = os.path.join(path, name) os.makedirs(save_dir, exist_ok=True) gen_state = self.generator.module.state_dict() if self.world_size > 1 else self.generator.state_dict() disc_state = self.discriminator.module.state_dict() if self.world_size > 1 else self.discriminator.state_dict() torch.save({ "generator": gen_state, "discriminator": disc_state, "step": self.global_step, }, os.path.join(save_dir, "checkpoint.pt")) log_rank0(f"Saved checkpoint to {save_dir}") def train(self, train_loader, eval_loader=None): log_rank0(f"\n{'='*60}") log_rank0(f"ELECTRA Pre-training: {self.args.max_steps} steps") log_rank0(f"{'='*60}\n") self.generator.train() self.discriminator.train() epoch = 0 while self.global_step < self.args.max_steps: epoch += 1 if isinstance(train_loader.sampler, DistributedSampler): train_loader.sampler.set_epoch(epoch) for batch in train_loader: if self.global_step >= self.args.max_steps: break metrics = self.train_step(batch) if self.global_step % self.args.log_interval == 0 and self.rank == 0: log_rank0( f"Step {self.global_step:6d} | " f"gen_loss={metrics['gen_loss']:.4f} | " f"disc_loss={metrics['disc_loss']:.4f} | " f"total={metrics['total_loss']:.4f} | " f"lr={metrics['lr']:.2e}" ) if self.trackio: self.trackio.log(metrics, step=self.global_step) if eval_loader and self.global_step % self.args.eval_interval == 0: eval_metrics = self.evaluate(eval_loader) if self.rank == 0: log_rank0(f"Eval @ {self.global_step}: gen={eval_metrics['gen_loss']:.4f}, disc={eval_metrics['disc_loss']:.4f}") if self.trackio: self.trackio.log({f"eval_{k}": v for k, v in eval_metrics.items()}, step=self.global_step) if self.global_step % self.args.save_interval == 0: self.save(self.args.output_dir, f"step_{self.global_step}") # Final save self.save(self.args.output_dir, "final") # ============================================================================= # MAIN # ============================================================================= def parse_args(): parser = argparse.ArgumentParser() # Model parser.add_argument("--hidden_size", type=int, default=576) parser.add_argument("--num_layers", type=int, default=28) parser.add_argument("--num_heads", type=int, default=9) parser.add_argument("--intermediate_size", type=int, default=2304) parser.add_argument("--gen_hidden_size", type=int, default=320) parser.add_argument("--gen_num_layers", type=int, default=8) parser.add_argument("--gen_num_heads", type=int, default=8) parser.add_argument("--gen_intermediate_size", type=int, default=1280) parser.add_argument("--max_seq_length", type=int, default=1024) # Training parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--max_steps", type=int, default=100000) parser.add_argument("--warmup_steps", type=int, default=10000) parser.add_argument("--lr", type=float, default=5e-4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--gen_weight", type=float, default=1.0) parser.add_argument("--disc_weight", type=float, default=50.0) # Masking parser.add_argument("--mask_start", type=float, default=0.30) parser.add_argument("--mask_end", type=float, default=0.05) parser.add_argument("--span_length", type=int, default=3) # Data parser.add_argument("--max_sequences", type=int, default=1000000) parser.add_argument("--use_streaming", action="store_true") # System parser.add_argument("--output_dir", default="./outputs/pretrain") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--eval_interval", type=int, default=5000) parser.add_argument("--save_interval", type=int, default=5000) parser.add_argument("--use_amp", action="store_true") parser.add_argument("--use_flash_attn", action="store_true") parser.add_argument("--resume_from", default="") parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument("--seed", type=int, default=42) # Tracking parser.add_argument("--use_trackio", action="store_true") parser.add_argument("--trackio_project", default="modern-protein-lm") parser.add_argument("--trackio_space_id", default="") return parser.parse_args() def main(): args = parse_args() rank, world_size, local_rank = setup_distributed() # Set seed random.seed(args.seed + rank) np.random.seed(args.seed + rank) torch.manual_seed(args.seed + rank) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") # Load data tokenizer = ProteinTokenizer() sequences = load_sequences(args) if world_size > 1: dist.barrier() # Split n_train = int(0.95 * len(sequences)) train_seqs = sequences[:n_train] eval_seqs = sequences[n_train:] train_dataset = PretrainDataset(train_seqs, tokenizer, args) eval_dataset = PretrainDataset(eval_seqs, tokenizer, args) if world_size > 1: train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) eval_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) else: train_sampler = None eval_sampler = None train_loader = DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, drop_last=True, ) eval_loader = DataLoader( eval_dataset, batch_size=args.batch_size, sampler=eval_sampler, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) # Models generator = Generator(args) discriminator = Discriminator(args) gen_params = sum(p.numel() for p in generator.parameters()) log_rank0(f"Generator: {gen_params/1e6:.1f}M params") # Resume if args.resume_from: checkpoint = torch.load(args.resume_from, map_location="cpu") generator.load_state_dict(checkpoint["generator"]) discriminator.load_state_dict(checkpoint["discriminator"]) log_rank0(f"Resumed from {args.resume_from}") trainer = Trainer(args, generator, discriminator, tokenizer, device, rank, world_size) trainer.train(train_loader, eval_loader) cleanup_distributed() log_rank0("Training complete!") if __name__ == "__main__": main()