| """ |
| Fine-tune pretrained ModernProteinLM on downstream predictive tasks. |
| Supports: regression (fluorescence, stability), classification (solubility, remote homology). |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
| import random |
| import math |
| from typing import Dict, List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler |
| from torch.cuda.amp import autocast, GradScaler |
| from transformers import get_cosine_schedule_with_warmup |
| from datasets import load_dataset |
| from scipy.stats import spearmanr |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig |
|
|
|
|
| |
| |
| |
|
|
| class ProteinTokenizer: |
| def __init__(self): |
| self.vocab = { |
| "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3, |
| "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10, |
| "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17, |
| "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24, |
| "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29, "<sep>": 30, |
| } |
| while len(self.vocab) < 33: |
| self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab) |
| self.id_to_token = {v: k for k, v in self.vocab.items()} |
| self.mask_token_id = 29 |
| self.pad_token_id = 1 |
| self.cls_token_id = 0 |
| self.eos_token_id = 2 |
|
|
| def encode(self, sequence: str, max_length: int = 1024): |
| tokens = [self.cls_token_id] |
| for aa in sequence.upper(): |
| tokens.append(self.vocab.get(aa, self.vocab["<unk>"])) |
| tokens.append(self.eos_token_id) |
| if len(tokens) > max_length: |
| tokens = tokens[:max_length] |
| attention_mask = [1] * len(tokens) |
| while len(tokens) < max_length: |
| tokens.append(self.pad_token_id) |
| attention_mask.append(0) |
| return {"input_ids": tokens, "attention_mask": attention_mask} |
|
|
|
|
| def setup_distributed(): |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) |
| torch.cuda.set_device(local_rank) |
| return rank, world_size, local_rank |
| return 0, 1, 0 |
|
|
|
|
| def log_rank0(msg): |
| if not dist.is_initialized() or dist.get_rank() == 0: |
| print(msg) |
|
|
|
|
| |
| |
| |
|
|
| TASK_SPECS = { |
| "fluorescence": { |
| "dataset": "proteinea/fluorescence", |
| "seq_key": "primary", |
| "label_key": "log_fluorescence", |
| "task_type": "regression", |
| "metric": "spearman", |
| "splits": ["train", "validation", "test"], |
| }, |
| "stability": { |
| "dataset": "proteinea/fluorescence", |
| "seq_key": "primary", |
| "label_key": "log_fluorescence", |
| "task_type": "regression", |
| "metric": "spearman", |
| "splits": ["train", "validation", "test"], |
| }, |
| "solubility": { |
| "dataset": "proteinea/solubility", |
| "seq_key": "sequences", |
| "label_key": "labels", |
| "task_type": "classification", |
| "num_labels": 2, |
| "metric": "accuracy", |
| "splits": ["train", "validation", "test"], |
| }, |
| "remote_homology": { |
| "dataset": "proteinea/remote_homology", |
| "seq_key": "primary", |
| "label_key": "fold_label", |
| "task_type": "classification", |
| "num_labels": 1195, |
| "metric": "accuracy", |
| "splits": ["train", "validation", "test"], |
| }, |
| } |
|
|
|
|
| class DownstreamDataset(Dataset): |
| def __init__(self, task_name, split, tokenizer, max_length=1024): |
| self.spec = TASK_SPECS[task_name] |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| try: |
| self.data = load_dataset(self.spec["dataset"], split=split) |
| except Exception as e: |
| log_rank0(f"Failed to load {split}: {e}, using train") |
| self.data = load_dataset(self.spec["dataset"], split="train") |
| |
| self.examples = list(self.data) |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx): |
| ex = self.examples[idx] |
| seq = ex[self.spec["seq_key"]] |
| encoded = self.tokenizer.encode(seq, self.max_length) |
| |
| item = { |
| "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), |
| "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), |
| } |
| |
| if self.spec["task_type"] == "regression": |
| item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.float) |
| else: |
| item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.long) |
| |
| return item |
|
|
|
|
| def mean_pool(hidden_states, attention_mask): |
| mask = attention_mask.unsqueeze(-1).float() |
| return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) |
|
|
|
|
| class TaskHead(nn.Module): |
| def __init__(self, hidden_size, task_spec): |
| super().__init__() |
| if task_spec["task_type"] == "regression": |
| self.head = nn.Linear(hidden_size, 1) |
| else: |
| self.head = nn.Linear(hidden_size, task_spec.get("num_labels", 2)) |
| self.task_type = task_spec["task_type"] |
| |
| def forward(self, pooled): |
| return self.head(pooled) |
|
|
|
|
| def evaluate(model, head, dataloader, task_spec, device): |
| model.eval() |
| head.eval() |
| |
| all_preds = [] |
| all_labels = [] |
| total_loss = 0.0 |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| |
| outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) |
| hidden = outputs.hidden_states[-1] |
| pooled = mean_pool(hidden, attention_mask) |
| logits = head(pooled) |
| |
| if task_spec["task_type"] == "regression": |
| loss = F.mse_loss(logits.squeeze(-1), labels) |
| preds = logits.squeeze(-1).cpu().numpy() |
| else: |
| loss = F.cross_entropy(logits, labels) |
| preds = torch.argmax(logits, dim=-1).cpu().numpy() |
| |
| total_loss += loss.item() * input_ids.size(0) |
| all_preds.extend(preds.tolist() if hasattr(preds, 'tolist') else preds) |
| all_labels.extend(labels.cpu().numpy().tolist()) |
| |
| metric = task_spec["metric"] |
| if metric == "spearman": |
| score, _ = spearmanr(all_labels, all_preds) |
| elif metric == "accuracy": |
| score = accuracy_score(all_labels, all_preds) |
| elif metric == "f1": |
| score = f1_score(all_labels, all_preds, average="macro") |
| |
| return score, total_loss / len(dataloader.dataset) |
|
|
|
|
| def train_task(args, model, task_name, tokenizer, device, rank, world_size): |
| spec = TASK_SPECS[task_name] |
| |
| train_ds = DownstreamDataset(task_name, spec["splits"][0], tokenizer, args.max_seq_length) |
| val_ds = DownstreamDataset( |
| task_name, |
| spec["splits"][1] if len(spec["splits"]) > 1 else spec["splits"][0], |
| tokenizer, args.max_seq_length |
| ) |
| test_ds = DownstreamDataset( |
| task_name, |
| spec["splits"][-1], |
| tokenizer, args.max_seq_length |
| ) |
| |
| if world_size > 1: |
| train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank) |
| else: |
| train_sampler = None |
| |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, |
| num_workers=args.num_workers, pin_memory=True, drop_last=True) |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True) |
| |
| head = TaskHead(args.hidden_size, spec).to(device) |
| |
| |
| params = [ |
| {"params": head.parameters(), "lr": args.lr}, |
| {"params": model.layers[-4:].parameters(), "lr": args.lr * 0.5}, |
| {"params": model.layers[:-4].parameters(), "lr": args.lr * 0.1}, |
| {"params": [model.embeddings.weight], "lr": args.lr * 0.1}, |
| ] |
| |
| optimizer = torch.optim.AdamW(params, weight_decay=args.weight_decay) |
| |
| total_steps = len(train_loader) * args.epochs |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, int(args.warmup_ratio * total_steps), total_steps |
| ) |
| |
| scaler = GradScaler() if args.use_amp else None |
| |
| best_score = -float("inf") |
| best_state = None |
| |
| for epoch in range(args.epochs): |
| model.train() |
| head.train() |
| |
| if train_sampler: |
| train_sampler.set_epoch(epoch) |
| |
| for batch in train_loader: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| |
| with autocast(enabled=args.use_amp): |
| outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True) |
| hidden = outputs.hidden_states[-1] |
| pooled = mean_pool(hidden, attention_mask) |
| logits = head(pooled) |
| |
| if spec["task_type"] == "regression": |
| loss = F.mse_loss(logits.squeeze(-1), labels) |
| else: |
| loss = F.cross_entropy(logits, labels) |
| |
| if scaler: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0) |
| optimizer.step() |
| |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| |
| val_score, val_loss = evaluate(model, head, val_loader, spec, device) |
| |
| if rank == 0: |
| log_rank0(f" Epoch {epoch+1}/{args.epochs}: val_{spec['metric']}={val_score:.4f}, loss={val_loss:.4f}") |
| |
| if val_score > best_score: |
| best_score = val_score |
| best_state = { |
| "model": model.state_dict(), |
| "head": head.state_dict(), |
| } |
| |
| |
| if best_state: |
| model.load_state_dict(best_state["model"]) |
| head.load_state_dict(best_state["head"]) |
| |
| test_score, test_loss = evaluate(model, head, test_loader, spec, device) |
| |
| return { |
| "task": task_name, |
| "val_score": float(best_score), |
| "test_score": float(test_score), |
| "metric": spec["metric"], |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--pretrain_dir", required=True) |
| parser.add_argument("--tasks", default="fluorescence,solubility") |
| parser.add_argument("--epochs", type=int, default=20) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--max_seq_length", type=int, default=1024) |
| parser.add_argument("--output_dir", default="./outputs/finetune") |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--use_amp", action="store_true") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--use_trackio", action="store_true") |
| parser.add_argument("--trackio_project", default="modern-protein-lm") |
| args = parser.parse_args() |
| |
| rank, world_size, local_rank = setup_distributed() |
| |
| random.seed(args.seed + rank) |
| np.random.seed(args.seed + rank) |
| torch.manual_seed(args.seed + rank) |
| |
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
| |
| tokenizer = ProteinTokenizer() |
| |
| |
| checkpoint_path = os.path.join(args.pretrain_dir, "checkpoint.pt") |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
| |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| |
| |
| disc_state = checkpoint["discriminator"] |
| |
| hidden_size = None |
| for key in disc_state: |
| if "model.embeddings.weight" in key: |
| hidden_size = disc_state[key].shape[1] |
| break |
| |
| if hidden_size is None: |
| raise ValueError("Could not infer model size from checkpoint") |
| |
| args.hidden_size = hidden_size |
| |
| config = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=hidden_size, |
| num_hidden_layers=28, |
| num_attention_heads=9, |
| intermediate_size=2304, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| ) |
| |
| model = ModernProteinLM(config).to(device) |
| |
| base_state = {k.replace("model.", ""): v for k, v in disc_state.items() if k.startswith("model.")} |
| model.load_state_dict(base_state, strict=False) |
| |
| log_rank0(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6:.1f}M params") |
| |
| if world_size > 1: |
| model = DDP(model, device_ids=[local_rank]) |
| |
| tasks = [t.strip() for t in args.tasks.split(",")] |
| results = {} |
| |
| for task in tasks: |
| log_rank0(f"\n{'='*50}") |
| log_rank0(f"Task: {task}") |
| log_rank0(f"{'='*50}") |
| |
| result = train_task(args, model, task, tokenizer, device, rank, world_size) |
| results[task] = result |
| |
| if rank == 0: |
| log_rank0(f" Test {result['metric']}: {result['test_score']:.4f}") |
| |
| if rank == 0: |
| os.makedirs(args.output_dir, exist_ok=True) |
| with open(os.path.join(args.output_dir, "results.json"), "w") as f: |
| json.dump(results, f, indent=2) |
| |
| log_rank0(f"\n{'='*50}") |
| log_rank0("FINAL RESULTS") |
| log_rank0(f"{'='*50}") |
| for task, res in results.items(): |
| log_rank0(f" {task}: {res['test_score']:.4f} ({res['metric']})") |
| |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|