Delta-Vector commited on
Commit
eb5278f
·
verified ·
1 Parent(s): 3f04365

fix OOM: chunked KL with checkpointing + PYTORCH_CUDA_ALLOC_CONF expandable_segments; add kl_chunk_size config key

Browse files
configs/base.toml CHANGED
@@ -31,6 +31,7 @@ attn_implementation = "flash_attention_2"
31
  student_dtype = "bfloat16"
32
  teacher_dtype = "bfloat16"
33
  mixed_precision = "bf16"
 
34
 
35
  [eval]
36
  every_steps = 5
 
31
  student_dtype = "bfloat16"
32
  teacher_dtype = "bfloat16"
33
  mixed_precision = "bf16"
34
+ kl_chunk_size = 0
35
 
36
  [eval]
37
  every_steps = 5
configs/grow40_simple.toml CHANGED
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
32
  student_dtype = "bfloat16"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
 
35
 
36
  [eval]
37
  every_steps = 50
 
32
  student_dtype = "bfloat16"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
+ kl_chunk_size = 0
36
 
37
  [eval]
38
  every_steps = 50
configs/grow40_winning.toml CHANGED
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
32
  student_dtype = "float32"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
 
35
 
36
  [eval]
37
  every_steps = 50
 
32
  student_dtype = "float32"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
 
37
  [eval]
38
  every_steps = 50
configs/replicate_zero4.toml CHANGED
@@ -31,6 +31,7 @@ attn_implementation = "flash_attention_2"
31
  student_dtype = "float32"
32
  teacher_dtype = "bfloat16"
33
  mixed_precision = "bf16"
 
34
 
35
  [eval]
36
  every_steps = 50
 
31
  student_dtype = "float32"
32
  teacher_dtype = "bfloat16"
33
  mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
 
36
  [eval]
37
  every_steps = 50
configs/zero_14_17.toml CHANGED
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
32
  student_dtype = "bfloat16"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
 
35
 
36
  [eval]
37
  every_steps = 50
 
32
  student_dtype = "bfloat16"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
+ kl_chunk_size = 0
36
 
37
  [eval]
38
  every_steps = 50
distill.py CHANGED
@@ -9,6 +9,10 @@ The TOML config is the single source of truth - no hardcoded defaults in this fi
9
  The only command line argument is --config <path-to-toml>.
10
  """
11
 
 
 
 
 
12
  import argparse
13
  import gc
14
  import json
@@ -20,6 +24,7 @@ from pathlib import Path
20
 
21
  import torch
22
  import torch.nn.functional as F
 
23
  from torch.optim import AdamW
24
 
25
  from accelerate import Accelerator
@@ -65,6 +70,7 @@ REQUIRED_KEYS = {
65
  "student_dtype",
66
  "teacher_dtype",
67
  "mixed_precision",
 
68
  ),
69
  "eval": ("every_steps", "samples", "seed"),
70
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
@@ -337,21 +343,44 @@ def collate_pad(token_lists, pad_id):
337
  # Loss
338
  # ----------------------------------------------------------------------------
339
 
340
- def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos):
341
- """Forward KL(teacher || student), masked for padding & start_pos.
342
-
343
- Computed in fp32 for numerical stability.
344
- """
345
- s = student_logits[:, start_pos:, :].float()
346
- t = teacher_logits[:, start_pos:, :].detach().float()
347
- mask = attention_mask[:, start_pos:].float()
348
-
349
  t_log_p = F.log_softmax(t, dim=-1)
350
  s_log_p = F.log_softmax(s, dim=-1)
351
  t_p = t_log_p.exp()
 
 
 
352
 
353
- per_token = (t_p * (t_log_p - s_log_p)).sum(-1) # [B, T-start]
354
- return (per_token * mask).sum() / mask.sum().clamp_min(1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
  # ----------------------------------------------------------------------------
@@ -390,7 +419,7 @@ def make_scheduler(optimizer, train_cfg):
390
  # ----------------------------------------------------------------------------
391
 
392
  @torch.no_grad()
393
- def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos):
394
  student.eval()
395
  sdev = accelerator.device
396
  total = 0.0
@@ -401,7 +430,10 @@ def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos):
401
  mask = mask.to(sdev)
402
  t_logits = teacher_forward(teacher, ids, mask)
403
  s_logits = student(input_ids=ids, attention_mask=mask).logits
404
- loss = kl_loss_masked(s_logits, t_logits, mask, start_pos=kl_start_pos)
 
 
 
405
  total += loss.item()
406
  n += 1
407
  del t_logits, s_logits, loss
@@ -550,6 +582,7 @@ def main():
550
  samples_per_step = cfg["train"]["samples_per_step"]
551
  grad_clip = cfg["train"]["grad_clip"]
552
  kl_start_pos = cfg["data"]["kl_start_pos"]
 
553
  max_steps = cfg["train"]["max_steps"]
554
  eval_every = cfg["eval"]["every_steps"]
555
  log_every = cfg["log"]["log_every"]
@@ -578,7 +611,10 @@ def main():
578
  with torch.no_grad():
579
  t_logits = teacher_forward(teacher, ids, mask)
580
  s_logits = student(input_ids=ids, attention_mask=mask).logits
581
- loss = kl_loss_masked(s_logits, t_logits, mask, start_pos=kl_start_pos)
 
 
 
582
 
583
  optimizer.zero_grad()
584
  accelerator.backward(loss)
@@ -612,7 +648,8 @@ def main():
612
 
613
  if global_step % eval_every == 0:
614
  eval_kl = evaluate(
615
- accelerator, student, teacher, eval_batches, pad_id, kl_start_pos
 
616
  )
617
  if accelerator.is_main_process:
618
  log.info(
@@ -635,7 +672,8 @@ def main():
635
 
636
  # Final eval
637
  eval_kl = evaluate(
638
- accelerator, student, teacher, eval_batches, pad_id, kl_start_pos
 
639
  )
640
  if accelerator.is_main_process:
641
  log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
 
9
  The only command line argument is --config <path-to-toml>.
10
  """
