GrimSqueaker commited on
Commit
9c95323
·
verified ·
1 Parent(s): 36bbb76

Upload downstream_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. downstream_eval.py +376 -0
downstream_eval.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downstream evaluation for ModernProteinLM on predictive protein tasks:
3
+ - Fluorescence (regression, Spearman)
4
+ - Solubility (binary classification)
5
+ - Secondary Structure (token classification, Q3/Q8 accuracy)
6
+ - Remote Homology (classification)
7
+
8
+ Compares against ESM-2 baselines.
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from torch.utils.data import DataLoader, Dataset
17
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_squared_error
18
+ from scipy.stats import spearmanr
19
+ from transformers import get_linear_schedule_with_warmup
20
+ from datasets import load_dataset
21
+ from tqdm import tqdm
22
+ import warnings
23
+ warnings.filterwarnings("ignore")
24
+
25
+ from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
26
+ from electra_pretrain import ProteinTokenizer
27
+
28
+
29
+ class ProteinDownstreamDataset(Dataset):
30
+ """Generic downstream dataset wrapper."""
31
+
32
+ TASK_CONFIGS = {
33
+ "fluorescence": {
34
+ "dataset": "proteinea/fluorescence",
35
+ "seq_col": "primary",
36
+ "label_col": "log_fluorescence",
37
+ "task": "regression",
38
+ "metric": "spearman",
39
+ },
40
+ "solubility": {
41
+ "dataset": "proteinea/solubility",
42
+ "seq_col": "sequences",
43
+ "label_col": "labels",
44
+ "task": "classification",
45
+ "num_labels": 2,
46
+ "metric": "accuracy",
47
+ },
48
+ "secondary_structure": {
49
+ "dataset": "proteinea/secondary_structure_prediction",
50
+ "seq_col": "input",
51
+ "label_cols": ["dssp3", "dssp8"],
52
+ "task": "token_classification",
53
+ "num_labels": 3, # Q3 first
54
+ "metric": "accuracy",
55
+ },
56
+ "remote_homology": {
57
+ "dataset": "proteinea/remote_homology",
58
+ "seq_col": "primary",
59
+ "label_col": "fold_label",
60
+ "task": "classification",
61
+ "num_labels": 1195, # Actually fold labels
62
+ "metric": "accuracy",
63
+ },
64
+ }
65
+
66
+ def __init__(self, task_name, split, tokenizer, max_length=1024):
67
+ self.task_name = task_name
68
+ self.config = self.TASK_CONFIGS[task_name]
69
+ self.tokenizer = tokenizer
70
+ self.max_length = max_length
71
+
72
+ try:
73
+ self.data = load_dataset(self.config["dataset"], split=split)
74
+ except:
75
+ # Some datasets don't have validation/test splits, use train
76
+ self.data = load_dataset(self.config["dataset"], split="train")
77
+
78
+ self.examples = list(self.data)
79
+
80
+ def __len__(self):
81
+ return len(self.examples)
82
+
83
+ def __getitem__(self, idx):
84
+ ex = self.examples[idx]
85
+ seq = ex[self.config["seq_col"]]
86
+ encoded = self.tokenizer.encode(seq, max_length=self.max_length)
87
+
88
+ item = {
89
+ "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
90
+ "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
91
+ }
92
+
93
+ if self.config["task"] == "regression":
94
+ item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.float)
95
+ elif self.config["task"] == "classification":
96
+ item["labels"] = torch.tensor(ex[self.config["label_col"]], dtype=torch.long)
97
+ elif self.config["task"] == "token_classification":
98
+ # Secondary structure: each AA has a label
99
+ ss = ex[self.config["label_cols"][0]] # dssp3
100
+ # Map 'C', 'H', 'E' to 0, 1, 2
101
+ ss_map = {'C': 0, 'H': 1, 'E': 2}
102
+ labels = [ss_map.get(c, 0) for c in ss]
103
+ # Pad/truncate to match sequence length
104
+ seq_len = sum(encoded["attention_mask"])
105
+ labels = labels[:seq_len]
106
+ while len(labels) < len(encoded["input_ids"]):
107
+ labels.append(-100)
108
+ item["labels"] = torch.tensor(labels, dtype=torch.long)
109
+
110
+ return item
111
+
112
+
113
+ class DownstreamModel(nn.Module):
114
+ def __init__(self, base_model, task_config):
115
+ super().__init__()
116
+ self.base = base_model
117
+ self.task = task_config["task"]
118
+ self.config = task_config
119
+
120
+ hidden_size = base_model.config.hidden_size
121
+
122
+ if self.task == "regression":
123
+ self.head = nn.Linear(hidden_size, 1)
124
+ elif self.task == "classification":
125
+ self.head = nn.Linear(hidden_size, task_config.get("num_labels", 2))
126
+ elif self.task == "token_classification":
127
+ self.head = nn.Linear(hidden_size, task_config.get("num_labels", 3))
128
+
129
+ def forward(self, input_ids, attention_mask, labels=None):
130
+ outputs = self.base(
131
+ input_ids=input_ids,
132
+ attention_mask=attention_mask,
133
+ output_hidden_states=True,
134
+ return_dict=True,
135
+ )
136
+ hidden = outputs.hidden_states[-1]
137
+
138
+ if self.task in ["regression", "classification"]:
139
+ # Mean pool
140
+ mask_expanded = attention_mask.unsqueeze(-1).float()
141
+ pooled = (hidden * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9)
142
+ logits = self.head(pooled)
143
+ else:
144
+ # Token-level
145
+ logits = self.head(hidden)
146
+
147
+ loss = None
148
+ if labels is not None:
149
+ if self.task == "regression":
150
+ loss_fct = nn.MSELoss()
151
+ loss = loss_fct(logits.squeeze(-1), labels)
152
+ elif self.task == "classification":
153
+ loss_fct = nn.CrossEntropyLoss()
154
+ loss = loss_fct(logits, labels)
155
+ elif self.task == "token_classification":
156
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
157
+ loss = loss_fct(logits.view(-1, self.config.get("num_labels", 3)), labels.view(-1))
158
+
159
+ return {"loss": loss, "logits": logits}
160
+
161
+
162
+ def evaluate(model, dataloader, task_config, device):
163
+ model.eval()
164
+ all_preds = []
165
+ all_labels = []
166
+ total_loss = 0
167
+
168
+ with torch.no_grad():
169
+ for batch in dataloader:
170
+ input_ids = batch["input_ids"].to(device)
171
+ attention_mask = batch["attention_mask"].to(device)
172
+ labels = batch["labels"].to(device)
173
+
174
+ outputs = model(input_ids, attention_mask, labels)
175
+ total_loss += outputs["loss"].item() * input_ids.size(0)
176
+
177
+ logits = outputs["logits"]
178
+ if task_config["task"] == "regression":
179
+ preds = logits.squeeze(-1).cpu().numpy()
180
+ all_preds.extend(preds)
181
+ all_labels.extend(labels.cpu().numpy())
182
+ elif task_config["task"] == "classification":
183
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
184
+ all_preds.extend(preds)
185
+ all_labels.extend(labels.cpu().numpy())
186
+ elif task_config["task"] == "token_classification":
187
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
188
+ labels_np = labels.cpu().numpy()
189
+ # Only evaluate non-padding positions
190
+ for i in range(len(preds)):
191
+ mask = labels_np[i] != -100
192
+ all_preds.extend(preds[i][mask])
193
+ all_labels.extend(labels_np[i][mask])
194
+
195
+ metric = task_config["metric"]
196
+ if metric == "spearman":
197
+ score, _ = spearmanr(all_labels, all_preds)
198
+ elif metric == "accuracy":
199
+ score = accuracy_score(all_labels, all_preds)
200
+ elif metric == "f1":
201
+ score = f1_score(all_labels, all_preds, average="macro")
202
+
203
+ avg_loss = total_loss / len(dataloader.dataset)
204
+ return score, avg_loss
205
+
206
+
207
+ def train_downstream(
208
+ base_model,
209
+ task_name,
210
+ tokenizer,
211
+ epochs=20,
212
+ batch_size=16,
213
+ lr=1e-4,
214
+ device="cuda",
215
+ seed=42,
216
+ ):
217
+ torch.manual_seed(seed)
218
+ np.random.seed(seed)
219
+
220
+ task_config = ProteinDownstreamDataset.TASK_CONFIGS[task_name]
221
+
222
+ train_dataset = ProteinDownstreamDataset(task_name, "train", tokenizer)
223
+
224
+ # For validation, use test or create split
225
+ try:
226
+ val_dataset = ProteinDownstreamDataset(task_name, "validation", tokenizer)
227
+ except:
228
+ val_dataset = ProteinDownstreamDataset(task_name, "test", tokenizer)
229
+
230
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
231
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
232
+
233
+ model = DownstreamModel(base_model, task_config).to(device)
234
+
235
+ # Freeze some layers for small datasets
236
+ if task_name in ["fluorescence"]:
237
+ # Fine-tune all for small regression tasks
238
+ pass
239
+
240
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
241
+
242
+ total_steps = len(train_loader) * epochs
243
+ scheduler = get_linear_schedule_with_warmup(
244
+ optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
245
+ )
246
+
247
+ best_score = -float("inf") if task_config["metric"] != "mse" else float("inf")
248
+ best_model_state = None
249
+
250
+ for epoch in range(epochs):
251
+ model.train()
252
+ total_loss = 0
253
+
254
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
255
+ for batch in pbar:
256
+ input_ids = batch["input_ids"].to(device)
257
+ attention_mask = batch["attention_mask"].to(device)
258
+ labels = batch["labels"].to(device)
259
+
260
+ outputs = model(input_ids, attention_mask, labels)
261
+ loss = outputs["loss"]
262
+
263
+ loss.backward()
264
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
265
+ optimizer.step()
266
+ scheduler.step()
267
+ optimizer.zero_grad()
268
+
269
+ total_loss += loss.item()
270
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
271
+
272
+ # Evaluate
273
+ score, val_loss = evaluate(model, val_loader, task_config, device)
274
+ print(f"Epoch {epoch+1}: Val {task_config['metric']}={score:.4f}, Loss={val_loss:.4f}")
275
+
276
+ if task_config["metric"] == "spearman":
277
+ is_better = score > best_score
278
+ elif task_config["metric"] == "accuracy":
279
+ is_better = score > best_score
280
+
281
+ if is_better:
282
+ best_score = score
283
+ best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
284
+
285
+ if best_model_state:
286
+ model.load_state_dict(best_model_state)
287
+
288
+ return model, best_score
289
+
290
+
291
+ def compare_models(
292
+ task_names=["fluorescence", "solubility", "secondary_structure"],
293
+ epochs=20,
294
+ device="cuda",
295
+ ):
296
+ tokenizer = ProteinTokenizer()
297
+ results = {}
298
+
299
+ for task in task_names:
300
+ print(f"\n{'='*50}")
301
+ print(f"Task: {task}")
302
+ print(f"{'='*50}")
303
+
304
+ # ModernProteinLM (random init)
305
+ config = ModernProteinLMConfig(
306
+ vocab_size=33,
307
+ hidden_size=640,
308
+ num_hidden_layers=24,
309
+ num_attention_heads=10,
310
+ intermediate_size=2304,
311
+ use_geglu=True,
312
+ tie_word_embeddings=True,
313
+ )
314
+ modern_model = ModernProteinLM(config)
315
+ print(f"ModernProteinLM params: {sum(p.numel() for p in modern_model.parameters())/1e6:.1f}M")
316
+
317
+ modern_model, modern_score = train_downstream(
318
+ modern_model, task, tokenizer, epochs=epochs, device=device
319
+ )
320
+
321
+ # ESM-2 baseline
322
+ try:
323
+ from transformers import AutoModel, AutoTokenizer
324
+ esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
325
+ esm_model = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
326
+ print(f"ESM-2 35M params: {sum(p.numel() for p in esm_model.parameters())/1e6:.1f}M")
327
+
328
+ # Convert ESM model to have same interface
329
+ esm_model.config.hidden_size = esm_model.config.hidden_size
330
+
331
+ esm_model, esm_score = train_downstream(
332
+ esm_model, task, tokenizer, epochs=epochs, device=device
333
+ )
334
+
335
+ results[task] = {
336
+ "modern": modern_score,
337
+ "esm2_35m": esm_score,
338
+ }
339
+ except Exception as e:
340
+ print(f"ESM-2 comparison failed: {e}")
341
+ results[task] = {"modern": modern_score, "esm2_35m": None}
342
+
343
+ print(f"\nResults for {task}:")
344
+ print(f" ModernProteinLM: {modern_score:.4f}")
345
+ if "esm2_35m" in results[task] and results[task]["esm2_35m"] is not None:
346
+ print(f" ESM-2 35M: {results[task]['esm2_35m']:.4f}")
347
+
348
+ with open("downstream_results.json", "w") as f:
349
+ json.dump(results, f, indent=2)
350
+
351
+ return results
352
+
353
+
354
+ if __name__ == "__main__":
355
+ device = "cuda" if torch.cuda.is_available() else "cpu"
356
+ print(f"Using device: {device}")
357
+
358
+ # Quick test on solubility (smallest dataset)
359
+ tokenizer = ProteinTokenizer()
360
+
361
+ config = ModernProteinLMConfig(
362
+ vocab_size=33,
363
+ hidden_size=128,
364
+ num_hidden_layers=4,
365
+ num_attention_heads=4,
366
+ intermediate_size=512,
367
+ use_geglu=True,
368
+ tie_word_embeddings=True,
369
+ )
370
+ model = ModernProteinLM(config)
371
+
372
+ print(f"\nTesting on solubility (tiny model)...")
373
+ trained_model, score = train_downstream(
374
+ model, "solubility", tokenizer, epochs=5, batch_size=8, lr=5e-4, device=device
375
+ )
376
+ print(f"Solubility accuracy: {score:.4f}")