| """ |
| ELECTRA-style discriminative pre-training for ModernProteinLM. |
| |
| Generator (small): ~25% of discriminator size, trained with MLM. |
| Discriminator (main model): Trained to detect replaced tokens (RTD objective). |
| |
| Key improvements over standard ELECTRA: |
| 1. Curriculum masking: start at 30%, decay to 5% |
| 2. Span masking: mask contiguous regions (protein structural motifs) |
| 3. Generator-distillation: generator temperature annealing |
| 4. No NSP, no dropout (following ESM-2) |
| """ |
|
|
| import os |
| import math |
| import random |
| from typing import Dict, List, Optional |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import ( |
| PreTrainedTokenizerFast, |
| get_cosine_schedule_with_warmup, |
| get_linear_schedule_with_warmup, |
| ) |
| from datasets import load_dataset, concatenate_datasets |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig |
|
|
|
|
| class ProteinTokenizer: |
| """Simple protein tokenizer matching ESM-2 vocab.""" |
| |
| ALL_AA = "LAGVSERTIDPQKNFYWMHCXBUZO" |
| |
| def __init__(self): |
| |
| |
| |
| |
| self.vocab = { |
| "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3, |
| "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, |
| "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, |
| "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, |
| "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29, |
| "<sep>": 30, |
| } |
| |
| while len(self.vocab) < 33: |
| self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab) |
| |
| self.id_to_token = {v: k for k, v in self.vocab.items()} |
| self.mask_token_id = 29 |
| self.pad_token_id = 1 |
| self.cls_token_id = 0 |
| self.eos_token_id = 2 |
| |
| def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True): |
| tokens = [] |
| if add_special_tokens: |
| tokens.append(self.cls_token_id) |
| |
| for aa in sequence.upper(): |
| if aa in self.vocab: |
| tokens.append(self.vocab[aa]) |
| else: |
| tokens.append(self.vocab["<unk>"]) |
| |
| if add_special_tokens: |
| tokens.append(self.eos_token_id) |
| |
| |
| if len(tokens) > max_length: |
| tokens = tokens[:max_length] |
| |
| attention_mask = [1] * len(tokens) |
| while len(tokens) < max_length: |
| tokens.append(self.pad_token_id) |
| attention_mask.append(0) |
| |
| return { |
| "input_ids": tokens, |
| "attention_mask": attention_mask, |
| } |
| |
| def batch_encode(self, sequences: List[str], max_length: int = 1024): |
| results = [self.encode(seq, max_length) for seq in sequences] |
| return { |
| "input_ids": torch.tensor([r["input_ids"] for r in results], dtype=torch.long), |
| "attention_mask": torch.tensor([r["attention_mask"] for r in results], dtype=torch.long), |
| } |
| |
| def decode(self, token_ids): |
| if isinstance(token_ids, torch.Tensor): |
| token_ids = token_ids.tolist() |
| return "".join([self.id_to_token.get(t, "<unk>") for t in token_ids]) |
|
|
|
|
| def create_span_mask(length, mask_ratio=0.30, mean_span_length=3, min_span_length=1): |
| """Create span mask for protein sequences.""" |
| num_to_mask = max(1, int(length * mask_ratio)) |
| mask = [False] * length |
| |
| attempts = 0 |
| masked = 0 |
| while masked < num_to_mask and attempts < num_to_mask * 10: |
| span_len = max(min_span_length, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked)) |
| start = random.randint(0, max(0, length - span_len - 1)) |
| |
| |
| if any(mask[start:start+span_len]): |
| attempts += 1 |
| continue |
| |
| for i in range(start, min(start + span_len, length)): |
| mask[i] = True |
| masked += 1 |
| |
| return mask |
|
|
|
|
| class ProteinDataset(Dataset): |
| def __init__(self, sequences, tokenizer, max_length=1024, mask_ratio=0.30, |
| mean_span_length=3, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, |
| total_steps=100000, current_step=0): |
| self.sequences = sequences |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.mean_span_length = mean_span_length |
| self.curriculum_start_ratio = curriculum_start_ratio |
| self.curriculum_end_ratio = curriculum_end_ratio |
| self.total_steps = total_steps |
| self.current_step = current_step |
| |
| def get_current_mask_ratio(self): |
| """Linear decay from start to end ratio.""" |
| progress = min(1.0, self.current_step / self.total_steps) |
| return self.curriculum_start_ratio + (self.curriculum_end_ratio - self.curriculum_start_ratio) * progress |
| |
| def __len__(self): |
| return len(self.sequences) |
| |
| def __getitem__(self, idx): |
| seq = self.sequences[idx] |
| encoded = self.tokenizer.encode(seq, max_length=self.max_length) |
| input_ids = encoded["input_ids"] |
| attention_mask = encoded["attention_mask"] |
| |
| |
| seq_len = sum(attention_mask) |
| |
| effective_len = seq_len - 2 if seq_len > 2 else seq_len |
| |
| |
| mask_ratio = self.get_current_mask_ratio() |
| span_mask = create_span_mask(effective_len, mask_ratio, self.mean_span_length) |
| |
| |
| masked_input = input_ids.copy() |
| labels = [-100] * len(input_ids) |
| replaced = [False] * len(input_ids) |
| |
| for i in range(1, 1 + effective_len): |
| if span_mask[i - 1]: |
| labels[i] = input_ids[i] |
| replaced[i] = True |
| |
| r = random.random() |
| if r < 0.8: |
| masked_input[i] = self.tokenizer.mask_token_id |
| elif r < 0.9: |
| masked_input[i] = random.randint(4, 28) |
| |
| |
| return { |
| "input_ids": torch.tensor(masked_input, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "labels": torch.tensor(labels, dtype=torch.long), |
| "replaced": torch.tensor(replaced, dtype=torch.bool), |
| "original_ids": torch.tensor(input_ids, dtype=torch.long), |
| } |
|
|
|
|
| class GeneratorModel(nn.Module): |
| """Small generator model for ELECTRA.""" |
| |
| def __init__(self, vocab_size, hidden_size=256, num_layers=4, num_heads=4, intermediate_size=1024): |
| super().__init__() |
| config = ModernProteinLMConfig( |
| vocab_size=vocab_size, |
| hidden_size=hidden_size, |
| num_hidden_layers=num_layers, |
| num_attention_heads=num_heads, |
| intermediate_size=intermediate_size, |
| tie_word_embeddings=True, |
| ) |
| self.model = ModernProteinLM(config) |
| |
| def forward(self, input_ids, attention_mask, labels): |
| return self.model(input_ids, attention_mask, labels=labels) |
|
|
|
|
| class DiscriminatorModel(ModernProteinLM): |
| """Discriminator with additional classification head for RTD.""" |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.discriminator_head = nn.Linear(config.hidden_size, 1) |
| |
| def forward(self, input_ids, attention_mask, labels=None): |
| outputs = super().forward(input_ids, attention_mask, return_dict=True) |
| hidden = outputs.hidden_states[-1] |
| |
| |
| disc_logits = self.discriminator_head(hidden).squeeze(-1) |
| |
| disc_loss = None |
| if labels is not None: |
| |
| loss_fct = nn.BCEWithLogitsLoss() |
| active_loss = labels != -100 |
| active_logits = disc_logits[active_loss] |
| active_labels = labels[active_loss].float() |
| disc_loss = loss_fct(active_logits, active_labels) |
| |
| return { |
| "loss": disc_loss, |
| "logits": disc_logits, |
| "hidden_states": outputs.hidden_states, |
| } |
|
|
|
|
| class ELECTRAProteinTrainer: |
| def __init__( |
| self, |
| generator: GeneratorModel, |
| discriminator: DiscriminatorModel, |
| tokenizer, |
| train_dataset, |
| eval_dataset, |
| output_dir="./electra_protein", |
| lr=5e-4, |
| batch_size=32, |
| max_steps=100000, |
| warmup_steps=10000, |
| weight_decay=0.01, |
| grad_clip=1.0, |
| generator_weight=1.0, |
| discriminator_weight=50.0, |
| device="cuda", |
| ): |
| self.generator = generator.to(device) |
| self.discriminator = discriminator.to(device) |
| self.tokenizer = tokenizer |
| self.train_dataset = train_dataset |
| self.eval_dataset = eval_dataset |
| self.output_dir = output_dir |
| self.device = device |
| self.max_steps = max_steps |
| self.grad_clip = grad_clip |
| self.generator_weight = generator_weight |
| self.discriminator_weight = discriminator_weight |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| self.gen_optimizer = torch.optim.AdamW( |
| generator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay |
| ) |
| self.disc_optimizer = torch.optim.AdamW( |
| discriminator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay |
| ) |
| |
| |
| self.gen_scheduler = get_cosine_schedule_with_warmup( |
| self.gen_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps |
| ) |
| self.disc_scheduler = get_cosine_schedule_with_warmup( |
| self.disc_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps |
| ) |
| |
| self.train_loader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True |
| ) |
| self.eval_loader = DataLoader( |
| eval_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True |
| ) |
| |
| self.global_step = 0 |
| self.best_eval_loss = float("inf") |
| |
| def train(self): |
| self.generator.train() |
| self.discriminator.train() |
| |
| pbar = tqdm(total=self.max_steps, desc="Training") |
| |
| for batch in self.train_loader: |
| if self.global_step >= self.max_steps: |
| break |
| |
| self._train_step(batch) |
| self.global_step += 1 |
| pbar.update(1) |
| |
| if self.global_step % 1000 == 0: |
| eval_loss = self.evaluate() |
| if eval_loss < self.best_eval_loss: |
| self.best_eval_loss = eval_loss |
| self.save_checkpoint("best") |
| self.generator.train() |
| self.discriminator.train() |
| |
| if self.global_step % 5000 == 0: |
| self.save_checkpoint(f"step_{self.global_step}") |
| |
| pbar.close() |
| self.save_checkpoint("final") |
| |
| def _train_step(self, batch): |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| mlm_labels = batch["labels"].to(self.device) |
| replaced_positions = batch["replaced"].to(self.device) |
| original_ids = batch["original_ids"].to(self.device) |
| |
| |
| gen_outputs = self.generator(input_ids, attention_mask, mlm_labels) |
| gen_loss = gen_outputs.loss |
| |
| |
| with torch.no_grad(): |
| gen_logits = gen_outputs.logits |
| gen_probs = F.softmax(gen_logits, dim=-1) |
| sampled_ids = torch.multinomial( |
| gen_probs.view(-1, gen_probs.size(-1)), 1 |
| ).view(gen_probs.shape[:-1]) |
| |
| |
| corrupted_input = original_ids.clone() |
| mask_positions = mlm_labels != -100 |
| corrupted_input[mask_positions] = sampled_ids[mask_positions] |
| |
| |
| |
| disc_labels = torch.ones_like(original_ids, dtype=torch.float) |
| disc_labels[replaced_positions] = 0.0 |
| |
| disc_labels[attention_mask == 0] = -100 |
| |
| disc_outputs = self.discriminator(corrupted_input, attention_mask, disc_labels) |
| disc_loss = disc_outputs["loss"] |
| |
| |
| |
| total_loss = self.generator_weight * gen_loss + self.discriminator_weight * disc_loss |
| |
| total_loss.backward() |
| |
| torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) |
| torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) |
| |
| self.gen_optimizer.step() |
| self.disc_optimizer.step() |
| self.gen_scheduler.step() |
| self.disc_scheduler.step() |
| |
| self.gen_optimizer.zero_grad() |
| self.disc_optimizer.zero_grad() |
| |
| if self.global_step % 100 == 0: |
| pbar = tqdm.get_tqdm() |
| pbar.set_postfix({ |
| "gen_loss": f"{gen_loss.item():.4f}", |
| "disc_loss": f"{disc_loss.item():.4f}", |
| "lr": f"{self.gen_scheduler.get_last_lr()[0]:.2e}", |
| }) |
| |
| def evaluate(self): |
| self.generator.eval() |
| self.discriminator.eval() |
| |
| total_gen_loss = 0 |
| total_disc_loss = 0 |
| total_samples = 0 |
| |
| with torch.no_grad(): |
| for batch in self.eval_loader: |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| mlm_labels = batch["labels"].to(self.device) |
| replaced_positions = batch["replaced"].to(self.device) |
| original_ids = batch["original_ids"].to(self.device) |
| |
| gen_outputs = self.generator(input_ids, attention_mask, mlm_labels) |
| total_gen_loss += gen_outputs.loss.item() * input_ids.size(0) |
| |
| disc_labels = torch.ones_like(original_ids, dtype=torch.float) |
| disc_labels[replaced_positions] = 0.0 |
| disc_labels[attention_mask == 0] = -100 |
| |
| disc_outputs = self.discriminator(input_ids, attention_mask, disc_labels) |
| total_disc_loss += disc_outputs["loss"].item() * input_ids.size(0) |
| total_samples += input_ids.size(0) |
| |
| avg_gen = total_gen_loss / total_samples |
| avg_disc = total_disc_loss / total_samples |
| |
| print(f"Eval - Gen Loss: {avg_gen:.4f}, Disc Loss: {avg_disc:.4f}") |
| return avg_gen + avg_disc |
| |
| def save_checkpoint(self, name): |
| path = os.path.join(self.output_dir, name) |
| os.makedirs(path, exist_ok=True) |
| |
| torch.save({ |
| "generator": self.generator.state_dict(), |
| "discriminator": self.discriminator.state_dict(), |
| "gen_optimizer": self.gen_optimizer.state_dict(), |
| "disc_optimizer": self.disc_optimizer.state_dict(), |
| "step": self.global_step, |
| }, os.path.join(path, "checkpoint.pt")) |
| |
| |
| self.discriminator.config.save_pretrained(path) |
| |
| print(f"Saved checkpoint to {path}") |
|
|
|
|
| def load_protein_sequences(dataset_name="lamm-mit/protein_secondary_structure_from_PDB", split="train", max_seqs=None): |
| """Load protein sequences from HF dataset.""" |
| ds = load_dataset(dataset_name, split=split, streaming=True) |
| sequences = [] |
| |
| for i, example in enumerate(ds): |
| if max_seqs and i >= max_seqs: |
| break |
| |
| seq = None |
| for key in ["input", "primary", "sequences", "sequence", "protein", "text"]: |
| if key in example: |
| seq = example[key] |
| break |
| if seq and len(seq) > 10: |
| sequences.append(seq) |
| |
| return sequences |
|
|
|
|
| def main(): |
| |
| DISC_CONFIG = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=576, |
| num_hidden_layers=28, |
| num_attention_heads=9, |
| intermediate_size=2304, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| max_position_embeddings=1026, |
| position_embedding_type="rotary", |
| rope_theta=10000.0, |
| ) |
| |
| |
| GEN_CONFIG = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=320, |
| num_hidden_layers=8, |
| num_attention_heads=8, |
| intermediate_size=1280, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| ) |
| |
| tokenizer = ProteinTokenizer() |
| |
| |
| print("Loading protein sequences...") |
| train_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=50000) |
| eval_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=5000) |
| |
| print(f"Loaded {len(train_seqs)} train, {len(eval_seqs)} eval sequences") |
| |
| train_dataset = ProteinDataset( |
| train_seqs, tokenizer, max_length=1024, |
| curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, |
| total_steps=100000, |
| ) |
| eval_dataset = ProteinDataset( |
| eval_seqs, tokenizer, max_length=1024, |
| curriculum_start_ratio=0.30, curriculum_end_ratio=0.05, |
| total_steps=100000, current_step=100000, |
| ) |
| |
| |
| generator = GeneratorModel( |
| vocab_size=33, |
| hidden_size=GEN_CONFIG.hidden_size, |
| num_layers=GEN_CONFIG.num_hidden_layers, |
| num_heads=GEN_CONFIG.num_attention_heads, |
| intermediate_size=GEN_CONFIG.intermediate_size, |
| ) |
| discriminator = DiscriminatorModel(DISC_CONFIG) |
| |
| |
| gen_params = sum(p.numel() for p in generator.parameters()) |
| disc_params = sum(p.numel() for p in discriminator.parameters()) |
| print(f"Generator params: {gen_params/1e6:.1f}M") |
| print(f"Discriminator params: {disc_params/1e6:.1f}M") |
| |
| trainer = ELECTRAProteinTrainer( |
| generator=generator, |
| discriminator=discriminator, |
| tokenizer=tokenizer, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| output_dir="./modern_protein_electra", |
| lr=5e-4, |
| batch_size=16, |
| max_steps=100000, |
| warmup_steps=10000, |
| weight_decay=0.01, |
| grad_clip=1.0, |
| generator_weight=1.0, |
| discriminator_weight=50.0, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| |
| print("Starting ELECTRA pre-training...") |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|