ModernProteinLM / downstream_eval.py
GrimSqueaker's picture
Upload downstream_eval.py with huggingface_hub
9c95323 verified
"""
Downstream evaluation for ModernProteinLM on predictive protein tasks:
- Fluorescence (regression, Spearman)
- Solubility (binary classification)
- Secondary Structure (token classification, Q3/Q8 accuracy)
- Remote Homology (classification)
Compares against ESM-2 baselines.
"""
import os
import json
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_squared_error
from scipy.stats import spearmanr
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
from electra_pretrain import ProteinTokenizer
class ProteinDownstreamDataset(Dataset):
"""Generic downstream dataset wrapper."""
TASK_CONFIGS = {
"fluorescence": {
"dataset": "proteinea/fluorescence",
"seq_col": "primary",
"label_col": "log_fluorescence",
"task": "regression",
"metric": "spearman",
},
"solubility": {
"dataset": "proteinea/solubility",
"seq_col": "sequences",
"label_col": "labels",
"task": "classification",
"num_labels": 2,
"metric": "accuracy",
},
"secondary_structure": {
"dataset": "proteinea/secondary_structure_prediction",
"seq_col": "input",
"label_cols": ["dssp3", "dssp8"],
"task": "token_classification",
"num_labels": 3, # Q3 first
"metric": "accuracy",
},
"remote_homology": {
"dataset": "proteinea/remote_homology",
"seq_col": "primary",
"label_col": "fold_label",
"task": "classification",
"num_labels": 1195, # Actually fold labels
"metric": "accuracy",
},
}
def __init__(self, task_name, split, tokenizer, max_length=1024):
self.task_name = task_name
self.config = self.TASK_CONFIGS[task_name]
self.tokenizer = tokenizer
self.max_length = max_length
try:
self.data = load_dataset(self.config["dataset"], split=split)
except:
# Some datasets don't have validation/test splits, use train
self.data = load_dataset(self.config["dataset"], split="train")
self.examples = list(self.data)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
seq = ex[self.config["seq_col"]]
encoded = self.tokenizer.encode(seq, max_length=self.max_length)
item = {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
}
if self.config["task"] == "regression":
item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.float)
elif self.config["task"] == "classification":
item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.long)
elif self.config["task"] == "token_classification":
# Secondary structure: each AA has a label
ss = ex[self.config["label_cols"][0]] # dssp3
# Map 'C', 'H', 'E' to 0, 1, 2
ss_map = {'C': 0, 'H': 1, 'E': 2}
labels = [ss_map.get(c, 0) for c in ss]
# Pad/truncate to match sequence length
seq_len = sum(encoded["attention_mask"])
labels = labels[:seq_len]
while len(labels) < len(encoded["input_ids"]):
labels.append(-100)
item["labels"] = torch.tensor(labels, dtype=torch.long)
return item
class DownstreamModel(nn.Module):
def __init__(self, base_model, task_config):
super().__init__()
self.base = base_model
self.task = task_config["task"]
self.config = task_config
hidden_size = base_model.config.hidden_size
if self.task == "regression":
self.head = nn.Linear(hidden_size, 1)
elif self.task == "classification":
self.head = nn.Linear(hidden_size, task_config.get("num_labels", 2))
elif self.task == "token_classification":
self.head = nn.Linear(hidden_size, task_config.get("num_labels", 3))
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.base(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
hidden = outputs.hidden_states[-1]
if self.task in ["regression", "classification"]:
# Mean pool
mask_expanded = attention_mask.unsqueeze(-1).float()
pooled = (hidden * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9)
logits = self.head(pooled)
else:
# Token-level
logits = self.head(hidden)
loss = None
if labels is not None:
if self.task == "regression":
loss_fct = nn.MSELoss()
loss = loss_fct(logits.squeeze(-1), labels)
elif self.task == "classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
elif self.task == "token_classification":
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, self.config.get("num_labels", 3)), labels.view(-1))
return {"loss": loss, "logits": logits}
def evaluate(model, dataloader, task_config, device):
model.eval()
all_preds = []
all_labels = []
total_loss = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids, attention_mask, labels)
total_loss += outputs["loss"].item() * input_ids.size(0)
logits = outputs["logits"]
if task_config["task"] == "regression":
preds = logits.squeeze(-1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
elif task_config["task"] == "classification":
preds = torch.argmax(logits, dim=-1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
elif task_config["task"] == "token_classification":
preds = torch.argmax(logits, dim=-1).cpu().numpy()
labels_np = labels.cpu().numpy()
# Only evaluate non-padding positions
for i in range(len(preds)):
mask = labels_np[i] != -100
all_preds.extend(preds[i][mask])
all_labels.extend(labels_np[i][mask])
metric = task_config["metric"]
if metric == "spearman":
score, _ = spearmanr(all_labels, all_preds)
elif metric == "accuracy":
score = accuracy_score(all_labels, all_preds)
elif metric == "f1":
score = f1_score(all_labels, all_preds, average="macro")
avg_loss = total_loss / len(dataloader.dataset)
return score, avg_loss
def train_downstream(
base_model,
task_name,
tokenizer,
epochs=20,
batch_size=16,
lr=1e-4,
device="cuda",
seed=42,
):
torch.manual_seed(seed)
np.random.seed(seed)
task_config = ProteinDownstreamDataset.TASK_CONFIGS[task_name]
train_dataset = ProteinDownstreamDataset(task_name, "train", tokenizer)
# For validation, use test or create split
try:
val_dataset = ProteinDownstreamDataset(task_name, "validation", tokenizer)
except:
val_dataset = ProteinDownstreamDataset(task_name, "test", tokenizer)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
model = DownstreamModel(base_model, task_config).to(device)
# Freeze some layers for small datasets
if task_name in ["fluorescence"]:
# Fine-tune all for small regression tasks
pass
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
)
best_score = -float("inf") if task_config["metric"] != "mse" else float("inf")
best_model_state = None
for epoch in range(epochs):
model.train()
total_loss = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids, attention_mask, labels)
loss = outputs["loss"]
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
# Evaluate
score, val_loss = evaluate(model, val_loader, task_config, device)
print(f"Epoch {epoch+1}: Val {task_config['metric']}={score:.4f}, Loss={val_loss:.4f}")
if task_config["metric"] == "spearman":
is_better = score > best_score
elif task_config["metric"] == "accuracy":
is_better = score > best_score
if is_better:
best_score = score
best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
if best_model_state:
model.load_state_dict(best_model_state)
return model, best_score
def compare_models(
task_names=["fluorescence", "solubility", "secondary_structure"],
epochs=20,
device="cuda",
):
tokenizer = ProteinTokenizer()
results = {}
for task in task_names:
print(f"\n{'='*50}")
print(f"Task: {task}")
print(f"{'='*50}")
# ModernProteinLM (random init)
config = ModernProteinLMConfig(
vocab_size=33,
hidden_size=640,
num_hidden_layers=24,
num_attention_heads=10,
intermediate_size=2304,
use_geglu=True,
tie_word_embeddings=True,
)
modern_model = ModernProteinLM(config)
print(f"ModernProteinLM params: {sum(p.numel() for p in modern_model.parameters())/1e6:.1f}M")
modern_model, modern_score = train_downstream(
modern_model, task, tokenizer, epochs=epochs, device=device
)
# ESM-2 baseline
try:
from transformers import AutoModel, AutoTokenizer
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
esm_model = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
print(f"ESM-2 35M params: {sum(p.numel() for p in esm_model.parameters())/1e6:.1f}M")
# Convert ESM model to have same interface
esm_model.config.hidden_size = esm_model.config.hidden_size
esm_model, esm_score = train_downstream(
esm_model, task, tokenizer, epochs=epochs, device=device
)
results[task] = {
"modern": modern_score,
"esm2_35m": esm_score,
}
except Exception as e:
print(f"ESM-2 comparison failed: {e}")
results[task] = {"modern": modern_score, "esm2_35m": None}
print(f"\nResults for {task}:")
print(f" ModernProteinLM: {modern_score:.4f}")
if "esm2_35m" in results[task] and results[task]["esm2_35m"] is not None:
print(f" ESM-2 35M: {results[task]['esm2_35m']:.4f}")
with open("downstream_results.json", "w") as f:
json.dump(results, f, indent=2)
return results
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Quick test on solubility (smallest dataset)
tokenizer = ProteinTokenizer()
config = ModernProteinLMConfig(
vocab_size=33,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=512,
use_geglu=True,
tie_word_embeddings=True,
)
model = ModernProteinLM(config)
print(f"\nTesting on solubility (tiny model)...")
trained_model, score = train_downstream(
model, "solubility", tokenizer, epochs=5, batch_size=8, lr=5e-4, device=device
)
print(f"Solubility accuracy: {score:.4f}")