11
 
12
+ import os
13
+ # Reduce fragmentation; large vocab + long seq creates many short-lived big tensors.
14
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
15
+
16
  import argparse
17
  import gc
18
  import json
 
24
 
25
  import torch
26
  import torch.nn.functional as F
27
+ import torch.utils.checkpoint as checkpoint_utils
28
  from torch.optim import AdamW
29
 
30
  from accelerate import Accelerator
 
70
  "student_dtype",
71
  "teacher_dtype",
72
  "mixed_precision",
73
+ "kl_chunk_size",
74
  ),
75
  "eval": ("every_steps", "samples", "seed"),
76
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
 
343
  # Loss
344
  # ----------------------------------------------------------------------------
345
 
346
+ def _kl_chunk_sum(s_chunk, t_chunk, m_chunk):
347
+ """Compute (sum of masked KL) over a slice. Used as a checkpointed unit so the
348
+ fp32 softmax intermediates only live for one chunk's worth of memory at a time."""
349
+ s = s_chunk.float()
350
+ t = t_chunk.float()
 
 
 
 
351
  t_log_p = F.log_softmax(t, dim=-1)
352
  s_log_p = F.log_softmax(s, dim=-1)
353
  t_p = t_log_p.exp()
354
+ per_token = (t_p * (t_log_p - s_log_p)).sum(-1)
355
+ return (per_token * m_chunk).sum()
356
+
357
 
358
+ def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size):
359
+ """Forward KL(teacher || student), masked for padding & start_pos, in fp32.
360
+
361
+ If chunk_size > 0, processes the [start_pos:] sequence in chunks of that many
362
+ positions, with gradient checkpointing on each chunk so peak memory is bounded
363
+ by one chunk's intermediates rather than the full sequence's.
364
+ """
365
+ s_full = student_logits[:, start_pos:, :]
366
+ t_full = teacher_logits[:, start_pos:, :].detach()
367
+ m_full = attention_mask[:, start_pos:].float()
368
+
369
+ T = s_full.shape[1]
370
+ if chunk_size <= 0 or chunk_size >= T:
371
+ return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0)
372
+
373
+ total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32)
374
+ for i in range(0, T, chunk_size):
375
+ end = min(i + chunk_size, T)
376
+ s_c = s_full[:, i:end, :]
377
+ t_c = t_full[:, i:end, :]
378
+ m_c = m_full[:, i:end]
379
+ chunk_kl = checkpoint_utils.checkpoint(
380
+ _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False
381
+ )
382
+ total_kl = total_kl + chunk_kl
383
+ return total_kl / m_full.sum().clamp_min(1.0)
384
 
385
 
386
  # ----------------------------------------------------------------------------
 
419
  # ----------------------------------------------------------------------------
420
 
