Delta-Vector commited on
Commit
729546e
·
verified ·
1 Parent(s): 3af7f4c

add phase-2 ultra-conservative sweep (J,K,L,M) + waiter that auto-launches after phase 1 from the best ckpt

Browse files
configs/sweep/J_phase2_lr5e9_const.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 2: ultra-conservative resume from phase 1 best.
2
+ # Tiny LR, constant, zero warmup, very high beta2 for max smoothing.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "./out/phase1_best"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 5.0e-9
21
+ schedule = "constant"
22
+ warmup_steps = 0
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.99]
26
+ eps = 1.0e-2
27
+ samples_per_step = 4
28
+ micro_batch_size = 4
29
+ max_steps = 3000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "J_phase2_lr5e9_const"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/J_phase2_lr5e9_const"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 40
configs/sweep/K_phase2_lr2e8_const.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 2: still conservative but a bit more LR than J.
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "./out/phase1_best"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 2.0e-8
20
+ schedule = "constant"
21
+ warmup_steps = 0
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.99]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 3000
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "K_phase2_lr2e8_const"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/K_phase2_lr2e8_const"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/L_phase2_lr1e8_warmup500.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 2: very gentle cosine warmup over 500 steps to avoid any LR shock.
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "./out/phase1_best"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 1.0e-8
20
+ schedule = "cosine"
21
+ warmup_steps = 500
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.99]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 3000
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "L_phase2_lr1e8_warmup500"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/L_phase2_lr1e8_warmup500"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/M_phase2_lr2e8_largebatch.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 2: same tiny LR but larger inner batch (16/rank → effective 128) so the
2
+ # gradients are much smoother. Should give the smoothest descent of all.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "./out/phase1_best"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 2.0e-8
21
+ schedule = "constant"
22
+ warmup_steps = 0
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.99]
26
+ eps = 1.0e-3
27
+ samples_per_step = 16
28
+ micro_batch_size = 1
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "M_phase2_lr2e8_largebatch"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/M_phase2_lr2e8_largebatch"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 40
scripts/backup_to_hf.py CHANGED
@@ -31,11 +31,16 @@ INCLUDE = [
31
  "configs/sweep/G_cold_lr2e7_grow40.toml",
32
  "configs/sweep/H_cold_lr1e7_32L.toml",
33
  "configs/sweep/I_cold_paramgroups_grow40.toml",
 
 
 
 
34
  "configs/accelerate.yaml",
35
  "scripts/backup_to_hf.py",
36
  "scripts/run_sweep.sh",
37
  "scripts/run_sweep_rerun.sh",
38
  "scripts/run_hparam_sweep.sh",
 
39
  "pyproject.toml",
40
  "requirements.lock.txt",
41
  ]
 
31
  "configs/sweep/G_cold_lr2e7_grow40.toml",
32
  "configs/sweep/H_cold_lr1e7_32L.toml",
33
  "configs/sweep/I_cold_paramgroups_grow40.toml",
34
+ "configs/sweep/J_phase2_lr5e9_const.toml",
35
+ "configs/sweep/K_phase2_lr2e8_const.toml",
36
+ "configs/sweep/L_phase2_lr1e8_warmup500.toml",
37
+ "configs/sweep/M_phase2_lr2e8_largebatch.toml",
38
  "configs/accelerate.yaml",
39
  "scripts/backup_to_hf.py",
40
  "scripts/run_sweep.sh",
41
  "scripts/run_sweep_rerun.sh",
42
  "scripts/run_hparam_sweep.sh",
43
+ "scripts/run_phase2_sweep.sh",
44
  "pyproject.toml",
45
  "requirements.lock.txt",
46
  ]
