GrimSqueaker commited on
Commit
3714d46
·
verified ·
1 Parent(s): 9cac9b5

Upload train_finetune.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_finetune.py +430 -0
train_finetune.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tune pretrained ModernProteinLM on downstream predictive tasks.
3
+ Supports: regression (fluorescence, stability), classification (solubility, remote homology).
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import json
10
+ import random
11
+ import math
12
+ from typing import Dict, List
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed as dist
19
+ from torch.nn.parallel import DistributedDataParallel as DDP
20
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
21
+ from torch.cuda.amp import autocast, GradScaler
22
+ from transformers import get_cosine_schedule_with_warmup
23
+ from datasets import load_dataset
24
+ from scipy.stats import spearmanr
25
+ from sklearn.metrics import accuracy_score, f1_score
26
+
27
+ from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
28
+
29
+
30
+ # =============================================================================
31
+ # TOKENIZER (shared with pretrain)
32
+ # =============================================================================
33
+
34
+ class ProteinTokenizer:
35
+ def __init__(self):
36
+ self.vocab = {
37
+ "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
38
+ "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
39
+ "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
40
+ "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
41
+ "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29, "<sep>": 30,
42
+ }
43
+ while len(self.vocab) < 33:
44
+ self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
45
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
46
+ self.mask_token_id = 29
47
+ self.pad_token_id = 1
48
+ self.cls_token_id = 0
49
+ self.eos_token_id = 2
50
+
51
+ def encode(self, sequence: str, max_length: int = 1024):
52
+ tokens = [self.cls_token_id]
53
+ for aa in sequence.upper():
54
+ tokens.append(self.vocab.get(aa, self.vocab["<unk>"]))
55
+ tokens.append(self.eos_token_id)
56
+ if len(tokens) > max_length:
57
+ tokens = tokens[:max_length]
58
+ attention_mask = [1] * len(tokens)
59
+ while len(tokens) < max_length:
60
+ tokens.append(self.pad_token_id)
61
+ attention_mask.append(0)
62
+ return {"input_ids": tokens, "attention_mask": attention_mask}
63
+
64
+
65
+ def setup_distributed():
66
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
67
+ rank = int(os.environ["RANK"])
68
+ world_size = int(os.environ["WORLD_SIZE"])
69
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
70
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
71
+ torch.cuda.set_device(local_rank)
72
+ return rank, world_size, local_rank
73
+ return 0, 1, 0
74
+
75
+
76
+ def log_rank0(msg):
77
+ if not dist.is_initialized() or dist.get_rank() == 0:
78
+ print(msg)
79
+
80
+
81
+ # =============================================================================
82
+ # TASK DEFINITIONS
83
+ # =============================================================================
84
+
85
+ TASK_SPECS = {
86
+ "fluorescence": {
87
+ "dataset": "proteinea/fluorescence",
88
+ "seq_key": "primary",
89
+ "label_key": "log_fluorescence",
90
+ "task_type": "regression",
91
+ "metric": "spearman",
92
+ "splits": ["train", "validation", "test"],
93
+ },
94
+ "stability": {
95
+ "dataset": "proteinea/fluorescence",
96
+ "seq_key": "primary",
97
+ "label_key": "log_fluorescence",
98
+ "task_type": "regression",
99
+ "metric": "spearman",
100
+ "splits": ["train", "validation", "test"],
101
+ },
102
+ "solubility": {
103
+ "dataset": "proteinea/solubility",
104
+ "seq_key": "sequences",
105
+ "label_key": "labels",
106
+ "task_type": "classification",
107
+ "num_labels": 2,
108
+ "metric": "accuracy",
109
+ "splits": ["train", "validation", "test"],
110
+ },
111
+ "remote_homology": {
112
+ "dataset": "proteinea/remote_homology",
113
+ "seq_key": "primary",
114
+ "label_key": "fold_label",
115
+ "task_type": "classification",
116
+ "num_labels": 1195,
117
+ "metric": "accuracy",
118
+ "splits": ["train", "validation", "test"],
119
+ },
120
+ }
121
+
122
+
123
+ class DownstreamDataset(Dataset):
124
+ def __init__(self, task_name, split, tokenizer, max_length=1024):
125
+ self.spec = TASK_SPECS[task_name]
126
+ self.tokenizer = tokenizer
127
+ self.max_length = max_length
128
+
129
+ try:
130
+ self.data = load_dataset(self.spec["dataset"], split=split)
131
+ except Exception as e:
132
+ log_rank0(f"Failed to load {split}: {e}, using train")
133
+ self.data = load_dataset(self.spec["dataset"], split="train")
134
+
135
+ self.examples = list(self.data)
136
+
137
+ def __len__(self):
138
+ return len(self.examples)
139
+
140
+ def __getitem__(self, idx):
141
+ ex = self.examples[idx]
142
+ seq = ex[self.spec["seq_key"]]
143
+ encoded = self.tokenizer.encode(seq, self.max_length)
144
+
145
+ item = {
146
+ "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
147
+ "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
148
+ }
149
+
150
+ if self.spec["task_type"] == "regression":
151
+ item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.float)
152
+ else:
153
+ item["labels"] = torch.tensor(ex[self.spec["label_key"]], dtype=torch.long)
154
+
155
+ return item
156
+
157
+
158
+ def mean_pool(hidden_states, attention_mask):
159
+ mask = attention_mask.unsqueeze(-1).float()
160
+ return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
161
+
162
+
163
+ class TaskHead(nn.Module):
164
+ def __init__(self, hidden_size, task_spec):
165
+ super().__init__()
166
+ if task_spec["task_type"] == "regression":
167
+ self.head = nn.Linear(hidden_size, 1)
168
+ else:
169
+ self.head = nn.Linear(hidden_size, task_spec.get("num_labels", 2))
170
+ self.task_type = task_spec["task_type"]
171
+
172
+ def forward(self, pooled):
173
+ return self.head(pooled)
174
+
175
+
176
+ def evaluate(model, head, dataloader, task_spec, device):
177
+ model.eval()
178
+ head.eval()
179
+
180
+ all_preds = []
181
+ all_labels = []
182
+ total_loss = 0.0
183
+
184
+ with torch.no_grad():
185
+ for batch in dataloader:
186
+ input_ids = batch["input_ids"].to(device)
187
+ attention_mask = batch["attention_mask"].to(device)
188
+ labels = batch["labels"].to(device)
189
+
190
+ outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
191
+ hidden = outputs.hidden_states[-1]
192
+ pooled = mean_pool(hidden, attention_mask)
193
+ logits = head(pooled)
194
+
195
+ if task_spec["task_type"] == "regression":
196
+ loss = F.mse_loss(logits.squeeze(-1), labels)
197
+ preds = logits.squeeze(-1).cpu().numpy()
198
+ else:
199
+ loss = F.cross_entropy(logits, labels)
200
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
201
+
202
+ total_loss += loss.item() * input_ids.size(0)
203
+ all_preds.extend(preds.tolist() if hasattr(preds, 'tolist') else preds)
204
+ all_labels.extend(labels.cpu().numpy().tolist())
205
+
206
+ metric = task_spec["metric"]
207
+ if metric == "spearman":
208
+ score, _ = spearmanr(all_labels, all_preds)
209
+ elif metric == "accuracy":
210
+ score = accuracy_score(all_labels, all_preds)
211
+ elif metric == "f1":
212
+ score = f1_score(all_labels, all_preds, average="macro")
213
+
214
+ return score, total_loss / len(dataloader.dataset)
215
+
216
+
217
+ def train_task(args, model, task_name, tokenizer, device, rank, world_size):
218
+ spec = TASK_SPECS[task_name]
219
+
220
+ train_ds = DownstreamDataset(task_name, spec["splits"][0], tokenizer, args.max_seq_length)
221
+ val_ds = DownstreamDataset(
222
+ task_name,
223
+ spec["splits"][1] if len(spec["splits"]) > 1 else spec["splits"][0],
224
+ tokenizer, args.max_seq_length
225
+ )
226
+ test_ds = DownstreamDataset(
227
+ task_name,
228
+ spec["splits"][-1],
229
+ tokenizer, args.max_seq_length
230
+ )
231
+
232
+ if world_size > 1:
233
+ train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank)
234
+ else:
235
+ train_sampler = None
236
+
237
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler,
238
+ num_workers=args.num_workers, pin_memory=True, drop_last=True)
239
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
240
+ num_workers=args.num_workers, pin_memory=True)
241
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
242
+ num_workers=args.num_workers, pin_memory=True)
243
+
244
+ head = TaskHead(args.hidden_size, spec).to(device)
245
+
246
+ # Layer-wise LR decay
247
+ params = [
248
+ {"params": head.parameters(), "lr": args.lr},
249
+ {"params": model.layers[-4:].parameters(), "lr": args.lr * 0.5},
250
+ {"params": model.layers[:-4].parameters(), "lr": args.lr * 0.1},
251
+ {"params": [model.embeddings.weight], "lr": args.lr * 0.1},
252
+ ]
253
+
254
+ optimizer = torch.optim.AdamW(params, weight_decay=args.weight_decay)
255
+
256
+ total_steps = len(train_loader) * args.epochs
257
+ scheduler = get_cosine_schedule_with_warmup(
258
+ optimizer, int(args.warmup_ratio * total_steps), total_steps
259
+ )
260
+
261
+ scaler = GradScaler() if args.use_amp else None
262
+
263
+ best_score = -float("inf")
264
+ best_state = None
265
+
266
+ for epoch in range(args.epochs):
267
+ model.train()
268
+ head.train()
269
+
270
+ if train_sampler:
271
+ train_sampler.set_epoch(epoch)
272
+
273
+ for batch in train_loader:
274
+ input_ids = batch["input_ids"].to(device)
275
+ attention_mask = batch["attention_mask"].to(device)
276
+ labels = batch["labels"].to(device)
277
+
278
+ with autocast(enabled=args.use_amp):
279
+ outputs = model(input_ids, attention_mask, output_hidden_states=True, return_dict=True)
280
+ hidden = outputs.hidden_states[-1]
281
+ pooled = mean_pool(hidden, attention_mask)
282
+ logits = head(pooled)
283
+
284
+ if spec["task_type"] == "regression":
285
+ loss = F.mse_loss(logits.squeeze(-1), labels)
286
+ else:
287
+ loss = F.cross_entropy(logits, labels)
288
+
289
+ if scaler:
290
+ scaler.scale(loss).backward()
291
+ scaler.unscale_(optimizer)
292
+ torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0)
293
+ scaler.step(optimizer)
294
+ scaler.update()
295
+ else:
296
+ loss.backward()
297
+ torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), 1.0)
298
+ optimizer.step()
299
+
300
+ scheduler.step()
301
+ optimizer.zero_grad()
302
+
303
+ # Evaluate
304
+ val_score, val_loss = evaluate(model, head, val_loader, spec, device)
305
+
306
+ if rank == 0:
307
+ log_rank0(f" Epoch {epoch+1}/{args.epochs}: val_{spec['metric']}={val_score:.4f}, loss={val_loss:.4f}")
308
+
309
+ if val_score > best_score:
310
+ best_score = val_score
311
+ best_state = {
312
+ "model": model.state_dict(),
313
+ "head": head.state_dict(),
314
+ }
315
+
316
+ # Load best and test
317
+ if best_state:
318
+ model.load_state_dict(best_state["model"])
319
+ head.load_state_dict(best_state["head"])
320
+
321
+ test_score, test_loss = evaluate(model, head, test_loader, spec, device)
322
+
323
+ return {
324
+ "task": task_name,
325
+ "val_score": float(best_score),
326
+ "test_score": float(test_score),
327
+ "metric": spec["metric"],
328
+ }
329
+
330
+
331
+ def main():
332
+ parser = argparse.ArgumentParser()
333
+ parser.add_argument("--pretrain_dir", required=True)
334
+ parser.add_argument("--tasks", default="fluorescence,solubility")
335
+ parser.add_argument("--epochs", type=int, default=20)
336
+ parser.add_argument("--batch_size", type=int, default=16)
337
+ parser.add_argument("--lr", type=float, default=1e-4)
338
+ parser.add_argument("--warmup_ratio", type=float, default=0.1)
339
+ parser.add_argument("--weight_decay", type=float, default=0.01)
340
+ parser.add_argument("--max_seq_length", type=int, default=1024)
341
+ parser.add_argument("--output_dir", default="./outputs/finetune")
342
+ parser.add_argument("--num_workers", type=int, default=4)
343
+ parser.add_argument("--use_amp", action="store_true")
344
+ parser.add_argument("--seed", type=int, default=42)
345
+ parser.add_argument("--use_trackio", action="store_true")
346
+ parser.add_argument("--trackio_project", default="modern-protein-lm")
347
+ args = parser.parse_args()
348
+
349
+ rank, world_size, local_rank = setup_distributed()
350
+
351
+ random.seed(args.seed + rank)
352
+ np.random.seed(args.seed + rank)
353
+ torch.manual_seed(args.seed + rank)
354
+
355
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
356
+
357
+ tokenizer = ProteinTokenizer()
358
+
359
+ # Load pretrained discriminator base
360
+ checkpoint_path = os.path.join(args.pretrain_dir, "checkpoint.pt")
361
+ if not os.path.exists(checkpoint_path):
362
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
363
+
364
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
365
+
366
+ # Infer config from checkpoint
367
+ disc_state = checkpoint["discriminator"]
368
+ # Find hidden_size from state dict
369
+ hidden_size = None
370
+ for key in disc_state:
371
+ if "model.embeddings.weight" in key:
372
+ hidden_size = disc_state[key].shape[1]
373
+ break
374
+
375
+ if hidden_size is None:
376
+ raise ValueError("Could not infer model size from checkpoint")
377
+
378
+ args.hidden_size = hidden_size
379
+
380
+ config = ModernProteinLMConfig(
381
+ vocab_size=33,
382
+ hidden_size=hidden_size,
383
+ num_hidden_layers=28,
384
+ num_attention_heads=9,
385
+ intermediate_size=2304,
386
+ use_geglu=True,
387
+ tie_word_embeddings=True,
388
+ )
389
+
390
+ model = ModernProteinLM(config).to(device)
391
+ # Load only base model weights (not discriminator head)
392
+ base_state = {k.replace("model.", ""): v for k, v in disc_state.items() if k.startswith("model.")}
393
+ model.load_state_dict(base_state, strict=False)
394
+
395
+ log_rank0(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")
396
+
397
+ if world_size > 1:
398
+ model = DDP(model, device_ids=[local_rank])
399
+
400
+ tasks = [t.strip() for t in args.tasks.split(",")]
401
+ results = {}
402
+
403
+ for task in tasks:
404
+ log_rank0(f"\n{'='*50}")
405
+ log_rank0(f"Task: {task}")
406
+ log_rank0(f"{'='*50}")
407
+
408
+ result = train_task(args, model, task, tokenizer, device, rank, world_size)
409
+ results[task] = result
410
+
411
+ if rank == 0:
412
+ log_rank0(f" Test {result['metric']}: {result['test_score']:.4f}")
413
+
414
+ if rank == 0:
415
+ os.makedirs(args.output_dir, exist_ok=True)
416
+ with open(os.path.join(args.output_dir, "results.json"), "w") as f:
417
+ json.dump(results, f, indent=2)
418
+
419
+ log_rank0(f"\n{'='*50}")
420
+ log_rank0("FINAL RESULTS")
421
+ log_rank0(f"{'='*50}")
422
+ for task, res in results.items():
423
+ log_rank0(f" {task}: {res['test_score']:.4f} ({res['metric']})")
424
+
425
+ if dist.is_initialized():
426
+ dist.destroy_process_group()
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()