421
  @torch.no_grad()
422
+ def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size):
423
  student.eval()
424
  sdev = accelerator.device
425
  total = 0.0
 
430
  mask = mask.to(sdev)
431
  t_logits = teacher_forward(teacher, ids, mask)
432
  s_logits = student(input_ids=ids, attention_mask=mask).logits
433
+ loss = kl_loss_masked(
434
+ s_logits, t_logits, mask,
435
+ start_pos=kl_start_pos, chunk_size=kl_chunk_size,
436
+ )
437
  total += loss.item()
438
  n += 1
439
  del t_logits, s_logits, loss
 
582
  samples_per_step = cfg["train"]["samples_per_step"]
583
  grad_clip = cfg["train"]["grad_clip"]
584
  kl_start_pos = cfg["data"]["kl_start_pos"]
585
+ kl_chunk_size = cfg["train"]["kl_chunk_size"]
586
  max_steps = cfg["train"]["max_steps"]
587
  eval_every = cfg["eval"]["every_steps"]
588
  log_every = cfg["log"]["log_every"]
 
611
  with torch.no_grad():
612
  t_logits = teacher_forward(teacher, ids, mask)
613
  s_logits = student(input_ids=ids, attention_mask=mask).logits
614
+ loss = kl_loss_masked(
615
+ s_logits, t_logits, mask,
616
+ start_pos=kl_start_pos, chunk_size=kl_chunk_size,
617
+ )
618
 
619
  optimizer.zero_grad()
620
  accelerator.backward(loss)
 
648
 
649
  if global_step % eval_every == 0:
650
  eval_kl = evaluate(
651
+ accelerator, student, teacher, eval_batches,
652
+ pad_id, kl_start_pos, kl_chunk_size,
653
  )
654
  if accelerator.is_main_process:
655
  log.info(
 
672
 
673
  # Final eval
674
  eval_kl = evaluate(
675
+ accelerator, student, teacher, eval_batches,
676
+ pad_id, kl_start_pos, kl_chunk_size,
677
  )
678
  if accelerator.is_main_process:
679
  log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
scripts/backup_to_hf.py CHANGED
@@ -24,6 +24,7 @@ INCLUDE = [
24
  "configs/accelerate.yaml",
25
  "scripts/backup_to_hf.py",
26
  "scripts/run_sweep.sh",
 
27
  "pyproject.toml",
28
  "requirements.lock.txt",
29
  ]
 
24
  "configs/accelerate.yaml",
25
  "scripts/backup_to_hf.py",
26
  "scripts/run_sweep.sh",
27
+ "scripts/run_sweep_rerun.sh",
28
  "pyproject.toml",
29
  "requirements.lock.txt",
30
  ]
scripts/run_sweep_rerun.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Re-runs the two configs that OOM'd in the original sweep, now with the
3
+ # chunked-KL fix and PYTORCH_CUDA_ALLOC_CONF=expandable_segments in distill.py.
4
+ # Reads HF_TOKEN, HUGGING_FACE_HUB_TOKEN, WANDB_API_KEY from the calling env.
5
+ #
6
+ # Launch with:
7
+ # nohup ./scripts/run_sweep_rerun.sh > logs/sweep_rerun_master.log 2>&1 &
8
+
9
+ set -uo pipefail
10
+ cd "$(dirname "$0")/.."
11
+
12
+ CONFIGS=(
13
+ "configs/replicate_zero4.toml"
14
+ "configs/grow40_winning.toml"
15
+ )
16
+
17
+ LOG_DIR="logs"
18
+ mkdir -p "$LOG_DIR"
19
+
20
+ for cfg in "${CONFIGS[@]}"; do
21
+ name="$(basename "$cfg" .toml)"
22
+ log="$LOG_DIR/$name.log"
23
+ echo ">>> [$(date '+%F %T')] starting $name -> $log"
24
+ .venv/bin/accelerate launch \
25
+ --config_file configs/accelerate.yaml \
26
+ distill.py \
27
+ --config "$cfg" \
28
+ > "$log" 2>&1
29
+ rc=$?
30
+ echo "<<< [$(date '+%F %T')] finished $name (exit=$rc)"
31
+ if [[ $rc -ne 0 ]]; then
32
+ echo " last 20 lines of $log:"
33
+ tail -20 "$log" | sed 's/^/ /'
34
+ fi
35
+ done
36
+
37
+ echo ">>> [$(date '+%F %T')] rerun complete"