ModernProteinLM / electra_pretrain.py
GrimSqueaker's picture
Upload electra_pretrain.py with huggingface_hub
f6fb767 verified
"""
ELECTRA-style discriminative pre-training for ModernProteinLM.
Generator (small): ~25% of discriminator size, trained with MLM.
Discriminator (main model): Trained to detect replaced tokens (RTD objective).
Key improvements over standard ELECTRA:
1. Curriculum masking: start at 30%, decay to 5%
2. Span masking: mask contiguous regions (protein structural motifs)
3. Generator-distillation: generator temperature annealing
4. No NSP, no dropout (following ESM-2)
"""
import os
import math
import random
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import (
PreTrainedTokenizerFast,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from datasets import load_dataset, concatenate_datasets
import numpy as np
from tqdm import tqdm
from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
class ProteinTokenizer:
"""Simple protein tokenizer matching ESM-2 vocab."""
ALL_AA = "LAGVSERTIDPQKNFYWMHCXBUZO"
def __init__(self):
# ESM-2 vocab
# 0: <cls>, 1: <pad>, 2: <eos>, 3: <unk>
# 4-29: amino acids
# 30: <mask>, 31: <sep>, 32: <mask> (duplicate for compatibility)
self.vocab = {
"<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
"L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
"T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
"F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
"B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29,
"<sep>": 30, # additional sep
}
# Pad to 33 for ESM compatibility
while len(self.vocab) < 33:
self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
self.id_to_token = {v: k for k, v in self.vocab.items()}
self.mask_token_id = 29
self.pad_token_id = 1
self.cls_token_id = 0
self.eos_token_id = 2
def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True):
tokens = []
if add_special_tokens:
tokens.append(self.cls_token_id)
for aa in sequence.upper():
if aa in self.vocab:
tokens.append(self.vocab[aa])
else:
tokens.append(self.vocab["<unk>"])
if add_special_tokens:
tokens.append(self.eos_token_id)
# Truncate or pad
if len(tokens) > max_length:
tokens = tokens[:max_length]
attention_mask = [1] * len(tokens)
while len(tokens) < max_length:
tokens.append(self.pad_token_id)
attention_mask.append(0)
return {
"input_ids": tokens,
"attention_mask": attention_mask,
}
def batch_encode(self, sequences: List[str], max_length: int = 1024):
results = [self.encode(seq, max_length) for seq in sequences]
return {
"input_ids": torch.tensor([r["input_ids"] for r in results], dtype=torch.long),
"attention_mask": torch.tensor([r["attention_mask"] for r in results], dtype=torch.long),
}
def decode(self, token_ids):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
return "".join([self.id_to_token.get(t, "<unk>") for t in token_ids])
def create_span_mask(length, mask_ratio=0.30, mean_span_length=3, min_span_length=1):
"""Create span mask for protein sequences."""
num_to_mask = max(1, int(length * mask_ratio))
mask = [False] * length
attempts = 0
masked = 0
while masked < num_to_mask and attempts < num_to_mask * 10:
span_len = max(min_span_length, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked))
start = random.randint(0, max(0, length - span_len - 1))
# Don't mask if already masked
if any(mask[start:start+span_len]):
attempts += 1
continue
for i in range(start, min(start + span_len, length)):
mask[i] = True
masked += 1
return mask
class ProteinDataset(Dataset):
def __init__(self, sequences, tokenizer, max_length=1024, mask_ratio=0.30,
mean_span_length=3, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
total_steps=100000, current_step=0):
self.sequences = sequences
self.tokenizer = tokenizer
self.max_length = max_length
self.mean_span_length = mean_span_length
self.curriculum_start_ratio = curriculum_start_ratio
self.curriculum_end_ratio = curriculum_end_ratio
self.total_steps = total_steps
self.current_step = current_step
def get_current_mask_ratio(self):
"""Linear decay from start to end ratio."""
progress = min(1.0, self.current_step / self.total_steps)
return self.curriculum_start_ratio + (self.curriculum_end_ratio - self.curriculum_start_ratio) * progress
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
seq = self.sequences[idx]
encoded = self.tokenizer.encode(seq, max_length=self.max_length)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Find actual sequence length (before padding)
seq_len = sum(attention_mask)
# Exclude special tokens from masking
effective_len = seq_len - 2 if seq_len > 2 else seq_len
# Apply span masking
mask_ratio = self.get_current_mask_ratio()
span_mask = create_span_mask(effective_len, mask_ratio, self.mean_span_length)
# Create masked input and labels
masked_input = input_ids.copy()
labels = [-100] * len(input_ids) # -100 = ignore in loss
replaced = [False] * len(input_ids) # For discriminator
for i in range(1, 1 + effective_len): # Skip CLS
if span_mask[i - 1]:
labels[i] = input_ids[i]
replaced[i] = True
# 80% mask, 10% random, 10% keep
r = random.random()
if r < 0.8:
masked_input[i] = self.tokenizer.mask_token_id
elif r < 0.9:
masked_input[i] = random.randint(4, 28) # Random AA
# else: keep original
return {
"input_ids": torch.tensor(masked_input, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
"replaced": torch.tensor(replaced, dtype=torch.bool),
"original_ids": torch.tensor(input_ids, dtype=torch.long),
}
class GeneratorModel(nn.Module):
"""Small generator model for ELECTRA."""
def __init__(self, vocab_size, hidden_size=256, num_layers=4, num_heads=4, intermediate_size=1024):
super().__init__()
config = ModernProteinLMConfig(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
intermediate_size=intermediate_size,
tie_word_embeddings=True,
)
self.model = ModernProteinLM(config)
def forward(self, input_ids, attention_mask, labels):
return self.model(input_ids, attention_mask, labels=labels)
class DiscriminatorModel(ModernProteinLM):
"""Discriminator with additional classification head for RTD."""
def __init__(self, config):
super().__init__(config)
self.discriminator_head = nn.Linear(config.hidden_size, 1)
def forward(self, input_ids, attention_mask, labels=None):
outputs = super().forward(input_ids, attention_mask, return_dict=True)
hidden = outputs.hidden_states[-1] # (B, T, H)
# Discriminator logits: real vs fake
disc_logits = self.discriminator_head(hidden).squeeze(-1) # (B, T)
disc_loss = None
if labels is not None:
# labels: 1 = real, 0 = fake (replaced)
loss_fct = nn.BCEWithLogitsLoss()
active_loss = labels != -100
active_logits = disc_logits[active_loss]
active_labels = labels[active_loss].float()
disc_loss = loss_fct(active_logits, active_labels)
return {
"loss": disc_loss,
"logits": disc_logits,
"hidden_states": outputs.hidden_states,
}
class ELECTRAProteinTrainer:
def __init__(
self,
generator: GeneratorModel,
discriminator: DiscriminatorModel,
tokenizer,
train_dataset,
eval_dataset,
output_dir="./electra_protein",
lr=5e-4,
batch_size=32,
max_steps=100000,
warmup_steps=10000,
weight_decay=0.01,
grad_clip=1.0,
generator_weight=1.0,
discriminator_weight=50.0,
device="cuda",
):
self.generator = generator.to(device)
self.discriminator = discriminator.to(device)
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.output_dir = output_dir
self.device = device
self.max_steps = max_steps
self.grad_clip = grad_clip
self.generator_weight = generator_weight
self.discriminator_weight = discriminator_weight
os.makedirs(output_dir, exist_ok=True)
# Optimizers
self.gen_optimizer = torch.optim.AdamW(
generator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay
)
self.disc_optimizer = torch.optim.AdamW(
discriminator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay
)
# Schedulers
self.gen_scheduler = get_cosine_schedule_with_warmup(
self.gen_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
)
self.disc_scheduler = get_cosine_schedule_with_warmup(
self.disc_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
)
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
)
self.eval_loader = DataLoader(
eval_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
)
self.global_step = 0
self.best_eval_loss = float("inf")
def train(self):
self.generator.train()
self.discriminator.train()
pbar = tqdm(total=self.max_steps, desc="Training")
for batch in self.train_loader:
if self.global_step >= self.max_steps:
break
self._train_step(batch)
self.global_step += 1
pbar.update(1)
if self.global_step % 1000 == 0:
eval_loss = self.evaluate()
if eval_loss < self.best_eval_loss:
self.best_eval_loss = eval_loss
self.save_checkpoint("best")
self.generator.train()
self.discriminator.train()
if self.global_step % 5000 == 0:
self.save_checkpoint(f"step_{self.global_step}")
pbar.close()
self.save_checkpoint("final")
def _train_step(self, batch):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
mlm_labels = batch["labels"].to(self.device)
replaced_positions = batch["replaced"].to(self.device)
original_ids = batch["original_ids"].to(self.device)
# ====== GENERATOR STEP ======
gen_outputs = self.generator(input_ids, attention_mask, mlm_labels)
gen_loss = gen_outputs.loss
# Sample from generator to create corrupted input for discriminator
with torch.no_grad():
gen_logits = gen_outputs.logits # (B, T, V)
gen_probs = F.softmax(gen_logits, dim=-1)
sampled_ids = torch.multinomial(
gen_probs.view(-1, gen_probs.size(-1)), 1
).view(gen_probs.shape[:-1])
# Replace masked positions with generator samples
corrupted_input = original_ids.clone()
mask_positions = mlm_labels != -100
corrupted_input[mask_positions] = sampled_ids[mask_positions]
# ====== DISCRIMINATOR STEP ======
# Create discriminator labels: 1 = original, 0 = replaced
disc_labels = torch.ones_like(original_ids, dtype=torch.float) # (B, T)
disc_labels[replaced_positions] = 0.0
# Ignore padding
disc_labels[attention_mask == 0] = -100
disc_outputs = self.discriminator(corrupted_input, attention_mask, disc_labels)
disc_loss = disc_outputs["loss"]
# ====== BACKWARD ======
# Combined loss with weighting
total_loss = self.generator_weight * gen_loss + self.discriminator_weight * disc_loss
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip)
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip)
self.gen_optimizer.step()
self.disc_optimizer.step()
self.gen_scheduler.step()
self.disc_scheduler.step()
self.gen_optimizer.zero_grad()
self.disc_optimizer.zero_grad()
if self.global_step % 100 == 0:
pbar = tqdm.get_tqdm()
pbar.set_postfix({
"gen_loss": f"{gen_loss.item():.4f}",
"disc_loss": f"{disc_loss.item():.4f}",
"lr": f"{self.gen_scheduler.get_last_lr()[0]:.2e}",
})
def evaluate(self):
self.generator.eval()
self.discriminator.eval()
total_gen_loss = 0
total_disc_loss = 0
total_samples = 0
with torch.no_grad():
for batch in self.eval_loader:
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
mlm_labels = batch["labels"].to(self.device)
replaced_positions = batch["replaced"].to(self.device)
original_ids = batch["original_ids"].to(self.device)
gen_outputs = self.generator(input_ids, attention_mask, mlm_labels)
total_gen_loss += gen_outputs.loss.item() * input_ids.size(0)
disc_labels = torch.ones_like(original_ids, dtype=torch.float)
disc_labels[replaced_positions] = 0.0
disc_labels[attention_mask == 0] = -100
disc_outputs = self.discriminator(input_ids, attention_mask, disc_labels)
total_disc_loss += disc_outputs["loss"].item() * input_ids.size(0)
total_samples += input_ids.size(0)
avg_gen = total_gen_loss / total_samples
avg_disc = total_disc_loss / total_samples
print(f"Eval - Gen Loss: {avg_gen:.4f}, Disc Loss: {avg_disc:.4f}")
return avg_gen + avg_disc
def save_checkpoint(self, name):
path = os.path.join(self.output_dir, name)
os.makedirs(path, exist_ok=True)
torch.save({
"generator": self.generator.state_dict(),
"discriminator": self.discriminator.state_dict(),
"gen_optimizer": self.gen_optimizer.state_dict(),
"disc_optimizer": self.disc_optimizer.state_dict(),
"step": self.global_step,
}, os.path.join(path, "checkpoint.pt"))
# Save discriminator config (main model)
self.discriminator.config.save_pretrained(path)
print(f"Saved checkpoint to {path}")
def load_protein_sequences(dataset_name="lamm-mit/protein_secondary_structure_from_PDB", split="train", max_seqs=None):
"""Load protein sequences from HF dataset."""
ds = load_dataset(dataset_name, split=split, streaming=True)
sequences = []
for i, example in enumerate(ds):
if max_seqs and i >= max_seqs:
break
# Try common column names
seq = None
for key in ["input", "primary", "sequences", "sequence", "protein", "text"]:
if key in example:
seq = example[key]
break
if seq and len(seq) > 10:
sequences.append(seq)
return sequences
def main():
# Config
DISC_CONFIG = ModernProteinLMConfig(
vocab_size=33,
hidden_size=576,
num_hidden_layers=28,
num_attention_heads=9,
intermediate_size=2304,
use_geglu=True,
tie_word_embeddings=True,
max_position_embeddings=1026,
position_embedding_type="rotary",
rope_theta=10000.0,
)
# Generator: ~25% of discriminator size
GEN_CONFIG = ModernProteinLMConfig(
vocab_size=33,
hidden_size=320,
num_hidden_layers=8,
num_attention_heads=8,
intermediate_size=1280,
use_geglu=True,
tie_word_embeddings=True,
)
tokenizer = ProteinTokenizer()
# Load data
print("Loading protein sequences...")
train_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=50000)
eval_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=5000)
print(f"Loaded {len(train_seqs)} train, {len(eval_seqs)} eval sequences")
train_dataset = ProteinDataset(
train_seqs, tokenizer, max_length=1024,
curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
total_steps=100000,
)
eval_dataset = ProteinDataset(
eval_seqs, tokenizer, max_length=1024,
curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
total_steps=100000, current_step=100000, # Fixed at end ratio for eval
)
# Models
generator = GeneratorModel(
vocab_size=33,
hidden_size=GEN_CONFIG.hidden_size,
num_layers=GEN_CONFIG.num_hidden_layers,
num_heads=GEN_CONFIG.num_attention_heads,
intermediate_size=GEN_CONFIG.intermediate_size,
)
discriminator = DiscriminatorModel(DISC_CONFIG)
# Count parameters
gen_params = sum(p.numel() for p in generator.parameters())
disc_params = sum(p.numel() for p in discriminator.parameters())
print(f"Generator params: {gen_params/1e6:.1f}M")
print(f"Discriminator params: {disc_params/1e6:.1f}M")
trainer = ELECTRAProteinTrainer(
generator=generator,
discriminator=discriminator,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir="./modern_protein_electra",
lr=5e-4,
batch_size=16,
max_steps=100000,
warmup_steps=10000,
weight_decay=0.01,
grad_clip=1.0,
generator_weight=1.0,
discriminator_weight=50.0,
device="cuda" if torch.cuda.is_available() else "cpu",
)
print("Starting ELECTRA pre-training...")
trainer.train()
if __name__ == "__main__":
main()