""" ELECTRA-style discriminative pre-training for ModernProteinLM. Generator (small): ~25% of discriminator size, trained with MLM. Discriminator (main model): Trained to detect replaced tokens (RTD objective). Key improvements over standard ELECTRA: 1. Curriculum masking: start at 30%, decay to 5% 2. Span masking: mask contiguous regions (protein structural motifs) 3. Generator-distillation: generator temperature annealing 4. No NSP, no dropout (following ESM-2) """ import os import math import random from typing import Dict, List, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from transformers import ( PreTrainedTokenizerFast, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) from datasets import load_dataset, concatenate_datasets import numpy as np from tqdm import tqdm from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig class ProteinTokenizer: """Simple protein tokenizer matching ESM-2 vocab.""" ALL_AA = "LAGVSERTIDPQKNFYWMHCXBUZO" def __init__(self): # ESM-2 vocab # 0: , 1: , 2: , 3: # 4-29: amino acids # 30: , 31: , 32: (duplicate for compatibility) self.vocab = { "": 0, "": 1, "": 2, "": 3, "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, "B": 25, "U": 26, "Z": 27, "O": 28, "": 29, "": 30, # additional sep } # Pad to 33 for ESM compatibility while len(self.vocab) < 33: self.vocab[f""] = len(self.vocab) self.id_to_token = {v: k for k, v in self.vocab.items()} self.mask_token_id = 29 self.pad_token_id = 1 self.cls_token_id = 0 self.eos_token_id = 2 def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True): tokens = [] if add_special_tokens: tokens.append(self.cls_token_id) for aa in sequence.upper(): if aa in self.vocab: tokens.append(self.vocab[aa]) else: tokens.append(self.vocab[""]) if add_special_tokens: tokens.append(self.eos_token_id) # Truncate or pad if len(tokens) > max_length: tokens = tokens[:max_length] attention_mask = [1] * len(tokens) while len(tokens) < max_length: tokens.append(self.pad_token_id) attention_mask.append(0) return { "input_ids": tokens, "attention_mask": attention_mask, } def batch_encode(self, sequences: List[str], max_length: int = 1024): results = [self.encode(seq, max_length) for seq in sequences] return { "input_ids": torch.tensor([r["input_ids"] for r in results], dtype=torch.long), "attention_mask": torch.tensor([r["attention_mask"] for r in results], dtype=torch.long), } def decode(self, token_ids): if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() return "".join([self.id_to_token.get(t, "") for t in token_ids]) def create_span_mask(length, mask_ratio=0.30, mean_span_length=3, min_span_length=1): """Create span mask for protein sequences.""" num_to_mask = max(1, int(length * mask_ratio)) mask = [False] * length attempts = 0 masked = 0 while masked < num_to_mask and attempts < num_to_mask * 10: span_len = max(min_span_length, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked)) start = random.randint(0, max(0, length - span_len - 1)) # Don't mask if already masked if any(mask[start:start+span_len]): attempts += 1 continue for i in range(start, min(start + span_len, length)): mask[i] = True masked += 1 return mask class ProteinDataset(Dataset): def __init__(self, sequences, tokenizer, max_length=1024, mask_ratio=0.30, mean_span_length=3, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, total_steps=100000, current_step=0): self.sequences = sequences self.tokenizer = tokenizer self.max_length = max_length self.mean_span_length = mean_span_length self.curriculum_start_ratio = curriculum_start_ratio self.curriculum_end_ratio = curriculum_end_ratio self.total_steps = total_steps self.current_step = current_step def get_current_mask_ratio(self): """Linear decay from start to end ratio.""" progress = min(1.0, self.current_step / self.total_steps) return self.curriculum_start_ratio + (self.curriculum_end_ratio - self.curriculum_start_ratio) * progress def __len__(self): return len(self.sequences) def __getitem__(self, idx): seq = self.sequences[idx] encoded = self.tokenizer.encode(seq, max_length=self.max_length) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] # Find actual sequence length (before padding) seq_len = sum(attention_mask) # Exclude special tokens from masking effective_len = seq_len - 2 if seq_len > 2 else seq_len # Apply span masking mask_ratio = self.get_current_mask_ratio() span_mask = create_span_mask(effective_len, mask_ratio, self.mean_span_length) # Create masked input and labels masked_input = input_ids.copy() labels = [-100] * len(input_ids) # -100 = ignore in loss replaced = [False] * len(input_ids) # For discriminator for i in range(1, 1 + effective_len): # Skip CLS if span_mask[i - 1]: labels[i] = input_ids[i] replaced[i] = True # 80% mask, 10% random, 10% keep r = random.random() if r < 0.8: masked_input[i] = self.tokenizer.mask_token_id elif r < 0.9: masked_input[i] = random.randint(4, 28) # Random AA # else: keep original return { "input_ids": torch.tensor(masked_input, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), "replaced": torch.tensor(replaced, dtype=torch.bool), "original_ids": torch.tensor(input_ids, dtype=torch.long), } class GeneratorModel(nn.Module): """Small generator model for ELECTRA.""" def __init__(self, vocab_size, hidden_size=256, num_layers=4, num_heads=4, intermediate_size=1024): super().__init__() config = ModernProteinLMConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_layers, num_attention_heads=num_heads, intermediate_size=intermediate_size, tie_word_embeddings=True, ) self.model = ModernProteinLM(config) def forward(self, input_ids, attention_mask, labels): return self.model(input_ids, attention_mask, labels=labels) class DiscriminatorModel(ModernProteinLM): """Discriminator with additional classification head for RTD.""" def __init__(self, config): super().__init__(config) self.discriminator_head = nn.Linear(config.hidden_size, 1) def forward(self, input_ids, attention_mask, labels=None): outputs = super().forward(input_ids, attention_mask, return_dict=True) hidden = outputs.hidden_states[-1] # (B, T, H) # Discriminator logits: real vs fake disc_logits = self.discriminator_head(hidden).squeeze(-1) # (B, T) disc_loss = None if labels is not None: # labels: 1 = real, 0 = fake (replaced) loss_fct = nn.BCEWithLogitsLoss() active_loss = labels != -100 active_logits = disc_logits[active_loss] active_labels = labels[active_loss].float() disc_loss = loss_fct(active_logits, active_labels) return { "loss": disc_loss, "logits": disc_logits, "hidden_states": outputs.hidden_states, } class ELECTRAProteinTrainer: def __init__( self, generator: GeneratorModel, discriminator: DiscriminatorModel, tokenizer, train_dataset, eval_dataset, output_dir="./electra_protein", lr=5e-4, batch_size=32, max_steps=100000, warmup_steps=10000, weight_decay=0.01, grad_clip=1.0, generator_weight=1.0, discriminator_weight=50.0, device="cuda", ): self.generator = generator.to(device) self.discriminator = discriminator.to(device) self.tokenizer = tokenizer self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.output_dir = output_dir self.device = device self.max_steps = max_steps self.grad_clip = grad_clip self.generator_weight = generator_weight self.discriminator_weight = discriminator_weight os.makedirs(output_dir, exist_ok=True) # Optimizers self.gen_optimizer = torch.optim.AdamW( generator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay ) self.disc_optimizer = torch.optim.AdamW( discriminator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay ) # Schedulers self.gen_scheduler = get_cosine_schedule_with_warmup( self.gen_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps ) self.disc_scheduler = get_cosine_schedule_with_warmup( self.disc_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) self.eval_loader = DataLoader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True ) self.global_step = 0 self.best_eval_loss = float("inf") def train(self): self.generator.train() self.discriminator.train() pbar = tqdm(total=self.max_steps, desc="Training") for batch in self.train_loader: if self.global_step >= self.max_steps: break self._train_step(batch) self.global_step += 1 pbar.update(1) if self.global_step % 1000 == 0: eval_loss = self.evaluate() if eval_loss < self.best_eval_loss: self.best_eval_loss = eval_loss self.save_checkpoint("best") self.generator.train() self.discriminator.train() if self.global_step % 5000 == 0: self.save_checkpoint(f"step_{self.global_step}") pbar.close() self.save_checkpoint("final") def _train_step(self, batch): input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) mlm_labels = batch["labels"].to(self.device) replaced_positions = batch["replaced"].to(self.device) original_ids = batch["original_ids"].to(self.device) # ====== GENERATOR STEP ====== gen_outputs = self.generator(input_ids, attention_mask, mlm_labels) gen_loss = gen_outputs.loss # Sample from generator to create corrupted input for discriminator with torch.no_grad(): gen_logits = gen_outputs.logits # (B, T, V) gen_probs = F.softmax(gen_logits, dim=-1) sampled_ids = torch.multinomial( gen_probs.view(-1, gen_probs.size(-1)), 1 ).view(gen_probs.shape[:-1]) # Replace masked positions with generator samples corrupted_input = original_ids.clone() mask_positions = mlm_labels != -100 corrupted_input[mask_positions] = sampled_ids[mask_positions] # ====== DISCRIMINATOR STEP ====== # Create discriminator labels: 1 = original, 0 = replaced disc_labels = torch.ones_like(original_ids, dtype=torch.float) # (B, T) disc_labels[replaced_positions] = 0.0 # Ignore padding disc_labels[attention_mask == 0] = -100 disc_outputs = self.discriminator(corrupted_input, attention_mask, disc_labels) disc_loss = disc_outputs["loss"] # ====== BACKWARD ====== # Combined loss with weighting total_loss = self.generator_weight * gen_loss + self.discriminator_weight * disc_loss total_loss.backward() torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) self.gen_optimizer.step() self.disc_optimizer.step() self.gen_scheduler.step() self.disc_scheduler.step() self.gen_optimizer.zero_grad() self.disc_optimizer.zero_grad() if self.global_step % 100 == 0: pbar = tqdm.get_tqdm() pbar.set_postfix({ "gen_loss": f"{gen_loss.item():.4f}", "disc_loss": f"{disc_loss.item():.4f}", "lr": f"{self.gen_scheduler.get_last_lr()[0]:.2e}", }) def evaluate(self): self.generator.eval() self.discriminator.eval() total_gen_loss = 0 total_disc_loss = 0 total_samples = 0 with torch.no_grad(): for batch in self.eval_loader: input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) mlm_labels = batch["labels"].to(self.device) replaced_positions = batch["replaced"].to(self.device) original_ids = batch["original_ids"].to(self.device) gen_outputs = self.generator(input_ids, attention_mask, mlm_labels) total_gen_loss += gen_outputs.loss.item() * input_ids.size(0) disc_labels = torch.ones_like(original_ids, dtype=torch.float) disc_labels[replaced_positions] = 0.0 disc_labels[attention_mask == 0] = -100 disc_outputs = self.discriminator(input_ids, attention_mask, disc_labels) total_disc_loss += disc_outputs["loss"].item() * input_ids.size(0) total_samples += input_ids.size(0) avg_gen = total_gen_loss / total_samples avg_disc = total_disc_loss / total_samples print(f"Eval - Gen Loss: {avg_gen:.4f}, Disc Loss: {avg_disc:.4f}") return avg_gen + avg_disc def save_checkpoint(self, name): path = os.path.join(self.output_dir, name) os.makedirs(path, exist_ok=True) torch.save({ "generator": self.generator.state_dict(), "discriminator": self.discriminator.state_dict(), "gen_optimizer": self.gen_optimizer.state_dict(), "disc_optimizer": self.disc_optimizer.state_dict(), "step": self.global_step, }, os.path.join(path, "checkpoint.pt")) # Save discriminator config (main model) self.discriminator.config.save_pretrained(path) print(f"Saved checkpoint to {path}") def load_protein_sequences(dataset_name="lamm-mit/protein_secondary_structure_from_PDB", split="train", max_seqs=None): """Load protein sequences from HF dataset.""" ds = load_dataset(dataset_name, split=split, streaming=True) sequences = [] for i, example in enumerate(ds): if max_seqs and i >= max_seqs: break # Try common column names seq = None for key in ["input", "primary", "sequences", "sequence", "protein", "text"]: if key in example: seq = example[key] break if seq and len(seq) > 10: sequences.append(seq) return sequences def main(): # Config DISC_CONFIG = ModernProteinLMConfig( vocab_size=33, hidden_size=576, num_hidden_layers=28, num_attention_heads=9, intermediate_size=2304, use_geglu=True, tie_word_embeddings=True, max_position_embeddings=1026, position_embedding_type="rotary", rope_theta=10000.0, ) # Generator: ~25% of discriminator size GEN_CONFIG = ModernProteinLMConfig( vocab_size=33, hidden_size=320, num_hidden_layers=8, num_attention_heads=8, intermediate_size=1280, use_geglu=True, tie_word_embeddings=True, ) tokenizer = ProteinTokenizer() # Load data print("Loading protein sequences...") train_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=50000) eval_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=5000) print(f"Loaded {len(train_seqs)} train, {len(eval_seqs)} eval sequences") train_dataset = ProteinDataset( train_seqs, tokenizer, max_length=1024, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, total_steps=100000, ) eval_dataset = ProteinDataset( eval_seqs, tokenizer, max_length=1024, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, total_steps=100000, current_step=100000, # Fixed at end ratio for eval ) # Models generator = GeneratorModel( vocab_size=33, hidden_size=GEN_CONFIG.hidden_size, num_layers=GEN_CONFIG.num_hidden_layers, num_heads=GEN_CONFIG.num_attention_heads, intermediate_size=GEN_CONFIG.intermediate_size, ) discriminator = DiscriminatorModel(DISC_CONFIG) # Count parameters gen_params = sum(p.numel() for p in generator.parameters()) disc_params = sum(p.numel() for p in discriminator.parameters()) print(f"Generator params: {gen_params/1e6:.1f}M") print(f"Discriminator params: {disc_params/1e6:.1f}M") trainer = ELECTRAProteinTrainer( generator=generator, discriminator=discriminator, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir="./modern_protein_electra", lr=5e-4, batch_size=16, max_steps=100000, warmup_steps=10000, weight_decay=0.01, grad_clip=1.0, generator_weight=1.0, discriminator_weight=50.0, device="cuda" if torch.cuda.is_available() else "cpu", ) print("Starting ELECTRA pre-training...") trainer.train() if __name__ == "__main__": main()