ModernProteinLM / train_pretrain.py
GrimSqueaker's picture
Upload train_pretrain.py with huggingface_hub
10db53f verified
"""
Production ELECTRA pre-training script for ModernProteinLM.
Supports: single GPU, multi-GPU DDP, FSDP (optional), bf16 AMP, gradient checkpointing.
"""
import os
import sys
import argparse
import math
import random
import time
import json
from typing import List, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
def setup_distributed():
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
return 0, 1, 0
def cleanup_distributed():
if dist.is_initialized():
dist.destroy_process_group()
def log_rank0(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)
# =============================================================================
# TOKENIZER
# =============================================================================
class ProteinTokenizer:
"""ESM-2 compatible protein tokenizer."""
def __init__(self):
self.vocab = {
"<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
"L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
"T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
"F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
"B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29,
"<sep>": 30,
}
while len(self.vocab) < 33:
self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
self.id_to_token = {v: k for k, v in self.vocab.items()}
self.mask_token_id = 29
self.pad_token_id = 1
self.cls_token_id = 0
self.eos_token_id = 2
def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True):
tokens = []
if add_special_tokens:
tokens.append(self.cls_token_id)
for aa in sequence.upper():
tokens.append(self.vocab.get(aa, self.vocab["<unk>"]))
if add_special_tokens:
tokens.append(self.eos_token_id)
if len(tokens) > max_length:
tokens = tokens[:max_length]
attention_mask = [1] * len(tokens)
while len(tokens) < max_length:
tokens.append(self.pad_token_id)
attention_mask.append(0)
return {
"input_ids": tokens,
"attention_mask": attention_mask,
}
# =============================================================================
# MASKING
# =============================================================================
def create_span_mask(length: int, mask_ratio: float, mean_span_length: int = 3):
num_to_mask = max(1, int(length * mask_ratio))
mask = [False] * length
masked = 0
attempts = 0
while masked < num_to_mask and attempts < num_to_mask * 10:
span_len = max(1, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked))
start = random.randint(0, max(0, length - span_len))
if any(mask[start:start+span_len]):
attempts += 1
continue
for i in range(start, min(start + span_len, length)):
mask[i] = True
masked += 1
return mask
# =============================================================================
# DATASET
# =============================================================================
class PretrainDataset(Dataset):
def __init__(self, sequences: List[str], tokenizer, args, current_step: int = 0):
self.sequences = sequences
self.tokenizer = tokenizer
self.args = args
self.current_step = current_step
def get_mask_ratio(self):
progress = min(1.0, self.current_step / self.args.max_steps)
return self.args.mask_start + (self.args.mask_end - self.args.mask_start) * progress
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
seq = self.sequences[idx]
encoded = self.tokenizer.encode(seq, max_length=self.args.max_seq_length)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
seq_len = sum(attention_mask)
effective_len = max(1, seq_len - 2)
span_mask = create_span_mask(effective_len, self.get_mask_ratio(), self.args.span_length)
masked_input = input_ids.copy()
labels = [-100] * len(input_ids)
replaced = [False] * len(input_ids)
for i in range(1, 1 + effective_len):
if span_mask[i - 1]:
labels[i] = input_ids[i]
replaced[i] = True
r = random.random()
if r < 0.8:
masked_input[i] = self.tokenizer.mask_token_id
elif r < 0.9:
masked_input[i] = random.randint(4, 28)
return {
"input_ids": torch.tensor(masked_input, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"mlm_labels": torch.tensor(labels, dtype=torch.long),
"replaced": torch.tensor(replaced, dtype=torch.bool),
"original_ids": torch.tensor(input_ids, dtype=torch.long),
}
def load_sequences(args):
all_sequences = []
# Try HF datasets first
sources = [
("lamm-mit/protein_secondary_structure_from_PDB", "train", "input"),
("adamstogsdill/pdb_protein_dataset_100_4000_1024", "train", "sequence"),
]
for dataset_name, split, seq_key in sources:
try:
if args.use_streaming:
ds = load_dataset(dataset_name, split=split, streaming=True)
count = 0
for ex in ds:
seq = ex.get(seq_key, "")
if isinstance(seq, str) and len(seq) >= 20:
all_sequences.append(seq)
count += 1
if count >= args.max_sequences:
break
else:
ds = load_dataset(dataset_name, split=split)
for ex in ds:
seq = ex.get(seq_key, "")
if isinstance(seq, str) and len(seq) >= 20:
all_sequences.append(seq)
log_rank0(f"Loaded {len(all_sequences)} from {dataset_name}")
except Exception as e:
log_rank0(f"Failed {dataset_name}: {e}")
# Fallback to synthetic
if len(all_sequences) < 1000:
log_rank0("Using synthetic sequences for testing")
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
all_sequences = [
"".join(random.choices(amino_acids, k=random.randint(50, 500)))
for _ in range(min(args.max_sequences, 50000))
]
# Limit total
if len(all_sequences) > args.max_sequences:
random.shuffle(all_sequences)
all_sequences = all_sequences[:args.max_sequences]
return all_sequences
# =============================================================================
# MODELS
# =============================================================================
class Generator(nn.Module):
def __init__(self, args):
super().__init__()
config = ModernProteinLMConfig(
vocab_size=33,
hidden_size=args.gen_hidden_size,
num_hidden_layers=args.gen_num_layers,
num_attention_heads=args.gen_num_heads,
intermediate_size=args.gen_intermediate_size,
use_geglu=True,
tie_word_embeddings=True,
max_position_embeddings=args.max_seq_length + 2,
)
self.model = ModernProteinLM(config)
def forward(self, input_ids, attention_mask, labels):
return self.model(input_ids, attention_mask, labels=labels)
class Discriminator(nn.Module):
def __init__(self, args):
super().__init__()
config = ModernProteinLMConfig(
vocab_size=33,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
intermediate_size=args.intermediate_size,
use_geglu=True,
tie_word_embeddings=True,
max_position_embeddings=args.max_seq_length + 2,
)
self.model = ModernProteinLM(config)
self.discriminator_head = nn.Linear(args.hidden_size, 1)
params = sum(p.numel() for p in self.model.parameters())
log_rank0(f"Discriminator: {params/1e6:.1f}M params")
def forward(self, input_ids, attention_mask, disc_labels=None):
outputs = self.model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
hidden = outputs.hidden_states[-1]
logits = self.discriminator_head(hidden).squeeze(-1)
loss = None
if disc_labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
active = disc_labels != -100
if active.any():
loss = loss_fct(logits[active], disc_labels[active].float())
return {"loss": loss, "logits": logits, "hidden_states": hidden}
# =============================================================================
# TRAINING
# =============================================================================
class Trainer:
def __init__(self, args, generator, discriminator, tokenizer, device, rank, world_size):
self.args = args
self.generator = generator.to(device)
self.discriminator = discriminator.to(device)
self.tokenizer = tokenizer
self.device = device
self.rank = rank
self.world_size = world_size
self.global_step = 0
if world_size > 1:
self.generator = DDP(self.generator, device_ids=[rank], find_unused_parameters=False)
self.discriminator = DDP(self.discriminator, device_ids=[rank], find_unused_parameters=False)
self.gen_opt = torch.optim.AdamW(
generator.parameters(), lr=args.lr,
betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay
)
self.disc_opt = torch.optim.AdamW(
discriminator.parameters(), lr=args.lr,
betas=(0.9, 0.98), eps=1e-6, weight_decay=args.weight_decay
)
self.gen_sched = get_cosine_schedule_with_warmup(
self.gen_opt, args.warmup_steps, args.max_steps
)
self.disc_sched = get_cosine_schedule_with_warmup(
self.disc_opt, args.warmup_steps, args.max_steps
)
self.scaler = GradScaler() if args.use_amp else None
if args.gradient_checkpointing:
self.generator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.generator.model.gradient_checkpointing_enable()
self.discriminator.module.model.gradient_checkpointing_enable() if world_size > 1 else self.discriminator.model.gradient_checkpointing_enable()
# Trackio
self.trackio = None
if args.use_trackio:
try:
import trackio
trackio.init(project=args.trackio_project, space_id=args.trackio_space_id or None)
self.trackio = trackio
log_rank0("Trackio initialized")
except ImportError:
log_rank0("Trackio not available")
def train_step(self, batch):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
mlm_labels = batch["mlm_labels"].to(self.device)
replaced = batch["replaced"].to(self.device)
original_ids = batch["original_ids"].to(self.device)
with autocast(enabled=self.args.use_amp):
# Generator
gen_out = self.generator(input_ids, attention_mask, mlm_labels)
gen_loss = gen_out.loss
# Sample corrupted input
with torch.no_grad():
gen_logits = gen_out.logits
gen_probs = F.softmax(gen_logits, dim=-1)
sampled = torch.multinomial(
gen_probs.view(-1, gen_probs.size(-1)), 1
).view(gen_probs.shape[:-1])
corrupted = original_ids.clone()
mask_pos = mlm_labels != -100
corrupted[mask_pos] = sampled[mask_pos]
# Discriminator
disc_labels = torch.ones_like(original_ids, dtype=torch.float)
disc_labels[replaced] = 0.0
disc_labels[attention_mask == 0] = -100
disc_out = self.discriminator(corrupted, attention_mask, disc_labels)
disc_loss = disc_out["loss"]
total_loss = self.args.gen_weight * gen_loss + self.args.disc_weight * disc_loss
# Backward
if self.scaler:
self.scaler.scale(total_loss).backward()
self.scaler.unscale_(self.gen_opt)
self.scaler.unscale_(self.disc_opt)
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip)
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip)
self.scaler.step(self.gen_opt)
self.scaler.step(self.disc_opt)
self.scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.args.grad_clip)
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.args.grad_clip)
self.gen_opt.step()
self.disc_opt.step()
self.gen_sched.step()
self.disc_sched.step()
self.gen_opt.zero_grad()
self.disc_opt.zero_grad()
self.global_step += 1
return {
"gen_loss": gen_loss.item(),
"disc_loss": disc_loss.item() if disc_loss else 0.0,
"total_loss": total_loss.item(),
"lr": self.gen_sched.get_last_lr()[0],
}
def evaluate(self, eval_loader):
self.generator.eval()
self.discriminator.eval()
total_gen = 0.0
total_disc = 0.0
n = 0
with torch.no_grad():
for batch in eval_loader:
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
mlm_labels = batch["mlm_labels"].to(self.device)
replaced = batch["replaced"].to(self.device)
original_ids = batch["original_ids"].to(self.device)
gen_out = self.generator(input_ids, attention_mask, mlm_labels)
total_gen += gen_out.loss.item()
disc_labels = torch.ones_like(original_ids, dtype=torch.float)
disc_labels[replaced] = 0.0
disc_labels[attention_mask == 0] = -100
disc_out = self.discriminator(input_ids, attention_mask, disc_labels)
if disc_out["loss"]:
total_disc += disc_out["loss"].item()
n += 1
self.generator.train()
self.discriminator.train()
return {"gen_loss": total_gen / max(n, 1), "disc_loss": total_disc / max(n, 1)}
def save(self, path, name):
save_dir = os.path.join(path, name)
os.makedirs(save_dir, exist_ok=True)
gen_state = self.generator.module.state_dict() if self.world_size > 1 else self.generator.state_dict()
disc_state = self.discriminator.module.state_dict() if self.world_size > 1 else self.discriminator.state_dict()
torch.save({
"generator": gen_state,
"discriminator": disc_state,
"step": self.global_step,
}, os.path.join(save_dir, "checkpoint.pt"))
log_rank0(f"Saved checkpoint to {save_dir}")
def train(self, train_loader, eval_loader=None):
log_rank0(f"\n{'='*60}")
log_rank0(f"ELECTRA Pre-training: {self.args.max_steps} steps")
log_rank0(f"{'='*60}\n")
self.generator.train()
self.discriminator.train()
epoch = 0
while self.global_step < self.args.max_steps:
epoch += 1
if isinstance(train_loader.sampler, DistributedSampler):
train_loader.sampler.set_epoch(epoch)
for batch in train_loader:
if self.global_step >= self.args.max_steps:
break
metrics = self.train_step(batch)
if self.global_step % self.args.log_interval == 0 and self.rank == 0:
log_rank0(
f"Step {self.global_step:6d} | "
f"gen_loss={metrics['gen_loss']:.4f} | "
f"disc_loss={metrics['disc_loss']:.4f} | "
f"total={metrics['total_loss']:.4f} | "
f"lr={metrics['lr']:.2e}"
)
if self.trackio:
self.trackio.log(metrics, step=self.global_step)
if eval_loader and self.global_step % self.args.eval_interval == 0:
eval_metrics = self.evaluate(eval_loader)
if self.rank == 0:
log_rank0(f"Eval @ {self.global_step}: gen={eval_metrics['gen_loss']:.4f}, disc={eval_metrics['disc_loss']:.4f}")
if self.trackio:
self.trackio.log({f"eval_{k}": v for k, v in eval_metrics.items()}, step=self.global_step)
if self.global_step % self.args.save_interval == 0:
self.save(self.args.output_dir, f"step_{self.global_step}")
# Final save
self.save(self.args.output_dir, "final")
# =============================================================================
# MAIN
# =============================================================================
def parse_args():
parser = argparse.ArgumentParser()
# Model
parser.add_argument("--hidden_size", type=int, default=576)
parser.add_argument("--num_layers", type=int, default=28)
parser.add_argument("--num_heads", type=int, default=9)
parser.add_argument("--intermediate_size", type=int, default=2304)
parser.add_argument("--gen_hidden_size", type=int, default=320)
parser.add_argument("--gen_num_layers", type=int, default=8)
parser.add_argument("--gen_num_heads", type=int, default=8)
parser.add_argument("--gen_intermediate_size", type=int, default=1280)
parser.add_argument("--max_seq_length", type=int, default=1024)
# Training
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_steps", type=int, default=100000)
parser.add_argument("--warmup_steps", type=int, default=10000)
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--gen_weight", type=float, default=1.0)
parser.add_argument("--disc_weight", type=float, default=50.0)
# Masking
parser.add_argument("--mask_start", type=float, default=0.30)
parser.add_argument("--mask_end", type=float, default=0.05)
parser.add_argument("--span_length", type=int, default=3)
# Data
parser.add_argument("--max_sequences", type=int, default=1000000)
parser.add_argument("--use_streaming", action="store_true")
# System
parser.add_argument("--output_dir", default="./outputs/pretrain")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--eval_interval", type=int, default=5000)
parser.add_argument("--save_interval", type=int, default=5000)
parser.add_argument("--use_amp", action="store_true")
parser.add_argument("--use_flash_attn", action="store_true")
parser.add_argument("--resume_from", default="")
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument("--seed", type=int, default=42)
# Tracking
parser.add_argument("--use_trackio", action="store_true")
parser.add_argument("--trackio_project", default="modern-protein-lm")
parser.add_argument("--trackio_space_id", default="")
return parser.parse_args()
def main():
args = parse_args()
rank, world_size, local_rank = setup_distributed()
# Set seed
random.seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.manual_seed(args.seed + rank)
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
# Load data
tokenizer = ProteinTokenizer()
sequences = load_sequences(args)
if world_size > 1:
dist.barrier()
# Split
n_train = int(0.95 * len(sequences))
train_seqs = sequences[:n_train]
eval_seqs = sequences[n_train:]
train_dataset = PretrainDataset(train_seqs, tokenizer, args)
eval_dataset = PretrainDataset(eval_seqs, tokenizer, args)
if world_size > 1:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
eval_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False)
else:
train_sampler = None
eval_sampler = None
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, sampler=train_sampler,
num_workers=args.num_workers, pin_memory=True, drop_last=True,
)
eval_loader = DataLoader(
eval_dataset, batch_size=args.batch_size, sampler=eval_sampler,
num_workers=args.num_workers, pin_memory=True, drop_last=False,
)
# Models
generator = Generator(args)
discriminator = Discriminator(args)
gen_params = sum(p.numel() for p in generator.parameters())
log_rank0(f"Generator: {gen_params/1e6:.1f}M params")
# Resume
if args.resume_from:
checkpoint = torch.load(args.resume_from, map_location="cpu")
generator.load_state_dict(checkpoint["generator"])
discriminator.load_state_dict(checkpoint["discriminator"])
log_rank0(f"Resumed from {args.resume_from}")
trainer = Trainer(args, generator, discriminator, tokenizer, device, rank, world_size)
trainer.train(train_loader, eval_loader)
cleanup_distributed()
log_rank0("Training complete!")
if __name__ == "__main__":
main()