Delta-Vector commited on
Commit
3af7f4c
·
verified ·
1 Parent(s): 35d9db6

add 9-config hparam sweep + new_layer_lr_mul param-groups support

Browse files
configs/base.toml CHANGED
@@ -33,6 +33,7 @@ student_dtype = "bfloat16"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
  kl_chunk_size = 0
 
36
 
37
  [eval]
38
  every_steps = 5
 
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
  kl_chunk_size = 0
36
+ new_layer_lr_mul = 1.0
37
 
38
  [eval]
39
  every_steps = 5
configs/grow40_simple.toml CHANGED
@@ -34,6 +34,7 @@ student_dtype = "bfloat16"
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 0
 
37
 
38
  [eval]
39
  every_steps = 50
 
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 0
37
+ new_layer_lr_mul = 1.0
38
 
39
  [eval]
40
  every_steps = 50
configs/grow40_winning.toml CHANGED
@@ -35,6 +35,7 @@ student_dtype = "bfloat16"
35
  teacher_dtype = "bfloat16"
36
  mixed_precision = "bf16"
37
  kl_chunk_size = 256
 
38
 
39
  [eval]
40
  every_steps = 50
 
35
  teacher_dtype = "bfloat16"
36
  mixed_precision = "bf16"
37
  kl_chunk_size = 256
38
+ new_layer_lr_mul = 1.0
39
 
40
  [eval]
41
  every_steps = 50
configs/grow40_winning_v2.toml CHANGED
@@ -34,6 +34,7 @@ student_dtype = "bfloat16"
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 256
 
37
 
38
  [eval]
39
  every_steps = 50
 
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 256
37
+ new_layer_lr_mul = 1.0
38
 
39
  [eval]
40
  every_steps = 50
configs/replicate_zero4.toml CHANGED
@@ -33,6 +33,7 @@ student_dtype = "float32"
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
  kl_chunk_size = 256
 
36
 
37
  [eval]
38
  every_steps = 50
 
33
  teacher_dtype = "bfloat16"
34
  mixed_precision = "bf16"
35
  kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
 
38
  [eval]
39
  every_steps = 50
