Delta-Vector commited on
Commit
3f04365
·
verified ·
1 Parent(s): f6e42f8

add grow_layers, sweep configs (replicate_zero4, grow40_winning, grow40_simple), sweep runner

Browse files
configs/base.toml CHANGED
@@ -28,6 +28,9 @@ samples_per_step = 4
28
  max_steps = 5
29
  grad_checkpointing = true
30
  attn_implementation = "flash_attention_2"
 
 
 
31
 
32
  [eval]
33
  every_steps = 5
@@ -42,4 +45,5 @@ log_every = 1
42
  output_dir = "./out/smoketest"
43
 
44
  [init]
45
- zero_layers = []
 
 
28
  max_steps = 5
29
  grad_checkpointing = true
30
  attn_implementation = "flash_attention_2"
31
+ student_dtype = "bfloat16"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
 
35
  [eval]
36
  every_steps = 5
 
45
  output_dir = "./out/smoketest"
46
 
47
  [init]
48
+ zero_layers = []
49
+ target_num_layers = 32
configs/grow40_simple.toml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Grow student to 40 layers with the current (bf16, seq=640) hparams.
2
+ # Tests the architectural change in isolation without the winning hparams,
3
+ # so we can attribute any improvement.
4
+
5
+ [model]
6
+ teacher = "Qwen/Qwen3.5-35B-A3B"
7
+ student = "Troiaaa/m-6a3lnzvb"
8
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
9
+
10
+ [data]
11
+ dataset = "karpathy/climbmix-400b-shuffle"
12
+ text_field = "text"
13
+ min_chars = 2560
14
+ max_seq_len = 640
15
+ kl_start_pos = 128
16
+ seed = 42
17
+ shuffle_buffer = 10000
18
+
19
+ [train]
20
+ seed = 42
21
+ lr = 5.0e-7
22
+ schedule = "constant"
23
+ warmup_steps = 0
24
+ weight_decay = 0.0
25
+ grad_clip = 1.0
26
+ betas = [0.9, 0.95]
27
+ eps = 1.0e-8
28
+ samples_per_step = 8
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+
36
+ [eval]
37
+ every_steps = 50
38
+ samples = 64
39
+ seed = 1234
40
+
41
+ [log]
42
+ wandb = true
43
+ wandb_project = "distil-subnet97"
44
+ wandb_run = "grow40_simple"
45
+ log_every = 1
46
+ output_dir = "./out/grow40_simple"
47
+
48
+ [init]
49
+ zero_layers = []
50
+ target_num_layers = 40
configs/grow40_winning.toml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Grow student to 40 layers AND apply the winning hparams from zero4_long.
2
+ # New layers (32-39) are appended at the end with output projections zeroed
3
+ # (identity at init, gradients still flow). No layer zeroing.
4
+
5
+ [model]
6
+ teacher = "Qwen/Qwen3.5-35B-A3B"
7
+ student = "Troiaaa/m-6a3lnzvb"
8
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
9
+
10
+ [data]
11
+ dataset = "karpathy/climbmix-400b-shuffle"
12
+ text_field = "text"
13
+ min_chars = 2560
14
+ max_seq_len = 2048
15
+ kl_start_pos = 128
16
+ seed = 6767
17
+ shuffle_buffer = 10000
18
+
19
+ [train]
20
+ seed = 6767
21
+ lr = 5.0e-7
22
+ schedule = "cosine"
23
+ warmup_steps = 100
24
+ weight_decay = 0.0
25
+ grad_clip = 1.0
26
+ betas = [0.9, 0.999]
27
+ eps = 1.0e-3
28
+ samples_per_step = 4
29
+ max_steps = 2000
30
+ grad_checkpointing = true
31
+ attn_implementation = "flash_attention_2"
32
+ student_dtype = "float32"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
+
36
+ [eval]
37
+ every_steps = 50
38
+ samples = 500
39
+ seed = 4242
40
+
41
+ [log]
42
+ wandb = true
43
+ wandb_project = "distil-subnet97"
44
+ wandb_run = "grow40_winning"
45
+ log_every = 1
46
+ output_dir = "./out/grow40_winning"
47
+
48
+ [init]
49
+ zero_layers = []
50
+ target_num_layers = 40
configs/replicate_zero4.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Replicates wandb run "zero4_long" (mepqfry1, eval kl 0.275).
2
+ # Same hparams as that run; same 4-layer zero (14-17). 32-layer student.
3
+
4
+ [model]
5
+ teacher = "Qwen/Qwen3.5-35B-A3B"
6
+ student = "Troiaaa/m-6a3lnzvb"
7
+ tokenizer = "Qwen/Qwen3.5-35B-A3B"
8
+
9
+ [data]
10
+ dataset = "karpathy/climbmix-400b-shuffle"
11
+ text_field = "text"
12
+ min_chars = 2560
13
+ max_seq_len = 2048
14
+ kl_start_pos = 128
15
+ seed = 6767
16
+ shuffle_buffer = 10000
17
+
18
+ [train]
19
+ seed = 6767
20
+ lr = 5.0e-7
21
+ schedule = "cosine"
22
+ warmup_steps = 100
23
+ weight_decay = 0.0
24
+ grad_clip = 1.0
25
+ betas = [0.9, 0.999]
26
+ eps = 1.0e-3
27
+ samples_per_step = 4
28
+ max_steps = 2000
29
+ grad_checkpointing = true
30
+ attn_implementation = "flash_attention_2"
31
+ student_dtype = "float32"
32
+ teacher_dtype = "bfloat16"
33
+ mixed_precision = "bf16"
34
+
35
+ [eval]
36
+ every_steps = 50
37
+ samples = 500
38
+ seed = 4242
39
+
40
+ [log]
41
+ wandb = true
42
+ wandb_project = "distil-subnet97"
43
+ wandb_run = "replicate_zero4"
44
+ log_every = 1
45
+ output_dir = "./out/replicate_zero4"
46
+
47
+ [init]
48
+ zero_layers = [14, 15, 16, 17]
49
+ target_num_layers = 32
configs/zero_14_17.toml CHANGED
@@ -29,6 +29,9 @@ samples_per_step = 8
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
 
 
 
