#!/usr/bin/env python3 """ Evaluate a student checkpoint against the frozen teacher using the same single-process sharded setup and fixed eval cache as distill_sharded.py. """ from __future__ import annotations import argparse from pathlib import Path import torch import distill_sharded as ds def main(): p = argparse.ArgumentParser() p.add_argument("--config", required=True) p.add_argument("--student", default=None, help="Optional student override path") p.add_argument("--samples", type=int, default=None, help="Optional eval sample override") args = p.parse_args() cfg = ds.load_config(args.config) if args.student: cfg["model"]["student"] = args.student if args.samples: cfg["eval"]["samples"] = args.samples student_device = torch.device(cfg["model"]["student_device"]) teacher_devices = list(cfg["model"]["teacher_devices"]) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pad_id = tokenizer.pad_token_id student = ds.load_student( cfg["model"]["student"], ds.parse_dtype(cfg["train"]["student_dtype"]), grad_ckpt=False, attn_impl=cfg["train"]["attn_implementation"], ) student.to(student_device) student.eval() teacher = ds.load_teacher( cfg["model"]["teacher"], ds.parse_dtype(cfg["train"]["teacher_dtype"]), attn_impl=cfg["train"]["attn_implementation"], devices=teacher_devices, max_mem_gb=cfg["model"]["teacher_max_memory_gb"], ) teacher_input_device, _ = ds.get_teacher_devices(teacher) specs = ds.build_dataset_specs(cfg["data"]) if Path(cfg["eval"]["cache_path"]).exists(): eval_batches = ds.build_or_load_eval_cache(cfg["eval"]["cache_path"]) else: eval_loader = ds.MixedStreamingLoader( specs=specs, tokenizer=tokenizer, min_chars=cfg["data"]["min_chars"], max_seq_len=cfg["data"]["max_seq_len"], kl_start_pos=cfg["data"]["kl_start_pos"], seed=cfg["eval"]["seed"], shuffle_buffer=cfg["data"]["shuffle_buffer"], ) eval_batches = ds.build_or_load_eval_cache( cfg["eval"]["cache_path"], eval_loader, cfg["eval"]["samples"], ) kl = ds.evaluate( student, teacher, eval_batches, pad_id, cfg["data"]["kl_start_pos"], cfg["train"]["kl_chunk_size"], student_device, teacher_input_device, ) print(f"{kl:.6f}") if __name__ == "__main__": main()