scripts/run_phase2_sweep.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Phase 2 sweep: waits for phase 1 to finish, then resumes from whichever
3
+ # phase 1 run achieved the lowest eval KL. All configs use very small LRs
4
+ # and constant/very-slow schedules. Goal: monotone, very slow KL descent.
5
+ #
6
+ # Launch in the background with:
7
+ # nohup ./scripts/run_phase2_sweep.sh > logs/sweep_phase2_master.log 2>&1 &
8
+
9
+ set -uo pipefail
10
+ cd "$(dirname "$0")/.."
11
+
12
+ LOG_DIR="logs"
13
+ mkdir -p "$LOG_DIR"
14
+
15
+ # 1. Wait for phase 1 to finish.
16
+ echo ">>> [$(date '+%F %T')] phase2 waiter: waiting for phase 1 to finish..."
17
+ while pgrep -f "run_hparam_sweep.sh" > /dev/null; do
18
+ sleep 30
19
+ done
20
+ # Also wait for any straggler distill.py procs from phase 1 to die
21
+ while pgrep -f "distill.py --config configs/sweep/[A-I]_" > /dev/null; do
22
+ sleep 30
23
+ done
24
+ echo ">>> [$(date '+%F %T')] phase2 waiter: phase 1 done."
25
+
26
+ # 2. Find phase 1's best ckpt.
27
+ PHASE1_BEST=$(.venv/bin/python - <<'PY'
28
+ import json, glob, os, sys
29
+ best_kl = float("inf")
30
+ best_dir = None
31
+ for f in glob.glob("out/sweep/[A-I]_*/best/best.json"):
32
+ try:
33
+ kl = json.load(open(f))["eval_kl"]
34
+ except Exception:
35
+ continue
36
+ if kl < best_kl:
37
+ best_kl = kl
38
+ best_dir = os.path.dirname(f)
39
+ if best_dir is None:
40
+ sys.exit("no phase 1 best found")
41
+ print(f"{best_dir}\t{best_kl}")
42
+ PY
43
+ )
44
+ BEST_DIR=$(echo "$PHASE1_BEST" | cut -f1)
45
+ BEST_KL=$(echo "$PHASE1_BEST" | cut -f2)
46
+ echo ">>> phase 1 best: $BEST_DIR (eval_kl=$BEST_KL)"
47
+
48
+ # 3. Symlink ./out/phase1_best -> the winner so phase 2 configs can reference
49
+ # a stable path.
50
+ mkdir -p out
51
+ rm -f out/phase1_best
52
+ ln -sfn "$(realpath "$BEST_DIR")" out/phase1_best
53
+ echo ">>> linked out/phase1_best -> $(readlink out/phase1_best)"
54
+
55
+ # 4. Run phase 2 configs sequentially.
56
+ CONFIGS=(
57
+ "configs/sweep/J_phase2_lr5e9_const.toml"
58
+ "configs/sweep/K_phase2_lr2e8_const.toml"
59
+ "configs/sweep/L_phase2_lr1e8_warmup500.toml"
60
+ "configs/sweep/M_phase2_lr2e8_largebatch.toml"
61
+ )
62
+
63
+ for cfg in "${CONFIGS[@]}"; do
64
+ name="$(basename "$cfg" .toml)"
65
+ log="$LOG_DIR/$name.log"
66
+ echo ">>> [$(date '+%F %T')] starting $name -> $log"
67
+ .venv/bin/accelerate launch \
68
+ --config_file configs/accelerate.yaml \
69
+ distill.py \
70
+ --config "$cfg" \
71
+ > "$log" 2>&1
72
+ rc=$?
73
+ best_line=$(grep -E "Best eval KL" "$log" | tail -1)
74
+ echo "<<< [$(date '+%F %T')] finished $name (exit=$rc) ${best_line}"
75
+ if [[ $rc -ne 0 ]]; then
76
+ echo " last 12 lines of $log:"
77
+ tail -12 "$log" | sed 's/^/ /'
78
+ fi
79
+ done
80
+
81
+ echo ">>> [$(date '+%F %T')] phase2 sweep complete"
82
+ echo ">>> overall summary (phase 1 + phase 2):"
83
+ for log in $LOG_DIR/[A-M]_*.log; do
84
+ name=$(basename "$log" .log)
85
+ best=$(grep -E "Best eval KL" "$log" 2>/dev/null | tail -1 | sed 's/.*Best eval KL = //')
86
+ printf " %-32s %s\n" "$name" "${best:-FAILED}"
87
+ done | sort -k2