32
 
33
  [eval]
34
  every_steps = 50
@@ -43,4 +46,5 @@ log_every = 1
43
  output_dir = "./out/zero_14_17"
44
 
45
  [init]
46
- zero_layers = [14, 15, 16, 17]
 
 
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
32
+ student_dtype = "bfloat16"
33
+ teacher_dtype = "bfloat16"
34
+ mixed_precision = "bf16"
35
 
36
  [eval]
37
  every_steps = 50
 
46
  output_dir = "./out/zero_14_17"
47
 
48
  [init]
49
+ zero_layers = [14, 15, 16, 17]
50
+ target_num_layers = 32
distill.py CHANGED
@@ -62,12 +62,26 @@ REQUIRED_KEYS = {
62
  "max_steps",
63
  "grad_checkpointing",
64
  "attn_implementation",
 
 
 
65
  ),
66
  "eval": ("every_steps", "samples", "seed"),
67
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
68
- "init": ("zero_layers",),
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def load_config(path):
73
  with open(path, "rb") as f:
@@ -117,9 +131,82 @@ def zero_layers(model, layer_indices):
117
  return n
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def load_student(model_id, dtype, grad_ckpt, attn_impl):
121
  from transformers import AutoModelForCausalLM
122
- log.info(f"Loading student: {model_id}")
123
  model = AutoModelForCausalLM.from_pretrained(
124
  model_id,
125
  dtype=dtype,
@@ -142,7 +229,7 @@ def load_teacher(model_id, dtype, attn_impl):
142
  archs = list(getattr(cfg, "architectures", []) or [])
143
  arch = archs[0] if archs else ""
144
  is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
145
- log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal})")
146
 
147
  if is_multimodal:
148
  from transformers import AutoModelForImageTextToText
@@ -354,12 +441,13 @@ def main():
354
 
355
  cfg = load_config(args.config)
356
 
357
- accelerator = Accelerator(mixed_precision="bf16")
358
  set_seed(cfg["train"]["seed"])
359
 
360
  if accelerator.is_main_process:
361
  log.info(f"Loaded config from {args.config}")
362
  log.info(f"World size: {accelerator.num_processes}")
 
363
 
364
  # ---- Tokenizer
365
  from transformers import AutoTokenizer
@@ -368,21 +456,31 @@ def main():
368
  tokenizer.pad_token = tokenizer.eos_token
369
  pad_id = tokenizer.pad_token_id
370
 
371
- # ---- Models
372
- dtype = torch.bfloat16
 
373
  student = load_student(
374
  cfg["model"]["student"],
375
- dtype,
376
  grad_ckpt=cfg["train"]["grad_checkpointing"],
377
  attn_impl=cfg["train"]["attn_implementation"],
378
  )
379
  teacher = load_teacher(
380
  cfg["model"]["teacher"],
381
- dtype,
382
  attn_impl=cfg["train"]["attn_implementation"],
383
  )
384
 
385
- # ---- Layer modifications (post-load, pre-prepare)
 
 
 
 
 
 
 
 
 
386
  zero_idx = cfg["init"]["zero_layers"]
387
  if zero_idx:
388
  n = zero_layers(student, zero_idx)
 
62
  "max_steps",
