jsantillana commited on
Commit
6848cb6
·
verified ·
1 Parent(s): 9274869

Upload folder using huggingface_hub

Browse files
training/aws_lora_base_tools_s3.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SageMaker entrypoint: LoRA tool-use SFT para VectraYX Base 260M - S3 ONLY.
3
+
4
+ Igual que aws_lora_nano_tools_s3.py pero con checkpoint y config de Base 260M.
5
+
6
+ Hyperparameters via env:
7
+ CORPUS_NAME = "v3_bash" (default)
8
+ EPOCHS = "5"
9
+ LR = "2e-4"
10
+ LORA_RANK = "16"
11
+ LORA_ALPHA = "32"
12
+ SEED = "42"
13
+ """
14
+ import os, sys, json, subprocess, shutil
15
+ from pathlib import Path
16
+
17
+ S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
18
+ SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
19
+ WD = Path("/opt/ml/code/work")
20
+ ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
21
+
22
+ # Base 260M — checkpoint post-P3 (phase3_last.pt)
23
+ BASE_CKPT = f"{S3_BUCKET}/checkpoints/vectrayx-base-20260506-1901/phase3_last.pt"
24
+ BASE_CFG = "base.json"
25
+ BASE_BATCH = 8
26
+ BASE_ACCUM = 8 # effective batch = 64
27
+
28
+
29
+ def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
30
+
31
+
32
+ def s3_download(src, dst):
33
+ dst = Path(dst)
34
+ dst.parent.mkdir(parents=True, exist_ok=True)
35
+ r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
36
+ capture_output=True, text=True)
37
+ if r.returncode != 0:
38
+ die(f"s3 download failed: {src}\n{r.stderr}")
39
+ print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
40
+
41
+
42
+ def sh(cmd, cwd=None):
43
+ print(f"$ {cmd}", flush=True)
44
+ r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV},
45
+ cwd=str(cwd or WD))
46
+ if r.returncode != 0:
47
+ die(f"Failed: {cmd}")
48
+
49
+
50
+ def main():
51
+ corpus_name = os.environ.get("CORPUS_NAME", "v3_bash")
52
+ epochs = int(os.environ.get("EPOCHS", "5"))
53
+ lr = float(os.environ.get("LR", "2e-4"))
54
+ lora_rank = int(os.environ.get("LORA_RANK", "16"))
55
+ lora_alpha = float(os.environ.get("LORA_ALPHA", "32"))
56
+ seed = int(os.environ.get("SEED", "42"))
57
+
58
+ WD.mkdir(parents=True, exist_ok=True)
59
+ SM_OUTPUT.mkdir(parents=True, exist_ok=True)
60
+
61
+ print(f"[config] model=base corpus={corpus_name} epochs={epochs} lr={lr} "
62
+ f"lora_rank={lora_rank} lora_alpha={lora_alpha} seed={seed}", flush=True)
63
+
64
+ # 1. Deps
65
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
66
+ "sentencepiece", "tokenizers"], check=True)
67
+
68
+ # 2. Código training_v2 (incluye finetune_lora_tools.py y utils.py corregidos)
69
+ print("[code] Downloading training_v2 from S3...", flush=True)
70
+ subprocess.run(["aws", "s3", "cp",
71
+ f"{S3_BUCKET}/code/training_v2.tar.gz",
72
+ "/tmp/tv2.tar.gz"], check=True)
73
+ sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
74
+ print(f"[code] ✓ training_v2 extracted", flush=True)
75
+
76
+ # 3. Tokenizer (mismo que Nano — BPE 16384)
77
+ s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
78
+
79
+ # 4. Checkpoint Base 260M (post-P3, pre-SFT)
80
+ s3_download(BASE_CKPT, WD/"resume.pt")
81
+
82
+ # 5. Corpus tool-use
83
+ s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
84
+ WD/"tool_sft.jsonl")
85
+
86
+ # 6. Eval data — b4_tooluse_v2 con bash básico (60%)
87
+ eval_dir = WD / "eval_data"
88
+ for b in ["b1_cveqa", "b2_classification", "b3_commands", "b5_conversational"]:
89
+ try:
90
+ s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
91
+ eval_dir / f"{b}.jsonl")
92
+ except Exception:
93
+ print(f"[s3] skip (optional) {b}.jsonl", flush=True)
94
+ s3_download(f"{S3_BUCKET}/eval-data/b4_tooluse_v2.jsonl",
95
+ eval_dir / "b4_tooluse.jsonl")
96
+
97
+ # 7. LoRA fine-tune sobre Base 260M
98
+ out_dir = WD / "checkpoints/lora_tool_sft"
99
+ sh(f"{sys.executable} -m training_v2.train.finetune_lora_tools "
100
+ f"--config {WD}/training_v2/configs/{BASE_CFG} "
101
+ f"--tokenizer {WD}/tokenizer.model "
102
+ f"--resume {WD}/resume.pt "
103
+ f"--tool-corpus {WD}/tool_sft.jsonl "
104
+ f"--out {out_dir} "
105
+ f"--lora-rank {lora_rank} "
106
+ f"--lora-alpha {lora_alpha} "
107
+ f"--batch-size {BASE_BATCH} "
108
+ f"--grad-accum {BASE_ACCUM} "
109
+ f"--epochs {epochs} "
110
+ f"--lr {lr} "
111
+ f"--seed {seed}")
112
+
113
+ # 8. Copiar artefactos
114
+ shutil.copy(out_dir / "final.pt", SM_OUTPUT / "final.pt")
115
+ shutil.copy(out_dir / "final_lora_only.pt", SM_OUTPUT / "final_lora_only.pt")
116
+ shutil.copy(WD / f"training_v2/configs/{BASE_CFG}", SM_OUTPUT / "model_config.json")
117
+
118
+ # 9. Benchmark B1–B5
119
+ sh(f"{sys.executable} -m training_v2.eval.benchmark "
120
+ f"--checkpoint {out_dir}/final.pt "
121
+ f"--config {WD}/training_v2/configs/{BASE_CFG} "
122
+ f"--tokenizer {WD}/tokenizer.model "
123
+ f"--data-dir {eval_dir} "
124
+ f"--out {SM_OUTPUT}/bench_lora_tools.json")
125
+
126
+ # 10. Manifest
127
+ manifest = {
128
+ "model": "base",
129
+ "method": "lora",
130
+ "corpus": corpus_name,
131
+ "lora_rank": lora_rank,
132
+ "lora_alpha": lora_alpha,
133
+ "epochs": epochs,
134
+ "lr": lr,
135
+ "seed": seed,
136
+ "resume_from": BASE_CKPT,
137
+ "effective_batch": BASE_BATCH * BASE_ACCUM,
138
+ }
139
+ (SM_OUTPUT / "manifest.json").write_text(json.dumps(manifest, indent=2))
140
+ print(f"[done] LoRA tool-SFT Base 260M → {SM_OUTPUT}", flush=True)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
training/aws_lora_nano_tools_s3.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SageMaker entrypoint: LoRA tool-use SFT para VectraYX Nano - S3 ONLY.
3
+
4
+ Hyperparameters via env:
5
+ CORPUS_NAME = "v3_bash" (default)
6
+ EPOCHS = "5"
7
+ LR = "2e-4"
8
+ LORA_RANK = "16"
9
+ LORA_ALPHA = "32"
10
+ SEED = "42"
11
+ """
12
+ import os, sys, json, subprocess, shutil
13
+ from pathlib import Path
14
+
15
+ S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
16
+ SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
17
+ WD = Path("/opt/ml/code/work")
18
+ ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
19
+
20
+ # Nano config — checkpoint post-SFT mixto
21
+ NANO_CKPT = f"{S3_BUCKET}/checkpoints/nano_sft_v5.pt"
22
+ NANO_CFG = "nano.json"
23
+ NANO_BATCH = 16
24
+ NANO_ACCUM = 4 # effective batch = 64
25
+
26
+
27
+ def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
28
+
29
+
30
+ def s3_download(src, dst):
31
+ dst = Path(dst)
32
+ dst.parent.mkdir(parents=True, exist_ok=True)
33
+ r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
34
+ capture_output=True, text=True)
35
+ if r.returncode != 0:
36
+ die(f"s3 download failed: {src}\n{r.stderr}")
37
+ print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
38
+
39
+
40
+ def sh(cmd, cwd=None):
41
+ print(f"$ {cmd}", flush=True)
42
+ r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV},
43
+ cwd=str(cwd or WD))
44
+ if r.returncode != 0:
45
+ die(f"Failed: {cmd}")
46
+
47
+
48
+ def main():
49
+ corpus_name = os.environ.get("CORPUS_NAME", "v3_bash")
50
+ epochs = int(os.environ.get("EPOCHS", "5"))
51
+ lr = float(os.environ.get("LR", "2e-4"))
52
+ lora_rank = int(os.environ.get("LORA_RANK", "16"))
53
+ lora_alpha = float(os.environ.get("LORA_ALPHA", "32"))
54
+ seed = int(os.environ.get("SEED", "42"))
55
+
56
+ WD.mkdir(parents=True, exist_ok=True)
57
+ SM_OUTPUT.mkdir(parents=True, exist_ok=True)
58
+
59
+ print(f"[config] corpus={corpus_name} epochs={epochs} lr={lr} "
60
+ f"lora_rank={lora_rank} lora_alpha={lora_alpha} seed={seed}", flush=True)
61
+
62
+ # 1. Deps
63
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
64
+ "sentencepiece", "tokenizers"], check=True)
65
+
66
+ # 2. Código training_v2
67
+ print("[code] Downloading training_v2 from S3...", flush=True)
68
+ subprocess.run(["aws", "s3", "cp",
69
+ f"{S3_BUCKET}/code/training_v2.tar.gz",
70
+ "/tmp/tv2.tar.gz"], check=True)
71
+ sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
72
+ print(f"[code] ✓ training_v2 extracted", flush=True)
73
+
74
+ # 3. Tokenizer
75
+ s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
76
+
77
+ # 4. Checkpoint base Nano (post-SFT mixto)
78
+ s3_download(NANO_CKPT, WD/"resume.pt")
79
+
80
+ # 5. Corpus tool-use
81
+ s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
82
+ WD/"tool_sft.jsonl")
83
+
84
+ # 6. Eval data — b4_tooluse_v2 tiene 50 preguntas con bash básico
85
+ eval_dir = WD / "eval_data"
86
+ for b in ["b1_cveqa", "b2_classification", "b3_commands",
87
+ "b5_conversational"]:
88
+ try:
89
+ s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
90
+ eval_dir / f"{b}.jsonl")
91
+ except Exception:
92
+ print(f"[s3] skip (optional) {b}.jsonl", flush=True)
93
+ # B4 v2 — benchmark ampliado con bash básico (60%) + MCP (40%)
94
+ s3_download(f"{S3_BUCKET}/eval-data/b4_tooluse_v2.jsonl",
95
+ eval_dir / "b4_tooluse.jsonl") # mismo nombre para que benchmark.py lo encuentre
96
+
97
+ # 7. LoRA fine-tune
98
+ out_dir = WD / "checkpoints/lora_tool_sft"
99
+ sh(f"{sys.executable} -m training_v2.train.finetune_lora_tools "
100
+ f"--config {WD}/training_v2/configs/{NANO_CFG} "
101
+ f"--tokenizer {WD}/tokenizer.model "
102
+ f"--resume {WD}/resume.pt "
103
+ f"--tool-corpus {WD}/tool_sft.jsonl "
104
+ f"--out {out_dir} "
105
+ f"--lora-rank {lora_rank} "
106
+ f"--lora-alpha {lora_alpha} "
107
+ f"--batch-size {NANO_BATCH} "
108
+ f"--grad-accum {NANO_ACCUM} "
109
+ f"--epochs {epochs} "
110
+ f"--lr {lr} "
111
+ f"--seed {seed}")
112
+
113
+ # 8. Copiar artefactos al output
114
+ shutil.copy(out_dir / "final.pt", SM_OUTPUT / "final.pt")
115
+ shutil.copy(out_dir / "final_lora_only.pt", SM_OUTPUT / "final_lora_only.pt")
116
+ shutil.copy(WD / f"training_v2/configs/{NANO_CFG}", SM_OUTPUT / "model_config.json")
117
+
118
+ # 9. Benchmark B1–B5 (usa final.pt merged)
119
+ sh(f"{sys.executable} -m training_v2.eval.benchmark "
120
+ f"--checkpoint {out_dir}/final.pt "
121
+ f"--config {WD}/training_v2/configs/{NANO_CFG} "
122
+ f"--tokenizer {WD}/tokenizer.model "
123
+ f"--data-dir {eval_dir} "
124
+ f"--out {SM_OUTPUT}/bench_lora_tools.json")
125
+
126
+ # 10. Manifest
127
+ manifest = {
128
+ "model": "nano",
129
+ "method": "lora",
130
+ "corpus": corpus_name,
131
+ "lora_rank": lora_rank,
132
+ "lora_alpha": lora_alpha,
133
+ "epochs": epochs,
134
+ "lr": lr,
135
+ "seed": seed,
136
+ "resume_from": NANO_CKPT,
137
+ "effective_batch": NANO_BATCH * NANO_ACCUM,
138
+ }
139
+ (SM_OUTPUT / "manifest.json").write_text(json.dumps(manifest, indent=2))
140
+ print(f"[done] LoRA tool-SFT Nano → {SM_OUTPUT}", flush=True)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
training/aws_tool_sft_train_s3.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SageMaker entrypoint: tool-use mini-SFT focalizado (Nano o Base) - S3 ONLY.
3
+
4
+ Hyperparameters via env:
5
+ MODEL = "nano" | "base"
6
+ CORPUS_NAME = "v1" | "v2"
7
+ EPOCHS = "2"
8
+ LR = "1e-5"
9
+ SEED = "42"
10
+ """
11
+ import os, sys, json, subprocess, shutil
12
+ from pathlib import Path
13
+
14
+ S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
15
+ SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
16
+ WD = Path("/opt/ml/code/work")
17
+ ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
18
+
19
+ MODEL_CFG = {
20
+ "nano": {
21
+ "config": "nano.json",
22
+ "ckpt_src": f"{S3_BUCKET}/checkpoints/nano_sft_v5.pt",
23
+ "batch": 16,
24
+ "accum": 4,
25
+ },
26
+ "base": {
27
+ "config": "base.json",
28
+ "ckpt_src": f"{S3_BUCKET}/checkpoints/vectrayx-base-20260506-1901/phase3_last.pt",
29
+ "batch": 8,
30
+ "accum": 8,
31
+ },
32
+ }
33
+
34
+
35
+ def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
36
+
37
+
38
+ def s3_download(src, dst):
39
+ """Download from S3 using AWS CLI."""
40
+ dst = Path(dst)
41
+ dst.parent.mkdir(parents=True, exist_ok=True)
42
+ r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
43
+ capture_output=True, text=True)
44
+ if r.returncode != 0:
45
+ die(f"s3 download failed: {src}\n{r.stderr}")
46
+ print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
47
+
48
+
49
+ def sh(cmd, cwd=None):
50
+ print(f"$ {cmd}", flush=True)
51
+ r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV}, cwd=str(cwd or WD))
52
+ if r.returncode != 0: die(f"Failed: {cmd}")
53
+
54
+
55
+ def main():
56
+ model_name = os.environ.get("MODEL", "nano")
57
+ corpus_name = os.environ.get("CORPUS_NAME", "v1")
58
+ epochs = int(os.environ.get("EPOCHS", "2"))
59
+ lr = float(os.environ.get("LR", "1e-5"))
60
+ seed = int(os.environ.get("SEED", "42"))
61
+
62
+ if model_name not in MODEL_CFG: die(f"Unknown MODEL={model_name}")
63
+ cfg = MODEL_CFG[model_name]
64
+
65
+ WD.mkdir(parents=True, exist_ok=True)
66
+ SM_OUTPUT.mkdir(parents=True, exist_ok=True)
67
+
68
+ # 1. Deps
69
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
70
+ "sentencepiece", "tokenizers"], check=True)
71
+
72
+ # 2. Download and extract training_v2 code
73
+ print("[code] Downloading training_v2 from S3...", flush=True)
74
+ subprocess.run(["aws", "s3", "cp",
75
+ "s3://vectrayx-sagemaker-792811916323/code/training_v2.tar.gz",
76
+ "/tmp/tv2.tar.gz"], check=True)
77
+ sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
78
+ print(f"[code] ✓ training_v2 extracted to {WD}", flush=True)
79
+
80
+ # 3. Tokenizer
81
+ s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
82
+
83
+ # 4. Checkpoint inicial
84
+ s3_download(cfg["ckpt_src"], WD/"resume.pt")
85
+
86
+ # 5. Tool SFT corpus
87
+ s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
88
+ WD/"tool_sft.jsonl")
89
+
90
+ # 6. Eval data
91
+ eval_dir = WD / "eval_data"
92
+ for b in ["b1_cveqa", "b2_classification", "b3_commands",
93
+ "b4_tooluse", "b5_conversational"]:
94
+ try:
95
+ s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
96
+ eval_dir/f"{b}.jsonl")
97
+ except:
98
+ print(f"[s3] skip (optional) {b}.jsonl", flush=True)
99
+
100
+ # 7. Mini-SFT focalizado
101
+ out_dir = WD / "checkpoints/tool_sft"
102
+ sh(f"{sys.executable} -m training_v2.train.finetune_tools "
103
+ f"--config {WD}/training_v2/configs/{cfg['config']} "
104
+ f"--tokenizer {WD}/tokenizer.model "
105
+ f"--resume {WD}/resume.pt "
106
+ f"--tool-corpus {WD}/tool_sft.jsonl "
107
+ f"--out {out_dir} "
108
+ f"--batch-size {cfg['batch']} --grad-accum {cfg['accum']} "
109
+ f"--epochs {epochs} --lr {lr} --seed {seed}")
110
+
111
+ # 8. Copiar checkpoint final
112
+ shutil.copy(out_dir/"final.pt", SM_OUTPUT/"final.pt")
113
+ shutil.copy(WD/f"training_v2/configs/{cfg['config']}",
114
+ SM_OUTPUT/"model_config.json")
115
+
116
+ # 9. Bench B1–B5
117
+ sh(f"{sys.executable} -m training_v2.eval.benchmark "
118
+ f"--checkpoint {out_dir}/final.pt "
119
+ f"--config {WD}/training_v2/configs/{cfg['config']} "
120
+ f"--tokenizer {WD}/tokenizer.model "
121
+ f"--data-dir {eval_dir} "
122
+ f"--out {SM_OUTPUT}/bench_tool_sft.json")
123
+
124
+ # 10. Manifest
125
+ manifest = {
126
+ "model": model_name,
127
+ "corpus": corpus_name,
128
+ "epochs": epochs, "lr": lr, "seed": seed,
129
+ "resume_from": cfg["ckpt_src"],
130
+ }
131
+ (SM_OUTPUT/"manifest.json").write_text(json.dumps(manifest, indent=2))
132
+ print(f"[done] tool-SFT {model_name}/{corpus_name}/seed={seed} → {SM_OUTPUT}", flush=True)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
training/finetune_lora_tools.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LoRA tool-use SFT para VectraYX Nano.
2
+
3
+ Aplica LoRA sobre las proyecciones de atención (wq, wk, wv, wo) del modelo
4
+ custom VectraYXNano. Congela todos los pesos base y solo entrena los adaptadores.
5
+
6
+ Ventaja sobre full fine-tune:
7
+ - Solo ~0.5% de parámetros entrenables (~200K vs 42M)
8
+ - Menos riesgo de catastrofic forgetting en B1/B2/B5
9
+ - SmolLM2-135M logra B4=0.16 con LoRA — probamos si Nano puede hacer lo mismo
10
+
11
+ Run example:
12
+ python -m training_v2.train.finetune_lora_tools \
13
+ --config training_v2/configs/nano.json \
14
+ --tokenizer models/vectrayx_bpe.model \
15
+ --resume checkpoints/nano_sft_v5.pt \
16
+ --tool-corpus corpus/tool_sft_v2_simple.jsonl \
17
+ --out checkpoints/nano_lora_tools \
18
+ --lora-rank 16 --lora-alpha 32 \
19
+ --batch-size 16 --grad-accum 4 --epochs 5 --lr 2e-4
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import math
25
+ import sys
26
+ import time
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+ import sentencepiece as spm
31
+ import torch
32
+ import torch.nn as nn
33
+ from torch.utils.data import DataLoader
34
+
35
+ ROOT = Path(__file__).resolve().parents[2]
36
+ sys.path.insert(0, str(ROOT))
37
+
38
+ from training_v2.data.sft_dataset import SFTDataset
39
+ from training_v2.model.transformer import VectraYXNano, ModelConfig
40
+ from training_v2.train.utils import (
41
+ cosine_with_warmup, log_jsonl,
42
+ )
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # LoRA implementation
47
+ # ---------------------------------------------------------------------------
48
+
49
+ class LoRALinear(nn.Module):
50
+ """Reemplaza un nn.Linear con LoRA: W' = W + (B @ A) * scale."""
51
+
52
+ def __init__(self, linear: nn.Linear, rank: int, alpha: float):
53
+ super().__init__()
54
+ self.linear = linear # pesos base — CONGELADOS
55
+ self.rank = rank
56
+ self.scale = alpha / rank
57
+
58
+ in_f = linear.in_features
59
+ out_f = linear.out_features
60
+
61
+ # A: inicialización kaiming, B: ceros (LoRA paper §4)
62
+ self.lora_A = nn.Parameter(torch.empty(rank, in_f))
63
+ self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
64
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
65
+
66
+ # Congelar pesos base
67
+ for p in self.linear.parameters():
68
+ p.requires_grad_(False)
69
+
70
+ def forward(self, x):
71
+ base = self.linear(x)
72
+ # Asegurar que lora_A y lora_B estén en el mismo device que x
73
+ lora = (x @ self.lora_A.to(x.device).T) @ self.lora_B.to(x.device).T
74
+ return base + lora * self.scale
75
+
76
+
77
+ def inject_lora(model: nn.Module, rank: int, alpha: float,
78
+ target_modules=("wq", "wk", "wv", "wo")) -> int:
79
+ """Inyecta LoRA en todas las capas de atención del modelo.
80
+
81
+ Retorna el número de parámetros entrenables.
82
+ """
83
+ replaced = 0
84
+ for name, module in model.named_modules():
85
+ for attr_name in target_modules:
86
+ if hasattr(module, attr_name):
87
+ original = getattr(module, attr_name)
88
+ if isinstance(original, nn.Linear):
89
+ setattr(module, attr_name, LoRALinear(original, rank, alpha))
90
+ replaced += 1
91
+
92
+ # Congelar todo excepto LoRA
93
+ for name, param in model.named_parameters():
94
+ if "lora_A" not in name and "lora_B" not in name:
95
+ param.requires_grad_(False)
96
+
97
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
98
+ total = sum(p.numel() for p in model.parameters())
99
+ print(f"[lora] Inyectado en {replaced} módulos | "
100
+ f"Entrenables: {trainable/1e3:.1f}K / {total/1e6:.2f}M "
101
+ f"({trainable/total*100:.2f}%)")
102
+ return trainable
103
+
104
+
105
+ def save_lora_checkpoint(path: Path, model: nn.Module, optimizer, step: int,
106
+ extra: dict = None):
107
+ """Guarda solo los pesos LoRA (no el modelo base)."""
108
+ lora_state = {k: v for k, v in model.state_dict().items()
109
+ if "lora_A" in k or "lora_B" in k}
110
+ torch.save({
111
+ "lora_state_dict": lora_state,
112
+ "optimizer_state_dict": optimizer.state_dict() if optimizer else None,
113
+ "step": step,
114
+ **(extra or {}),
115
+ }, path)
116
+ print(f"[save] LoRA checkpoint → {path} ({path.stat().st_size/1e6:.1f}MB)")
117
+
118
+
119
+ def load_lora_checkpoint(path: Path, model: nn.Module, optimizer=None,
120
+ map_location="cpu"):
121
+ """Carga pesos LoRA en el modelo."""
122
+ ckpt = torch.load(path, map_location=map_location)
123
+ missing, unexpected = model.load_state_dict(ckpt["lora_state_dict"], strict=False)
124
+ lora_keys = [k for k in ckpt["lora_state_dict"]]
125
+ print(f"[load] LoRA: {len(lora_keys)} keys loaded, "
126
+ f"{len(missing)} missing, {len(unexpected)} unexpected")
127
+ if optimizer and ckpt.get("optimizer_state_dict"):
128
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
129
+ return ckpt.get("step", 0)
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # Main
134
+ # ---------------------------------------------------------------------------
135
+
136
+ def main():
137
+ p = argparse.ArgumentParser()
138
+ p.add_argument("--config", required=True)
139
+ p.add_argument("--tokenizer", required=True)
140
+ p.add_argument("--resume", required=True, help="checkpoint base a fine-tunear")
141
+ p.add_argument("--tool-corpus", required=True, help="tool-use JSONL corpus")
142
+ p.add_argument("--out", required=True)
143
+ # LoRA
144
+ p.add_argument("--lora-rank", type=int, default=16,
145
+ help="LoRA rank r (default 16)")
146
+ p.add_argument("--lora-alpha", type=float, default=32.0,
147
+ help="LoRA alpha (default 32, scale=alpha/rank=2)")
148
+ p.add_argument("--lora-targets", nargs="+",
149
+ default=["wq", "wk", "wv", "wo"],
150
+ help="Módulos de atención a inyectar LoRA")
151
+ # Training
152
+ p.add_argument("--batch-size", type=int, default=16)
153
+ p.add_argument("--grad-accum", type=int, default=4)
154
+ p.add_argument("--epochs", type=int, default=5)
155
+ p.add_argument("--lr", type=float, default=2e-4,
156
+ help="LR más alto que full FT (LoRA converge más rápido)")
157
+ p.add_argument("--weight-decay", type=float, default=0.01)
158
+ p.add_argument("--grad-clip", type=float, default=1.0)
159
+ p.add_argument("--warmup-frac", type=float, default=0.05)
160
+ p.add_argument("--num-workers", type=int, default=2)
161
+ p.add_argument("--log-every", type=int, default=10)
162
+ p.add_argument("--save-every", type=int, default=200)
163
+ p.add_argument("--seed", type=int, default=42)
164
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
165
+ p.add_argument("--dtype", default="bfloat16",
166
+ choices=["bfloat16", "float16", "float32"])
167
+ p.add_argument("--max-steps", type=int, default=None)
168
+ args = p.parse_args()
169
+
170
+ torch.manual_seed(args.seed)
171
+ np.random.seed(args.seed)
172
+
173
+ # 1. Cargar modelo base
174
+ cfg = ModelConfig.from_json(args.config)
175
+ model = VectraYXNano(cfg).to(args.device)
176
+ total_params = model.num_params()
177
+ print(f"[model] {total_params/1e6:.2f}M params (base)")
178
+
179
+ # Cargar checkpoint base (full weights) usando load_checkpoint de utils
180
+ from training_v2.train.utils import load_checkpoint as _load_ckpt
181
+ _load_ckpt(args.resume, model, optimizer=None, map_location=args.device)
182
+ print(f"[resume] {args.resume}")
183
+
184
+ # 2. Inyectar LoRA
185
+ trainable = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha,
186
+ target_modules=args.lora_targets)
187
+ # Mover parámetros LoRA al mismo device que el modelo
188
+ model = model.to(args.device)
189
+
190
+ # 3. Tokenizer
191
+ sp = spm.SentencePieceProcessor()
192
+ sp.load(args.tokenizer)
193
+ pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
194
+
195
+ # 4. Dataset
196
+ block_size = cfg.max_seq_len
197
+ tool_corpus = Path(args.tool_corpus)
198
+ if not tool_corpus.exists():
199
+ raise FileNotFoundError(f"Tool corpus not found: {tool_corpus}")
200
+
201
+ dataset = SFTDataset([tool_corpus], sp, block_size, pad_id=pad_id, seed=args.seed)
202
+ print(f"[dataset] {len(dataset)} ejemplos de {tool_corpus.name}")
203
+
204
+ # 5. Output dir
205
+ out_dir = Path(args.out)
206
+ out_dir.mkdir(parents=True, exist_ok=True)
207
+ log_path = out_dir / "train_log.jsonl"
208
+
209
+ # 6. Optimizer — solo parámetros LoRA
210
+ lora_params = [p for p in model.parameters() if p.requires_grad]
211
+ optimizer = torch.optim.AdamW(lora_params, lr=args.lr,
212
+ weight_decay=args.weight_decay,
213
+ betas=(0.9, 0.95))
214
+
215
+ # 7. AMP
216
+ dtype = {"bfloat16": torch.bfloat16,
217
+ "float16": torch.float16,
218
+ "float32": torch.float32}[args.dtype]
219
+ use_amp = args.device == "cuda" and dtype != torch.float32
220
+
221
+ # 8. Training loop
222
+ def collate(batch):
223
+ xs = torch.stack([b[0] for b in batch])
224
+ ys = torch.stack([b[1] for b in batch])
225
+ ms = torch.stack([b[2] for b in batch])
226
+ return xs, ys, ms
227
+
228
+ loader = DataLoader(
229
+ dataset, batch_size=args.batch_size, shuffle=True,
230
+ num_workers=args.num_workers, collate_fn=collate, pin_memory=True,
231
+ persistent_workers=args.num_workers > 0,
232
+ )
233
+
234
+ steps_per_epoch = max(1, len(loader) // args.grad_accum)
235
+ total_steps = steps_per_epoch * args.epochs
236
+ if args.max_steps:
237
+ total_steps = min(total_steps, args.max_steps)
238
+ warmup = max(20, int(args.warmup_frac * total_steps))
239
+
240
+ print(f"\n[train] LoRA rank={args.lora_rank} alpha={args.lora_alpha} "
241
+ f"scale={args.lora_alpha/args.lora_rank:.1f}")
242
+ print(f"[train] epochs={args.epochs} steps/epoch≈{steps_per_epoch} "
243
+ f"total={total_steps} warmup={warmup}")
244
+ print(f"[train] lr={args.lr} batch={args.batch_size} accum={args.grad_accum} "
245
+ f"effective_batch={args.batch_size * args.grad_accum}")
246
+
247
+ model.train()
248
+ t_start = time.time()
249
+ step = 0
250
+ running_loss = 0.0
251
+ running_n = 0
252
+
253
+ for ep in range(args.epochs):
254
+ print(f"\n=== epoch {ep+1}/{args.epochs} (LoRA tool-SFT) ===")
255
+ data_iter = iter(loader)
256
+
257
+ for _ in range(steps_per_epoch):
258
+ if args.max_steps and step >= args.max_steps:
259
+ break
260
+
261
+ cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
262
+ for g in optimizer.param_groups:
263
+ g["lr"] = cur_lr
264
+
265
+ optimizer.zero_grad(set_to_none=True)
266
+ loss_accum = 0.0
267
+
268
+ for _micro in range(args.grad_accum):
269
+ try:
270
+ xs, ys, ms = next(data_iter)
271
+ except StopIteration:
272
+ data_iter = iter(loader)
273
+ xs, ys, ms = next(data_iter)
274
+
275
+ xs = xs.to(args.device, non_blocking=True)
276
+ ys = ys.to(args.device, non_blocking=True)
277
+ ms = ms.to(args.device, non_blocking=True)
278
+
279
+ with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
280
+ _, loss = model(xs, targets=ys, loss_mask=ms)
281
+ loss = loss / args.grad_accum
282
+ loss.backward()
283
+ loss_accum += loss.item() * args.grad_accum
284
+
285
+ gnorm = torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
286
+ optimizer.step()
287
+ step += 1
288
+ running_loss += loss_accum / args.grad_accum
289
+ running_n += 1
290
+
291
+ if step % args.log_every == 0:
292
+ elapsed = time.time() - t_start
293
+ avg = running_loss / running_n
294
+ print(f"[lora ep{ep+1} step {step:>4}/{total_steps}] "
295
+ f"loss={avg:.4f} lr={cur_lr:.2e} "
296
+ f"gnorm={gnorm:.2f} {elapsed/60:.1f}min")
297
+ log_jsonl(log_path, {"epoch": ep+1, "step": step, "loss": avg,
298
+ "lr": cur_lr, "gnorm": float(gnorm)})
299
+ running_loss = 0.0
300
+ running_n = 0
301
+
302
+ if step % args.save_every == 0:
303
+ save_lora_checkpoint(out_dir / "last_lora.pt", model, optimizer,
304
+ step, {"epoch": ep+1})
305
+
306
+ if args.max_steps and step >= args.max_steps:
307
+ break
308
+
309
+ save_lora_checkpoint(out_dir / f"epoch{ep+1}_lora.pt", model, optimizer,
310
+ step, {"epoch": ep+1})
311
+ print(f"[save] epoch{ep+1}_lora.pt")
312
+
313
+ # Guardar checkpoint final con pesos COMPLETOS (base + LoRA merged)
314
+ # Estrategia: construir state_dict manualmente fusionando LoRA
315
+ print("\n[merge] Mergeando LoRA en pesos base...")
316
+
317
+ # Primero recolectar todos los módulos LoRA con sus rutas
318
+ lora_modules = {}
319
+ for mod_name, mod in model.named_modules():
320
+ if isinstance(mod, LoRALinear):
321
+ lora_modules[mod_name] = mod
322
+
323
+ # Construir state_dict fusionado
324
+ merged_state = {}
325
+ for param_name, param in model.named_parameters():
326
+ # Detectar si este parámetro pertenece a un LoRALinear
327
+ is_lora_internal = False
328
+ for lora_path in lora_modules:
329
+ if param_name.startswith(lora_path + ".lora_"):
330
+ is_lora_internal = True # saltar lora_A y lora_B
331
+ break
332
+ if param_name == lora_path + ".linear.weight":
333
+ # Fusionar con LoRA
334
+ lora_mod = lora_modules[lora_path]
335
+ fused = param.data + (lora_mod.lora_B.data @ lora_mod.lora_A.data) * lora_mod.scale
336
+ # Guardar con nombre limpio (sin .linear)
337
+ clean = lora_path + ".weight"
338
+ merged_state[clean] = fused
339
+ is_lora_internal = True
340
+ break
341
+ if param_name == lora_path + ".linear.bias":
342
+ clean = lora_path + ".bias"
343
+ merged_state[clean] = param.data
344
+ is_lora_internal = True
345
+ break
346
+ if not is_lora_internal:
347
+ merged_state[param_name] = param.data
348
+
349
+ print(f"[merge] {len(merged_state)} keys en merged state_dict")
350
+
351
+ # Guardar solo LoRA ANTES de modificar el modelo
352
+ save_lora_checkpoint(out_dir / "final_lora_only.pt", model, optimizer,
353
+ step, {"done": True, "lora_rank": args.lora_rank,
354
+ "lora_alpha": args.lora_alpha})
355
+
356
+ # Guardar merged (full model) para benchmark — usar clave "model" que espera load_checkpoint
357
+ # strict=False en benchmark porque lm_head comparte pesos con tok_emb (tie_embeddings)
358
+ torch.save({"model": merged_state, "step": step,
359
+ "lora_rank": args.lora_rank, "lora_alpha": args.lora_alpha,
360
+ "merged": True, "tie_embeddings": True},
361
+ out_dir / "final.pt")
362
+ print(f"[done] final.pt (merged) → {out_dir}")
363
+ print(f"[done] final_lora_only.pt (adapter only) → {out_dir}")
364
+
365
+
366
+ if __name__ == "__main__":
367
+ main()
training/finetune_sft.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT fine-tuning with assistant-only loss masking and an internal mini-curriculum.
2
+
3
+ Mini-curriculum (within SFT):
4
+ Epoch 1-2: 60% conversational (OASST1 ES + sft_conv) + 40% CVE Q&A
5
+ Epoch 3: add tool-use (50% conv + 25% CVE + 25% tool_use)
6
+
7
+ This avoids drowning the chat behavior in JSON tool-call patterns the way SFT v3 did.
8
+
9
+ Run example:
10
+ python -m training_v2.train.finetune_sft \
11
+ --config training_v2/configs/nano.json \
12
+ --tokenizer training_v2/tokenizer/out/vectrayx_bpe.model \
13
+ --resume training_v2/checkpoints/phase3/last.pt \
14
+ --out training_v2/checkpoints/sft_v4 \
15
+ --batch-size 16 --grad-accum 4 --epochs 3 --lr 2e-5
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import sentencepiece as spm
26
+ import torch
27
+ from torch.utils.data import DataLoader, ConcatDataset
28
+
29
+ ROOT = Path(__file__).resolve().parents[2]
30
+ sys.path.insert(0, str(ROOT))
31
+
32
+ from training_v2.data.sft_dataset import SFTDataset
33
+ from training_v2.model.transformer import VectraYXNano, ModelConfig
34
+ from training_v2.train.utils import (
35
+ cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
36
+ )
37
+
38
+
39
+ SFT_FILES = {
40
+ "conversational": [
41
+ "corpus/sft_conversational.jsonl",
42
+ "sft_v2_data/oasst1_es.jsonl",
43
+ ],
44
+ "cve_qa": [
45
+ "corpus/sft_v2_dataset.jsonl",
46
+ ],
47
+ "tool_use": [
48
+ "corpus/tooluse_dataset.jsonl",
49
+ ],
50
+ }
51
+
52
+
53
+ def load_sft_corpus_config(path):
54
+ global SFT_FILES
55
+ cfg = json.loads(Path(path).read_text())
56
+ SFT_FILES = {
57
+ "conversational": cfg.get("sft_conversational", SFT_FILES["conversational"]),
58
+ "cve_qa": cfg.get("sft_cve_qa", SFT_FILES["cve_qa"]),
59
+ "tool_use": cfg.get("sft_tool_use", SFT_FILES["tool_use"]),
60
+ }
61
+
62
+
63
+ def discover(paths, root):
64
+ found = []
65
+ for rel in paths:
66
+ full = Path(root) / rel
67
+ if full.exists():
68
+ found.append(full)
69
+ else:
70
+ print(f" [skip missing] {full}")
71
+ return found
72
+
73
+
74
+ def build_dataset(args, sp, include_tools):
75
+ block_size = ModelConfig.from_json(args.config).max_seq_len
76
+ pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
77
+
78
+ conv = discover(SFT_FILES["conversational"], args.corpus_root)
79
+ cve = discover(SFT_FILES["cve_qa"], args.corpus_root)
80
+ tools = discover(SFT_FILES["tool_use"], args.corpus_root)
81
+
82
+ parts = []
83
+ if conv:
84
+ parts.append(("conv", SFTDataset(conv, sp, block_size, pad_id=pad_id, seed=args.seed)))
85
+ if cve:
86
+ parts.append(("cve", SFTDataset(cve, sp, block_size, pad_id=pad_id, seed=args.seed + 1)))
87
+ if include_tools and tools:
88
+ parts.append(("tools", SFTDataset(tools, sp, block_size, pad_id=pad_id, seed=args.seed + 2)))
89
+ return parts, pad_id
90
+
91
+
92
+ def make_loader(parts, weights, batch_size, num_workers):
93
+ """Weighted sampling across the named parts."""
94
+ sizes = [len(d) for _, d in parts]
95
+ names = [n for n, _ in parts]
96
+ datasets = [d for _, d in parts]
97
+ big = ConcatDataset(datasets)
98
+
99
+ offsets = np.cumsum([0] + sizes)
100
+ weight_per_idx = np.zeros(offsets[-1], dtype=np.float64)
101
+ for i, n in enumerate(names):
102
+ w = weights.get(n, 1.0) / max(1, sizes[i])
103
+ weight_per_idx[offsets[i]:offsets[i + 1]] = w
104
+ sampler = torch.utils.data.WeightedRandomSampler(
105
+ weights=weight_per_idx,
106
+ num_samples=int(sum(sizes)),
107
+ replacement=True,
108
+ )
109
+
110
+ def collate(batch):
111
+ xs = torch.stack([b[0] for b in batch], 0)
112
+ ys = torch.stack([b[1] for b in batch], 0)
113
+ ms = torch.stack([b[2] for b in batch], 0)
114
+ return xs, ys, ms
115
+
116
+ return DataLoader(
117
+ big, batch_size=batch_size, sampler=sampler,
118
+ num_workers=num_workers, collate_fn=collate, pin_memory=True,
119
+ persistent_workers=num_workers > 0,
120
+ )
121
+
122
+
123
+ def main():
124
+ p = argparse.ArgumentParser()
125
+ p.add_argument("--config", required=True)
126
+ p.add_argument("--tokenizer", required=True)
127
+ p.add_argument("--resume", required=True, help="pre-training checkpoint to fine-tune")
128
+ p.add_argument("--out", required=True)
129
+ p.add_argument("--corpus-root", default=".")
130
+ p.add_argument("--corpus-config", default=None)
131
+ p.add_argument("--batch-size", type=int, default=16)
132
+ p.add_argument("--grad-accum", type=int, default=4)
133
+ p.add_argument("--epochs", type=int, default=3)
134
+ p.add_argument("--lr", type=float, default=2e-5)
135
+ p.add_argument("--weight-decay", type=float, default=0.0)
136
+ p.add_argument("--grad-clip", type=float, default=1.0)
137
+ p.add_argument("--warmup-frac", type=float, default=0.03)
138
+ p.add_argument("--num-workers", type=int, default=2)
139
+ p.add_argument("--log-every", type=int, default=20)
140
+ p.add_argument("--save-every", type=int, default=500)
141
+ p.add_argument("--seed", type=int, default=42)
142
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
143
+ p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
144
+ args = p.parse_args()
145
+
146
+ if args.corpus_config:
147
+ load_sft_corpus_config(args.corpus_config)
148
+
149
+ torch.manual_seed(args.seed)
150
+ np.random.seed(args.seed)
151
+
152
+ cfg = ModelConfig.from_json(args.config)
153
+ model = VectraYXNano(cfg).to(args.device)
154
+ print(f"[model] {model.num_params()/1e6:.2f}M params")
155
+ load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
156
+ print(f"[resume] {args.resume}")
157
+
158
+ sp = spm.SentencePieceProcessor()
159
+ sp.load(args.tokenizer)
160
+ parts, pad_id = build_dataset(args, sp, include_tools=True)
161
+ if not parts:
162
+ raise RuntimeError("no SFT files found")
163
+
164
+ out_dir = Path(args.out)
165
+ out_dir.mkdir(parents=True, exist_ok=True)
166
+ log_path = out_dir / "train_log.jsonl"
167
+
168
+ optimizer = make_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
169
+
170
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
171
+ use_amp = args.device == "cuda" and dtype != torch.float32
172
+
173
+ epoch_plans = [
174
+ {"conv": 1.00, "cve": 0.00, "tools": 0.0}, # epoch 1: SOLO conversacional
175
+ {"conv": 0.70, "cve": 0.30, "tools": 0.00}, # epoch 2: + CVE Q&A
176
+ {"conv": 0.55, "cve": 0.30, "tools": 0.15}, # epoch 3: + tool use
177
+ ]
178
+
179
+ total_steps = 0
180
+ for ep in range(args.epochs):
181
+ weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
182
+ print(f"\n=== epoch {ep+1}/{args.epochs} | mix={weights} ===")
183
+ loader = make_loader(parts, weights, args.batch_size, args.num_workers)
184
+ steps_per_epoch = max(1, len(loader) // args.grad_accum)
185
+ total_steps += steps_per_epoch
186
+ warmup = max(50, int(args.warmup_frac * total_steps))
187
+ print(f"[sft] total_steps≈{total_steps} warmup={warmup}")
188
+
189
+ model.train()
190
+ t_start = time.time()
191
+ step = 0
192
+ running_loss = 0.0
193
+ running_n = 0
194
+
195
+ for ep in range(args.epochs):
196
+ weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
197
+ loader = make_loader(parts, weights, args.batch_size, args.num_workers)
198
+ data_iter = iter(loader)
199
+ steps_per_epoch = max(1, len(loader) // args.grad_accum)
200
+
201
+ for _ in range(steps_per_epoch):
202
+ cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
203
+ for g in optimizer.param_groups:
204
+ g["lr"] = cur_lr
205
+
206
+ optimizer.zero_grad(set_to_none=True)
207
+ loss_accum = 0.0
208
+ for _micro in range(args.grad_accum):
209
+ try:
210
+ xs, ys, ms = next(data_iter)
211
+ except StopIteration:
212
+ data_iter = iter(loader)
213
+ xs, ys, ms = next(data_iter)
214
+ xs = xs.to(args.device, non_blocking=True)
215
+ ys = ys.to(args.device, non_blocking=True)
216
+ ms = ms.to(args.device, non_blocking=True)
217
+ with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
218
+ _, loss = model(xs, targets=ys, loss_mask=ms)
219
+ loss = loss / args.grad_accum
220
+ loss.backward()
221
+ loss_accum += loss.item() * args.grad_accum
222
+
223
+ gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
224
+ optimizer.step()
225
+ step += 1
226
+ running_loss += loss_accum / args.grad_accum
227
+ running_n += 1
228
+
229
+ if step % args.log_every == 0:
230
+ elapsed = time.time() - t_start
231
+ avg = running_loss / running_n
232
+ print(f"[sft ep{ep+1} step {step:>5}/{total_steps}] loss={avg:.4f} "
233
+ f"lr={cur_lr:.2e} gnorm={gnorm:.2f} elapsed={elapsed/60:.1f}min")
234
+ log_jsonl(log_path, {"epoch": ep + 1, "step": step, "loss": avg,
235
+ "lr": cur_lr, "gnorm": float(gnorm)})
236
+ running_loss = 0.0
237
+ running_n = 0
238
+
239
+ if step % args.save_every == 0:
240
+ save_checkpoint(out_dir / "last.pt", model, optimizer,
241
+ {"step": step}, step,
242
+ extra={"epoch": ep + 1, "weights": weights})
243
+
244
+ save_checkpoint(out_dir / f"epoch{ep+1}.pt", model, optimizer,
245
+ {"step": step}, step,
246
+ extra={"epoch": ep + 1, "weights": weights})
247
+ print(f"[save] {out_dir}/epoch{ep+1}.pt")
248
+
249
+ save_checkpoint(out_dir / "final.pt", model, optimizer, {"step": step}, step,
250
+ extra={"done": True})
251
+ print(f"[done] SFT → {out_dir}/final.pt")
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()
training/finetune_tools.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool-use focused SFT for VectraYX Nano/Base.
2
+
3
+ This is a simplified version of finetune_sft.py that trains ONLY on tool-use examples.
4
+ The goal is to test the hypothesis that B4=0.000 is due to diluted tool-call gradients
5
+ in the mixed SFT corpus, not a capacity gate.
6
+
7
+ Run example:
8
+ python -m training_v2.train.finetune_tools \
9
+ --config training_v2/configs/nano.json \
10
+ --tokenizer models/vectrayx_bpe.model \
11
+ --resume checkpoints/nano_final.pt \
12
+ --tool-corpus /tmp/tool_sft_v1.jsonl \
13
+ --out checkpoints/tool_sft_nano \
14
+ --batch-size 16 --grad-accum 4 --epochs 2 --lr 1e-5
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import sys
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import sentencepiece as spm
25
+ import torch
26
+ from torch.utils.data import DataLoader
27
+
28
+ ROOT = Path(__file__).resolve().parents[2]
29
+ sys.path.insert(0, str(ROOT))
30
+
31
+ from training_v2.data.sft_dataset import SFTDataset
32
+ from training_v2.model.transformer import VectraYXNano, ModelConfig
33
+ from training_v2.train.utils import (
34
+ cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
35
+ )
36
+
37
+
38
+ def main():
39
+ p = argparse.ArgumentParser()
40
+ p.add_argument("--config", required=True)
41
+ p.add_argument("--tokenizer", required=True)
42
+ p.add_argument("--resume", required=True, help="checkpoint to fine-tune from")
43
+ p.add_argument("--tool-corpus", required=True, help="tool-use JSONL corpus")
44
+ p.add_argument("--out", required=True)
45
+ p.add_argument("--batch-size", type=int, default=16)
46
+ p.add_argument("--grad-accum", type=int, default=4)
47
+ p.add_argument("--epochs", type=int, default=2)
48
+ p.add_argument("--lr", type=float, default=1e-5)
49
+ p.add_argument("--weight-decay", type=float, default=0.0)
50
+ p.add_argument("--grad-clip", type=float, default=1.0)
51
+ p.add_argument("--warmup-frac", type=float, default=0.03)
52
+ p.add_argument("--num-workers", type=int, default=2)
53
+ p.add_argument("--log-every", type=int, default=20)
54
+ p.add_argument("--save-every", type=int, default=500)
55
+ p.add_argument("--seed", type=int, default=42)
56
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
57
+ p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
58
+ p.add_argument("--max-steps", type=int, default=None, help="for testing")
59
+ args = p.parse_args()
60
+
61
+ torch.manual_seed(args.seed)
62
+ np.random.seed(args.seed)
63
+
64
+ # Load model
65
+ cfg = ModelConfig.from_json(args.config)
66
+ model = VectraYXNano(cfg).to(args.device)
67
+ print(f"[model] {model.num_params()/1e6:.2f}M params")
68
+ load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
69
+ print(f"[resume] {args.resume}")
70
+
71
+ # Load tokenizer
72
+ sp = spm.SentencePieceProcessor()
73
+ sp.load(args.tokenizer)
74
+ pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
75
+
76
+ # Build tool-only dataset
77
+ block_size = cfg.max_seq_len
78
+ tool_corpus = Path(args.tool_corpus)
79
+ if not tool_corpus.exists():
80
+ raise FileNotFoundError(f"Tool corpus not found: {tool_corpus}")
81
+
82
+ dataset = SFTDataset([tool_corpus], sp, block_size, pad_id=pad_id, seed=args.seed)
83
+ print(f"[dataset] {len(dataset)} tool-use examples from {tool_corpus}")
84
+
85
+ # Setup output
86
+ out_dir = Path(args.out)
87
+ out_dir.mkdir(parents=True, exist_ok=True)
88
+ log_path = out_dir / "train_log.jsonl"
89
+
90
+ # Optimizer
91
+ optimizer = make_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
92
+
93
+ # AMP setup
94
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
95
+ use_amp = args.device == "cuda" and dtype != torch.float32
96
+
97
+ # Training loop
98
+ def collate(batch):
99
+ xs = torch.stack([b[0] for b in batch], 0)
100
+ ys = torch.stack([b[1] for b in batch], 0)
101
+ ms = torch.stack([b[2] for b in batch], 0)
102
+ return xs, ys, ms
103
+
104
+ loader = DataLoader(
105
+ dataset, batch_size=args.batch_size, shuffle=True,
106
+ num_workers=args.num_workers, collate_fn=collate, pin_memory=True,
107
+ persistent_workers=args.num_workers > 0,
108
+ )
109
+
110
+ steps_per_epoch = max(1, len(loader) // args.grad_accum)
111
+ total_steps = steps_per_epoch * args.epochs
112
+ if args.max_steps:
113
+ total_steps = min(total_steps, args.max_steps)
114
+ warmup = max(50, int(args.warmup_frac * total_steps))
115
+ print(f"[train] epochs={args.epochs} steps_per_epoch≈{steps_per_epoch} total_steps={total_steps} warmup={warmup}")
116
+
117
+ model.train()
118
+ t_start = time.time()
119
+ step = 0
120
+ running_loss = 0.0
121
+ running_n = 0
122
+
123
+ for ep in range(args.epochs):
124
+ print(f"\n=== epoch {ep+1}/{args.epochs} (tool-only) ===")
125
+ data_iter = iter(loader)
126
+
127
+ for _ in range(steps_per_epoch):
128
+ if args.max_steps and step >= args.max_steps:
129
+ break
130
+
131
+ cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
132
+ for g in optimizer.param_groups:
133
+ g["lr"] = cur_lr
134
+
135
+ optimizer.zero_grad(set_to_none=True)
136
+ loss_accum = 0.0
137
+ for _micro in range(args.grad_accum):
138
+ try:
139
+ xs, ys, ms = next(data_iter)
140
+ except StopIteration:
141
+ data_iter = iter(loader)
142
+ xs, ys, ms = next(data_iter)
143
+ xs = xs.to(args.device, non_blocking=True)
144
+ ys = ys.to(args.device, non_blocking=True)
145
+ ms = ms.to(args.device, non_blocking=True)
146
+
147
+ with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
148
+ _, loss = model(xs, targets=ys, loss_mask=ms)
149
+ loss = loss / args.grad_accum
150
+ loss.backward()
151
+ loss_accum += loss.item() * args.grad_accum
152
+
153
+ gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
154
+ optimizer.step()
155
+ step += 1
156
+ running_loss += loss_accum / args.grad_accum
157
+ running_n += 1
158
+
159
+ if step % args.log_every == 0:
160
+ elapsed = time.time() - t_start
161
+ avg = running_loss / running_n
162
+ print(f"[tool-sft ep{ep+1} step {step:>5}/{total_steps}] loss={avg:.4f} "
163
+ f"lr={cur_lr:.2e} gnorm={gnorm:.2f} elapsed={elapsed/60:.1f}min")
164
+ log_jsonl(log_path, {"epoch": ep + 1, "step": step, "loss": avg,
165
+ "lr": cur_lr, "gnorm": float(gnorm)})
166
+ running_loss = 0.0
167
+ running_n = 0
168
+
169
+ if step % args.save_every == 0:
170
+ save_checkpoint(out_dir / "last.pt", model, optimizer,
171
+ {"step": step}, step,
172
+ extra={"epoch": ep + 1, "tool_only": True})
173
+
174
+ if args.max_steps and step >= args.max_steps:
175
+ break
176
+
177
+ save_checkpoint(out_dir / f"epoch{ep+1}.pt", model, optimizer,
178
+ {"step": step}, step,
179
+ extra={"epoch": ep + 1, "tool_only": True})
180
+ print(f"[save] {out_dir}/epoch{ep+1}.pt")
181
+
182
+ save_checkpoint(out_dir / "final.pt", model, optimizer, {"step": step}, step,
183
+ extra={"done": True, "tool_only": True})
184
+ print(f"[done] {out_dir}/final.pt")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
training/pretrain.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curriculum pre-training driver for VectraYX-Nano v2.
2
+
3
+ Phase 1: 100% conversational (LR 3e-4 from scratch)
4
+ Phase 2: 75% tech + 25% conv (LR 1.5e-4, resumed from phase 1)
5
+ Phase 3: 70% tools + 20% tech + 10% conv (LR 8e-5, resumed from phase 2)
6
+
7
+ Run example:
8
+ python -m training_v2.train.pretrain \
9
+ --config training_v2/configs/nano.json \
10
+ --bins training_v2/data/bins \
11
+ --out training_v2/checkpoints \
12
+ --phase 1 --max-steps 8000 --batch-size 16 --grad-accum 8
13
+
14
+ Then:
15
+ --phase 2 --resume training_v2/checkpoints/phase1/last.pt
16
+ --phase 3 --resume training_v2/checkpoints/phase2/last.pt
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import sys
22
+ import time
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.utils.data import DataLoader
29
+
30
+ ROOT = Path(__file__).resolve().parents[2]
31
+ sys.path.insert(0, str(ROOT))
32
+
33
+ from training_v2.data.curriculum_dataset import (
34
+ MixedCurriculumDataset, make_phase_mix, load_phase_summary,
35
+ )
36
+ from training_v2.model.transformer import VectraYXNano, ModelConfig
37
+ from training_v2.train.utils import (
38
+ cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
39
+ )
40
+
41
+
42
+ PHASE_LR = {1: 3.0e-4, 2: 1.5e-4, 3: 8.0e-5}
43
+ PHASE_WARMUP_FRAC = {1: 0.05, 2: 0.02, 3: 0.02}
44
+
45
+
46
+ def build_dataloader(args, mix, block_size):
47
+ phase_dirs = {
48
+ "phase1_conv": Path(args.bins) / "phase1_conv",
49
+ "phase2_tech": Path(args.bins) / "phase2_tech",
50
+ "phase3_tools": Path(args.bins) / "phase3_tools",
51
+ }
52
+ ds = MixedCurriculumDataset(
53
+ phase_dirs={k: v for k, v in phase_dirs.items() if mix.get(k, 0) > 0},
54
+ weights=mix,
55
+ block_size=block_size,
56
+ dtype=np.uint16,
57
+ seed=args.seed,
58
+ )
59
+
60
+ def collate(batch):
61
+ xs = torch.stack([b[0] for b in batch], 0)
62
+ ys = torch.stack([b[1] for b in batch], 0)
63
+ return xs, ys
64
+
65
+ return DataLoader(
66
+ ds,
67
+ batch_size=args.batch_size,
68
+ num_workers=args.num_workers,
69
+ collate_fn=collate,
70
+ pin_memory=True,
71
+ persistent_workers=args.num_workers > 0,
72
+ )
73
+
74
+
75
+ def estimate_phase_tokens(phase_idx, mix, summary):
76
+ total = 0.0
77
+ for k, w in mix.items():
78
+ n = summary.get(k, {}).get("n_tokens", 0)
79
+ if w > 0 and n > 0:
80
+ total += n
81
+ return int(total)
82
+
83
+
84
+ def main():
85
+ p = argparse.ArgumentParser()
86
+ p.add_argument("--config", required=True)
87
+ p.add_argument("--bins", required=True, help="root of binary shard dirs")
88
+ p.add_argument("--out", required=True, help="checkpoint output root")
89
+ p.add_argument("--phase", type=int, choices=[1, 2, 3], required=True)
90
+ p.add_argument("--resume", type=str, default=None)
91
+ p.add_argument("--batch-size", type=int, default=16)
92
+ p.add_argument("--grad-accum", type=int, default=8)
93
+ p.add_argument("--max-steps", type=int, default=None)
94
+ p.add_argument("--epochs", type=float, default=2.0,
95
+ help="estimate steps as epochs*phase_tokens/(batch*ga*block)")
96
+ p.add_argument("--lr", type=float, default=None)
97
+ p.add_argument("--weight-decay", type=float, default=0.1)
98
+ p.add_argument("--grad-clip", type=float, default=1.0)
99
+ p.add_argument("--num-workers", type=int, default=2)
100
+ p.add_argument("--log-every", type=int, default=20)
101
+ p.add_argument("--save-every", type=int, default=1000)
102
+ p.add_argument("--seed", type=int, default=42)
103
+ p.add_argument("--replay-conv", type=float, default=None,
104
+ help="override replay ratio of conversational data in phase 2/3")
105
+ p.add_argument("--replay-tech", type=float, default=None,
106
+ help="override replay ratio of technical data in phase 3")
107
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
108
+ p.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
109
+ p.add_argument("--compile", action="store_true")
110
+ args = p.parse_args()
111
+
112
+ torch.manual_seed(args.seed)
113
+ np.random.seed(args.seed)
114
+
115
+ cfg = ModelConfig.from_json(args.config)
116
+ model = VectraYXNano(cfg).to(args.device)
117
+ n_params = model.num_params()
118
+ print(f"[model] {n_params/1e6:.2f}M params · cfg={cfg}")
119
+
120
+ mix = make_phase_mix(args.phase, replay_conv=args.replay_conv, replay_tech=args.replay_tech)
121
+ summary = load_phase_summary(args.bins)
122
+ phase_tokens = estimate_phase_tokens(args.phase, mix, summary)
123
+ tokens_per_step = args.batch_size * args.grad_accum * cfg.max_seq_len
124
+ if args.max_steps is None:
125
+ args.max_steps = max(1000, int(args.epochs * phase_tokens / tokens_per_step))
126
+ print(f"[phase {args.phase}] mix={mix}")
127
+ print(f"[phase {args.phase}] phase_tokens={phase_tokens:,} tokens/step={tokens_per_step:,} steps={args.max_steps}")
128
+
129
+ lr = args.lr if args.lr is not None else PHASE_LR[args.phase]
130
+ warmup = max(50, int(PHASE_WARMUP_FRAC[args.phase] * args.max_steps))
131
+ optimizer = make_optimizer(model, lr=lr, weight_decay=args.weight_decay)
132
+
133
+ start_step = 0
134
+ if args.resume:
135
+ start_step, _ = load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
136
+ print(f"[resume] loaded weights from {args.resume} (step={start_step})")
137
+ start_step = 0 # fresh optimizer for new phase
138
+
139
+ loader = build_dataloader(args, mix, cfg.max_seq_len)
140
+ data_iter = iter(loader)
141
+
142
+ out_dir = Path(args.out) / f"phase{args.phase}"
143
+ out_dir.mkdir(parents=True, exist_ok=True)
144
+ log_path = out_dir / "train_log.jsonl"
145
+
146
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
147
+ use_amp = args.device == "cuda" and dtype != torch.float32
148
+ scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
149
+
150
+ if args.compile:
151
+ try:
152
+ model = torch.compile(model)
153
+ except Exception as e:
154
+ print(f"[compile] skipped: {e}")
155
+
156
+ model.train()
157
+ t_start = time.time()
158
+ tokens_seen = 0
159
+ running_loss = 0.0
160
+ running_n = 0
161
+
162
+ for step in range(start_step, args.max_steps):
163
+ cur_lr = cosine_with_warmup(step, warmup, args.max_steps, lr)
164
+ for g in optimizer.param_groups:
165
+ g["lr"] = cur_lr
166
+
167
+ optimizer.zero_grad(set_to_none=True)
168
+ loss_accum = 0.0
169
+ for micro in range(args.grad_accum):
170
+ try:
171
+ batch = next(data_iter)
172
+ except StopIteration:
173
+ data_iter = iter(loader)
174
+ batch = next(data_iter)
175
+ xs, ys = batch[0], batch[1]
176
+ xs = xs.to(args.device, non_blocking=True)
177
+ ys = ys.to(args.device, non_blocking=True)
178
+ with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
179
+ _, loss = model(xs, targets=ys)
180
+ loss = loss / args.grad_accum
181
+ if scaler.is_enabled():
182
+ scaler.scale(loss).backward()
183
+ else:
184
+ loss.backward()
185
+ loss_accum += loss.item() * args.grad_accum
186
+
187
+ if scaler.is_enabled():
188
+ scaler.unscale_(optimizer)
189
+ gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
190
+ if scaler.is_enabled():
191
+ scaler.step(optimizer)
192
+ scaler.update()
193
+ else:
194
+ optimizer.step()
195
+
196
+ tokens_seen += tokens_per_step
197
+ running_loss += loss_accum / args.grad_accum
198
+ running_n += 1
199
+
200
+ if (step + 1) % args.log_every == 0:
201
+ elapsed = time.time() - t_start
202
+ tps = tokens_seen / max(1.0, elapsed)
203
+ avg_loss = running_loss / running_n
204
+ print(f"[p{args.phase} step {step+1:>6}/{args.max_steps}] "
205
+ f"loss={avg_loss:.4f} lr={cur_lr:.2e} gnorm={gnorm:.2f} "
206
+ f"tok/s={tps:>7,.0f} elapsed={elapsed/60:.1f}min")
207
+ log_jsonl(log_path, {
208
+ "phase": args.phase, "step": step + 1, "loss": avg_loss,
209
+ "lr": cur_lr, "gnorm": float(gnorm), "tok_per_s": tps,
210
+ "tokens_seen": tokens_seen,
211
+ })
212
+ running_loss = 0.0
213
+ running_n = 0
214
+
215
+ if (step + 1) % args.save_every == 0 or (step + 1) == args.max_steps:
216
+ ckpt_path = out_dir / "last.pt"
217
+ save_checkpoint(ckpt_path, model, optimizer, {"step": step + 1}, step + 1,
218
+ extra={"phase": args.phase, "mix": mix, "lr": lr})
219
+ print(f"[save] {ckpt_path}")
220
+
221
+ final = out_dir / "last.pt"
222
+ save_checkpoint(final, model, optimizer, {"step": args.max_steps}, args.max_steps,
223
+ extra={"phase": args.phase, "mix": mix, "lr": lr, "done": True})
224
+ print(f"[done] phase {args.phase} → {final}")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()
training/sft_dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT dataset with proper assistant-only loss masking and safe packing.
2
+
3
+ Each example is a chat-formatted string with `<|system|> <|user|> <|assistant|> <|end|>`
4
+ turn delimiters. We tokenize on the fly (corpus is small, ~25M tokens) and build a
5
+ mask=1 only on tokens that are part of an assistant response (everything between
6
+ `<|assistant|>` and the next `<|end|>`).
7
+
8
+ For pre-training-style packing without cross-example contamination we group multiple
9
+ short examples into a fixed-length window using `cu_seqlens`-style document boundaries
10
+ implemented via per-document attention reset. Here we keep it simple: pad/truncate
11
+ each example to `block_size`. Throughput is still high (>40k tok/s on L4) for this
12
+ volume.
13
+ """
14
+
15
+ import json
16
+ import random
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch.utils.data import Dataset
22
+
23
+
24
+ def _read_jsonl(path):
25
+ out = []
26
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
27
+ for line in f:
28
+ line = line.strip()
29
+ if not line:
30
+ continue
31
+ try:
32
+ obj = json.loads(line)
33
+ except json.JSONDecodeError:
34
+ continue
35
+ t = obj.get("text") or ""
36
+ if t:
37
+ out.append({"text": t, "source": obj.get("source", Path(path).stem)})
38
+ return out
39
+
40
+
41
+ def build_assistant_mask(token_ids, assistant_id, end_id):
42
+ """mask[i] = 1 iff token_ids[i] is inside an `<|assistant|> ... <|end|>` span.
43
+
44
+ We mark from the token AFTER `<|assistant|>` up to and including `<|end|>` so the
45
+ model learns to emit the closing delimiter.
46
+ """
47
+ mask = np.zeros(len(token_ids), dtype=np.int64)
48
+ inside = False
49
+ for i, t in enumerate(token_ids):
50
+ if t == assistant_id and not inside:
51
+ inside = True
52
+ continue # don't include the assistant tag itself
53
+ if inside:
54
+ mask[i] = 1
55
+ if t == end_id:
56
+ inside = False
57
+ return mask
58
+
59
+
60
+ class SFTDataset(Dataset):
61
+ def __init__(self, jsonl_paths, sp, block_size, assistant_token="<|assistant|>",
62
+ end_token="<|end|>", pad_id=0, seed=42, mix_weights=None):
63
+ self.sp = sp
64
+ self.block_size = block_size
65
+ self.pad_id = pad_id
66
+ self.assistant_id = sp.piece_to_id(assistant_token)
67
+ self.end_id = sp.piece_to_id(end_token)
68
+ if self.assistant_id < 0 or self.end_id < 0:
69
+ raise ValueError(f"missing special tokens in tokenizer: "
70
+ f"{assistant_token}={self.assistant_id} {end_token}={self.end_id}")
71
+
72
+ self.examples = []
73
+ rng = random.Random(seed)
74
+ for p in jsonl_paths:
75
+ recs = _read_jsonl(p)
76
+ w = (mix_weights or {}).get(Path(p).name, 1.0)
77
+ if w != 1.0:
78
+ k = int(len(recs) * w)
79
+ recs = rng.sample(recs, min(k, len(recs)))
80
+ self.examples.extend(recs)
81
+ print(f" [sft] {p}: {len(recs):,} ex (w={w})")
82
+ rng.shuffle(self.examples)
83
+ print(f"[sft] total: {len(self.examples):,} examples")
84
+
85
+ def __len__(self):
86
+ return len(self.examples)
87
+
88
+ def __getitem__(self, idx):
89
+ text = self.examples[idx]["text"]
90
+ ids = self.sp.encode(text, out_type=int)
91
+ ids = ids[: self.block_size + 1]
92
+ mask = build_assistant_mask(ids, self.assistant_id, self.end_id)
93
+
94
+ if len(ids) < self.block_size + 1:
95
+ need = self.block_size + 1 - len(ids)
96
+ ids = ids + [self.pad_id] * need
97
+ mask = np.concatenate([mask, np.zeros(need, dtype=np.int64)])
98
+
99
+ ids = np.asarray(ids, dtype=np.int64)
100
+ x = torch.from_numpy(ids[:-1])
101
+ y = torch.from_numpy(ids[1:].copy())
102
+ m = torch.from_numpy(mask[1:].copy()) # mask aligned with targets
103
+ # zero out padded targets
104
+ y[m == 0] = -100
105
+ return x, y, m
training/transformer.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VectraYX-Nano transformer (decoder-only, ~42M params).
2
+
3
+ Modern small-LLM stack:
4
+ RMSNorm (pre-norm) · SwiGLU FFN · RoPE · GQA (8q/2kv)
5
+ QK-Norm · no biases · tied embeddings · z-loss
6
+ """
7
+
8
+ import json
9
+ import math
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ @dataclass
18
+ class ModelConfig:
19
+ vocab_size: int = 16384
20
+ n_layers: int = 8
21
+ n_heads: int = 8
22
+ n_kv_heads: int = 2
23
+ d_model: int = 512
24
+ d_ffn: int = 2048
25
+ max_seq_len: int = 1024
26
+ rope_theta: float = 10000.0
27
+ rms_eps: float = 1e-6
28
+ init_std: float = 0.02
29
+ dropout: float = 0.0
30
+ tie_embeddings: bool = True
31
+ qk_norm: bool = True
32
+ z_loss_coef: float = 1e-4
33
+
34
+ @classmethod
35
+ def from_json(cls, path):
36
+ cfg = json.loads(open(path).read())["model"]
37
+ return cls(**{k: cfg[k] for k in cfg if k in cls.__dataclass_fields__})
38
+
39
+
40
+ class RMSNorm(nn.Module):
41
+ def __init__(self, dim, eps=1e-6):
42
+ super().__init__()
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+ self.eps = eps
45
+
46
+ def forward(self, x):
47
+ var = x.pow(2).mean(-1, keepdim=True)
48
+ x = x * torch.rsqrt(var + self.eps)
49
+ return x.to(self.weight.dtype) * self.weight
50
+
51
+
52
+ def precompute_rope(head_dim, max_seq_len, theta=10000.0, device=None):
53
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
54
+ t = torch.arange(max_seq_len, dtype=torch.float32)
55
+ freqs = torch.outer(t, inv_freq)
56
+ cos = freqs.cos()
57
+ sin = freqs.sin()
58
+ if device is not None:
59
+ cos = cos.to(device)
60
+ sin = sin.to(device)
61
+ return cos, sin
62
+
63
+
64
+ def apply_rope(x, cos, sin):
65
+ # x: (B, H, T, D) with D even. cos/sin: (T, D/2)
66
+ T, D = x.shape[-2], x.shape[-1]
67
+ cos = cos[:T].view(1, 1, T, D // 2)
68
+ sin = sin[:T].view(1, 1, T, D // 2)
69
+ x1 = x[..., : D // 2]
70
+ x2 = x[..., D // 2:]
71
+ rx1 = x1 * cos - x2 * sin
72
+ rx2 = x1 * sin + x2 * cos
73
+ return torch.cat([rx1, rx2], dim=-1)
74
+
75
+
76
+ class GQAttention(nn.Module):
77
+ def __init__(self, cfg: ModelConfig):
78
+ super().__init__()
79
+ assert cfg.d_model % cfg.n_heads == 0
80
+ assert cfg.n_heads % cfg.n_kv_heads == 0
81
+ self.n_heads = cfg.n_heads
82
+ self.n_kv_heads = cfg.n_kv_heads
83
+ self.head_dim = cfg.d_model // cfg.n_heads
84
+ self.repeat = self.n_heads // self.n_kv_heads
85
+
86
+ self.wq = nn.Linear(cfg.d_model, cfg.n_heads * self.head_dim, bias=False)
87
+ self.wk = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
88
+ self.wv = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
89
+ self.wo = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
90
+
91
+ self.qk_norm = cfg.qk_norm
92
+ if self.qk_norm:
93
+ self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
94
+ self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
95
+
96
+ self.dropout = cfg.dropout
97
+
98
+ def forward(self, x, cos, sin):
99
+ B, T, _ = x.shape
100
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
101
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
102
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
103
+
104
+ if self.qk_norm:
105
+ q = self.q_norm(q)
106
+ k = self.k_norm(k)
107
+
108
+ q = apply_rope(q, cos, sin)
109
+ k = apply_rope(k, cos, sin)
110
+
111
+ if self.repeat > 1:
112
+ k = k.repeat_interleave(self.repeat, dim=1)
113
+ v = v.repeat_interleave(self.repeat, dim=1)
114
+
115
+ out = F.scaled_dot_product_attention(
116
+ q, k, v,
117
+ dropout_p=self.dropout if self.training else 0.0,
118
+ is_causal=True,
119
+ )
120
+ out = out.transpose(1, 2).contiguous().view(B, T, -1)
121
+ return self.wo(out)
122
+
123
+
124
+ class SwiGLU(nn.Module):
125
+ def __init__(self, cfg: ModelConfig):
126
+ super().__init__()
127
+ self.w_gate = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
128
+ self.w_up = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
129
+ self.w_down = nn.Linear(cfg.d_ffn, cfg.d_model, bias=False)
130
+
131
+ def forward(self, x):
132
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
133
+
134
+
135
+ class Block(nn.Module):
136
+ def __init__(self, cfg: ModelConfig):
137
+ super().__init__()
138
+ self.attn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
139
+ self.attn = GQAttention(cfg)
140
+ self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
141
+ self.ffn = SwiGLU(cfg)
142
+
143
+ def forward(self, x, cos, sin):
144
+ x = x + self.attn(self.attn_norm(x), cos, sin)
145
+ x = x + self.ffn(self.ffn_norm(x))
146
+ return x
147
+
148
+
149
+ class VectraYXNano(nn.Module):
150
+ def __init__(self, cfg: ModelConfig):
151
+ super().__init__()
152
+ self.cfg = cfg
153
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
154
+ self.layers = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
155
+ self.final_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
156
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
157
+
158
+ if cfg.tie_embeddings:
159
+ self.lm_head.weight = self.tok_emb.weight
160
+
161
+ head_dim = cfg.d_model // cfg.n_heads
162
+ cos, sin = precompute_rope(head_dim, cfg.max_seq_len, cfg.rope_theta)
163
+ self.register_buffer("rope_cos", cos, persistent=False)
164
+ self.register_buffer("rope_sin", sin, persistent=False)
165
+
166
+ self.apply(self._init_weights)
167
+ residual_std = cfg.init_std / math.sqrt(2 * cfg.n_layers)
168
+ for n, p in self.named_parameters():
169
+ if n.endswith("wo.weight") or n.endswith("w_down.weight"):
170
+ nn.init.normal_(p, mean=0.0, std=residual_std)
171
+
172
+ def _init_weights(self, m):
173
+ std = self.cfg.init_std
174
+ if isinstance(m, nn.Linear):
175
+ nn.init.normal_(m.weight, mean=0.0, std=std)
176
+ if m.bias is not None:
177
+ nn.init.zeros_(m.bias)
178
+ elif isinstance(m, nn.Embedding):
179
+ nn.init.normal_(m.weight, mean=0.0, std=std)
180
+
181
+ def num_params(self, exclude_embedding=False):
182
+ n = sum(p.numel() for p in self.parameters())
183
+ if exclude_embedding and self.cfg.tie_embeddings:
184
+ n -= self.tok_emb.weight.numel()
185
+ return n
186
+
187
+ def forward(self, idx, targets=None, loss_mask=None):
188
+ B, T = idx.shape
189
+ assert T <= self.cfg.max_seq_len, f"seq {T} > max {self.cfg.max_seq_len}"
190
+ x = self.tok_emb(idx)
191
+ cos = self.rope_cos
192
+ sin = self.rope_sin
193
+ for layer in self.layers:
194
+ x = layer(x, cos, sin)
195
+ x = self.final_norm(x)
196
+ logits = self.lm_head(x)
197
+
198
+ if targets is None:
199
+ return logits, None
200
+
201
+ # cross-entropy + z-loss for stability
202
+ flat_logits = logits.view(-1, logits.size(-1))
203
+ flat_tgt = targets.view(-1)
204
+ ce = F.cross_entropy(flat_logits, flat_tgt, reduction="none", ignore_index=-100)
205
+ if loss_mask is not None:
206
+ mask = loss_mask.view(-1).float()
207
+ denom = mask.sum().clamp_min(1.0)
208
+ ce_loss = (ce * mask).sum() / denom
209
+ else:
210
+ valid = (flat_tgt != -100).float()
211
+ denom = valid.sum().clamp_min(1.0)
212
+ ce_loss = (ce * valid).sum() / denom
213
+
214
+ if self.cfg.z_loss_coef > 0:
215
+ lse = torch.logsumexp(flat_logits.float(), dim=-1)
216
+ if loss_mask is not None:
217
+ z = ((lse ** 2) * loss_mask.view(-1).float()).sum() / denom
218
+ else:
219
+ z = ((lse ** 2) * (flat_tgt != -100).float()).sum() / denom
220
+ loss = ce_loss + self.cfg.z_loss_coef * z
221
+ else:
222
+ loss = ce_loss
223
+ return logits, loss
224
+
225
+ @torch.no_grad()
226
+ def generate(self, idx, max_new_tokens, temperature=0.7, top_k=40, top_p=0.9,
227
+ eos_id=None, repeat_penalty=1.0):
228
+ self.eval()
229
+ for _ in range(max_new_tokens):
230
+ cond = idx[:, -self.cfg.max_seq_len:]
231
+ logits, _ = self(cond)
232
+ logits = logits[:, -1, :].float()
233
+
234
+ if repeat_penalty != 1.0:
235
+ for token in set(idx[0].tolist()):
236
+ logits[0, token] = logits[0, token] / repeat_penalty if logits[0, token] > 0 else logits[0, token] * repeat_penalty
237
+
238
+ if temperature <= 0:
239
+ next_id = logits.argmax(-1, keepdim=True)
240
+ else:
241
+ logits = logits / temperature
242
+ if top_k:
243
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
244
+ logits[logits < v[:, [-1]]] = -float("inf")
245
+ if top_p and top_p < 1.0:
246
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
247
+ probs = F.softmax(sorted_logits, dim=-1)
248
+ cumprobs = probs.cumsum(-1)
249
+ drop = cumprobs > top_p
250
+ drop[..., 1:] = drop[..., :-1].clone()
251
+ drop[..., 0] = False
252
+ sorted_logits[drop] = -float("inf")
253
+ logits = torch.full_like(logits, -float("inf")).scatter(-1, sorted_idx, sorted_logits)
254
+ probs = F.softmax(logits, dim=-1)
255
+ next_id = torch.multinomial(probs, 1)
256
+ idx = torch.cat([idx, next_id], dim=-1)
257
+ if eos_id is not None and next_id.item() == eos_id:
258
+ break
259
+ return idx
training/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training utilities: optimizer setup, LR schedule, checkpointing."""
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+
11
+ def cosine_with_warmup(step, warmup, total, max_lr, min_lr_ratio=0.1):
12
+ if step < warmup:
13
+ return max_lr * (step + 1) / warmup
14
+ progress = (step - warmup) / max(1, total - warmup)
15
+ progress = min(1.0, progress)
16
+ return min_lr_ratio * max_lr + 0.5 * (max_lr - min_lr_ratio * max_lr) * (1 + math.cos(math.pi * progress))
17
+
18
+
19
+ def make_optimizer(model, lr, weight_decay=0.1, betas=(0.9, 0.95), fused=True):
20
+ """AdamW with weight decay only on 2D weights (no decay on biases / norms / embeddings).
21
+
22
+ Per Loshchilov & Hutter; same convention as nanoGPT.
23
+ """
24
+ decay, no_decay = [], []
25
+ for n, p in model.named_parameters():
26
+ if not p.requires_grad:
27
+ continue
28
+ if p.dim() >= 2 and "tok_emb" not in n:
29
+ decay.append(p)
30
+ else:
31
+ no_decay.append(p)
32
+ groups = [
33
+ {"params": decay, "weight_decay": weight_decay},
34
+ {"params": no_decay, "weight_decay": 0.0},
35
+ ]
36
+ extra = {}
37
+ if fused and torch.cuda.is_available():
38
+ try:
39
+ return torch.optim.AdamW(groups, lr=lr, betas=betas, fused=True)
40
+ except TypeError:
41
+ pass
42
+ return torch.optim.AdamW(groups, lr=lr, betas=betas, **extra)
43
+
44
+
45
+ def save_checkpoint(path, model, optimizer, scheduler_state, step, extra=None):
46
+ path = Path(path)
47
+ path.parent.mkdir(parents=True, exist_ok=True)
48
+ payload = {
49
+ "model": model.state_dict(),
50
+ "optimizer": optimizer.state_dict() if optimizer is not None else None,
51
+ "scheduler": scheduler_state,
52
+ "step": step,
53
+ "config": {k: getattr(model.cfg, k) for k in model.cfg.__dataclass_fields__},
54
+ "extra": extra or {},
55
+ }
56
+ tmp = path.with_suffix(path.suffix + ".tmp")
57
+ torch.save(payload, tmp)
58
+ os.replace(tmp, path)
59
+
60
+
61
+ def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
62
+ payload = torch.load(path, map_location=map_location, weights_only=False)
63
+ # Si el checkpoint tiene tie_embeddings=True, usar strict=False
64
+ # (lm_head comparte pesos con tok_emb y no se guarda por separado)
65
+ strict = not payload.get("tie_embeddings", False)
66
+ missing, unexpected = model.load_state_dict(payload["model"], strict=strict)
67
+ if missing:
68
+ print(f"[load_checkpoint] missing keys (expected with tie_embeddings): {missing[:3]}")
69
+ if optimizer is not None and payload.get("optimizer"):
70
+ optimizer.load_state_dict(payload["optimizer"])
71
+ return payload.get("step", 0), payload.get("extra", {})
72
+
73
+
74
+ def count_tokens(loader_output_iter, n_steps, block_size, batch_size):
75
+ """Approximate; effective tokens consumed per step."""
76
+ return n_steps * block_size * batch_size
77
+
78
+
79
+ def log_jsonl(path, record):
80
+ with open(path, "a", encoding="utf-8") as f:
81
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")