add grow_layers, sweep configs (replicate_zero4, grow40_winning, grow40_simple), sweep runner
Browse files- configs/base.toml +5 -1
- configs/grow40_simple.toml +50 -0
- configs/grow40_winning.toml +50 -0
- configs/replicate_zero4.toml +49 -0
- configs/zero_14_17.toml +5 -1
- distill.py +107 -9
- scripts/backup_to_hf.py +4 -0
- scripts/run_sweep.sh +40 -0
configs/base.toml
CHANGED
|
@@ -28,6 +28,9 @@ samples_per_step = 4
|
|
| 28 |
max_steps = 5
|
| 29 |
grad_checkpointing = true
|
| 30 |
attn_implementation = "flash_attention_2"
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
[eval]
|
| 33 |
every_steps = 5
|
|
@@ -42,4 +45,5 @@ log_every = 1
|
|
| 42 |
output_dir = "./out/smoketest"
|
| 43 |
|
| 44 |
[init]
|
| 45 |
-
zero_layers
|
|
|
|
|
|
| 28 |
max_steps = 5
|
| 29 |
grad_checkpointing = true
|
| 30 |
attn_implementation = "flash_attention_2"
|
| 31 |
+
student_dtype = "bfloat16"
|
| 32 |
+
teacher_dtype = "bfloat16"
|
| 33 |
+
mixed_precision = "bf16"
|
| 34 |
|
| 35 |
[eval]
|
| 36 |
every_steps = 5
|
|
|
|
| 45 |
output_dir = "./out/smoketest"
|
| 46 |
|
| 47 |
[init]
|
| 48 |
+
zero_layers = []
|
| 49 |
+
target_num_layers = 32
|
configs/grow40_simple.toml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Grow student to 40 layers with the current (bf16, seq=640) hparams.
|
| 2 |
+
# Tests the architectural change in isolation without the winning hparams,
|
| 3 |
+
# so we can attribute any improvement.
|
| 4 |
+
|
| 5 |
+
[model]
|
| 6 |
+
teacher = "Qwen/Qwen3.5-35B-A3B"
|
| 7 |
+
student = "Troiaaa/m-6a3lnzvb"
|
| 8 |
+
tokenizer = "Qwen/Qwen3.5-35B-A3B"
|
| 9 |
+
|
| 10 |
+
[data]
|
| 11 |
+
dataset = "karpathy/climbmix-400b-shuffle"
|
| 12 |
+
text_field = "text"
|
| 13 |
+
min_chars = 2560
|
| 14 |
+
max_seq_len = 640
|
| 15 |
+
kl_start_pos = 128
|
| 16 |
+
seed = 42
|
| 17 |
+
shuffle_buffer = 10000
|
| 18 |
+
|
| 19 |
+
[train]
|
| 20 |
+
seed = 42
|
| 21 |
+
lr = 5.0e-7
|
| 22 |
+
schedule = "constant"
|
| 23 |
+
warmup_steps = 0
|
| 24 |
+
weight_decay = 0.0
|
| 25 |
+
grad_clip = 1.0
|
| 26 |
+
betas = [0.9, 0.95]
|
| 27 |
+
eps = 1.0e-8
|
| 28 |
+
samples_per_step = 8
|
| 29 |
+
max_steps = 2000
|
| 30 |
+
grad_checkpointing = true
|
| 31 |
+
attn_implementation = "flash_attention_2"
|
| 32 |
+
student_dtype = "bfloat16"
|
| 33 |
+
teacher_dtype = "bfloat16"
|
| 34 |
+
mixed_precision = "bf16"
|
| 35 |
+
|
| 36 |
+
[eval]
|
| 37 |
+
every_steps = 50
|
| 38 |
+
samples = 64
|
| 39 |
+
seed = 1234
|
| 40 |
+
|
| 41 |
+
[log]
|
| 42 |
+
wandb = true
|
| 43 |
+
wandb_project = "distil-subnet97"
|
| 44 |
+
wandb_run = "grow40_simple"
|
| 45 |
+
log_every = 1
|
| 46 |
+
output_dir = "./out/grow40_simple"
|
| 47 |
+
|
| 48 |
+
[init]
|
| 49 |
+
zero_layers = []
|
| 50 |
+
target_num_layers = 40
|
configs/grow40_winning.toml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Grow student to 40 layers AND apply the winning hparams from zero4_long.
|
| 2 |
+
# New layers (32-39) are appended at the end with output projections zeroed
|
| 3 |
+
# (identity at init, gradients still flow). No layer zeroing.
|
| 4 |
+
|
| 5 |
+
[model]
|
| 6 |
+
teacher = "Qwen/Qwen3.5-35B-A3B"
|
| 7 |
+
student = "Troiaaa/m-6a3lnzvb"
|
| 8 |
+
tokenizer = "Qwen/Qwen3.5-35B-A3B"
|
| 9 |
+
|
| 10 |
+
[data]
|
| 11 |
+
dataset = "karpathy/climbmix-400b-shuffle"
|
| 12 |
+
text_field = "text"
|
| 13 |
+
min_chars = 2560
|
| 14 |
+
max_seq_len = 2048
|
| 15 |
+
kl_start_pos = 128
|
| 16 |
+
seed = 6767
|
| 17 |
+
shuffle_buffer = 10000
|
| 18 |
+
|
| 19 |
+
[train]
|
| 20 |
+
seed = 6767
|
| 21 |
+
lr = 5.0e-7
|
| 22 |
+
schedule = "cosine"
|
| 23 |
+
warmup_steps = 100
|
| 24 |
+
weight_decay = 0.0
|
| 25 |
+
grad_clip = 1.0
|
| 26 |
+
betas = [0.9, 0.999]
|
| 27 |
+
eps = 1.0e-3
|
| 28 |
+
samples_per_step = 4
|
| 29 |
+
max_steps = 2000
|
| 30 |
+
grad_checkpointing = true
|
| 31 |
+
attn_implementation = "flash_attention_2"
|
| 32 |
+
student_dtype = "float32"
|
| 33 |
+
teacher_dtype = "bfloat16"
|
| 34 |
+
mixed_precision = "bf16"
|
| 35 |
+
|
| 36 |
+
[eval]
|
| 37 |
+
every_steps = 50
|
| 38 |
+
samples = 500
|
| 39 |
+
seed = 4242
|
| 40 |
+
|
| 41 |
+
[log]
|
| 42 |
+
wandb = true
|
| 43 |
+
wandb_project = "distil-subnet97"
|
| 44 |
+
wandb_run = "grow40_winning"
|
| 45 |
+
log_every = 1
|
| 46 |
+
output_dir = "./out/grow40_winning"
|
| 47 |
+
|
| 48 |
+
[init]
|
| 49 |
+
zero_layers = []
|
| 50 |
+
target_num_layers = 40
|
configs/replicate_zero4.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Replicates wandb run "zero4_long" (mepqfry1, eval kl 0.275).
|
| 2 |
+
# Same hparams as that run; same 4-layer zero (14-17). 32-layer student.
|
| 3 |
+
|
| 4 |
+
[model]
|
| 5 |
+
teacher = "Qwen/Qwen3.5-35B-A3B"
|
| 6 |
+
student = "Troiaaa/m-6a3lnzvb"
|
| 7 |
+
tokenizer = "Qwen/Qwen3.5-35B-A3B"
|
| 8 |
+
|
| 9 |
+
[data]
|
| 10 |
+
dataset = "karpathy/climbmix-400b-shuffle"
|
| 11 |
+
text_field = "text"
|
| 12 |
+
min_chars = 2560
|
| 13 |
+
max_seq_len = 2048
|
| 14 |
+
kl_start_pos = 128
|
| 15 |
+
seed = 6767
|
| 16 |
+
shuffle_buffer = 10000
|
| 17 |
+
|
| 18 |
+
[train]
|
| 19 |
+
seed = 6767
|
| 20 |
+
lr = 5.0e-7
|
| 21 |
+
schedule = "cosine"
|
| 22 |
+
warmup_steps = 100
|
| 23 |
+
weight_decay = 0.0
|
| 24 |
+
grad_clip = 1.0
|
| 25 |
+
betas = [0.9, 0.999]
|
| 26 |
+
eps = 1.0e-3
|
| 27 |
+
samples_per_step = 4
|
| 28 |
+
max_steps = 2000
|
| 29 |
+
grad_checkpointing = true
|
| 30 |
+
attn_implementation = "flash_attention_2"
|
| 31 |
+
student_dtype = "float32"
|
| 32 |
+
teacher_dtype = "bfloat16"
|
| 33 |
+
mixed_precision = "bf16"
|
| 34 |
+
|
| 35 |
+
[eval]
|
| 36 |
+
every_steps = 50
|
| 37 |
+
samples = 500
|
| 38 |
+
seed = 4242
|
| 39 |
+
|
| 40 |
+
[log]
|
| 41 |
+
wandb = true
|
| 42 |
+
wandb_project = "distil-subnet97"
|
| 43 |
+
wandb_run = "replicate_zero4"
|
| 44 |
+
log_every = 1
|
| 45 |
+
output_dir = "./out/replicate_zero4"
|
| 46 |
+
|
| 47 |
+
[init]
|
| 48 |
+
zero_layers = [14, 15, 16, 17]
|
| 49 |
+
target_num_layers = 32
|
configs/zero_14_17.toml
CHANGED
|
@@ -29,6 +29,9 @@ samples_per_step = 8
|
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
[eval]
|
| 34 |
every_steps = 50
|
|
@@ -43,4 +46,5 @@ log_every = 1
|
|
| 43 |
output_dir = "./out/zero_14_17"
|
| 44 |
|
| 45 |
[init]
|
| 46 |
-
zero_layers
|
|
|
|
|
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
| 32 |
+
student_dtype = "bfloat16"
|
| 33 |
+
teacher_dtype = "bfloat16"
|
| 34 |
+
mixed_precision = "bf16"
|
| 35 |
|
| 36 |
[eval]
|
| 37 |
every_steps = 50
|
|
|
|
| 46 |
output_dir = "./out/zero_14_17"
|
| 47 |
|
| 48 |
[init]
|
| 49 |
+
zero_layers = [14, 15, 16, 17]
|
| 50 |
+
target_num_layers = 32
|
distill.py
CHANGED
|
@@ -62,12 +62,26 @@ REQUIRED_KEYS = {
|
|
| 62 |
"max_steps",
|
| 63 |
"grad_checkpointing",
|
| 64 |
"attn_implementation",
|
|
|
|
|
|
|
|
|
|
| 65 |
),
|
| 66 |
"eval": ("every_steps", "samples", "seed"),
|
| 67 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
| 68 |
-
"init": ("zero_layers",),
|
| 69 |
}
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def load_config(path):
|
| 73 |
with open(path, "rb") as f:
|
|
@@ -117,9 +131,82 @@ def zero_layers(model, layer_indices):
|
|
| 117 |
return n
|
| 118 |
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def load_student(model_id, dtype, grad_ckpt, attn_impl):
|
| 121 |
from transformers import AutoModelForCausalLM
|
| 122 |
-
log.info(f"Loading student: {model_id}")
|
| 123 |
model = AutoModelForCausalLM.from_pretrained(
|
| 124 |
model_id,
|
| 125 |
dtype=dtype,
|
|
@@ -142,7 +229,7 @@ def load_teacher(model_id, dtype, attn_impl):
|
|
| 142 |
archs = list(getattr(cfg, "architectures", []) or [])
|
| 143 |
arch = archs[0] if archs else ""
|
| 144 |
is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
|
| 145 |
-
log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal})")
|
| 146 |
|
| 147 |
if is_multimodal:
|
| 148 |
from transformers import AutoModelForImageTextToText
|
|
@@ -354,12 +441,13 @@ def main():
|
|
| 354 |
|
| 355 |
cfg = load_config(args.config)
|
| 356 |
|
| 357 |
-
accelerator = Accelerator(mixed_precision="
|
| 358 |
set_seed(cfg["train"]["seed"])
|
| 359 |
|
| 360 |
if accelerator.is_main_process:
|
| 361 |
log.info(f"Loaded config from {args.config}")
|
| 362 |
log.info(f"World size: {accelerator.num_processes}")
|
|
|
|
| 363 |
|
| 364 |
# ---- Tokenizer
|
| 365 |
from transformers import AutoTokenizer
|
|
@@ -368,21 +456,31 @@ def main():
|
|
| 368 |
tokenizer.pad_token = tokenizer.eos_token
|
| 369 |
pad_id = tokenizer.pad_token_id
|
| 370 |
|
| 371 |
-
# ---- Models
|
| 372 |
-
|
|
|
|
| 373 |
student = load_student(
|
| 374 |
cfg["model"]["student"],
|
| 375 |
-
|
| 376 |
grad_ckpt=cfg["train"]["grad_checkpointing"],
|
| 377 |
attn_impl=cfg["train"]["attn_implementation"],
|
| 378 |
)
|
| 379 |
teacher = load_teacher(
|
| 380 |
cfg["model"]["teacher"],
|
| 381 |
-
|
| 382 |
attn_impl=cfg["train"]["attn_implementation"],
|
| 383 |
)
|
| 384 |
|
| 385 |
-
# ---- Layer modifications
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
zero_idx = cfg["init"]["zero_layers"]
|
| 387 |
if zero_idx:
|
| 388 |
n = zero_layers(student, zero_idx)
|
|
|
|
| 62 |
"max_steps",
|
| 63 |
"grad_checkpointing",
|
| 64 |
"attn_implementation",
|
| 65 |
+
"student_dtype",
|
| 66 |
+
"teacher_dtype",
|
| 67 |
+
"mixed_precision",
|
| 68 |
),
|
| 69 |
"eval": ("every_steps", "samples", "seed"),
|
| 70 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
| 71 |
+
"init": ("zero_layers", "target_num_layers"),
|
| 72 |
}
|
| 73 |
|
| 74 |
+
DTYPE_MAP = {
|
| 75 |
+
"float32": torch.float32,
|
| 76 |
+
"bfloat16": torch.bfloat16,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def parse_dtype(s):
|
| 81 |
+
if s not in DTYPE_MAP:
|
| 82 |
+
raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}")
|
| 83 |
+
return DTYPE_MAP[s]
|
| 84 |
+
|
| 85 |
|
| 86 |
def load_config(path):
|
| 87 |
with open(path, "rb") as f:
|
|
|
|
| 131 |
return n
|
| 132 |
|
| 133 |
|
| 134 |
+
def _zero_output_projections(layer):
|
| 135 |
+
"""Zero out attention and MLP output projections so the layer is identity
|
| 136 |
+
at init while still allowing gradients to flow into o_proj/down_proj first
|
| 137 |
+
(and from there back into the rest of the layer's params after one step).
|
| 138 |
+
|
| 139 |
+
Knows about Qwen3.5 names: self_attn.o_proj (full attention),
|
| 140 |
+
linear_attn.out_proj (linear attention), mlp.down_proj.
|
| 141 |
+
"""
|
| 142 |
+
zeroed = []
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"):
|
| 145 |
+
layer.self_attn.o_proj.weight.zero_()
|
| 146 |
+
zeroed.append("self_attn.o_proj")
|
| 147 |
+
if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"):
|
| 148 |
+
layer.linear_attn.out_proj.weight.zero_()
|
| 149 |
+
zeroed.append("linear_attn.out_proj")
|
| 150 |
+
if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"):
|
| 151 |
+
layer.mlp.down_proj.weight.zero_()
|
| 152 |
+
zeroed.append("mlp.down_proj")
|
| 153 |
+
return zeroed
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def grow_layers(model, target_n):
|
| 157 |
+
"""Grow the student to `target_n` decoder layers by appending new ones at the end.
|
| 158 |
+
|
| 159 |
+
New layers are constructed via the existing decoder layer class with the model's
|
| 160 |
+
own _init_weights, then their output projections are zeroed so each new layer
|
| 161 |
+
starts as the identity but is still trainable.
|
| 162 |
+
"""
|
| 163 |
+
inner = get_inner_with_layers(model)
|
| 164 |
+
cur_n = len(inner.layers)
|
| 165 |
+
if target_n == cur_n:
|
| 166 |
+
return cur_n
|
| 167 |
+
if target_n < cur_n:
|
| 168 |
+
raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink")
|
| 169 |
+
|
| 170 |
+
# Locate the (text) config that the layers are built from. For multimodal
|
| 171 |
+
# wrappers this lives at .text_config; for the dense student it's the same
|
| 172 |
+
# object as model.config.
|
| 173 |
+
cfg = model.config
|
| 174 |
+
text_cfg = getattr(cfg, "text_config", cfg)
|
| 175 |
+
|
| 176 |
+
# Extend layer_types by repeating the existing periodic pattern
|
| 177 |
+
if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types:
|
| 178 |
+
raise RuntimeError("text config has no layer_types; cannot extend pattern")
|
| 179 |
+
period = getattr(text_cfg, "full_attention_interval", 4)
|
| 180 |
+
new_types = list(text_cfg.layer_types)
|
| 181 |
+
while len(new_types) < target_n:
|
| 182 |
+
new_types.append(new_types[len(new_types) % period])
|
| 183 |
+
text_cfg.layer_types = new_types
|
| 184 |
+
text_cfg.num_hidden_layers = target_n
|
| 185 |
+
if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg:
|
| 186 |
+
cfg.num_hidden_layers = target_n
|
| 187 |
+
|
| 188 |
+
# Construct new layers using the same class as the existing ones
|
| 189 |
+
layer_cls = type(inner.layers[0])
|
| 190 |
+
device = next(inner.parameters()).device
|
| 191 |
+
dtype = next(inner.parameters()).dtype
|
| 192 |
+
|
| 193 |
+
new_layer_zeroed = []
|
| 194 |
+
for i in range(cur_n, target_n):
|
| 195 |
+
new_layer = layer_cls(text_cfg, layer_idx=i)
|
| 196 |
+
# Apply the parent model's init scheme (std=initializer_range etc.)
|
| 197 |
+
new_layer.apply(model._init_weights)
|
| 198 |
+
new_layer.to(device=device, dtype=dtype)
|
| 199 |
+
# Zero output projections -> identity at init, gradients still flow
|
| 200 |
+
zeroed = _zero_output_projections(new_layer)
|
| 201 |
+
new_layer_zeroed.append((i, zeroed))
|
| 202 |
+
inner.layers.append(new_layer)
|
| 203 |
+
|
| 204 |
+
return target_n, new_layer_zeroed
|
| 205 |
+
|
| 206 |
+
|
| 207 |
def load_student(model_id, dtype, grad_ckpt, attn_impl):
|
| 208 |
from transformers import AutoModelForCausalLM
|
| 209 |
+
log.info(f"Loading student: {model_id} (dtype={dtype})")
|
| 210 |
model = AutoModelForCausalLM.from_pretrained(
|
| 211 |
model_id,
|
| 212 |
dtype=dtype,
|
|
|
|
| 229 |
archs = list(getattr(cfg, "architectures", []) or [])
|
| 230 |
arch = archs[0] if archs else ""
|
| 231 |
is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
|
| 232 |
+
log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal}, dtype={dtype})")
|
| 233 |
|
| 234 |
if is_multimodal:
|
| 235 |
from transformers import AutoModelForImageTextToText
|
|
|
|
| 441 |
|
| 442 |
cfg = load_config(args.config)
|
| 443 |
|
| 444 |
+
accelerator = Accelerator(mixed_precision=cfg["train"]["mixed_precision"])
|
| 445 |
set_seed(cfg["train"]["seed"])
|
| 446 |
|
| 447 |
if accelerator.is_main_process:
|
| 448 |
log.info(f"Loaded config from {args.config}")
|
| 449 |
log.info(f"World size: {accelerator.num_processes}")
|
| 450 |
+
log.info(f"Mixed precision: {cfg['train']['mixed_precision']}")
|
| 451 |
|
| 452 |
# ---- Tokenizer
|
| 453 |
from transformers import AutoTokenizer
|
|
|
|
| 456 |
tokenizer.pad_token = tokenizer.eos_token
|
| 457 |
pad_id = tokenizer.pad_token_id
|
| 458 |
|
| 459 |
+
# ---- Models (separate dtypes per config)
|
| 460 |
+
student_dtype = parse_dtype(cfg["train"]["student_dtype"])
|
| 461 |
+
teacher_dtype = parse_dtype(cfg["train"]["teacher_dtype"])
|
| 462 |
student = load_student(
|
| 463 |
cfg["model"]["student"],
|
| 464 |
+
student_dtype,
|
| 465 |
grad_ckpt=cfg["train"]["grad_checkpointing"],
|
| 466 |
attn_impl=cfg["train"]["attn_implementation"],
|
| 467 |
)
|
| 468 |
teacher = load_teacher(
|
| 469 |
cfg["model"]["teacher"],
|
| 470 |
+
teacher_dtype,
|
| 471 |
attn_impl=cfg["train"]["attn_implementation"],
|
| 472 |
)
|
| 473 |
|
| 474 |
+
# ---- Layer modifications: grow first, then zero (composable)
|
| 475 |
+
target_n = cfg["init"]["target_num_layers"]
|
| 476 |
+
cur_n = len(get_inner_with_layers(student).layers)
|
| 477 |
+
if target_n != cur_n:
|
| 478 |
+
new_n, new_zeroed = grow_layers(student, target_n)
|
| 479 |
+
if accelerator.is_main_process:
|
| 480 |
+
log.info(f"Grew student from {cur_n} -> {new_n} layers")
|
| 481 |
+
for idx, names in new_zeroed:
|
| 482 |
+
log.info(f" layer {idx}: zeroed {names}")
|
| 483 |
+
|
| 484 |
zero_idx = cfg["init"]["zero_layers"]
|
| 485 |
if zero_idx:
|
| 486 |
n = zero_layers(student, zero_idx)
|
scripts/backup_to_hf.py
CHANGED
|
@@ -18,8 +18,12 @@ INCLUDE = [
|
|
| 18 |
"distill.py",
|
| 19 |
"configs/base.toml",
|
| 20 |
"configs/zero_14_17.toml",
|
|
|
|
|
|
|
|
|
|
| 21 |
"configs/accelerate.yaml",
|
| 22 |
"scripts/backup_to_hf.py",
|
|
|
|
| 23 |
"pyproject.toml",
|
| 24 |
"requirements.lock.txt",
|
| 25 |
]
|
|
|
|
| 18 |
"distill.py",
|
| 19 |
"configs/base.toml",
|
| 20 |
"configs/zero_14_17.toml",
|
| 21 |
+
"configs/replicate_zero4.toml",
|
| 22 |
+
"configs/grow40_winning.toml",
|
| 23 |
+
"configs/grow40_simple.toml",
|
| 24 |
"configs/accelerate.yaml",
|
| 25 |
"scripts/backup_to_hf.py",
|
| 26 |
+
"scripts/run_sweep.sh",
|
| 27 |
"pyproject.toml",
|
| 28 |
"requirements.lock.txt",
|
| 29 |
]
|
scripts/run_sweep.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Sequential sweep runner.
|
| 3 |
+
#
|
| 4 |
+
# Each config grabs all 8 GPUs via accelerate, so they run back-to-back, not in
|
| 5 |
+
# parallel. Output goes to logs/<run>.log; the master log goes to logs/sweep_master.log.
|
| 6 |
+
# Reads HF_TOKEN, HUGGING_FACE_HUB_TOKEN, WANDB_API_KEY from the calling env.
|
| 7 |
+
#
|
| 8 |
+
# Launch in the background with:
|
| 9 |
+
# nohup ./scripts/run_sweep.sh > logs/sweep_master.log 2>&1 &
|
| 10 |
+
|
| 11 |
+
set -uo pipefail
|
| 12 |
+
cd "$(dirname "$0")/.."
|
| 13 |
+
|
| 14 |
+
CONFIGS=(
|
| 15 |
+
"configs/replicate_zero4.toml"
|
| 16 |
+
"configs/grow40_winning.toml"
|
| 17 |
+
"configs/grow40_simple.toml"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
LOG_DIR="logs"
|
| 21 |
+
mkdir -p "$LOG_DIR"
|
| 22 |
+
|
| 23 |
+
for cfg in "${CONFIGS[@]}"; do
|
| 24 |
+
name="$(basename "$cfg" .toml)"
|
| 25 |
+
log="$LOG_DIR/$name.log"
|
| 26 |
+
echo ">>> [$(date '+%F %T')] starting $name -> $log"
|
| 27 |
+
.venv/bin/accelerate launch \
|
| 28 |
+
--config_file configs/accelerate.yaml \
|
| 29 |
+
distill.py \
|
| 30 |
+
--config "$cfg" \
|
| 31 |
+
> "$log" 2>&1
|
| 32 |
+
rc=$?
|
| 33 |
+
echo "<<< [$(date '+%F %T')] finished $name (exit=$rc)"
|
| 34 |
+
if [[ $rc -ne 0 ]]; then
|
| 35 |
+
echo " last 20 lines of $log:"
|
| 36 |
+
tail -20 "$log" | sed 's/^/ /'
|
| 37 |
+
fi
|
| 38 |
+
done
|
| 39 |
+
|
| 40 |
+
echo ">>> [$(date '+%F %T')] sweep complete"
|