63
  "grad_checkpointing",
64
  "attn_implementation",
65
+ "student_dtype",
66
+ "teacher_dtype",
67
+ "mixed_precision",
68
  ),
69
  "eval": ("every_steps", "samples", "seed"),
70
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
71
+ "init": ("zero_layers", "target_num_layers"),
72
  }
73
 
74
+ DTYPE_MAP = {
75
+ "float32": torch.float32,
76
+ "bfloat16": torch.bfloat16,
77
+ }
78
+
79
+
80
+ def parse_dtype(s):
81
+ if s not in DTYPE_MAP:
82
+ raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}")
83
+ return DTYPE_MAP[s]
84
+
85
 
86
  def load_config(path):
87
  with open(path, "rb") as f:
 
131
  return n
132
 
133
 
134
+ def _zero_output_projections(layer):
135
+ """Zero out attention and MLP output projections so the layer is identity
136
+ at init while still allowing gradients to flow into o_proj/down_proj first
137
+ (and from there back into the rest of the layer's params after one step).
138
+
139
+ Knows about Qwen3.5 names: self_attn.o_proj (full attention),
140
+ linear_attn.out_proj (linear attention), mlp.down_proj.
141
+ """
142
+ zeroed = []
143
+ with torch.no_grad():
144
+ if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"):
145
+ layer.self_attn.o_proj.weight.zero_()
146
+ zeroed.append("self_attn.o_proj")
147
+ if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"):
148
+ layer.linear_attn.out_proj.weight.zero_()
149
+ zeroed.append("linear_attn.out_proj")
150
+ if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"):
151
+ layer.mlp.down_proj.weight.zero_()
152
+ zeroed.append("mlp.down_proj")
153
+ return zeroed
154
+
155
+
156
+ def grow_layers(model, target_n):
157
+ """Grow the student to `target_n` decoder layers by appending new ones at the end.
158
+
159
+ New layers are constructed via the existing decoder layer class with the model's
160
+ own _init_weights, then their output projections are zeroed so each new layer
161
+ starts as the identity but is still trainable.
162
+ """
163
+ inner = get_inner_with_layers(model)
164
+ cur_n = len(inner.layers)
165
+ if target_n == cur_n:
166
+ return cur_n
167
+ if target_n < cur_n:
168
+ raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink")
169
+
170
+ # Locate the (text) config that the layers are built from. For multimodal
171
+ # wrappers this lives at .text_config; for the dense student it's the same
172
+ # object as model.config.
173
+ cfg = model.config
174
+ text_cfg = getattr(cfg, "text_config", cfg)
175
+
176
+ # Extend layer_types by repeating the existing periodic pattern
177
+ if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types:
178
+ raise RuntimeError("text config has no layer_types; cannot extend pattern")
179
+ period = getattr(text_cfg, "full_attention_interval", 4)
180
+ new_types = list(text_cfg.layer_types)
181
+ while len(new_types) < target_n:
182
+ new_types.append(new_types[len(new_types) % period])
183
+ text_cfg.layer_types = new_types
184
+ text_cfg.num_hidden_layers = target_n
185
+ if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg:
186
+ cfg.num_hidden_layers = target_n
187
+
188
+ # Construct new layers using the same class as the existing ones
189
+ layer_cls = type(inner.layers[0])
190
+ device = next(inner.parameters()).device
191
+ dtype = next(inner.parameters()).dtype
192
+
193
+ new_layer_zeroed = []
194
+ for i in range(cur_n, target_n):
195
+ new_layer = layer_cls(text_cfg, layer_idx=i)
196
+ # Apply the parent model's init scheme (std=initializer_range etc.)
197
+ new_layer.apply(model._init_weights)
198
+ new_layer.to(device=device, dtype=dtype)
199
+ # Zero output projections -> identity at init, gradients still flow
200
+ zeroed = _zero_output_projections(new_layer)
201
+ new_layer_zeroed.append((i, zeroed))
202
+ inner.layers.append(new_layer)
203
+
204
+ return target_n, new_layer_zeroed
205
+
206
+
207
  def load_student(model_id, dtype, grad_ckpt, attn_impl):
208
  from transformers import AutoModelForCausalLM
209
+ log.info(f"Loading student: {model_id} (dtype={dtype})")
210
  model = AutoModelForCausalLM.from_pretrained(
211
  model_id,
212
  dtype=dtype,
 
229
  archs = list(getattr(cfg, "architectures", []) or [])
230
  arch = archs[0] if archs else ""
231
  is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
232
+ log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal}, dtype={dtype})")
233
 
234
  if is_multimodal:
235
  from transformers import AutoModelForImageTextToText
 
441
 
442
  cfg = load_config(args.config)
