#!/usr/bin/env python3 """ Single-process KL distillation with a sharded frozen teacher and one trainable student GPU. This is a derivative of distill.py tailored for large-teacher / smaller-student setups where replicating the teacher per process is wasteful or infeasible. """ from __future__ import annotations import argparse import gc import json import logging import random import re import shutil import time import tomllib from collections import OrderedDict from pathlib import Path import torch import torch.nn.functional as F import torch.utils.checkpoint as checkpoint_utils from torch.optim import AdamW logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("distill_sharded") REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init") REQUIRED_KEYS = { "model": ("teacher", "student", "tokenizer", "student_device", "teacher_devices", "teacher_max_memory_gb"), "data": ("min_chars", "max_seq_len", "kl_start_pos", "seed", "shuffle_buffer"), "train": ( "seed", "lr", "schedule", "warmup_steps", "weight_decay", "grad_clip", "betas", "eps", "samples_per_step", "max_steps", "grad_checkpointing", "attn_implementation", "student_dtype", "teacher_dtype", "kl_chunk_size", "micro_batch_size", "new_layer_lr_mul", ), "eval": ("every_steps", "samples", "seed", "cache_path"), "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir", "experiment_log"), "init": ("zero_layers", "target_num_layers"), } DTYPE_MAP = { "float32": torch.float32, "bfloat16": torch.bfloat16, } def parse_dtype(s: str) -> torch.dtype: if s not in DTYPE_MAP: raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}") return DTYPE_MAP[s] def load_config(path: str) -> dict: with open(path, "rb") as f: cfg = tomllib.load(f) for sec in REQUIRED_SECTIONS: if sec not in cfg: raise KeyError(f"config missing required section [{sec}]") for key in REQUIRED_KEYS[sec]: if key not in cfg[sec]: raise KeyError(f"config missing required key [{sec}].{key}") return cfg def get_inner_with_layers(model): seen = set() stack = [model] while stack: m = stack.pop() if id(m) in seen: continue seen.add(id(m)) if hasattr(m, "layers"): return m for attr in ("model", "language_model", "transformer", "base_model"): child = getattr(m, attr, None) if child is not None: stack.append(child) raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}") def zero_layers(model, layer_indices): inner = get_inner_with_layers(model) layers = inner.layers n = len(layers) for idx in layer_indices: if idx < 0 or idx >= n: raise IndexError(f"layer {idx} out of range (0..{n - 1})") with torch.no_grad(): for p in layers[idx].parameters(): p.zero_() return n def _zero_output_projections(layer): zeroed = [] with torch.no_grad(): if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"): layer.self_attn.o_proj.weight.zero_() zeroed.append("self_attn.o_proj") if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"): layer.linear_attn.out_proj.weight.zero_() zeroed.append("linear_attn.out_proj") if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"): layer.mlp.down_proj.weight.zero_() zeroed.append("mlp.down_proj") return zeroed def grow_layers(model, target_n): inner = get_inner_with_layers(model) cur_n = len(inner.layers) if target_n == cur_n: return cur_n, [] if target_n < cur_n: raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink") cfg = model.config text_cfg = getattr(cfg, "text_config", cfg) if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types: raise RuntimeError("text config has no layer_types; cannot extend pattern") period = getattr(text_cfg, "full_attention_interval", 4) new_types = list(text_cfg.layer_types) while len(new_types) < target_n: new_types.append(new_types[len(new_types) % period]) text_cfg.layer_types = new_types text_cfg.num_hidden_layers = target_n if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg: cfg.num_hidden_layers = target_n layer_cls = type(inner.layers[0]) device = next(inner.parameters()).device dtype = next(inner.parameters()).dtype new_layer_zeroed = [] for i in range(cur_n, target_n): new_layer = layer_cls(text_cfg, layer_idx=i) new_layer.apply(model._init_weights) new_layer.to(device=device, dtype=dtype) zeroed = _zero_output_projections(new_layer) new_layer_zeroed.append((i, zeroed)) inner.layers.append(new_layer) return target_n, new_layer_zeroed def detect_model_kind(model_id: str) -> str: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_id) archs = list(getattr(cfg, "architectures", []) or []) arch = archs[0] if archs else "" if "ConditionalGeneration" in arch or "ImageText" in arch: return "image_text" return "causal_lm" def load_student(model_id: str, dtype: torch.dtype, grad_ckpt: bool, attn_impl: str): kind = detect_model_kind(model_id) if kind == "image_text": from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained( model_id, dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, ) else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( model_id, dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, ) model.config.use_cache = False if grad_ckpt: model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) return model def load_teacher(model_id: str, dtype: torch.dtype, attn_impl: str, devices: list[int], max_mem_gb: int): kind = detect_model_kind(model_id) max_memory = {idx: f"{max_mem_gb}GiB" for idx in devices} max_memory["cpu"] = "256GiB" common = dict( dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, device_map="auto", max_memory=max_memory, ) if kind == "image_text": from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained(model_id, **common) else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(model_id, **common) model.config.use_cache = False model.eval() for p in model.parameters(): p.requires_grad_(False) return model def get_teacher_devices(model) -> tuple[torch.device, torch.device]: device_map = getattr(model, "hf_device_map", None) or {} ordered = OrderedDict() for _, dev in device_map.items(): if isinstance(dev, int): ordered.setdefault(f"cuda:{dev}", None) elif isinstance(dev, str) and dev.startswith("cuda:"): ordered.setdefault(dev, None) if not ordered: first = next(model.parameters()).device return first, first keys = list(ordered.keys()) return torch.device(keys[0]), torch.device(keys[-1]) def teacher_forward(teacher, input_ids, attention_mask, out_device): out = teacher(input_ids=input_ids, attention_mask=attention_mask) logits = getattr(out, "logits", None) if logits is None: raise RuntimeError("teacher forward did not return .logits") if logits.device != out_device: logits = logits.to(out_device, non_blocking=True) return logits class StreamingTextLoader: def __init__( self, name, text_field, min_chars, max_seq_len, kl_start_pos, tokenizer, seed, shuffle_buffer, ): from datasets import load_dataset last_err = None for attempt in range(8): try: ds = load_dataset(name, split="train", streaming=True) break except Exception as e: last_err = e wait = min(2 ** attempt, 30) log.warning( f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): " f"{type(e).__name__}: {e}; sleeping {wait}s" ) time.sleep(wait) else: raise RuntimeError(f"load_dataset failed after 8 retries") from last_err ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer) self._ds = iter(ds) self._text_field = text_field self._min_chars = min_chars self._max_seq_len = max_seq_len self._min_tokens = kl_start_pos + 16 self._tokenizer = tokenizer self._name = name def next_sample(self): scanned = 0 while scanned < 100: try: item = next(self._ds) except StopIteration: return None scanned += 1 text = item.get(self._text_field, "") or "" if len(text) < self._min_chars: continue ids = self._tokenizer( text, return_tensors="pt", truncation=True, max_length=self._max_seq_len, ).input_ids.squeeze(0) if ids.shape[0] < self._min_tokens: continue return ids return None class MixedStreamingLoader: def __init__(self, specs, tokenizer, min_chars, max_seq_len, kl_start_pos, seed, shuffle_buffer): self._rng = random.Random(seed) self._weights = [] self._loaders = [] for spec in specs: self._weights.append(spec["weight"]) self._loaders.append( StreamingTextLoader( name=spec["name"], text_field=spec["text_field"], min_chars=min_chars, max_seq_len=max_seq_len, kl_start_pos=kl_start_pos, tokenizer=tokenizer, seed=seed + len(self._loaders), shuffle_buffer=shuffle_buffer, ) ) def next_batch(self, n): out = [] while len(out) < n: idx = self._rng.choices(range(len(self._loaders)), weights=self._weights, k=1)[0] sample = self._loaders[idx].next_sample() if sample is None: continue out.append(sample) return out def collate_pad(token_lists, pad_id): max_len = max(t.shape[0] for t in token_lists) B = len(token_lists) input_ids = torch.full((B, max_len), pad_id, dtype=torch.long) attention_mask = torch.zeros((B, max_len), dtype=torch.long) for i, t in enumerate(token_lists): L = t.shape[0] input_ids[i, :L] = t attention_mask[i, :L] = 1 return input_ids, attention_mask def _kl_chunk_sum(s_chunk, t_chunk, m_chunk): s = s_chunk.float() t = t_chunk.float() t_log_p = F.log_softmax(t, dim=-1) s_log_p = F.log_softmax(s, dim=-1) t_p = t_log_p.exp() per_token = (t_p * (t_log_p - s_log_p)).sum(-1) return (per_token * m_chunk).sum() def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size): s_full = student_logits[:, start_pos:, :] t_full = teacher_logits[:, start_pos:, :].detach() m_full = attention_mask[:, start_pos:].float() T = s_full.shape[1] if chunk_size <= 0 or chunk_size >= T: return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0) total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32) for i in range(0, T, chunk_size): end = min(i + chunk_size, T) s_c = s_full[:, i:end, :] t_c = t_full[:, i:end, :] m_c = m_full[:, i:end] chunk_kl = checkpoint_utils.checkpoint( _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False ) total_kl = total_kl + chunk_kl return total_kl / m_full.sum().clamp_min(1.0) def apply_trainable_masks(model, train_cfg): trainable = train_cfg.get("trainable_patterns", []) frozen = train_cfg.get("freeze_patterns", []) if not trainable and not frozen: return trainable_re = [re.compile(p) for p in trainable] frozen_re = [re.compile(p) for p in frozen] for name, p in model.named_parameters(): keep = True if trainable_re: keep = any(r.search(name) for r in trainable_re) if keep and frozen_re and any(r.search(name) for r in frozen_re): keep = False p.requires_grad_(keep) def make_optimizer(model, train_cfg, new_layer_indices=None): base_lr = train_cfg["lr"] mul = train_cfg["new_layer_lr_mul"] common = dict( weight_decay=train_cfg["weight_decay"], betas=tuple(train_cfg["betas"]), eps=train_cfg["eps"], ) if not new_layer_indices or mul == 1.0: return AdamW( [p for p in model.parameters() if p.requires_grad], lr=base_lr, **common, ) inner = get_inner_with_layers(model) new_pids = set() for idx in new_layer_indices: for p in inner.layers[idx].parameters(): if p.requires_grad: new_pids.add(id(p)) new_params = [] rest_params = [] for p in model.parameters(): if not p.requires_grad: continue (new_params if id(p) in new_pids else rest_params).append(p) return AdamW( [ {"params": rest_params, "lr": base_lr}, {"params": new_params, "lr": base_lr * mul}, ], **common, ) def make_scheduler(optimizer, train_cfg): schedule = train_cfg["schedule"] warmup = train_cfg["warmup_steps"] total = train_cfg["max_steps"] if schedule == "constant": from transformers import get_constant_schedule_with_warmup return get_constant_schedule_with_warmup(optimizer, warmup) if schedule == "cosine": from transformers import get_cosine_schedule_with_warmup return get_cosine_schedule_with_warmup(optimizer, warmup, total) if schedule == "linear": from transformers import get_linear_schedule_with_warmup return get_linear_schedule_with_warmup(optimizer, warmup, total) raise ValueError(f"unknown schedule: {schedule!r}") def build_dataset_specs(data_cfg): if "datasets" in data_cfg: names = data_cfg["datasets"] text_fields = data_cfg.get("text_fields", [data_cfg.get("text_field", "text")] * len(names)) weights = data_cfg.get("dataset_weights", [1.0] * len(names)) if not (len(names) == len(text_fields) == len(weights)): raise ValueError("datasets/text_fields/dataset_weights length mismatch") return [ {"name": name, "text_field": field, "weight": weight} for name, field, weight in zip(names, text_fields, weights) ] return [ { "name": data_cfg["dataset"], "text_field": data_cfg["text_field"], "weight": 1.0, } ] def build_or_load_eval_cache(path, loader=None, samples=None): path = Path(path) if path.exists(): log.info(f"Loading eval cache from {path}") raw = torch.load(path) return [torch.tensor(x, dtype=torch.long) for x in raw] if loader is None or samples is None: raise ValueError("loader and samples are required when building a new eval cache") path.parent.mkdir(parents=True, exist_ok=True) log.info(f"Building eval cache at {path}") batches = loader.next_batch(samples) torch.save([x.tolist() for x in batches], path) return batches def log_jsonl(path: Path, record: dict): path.parent.mkdir(parents=True, exist_ok=True) with path.open("a") as f: f.write(json.dumps(record, sort_keys=True) + "\n") @torch.no_grad() def evaluate(student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, student_device, teacher_input_device): student.eval() total = 0.0 n = 0 for sample in eval_batches: ids, mask = collate_pad([sample], pad_id) teacher_ids = ids.to(teacher_input_device, non_blocking=True) teacher_mask = mask.to(teacher_input_device, non_blocking=True) student_ids = ids.to(student_device, non_blocking=True) student_mask = mask.to(student_device, non_blocking=True) t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device) s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits loss = kl_loss_masked( s_logits, t_logits, student_mask, start_pos=kl_start_pos, chunk_size=kl_chunk_size, ) total += loss.item() n += 1 del t_logits, s_logits, loss, teacher_ids, teacher_mask, student_ids, student_mask student.train() return total / max(n, 1) def save_best(student, tokenizer, output_dir, step, eval_kl): out_dir = Path(output_dir) / "best" if out_dir.exists(): shutil.rmtree(out_dir) out_dir.mkdir(parents=True, exist_ok=True) student.save_pretrained(out_dir, safe_serialization=True) tokenizer.save_pretrained(out_dir) with (out_dir / "best.json").open("w") as f: json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2) log.info(f"saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}") def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() cfg = load_config(args.config) torch.manual_seed(cfg["train"]["seed"]) random.seed(cfg["train"]["seed"]) student_device = torch.device(cfg["model"]["student_device"]) teacher_devices = list(cfg["model"]["teacher_devices"]) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pad_id = tokenizer.pad_token_id student = load_student( cfg["model"]["student"], parse_dtype(cfg["train"]["student_dtype"]), grad_ckpt=cfg["train"]["grad_checkpointing"], attn_impl=cfg["train"]["attn_implementation"], ) student.to(student_device) student.config.use_cache = False target_n = cfg["init"]["target_num_layers"] cur_n = len(get_inner_with_layers(student).layers) new_layer_indices = [] if target_n != cur_n: new_n, new_zeroed = grow_layers(student, target_n) new_layer_indices = [idx for idx, _ in new_zeroed] log.info(f"Grew student from {cur_n} -> {new_n} layers") for idx, names in new_zeroed: log.info(f" layer {idx}: zeroed {names}") zero_idx = cfg["init"]["zero_layers"] if zero_idx: n = zero_layers(student, zero_idx) log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)") apply_trainable_masks(student, cfg["train"]) trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad) total_params = sum(p.numel() for p in student.parameters()) if trainable_params == 0: raise RuntimeError("No trainable parameters remain after applying trainable/freeze patterns") log.info(f"Student params: total={total_params/1e9:.3f}B trainable={trainable_params/1e9:.3f}B") teacher = load_teacher( cfg["model"]["teacher"], parse_dtype(cfg["train"]["teacher_dtype"]), attn_impl=cfg["train"]["attn_implementation"], devices=teacher_devices, max_mem_gb=cfg["model"]["teacher_max_memory_gb"], ) teacher_input_device, _ = get_teacher_devices(teacher) log.info(f"Teacher input device: {teacher_input_device}") optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices) scheduler = make_scheduler(optimizer, cfg["train"]) output_dir = Path(cfg["log"]["output_dir"]) output_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(args.config, output_dir / "config.snapshot.toml") metrics_path = output_dir / "metrics.jsonl" experiment_log = Path(cfg["log"]["experiment_log"]) use_wandb = cfg["log"]["wandb"] if use_wandb: import wandb wandb.init( project=cfg["log"]["wandb_project"], name=cfg["log"]["wandb_run"], config=cfg, ) specs = build_dataset_specs(cfg["data"]) train_loader = MixedStreamingLoader( specs=specs, tokenizer=tokenizer, min_chars=cfg["data"]["min_chars"], max_seq_len=cfg["data"]["max_seq_len"], kl_start_pos=cfg["data"]["kl_start_pos"], seed=cfg["data"]["seed"], shuffle_buffer=cfg["data"]["shuffle_buffer"], ) eval_cache_path = Path(cfg["eval"]["cache_path"]) if eval_cache_path.exists(): eval_batches = build_or_load_eval_cache(eval_cache_path) else: eval_loader = MixedStreamingLoader( specs=specs, tokenizer=tokenizer, min_chars=cfg["data"]["min_chars"], max_seq_len=cfg["data"]["max_seq_len"], kl_start_pos=cfg["data"]["kl_start_pos"], seed=cfg["eval"]["seed"], shuffle_buffer=cfg["data"]["shuffle_buffer"], ) eval_batches = build_or_load_eval_cache(eval_cache_path, eval_loader, cfg["eval"]["samples"]) log.info(f"Eval samples: {len(eval_batches)}") samples_per_step = cfg["train"]["samples_per_step"] micro_batch_size = cfg["train"]["micro_batch_size"] grad_clip = cfg["train"]["grad_clip"] kl_start_pos = cfg["data"]["kl_start_pos"] kl_chunk_size = cfg["train"]["kl_chunk_size"] max_steps = cfg["train"]["max_steps"] eval_every = cfg["eval"]["every_steps"] log_every = cfg["log"]["log_every"] student.train() best_kl = float("inf") global_step = 0 run_summary = { "config": args.config, "run_name": cfg["log"]["wandb_run"], "student": cfg["model"]["student"], "teacher": cfg["model"]["teacher"], "start_time": int(time.time()), } while global_step < max_steps: t0 = time.time() batch = train_loader.next_batch(samples_per_step) optimizer.zero_grad(set_to_none=True) batch_n = len(batch) kl_sum = 0.0 for mb_start in range(0, batch_n, micro_batch_size): micro = batch[mb_start : mb_start + micro_batch_size] mb_n = len(micro) ids, mask = collate_pad(micro, pad_id) teacher_ids = ids.to(teacher_input_device, non_blocking=True) teacher_mask = mask.to(teacher_input_device, non_blocking=True) student_ids = ids.to(student_device, non_blocking=True) student_mask = mask.to(student_device, non_blocking=True) with torch.no_grad(): t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device) s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits loss = kl_loss_masked( s_logits, t_logits, student_mask, start_pos=kl_start_pos, chunk_size=kl_chunk_size, ) scaled = loss * (mb_n / batch_n) scaled.backward() kl_sum += loss.item() * mb_n del teacher_ids, teacher_mask, student_ids, student_mask, t_logits, s_logits, loss, scaled if grad_clip > 0: torch.nn.utils.clip_grad_norm_(student.parameters(), grad_clip) optimizer.step() scheduler.step() global_step += 1 elapsed = time.time() - t0 kl_avg = kl_sum / batch_n lr_now = scheduler.get_last_lr()[0] record = { "step": global_step, "train_kl": kl_avg, "lr": lr_now, "step_time_s": elapsed, } log_jsonl(metrics_path, record) if global_step % log_every == 0: log.info( f"step {global_step}/{max_steps} | kl {kl_avg:.6f} | " f"lr {lr_now:.2e} | {elapsed:.2f}s" ) if use_wandb: import wandb wandb.log( { "train/kl": kl_avg, "train/lr": lr_now, "perf/step_time_s": elapsed, }, step=global_step, ) if global_step % eval_every == 0: eval_kl = evaluate( student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, student_device, teacher_input_device, ) log.info(f"eval @ step {global_step}: kl={eval_kl:.6f} (best={best_kl:.6f})") log_jsonl(metrics_path, {"step": global_step, "eval_kl": eval_kl}) if use_wandb: import wandb wandb.log({"eval/kl": eval_kl}, step=global_step) if eval_kl < best_kl: best_kl = eval_kl save_best(student, tokenizer, output_dir, global_step, eval_kl) student.train() if global_step % 10 == 0: gc.collect() torch.cuda.empty_cache() final_eval = evaluate( student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, student_device, teacher_input_device, ) log.info(f"final eval: kl={final_eval:.6f} (best={best_kl:.6f})") if final_eval < best_kl: best_kl = final_eval save_best(student, tokenizer, output_dir, global_step, final_eval) run_summary.update( { "end_time": int(time.time()), "best_eval_kl": best_kl, "final_eval_kl": final_eval, "max_steps": max_steps, "student_total_params": total_params, "student_trainable_params": trainable_params, } ) log_jsonl(experiment_log, run_summary) if use_wandb: import wandb wandb.log({"eval/final_kl": final_eval, "eval/best_kl": best_kl}, step=global_step) wandb.finish() if __name__ == "__main__": main()