Delta-Vector commited on
Commit
f6e42f8
·
verified ·
1 Parent(s): 46f472b

initial scaffold: distill.py + base/zero_14_17 configs + accelerate yaml

Browse files
configs/accelerate.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ distributed_type: MULTI_GPU
3
+ mixed_precision: bf16
4
+ num_processes: 8
5
+ num_machines: 1
6
+ machine_rank: 0
7
+ gpu_ids: all
8
+ rdzv_backend: static
9
+ same_network: true
10
+ tpu_use_cluster: false
11
+ tpu_use_sudo: false
12
+ use_cpu: false
13
+ debug: false
14
+ enable_cpu_affinity: false
15
+ main_training_function: main
16
+ downcast_bf16: 'no'
configs/base.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base distillation config (smoketest variant).
2
+ # Every value the script reads must live in this file - no defaults in code.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "Troiaaa/m-6a3lnzvb"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 640
14
+ kl_start_pos = 128
15
+ seed = 42
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 42
20
+ lr = 5.0e-7
21
+ schedule = "constant"
22
+ warmup_steps = 0
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.95]
26
+ eps = 1.0e-8
27
+ samples_per_step = 4
28
+ max_steps = 5
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+
32
+ [eval]
33
+ every_steps = 5
34
+ samples = 16
35
+ seed = 1234
36
+
37
+ [log]
38
+ wandb = true
39
+ wandb_project = "distil-subnet97"
40
+ wandb_run = "smoketest"
41
+ log_every = 1
42
+ output_dir = "./out/smoketest"
43
+
44
+ [init]
45
+ zero_layers = []
configs/zero_14_17.toml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Layer-zero distillation: zero student layers 14-17 at init,
2
+ # constant LR 5e-7, 2000 steps. Aim: lower KL than the prior checkpoint
3
+ # despite the surgery.
4
+
5
+ [model]
6
+ teacher = "Qwen/Qwen3.5-35B-A3B"
7
+ student = "Troiaaa/m-6a3lnzvb"
8
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
9
+
10
+ [data]
11
+ dataset = "karpathy/climbmix-400b-shuffle"
12
+ text_field = "text"
13
+ min_chars = 2560
14
+ max_seq_len = 640
15
+ kl_start_pos = 128
16
+ seed = 42
17
+ shuffle_buffer = 10000
18
+
19
+ [train]
20
+ seed = 42
21
+ lr = 5.0e-7
22
+ schedule = "constant"
23
+ warmup_steps = 0
24
+ weight_decay = 0.0
25
+ grad_clip = 1.0
26
+ betas = [0.9, 0.95]
27
+ eps = 1.0e-8
28
+ samples_per_step = 8
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+
33
+ [eval]
34
+ every_steps = 50
35
+ samples = 64
36
+ seed = 1234
37
+
38
+ [log]
39
+ wandb = true
40
+ wandb_project = "distil-subnet97"
41
+ wandb_run = "m-6a3lnzvb-zero14_17"
42
+ log_every = 1
43
+ output_dir = "./out/zero_14_17"
44
+
45
+ [init]
46
+ zero_layers = [14, 15, 16, 17]
distill.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ KL Distillation Training - TOML-driven, accelerate multi-GPU.
4
+
5
+ Run with:
6
+ accelerate launch --config_file configs/accelerate.yaml distill.py --config configs/base.toml
7
+
8
+ The TOML config is the single source of truth - no hardcoded defaults in this file.
9
+ The only command line argument is --config <path-to-toml>.
10
+ """
11
+
12
+ import argparse
13
+ import gc
14
+ import json
15
+ import logging
16
+ import shutil
17
+ import time
18
+ import tomllib
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.optim import AdamW
24
+
25
+ from accelerate import Accelerator
26
+ from accelerate.utils import set_seed
27
+
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format="%(asctime)s [%(levelname)s] %(message)s",
31
+ datefmt="%H:%M:%S",
32
+ )
33
+ log = logging.getLogger("distill")
34
+
35
+
36
+ # ----------------------------------------------------------------------------
37
+ # Config
38
+ # ----------------------------------------------------------------------------
39
+
40
+ REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init")
41
+ REQUIRED_KEYS = {
42
+ "model": ("teacher", "student", "tokenizer"),
43
+ "data": (
44
+ "dataset",
45
+ "text_field",
46
+ "min_chars",
47
+ "max_seq_len",
48
+ "kl_start_pos",
49
+ "seed",
50
+ "shuffle_buffer",
51
+ ),
52
+ "train": (
53
+ "seed",
54
+ "lr",
55
+ "schedule",
56
+ "warmup_steps",
57
+ "weight_decay",
58
+ "grad_clip",
59
+ "betas",
60
+ "eps",
61
+ "samples_per_step",
62
+ "max_steps",
63
+ "grad_checkpointing",
64
+ "attn_implementation",
65
+ ),
66
+ "eval": ("every_steps", "samples", "seed"),
67
+ "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
68
+ "init": ("zero_layers",),
69
+ }
70
+
71
+
72
+ def load_config(path):
73
+ with open(path, "rb") as f:
74
+ cfg = tomllib.load(f)
75
+ for sec in REQUIRED_SECTIONS:
76
+ if sec not in cfg:
77
+ raise KeyError(f"config missing required section [{sec}]")
78
+ for key in REQUIRED_KEYS[sec]:
79
+ if key not in cfg[sec]:
80
+ raise KeyError(f"config missing required key [{sec}].{key}")
81
+ return cfg
82
+
83
+
84
+ # ----------------------------------------------------------------------------
85
+ # Model loading
86
+ # ----------------------------------------------------------------------------
87
+
88
+ def get_inner_with_layers(model):
89
+ """Walk wrappers (model, language_model, transformer, ...) to find an
90
+ object that has `.layers`. Used by zero_layers."""
91
+ seen = set()
92
+ stack = [model]
93
+ while stack:
94
+ m = stack.pop()
95
+ if id(m) in seen:
96
+ continue
97
+ seen.add(id(m))
98
+ if hasattr(m, "layers"):
99
+ return m
100
+ for attr in ("model", "language_model", "transformer", "base_model"):
101
+ child = getattr(m, attr, None)
102
+ if child is not None:
103
+ stack.append(child)
104
+ raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}")
105
+
106
+
107
+ def zero_layers(model, layer_indices):
108
+ inner = get_inner_with_layers(model)
109
+ layers = inner.layers
110
+ n = len(layers)
111
+ for idx in layer_indices:
112
+ if idx < 0 or idx >= n:
113
+ raise IndexError(f"layer {idx} out of range (0..{n - 1})")
114
+ with torch.no_grad():
115
+ for p in layers[idx].parameters():
116
+ p.zero_()
117
+ return n
118
+
119
+
120
+ def load_student(model_id, dtype, grad_ckpt, attn_impl):
121
+ from transformers import AutoModelForCausalLM
122
+ log.info(f"Loading student: {model_id}")
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ model_id,
125
+ dtype=dtype,
126
+ low_cpu_mem_usage=True,
127
+ attn_implementation=attn_impl,
128
+ )
129
+ model.config.use_cache = False
130
+ if grad_ckpt:
131
+ model.gradient_checkpointing_enable(
132
+ gradient_checkpointing_kwargs={"use_reentrant": False}
133
+ )
134
+ return model
135
+
136
+
137
+ def load_teacher(model_id, dtype, attn_impl):
138
+ """Load teacher model. Handles both pure CausalLM and multimodal
139
+ (ConditionalGeneration) wrappers."""
140
+ from transformers import AutoConfig
141
+ cfg = AutoConfig.from_pretrained(model_id)
142
+ archs = list(getattr(cfg, "architectures", []) or [])
143
+ arch = archs[0] if archs else ""
144
+ is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
145
+ log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal})")
146
+
147
+ if is_multimodal:
148
+ from transformers import AutoModelForImageTextToText
149
+ model = AutoModelForImageTextToText.from_pretrained(
150
+ model_id,
151
+ dtype=dtype,
152
+ low_cpu_mem_usage=True,
153
+ attn_implementation=attn_impl,
154
+ )
155
+ else:
156
+ from transformers import AutoModelForCausalLM
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ model_id,
159
+ dtype=dtype,
160
+ low_cpu_mem_usage=True,
161
+ attn_implementation=attn_impl,
162
+ )
163
+ model.config.use_cache = False
164
+ model.eval()
165
+ for p in model.parameters():
166
+ p.requires_grad_(False)
167
+ return model
168
+
169
+
170
+ def teacher_forward(teacher, input_ids, attention_mask):
171
+ """Get teacher logits whether the model is unimodal or multimodal."""
172
+ out = teacher(input_ids=input_ids, attention_mask=attention_mask)
173
+ logits = getattr(out, "logits", None)
174
+ if logits is None:
175
+ raise RuntimeError("teacher forward did not return .logits")
176
+ return logits
177
+
178
+
179
+ # ----------------------------------------------------------------------------
180
+ # Data
181
+ # ----------------------------------------------------------------------------
182
+
183
+ class StreamingTextLoader:
184
+ """Per-rank shard of a HF streaming dataset, yielding tokenized samples."""
185
+
186
+ def __init__(
187
+ self,
188
+ name,
189
+ text_field,
190
+ min_chars,
191
+ max_seq_len,
192
+ kl_start_pos,
193
+ tokenizer,
194
+ rank,
195
+ world_size,
196
+ seed,
197
+ shuffle_buffer,
198
+ ):
199
+ from datasets import load_dataset
200
+ from datasets.distributed import split_dataset_by_node
201
+
202
+ ds = load_dataset(name, split="train", streaming=True)
203
+ ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer)
204
+ ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
205
+ self._ds = iter(ds)
206
+ self._text_field = text_field
207
+ self._min_chars = min_chars
208
+ self._max_seq_len = max_seq_len
209
+ self._min_tokens = kl_start_pos + 16
210
+ self._tokenizer = tokenizer
211
+
212
+ def next_batch(self, n):
213
+ out = []
214
+ scanned = 0
215
+ while len(out) < n and scanned < n * 50:
216
+ try:
217
+ item = next(self._ds)
218
+ except StopIteration:
219
+ break
220
+ scanned += 1
221
+ text = item.get(self._text_field, "") or ""
222
+ if len(text) < self._min_chars:
223
+ continue
224
+ ids = self._tokenizer(
225
+ text,
226
+ return_tensors="pt",
227
+ truncation=True,
228
+ max_length=self._max_seq_len,
229
+ ).input_ids.squeeze(0)
230
+ if ids.shape[0] < self._min_tokens:
231
+ continue
232
+ out.append(ids)
233
+ return out
234
+
235
+
236
+ def collate_pad(token_lists, pad_id):
237
+ """Right-pad a list of [L_i] tensors into [B, max_L] + attention_mask."""
238
+ max_len = max(t.shape[0] for t in token_lists)
239
+ B = len(token_lists)
240
+ input_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
241
+ attention_mask = torch.zeros((B, max_len), dtype=torch.long)
242
+ for i, t in enumerate(token_lists):
243
+ L = t.shape[0]
244
+ input_ids[i, :L] = t
245
+ attention_mask[i, :L] = 1
246
+ return input_ids, attention_mask
247
+
248
+
249
+ # ----------------------------------------------------------------------------
250
+ # Loss
251
+ # ----------------------------------------------------------------------------
252
+
253
+ def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos):
254
+ """Forward KL(teacher || student), masked for padding & start_pos.
255
+
256
+ Computed in fp32 for numerical stability.
257
+ """
258
+ s = student_logits[:, start_pos:, :].float()
259
+ t = teacher_logits[:, start_pos:, :].detach().float()
260
+ mask = attention_mask[:, start_pos:].float()
261
+
262
+ t_log_p = F.log_softmax(t, dim=-1)
263
+ s_log_p = F.log_softmax(s, dim=-1)
264
+ t_p = t_log_p.exp()
265
+
266
+ per_token = (t_p * (t_log_p - s_log_p)).sum(-1) # [B, T-start]
267
+ return (per_token * mask).sum() / mask.sum().clamp_min(1.0)
268
+
269
+
270
+ # ----------------------------------------------------------------------------
271
+ # Optimizer / scheduler
272
+ # ----------------------------------------------------------------------------
273
+
274
+ def make_optimizer(model, train_cfg):
275
+ return AdamW(
276
+ [p for p in model.parameters() if p.requires_grad],
277
+ lr=train_cfg["lr"],
278
+ weight_decay=train_cfg["weight_decay"],
279
+ betas=tuple(train_cfg["betas"]),
280
+ eps=train_cfg["eps"],
281
+ )
282
+
283
+
284
+ def make_scheduler(optimizer, train_cfg):
285
+ schedule = train_cfg["schedule"]
286
+ warmup = train_cfg["warmup_steps"]
287
+ total = train_cfg["max_steps"]
288
+
289
+ if schedule == "constant":
290
+ from transformers import get_constant_schedule_with_warmup
291
+ return get_constant_schedule_with_warmup(optimizer, warmup)
292
+ if schedule == "cosine":
293
+ from transformers import get_cosine_schedule_with_warmup
294
+ return get_cosine_schedule_with_warmup(optimizer, warmup, total)
295
+ if schedule == "linear":
296
+ from transformers import get_linear_schedule_with_warmup
297
+ return get_linear_schedule_with_warmup(optimizer, warmup, total)
298
+ raise ValueError(f"unknown schedule: {schedule!r}")
299
+
300
+
301
+ # ----------------------------------------------------------------------------
302
+ # Eval
303
+ # ----------------------------------------------------------------------------
304
+
305
+ @torch.no_grad()
306
+ def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos):
307
+ student.eval()
308
+ sdev = accelerator.device
309
+ total = 0.0
310
+ n = 0
311
+ for sample in eval_batches:
312
+ ids, mask = collate_pad([sample], pad_id)
313
+ ids = ids.to(sdev)
314
+ mask = mask.to(sdev)
315
+ t_logits = teacher_forward(teacher, ids, mask)
316
+ s_logits = student(input_ids=ids, attention_mask=mask).logits
317
+ loss = kl_loss_masked(s_logits, t_logits, mask, start_pos=kl_start_pos)
318
+ total += loss.item()
319
+ n += 1
320
+ del t_logits, s_logits, loss
321
+ student.train()
322
+ if n == 0:
323
+ local = torch.tensor(float("inf"), device=sdev)
324
+ else:
325
+ local = torch.tensor(total / n, device=sdev)
326
+ gathered = accelerator.gather(local.unsqueeze(0))
327
+ return gathered.mean().item()
328
+
329
+
330
+ def save_best(accelerator, student, tokenizer, output_dir, step, eval_kl):
331
+ accelerator.wait_for_everyone()
332
+ if accelerator.is_main_process:
333
+ out_dir = Path(output_dir) / "best"
334
+ if out_dir.exists():
335
+ shutil.rmtree(out_dir)
336
+ out_dir.mkdir(parents=True, exist_ok=True)
337
+ unwrapped = accelerator.unwrap_model(student)
338
+ unwrapped.save_pretrained(out_dir, safe_serialization=True)
339
+ tokenizer.save_pretrained(out_dir)
340
+ with open(out_dir / "best.json", "w") as f:
341
+ json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2)
342
+ log.info(f" saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}")
343
+ accelerator.wait_for_everyone()
344
+
345
+
346
+ # ----------------------------------------------------------------------------
347
+ # Main
348
+ # ----------------------------------------------------------------------------
349
+
350
+ def main():
351
+ p = argparse.ArgumentParser()
352
+ p.add_argument("--config", required=True, help="Path to TOML config")
353
+ args = p.parse_args()
354
+
355
+ cfg = load_config(args.config)
356
+
357
+ accelerator = Accelerator(mixed_precision="bf16")
358
+ set_seed(cfg["train"]["seed"])
359
+
360
+ if accelerator.is_main_process:
361
+ log.info(f"Loaded config from {args.config}")
362
+ log.info(f"World size: {accelerator.num_processes}")
363
+
364
+ # ---- Tokenizer
365
+ from transformers import AutoTokenizer
366
+ tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"])
367
+ if tokenizer.pad_token is None:
368
+ tokenizer.pad_token = tokenizer.eos_token
369
+ pad_id = tokenizer.pad_token_id
370
+
371
+ # ---- Models
372
+ dtype = torch.bfloat16
373
+ student = load_student(
374
+ cfg["model"]["student"],
375
+ dtype,
376
+ grad_ckpt=cfg["train"]["grad_checkpointing"],
377
+ attn_impl=cfg["train"]["attn_implementation"],
378
+ )
379
+ teacher = load_teacher(
380
+ cfg["model"]["teacher"],
381
+ dtype,
382
+ attn_impl=cfg["train"]["attn_implementation"],
383
+ )
384
+
385
+ # ---- Layer modifications (post-load, pre-prepare)
386
+ zero_idx = cfg["init"]["zero_layers"]
387
+ if zero_idx:
388
+ n = zero_layers(student, zero_idx)
389
+ if accelerator.is_main_process:
390
+ log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)")
391
+
392
+ teacher = teacher.to(accelerator.device)
393
+
394
+ # ---- Optimizer / scheduler
395
+ optimizer = make_optimizer(student, cfg["train"])
396
+ scheduler = make_scheduler(optimizer, cfg["train"])
397
+
398
+ student, optimizer, scheduler = accelerator.prepare(
399
+ student, optimizer, scheduler
400
+ )
401
+
402
+ # ---- Output dir + config snapshot
403
+ output_dir = Path(cfg["log"]["output_dir"])
404
+ if accelerator.is_main_process:
405
+ output_dir.mkdir(parents=True, exist_ok=True)
406
+ shutil.copy2(args.config, output_dir / "config.snapshot.toml")
407
+
408
+ # ---- Wandb
409
+ use_wandb = cfg["log"]["wandb"]
410
+ if use_wandb and accelerator.is_main_process:
411
+ import wandb
412
+ wandb.init(
413
+ project=cfg["log"]["wandb_project"],
414
+ name=cfg["log"]["wandb_run"],
415
+ config=cfg,
416
+ )
417
+
418
+ # ---- Data loaders
419
+ train_loader = StreamingTextLoader(
420
+ name=cfg["data"]["dataset"],
421
+ text_field=cfg["data"]["text_field"],
422
+ min_chars=cfg["data"]["min_chars"],
423
+ max_seq_len=cfg["data"]["max_seq_len"],
424
+ kl_start_pos=cfg["data"]["kl_start_pos"],
425
+ tokenizer=tokenizer,
426
+ rank=accelerator.process_index,
427
+ world_size=accelerator.num_processes,
428
+ seed=cfg["data"]["seed"],
429
+ shuffle_buffer=cfg["data"]["shuffle_buffer"],
430
+ )
431
+ eval_loader = StreamingTextLoader(
432
+ name=cfg["data"]["dataset"],
433
+ text_field=cfg["data"]["text_field"],
434
+ min_chars=cfg["data"]["min_chars"],
435
+ max_seq_len=cfg["data"]["max_seq_len"],
436
+ kl_start_pos=cfg["data"]["kl_start_pos"],
437
+ tokenizer=tokenizer,
438
+ rank=accelerator.process_index,
439
+ world_size=accelerator.num_processes,
440
+ seed=cfg["eval"]["seed"],
441
+ shuffle_buffer=cfg["data"]["shuffle_buffer"],
442
+ )
443
+ eval_per_rank = max(1, cfg["eval"]["samples"] // accelerator.num_processes)
444
+ eval_batches = eval_loader.next_batch(eval_per_rank)
445
+ if accelerator.is_main_process:
446
+ log.info(
447
+ f"Eval set: {len(eval_batches)}/rank x {accelerator.num_processes} ranks "
448
+ f"= {len(eval_batches) * accelerator.num_processes} samples"
449
+ )
450
+
451
+ # ---- Train loop
452
+ samples_per_step = cfg["train"]["samples_per_step"]
453
+ grad_clip = cfg["train"]["grad_clip"]
454
+ kl_start_pos = cfg["data"]["kl_start_pos"]
455
+ max_steps = cfg["train"]["max_steps"]
456
+ eval_every = cfg["eval"]["every_steps"]
457
+ log_every = cfg["log"]["log_every"]
458
+
459
+ if accelerator.is_main_process:
460
+ log.info(
461
+ f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
462
+ f"(per rank), effective batch={samples_per_step * accelerator.num_processes}"
463
+ )
464
+
465
+ student.train()
466
+ best_kl = float("inf")
467
+ global_step = 0
468
+
469
+ while global_step < max_steps:
470
+ t0 = time.time()
471
+ batch = train_loader.next_batch(samples_per_step)
472
+ if not batch:
473
+ log.warning(f"rank {accelerator.process_index}: data exhausted")
474
+ break
475
+
476
+ ids, mask = collate_pad(batch, pad_id)
477
+ ids = ids.to(accelerator.device)
478
+ mask = mask.to(accelerator.device)
479
+
480
+ with torch.no_grad():
481
+ t_logits = teacher_forward(teacher, ids, mask)
482
+ s_logits = student(input_ids=ids, attention_mask=mask).logits
483
+ loss = kl_loss_masked(s_logits, t_logits, mask, start_pos=kl_start_pos)
484
+
485
+ optimizer.zero_grad()
486
+ accelerator.backward(loss)
487
+ if grad_clip > 0:
488
+ accelerator.clip_grad_norm_(student.parameters(), grad_clip)
489
+ optimizer.step()
490
+ scheduler.step()
491
+ global_step += 1
492
+
493
+ elapsed = time.time() - t0
494
+ kl_local = loss.detach()
495
+ kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
496
+ del t_logits, s_logits, loss, kl_local
497
+
498
+ if accelerator.is_main_process and global_step % log_every == 0:
499
+ lr_now = scheduler.get_last_lr()[0]
500
+ log.info(
501
+ f"step {global_step}/{max_steps} | kl {kl_avg:.4f} | "
502
+ f"lr {lr_now:.2e} | {elapsed:.2f}s"
503
+ )
504
+ if use_wandb:
505
+ import wandb
506
+ wandb.log(
507
+ {
508
+ "train/kl": kl_avg,
509
+ "train/lr": lr_now,
510
+ "perf/step_time_s": elapsed,
511
+ },
512
+ step=global_step,
513
+ )
514
+
515
+ if global_step % eval_every == 0:
516
+ eval_kl = evaluate(
517
+ accelerator, student, teacher, eval_batches, pad_id, kl_start_pos
518
+ )
519
+ if accelerator.is_main_process:
520
+ log.info(
521
+ f" eval @ step {global_step}: kl={eval_kl:.6f} "
522
+ f"(best={best_kl:.6f})"
523
+ )
524
+ if use_wandb:
525
+ import wandb
526
+ wandb.log({"eval/kl": eval_kl}, step=global_step)
527
+ if eval_kl < best_kl:
528
+ best_kl = eval_kl
529
+ save_best(
530
+ accelerator, student, tokenizer, output_dir, global_step, eval_kl
531
+ )
532
+ student.train()
533
+
534
+ if global_step % 20 == 0:
535
+ gc.collect()
536
+ torch.cuda.empty_cache()
537
+
538
+ # Final eval
539
+ eval_kl = evaluate(
540
+ accelerator, student, teacher, eval_batches, pad_id, kl_start_pos
541
+ )
542
+ if accelerator.is_main_process:
543
+ log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
544
+ if use_wandb:
545
+ import wandb
546
+ wandb.log({"eval/kl": eval_kl}, step=global_step)
547
+ if eval_kl < best_kl:
548
+ best_kl = eval_kl
549
+ save_best(accelerator, student, tokenizer, output_dir, global_step, eval_kl)
550
+
551
+ if accelerator.is_main_process:
552
+ log.info(f"Done. Best eval KL = {best_kl:.6f}")
553
+ if use_wandb:
554
+ import wandb
555
+ wandb.finish()
556
+
557
+
558
+ if __name__ == "__main__":
559
+ main()
pyproject.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [project]
2
+ name = "distill"
3
+ version = "0.1.0"
4
+ requires-python = ">=3.12"
5
+ dependencies = []
requirements.lock.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.13.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.5
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ annotated-types==0.7.0
7
+ anyio==4.13.0
8
+ attrs==26.1.0
9
+ certifi==2026.2.25
10
+ charset-normalizer==3.4.7
11
+ click==8.3.2
12
+ cuda-bindings==12.9.4
13
+ cuda-pathfinder==1.2.2
14
+ cuda-toolkit==12.8.1
15
+ datasets==4.8.4
16
+ dill==0.4.1
17
+ einops==0.8.2
18
+ filelock==3.25.2
19
+ fla-core==0.4.2
20
+ flash-attn @ file:///tmp/flash_attn-2.8.3+cu128torch2.11-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
21
+ flash-linear-attention==0.4.2
22
+ frozenlist==1.8.0
23
+ fsspec==2026.2.0
24
+ gitdb==4.0.12
25
+ gitpython==3.1.46
26
+ h11==0.16.0
27
+ hf-xet==1.4.3
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ huggingface-hub==1.9.0
31
+ idna==3.11
32
+ jinja2==3.1.6
33
+ markdown-it-py==4.0.0
34
+ markupsafe==3.0.3
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ multidict==6.7.1
38
+ multiprocess==0.70.19
39
+ networkx==3.6.1
40
+ numpy==2.4.4
41
+ nvidia-cublas-cu12==12.8.4.1
42
+ nvidia-cuda-cupti-cu12==12.8.90
43
+ nvidia-cuda-nvrtc-cu12==12.8.93
44
+ nvidia-cuda-runtime-cu12==12.8.90
45
+ nvidia-cudnn-cu12==9.19.0.56
46
+ nvidia-cufft-cu12==11.3.3.83
47
+ nvidia-cufile-cu12==1.13.1.3
48
+ nvidia-curand-cu12==10.3.9.90
49
+ nvidia-cusolver-cu12==11.7.3.90
50
+ nvidia-cusparse-cu12==12.5.8.93
51
+ nvidia-cusparselt-cu12==0.7.1
52
+ nvidia-nccl-cu12==2.28.9
53
+ nvidia-nvjitlink-cu12==12.8.93
54
+ nvidia-nvshmem-cu12==3.4.5
55
+ nvidia-nvtx-cu12==12.8.90
56
+ packaging==26.0
57
+ pandas==3.0.2
58
+ platformdirs==4.9.4
59
+ propcache==0.4.1
60
+ protobuf==6.33.6
61
+ psutil==7.2.2
62
+ pyarrow==23.0.1
63
+ pydantic==2.12.5
64
+ pydantic-core==2.41.5
65
+ pygments==2.20.0
66
+ python-dateutil==2.9.0.post0
67
+ pyyaml==6.0.3
68
+ regex==2026.4.4
69
+ requests==2.33.1
70
+ rich==14.3.3
71
+ safetensors==0.7.0
72
+ sentencepiece==0.2.1
73
+ sentry-sdk==2.57.0
74
+ setuptools==70.2.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ smmap==5.0.3
78
+ sympy==1.14.0
79
+ tokenizers==0.22.2
80
+ tomli-w==1.2.0
81
+ torch==2.11.0+cu128
82
+ tqdm==4.67.3
83
+ transformers @ git+https://github.com/huggingface/transformers.git@52cb0653b48fcb0737a74546911df77034b61732
84
+ triton==3.6.0
85
+ typer==0.24.1
86
+ typing-extensions==4.15.0
87
+ typing-inspection==0.4.2
88
+ urllib3==2.6.3
89
+ wandb==0.25.1
90
+ xxhash==3.6.0
91
+ yarl==1.23.0
scripts/backup_to_hf.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Push the distill code/configs to the HF backup repo.
3
+
4
+ Usage:
5
+ .venv/bin/python scripts/backup_to_hf.py "<commit message>"
6
+ """
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from huggingface_hub import HfApi, CommitOperationAdd, create_commit
12
+
13
+ REPO_ID = "Delta-Vector/distill-m-6a3lnzvb-code"
14
+ REPO_TYPE = "model"
15
+
16
+ # Files/directories to mirror to the repo
17
+ INCLUDE = [
18
+ "distill.py",
19
+ "configs/base.toml",
20
+ "configs/zero_14_17.toml",
21
+ "configs/accelerate.yaml",
22
+ "scripts/backup_to_hf.py",
23
+ "pyproject.toml",
24
+ "requirements.lock.txt",
25
+ ]
26
+
27
+
28
+ def main():
29
+ msg = sys.argv[1] if len(sys.argv) > 1 else "update"
30
+ token = os.environ.get("HF_TOKEN")
31
+ if not token:
32
+ print("HF_TOKEN env var required", file=sys.stderr)
33
+ sys.exit(1)
34
+
35
+ root = Path(__file__).resolve().parent.parent
36
+ ops = []
37
+ for rel in INCLUDE:
38
+ local = root / rel
39
+ if not local.exists():
40
+ print(f" skip (missing): {rel}")
41
+ continue
42
+ ops.append(
43
+ CommitOperationAdd(path_in_repo=rel, path_or_fileobj=str(local))
44
+ )
45
+ print(f" add: {rel}")
46
+
47
+ if not ops:
48
+ print("nothing to upload")
49
+ return
50
+
51
+ api = HfApi(token=token)
52
+ api.create_commit(
53
+ repo_id=REPO_ID,
54
+ repo_type=REPO_TYPE,
55
+ operations=ops,
56
+ commit_message=msg,
57
+ )
58
+ print(f"pushed {len(ops)} files to {REPO_ID}: {msg}")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()