443
 
444
+ accelerator = Accelerator(mixed_precision=cfg["train"]["mixed_precision"])
445
  set_seed(cfg["train"]["seed"])
446
 
447
  if accelerator.is_main_process:
448
  log.info(f"Loaded config from {args.config}")
449
  log.info(f"World size: {accelerator.num_processes}")
450
+ log.info(f"Mixed precision: {cfg['train']['mixed_precision']}")
451
 
452
  # ---- Tokenizer
453
  from transformers import AutoTokenizer
 
456
  tokenizer.pad_token = tokenizer.eos_token
457
  pad_id = tokenizer.pad_token_id
458
 
459
+ # ---- Models (separate dtypes per config)
460
+ student_dtype = parse_dtype(cfg["train"]["student_dtype"])
461
+ teacher_dtype = parse_dtype(cfg["train"]["teacher_dtype"])
462
  student = load_student(
463
  cfg["model"]["student"],
464
+ student_dtype,
465
  grad_ckpt=cfg["train"]["grad_checkpointing"],
466
  attn_impl=cfg["train"]["attn_implementation"],
467
  )
468
  teacher = load_teacher(
469
  cfg["model"]["teacher"],
470
+ teacher_dtype,
471
  attn_impl=cfg["train"]["attn_implementation"],
472
  )
473
 
474
+ # ---- Layer modifications: grow first, then zero (composable)
475
+ target_n = cfg["init"]["target_num_layers"]
476
+ cur_n = len(get_inner_with_layers(student).layers)
477
+ if target_n != cur_n:
478
+ new_n, new_zeroed = grow_layers(student, target_n)
479
+ if accelerator.is_main_process:
480
+ log.info(f"Grew student from {cur_n} -> {new_n} layers")
481
+ for idx, names in new_zeroed:
482
+ log.info(f" layer {idx}: zeroed {names}")
483
+
484
  zero_idx = cfg["init"]["zero_layers"]
485
  if zero_idx:
486
  n = zero_layers(student, zero_idx)
scripts/backup_to_hf.py CHANGED
@@ -18,8 +18,12 @@ INCLUDE = [
18
  "distill.py",
19
  "configs/base.toml",
20
  "configs/zero_14_17.toml",
 
 
 
21
  "configs/accelerate.yaml",
22
  "scripts/backup_to_hf.py",
 
23
  "pyproject.toml",
24
  "requirements.lock.txt",
25
  ]
 
18
  "distill.py",
19
  "configs/base.toml",
20
  "configs/zero_14_17.toml",
21
+ "configs/replicate_zero4.toml",
22
+ "configs/grow40_winning.toml",
23
+ "configs/grow40_simple.toml",
24
  "configs/accelerate.yaml",
25
  "scripts/backup_to_hf.py",
26
+ "scripts/run_sweep.sh",
27
  "pyproject.toml",
28
  "requirements.lock.txt",
29
  ]
scripts/run_sweep.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Sequential sweep runner.
3
+ #
4
+ # Each config grabs all 8 GPUs via accelerate, so they run back-to-back, not in
5
+ # parallel. Output goes to logs/<run>.log; the master log goes to logs/sweep_master.log.
6
+ # Reads HF_TOKEN, HUGGING_FACE_HUB_TOKEN, WANDB_API_KEY from the calling env.
7
+ #
8
+ # Launch in the background with:
9
+ # nohup ./scripts/run_sweep.sh > logs/sweep_master.log 2>&1 &
10
+
11
+ set -uo pipefail
12
+ cd "$(dirname "$0")/.."
13
+
14
+ CONFIGS=(
15
+ "configs/replicate_zero4.toml"
16
+ "configs/grow40_winning.toml"
17
+ "configs/grow40_simple.toml"
18
+ )
19
+
20
+ LOG_DIR="logs"
21
+ mkdir -p "$LOG_DIR"
22
+
23
+ for cfg in "${CONFIGS[@]}"; do
24
+ name="$(basename "$cfg" .toml)"
25
+ log="$LOG_DIR/$name.log"
26
+ echo ">>> [$(date '+%F %T')] starting $name -> $log"
27
+ .venv/bin/accelerate launch \
28
+ --config_file configs/accelerate.yaml \
29
+ distill.py \
30
+ --config "$cfg" \
31
+ > "$log" 2>&1
32
+ rc=$?
33
+ echo "<<< [$(date '+%F %T')] finished $name (exit=$rc)"
34
+ if [[ $rc -ne 0 ]]; then
35
+ echo " last 20 lines of $log:"
36
+ tail -20 "$log" | sed 's/^/ /'
37
+ fi
38
+ done
39
+
40
+ echo ">>> [$(date '+%F %T')] sweep complete"