configs/sweep/A_resume_lr1e7_cos.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume from grow40_winning best (eval kl 0.2219). Lower peak LR to avoid the
2
+ # overshoot we saw at 5e-7. Cosine warmup 100, 1500 steps.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "./out/grow40_winning/best"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 1.0e-7
21
+ schedule = "cosine"
22
+ warmup_steps = 100
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.999]
26
+ eps = 1.0e-3
27
+ samples_per_step = 4
28
+ micro_batch_size = 4
29
+ max_steps = 1500
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "A_resume_lr1e7_cos"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/A_resume_lr1e7_cos"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 40
configs/sweep/B_resume_lr5e8_cos.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume from grow40_winning best (eval kl 0.2219). Even lower peak LR.
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "./out/grow40_winning/best"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 5.0e-8
20
+ schedule = "cosine"
21
+ warmup_steps = 100
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.999]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 1500
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "B_resume_lr5e8_cos"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/B_resume_lr5e8_cos"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/C_resume_lr2e8_cos.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume from grow40_winning best. Very small LR - basically a fine-tune.
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "./out/grow40_winning/best"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 2.0e-8
20
+ schedule = "cosine"
21
+ warmup_steps = 100
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.999]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 1500
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "C_resume_lr2e8_cos"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/C_resume_lr2e8_cos"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/D_resume_lr1e7_const.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume from grow40_winning best. Constant LR (no schedule overshoot at all).
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "./out/grow40_winning/best"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 1.0e-7
20
+ schedule = "constant"
21
+ warmup_steps = 0
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.999]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 1500
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "D_resume_lr1e7_const"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/D_resume_lr1e7_const"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/E_resume_lr5e8_b95.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume from grow40_winning best. Smaller second-moment memory (beta2=0.95)
2
+ # so Adam stabilizes faster. Same low LR.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "./out/grow40_winning/best"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 5.0e-8
21
+ schedule = "cosine"
22
+ warmup_steps = 100
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.95]
26
+ eps = 1.0e-8
27
+ samples_per_step = 4
28
+ micro_batch_size = 4
29
+ max_steps = 1500
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "E_resume_lr5e8_b95"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/E_resume_lr5e8_b95"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 40
configs/sweep/F_cold_lr1e7_grow40.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cold start, 40 layers, lower peak LR than the original winning recipe.
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "Troiaaa/m-6a3lnzvb"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 1.0e-7
20
+ schedule = "cosine"
21
+ warmup_steps = 100
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.999]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 2000
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "F_cold_lr1e7_grow40"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/F_cold_lr1e7_grow40"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/G_cold_lr2e7_grow40.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cold start, 40 layers, lr=2e-7 (between 1e-7 and the failing 5e-7).
2
+
3
+ [model]
4
+ teacher = "Qwen/Qwen3.5-35B-A3B"
5
+ student = "Troiaaa/m-6a3lnzvb"
6
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
7
+
8
+ [data]
9
+ dataset = "karpathy/climbmix-400b-shuffle"
10
+ text_field = "text"
11
+ min_chars = 2560
12
+ max_seq_len = 2048
13
+ kl_start_pos = 128
14
+ seed = 6767
15
+ shuffle_buffer = 10000
16
+
17
+ [train]
18
+ seed = 6767
19
+ lr = 2.0e-7
20
+ schedule = "cosine"
21
+ warmup_steps = 100
22
+ weight_decay = 0.0
23
+ grad_clip = 1.0
24
+ betas = [0.9, 0.999]
25
+ eps = 1.0e-3
26
+ samples_per_step = 4
27
+ micro_batch_size = 4
28
+ max_steps = 2000
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+ kl_chunk_size = 256
35
+ new_layer_lr_mul = 1.0
36
+
37
+ [eval]
38
+ every_steps = 50
39
+ samples = 500
40
+ seed = 4242
41
+
42
+ [log]
43
+ wandb = true
44
+ wandb_project = "distil-subnet97"
45
+ wandb_run = "G_cold_lr2e7_grow40"
46
+ log_every = 1
47
+ output_dir = "./out/sweep/G_cold_lr2e7_grow40"
48
+
49
+ [init]
50
+ zero_layers = []
51
+ target_num_layers = 40
configs/sweep/H_cold_lr1e7_32L.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cold start, 32 layers (NO grow), lower LR. Tests whether the +8 layers were
2
+ # helping at all once we use the right LR.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "Troiaaa/m-6a3lnzvb"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 1.0e-7
21
+ schedule = "cosine"
22
+ warmup_steps = 100
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.999]
26
+ eps = 1.0e-3
27
+ samples_per_step = 4
28
+ micro_batch_size = 4
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 1.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "H_cold_lr1e7_32L"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/H_cold_lr1e7_32L"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 32
configs/sweep/I_cold_paramgroups_grow40.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cold start, 40 layers, low LR for original layers + 5x for the new ones.
2
+ # Lets the new layers wake up faster without disturbing the trained layers.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "Troiaaa/m-6a3lnzvb"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 1.0e-7
21
+ schedule = "cosine"
22
+ warmup_steps = 100
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.999]
26
+ eps = 1.0e-3
27
+ samples_per_step = 4
28
+ micro_batch_size = 4
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+ kl_chunk_size = 256
36
+ new_layer_lr_mul = 5.0
37
+
38
+ [eval]
39
+ every_steps = 50
40
+ samples = 500
41
+ seed = 4242
42
+
43
+ [log]
44
+ wandb = true
45
+ wandb_project = "distil-subnet97"
46
+ wandb_run = "I_cold_paramgroups_grow40"
47
+ log_every = 1
48
+ output_dir = "./out/sweep/I_cold_paramgroups_grow40"
49
+
50
+ [init]
51
+ zero_layers = []
52
+ target_num_layers = 40
configs/zero_14_17.toml CHANGED
@@ -34,6 +34,7 @@ student_dtype = "bfloat16"
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 0
 
37
 
38
  [eval]
39
  every_steps = 50
 
34
  teacher_dtype = "bfloat16"
35
  mixed_precision = "bf16"
36
  kl_chunk_size = 0
37
+ new_layer_lr_mul = 1.0
38
 
39
  [eval]
40
  every_steps = 50
distill.py CHANGED
@@ -72,6 +72,7 @@ REQUIRED_KEYS = {
72
  "mixed_precision",
73
  "kl_chunk_size",
74
  "micro_batch_size",
 
75
  ),
76
  "eval": ("every_steps", "samples", "seed"),
77
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
@@ -403,15 +404,48 @@ def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, ch
403
  # Optimizer / scheduler
404
  # ----------------------------------------------------------------------------
405
 
406
- def make_optimizer(model, train_cfg):
407
- return AdamW(
408
- [p for p in model.parameters() if p.requires_grad],
409
- lr=train_cfg["lr"],
 
 
 
 
410
  weight_decay=train_cfg["weight_decay"],
411
  betas=tuple(train_cfg["betas"]),
412
  eps=train_cfg["eps"],
413
  )
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  def make_scheduler(optimizer, train_cfg):
417
  schedule = train_cfg["schedule"]
