fix OOM: chunked KL with checkpointing + PYTORCH_CUDA_ALLOC_CONF expandable_segments; add kl_chunk_size config key
Browse files- configs/base.toml +1 -0
- configs/grow40_simple.toml +1 -0
- configs/grow40_winning.toml +1 -0
- configs/replicate_zero4.toml +1 -0
- configs/zero_14_17.toml +1 -0
- distill.py +54 -16
- scripts/backup_to_hf.py +1 -0
- scripts/run_sweep_rerun.sh +37 -0
configs/base.toml
CHANGED
|
@@ -31,6 +31,7 @@ attn_implementation = "flash_attention_2"
|
|
| 31 |
student_dtype = "bfloat16"
|
| 32 |
teacher_dtype = "bfloat16"
|
| 33 |
mixed_precision = "bf16"
|
|
|
|
| 34 |
|
| 35 |
[eval]
|
| 36 |
every_steps = 5
|
|
|
|
| 31 |
student_dtype = "bfloat16"
|
| 32 |
teacher_dtype = "bfloat16"
|
| 33 |
mixed_precision = "bf16"
|
| 34 |
+
kl_chunk_size = 0
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 5
|
configs/grow40_simple.toml
CHANGED
|
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
|
|
| 32 |
student_dtype = "bfloat16"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
|
|
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 50
|
|
|
|
| 32 |
student_dtype = "bfloat16"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
| 35 |
+
kl_chunk_size = 0
|
| 36 |
|
| 37 |
[eval]
|
| 38 |
every_steps = 50
|
configs/grow40_winning.toml
CHANGED
|
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
|
|
| 32 |
student_dtype = "float32"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
|
|
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 50
|
|
|
|
| 32 |
student_dtype = "float32"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
| 35 |
+
kl_chunk_size = 256
|
| 36 |
|
| 37 |
[eval]
|
| 38 |
every_steps = 50
|
configs/replicate_zero4.toml
CHANGED
|
@@ -31,6 +31,7 @@ attn_implementation = "flash_attention_2"
|
|
| 31 |
student_dtype = "float32"
|
| 32 |
teacher_dtype = "bfloat16"
|
| 33 |
mixed_precision = "bf16"
|
|
|
|
| 34 |
|
| 35 |
[eval]
|
| 36 |
every_steps = 50
|
|
|
|
| 31 |
student_dtype = "float32"
|
| 32 |
teacher_dtype = "bfloat16"
|
| 33 |
mixed_precision = "bf16"
|
| 34 |
+
kl_chunk_size = 256
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 50
|
configs/zero_14_17.toml
CHANGED
|
@@ -32,6 +32,7 @@ attn_implementation = "flash_attention_2"
|
|
| 32 |
student_dtype = "bfloat16"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
|
|
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 50
|
|
|
|
| 32 |
student_dtype = "bfloat16"
|
| 33 |
teacher_dtype = "bfloat16"
|
| 34 |
mixed_precision = "bf16"
|
| 35 |
+
kl_chunk_size = 0
|
| 36 |
|
| 37 |
[eval]
|
| 38 |
every_steps = 50
|
distill.py
CHANGED
|
@@ -9,6 +9,10 @@ The TOML config is the single source of truth - no hardcoded defaults in this fi
|
|
| 9 |
The only command line argument is --config <path-to-toml>.
|
| 10 |
"""
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import argparse
|
| 13 |
import gc
|
| 14 |
import json
|
|
@@ -20,6 +24,7 @@ from pathlib import Path
|
|
| 20 |
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
|
|
|
| 23 |
from torch.optim import AdamW
|
| 24 |
|
| 25 |
from accelerate import Accelerator
|
|
@@ -65,6 +70,7 @@ REQUIRED_KEYS = {
|
|
| 65 |
"student_dtype",
|
| 66 |
"teacher_dtype",
|
| 67 |
"mixed_precision",
|
|
|
|
| 68 |
),
|
| 69 |
"eval": ("every_steps", "samples", "seed"),
|
| 70 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
|
@@ -337,21 +343,44 @@ def collate_pad(token_lists, pad_id):
|
|
| 337 |
# Loss
|
| 338 |
# ----------------------------------------------------------------------------
|
| 339 |
|
| 340 |
-
def
|
| 341 |
-
"""
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
s = student_logits[:, start_pos:, :].float()
|
| 346 |
-
t = teacher_logits[:, start_pos:, :].detach().float()
|
| 347 |
-
mask = attention_mask[:, start_pos:].float()
|
| 348 |
-
|
| 349 |
t_log_p = F.log_softmax(t, dim=-1)
|
| 350 |
s_log_p = F.log_softmax(s, dim=-1)
|
| 351 |
t_p = t_log_p.exp()
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
|
| 357 |
# ----------------------------------------------------------------------------
|
|
@@ -390,7 +419,7 @@ def make_scheduler(optimizer, train_cfg):
|
|
| 390 |
# ----------------------------------------------------------------------------
|
| 391 |
|
| 392 |
@torch.no_grad()
|
| 393 |
-
def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos):
|
| 394 |
student.eval()
|
| 395 |
sdev = accelerator.device
|
| 396 |
total = 0.0
|
|
@@ -401,7 +430,10 @@ def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos):
|
|
| 401 |
mask = mask.to(sdev)
|
| 402 |
t_logits = teacher_forward(teacher, ids, mask)
|
| 403 |
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 404 |
-
loss = kl_loss_masked(
|
|
|
|
|
|
|
|
|
|
| 405 |
total += loss.item()
|
| 406 |
n += 1
|
| 407 |
del t_logits, s_logits, loss
|
|
@@ -550,6 +582,7 @@ def main():
|
|
| 550 |
samples_per_step = cfg["train"]["samples_per_step"]
|
| 551 |
grad_clip = cfg["train"]["grad_clip"]
|
| 552 |
kl_start_pos = cfg["data"]["kl_start_pos"]
|
|
|
|
| 553 |
max_steps = cfg["train"]["max_steps"]
|
| 554 |
eval_every = cfg["eval"]["every_steps"]
|
| 555 |
log_every = cfg["log"]["log_every"]
|
|
@@ -578,7 +611,10 @@ def main():
|
|
| 578 |
with torch.no_grad():
|
| 579 |
t_logits = teacher_forward(teacher, ids, mask)
|
| 580 |
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 581 |
-
loss = kl_loss_masked(
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
optimizer.zero_grad()
|
| 584 |
accelerator.backward(loss)
|
|
@@ -612,7 +648,8 @@ def main():
|
|
| 612 |
|
| 613 |
if global_step % eval_every == 0:
|
| 614 |
eval_kl = evaluate(
|
| 615 |
-
accelerator, student, teacher, eval_batches,
|
|
|
|
| 616 |
)
|
| 617 |
if accelerator.is_main_process:
|
| 618 |
log.info(
|
|
@@ -635,7 +672,8 @@ def main():
|
|
| 635 |
|
| 636 |
# Final eval
|
| 637 |
eval_kl = evaluate(
|
| 638 |
-
accelerator, student, teacher, eval_batches,
|
|
|
|
| 639 |
)
|
| 640 |
if accelerator.is_main_process:
|
| 641 |
log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
|
|
|
|
| 9 |
The only command line argument is --config <path-to-toml>.
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
import os
|
| 13 |
+
# Reduce fragmentation; large vocab + long seq creates many short-lived big tensors.
|
| 14 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 15 |
+
|
| 16 |
import argparse
|
| 17 |
import gc
|
| 18 |
import json
|
|
|
|
| 24 |
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
+
import torch.utils.checkpoint as checkpoint_utils
|
| 28 |
from torch.optim import AdamW
|
| 29 |
|
| 30 |
from accelerate import Accelerator
|
|
|
|
| 70 |
"student_dtype",
|
| 71 |
"teacher_dtype",
|
| 72 |
"mixed_precision",
|
| 73 |
+
"kl_chunk_size",
|
| 74 |
),
|
| 75 |
"eval": ("every_steps", "samples", "seed"),
|
| 76 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
|
|
|
| 343 |
# Loss
|
| 344 |
# ----------------------------------------------------------------------------
|
| 345 |
|
| 346 |
+
def _kl_chunk_sum(s_chunk, t_chunk, m_chunk):
|
| 347 |
+
"""Compute (sum of masked KL) over a slice. Used as a checkpointed unit so the
|
| 348 |
+
fp32 softmax intermediates only live for one chunk's worth of memory at a time."""
|
| 349 |
+
s = s_chunk.float()
|
| 350 |
+
t = t_chunk.float()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
t_log_p = F.log_softmax(t, dim=-1)
|
| 352 |
s_log_p = F.log_softmax(s, dim=-1)
|
| 353 |
t_p = t_log_p.exp()
|
| 354 |
+
per_token = (t_p * (t_log_p - s_log_p)).sum(-1)
|
| 355 |
+
return (per_token * m_chunk).sum()
|
| 356 |
+
|
| 357 |
|
| 358 |
+
def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size):
|
| 359 |
+
"""Forward KL(teacher || student), masked for padding & start_pos, in fp32.
|
| 360 |
+
|
| 361 |
+
If chunk_size > 0, processes the [start_pos:] sequence in chunks of that many
|
| 362 |
+
positions, with gradient checkpointing on each chunk so peak memory is bounded
|
| 363 |
+
by one chunk's intermediates rather than the full sequence's.
|
| 364 |
+
"""
|
| 365 |
+
s_full = student_logits[:, start_pos:, :]
|
| 366 |
+
t_full = teacher_logits[:, start_pos:, :].detach()
|
| 367 |
+
m_full = attention_mask[:, start_pos:].float()
|
| 368 |
+
|
| 369 |
+
T = s_full.shape[1]
|
| 370 |
+
if chunk_size <= 0 or chunk_size >= T:
|
| 371 |
+
return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0)
|
| 372 |
+
|
| 373 |
+
total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32)
|
| 374 |
+
for i in range(0, T, chunk_size):
|
| 375 |
+
end = min(i + chunk_size, T)
|
| 376 |
+
s_c = s_full[:, i:end, :]
|
| 377 |
+
t_c = t_full[:, i:end, :]
|
| 378 |
+
m_c = m_full[:, i:end]
|
| 379 |
+
chunk_kl = checkpoint_utils.checkpoint(
|
| 380 |
+
_kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False
|
| 381 |
+
)
|
| 382 |
+
total_kl = total_kl + chunk_kl
|
| 383 |
+
return total_kl / m_full.sum().clamp_min(1.0)
|
| 384 |
|
| 385 |
|
| 386 |
# ----------------------------------------------------------------------------
|
|
|
|
| 419 |
# ----------------------------------------------------------------------------
|
| 420 |
|
| 421 |
@torch.no_grad()
|
| 422 |
+
def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size):
|
| 423 |
student.eval()
|
| 424 |
sdev = accelerator.device
|
| 425 |
total = 0.0
|
|
|
|
| 430 |
mask = mask.to(sdev)
|
| 431 |
t_logits = teacher_forward(teacher, ids, mask)
|
| 432 |
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 433 |
+
loss = kl_loss_masked(
|
| 434 |
+
s_logits, t_logits, mask,
|
| 435 |
+
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
|
| 436 |
+
)
|
| 437 |
total += loss.item()
|
| 438 |
n += 1
|
| 439 |
del t_logits, s_logits, loss
|
|
|
|
| 582 |
samples_per_step = cfg["train"]["samples_per_step"]
|
| 583 |
grad_clip = cfg["train"]["grad_clip"]
|
| 584 |
kl_start_pos = cfg["data"]["kl_start_pos"]
|
| 585 |
+
kl_chunk_size = cfg["train"]["kl_chunk_size"]
|
| 586 |
max_steps = cfg["train"]["max_steps"]
|
| 587 |
eval_every = cfg["eval"]["every_steps"]
|
| 588 |
log_every = cfg["log"]["log_every"]
|
|
|
|
| 611 |
with torch.no_grad():
|
| 612 |
t_logits = teacher_forward(teacher, ids, mask)
|
| 613 |
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 614 |
+
loss = kl_loss_masked(
|
| 615 |
+
s_logits, t_logits, mask,
|
| 616 |
+
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
|
| 617 |
+
)
|
| 618 |
|
| 619 |
optimizer.zero_grad()
|
| 620 |
accelerator.backward(loss)
|
|
|
|
| 648 |
|
| 649 |
if global_step % eval_every == 0:
|
| 650 |
eval_kl = evaluate(
|
| 651 |
+
accelerator, student, teacher, eval_batches,
|
| 652 |
+
pad_id, kl_start_pos, kl_chunk_size,
|
| 653 |
)
|
| 654 |
if accelerator.is_main_process:
|
| 655 |
log.info(
|
|
|
|
| 672 |
|
| 673 |
# Final eval
|
| 674 |
eval_kl = evaluate(
|
| 675 |
+
accelerator, student, teacher, eval_batches,
|
| 676 |
+
pad_id, kl_start_pos, kl_chunk_size,
|
| 677 |
)
|
| 678 |
if accelerator.is_main_process:
|
| 679 |
log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
|
scripts/backup_to_hf.py
CHANGED
|
@@ -24,6 +24,7 @@ INCLUDE = [
|
|
| 24 |
"configs/accelerate.yaml",
|
| 25 |
"scripts/backup_to_hf.py",
|
| 26 |
"scripts/run_sweep.sh",
|
|
|
|
| 27 |
"pyproject.toml",
|
| 28 |
"requirements.lock.txt",
|
| 29 |
]
|
|
|
|
| 24 |
"configs/accelerate.yaml",
|
| 25 |
"scripts/backup_to_hf.py",
|
| 26 |
"scripts/run_sweep.sh",
|
| 27 |
+
"scripts/run_sweep_rerun.sh",
|
| 28 |
"pyproject.toml",
|
| 29 |
"requirements.lock.txt",
|
| 30 |
]
|
scripts/run_sweep_rerun.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Re-runs the two configs that OOM'd in the original sweep, now with the
|
| 3 |
+
# chunked-KL fix and PYTORCH_CUDA_ALLOC_CONF=expandable_segments in distill.py.
|
| 4 |
+
# Reads HF_TOKEN, HUGGING_FACE_HUB_TOKEN, WANDB_API_KEY from the calling env.
|
| 5 |
+
#
|
| 6 |
+
# Launch with:
|
| 7 |
+
# nohup ./scripts/run_sweep_rerun.sh > logs/sweep_rerun_master.log 2>&1 &
|
| 8 |
+
|
| 9 |
+
set -uo pipefail
|
| 10 |
+
cd "$(dirname "$0")/.."
|
| 11 |
+
|
| 12 |
+
CONFIGS=(
|
| 13 |
+
"configs/replicate_zero4.toml"
|
| 14 |
+
"configs/grow40_winning.toml"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
LOG_DIR="logs"
|
| 18 |
+
mkdir -p "$LOG_DIR"
|
| 19 |
+
|
| 20 |
+
for cfg in "${CONFIGS[@]}"; do
|
| 21 |
+
name="$(basename "$cfg" .toml)"
|
| 22 |
+
log="$LOG_DIR/$name.log"
|
| 23 |
+
echo ">>> [$(date '+%F %T')] starting $name -> $log"
|
| 24 |
+
.venv/bin/accelerate launch \
|
| 25 |
+
--config_file configs/accelerate.yaml \
|
| 26 |
+
distill.py \
|
| 27 |
+
--config "$cfg" \
|
| 28 |
+
> "$log" 2>&1
|
| 29 |
+
rc=$?
|
| 30 |
+
echo "<<< [$(date '+%F %T')] finished $name (exit=$rc)"
|
| 31 |
+
if [[ $rc -ne 0 ]]; then
|
| 32 |
+
echo " last 20 lines of $log:"
|
| 33 |
+
tail -20 "$log" | sed 's/^/ /'
|
| 34 |
+
fi
|
| 35 |
+
done
|
| 36 |
+
|
| 37 |
+
echo ">>> [$(date '+%F %T')] rerun complete"
|