Delta-Vector commited on
Commit
93462da
·
verified ·
1 Parent(s): 729546e

Upload eval_kl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_kl.py +90 -0
eval_kl.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate a student checkpoint against the frozen teacher using the same
4
+ single-process sharded setup and fixed eval cache as distill_sharded.py.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ from pathlib import Path
11
+
12
+ import torch
13
+
14
+ import distill_sharded as ds
15
+
16
+
17
+ def main():
18
+ p = argparse.ArgumentParser()
19
+ p.add_argument("--config", required=True)
20
+ p.add_argument("--student", default=None, help="Optional student override path")
21
+ p.add_argument("--samples", type=int, default=None, help="Optional eval sample override")
22
+ args = p.parse_args()
23
+
24
+ cfg = ds.load_config(args.config)
25
+ if args.student:
26
+ cfg["model"]["student"] = args.student
27
+ if args.samples:
28
+ cfg["eval"]["samples"] = args.samples
29
+
30
+ student_device = torch.device(cfg["model"]["student_device"])
31
+ teacher_devices = list(cfg["model"]["teacher_devices"])
32
+
33
+ from transformers import AutoTokenizer
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True)
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ pad_id = tokenizer.pad_token_id
39
+
40
+ student = ds.load_student(
41
+ cfg["model"]["student"],
42
+ ds.parse_dtype(cfg["train"]["student_dtype"]),
43
+ grad_ckpt=False,
44
+ attn_impl=cfg["train"]["attn_implementation"],
45
+ )
46
+ student.to(student_device)
47
+ student.eval()
48
+
49
+ teacher = ds.load_teacher(
50
+ cfg["model"]["teacher"],
51
+ ds.parse_dtype(cfg["train"]["teacher_dtype"]),
52
+ attn_impl=cfg["train"]["attn_implementation"],
53
+ devices=teacher_devices,
54
+ max_mem_gb=cfg["model"]["teacher_max_memory_gb"],
55
+ )
56
+ teacher_input_device, _ = ds.get_teacher_devices(teacher)
57
+
58
+ specs = ds.build_dataset_specs(cfg["data"])
59
+ if Path(cfg["eval"]["cache_path"]).exists():
60
+ eval_batches = ds.build_or_load_eval_cache(cfg["eval"]["cache_path"])
61
+ else:
62
+ eval_loader = ds.MixedStreamingLoader(
63
+ specs=specs,
64
+ tokenizer=tokenizer,
65
+ min_chars=cfg["data"]["min_chars"],
66
+ max_seq_len=cfg["data"]["max_seq_len"],
67
+ kl_start_pos=cfg["data"]["kl_start_pos"],
68
+ seed=cfg["eval"]["seed"],
69
+ shuffle_buffer=cfg["data"]["shuffle_buffer"],
70
+ )
71
+ eval_batches = ds.build_or_load_eval_cache(
72
+ cfg["eval"]["cache_path"],
73
+ eval_loader,
74
+ cfg["eval"]["samples"],
75
+ )
76
+ kl = ds.evaluate(
77
+ student,
78
+ teacher,
79
+ eval_batches,
80
+ pad_id,
81
+ cfg["data"]["kl_start_pos"],
82
+ cfg["train"]["kl_chunk_size"],
83
+ student_device,
84
+ teacher_input_device,
85
+ )
86
+ print(f"{kl:.6f}")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()