@@ -522,8 +556,10 @@ def main():
522
  # ---- Layer modifications: grow first, then zero (composable)
523
  target_n = cfg["init"]["target_num_layers"]
524
  cur_n = len(get_inner_with_layers(student).layers)
 
525
  if target_n != cur_n:
526
  new_n, new_zeroed = grow_layers(student, target_n)
 
527
  if accelerator.is_main_process:
528
  log.info(f"Grew student from {cur_n} -> {new_n} layers")
529
  for idx, names in new_zeroed:
@@ -538,8 +574,14 @@ def main():
538
  teacher = teacher.to(accelerator.device)
539
 
540
  # ---- Optimizer / scheduler
541
- optimizer = make_optimizer(student, cfg["train"])
542
  scheduler = make_scheduler(optimizer, cfg["train"])
 
 
 
 
 
 
543
 
544
  # NB: do NOT pass `scheduler` to accelerator.prepare. When prepared, accelerate
545
  # advances the scheduler by `num_processes` steps per call (to match the
 
72
  "mixed_precision",
73
  "kl_chunk_size",
74
  "micro_batch_size",
75
+ "new_layer_lr_mul",
76
  ),
77
  "eval": ("every_steps", "samples", "seed"),
78
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
 
404
  # Optimizer / scheduler
405
  # ----------------------------------------------------------------------------
406
 
407
+ def make_optimizer(model, train_cfg, new_layer_indices=None):
408
+ """Create AdamW. If `new_layer_lr_mul != 1.0` and we know which layers are
409
+ 'new' (returned from grow_layers), put their params in a separate group with
410
+ a multiplied LR. Useful for the 'wake up new layers without disturbing the
411
+ old ones' regime."""
412
+ base_lr = train_cfg["lr"]
413
+ mul = train_cfg["new_layer_lr_mul"]
414
+ common = dict(
415
  weight_decay=train_cfg["weight_decay"],
416
  betas=tuple(train_cfg["betas"]),
417
  eps=train_cfg["eps"],
418
  )
419
 
420
+ if not new_layer_indices or mul == 1.0:
421
+ return AdamW(
422
+ [p for p in model.parameters() if p.requires_grad],
423
+ lr=base_lr,
424
+ **common,
425
+ )
426
+
427
+ inner = get_inner_with_layers(model)
428
+ new_pids = set()
429
+ for idx in new_layer_indices:
430
+ for p in inner.layers[idx].parameters():
431
+ if p.requires_grad:
432
+ new_pids.add(id(p))
433
+
434
+ new_params = []
435
+ rest_params = []
436
+ for p in model.parameters():
437
+ if not p.requires_grad:
438
+ continue
439
+ (new_params if id(p) in new_pids else rest_params).append(p)
440
+
441
+ return AdamW(
442
+ [
443
+ {"params": rest_params, "lr": base_lr},
444
+ {"params": new_params, "lr": base_lr * mul},
445
+ ],
446
+ **common,
447
+ )
448
+
449
 
450
  def make_scheduler(optimizer, train_cfg):
451
  schedule = train_cfg["schedule"]
 
556
  # ---- Layer modifications: grow first, then zero (composable)
557
  target_n = cfg["init"]["target_num_layers"]
558
  cur_n = len(get_inner_with_layers(student).layers)
559
+ new_layer_indices = []
560
  if target_n != cur_n:
561
  new_n, new_zeroed = grow_layers(student, target_n)
562
+ new_layer_indices = [idx for idx, _ in new_zeroed]
563
  if accelerator.is_main_process:
564
  log.info(f"Grew student from {cur_n} -> {new_n} layers")
565
  for idx, names in new_zeroed:
 
574
  teacher = teacher.to(accelerator.device)
575
 
576
  # ---- Optimizer / scheduler
577
+ optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices)
578
  scheduler = make_scheduler(optimizer, cfg["train"])
579
+ if accelerator.is_main_process and len(optimizer.param_groups) > 1:
580
+ log.info(
581
+ f"Param groups: rest lr={optimizer.param_groups[0]['lr']:.2e}, "
582
+ f"new lr={optimizer.param_groups[1]['lr']:.2e} "
583
+ f"({len(new_layer_indices)} layers grown)"
584
+ )
585
 
586
  # NB: do NOT pass `scheduler` to accelerator.prepare. When prepared, accelerate
587
  # advances the scheduler by `num_processes` steps per call (to match the
