| |
| """ |
| Single-process KL distillation with a sharded frozen teacher and one trainable |
| student GPU. |
| |
| This is a derivative of distill.py tailored for large-teacher / smaller-student |
| setups where replicating the teacher per process is wasteful or infeasible. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import gc |
| import json |
| import logging |
| import random |
| import re |
| import shutil |
| import time |
| import tomllib |
| from collections import OrderedDict |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint_utils |
| from torch.optim import AdamW |
|
|
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger("distill_sharded") |
|
|
|
|
| REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init") |
| REQUIRED_KEYS = { |
| "model": ("teacher", "student", "tokenizer", "student_device", "teacher_devices", "teacher_max_memory_gb"), |
| "data": ("min_chars", "max_seq_len", "kl_start_pos", "seed", "shuffle_buffer"), |
| "train": ( |
| "seed", |
| "lr", |
| "schedule", |
| "warmup_steps", |
| "weight_decay", |
| "grad_clip", |
| "betas", |
| "eps", |
| "samples_per_step", |
| "max_steps", |
| "grad_checkpointing", |
| "attn_implementation", |
| "student_dtype", |
| "teacher_dtype", |
| "kl_chunk_size", |
| "micro_batch_size", |
| "new_layer_lr_mul", |
| ), |
| "eval": ("every_steps", "samples", "seed", "cache_path"), |
| "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir", "experiment_log"), |
| "init": ("zero_layers", "target_num_layers"), |
| } |
|
|
| DTYPE_MAP = { |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| } |
|
|
|
|
| def parse_dtype(s: str) -> torch.dtype: |
| if s not in DTYPE_MAP: |
| raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}") |
| return DTYPE_MAP[s] |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path, "rb") as f: |
| cfg = tomllib.load(f) |
| for sec in REQUIRED_SECTIONS: |
| if sec not in cfg: |
| raise KeyError(f"config missing required section [{sec}]") |
| for key in REQUIRED_KEYS[sec]: |
| if key not in cfg[sec]: |
| raise KeyError(f"config missing required key [{sec}].{key}") |
| return cfg |
|
|
|
|
| def get_inner_with_layers(model): |
| seen = set() |
| stack = [model] |
| while stack: |
| m = stack.pop() |
| if id(m) in seen: |
| continue |
| seen.add(id(m)) |
| if hasattr(m, "layers"): |
| return m |
| for attr in ("model", "language_model", "transformer", "base_model"): |
| child = getattr(m, attr, None) |
| if child is not None: |
| stack.append(child) |
| raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}") |
|
|
|
|
| def zero_layers(model, layer_indices): |
| inner = get_inner_with_layers(model) |
| layers = inner.layers |
| n = len(layers) |
| for idx in layer_indices: |
| if idx < 0 or idx >= n: |
| raise IndexError(f"layer {idx} out of range (0..{n - 1})") |
| with torch.no_grad(): |
| for p in layers[idx].parameters(): |
| p.zero_() |
| return n |
|
|
|
|
| def _zero_output_projections(layer): |
| zeroed = [] |
| with torch.no_grad(): |
| if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"): |
| layer.self_attn.o_proj.weight.zero_() |
| zeroed.append("self_attn.o_proj") |
| if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"): |
| layer.linear_attn.out_proj.weight.zero_() |
| zeroed.append("linear_attn.out_proj") |
| if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"): |
| layer.mlp.down_proj.weight.zero_() |
| zeroed.append("mlp.down_proj") |
| return zeroed |
|
|
|
|
| def grow_layers(model, target_n): |
| inner = get_inner_with_layers(model) |
| cur_n = len(inner.layers) |
| if target_n == cur_n: |
| return cur_n, [] |
| if target_n < cur_n: |
| raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink") |
|
|
| cfg = model.config |
| text_cfg = getattr(cfg, "text_config", cfg) |
| if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types: |
| raise RuntimeError("text config has no layer_types; cannot extend pattern") |
|
|
| period = getattr(text_cfg, "full_attention_interval", 4) |
| new_types = list(text_cfg.layer_types) |
| while len(new_types) < target_n: |
| new_types.append(new_types[len(new_types) % period]) |
| text_cfg.layer_types = new_types |
| text_cfg.num_hidden_layers = target_n |
| if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg: |
| cfg.num_hidden_layers = target_n |
|
|
| layer_cls = type(inner.layers[0]) |
| device = next(inner.parameters()).device |
| dtype = next(inner.parameters()).dtype |
|
|
| new_layer_zeroed = [] |
| for i in range(cur_n, target_n): |
| new_layer = layer_cls(text_cfg, layer_idx=i) |
| new_layer.apply(model._init_weights) |
| new_layer.to(device=device, dtype=dtype) |
| zeroed = _zero_output_projections(new_layer) |
| new_layer_zeroed.append((i, zeroed)) |
| inner.layers.append(new_layer) |
|
|
| return target_n, new_layer_zeroed |
|
|
|
|
| def detect_model_kind(model_id: str) -> str: |
| from transformers import AutoConfig |
|
|
| cfg = AutoConfig.from_pretrained(model_id) |
| archs = list(getattr(cfg, "architectures", []) or []) |
| arch = archs[0] if archs else "" |
| if "ConditionalGeneration" in arch or "ImageText" in arch: |
| return "image_text" |
| return "causal_lm" |
|
|
|
|
| def load_student(model_id: str, dtype: torch.dtype, grad_ckpt: bool, attn_impl: str): |
| kind = detect_model_kind(model_id) |
| if kind == "image_text": |
| from transformers import AutoModelForImageTextToText |
|
|
| model = AutoModelForImageTextToText.from_pretrained( |
| model_id, |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| ) |
| else: |
| from transformers import AutoModelForCausalLM |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| ) |
| model.config.use_cache = False |
| if grad_ckpt: |
| model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| return model |
|
|
|
|
| def load_teacher(model_id: str, dtype: torch.dtype, attn_impl: str, devices: list[int], max_mem_gb: int): |
| kind = detect_model_kind(model_id) |
| max_memory = {idx: f"{max_mem_gb}GiB" for idx in devices} |
| max_memory["cpu"] = "256GiB" |
| common = dict( |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| device_map="auto", |
| max_memory=max_memory, |
| ) |
|
|
| if kind == "image_text": |
| from transformers import AutoModelForImageTextToText |
|
|
| model = AutoModelForImageTextToText.from_pretrained(model_id, **common) |
| else: |
| from transformers import AutoModelForCausalLM |
|
|
| model = AutoModelForCausalLM.from_pretrained(model_id, **common) |
| model.config.use_cache = False |
| model.eval() |
| for p in model.parameters(): |
| p.requires_grad_(False) |
| return model |
|
|
|
|
| def get_teacher_devices(model) -> tuple[torch.device, torch.device]: |
| device_map = getattr(model, "hf_device_map", None) or {} |
| ordered = OrderedDict() |
| for _, dev in device_map.items(): |
| if isinstance(dev, int): |
| ordered.setdefault(f"cuda:{dev}", None) |
| elif isinstance(dev, str) and dev.startswith("cuda:"): |
| ordered.setdefault(dev, None) |
| if not ordered: |
| first = next(model.parameters()).device |
| return first, first |
| keys = list(ordered.keys()) |
| return torch.device(keys[0]), torch.device(keys[-1]) |
|
|
|
|
| def teacher_forward(teacher, input_ids, attention_mask, out_device): |
| out = teacher(input_ids=input_ids, attention_mask=attention_mask) |
| logits = getattr(out, "logits", None) |
| if logits is None: |
| raise RuntimeError("teacher forward did not return .logits") |
| if logits.device != out_device: |
| logits = logits.to(out_device, non_blocking=True) |
| return logits |
|
|
|
|
| class StreamingTextLoader: |
| def __init__( |
| self, |
| name, |
| text_field, |
| min_chars, |
| max_seq_len, |
| kl_start_pos, |
| tokenizer, |
| seed, |
| shuffle_buffer, |
| ): |
| from datasets import load_dataset |
|
|
| last_err = None |
| for attempt in range(8): |
| try: |
| ds = load_dataset(name, split="train", streaming=True) |
| break |
| except Exception as e: |
| last_err = e |
| wait = min(2 ** attempt, 30) |
| log.warning( |
| f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): " |
| f"{type(e).__name__}: {e}; sleeping {wait}s" |
| ) |
| time.sleep(wait) |
| else: |
| raise RuntimeError(f"load_dataset failed after 8 retries") from last_err |
| ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer) |
| self._ds = iter(ds) |
| self._text_field = text_field |
| self._min_chars = min_chars |
| self._max_seq_len = max_seq_len |
| self._min_tokens = kl_start_pos + 16 |
| self._tokenizer = tokenizer |
| self._name = name |
|
|
| def next_sample(self): |
| scanned = 0 |
| while scanned < 100: |
| try: |
| item = next(self._ds) |
| except StopIteration: |
| return None |
| scanned += 1 |
| text = item.get(self._text_field, "") or "" |
| if len(text) < self._min_chars: |
| continue |
| ids = self._tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=self._max_seq_len, |
| ).input_ids.squeeze(0) |
| if ids.shape[0] < self._min_tokens: |
| continue |
| return ids |
| return None |
|
|
|
|
| class MixedStreamingLoader: |
| def __init__(self, specs, tokenizer, min_chars, max_seq_len, kl_start_pos, seed, shuffle_buffer): |
| self._rng = random.Random(seed) |
| self._weights = [] |
| self._loaders = [] |
| for spec in specs: |
| self._weights.append(spec["weight"]) |
| self._loaders.append( |
| StreamingTextLoader( |
| name=spec["name"], |
| text_field=spec["text_field"], |
| min_chars=min_chars, |
| max_seq_len=max_seq_len, |
| kl_start_pos=kl_start_pos, |
| tokenizer=tokenizer, |
| seed=seed + len(self._loaders), |
| shuffle_buffer=shuffle_buffer, |
| ) |
| ) |
|
|
| def next_batch(self, n): |
| out = [] |
| while len(out) < n: |
| idx = self._rng.choices(range(len(self._loaders)), weights=self._weights, k=1)[0] |
| sample = self._loaders[idx].next_sample() |
| if sample is None: |
| continue |
| out.append(sample) |
| return out |
|
|
|
|
| def collate_pad(token_lists, pad_id): |
| max_len = max(t.shape[0] for t in token_lists) |
| B = len(token_lists) |
| input_ids = torch.full((B, max_len), pad_id, dtype=torch.long) |
| attention_mask = torch.zeros((B, max_len), dtype=torch.long) |
| for i, t in enumerate(token_lists): |
| L = t.shape[0] |
| input_ids[i, :L] = t |
| attention_mask[i, :L] = 1 |
| return input_ids, attention_mask |
|
|
|
|
| def _kl_chunk_sum(s_chunk, t_chunk, m_chunk): |
| s = s_chunk.float() |
| t = t_chunk.float() |
| t_log_p = F.log_softmax(t, dim=-1) |
| s_log_p = F.log_softmax(s, dim=-1) |
| t_p = t_log_p.exp() |
| per_token = (t_p * (t_log_p - s_log_p)).sum(-1) |
| return (per_token * m_chunk).sum() |
|
|
|
|
| def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size): |
| s_full = student_logits[:, start_pos:, :] |
| t_full = teacher_logits[:, start_pos:, :].detach() |
| m_full = attention_mask[:, start_pos:].float() |
|
|
| T = s_full.shape[1] |
| if chunk_size <= 0 or chunk_size >= T: |
| return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0) |
|
|
| total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32) |
| for i in range(0, T, chunk_size): |
| end = min(i + chunk_size, T) |
| s_c = s_full[:, i:end, :] |
| t_c = t_full[:, i:end, :] |
| m_c = m_full[:, i:end] |
| chunk_kl = checkpoint_utils.checkpoint( |
| _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False |
| ) |
| total_kl = total_kl + chunk_kl |
| return total_kl / m_full.sum().clamp_min(1.0) |
|
|
|
|
| def apply_trainable_masks(model, train_cfg): |
| trainable = train_cfg.get("trainable_patterns", []) |
| frozen = train_cfg.get("freeze_patterns", []) |
| if not trainable and not frozen: |
| return |
|
|
| trainable_re = [re.compile(p) for p in trainable] |
| frozen_re = [re.compile(p) for p in frozen] |
| for name, p in model.named_parameters(): |
| keep = True |
| if trainable_re: |
| keep = any(r.search(name) for r in trainable_re) |
| if keep and frozen_re and any(r.search(name) for r in frozen_re): |
| keep = False |
| p.requires_grad_(keep) |
|
|
|
|
| def make_optimizer(model, train_cfg, new_layer_indices=None): |
| base_lr = train_cfg["lr"] |
| mul = train_cfg["new_layer_lr_mul"] |
| common = dict( |
| weight_decay=train_cfg["weight_decay"], |
| betas=tuple(train_cfg["betas"]), |
| eps=train_cfg["eps"], |
| ) |
|
|
| if not new_layer_indices or mul == 1.0: |
| return AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=base_lr, |
| **common, |
| ) |
|
|
| inner = get_inner_with_layers(model) |
| new_pids = set() |
| for idx in new_layer_indices: |
| for p in inner.layers[idx].parameters(): |
| if p.requires_grad: |
| new_pids.add(id(p)) |
|
|
| new_params = [] |
| rest_params = [] |
| for p in model.parameters(): |
| if not p.requires_grad: |
| continue |
| (new_params if id(p) in new_pids else rest_params).append(p) |
|
|
| return AdamW( |
| [ |
| {"params": rest_params, "lr": base_lr}, |
| {"params": new_params, "lr": base_lr * mul}, |
| ], |
| **common, |
| ) |
|
|
|
|
| def make_scheduler(optimizer, train_cfg): |
| schedule = train_cfg["schedule"] |
| warmup = train_cfg["warmup_steps"] |
| total = train_cfg["max_steps"] |
|
|
| if schedule == "constant": |
| from transformers import get_constant_schedule_with_warmup |
|
|
| return get_constant_schedule_with_warmup(optimizer, warmup) |
| if schedule == "cosine": |
| from transformers import get_cosine_schedule_with_warmup |
|
|
| return get_cosine_schedule_with_warmup(optimizer, warmup, total) |
| if schedule == "linear": |
| from transformers import get_linear_schedule_with_warmup |
|
|
| return get_linear_schedule_with_warmup(optimizer, warmup, total) |
| raise ValueError(f"unknown schedule: {schedule!r}") |
|
|
|
|
| def build_dataset_specs(data_cfg): |
| if "datasets" in data_cfg: |
| names = data_cfg["datasets"] |
| text_fields = data_cfg.get("text_fields", [data_cfg.get("text_field", "text")] * len(names)) |
| weights = data_cfg.get("dataset_weights", [1.0] * len(names)) |
| if not (len(names) == len(text_fields) == len(weights)): |
| raise ValueError("datasets/text_fields/dataset_weights length mismatch") |
| return [ |
| {"name": name, "text_field": field, "weight": weight} |
| for name, field, weight in zip(names, text_fields, weights) |
| ] |
| return [ |
| { |
| "name": data_cfg["dataset"], |
| "text_field": data_cfg["text_field"], |
| "weight": 1.0, |
| } |
| ] |
|
|
|
|
| def build_or_load_eval_cache(path, loader=None, samples=None): |
| path = Path(path) |
| if path.exists(): |
| log.info(f"Loading eval cache from {path}") |
| raw = torch.load(path) |
| return [torch.tensor(x, dtype=torch.long) for x in raw] |
| if loader is None or samples is None: |
| raise ValueError("loader and samples are required when building a new eval cache") |
| path.parent.mkdir(parents=True, exist_ok=True) |
| log.info(f"Building eval cache at {path}") |
| batches = loader.next_batch(samples) |
| torch.save([x.tolist() for x in batches], path) |
| return batches |
|
|
|
|
| def log_jsonl(path: Path, record: dict): |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("a") as f: |
| f.write(json.dumps(record, sort_keys=True) + "\n") |
|
|
|
|
| @torch.no_grad() |
| def evaluate(student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, student_device, teacher_input_device): |
| student.eval() |
| total = 0.0 |
| n = 0 |
| for sample in eval_batches: |
| ids, mask = collate_pad([sample], pad_id) |
| teacher_ids = ids.to(teacher_input_device, non_blocking=True) |
| teacher_mask = mask.to(teacher_input_device, non_blocking=True) |
| student_ids = ids.to(student_device, non_blocking=True) |
| student_mask = mask.to(student_device, non_blocking=True) |
| t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device) |
| s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits |
| loss = kl_loss_masked( |
| s_logits, |
| t_logits, |
| student_mask, |
| start_pos=kl_start_pos, |
| chunk_size=kl_chunk_size, |
| ) |
| total += loss.item() |
| n += 1 |
| del t_logits, s_logits, loss, teacher_ids, teacher_mask, student_ids, student_mask |
| student.train() |
| return total / max(n, 1) |
|
|
|
|
| def save_best(student, tokenizer, output_dir, step, eval_kl): |
| out_dir = Path(output_dir) / "best" |
| if out_dir.exists(): |
| shutil.rmtree(out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| student.save_pretrained(out_dir, safe_serialization=True) |
| tokenizer.save_pretrained(out_dir) |
| with (out_dir / "best.json").open("w") as f: |
| json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2) |
| log.info(f"saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", required=True) |
| args = parser.parse_args() |
|
|
| cfg = load_config(args.config) |
| torch.manual_seed(cfg["train"]["seed"]) |
| random.seed(cfg["train"]["seed"]) |
|
|
| student_device = torch.device(cfg["model"]["student_device"]) |
| teacher_devices = list(cfg["model"]["teacher_devices"]) |
|
|
| from transformers import AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| pad_id = tokenizer.pad_token_id |
|
|
| student = load_student( |
| cfg["model"]["student"], |
| parse_dtype(cfg["train"]["student_dtype"]), |
| grad_ckpt=cfg["train"]["grad_checkpointing"], |
| attn_impl=cfg["train"]["attn_implementation"], |
| ) |
| student.to(student_device) |
| student.config.use_cache = False |
|
|
| target_n = cfg["init"]["target_num_layers"] |
| cur_n = len(get_inner_with_layers(student).layers) |
| new_layer_indices = [] |
| if target_n != cur_n: |
| new_n, new_zeroed = grow_layers(student, target_n) |
| new_layer_indices = [idx for idx, _ in new_zeroed] |
| log.info(f"Grew student from {cur_n} -> {new_n} layers") |
| for idx, names in new_zeroed: |
| log.info(f" layer {idx}: zeroed {names}") |
|
|
| zero_idx = cfg["init"]["zero_layers"] |
| if zero_idx: |
| n = zero_layers(student, zero_idx) |
| log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)") |
|
|
| apply_trainable_masks(student, cfg["train"]) |
| trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in student.parameters()) |
| if trainable_params == 0: |
| raise RuntimeError("No trainable parameters remain after applying trainable/freeze patterns") |
| log.info(f"Student params: total={total_params/1e9:.3f}B trainable={trainable_params/1e9:.3f}B") |
|
|
| teacher = load_teacher( |
| cfg["model"]["teacher"], |
| parse_dtype(cfg["train"]["teacher_dtype"]), |
| attn_impl=cfg["train"]["attn_implementation"], |
| devices=teacher_devices, |
| max_mem_gb=cfg["model"]["teacher_max_memory_gb"], |
| ) |
| teacher_input_device, _ = get_teacher_devices(teacher) |
| log.info(f"Teacher input device: {teacher_input_device}") |
|
|
| optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices) |
| scheduler = make_scheduler(optimizer, cfg["train"]) |
|
|
| output_dir = Path(cfg["log"]["output_dir"]) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(args.config, output_dir / "config.snapshot.toml") |
| metrics_path = output_dir / "metrics.jsonl" |
| experiment_log = Path(cfg["log"]["experiment_log"]) |
|
|
| use_wandb = cfg["log"]["wandb"] |
| if use_wandb: |
| import wandb |
|
|
| wandb.init( |
| project=cfg["log"]["wandb_project"], |
| name=cfg["log"]["wandb_run"], |
| config=cfg, |
| ) |
|
|
| specs = build_dataset_specs(cfg["data"]) |
| train_loader = MixedStreamingLoader( |
| specs=specs, |
| tokenizer=tokenizer, |
| min_chars=cfg["data"]["min_chars"], |
| max_seq_len=cfg["data"]["max_seq_len"], |
| kl_start_pos=cfg["data"]["kl_start_pos"], |
| seed=cfg["data"]["seed"], |
| shuffle_buffer=cfg["data"]["shuffle_buffer"], |
| ) |
| eval_cache_path = Path(cfg["eval"]["cache_path"]) |
| if eval_cache_path.exists(): |
| eval_batches = build_or_load_eval_cache(eval_cache_path) |
| else: |
| eval_loader = MixedStreamingLoader( |
| specs=specs, |
| tokenizer=tokenizer, |
| min_chars=cfg["data"]["min_chars"], |
| max_seq_len=cfg["data"]["max_seq_len"], |
| kl_start_pos=cfg["data"]["kl_start_pos"], |
| seed=cfg["eval"]["seed"], |
| shuffle_buffer=cfg["data"]["shuffle_buffer"], |
| ) |
| eval_batches = build_or_load_eval_cache(eval_cache_path, eval_loader, cfg["eval"]["samples"]) |
| log.info(f"Eval samples: {len(eval_batches)}") |
|
|
| samples_per_step = cfg["train"]["samples_per_step"] |
| micro_batch_size = cfg["train"]["micro_batch_size"] |
| grad_clip = cfg["train"]["grad_clip"] |
| kl_start_pos = cfg["data"]["kl_start_pos"] |
| kl_chunk_size = cfg["train"]["kl_chunk_size"] |
| max_steps = cfg["train"]["max_steps"] |
| eval_every = cfg["eval"]["every_steps"] |
| log_every = cfg["log"]["log_every"] |
|
|
| student.train() |
| best_kl = float("inf") |
| global_step = 0 |
| run_summary = { |
| "config": args.config, |
| "run_name": cfg["log"]["wandb_run"], |
| "student": cfg["model"]["student"], |
| "teacher": cfg["model"]["teacher"], |
| "start_time": int(time.time()), |
| } |
|
|
| while global_step < max_steps: |
| t0 = time.time() |
| batch = train_loader.next_batch(samples_per_step) |
| optimizer.zero_grad(set_to_none=True) |
| batch_n = len(batch) |
| kl_sum = 0.0 |
|
|
| for mb_start in range(0, batch_n, micro_batch_size): |
| micro = batch[mb_start : mb_start + micro_batch_size] |
| mb_n = len(micro) |
| ids, mask = collate_pad(micro, pad_id) |
| teacher_ids = ids.to(teacher_input_device, non_blocking=True) |
| teacher_mask = mask.to(teacher_input_device, non_blocking=True) |
| student_ids = ids.to(student_device, non_blocking=True) |
| student_mask = mask.to(student_device, non_blocking=True) |
|
|
| with torch.no_grad(): |
| t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device) |
| s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits |
| loss = kl_loss_masked( |
| s_logits, |
| t_logits, |
| student_mask, |
| start_pos=kl_start_pos, |
| chunk_size=kl_chunk_size, |
| ) |
| scaled = loss * (mb_n / batch_n) |
| scaled.backward() |
| kl_sum += loss.item() * mb_n |
| del teacher_ids, teacher_mask, student_ids, student_mask, t_logits, s_logits, loss, scaled |
|
|
| if grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(student.parameters(), grad_clip) |
| optimizer.step() |
| scheduler.step() |
| global_step += 1 |
|
|
| elapsed = time.time() - t0 |
| kl_avg = kl_sum / batch_n |
| lr_now = scheduler.get_last_lr()[0] |
| record = { |
| "step": global_step, |
| "train_kl": kl_avg, |
| "lr": lr_now, |
| "step_time_s": elapsed, |
| } |
| log_jsonl(metrics_path, record) |
|
|
| if global_step % log_every == 0: |
| log.info( |
| f"step {global_step}/{max_steps} | kl {kl_avg:.6f} | " |
| f"lr {lr_now:.2e} | {elapsed:.2f}s" |
| ) |
| if use_wandb: |
| import wandb |
|
|
| wandb.log( |
| { |
| "train/kl": kl_avg, |
| "train/lr": lr_now, |
| "perf/step_time_s": elapsed, |
| }, |
| step=global_step, |
| ) |
|
|
| if global_step % eval_every == 0: |
| eval_kl = evaluate( |
| student, |
| teacher, |
| eval_batches, |
| pad_id, |
| kl_start_pos, |
| kl_chunk_size, |
| student_device, |
| teacher_input_device, |
| ) |
| log.info(f"eval @ step {global_step}: kl={eval_kl:.6f} (best={best_kl:.6f})") |
| log_jsonl(metrics_path, {"step": global_step, "eval_kl": eval_kl}) |
| if use_wandb: |
| import wandb |
|
|
| wandb.log({"eval/kl": eval_kl}, step=global_step) |
| if eval_kl < best_kl: |
| best_kl = eval_kl |
| save_best(student, tokenizer, output_dir, global_step, eval_kl) |
| student.train() |
|
|
| if global_step % 10 == 0: |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| final_eval = evaluate( |
| student, |
| teacher, |
| eval_batches, |
| pad_id, |
| kl_start_pos, |
| kl_chunk_size, |
| student_device, |
| teacher_input_device, |
| ) |
| log.info(f"final eval: kl={final_eval:.6f} (best={best_kl:.6f})") |
| if final_eval < best_kl: |
| best_kl = final_eval |
| save_best(student, tokenizer, output_dir, global_step, final_eval) |
|
|
| run_summary.update( |
| { |
| "end_time": int(time.time()), |
| "best_eval_kl": best_kl, |
| "final_eval_kl": final_eval, |
| "max_steps": max_steps, |
| "student_total_params": total_params, |
| "student_trainable_params": trainable_params, |
| } |
| ) |
| log_jsonl(experiment_log, run_summary) |
|
|
| if use_wandb: |
| import wandb |
|
|
| wandb.log({"eval/final_kl": final_eval, "eval/best_kl": best_kl}, step=global_step) |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|