ModernProteinLM / train_finetune.py
GrimSqueaker's picture
Upload train_finetune.py with huggingface_hub
3714d46 verified
"""
Fine-tune pretrained ModernProteinLM on downstream predictive tasks.
Supports: regression (fluorescence, stability), classification (solubility, remote homology).
"""
import os
import sys
import argparse
import json
import random
import math
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score, f1_score
from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
# =============================================================================
# TOKENIZER (shared with pretrain)
# =============================================================================
class ProteinTokenizer:
def __init__(self):
self.vocab = {
"<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
"L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
"T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
"F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
"B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29, "<sep>": 30,
}
while len(self.vocab) < 33:
self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
self.id_to_token = {v: k for k, v in self.vocab.items()}
self.mask_token_id = 29
self.pad_token_id = 1
self.cls_token_id = 0
self.eos_token_id = 2
def encode(self, sequence: str, max_length: int = 1024):
tokens = [self.cls_token_id]
for aa in sequence.upper():
tokens.append(self.vocab.get(aa, self.vocab["<unk>"]))
tokens.append(self.eos_token_id)
if len(tokens) > max_length:
tokens = tokens[:max_length]
attention_mask = [1] * len(tokens)
while len(tokens) < max_length:
tokens.append(self.pad_token_id)
attention_mask.append(0)
return {"input_ids": tokens, "attention_mask": attention_mask}
def setup_distributed():
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
return 0, 1, 0
def log_rank0(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)
# =============================================================================
# TASK DEFINITIONS
# =============================================================================
TASK_SPECS = {
"fluorescence": {
"dataset": "proteinea/fluorescence",
"seq_key": "primary",
"label_key": "log_fluorescence",
"task_type": "regression",
"metric": "spearman",
"splits": ["train", "validation", "test"],
},
"stability": {
"dataset": "proteinea/fluorescence",
"seq_key": "primary",
"label_key": "log_fluorescence",
"task_type": "regression",
"metric": "spearman",
"splits": ["train", "validation", "test"],
},
"solubility": {
"dataset": "proteinea/solubility",
"seq_key": "sequences",
"label_key": "labels",
"task_type": "classification",
"num_labels": 2,
"metric": "accuracy",
"splits": ["train", "validation", "test"],
},
"remote_homology": {
"dataset": "proteinea/remote_homology",
"seq_key": "primary",
"label_key": "fold_label",
"task_type": "classification",
"num_labels": 1195,
"metric": "accuracy",
"splits": ["train", "validation", "test"],
},
}
class DownstreamDataset(Dataset):
def __init__(self, task_name, split, tokenizer, max_length=1024):
self.spec = TASK_SPECS[task_name]
self.tokenizer = tokenizer
self.max_length = max_length
try:
self.data = load_dataset(self.spec["dataset"], split=split)
except Exception as e:
log_rank0(f"Failed to load {split}: {e}, using train")
self.data = load_dataset(self.spec["dataset"], split="train")
self.examples = list(self.data)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
seq = ex[self.spec["seq_key"]]
encoded = self.tokenizer.encode(seq, self.max_length)
item = {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
}
if self.spec["task_type"] == "regression":
item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.float)
else:
item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.long)
return item
def mean_pool(hidden_states, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
class TaskHead(nn.Module):
def __init__(self, hidden_size, task_spec):
super().__init__()
if task_spec["task_type"] == "regression":
self.head = nn.Linear(hidden_size, 1)
else:
self.head = nn.Linear(hidden_size, task_spec.get("num_labels", 2))
self.task_type = task_spec["task_type"]
def forward(self, pooled):
return self.head(pooled)
def evaluate(model, head, dataloader, task_spec, device):
model.eval()
head.eval()
all_preds = []
all_labels = []
total_loss = 0.0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
hidden = outputs.hidden_states[-1]
pooled = mean_pool(hidden, attention_mask)
logits = head(pooled)
if task_spec["task_type"] == "regression":
loss = F.mse_loss(logits.squeeze(-1), labels)
preds = logits.squeeze(-1).cpu().numpy()
else:
loss = F.cross_entropy(logits, labels)
preds = torch.argmax(logits, dim=-1).cpu().numpy()
total_loss += loss.item() * input_ids.size(0)
all_preds.extend(preds.tolist() if hasattr(preds, 'tolist') else preds)
all_labels.extend(labels.cpu().numpy().tolist())
metric = task_spec["metric"]
if metric == "spearman":
score, _ = spearmanr(all_labels, all_preds)
elif metric == "accuracy":
score = accuracy_score(all_labels, all_preds)
elif metric == "f1":
score = f1_score(all_labels, all_preds, average="macro")
return score, total_loss / len(dataloader.dataset)
def train_task(args, model, task_name, tokenizer, device, rank, world_size):
spec = TASK_SPECS[task_name]
train_ds = DownstreamDataset(task_name, spec["splits"][0], tokenizer, args.max_seq_length)
val_ds = DownstreamDataset(
task_name,
spec["splits"][1] if len(spec["splits"]) > 1 else spec["splits"][0],
tokenizer, args.max_seq_length
)
test_ds = DownstreamDataset(
task_name,
spec["splits"][-1],
tokenizer, args.max_seq_length
)
if world_size > 1:
train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank)
else:
train_sampler = None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler,
num_workers=args.num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True)
head = TaskHead(args.hidden_size, spec).to(device)
# Layer-wise LR decay
params = [
{"params": head.parameters(), "lr": args.lr},
{"params": model.layers[-4:].parameters(), "lr": args.lr * 0.5},
{"params": model.layers[:-4].parameters(), "lr": args.lr * 0.1},
{"params": [model.embeddings.weight], "lr": args.lr * 0.1},
]
optimizer = torch.optim.AdamW(params, weight_decay=args.weight_decay)
total_steps = len(train_loader) * args.epochs
scheduler = get_cosine_schedule_with_warmup(
optimizer, int(args.warmup_ratio * total_steps), total_steps
)
scaler = GradScaler() if args.use_amp else None
best_score = -float("inf")
best_state = None
for epoch in range(args.epochs):
model.train()
head.train()
if train_sampler:
train_sampler.set_epoch(epoch)
for batch in train_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
with autocast(enabled=args.use_amp):
outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
hidden = outputs.hidden_states[-1]
pooled = mean_pool(hidden, attention_mask)
logits = head(pooled)
if spec["task_type"] == "regression":
loss = F.mse_loss(logits.squeeze(-1), labels)
else:
loss = F.cross_entropy(logits, labels)
if scaler:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Evaluate
val_score, val_loss = evaluate(model, head, val_loader, spec, device)
if rank == 0:
log_rank0(f" Epoch {epoch+1}/{args.epochs}: val_{spec['metric']}={val_score:.4f}, loss={val_loss:.4f}")
if val_score > best_score:
best_score = val_score
best_state = {
"model": model.state_dict(),
"head": head.state_dict(),
}
# Load best and test
if best_state:
model.load_state_dict(best_state["model"])
head.load_state_dict(best_state["head"])
test_score, test_loss = evaluate(model, head, test_loader, spec, device)
return {
"task": task_name,
"val_score": float(best_score),
"test_score": float(test_score),
"metric": spec["metric"],
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pretrain_dir", required=True)
parser.add_argument("--tasks", default="fluorescence,solubility")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--max_seq_length", type=int, default=1024)
parser.add_argument("--output_dir", default="./outputs/finetune")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--use_amp", action="store_true")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--use_trackio", action="store_true")
parser.add_argument("--trackio_project", default="modern-protein-lm")
args = parser.parse_args()
rank, world_size, local_rank = setup_distributed()
random.seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.manual_seed(args.seed + rank)
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
tokenizer = ProteinTokenizer()
# Load pretrained discriminator base
checkpoint_path = os.path.join(args.pretrain_dir, "checkpoint.pt")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Infer config from checkpoint
disc_state = checkpoint["discriminator"]
# Find hidden_size from state dict
hidden_size = None
for key in disc_state:
if "model.embeddings.weight" in key:
hidden_size = disc_state[key].shape[1]
break
if hidden_size is None:
raise ValueError("Could not infer model size from checkpoint")
args.hidden_size = hidden_size
config = ModernProteinLMConfig(
vocab_size=33,
hidden_size=hidden_size,
num_hidden_layers=28,
num_attention_heads=9,
intermediate_size=2304,
use_geglu=True,
tie_word_embeddings=True,
)
model = ModernProteinLM(config).to(device)
# Load only base model weights (not discriminator head)
base_state = {k.replace("model.", ""): v for k, v in disc_state.items() if k.startswith("model.")}
model.load_state_dict(base_state, strict=False)
log_rank0(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")
if world_size > 1:
model = DDP(model, device_ids=[local_rank])
tasks = [t.strip() for t in args.tasks.split(",")]
results = {}
for task in tasks:
log_rank0(f"\n{'='*50}")
log_rank0(f"Task: {task}")
log_rank0(f"{'='*50}")
result = train_task(args, model, task, tokenizer, device, rank, world_size)
results[task] = result
if rank == 0:
log_rank0(f" Test {result['metric']}: {result['test_score']:.4f}")
if rank == 0:
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
json.dump(results, f, indent=2)
log_rank0(f"\n{'='*50}")
log_rank0("FINAL RESULTS")
log_rank0(f"{'='*50}")
for task, res in results.items():
log_rank0(f" {task}: {res['test_score']:.4f} ({res['metric']})")
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()