scripts/backup_to_hf.py CHANGED
@@ -22,10 +22,20 @@ INCLUDE = [
22
  "configs/grow40_winning.toml",
23
  "configs/grow40_simple.toml",
24
  "configs/grow40_winning_v2.toml",
 
 
 
 
 
 
 
 
 
25
  "configs/accelerate.yaml",
26
  "scripts/backup_to_hf.py",
27
  "scripts/run_sweep.sh",
28
  "scripts/run_sweep_rerun.sh",
 
29
  "pyproject.toml",
30
  "requirements.lock.txt",
31
  ]
 
22
  "configs/grow40_winning.toml",
23
  "configs/grow40_simple.toml",
24
  "configs/grow40_winning_v2.toml",
25
+ "configs/sweep/A_resume_lr1e7_cos.toml",
26
+ "configs/sweep/B_resume_lr5e8_cos.toml",
27
+ "configs/sweep/C_resume_lr2e8_cos.toml",
28
+ "configs/sweep/D_resume_lr1e7_const.toml",
29
+ "configs/sweep/E_resume_lr5e8_b95.toml",
30
+ "configs/sweep/F_cold_lr1e7_grow40.toml",
31
+ "configs/sweep/G_cold_lr2e7_grow40.toml",
32
+ "configs/sweep/H_cold_lr1e7_32L.toml",
33
+ "configs/sweep/I_cold_paramgroups_grow40.toml",
34
  "configs/accelerate.yaml",
35
  "scripts/backup_to_hf.py",
36
  "scripts/run_sweep.sh",
37
  "scripts/run_sweep_rerun.sh",
38
+ "scripts/run_hparam_sweep.sh",
39
  "pyproject.toml",
40
  "requirements.lock.txt",
41
  ]
scripts/run_hparam_sweep.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Hyperparameter sweep over 9 configs that try to push past grow40_winning's 0.2219.
3
+ #
4
+ # Each config grabs all 8 GPUs via accelerate, so they run sequentially.
5
+ # Output goes to logs/<run>.log; the master log goes to logs/sweep_hparam_master.log.
6
+ # Reads HF_TOKEN, HUGGING_FACE_HUB_TOKEN, WANDB_API_KEY from the calling env.
7
+ #
8
+ # Launch in the background with:
9
+ # nohup ./scripts/run_hparam_sweep.sh > logs/sweep_hparam_master.log 2>&1 &
10
+
11
+ set -uo pipefail
12
+ cd "$(dirname "$0")/.."
13
+
14
+ CONFIGS=(
15
+ "configs/sweep/A_resume_lr1e7_cos.toml"
16
+ "configs/sweep/B_resume_lr5e8_cos.toml"
17
+ "configs/sweep/C_resume_lr2e8_cos.toml"
18
+ "configs/sweep/D_resume_lr1e7_const.toml"
19
+ "configs/sweep/E_resume_lr5e8_b95.toml"
20
+ "configs/sweep/F_cold_lr1e7_grow40.toml"
21
+ "configs/sweep/G_cold_lr2e7_grow40.toml"
22
+ "configs/sweep/H_cold_lr1e7_32L.toml"
23
+ "configs/sweep/I_cold_paramgroups_grow40.toml"
24
+ )
25
+
26
+ LOG_DIR="logs"
27
+ mkdir -p "$LOG_DIR"
28
+
29
+ for cfg in "${CONFIGS[@]}"; do
30
+ name="$(basename "$cfg" .toml)"
31
+ log="$LOG_DIR/$name.log"
32
+ echo ">>> [$(date '+%F %T')] starting $name -> $log"
33
+ .venv/bin/accelerate launch \
34
+ --config_file configs/accelerate.yaml \
35
+ distill.py \
36
+ --config "$cfg" \
37
+ > "$log" 2>&1
38
+ rc=$?
39
+ best_line=$(grep -E "Best eval KL" "$log" | tail -1)
40
+ echo "<<< [$(date '+%F %T')] finished $name (exit=$rc) ${best_line}"
41
+ if [[ $rc -ne 0 ]]; then
42
+ echo " last 12 lines of $log:"
43
+ tail -12 "$log" | sed 's/^/ /'
44
+ fi
45
+ done
46
+
47
+ echo ">>> [$(date '+%F %T')] hparam sweep complete"
48
+ echo ">>> summary of best eval KLs:"
49
+ for cfg in "${CONFIGS[@]}"; do
50
+ name="$(basename "$cfg" .toml)"
51
+ log="$LOG_DIR/$name.log"
52
+ best=$(grep -E "Best eval KL" "$log" | tail -1 | sed 's/.*Best eval KL = //')
53
+ printf " %-32s %s\n" "$name" "${best:-FAILED}"
54
+ done