""" Downstream evaluation for ModernProteinLM on predictive protein tasks: - Fluorescence (regression, Spearman) - Solubility (binary classification) - Secondary Structure (token classification, Q3/Q8 accuracy) - Remote Homology (classification) Compares against ESM-2 baselines. """ import os import json import torch import torch.nn as nn import numpy as np from torch.utils.data import DataLoader, Dataset from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_squared_error from scipy.stats import spearmanr from transformers import get_linear_schedule_with_warmup from datasets import load_dataset from tqdm import tqdm import warnings warnings.filterwarnings("ignore") from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig from electra_pretrain import ProteinTokenizer class ProteinDownstreamDataset(Dataset): """Generic downstream dataset wrapper.""" TASK_CONFIGS = { "fluorescence": { "dataset": "proteinea/fluorescence", "seq_col": "primary", "label_col": "log_fluorescence", "task": "regression", "metric": "spearman", }, "solubility": { "dataset": "proteinea/solubility", "seq_col": "sequences", "label_col": "labels", "task": "classification", "num_labels": 2, "metric": "accuracy", }, "secondary_structure": { "dataset": "proteinea/secondary_structure_prediction", "seq_col": "input", "label_cols": ["dssp3", "dssp8"], "task": "token_classification", "num_labels": 3, # Q3 first "metric": "accuracy", }, "remote_homology": { "dataset": "proteinea/remote_homology", "seq_col": "primary", "label_col": "fold_label", "task": "classification", "num_labels": 1195, # Actually fold labels "metric": "accuracy", }, } def __init__(self, task_name, split, tokenizer, max_length=1024): self.task_name = task_name self.config = self.TASK_CONFIGS[task_name] self.tokenizer = tokenizer self.max_length = max_length try: self.data = load_dataset(self.config["dataset"], split=split) except: # Some datasets don't have validation/test splits, use train self.data = load_dataset(self.config["dataset"], split="train") self.examples = list(self.data) def __len__(self): return len(self.examples) def __getitem__(self, idx): ex = self.examples[idx] seq = ex[self.config["seq_col"]] encoded = self.tokenizer.encode(seq, max_length=self.max_length) item = { "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), } if self.config["task"] == "regression": item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.float) elif self.config["task"] == "classification": item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.long) elif self.config["task"] == "token_classification": # Secondary structure: each AA has a label ss = ex[self.config["label_cols"][0]] # dssp3 # Map 'C', 'H', 'E' to 0, 1, 2 ss_map = {'C': 0, 'H': 1, 'E': 2} labels = [ss_map.get(c, 0) for c in ss] # Pad/truncate to match sequence length seq_len = sum(encoded["attention_mask"]) labels = labels[:seq_len] while len(labels) < len(encoded["input_ids"]): labels.append(-100) item["labels"] = torch.tensor(labels, dtype=torch.long) return item class DownstreamModel(nn.Module): def __init__(self, base_model, task_config): super().__init__() self.base = base_model self.task = task_config["task"] self.config = task_config hidden_size = base_model.config.hidden_size if self.task == "regression": self.head = nn.Linear(hidden_size, 1) elif self.task == "classification": self.head = nn.Linear(hidden_size, task_config.get("num_labels", 2)) elif self.task == "token_classification": self.head = nn.Linear(hidden_size, task_config.get("num_labels", 3)) def forward(self, input_ids, attention_mask, labels=None): outputs = self.base( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) hidden = outputs.hidden_states[-1] if self.task in ["regression", "classification"]: # Mean pool mask_expanded = attention_mask.unsqueeze(-1).float() pooled = (hidden * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9) logits = self.head(pooled) else: # Token-level logits = self.head(hidden) loss = None if labels is not None: if self.task == "regression": loss_fct = nn.MSELoss() loss = loss_fct(logits.squeeze(-1), labels) elif self.task == "classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) elif self.task == "token_classification": loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.view(-1, self.config.get("num_labels", 3)), labels.view(-1)) return {"loss": loss, "logits": logits} def evaluate(model, dataloader, task_config, device): model.eval() all_preds = [] all_labels = [] total_loss = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids, attention_mask, labels) total_loss += outputs["loss"].item() * input_ids.size(0) logits = outputs["logits"] if task_config["task"] == "regression": preds = logits.squeeze(-1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) elif task_config["task"] == "classification": preds = torch.argmax(logits, dim=-1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) elif task_config["task"] == "token_classification": preds = torch.argmax(logits, dim=-1).cpu().numpy() labels_np = labels.cpu().numpy() # Only evaluate non-padding positions for i in range(len(preds)): mask = labels_np[i] != -100 all_preds.extend(preds[i][mask]) all_labels.extend(labels_np[i][mask]) metric = task_config["metric"] if metric == "spearman": score, _ = spearmanr(all_labels, all_preds) elif metric == "accuracy": score = accuracy_score(all_labels, all_preds) elif metric == "f1": score = f1_score(all_labels, all_preds, average="macro") avg_loss = total_loss / len(dataloader.dataset) return score, avg_loss def train_downstream( base_model, task_name, tokenizer, epochs=20, batch_size=16, lr=1e-4, device="cuda", seed=42, ): torch.manual_seed(seed) np.random.seed(seed) task_config = ProteinDownstreamDataset.TASK_CONFIGS[task_name] train_dataset = ProteinDownstreamDataset(task_name, "train", tokenizer) # For validation, use test or create split try: val_dataset = ProteinDownstreamDataset(task_name, "validation", tokenizer) except: val_dataset = ProteinDownstreamDataset(task_name, "test", tokenizer) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) model = DownstreamModel(base_model, task_config).to(device) # Freeze some layers for small datasets if task_name in ["fluorescence"]: # Fine-tune all for small regression tasks pass optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) total_steps = len(train_loader) * epochs scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) best_score = -float("inf") if task_config["metric"] != "mse" else float("inf") best_model_state = None for epoch in range(epochs): model.train() total_loss = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids, attention_mask, labels) loss = outputs["loss"] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() total_loss += loss.item() pbar.set_postfix({"loss": f"{loss.item():.4f}"}) # Evaluate score, val_loss = evaluate(model, val_loader, task_config, device) print(f"Epoch {epoch+1}: Val {task_config['metric']}={score:.4f}, Loss={val_loss:.4f}") if task_config["metric"] == "spearman": is_better = score > best_score elif task_config["metric"] == "accuracy": is_better = score > best_score if is_better: best_score = score best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} if best_model_state: model.load_state_dict(best_model_state) return model, best_score def compare_models( task_names=["fluorescence", "solubility", "secondary_structure"], epochs=20, device="cuda", ): tokenizer = ProteinTokenizer() results = {} for task in task_names: print(f"\n{'='*50}") print(f"Task: {task}") print(f"{'='*50}") # ModernProteinLM (random init) config = ModernProteinLMConfig( vocab_size=33, hidden_size=640, num_hidden_layers=24, num_attention_heads=10, intermediate_size=2304, use_geglu=True, tie_word_embeddings=True, ) modern_model = ModernProteinLM(config) print(f"ModernProteinLM params: {sum(p.numel() for p in modern_model.parameters())/1e6:.1f}M") modern_model, modern_score = train_downstream( modern_model, task, tokenizer, epochs=epochs, device=device ) # ESM-2 baseline try: from transformers import AutoModel, AutoTokenizer esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") esm_model = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D") print(f"ESM-2 35M params: {sum(p.numel() for p in esm_model.parameters())/1e6:.1f}M") # Convert ESM model to have same interface esm_model.config.hidden_size = esm_model.config.hidden_size esm_model, esm_score = train_downstream( esm_model, task, tokenizer, epochs=epochs, device=device ) results[task] = { "modern": modern_score, "esm2_35m": esm_score, } except Exception as e: print(f"ESM-2 comparison failed: {e}") results[task] = {"modern": modern_score, "esm2_35m": None} print(f"\nResults for {task}:") print(f" ModernProteinLM: {modern_score:.4f}") if "esm2_35m" in results[task] and results[task]["esm2_35m"] is not None: print(f" ESM-2 35M: {results[task]['esm2_35m']:.4f}") with open("downstream_results.json", "w") as f: json.dump(results, f, indent=2) return results if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Quick test on solubility (smallest dataset) tokenizer = ProteinTokenizer() config = ModernProteinLMConfig( vocab_size=33, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=512, use_geglu=True, tie_word_embeddings=True, ) model = ModernProteinLM(config) print(f"\nTesting on solubility (tiny model)...") trained_model, score = train_downstream( model, "solubility", tokenizer, epochs=5, batch_size=8, lr=5e-4, device=device ) print(f"Solubility accuracy: {score:.4f}")