| """ |
| Downstream evaluation for ModernProteinLM on predictive protein tasks: |
| - Fluorescence (regression, Spearman) |
| - Solubility (binary classification) |
| - Secondary Structure (token classification, Q3/Q8 accuracy) |
| - Remote Homology (classification) |
| |
| Compares against ESM-2 baselines. |
| """ |
|
|
| import os |
| import json |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from torch.utils.data import DataLoader, Dataset |
| from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_squared_error |
| from scipy.stats import spearmanr |
| from transformers import get_linear_schedule_with_warmup |
| from datasets import load_dataset |
| from tqdm import tqdm |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig |
| from electra_pretrain import ProteinTokenizer |
|
|
|
|
| class ProteinDownstreamDataset(Dataset): |
| """Generic downstream dataset wrapper.""" |
| |
| TASK_CONFIGS = { |
| "fluorescence": { |
| "dataset": "proteinea/fluorescence", |
| "seq_col": "primary", |
| "label_col": "log_fluorescence", |
| "task": "regression", |
| "metric": "spearman", |
| }, |
| "solubility": { |
| "dataset": "proteinea/solubility", |
| "seq_col": "sequences", |
| "label_col": "labels", |
| "task": "classification", |
| "num_labels": 2, |
| "metric": "accuracy", |
| }, |
| "secondary_structure": { |
| "dataset": "proteinea/secondary_structure_prediction", |
| "seq_col": "input", |
| "label_cols": ["dssp3", "dssp8"], |
| "task": "token_classification", |
| "num_labels": 3, |
| "metric": "accuracy", |
| }, |
| "remote_homology": { |
| "dataset": "proteinea/remote_homology", |
| "seq_col": "primary", |
| "label_col": "fold_label", |
| "task": "classification", |
| "num_labels": 1195, |
| "metric": "accuracy", |
| }, |
| } |
| |
| def __init__(self, task_name, split, tokenizer, max_length=1024): |
| self.task_name = task_name |
| self.config = self.TASK_CONFIGS[task_name] |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| try: |
| self.data = load_dataset(self.config["dataset"], split=split) |
| except: |
| |
| self.data = load_dataset(self.config["dataset"], split="train") |
| |
| self.examples = list(self.data) |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx): |
| ex = self.examples[idx] |
| seq = ex[self.config["seq_col"]] |
| encoded = self.tokenizer.encode(seq, max_length=self.max_length) |
| |
| item = { |
| "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), |
| "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), |
| } |
| |
| if self.config["task"] == "regression": |
| item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.float) |
| elif self.config["task"] == "classification": |
| item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.long) |
| elif self.config["task"] == "token_classification": |
| |
| ss = ex[self.config["label_cols"][0]] |
| |
| ss_map = {'C': 0, 'H': 1, 'E': 2} |
| labels = [ss_map.get(c, 0) for c in ss] |
| |
| seq_len = sum(encoded["attention_mask"]) |
| labels = labels[:seq_len] |
| while len(labels) < len(encoded["input_ids"]): |
| labels.append(-100) |
| item["labels"] = torch.tensor(labels, dtype=torch.long) |
| |
| return item |
|
|
|
|
| class DownstreamModel(nn.Module): |
| def __init__(self, base_model, task_config): |
| super().__init__() |
| self.base = base_model |
| self.task = task_config["task"] |
| self.config = task_config |
| |
| hidden_size = base_model.config.hidden_size |
| |
| if self.task == "regression": |
| self.head = nn.Linear(hidden_size, 1) |
| elif self.task == "classification": |
| self.head = nn.Linear(hidden_size, task_config.get("num_labels", 2)) |
| elif self.task == "token_classification": |
| self.head = nn.Linear(hidden_size, task_config.get("num_labels", 3)) |
| |
| def forward(self, input_ids, attention_mask, labels=None): |
| outputs = self.base( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| hidden = outputs.hidden_states[-1] |
| |
| if self.task in ["regression", "classification"]: |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| pooled = (hidden * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9) |
| logits = self.head(pooled) |
| else: |
| |
| logits = self.head(hidden) |
| |
| loss = None |
| if labels is not None: |
| if self.task == "regression": |
| loss_fct = nn.MSELoss() |
| loss = loss_fct(logits.squeeze(-1), labels) |
| elif self.task == "classification": |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| elif self.task == "token_classification": |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct(logits.view(-1, self.config.get("num_labels", 3)), labels.view(-1)) |
| |
| return {"loss": loss, "logits": logits} |
|
|
|
|
| def evaluate(model, dataloader, task_config, device): |
| model.eval() |
| all_preds = [] |
| all_labels = [] |
| total_loss = 0 |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| |
| outputs = model(input_ids, attention_mask, labels) |
| total_loss += outputs["loss"].item() * input_ids.size(0) |
| |
| logits = outputs["logits"] |
| if task_config["task"] == "regression": |
| preds = logits.squeeze(-1).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(labels.cpu().numpy()) |
| elif task_config["task"] == "classification": |
| preds = torch.argmax(logits, dim=-1).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(labels.cpu().numpy()) |
| elif task_config["task"] == "token_classification": |
| preds = torch.argmax(logits, dim=-1).cpu().numpy() |
| labels_np = labels.cpu().numpy() |
| |
| for i in range(len(preds)): |
| mask = labels_np[i] != -100 |
| all_preds.extend(preds[i][mask]) |
| all_labels.extend(labels_np[i][mask]) |
| |
| metric = task_config["metric"] |
| if metric == "spearman": |
| score, _ = spearmanr(all_labels, all_preds) |
| elif metric == "accuracy": |
| score = accuracy_score(all_labels, all_preds) |
| elif metric == "f1": |
| score = f1_score(all_labels, all_preds, average="macro") |
| |
| avg_loss = total_loss / len(dataloader.dataset) |
| return score, avg_loss |
|
|
|
|
| def train_downstream( |
| base_model, |
| task_name, |
| tokenizer, |
| epochs=20, |
| batch_size=16, |
| lr=1e-4, |
| device="cuda", |
| seed=42, |
| ): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| |
| task_config = ProteinDownstreamDataset.TASK_CONFIGS[task_name] |
| |
| train_dataset = ProteinDownstreamDataset(task_name, "train", tokenizer) |
| |
| |
| try: |
| val_dataset = ProteinDownstreamDataset(task_name, "validation", tokenizer) |
| except: |
| val_dataset = ProteinDownstreamDataset(task_name, "test", tokenizer) |
| |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) |
| |
| model = DownstreamModel(base_model, task_config).to(device) |
| |
| |
| if task_name in ["fluorescence"]: |
| |
| pass |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| |
| total_steps = len(train_loader) * epochs |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps |
| ) |
| |
| best_score = -float("inf") if task_config["metric"] != "mse" else float("inf") |
| best_model_state = None |
| |
| for epoch in range(epochs): |
| model.train() |
| total_loss = 0 |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") |
| for batch in pbar: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| |
| outputs = model(input_ids, attention_mask, labels) |
| loss = outputs["loss"] |
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) |
| |
| |
| score, val_loss = evaluate(model, val_loader, task_config, device) |
| print(f"Epoch {epoch+1}: Val {task_config['metric']}={score:.4f}, Loss={val_loss:.4f}") |
| |
| if task_config["metric"] == "spearman": |
| is_better = score > best_score |
| elif task_config["metric"] == "accuracy": |
| is_better = score > best_score |
| |
| if is_better: |
| best_score = score |
| best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| |
| if best_model_state: |
| model.load_state_dict(best_model_state) |
| |
| return model, best_score |
|
|
|
|
| def compare_models( |
| task_names=["fluorescence", "solubility", "secondary_structure"], |
| epochs=20, |
| device="cuda", |
| ): |
| tokenizer = ProteinTokenizer() |
| results = {} |
| |
| for task in task_names: |
| print(f"\n{'='*50}") |
| print(f"Task: {task}") |
| print(f"{'='*50}") |
| |
| |
| config = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=640, |
| num_hidden_layers=24, |
| num_attention_heads=10, |
| intermediate_size=2304, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| ) |
| modern_model = ModernProteinLM(config) |
| print(f"ModernProteinLM params: {sum(p.numel() for p in modern_model.parameters())/1e6:.1f}M") |
| |
| modern_model, modern_score = train_downstream( |
| modern_model, task, tokenizer, epochs=epochs, device=device |
| ) |
| |
| |
| try: |
| from transformers import AutoModel, AutoTokenizer |
| esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") |
| esm_model = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D") |
| print(f"ESM-2 35M params: {sum(p.numel() for p in esm_model.parameters())/1e6:.1f}M") |
| |
| |
| esm_model.config.hidden_size = esm_model.config.hidden_size |
| |
| esm_model, esm_score = train_downstream( |
| esm_model, task, tokenizer, epochs=epochs, device=device |
| ) |
| |
| results[task] = { |
| "modern": modern_score, |
| "esm2_35m": esm_score, |
| } |
| except Exception as e: |
| print(f"ESM-2 comparison failed: {e}") |
| results[task] = {"modern": modern_score, "esm2_35m": None} |
| |
| print(f"\nResults for {task}:") |
| print(f" ModernProteinLM: {modern_score:.4f}") |
| if "esm2_35m" in results[task] and results[task]["esm2_35m"] is not None: |
| print(f" ESM-2 35M: {results[task]['esm2_35m']:.4f}") |
| |
| with open("downstream_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
| |
| |
| tokenizer = ProteinTokenizer() |
| |
| config = ModernProteinLMConfig( |
| vocab_size=33, |
| hidden_size=128, |
| num_hidden_layers=4, |
| num_attention_heads=4, |
| intermediate_size=512, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| ) |
| model = ModernProteinLM(config) |
| |
| print(f"\nTesting on solubility (tiny model)...") |
| trained_model, score = train_downstream( |
| model, "solubility", tokenizer, epochs=5, batch_size=8, lr=5e-4, device=device |
| ) |
| print(f"Solubility accuracy: {score:.4f}") |
|
|