""" Fine-tune pretrained ModernProteinLM on downstream predictive tasks. Supports: regression (fluorescence, stability), classification (solubility, remote homology). """ import os import sys import argparse import json import random import math from typing import Dict, List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset, DistributedSampler from torch.cuda.amp import autocast, GradScaler from transformers import get_cosine_schedule_with_warmup from datasets import load_dataset from scipy.stats import spearmanr from sklearn.metrics import accuracy_score, f1_score from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig # ============================================================================= # TOKENIZER (shared with pretrain) # ============================================================================= class ProteinTokenizer: def __init__(self): self.vocab = { "": 0, "": 1, "": 2, "": 3, "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, "B": 25, "U": 26, "Z": 27, "O": 28, "": 29, "": 30, } while len(self.vocab) < 33: self.vocab[f""] = len(self.vocab) self.id_to_token = {v: k for k, v in self.vocab.items()} self.mask_token_id = 29 self.pad_token_id = 1 self.cls_token_id = 0 self.eos_token_id = 2 def encode(self, sequence: str, max_length: int = 1024): tokens = [self.cls_token_id] for aa in sequence.upper(): tokens.append(self.vocab.get(aa, self.vocab[""])) tokens.append(self.eos_token_id) if len(tokens) > max_length: tokens = tokens[:max_length] attention_mask = [1] * len(tokens) while len(tokens) < max_length: tokens.append(self.pad_token_id) attention_mask.append(0) return {"input_ids": tokens, "attention_mask": attention_mask} def setup_distributed(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(local_rank) return rank, world_size, local_rank return 0, 1, 0 def log_rank0(msg): if not dist.is_initialized() or dist.get_rank() == 0: print(msg) # ============================================================================= # TASK DEFINITIONS # ============================================================================= TASK_SPECS = { "fluorescence": { "dataset": "proteinea/fluorescence", "seq_key": "primary", "label_key": "log_fluorescence", "task_type": "regression", "metric": "spearman", "splits": ["train", "validation", "test"], }, "stability": { "dataset": "proteinea/fluorescence", "seq_key": "primary", "label_key": "log_fluorescence", "task_type": "regression", "metric": "spearman", "splits": ["train", "validation", "test"], }, "solubility": { "dataset": "proteinea/solubility", "seq_key": "sequences", "label_key": "labels", "task_type": "classification", "num_labels": 2, "metric": "accuracy", "splits": ["train", "validation", "test"], }, "remote_homology": { "dataset": "proteinea/remote_homology", "seq_key": "primary", "label_key": "fold_label", "task_type": "classification", "num_labels": 1195, "metric": "accuracy", "splits": ["train", "validation", "test"], }, } class DownstreamDataset(Dataset): def __init__(self, task_name, split, tokenizer, max_length=1024): self.spec = TASK_SPECS[task_name] self.tokenizer = tokenizer self.max_length = max_length try: self.data = load_dataset(self.spec["dataset"], split=split) except Exception as e: log_rank0(f"Failed to load {split}: {e}, using train") self.data = load_dataset(self.spec["dataset"], split="train") self.examples = list(self.data) def __len__(self): return len(self.examples) def __getitem__(self, idx): ex = self.examples[idx] seq = ex[self.spec["seq_key"]] encoded = self.tokenizer.encode(seq, self.max_length) item = { "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), } if self.spec["task_type"] == "regression": item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.float) else: item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.long) return item def mean_pool(hidden_states, attention_mask): mask = attention_mask.unsqueeze(-1).float() return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) class TaskHead(nn.Module): def __init__(self, hidden_size, task_spec): super().__init__() if task_spec["task_type"] == "regression": self.head = nn.Linear(hidden_size, 1) else: self.head = nn.Linear(hidden_size, task_spec.get("num_labels", 2)) self.task_type = task_spec["task_type"] def forward(self, pooled): return self.head(pooled) def evaluate(model, head, dataloader, task_spec, device): model.eval() head.eval() all_preds = [] all_labels = [] total_loss = 0.0 with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) hidden = outputs.hidden_states[-1] pooled = mean_pool(hidden, attention_mask) logits = head(pooled) if task_spec["task_type"] == "regression": loss = F.mse_loss(logits.squeeze(-1), labels) preds = logits.squeeze(-1).cpu().numpy() else: loss = F.cross_entropy(logits, labels) preds = torch.argmax(logits, dim=-1).cpu().numpy() total_loss += loss.item() * input_ids.size(0) all_preds.extend(preds.tolist() if hasattr(preds, 'tolist') else preds) all_labels.extend(labels.cpu().numpy().tolist()) metric = task_spec["metric"] if metric == "spearman": score, _ = spearmanr(all_labels, all_preds) elif metric == "accuracy": score = accuracy_score(all_labels, all_preds) elif metric == "f1": score = f1_score(all_labels, all_preds, average="macro") return score, total_loss / len(dataloader.dataset) def train_task(args, model, task_name, tokenizer, device, rank, world_size): spec = TASK_SPECS[task_name] train_ds = DownstreamDataset(task_name, spec["splits"][0], tokenizer, args.max_seq_length) val_ds = DownstreamDataset( task_name, spec["splits"][1] if len(spec["splits"]) > 1 else spec["splits"][0], tokenizer, args.max_seq_length ) test_ds = DownstreamDataset( task_name, spec["splits"][-1], tokenizer, args.max_seq_length ) if world_size > 1: train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank) else: train_sampler = None train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) head = TaskHead(args.hidden_size, spec).to(device) # Layer-wise LR decay params = [ {"params": head.parameters(), "lr": args.lr}, {"params": model.layers[-4:].parameters(), "lr": args.lr * 0.5}, {"params": model.layers[:-4].parameters(), "lr": args.lr * 0.1}, {"params": [model.embeddings.weight], "lr": args.lr * 0.1}, ] optimizer = torch.optim.AdamW(params, weight_decay=args.weight_decay) total_steps = len(train_loader) * args.epochs scheduler = get_cosine_schedule_with_warmup( optimizer, int(args.warmup_ratio * total_steps), total_steps ) scaler = GradScaler() if args.use_amp else None best_score = -float("inf") best_state = None for epoch in range(args.epochs): model.train() head.train() if train_sampler: train_sampler.set_epoch(epoch) for batch in train_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) with autocast(enabled=args.use_amp): outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) hidden = outputs.hidden_states[-1] pooled = mean_pool(hidden, attention_mask) logits = head(pooled) if spec["task_type"] == "regression": loss = F.mse_loss(logits.squeeze(-1), labels) else: loss = F.cross_entropy(logits, labels) if scaler: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() # Evaluate val_score, val_loss = evaluate(model, head, val_loader, spec, device) if rank == 0: log_rank0(f" Epoch {epoch+1}/{args.epochs}: val_{spec['metric']}={val_score:.4f}, loss={val_loss:.4f}") if val_score > best_score: best_score = val_score best_state = { "model": model.state_dict(), "head": head.state_dict(), } # Load best and test if best_state: model.load_state_dict(best_state["model"]) head.load_state_dict(best_state["head"]) test_score, test_loss = evaluate(model, head, test_loader, spec, device) return { "task": task_name, "val_score": float(best_score), "test_score": float(test_score), "metric": spec["metric"], } def main(): parser = argparse.ArgumentParser() parser.add_argument("--pretrain_dir", required=True) parser.add_argument("--tasks", default="fluorescence,solubility") parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_ratio", type=float, default=0.1) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--max_seq_length", type=int, default=1024) parser.add_argument("--output_dir", default="./outputs/finetune") parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--use_amp", action="store_true") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--use_trackio", action="store_true") parser.add_argument("--trackio_project", default="modern-protein-lm") args = parser.parse_args() rank, world_size, local_rank = setup_distributed() random.seed(args.seed + rank) np.random.seed(args.seed + rank) torch.manual_seed(args.seed + rank) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") tokenizer = ProteinTokenizer() # Load pretrained discriminator base checkpoint_path = os.path.join(args.pretrain_dir, "checkpoint.pt") if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") # Infer config from checkpoint disc_state = checkpoint["discriminator"] # Find hidden_size from state dict hidden_size = None for key in disc_state: if "model.embeddings.weight" in key: hidden_size = disc_state[key].shape[1] break if hidden_size is None: raise ValueError("Could not infer model size from checkpoint") args.hidden_size = hidden_size config = ModernProteinLMConfig( vocab_size=33, hidden_size=hidden_size, num_hidden_layers=28, num_attention_heads=9, intermediate_size=2304, use_geglu=True, tie_word_embeddings=True, ) model = ModernProteinLM(config).to(device) # Load only base model weights (not discriminator head) base_state = {k.replace("model.", ""): v for k, v in disc_state.items() if k.startswith("model.")} model.load_state_dict(base_state, strict=False) log_rank0(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6:.1f}M params") if world_size > 1: model = DDP(model, device_ids=[local_rank]) tasks = [t.strip() for t in args.tasks.split(",")] results = {} for task in tasks: log_rank0(f"\n{'='*50}") log_rank0(f"Task: {task}") log_rank0(f"{'='*50}") result = train_task(args, model, task, tokenizer, device, rank, world_size) results[task] = result if rank == 0: log_rank0(f" Test {result['metric']}: {result['test_score']:.4f}") if rank == 0: os.makedirs(args.output_dir, exist_ok=True) with open(os.path.join(args.output_dir, "results.json"), "w") as f: json.dump(results, f, indent=2) log_rank0(f"\n{'='*50}") log_rank0("FINAL RESULTS") log_rank0(f"{'='*50}") for task, res in results.items(): log_rank0(f" {task}: {res['test_score']:.4f} ({res['metric']})") if dist.is_initialized(): dist.destroy_process_group() if __name__ == "__main__": main()