Upload folder using huggingface_hub
Browse files- training/aws_lora_base_tools_s3.py +144 -0
- training/aws_lora_nano_tools_s3.py +144 -0
- training/aws_tool_sft_train_s3.py +136 -0
- training/finetune_lora_tools.py +367 -0
- training/finetune_sft.py +255 -0
- training/finetune_tools.py +188 -0
- training/pretrain.py +228 -0
- training/sft_dataset.py +105 -0
- training/transformer.py +259 -0
- training/utils.py +81 -0
training/aws_lora_base_tools_s3.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""SageMaker entrypoint: LoRA tool-use SFT para VectraYX Base 260M - S3 ONLY.
|
| 3 |
+
|
| 4 |
+
Igual que aws_lora_nano_tools_s3.py pero con checkpoint y config de Base 260M.
|
| 5 |
+
|
| 6 |
+
Hyperparameters via env:
|
| 7 |
+
CORPUS_NAME = "v3_bash" (default)
|
| 8 |
+
EPOCHS = "5"
|
| 9 |
+
LR = "2e-4"
|
| 10 |
+
LORA_RANK = "16"
|
| 11 |
+
LORA_ALPHA = "32"
|
| 12 |
+
SEED = "42"
|
| 13 |
+
"""
|
| 14 |
+
import os, sys, json, subprocess, shutil
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
|
| 18 |
+
SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
| 19 |
+
WD = Path("/opt/ml/code/work")
|
| 20 |
+
ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
|
| 21 |
+
|
| 22 |
+
# Base 260M — checkpoint post-P3 (phase3_last.pt)
|
| 23 |
+
BASE_CKPT = f"{S3_BUCKET}/checkpoints/vectrayx-base-20260506-1901/phase3_last.pt"
|
| 24 |
+
BASE_CFG = "base.json"
|
| 25 |
+
BASE_BATCH = 8
|
| 26 |
+
BASE_ACCUM = 8 # effective batch = 64
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def s3_download(src, dst):
|
| 33 |
+
dst = Path(dst)
|
| 34 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
|
| 36 |
+
capture_output=True, text=True)
|
| 37 |
+
if r.returncode != 0:
|
| 38 |
+
die(f"s3 download failed: {src}\n{r.stderr}")
|
| 39 |
+
print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def sh(cmd, cwd=None):
|
| 43 |
+
print(f"$ {cmd}", flush=True)
|
| 44 |
+
r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV},
|
| 45 |
+
cwd=str(cwd or WD))
|
| 46 |
+
if r.returncode != 0:
|
| 47 |
+
die(f"Failed: {cmd}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
corpus_name = os.environ.get("CORPUS_NAME", "v3_bash")
|
| 52 |
+
epochs = int(os.environ.get("EPOCHS", "5"))
|
| 53 |
+
lr = float(os.environ.get("LR", "2e-4"))
|
| 54 |
+
lora_rank = int(os.environ.get("LORA_RANK", "16"))
|
| 55 |
+
lora_alpha = float(os.environ.get("LORA_ALPHA", "32"))
|
| 56 |
+
seed = int(os.environ.get("SEED", "42"))
|
| 57 |
+
|
| 58 |
+
WD.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
SM_OUTPUT.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
print(f"[config] model=base corpus={corpus_name} epochs={epochs} lr={lr} "
|
| 62 |
+
f"lora_rank={lora_rank} lora_alpha={lora_alpha} seed={seed}", flush=True)
|
| 63 |
+
|
| 64 |
+
# 1. Deps
|
| 65 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
|
| 66 |
+
"sentencepiece", "tokenizers"], check=True)
|
| 67 |
+
|
| 68 |
+
# 2. Código training_v2 (incluye finetune_lora_tools.py y utils.py corregidos)
|
| 69 |
+
print("[code] Downloading training_v2 from S3...", flush=True)
|
| 70 |
+
subprocess.run(["aws", "s3", "cp",
|
| 71 |
+
f"{S3_BUCKET}/code/training_v2.tar.gz",
|
| 72 |
+
"/tmp/tv2.tar.gz"], check=True)
|
| 73 |
+
sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
|
| 74 |
+
print(f"[code] ✓ training_v2 extracted", flush=True)
|
| 75 |
+
|
| 76 |
+
# 3. Tokenizer (mismo que Nano — BPE 16384)
|
| 77 |
+
s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
|
| 78 |
+
|
| 79 |
+
# 4. Checkpoint Base 260M (post-P3, pre-SFT)
|
| 80 |
+
s3_download(BASE_CKPT, WD/"resume.pt")
|
| 81 |
+
|
| 82 |
+
# 5. Corpus tool-use
|
| 83 |
+
s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
|
| 84 |
+
WD/"tool_sft.jsonl")
|
| 85 |
+
|
| 86 |
+
# 6. Eval data — b4_tooluse_v2 con bash básico (60%)
|
| 87 |
+
eval_dir = WD / "eval_data"
|
| 88 |
+
for b in ["b1_cveqa", "b2_classification", "b3_commands", "b5_conversational"]:
|
| 89 |
+
try:
|
| 90 |
+
s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
|
| 91 |
+
eval_dir / f"{b}.jsonl")
|
| 92 |
+
except Exception:
|
| 93 |
+
print(f"[s3] skip (optional) {b}.jsonl", flush=True)
|
| 94 |
+
s3_download(f"{S3_BUCKET}/eval-data/b4_tooluse_v2.jsonl",
|
| 95 |
+
eval_dir / "b4_tooluse.jsonl")
|
| 96 |
+
|
| 97 |
+
# 7. LoRA fine-tune sobre Base 260M
|
| 98 |
+
out_dir = WD / "checkpoints/lora_tool_sft"
|
| 99 |
+
sh(f"{sys.executable} -m training_v2.train.finetune_lora_tools "
|
| 100 |
+
f"--config {WD}/training_v2/configs/{BASE_CFG} "
|
| 101 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 102 |
+
f"--resume {WD}/resume.pt "
|
| 103 |
+
f"--tool-corpus {WD}/tool_sft.jsonl "
|
| 104 |
+
f"--out {out_dir} "
|
| 105 |
+
f"--lora-rank {lora_rank} "
|
| 106 |
+
f"--lora-alpha {lora_alpha} "
|
| 107 |
+
f"--batch-size {BASE_BATCH} "
|
| 108 |
+
f"--grad-accum {BASE_ACCUM} "
|
| 109 |
+
f"--epochs {epochs} "
|
| 110 |
+
f"--lr {lr} "
|
| 111 |
+
f"--seed {seed}")
|
| 112 |
+
|
| 113 |
+
# 8. Copiar artefactos
|
| 114 |
+
shutil.copy(out_dir / "final.pt", SM_OUTPUT / "final.pt")
|
| 115 |
+
shutil.copy(out_dir / "final_lora_only.pt", SM_OUTPUT / "final_lora_only.pt")
|
| 116 |
+
shutil.copy(WD / f"training_v2/configs/{BASE_CFG}", SM_OUTPUT / "model_config.json")
|
| 117 |
+
|
| 118 |
+
# 9. Benchmark B1–B5
|
| 119 |
+
sh(f"{sys.executable} -m training_v2.eval.benchmark "
|
| 120 |
+
f"--checkpoint {out_dir}/final.pt "
|
| 121 |
+
f"--config {WD}/training_v2/configs/{BASE_CFG} "
|
| 122 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 123 |
+
f"--data-dir {eval_dir} "
|
| 124 |
+
f"--out {SM_OUTPUT}/bench_lora_tools.json")
|
| 125 |
+
|
| 126 |
+
# 10. Manifest
|
| 127 |
+
manifest = {
|
| 128 |
+
"model": "base",
|
| 129 |
+
"method": "lora",
|
| 130 |
+
"corpus": corpus_name,
|
| 131 |
+
"lora_rank": lora_rank,
|
| 132 |
+
"lora_alpha": lora_alpha,
|
| 133 |
+
"epochs": epochs,
|
| 134 |
+
"lr": lr,
|
| 135 |
+
"seed": seed,
|
| 136 |
+
"resume_from": BASE_CKPT,
|
| 137 |
+
"effective_batch": BASE_BATCH * BASE_ACCUM,
|
| 138 |
+
}
|
| 139 |
+
(SM_OUTPUT / "manifest.json").write_text(json.dumps(manifest, indent=2))
|
| 140 |
+
print(f"[done] LoRA tool-SFT Base 260M → {SM_OUTPUT}", flush=True)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
training/aws_lora_nano_tools_s3.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""SageMaker entrypoint: LoRA tool-use SFT para VectraYX Nano - S3 ONLY.
|
| 3 |
+
|
| 4 |
+
Hyperparameters via env:
|
| 5 |
+
CORPUS_NAME = "v3_bash" (default)
|
| 6 |
+
EPOCHS = "5"
|
| 7 |
+
LR = "2e-4"
|
| 8 |
+
LORA_RANK = "16"
|
| 9 |
+
LORA_ALPHA = "32"
|
| 10 |
+
SEED = "42"
|
| 11 |
+
"""
|
| 12 |
+
import os, sys, json, subprocess, shutil
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
|
| 16 |
+
SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
| 17 |
+
WD = Path("/opt/ml/code/work")
|
| 18 |
+
ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
|
| 19 |
+
|
| 20 |
+
# Nano config — checkpoint post-SFT mixto
|
| 21 |
+
NANO_CKPT = f"{S3_BUCKET}/checkpoints/nano_sft_v5.pt"
|
| 22 |
+
NANO_CFG = "nano.json"
|
| 23 |
+
NANO_BATCH = 16
|
| 24 |
+
NANO_ACCUM = 4 # effective batch = 64
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def s3_download(src, dst):
|
| 31 |
+
dst = Path(dst)
|
| 32 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
|
| 34 |
+
capture_output=True, text=True)
|
| 35 |
+
if r.returncode != 0:
|
| 36 |
+
die(f"s3 download failed: {src}\n{r.stderr}")
|
| 37 |
+
print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def sh(cmd, cwd=None):
|
| 41 |
+
print(f"$ {cmd}", flush=True)
|
| 42 |
+
r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV},
|
| 43 |
+
cwd=str(cwd or WD))
|
| 44 |
+
if r.returncode != 0:
|
| 45 |
+
die(f"Failed: {cmd}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
corpus_name = os.environ.get("CORPUS_NAME", "v3_bash")
|
| 50 |
+
epochs = int(os.environ.get("EPOCHS", "5"))
|
| 51 |
+
lr = float(os.environ.get("LR", "2e-4"))
|
| 52 |
+
lora_rank = int(os.environ.get("LORA_RANK", "16"))
|
| 53 |
+
lora_alpha = float(os.environ.get("LORA_ALPHA", "32"))
|
| 54 |
+
seed = int(os.environ.get("SEED", "42"))
|
| 55 |
+
|
| 56 |
+
WD.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
SM_OUTPUT.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
print(f"[config] corpus={corpus_name} epochs={epochs} lr={lr} "
|
| 60 |
+
f"lora_rank={lora_rank} lora_alpha={lora_alpha} seed={seed}", flush=True)
|
| 61 |
+
|
| 62 |
+
# 1. Deps
|
| 63 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
|
| 64 |
+
"sentencepiece", "tokenizers"], check=True)
|
| 65 |
+
|
| 66 |
+
# 2. Código training_v2
|
| 67 |
+
print("[code] Downloading training_v2 from S3...", flush=True)
|
| 68 |
+
subprocess.run(["aws", "s3", "cp",
|
| 69 |
+
f"{S3_BUCKET}/code/training_v2.tar.gz",
|
| 70 |
+
"/tmp/tv2.tar.gz"], check=True)
|
| 71 |
+
sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
|
| 72 |
+
print(f"[code] ✓ training_v2 extracted", flush=True)
|
| 73 |
+
|
| 74 |
+
# 3. Tokenizer
|
| 75 |
+
s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
|
| 76 |
+
|
| 77 |
+
# 4. Checkpoint base Nano (post-SFT mixto)
|
| 78 |
+
s3_download(NANO_CKPT, WD/"resume.pt")
|
| 79 |
+
|
| 80 |
+
# 5. Corpus tool-use
|
| 81 |
+
s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
|
| 82 |
+
WD/"tool_sft.jsonl")
|
| 83 |
+
|
| 84 |
+
# 6. Eval data — b4_tooluse_v2 tiene 50 preguntas con bash básico
|
| 85 |
+
eval_dir = WD / "eval_data"
|
| 86 |
+
for b in ["b1_cveqa", "b2_classification", "b3_commands",
|
| 87 |
+
"b5_conversational"]:
|
| 88 |
+
try:
|
| 89 |
+
s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
|
| 90 |
+
eval_dir / f"{b}.jsonl")
|
| 91 |
+
except Exception:
|
| 92 |
+
print(f"[s3] skip (optional) {b}.jsonl", flush=True)
|
| 93 |
+
# B4 v2 — benchmark ampliado con bash básico (60%) + MCP (40%)
|
| 94 |
+
s3_download(f"{S3_BUCKET}/eval-data/b4_tooluse_v2.jsonl",
|
| 95 |
+
eval_dir / "b4_tooluse.jsonl") # mismo nombre para que benchmark.py lo encuentre
|
| 96 |
+
|
| 97 |
+
# 7. LoRA fine-tune
|
| 98 |
+
out_dir = WD / "checkpoints/lora_tool_sft"
|
| 99 |
+
sh(f"{sys.executable} -m training_v2.train.finetune_lora_tools "
|
| 100 |
+
f"--config {WD}/training_v2/configs/{NANO_CFG} "
|
| 101 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 102 |
+
f"--resume {WD}/resume.pt "
|
| 103 |
+
f"--tool-corpus {WD}/tool_sft.jsonl "
|
| 104 |
+
f"--out {out_dir} "
|
| 105 |
+
f"--lora-rank {lora_rank} "
|
| 106 |
+
f"--lora-alpha {lora_alpha} "
|
| 107 |
+
f"--batch-size {NANO_BATCH} "
|
| 108 |
+
f"--grad-accum {NANO_ACCUM} "
|
| 109 |
+
f"--epochs {epochs} "
|
| 110 |
+
f"--lr {lr} "
|
| 111 |
+
f"--seed {seed}")
|
| 112 |
+
|
| 113 |
+
# 8. Copiar artefactos al output
|
| 114 |
+
shutil.copy(out_dir / "final.pt", SM_OUTPUT / "final.pt")
|
| 115 |
+
shutil.copy(out_dir / "final_lora_only.pt", SM_OUTPUT / "final_lora_only.pt")
|
| 116 |
+
shutil.copy(WD / f"training_v2/configs/{NANO_CFG}", SM_OUTPUT / "model_config.json")
|
| 117 |
+
|
| 118 |
+
# 9. Benchmark B1–B5 (usa final.pt merged)
|
| 119 |
+
sh(f"{sys.executable} -m training_v2.eval.benchmark "
|
| 120 |
+
f"--checkpoint {out_dir}/final.pt "
|
| 121 |
+
f"--config {WD}/training_v2/configs/{NANO_CFG} "
|
| 122 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 123 |
+
f"--data-dir {eval_dir} "
|
| 124 |
+
f"--out {SM_OUTPUT}/bench_lora_tools.json")
|
| 125 |
+
|
| 126 |
+
# 10. Manifest
|
| 127 |
+
manifest = {
|
| 128 |
+
"model": "nano",
|
| 129 |
+
"method": "lora",
|
| 130 |
+
"corpus": corpus_name,
|
| 131 |
+
"lora_rank": lora_rank,
|
| 132 |
+
"lora_alpha": lora_alpha,
|
| 133 |
+
"epochs": epochs,
|
| 134 |
+
"lr": lr,
|
| 135 |
+
"seed": seed,
|
| 136 |
+
"resume_from": NANO_CKPT,
|
| 137 |
+
"effective_batch": NANO_BATCH * NANO_ACCUM,
|
| 138 |
+
}
|
| 139 |
+
(SM_OUTPUT / "manifest.json").write_text(json.dumps(manifest, indent=2))
|
| 140 |
+
print(f"[done] LoRA tool-SFT Nano → {SM_OUTPUT}", flush=True)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
training/aws_tool_sft_train_s3.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""SageMaker entrypoint: tool-use mini-SFT focalizado (Nano o Base) - S3 ONLY.
|
| 3 |
+
|
| 4 |
+
Hyperparameters via env:
|
| 5 |
+
MODEL = "nano" | "base"
|
| 6 |
+
CORPUS_NAME = "v1" | "v2"
|
| 7 |
+
EPOCHS = "2"
|
| 8 |
+
LR = "1e-5"
|
| 9 |
+
SEED = "42"
|
| 10 |
+
"""
|
| 11 |
+
import os, sys, json, subprocess, shutil
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
S3_BUCKET = "s3://vectrayx-sagemaker-792811916323"
|
| 15 |
+
SM_OUTPUT = Path(os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
| 16 |
+
WD = Path("/opt/ml/code/work")
|
| 17 |
+
ENV = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
|
| 18 |
+
|
| 19 |
+
MODEL_CFG = {
|
| 20 |
+
"nano": {
|
| 21 |
+
"config": "nano.json",
|
| 22 |
+
"ckpt_src": f"{S3_BUCKET}/checkpoints/nano_sft_v5.pt",
|
| 23 |
+
"batch": 16,
|
| 24 |
+
"accum": 4,
|
| 25 |
+
},
|
| 26 |
+
"base": {
|
| 27 |
+
"config": "base.json",
|
| 28 |
+
"ckpt_src": f"{S3_BUCKET}/checkpoints/vectrayx-base-20260506-1901/phase3_last.pt",
|
| 29 |
+
"batch": 8,
|
| 30 |
+
"accum": 8,
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def die(m): print(f"\n[FATAL] {m}", flush=True); sys.exit(1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def s3_download(src, dst):
|
| 39 |
+
"""Download from S3 using AWS CLI."""
|
| 40 |
+
dst = Path(dst)
|
| 41 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
r = subprocess.run(["aws", "s3", "cp", src, str(dst)],
|
| 43 |
+
capture_output=True, text=True)
|
| 44 |
+
if r.returncode != 0:
|
| 45 |
+
die(f"s3 download failed: {src}\n{r.stderr}")
|
| 46 |
+
print(f"[s3] ✓ {src} ({dst.stat().st_size/1e6:.1f}MB)", flush=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def sh(cmd, cwd=None):
|
| 50 |
+
print(f"$ {cmd}", flush=True)
|
| 51 |
+
r = subprocess.run(cmd, shell=True, env={**os.environ, **ENV}, cwd=str(cwd or WD))
|
| 52 |
+
if r.returncode != 0: die(f"Failed: {cmd}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
model_name = os.environ.get("MODEL", "nano")
|
| 57 |
+
corpus_name = os.environ.get("CORPUS_NAME", "v1")
|
| 58 |
+
epochs = int(os.environ.get("EPOCHS", "2"))
|
| 59 |
+
lr = float(os.environ.get("LR", "1e-5"))
|
| 60 |
+
seed = int(os.environ.get("SEED", "42"))
|
| 61 |
+
|
| 62 |
+
if model_name not in MODEL_CFG: die(f"Unknown MODEL={model_name}")
|
| 63 |
+
cfg = MODEL_CFG[model_name]
|
| 64 |
+
|
| 65 |
+
WD.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
SM_OUTPUT.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# 1. Deps
|
| 69 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
|
| 70 |
+
"sentencepiece", "tokenizers"], check=True)
|
| 71 |
+
|
| 72 |
+
# 2. Download and extract training_v2 code
|
| 73 |
+
print("[code] Downloading training_v2 from S3...", flush=True)
|
| 74 |
+
subprocess.run(["aws", "s3", "cp",
|
| 75 |
+
"s3://vectrayx-sagemaker-792811916323/code/training_v2.tar.gz",
|
| 76 |
+
"/tmp/tv2.tar.gz"], check=True)
|
| 77 |
+
sh("tar xzf /tmp/tv2.tar.gz", cwd=WD)
|
| 78 |
+
print(f"[code] ✓ training_v2 extracted to {WD}", flush=True)
|
| 79 |
+
|
| 80 |
+
# 3. Tokenizer
|
| 81 |
+
s3_download(f"{S3_BUCKET}/tokenizers/vectrayx_bpe.model", WD/"tokenizer.model")
|
| 82 |
+
|
| 83 |
+
# 4. Checkpoint inicial
|
| 84 |
+
s3_download(cfg["ckpt_src"], WD/"resume.pt")
|
| 85 |
+
|
| 86 |
+
# 5. Tool SFT corpus
|
| 87 |
+
s3_download(f"{S3_BUCKET}/training-data/tool_sft_{corpus_name}.jsonl",
|
| 88 |
+
WD/"tool_sft.jsonl")
|
| 89 |
+
|
| 90 |
+
# 6. Eval data
|
| 91 |
+
eval_dir = WD / "eval_data"
|
| 92 |
+
for b in ["b1_cveqa", "b2_classification", "b3_commands",
|
| 93 |
+
"b4_tooluse", "b5_conversational"]:
|
| 94 |
+
try:
|
| 95 |
+
s3_download(f"{S3_BUCKET}/eval-data/{b}.jsonl",
|
| 96 |
+
eval_dir/f"{b}.jsonl")
|
| 97 |
+
except:
|
| 98 |
+
print(f"[s3] skip (optional) {b}.jsonl", flush=True)
|
| 99 |
+
|
| 100 |
+
# 7. Mini-SFT focalizado
|
| 101 |
+
out_dir = WD / "checkpoints/tool_sft"
|
| 102 |
+
sh(f"{sys.executable} -m training_v2.train.finetune_tools "
|
| 103 |
+
f"--config {WD}/training_v2/configs/{cfg['config']} "
|
| 104 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 105 |
+
f"--resume {WD}/resume.pt "
|
| 106 |
+
f"--tool-corpus {WD}/tool_sft.jsonl "
|
| 107 |
+
f"--out {out_dir} "
|
| 108 |
+
f"--batch-size {cfg['batch']} --grad-accum {cfg['accum']} "
|
| 109 |
+
f"--epochs {epochs} --lr {lr} --seed {seed}")
|
| 110 |
+
|
| 111 |
+
# 8. Copiar checkpoint final
|
| 112 |
+
shutil.copy(out_dir/"final.pt", SM_OUTPUT/"final.pt")
|
| 113 |
+
shutil.copy(WD/f"training_v2/configs/{cfg['config']}",
|
| 114 |
+
SM_OUTPUT/"model_config.json")
|
| 115 |
+
|
| 116 |
+
# 9. Bench B1–B5
|
| 117 |
+
sh(f"{sys.executable} -m training_v2.eval.benchmark "
|
| 118 |
+
f"--checkpoint {out_dir}/final.pt "
|
| 119 |
+
f"--config {WD}/training_v2/configs/{cfg['config']} "
|
| 120 |
+
f"--tokenizer {WD}/tokenizer.model "
|
| 121 |
+
f"--data-dir {eval_dir} "
|
| 122 |
+
f"--out {SM_OUTPUT}/bench_tool_sft.json")
|
| 123 |
+
|
| 124 |
+
# 10. Manifest
|
| 125 |
+
manifest = {
|
| 126 |
+
"model": model_name,
|
| 127 |
+
"corpus": corpus_name,
|
| 128 |
+
"epochs": epochs, "lr": lr, "seed": seed,
|
| 129 |
+
"resume_from": cfg["ckpt_src"],
|
| 130 |
+
}
|
| 131 |
+
(SM_OUTPUT/"manifest.json").write_text(json.dumps(manifest, indent=2))
|
| 132 |
+
print(f"[done] tool-SFT {model_name}/{corpus_name}/seed={seed} → {SM_OUTPUT}", flush=True)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
main()
|
training/finetune_lora_tools.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LoRA tool-use SFT para VectraYX Nano.
|
| 2 |
+
|
| 3 |
+
Aplica LoRA sobre las proyecciones de atención (wq, wk, wv, wo) del modelo
|
| 4 |
+
custom VectraYXNano. Congela todos los pesos base y solo entrena los adaptadores.
|
| 5 |
+
|
| 6 |
+
Ventaja sobre full fine-tune:
|
| 7 |
+
- Solo ~0.5% de parámetros entrenables (~200K vs 42M)
|
| 8 |
+
- Menos riesgo de catastrofic forgetting en B1/B2/B5
|
| 9 |
+
- SmolLM2-135M logra B4=0.16 con LoRA — probamos si Nano puede hacer lo mismo
|
| 10 |
+
|
| 11 |
+
Run example:
|
| 12 |
+
python -m training_v2.train.finetune_lora_tools \
|
| 13 |
+
--config training_v2/configs/nano.json \
|
| 14 |
+
--tokenizer models/vectrayx_bpe.model \
|
| 15 |
+
--resume checkpoints/nano_sft_v5.pt \
|
| 16 |
+
--tool-corpus corpus/tool_sft_v2_simple.jsonl \
|
| 17 |
+
--out checkpoints/nano_lora_tools \
|
| 18 |
+
--lora-rank 16 --lora-alpha 32 \
|
| 19 |
+
--batch-size 16 --grad-accum 4 --epochs 5 --lr 2e-4
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import math
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import sentencepiece as spm
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
from torch.utils.data import DataLoader
|
| 34 |
+
|
| 35 |
+
ROOT = Path(__file__).resolve().parents[2]
|
| 36 |
+
sys.path.insert(0, str(ROOT))
|
| 37 |
+
|
| 38 |
+
from training_v2.data.sft_dataset import SFTDataset
|
| 39 |
+
from training_v2.model.transformer import VectraYXNano, ModelConfig
|
| 40 |
+
from training_v2.train.utils import (
|
| 41 |
+
cosine_with_warmup, log_jsonl,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# LoRA implementation
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
class LoRALinear(nn.Module):
|
| 50 |
+
"""Reemplaza un nn.Linear con LoRA: W' = W + (B @ A) * scale."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.linear = linear # pesos base — CONGELADOS
|
| 55 |
+
self.rank = rank
|
| 56 |
+
self.scale = alpha / rank
|
| 57 |
+
|
| 58 |
+
in_f = linear.in_features
|
| 59 |
+
out_f = linear.out_features
|
| 60 |
+
|
| 61 |
+
# A: inicialización kaiming, B: ceros (LoRA paper §4)
|
| 62 |
+
self.lora_A = nn.Parameter(torch.empty(rank, in_f))
|
| 63 |
+
self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
|
| 64 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 65 |
+
|
| 66 |
+
# Congelar pesos base
|
| 67 |
+
for p in self.linear.parameters():
|
| 68 |
+
p.requires_grad_(False)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
base = self.linear(x)
|
| 72 |
+
# Asegurar que lora_A y lora_B estén en el mismo device que x
|
| 73 |
+
lora = (x @ self.lora_A.to(x.device).T) @ self.lora_B.to(x.device).T
|
| 74 |
+
return base + lora * self.scale
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def inject_lora(model: nn.Module, rank: int, alpha: float,
|
| 78 |
+
target_modules=("wq", "wk", "wv", "wo")) -> int:
|
| 79 |
+
"""Inyecta LoRA en todas las capas de atención del modelo.
|
| 80 |
+
|
| 81 |
+
Retorna el número de parámetros entrenables.
|
| 82 |
+
"""
|
| 83 |
+
replaced = 0
|
| 84 |
+
for name, module in model.named_modules():
|
| 85 |
+
for attr_name in target_modules:
|
| 86 |
+
if hasattr(module, attr_name):
|
| 87 |
+
original = getattr(module, attr_name)
|
| 88 |
+
if isinstance(original, nn.Linear):
|
| 89 |
+
setattr(module, attr_name, LoRALinear(original, rank, alpha))
|
| 90 |
+
replaced += 1
|
| 91 |
+
|
| 92 |
+
# Congelar todo excepto LoRA
|
| 93 |
+
for name, param in model.named_parameters():
|
| 94 |
+
if "lora_A" not in name and "lora_B" not in name:
|
| 95 |
+
param.requires_grad_(False)
|
| 96 |
+
|
| 97 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 98 |
+
total = sum(p.numel() for p in model.parameters())
|
| 99 |
+
print(f"[lora] Inyectado en {replaced} módulos | "
|
| 100 |
+
f"Entrenables: {trainable/1e3:.1f}K / {total/1e6:.2f}M "
|
| 101 |
+
f"({trainable/total*100:.2f}%)")
|
| 102 |
+
return trainable
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def save_lora_checkpoint(path: Path, model: nn.Module, optimizer, step: int,
|
| 106 |
+
extra: dict = None):
|
| 107 |
+
"""Guarda solo los pesos LoRA (no el modelo base)."""
|
| 108 |
+
lora_state = {k: v for k, v in model.state_dict().items()
|
| 109 |
+
if "lora_A" in k or "lora_B" in k}
|
| 110 |
+
torch.save({
|
| 111 |
+
"lora_state_dict": lora_state,
|
| 112 |
+
"optimizer_state_dict": optimizer.state_dict() if optimizer else None,
|
| 113 |
+
"step": step,
|
| 114 |
+
**(extra or {}),
|
| 115 |
+
}, path)
|
| 116 |
+
print(f"[save] LoRA checkpoint → {path} ({path.stat().st_size/1e6:.1f}MB)")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_lora_checkpoint(path: Path, model: nn.Module, optimizer=None,
|
| 120 |
+
map_location="cpu"):
|
| 121 |
+
"""Carga pesos LoRA en el modelo."""
|
| 122 |
+
ckpt = torch.load(path, map_location=map_location)
|
| 123 |
+
missing, unexpected = model.load_state_dict(ckpt["lora_state_dict"], strict=False)
|
| 124 |
+
lora_keys = [k for k in ckpt["lora_state_dict"]]
|
| 125 |
+
print(f"[load] LoRA: {len(lora_keys)} keys loaded, "
|
| 126 |
+
f"{len(missing)} missing, {len(unexpected)} unexpected")
|
| 127 |
+
if optimizer and ckpt.get("optimizer_state_dict"):
|
| 128 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 129 |
+
return ckpt.get("step", 0)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# Main
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
p = argparse.ArgumentParser()
|
| 138 |
+
p.add_argument("--config", required=True)
|
| 139 |
+
p.add_argument("--tokenizer", required=True)
|
| 140 |
+
p.add_argument("--resume", required=True, help="checkpoint base a fine-tunear")
|
| 141 |
+
p.add_argument("--tool-corpus", required=True, help="tool-use JSONL corpus")
|
| 142 |
+
p.add_argument("--out", required=True)
|
| 143 |
+
# LoRA
|
| 144 |
+
p.add_argument("--lora-rank", type=int, default=16,
|
| 145 |
+
help="LoRA rank r (default 16)")
|
| 146 |
+
p.add_argument("--lora-alpha", type=float, default=32.0,
|
| 147 |
+
help="LoRA alpha (default 32, scale=alpha/rank=2)")
|
| 148 |
+
p.add_argument("--lora-targets", nargs="+",
|
| 149 |
+
default=["wq", "wk", "wv", "wo"],
|
| 150 |
+
help="Módulos de atención a inyectar LoRA")
|
| 151 |
+
# Training
|
| 152 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 153 |
+
p.add_argument("--grad-accum", type=int, default=4)
|
| 154 |
+
p.add_argument("--epochs", type=int, default=5)
|
| 155 |
+
p.add_argument("--lr", type=float, default=2e-4,
|
| 156 |
+
help="LR más alto que full FT (LoRA converge más rápido)")
|
| 157 |
+
p.add_argument("--weight-decay", type=float, default=0.01)
|
| 158 |
+
p.add_argument("--grad-clip", type=float, default=1.0)
|
| 159 |
+
p.add_argument("--warmup-frac", type=float, default=0.05)
|
| 160 |
+
p.add_argument("--num-workers", type=int, default=2)
|
| 161 |
+
p.add_argument("--log-every", type=int, default=10)
|
| 162 |
+
p.add_argument("--save-every", type=int, default=200)
|
| 163 |
+
p.add_argument("--seed", type=int, default=42)
|
| 164 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 165 |
+
p.add_argument("--dtype", default="bfloat16",
|
| 166 |
+
choices=["bfloat16", "float16", "float32"])
|
| 167 |
+
p.add_argument("--max-steps", type=int, default=None)
|
| 168 |
+
args = p.parse_args()
|
| 169 |
+
|
| 170 |
+
torch.manual_seed(args.seed)
|
| 171 |
+
np.random.seed(args.seed)
|
| 172 |
+
|
| 173 |
+
# 1. Cargar modelo base
|
| 174 |
+
cfg = ModelConfig.from_json(args.config)
|
| 175 |
+
model = VectraYXNano(cfg).to(args.device)
|
| 176 |
+
total_params = model.num_params()
|
| 177 |
+
print(f"[model] {total_params/1e6:.2f}M params (base)")
|
| 178 |
+
|
| 179 |
+
# Cargar checkpoint base (full weights) usando load_checkpoint de utils
|
| 180 |
+
from training_v2.train.utils import load_checkpoint as _load_ckpt
|
| 181 |
+
_load_ckpt(args.resume, model, optimizer=None, map_location=args.device)
|
| 182 |
+
print(f"[resume] {args.resume}")
|
| 183 |
+
|
| 184 |
+
# 2. Inyectar LoRA
|
| 185 |
+
trainable = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha,
|
| 186 |
+
target_modules=args.lora_targets)
|
| 187 |
+
# Mover parámetros LoRA al mismo device que el modelo
|
| 188 |
+
model = model.to(args.device)
|
| 189 |
+
|
| 190 |
+
# 3. Tokenizer
|
| 191 |
+
sp = spm.SentencePieceProcessor()
|
| 192 |
+
sp.load(args.tokenizer)
|
| 193 |
+
pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
|
| 194 |
+
|
| 195 |
+
# 4. Dataset
|
| 196 |
+
block_size = cfg.max_seq_len
|
| 197 |
+
tool_corpus = Path(args.tool_corpus)
|
| 198 |
+
if not tool_corpus.exists():
|
| 199 |
+
raise FileNotFoundError(f"Tool corpus not found: {tool_corpus}")
|
| 200 |
+
|
| 201 |
+
dataset = SFTDataset([tool_corpus], sp, block_size, pad_id=pad_id, seed=args.seed)
|
| 202 |
+
print(f"[dataset] {len(dataset)} ejemplos de {tool_corpus.name}")
|
| 203 |
+
|
| 204 |
+
# 5. Output dir
|
| 205 |
+
out_dir = Path(args.out)
|
| 206 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
log_path = out_dir / "train_log.jsonl"
|
| 208 |
+
|
| 209 |
+
# 6. Optimizer — solo parámetros LoRA
|
| 210 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 211 |
+
optimizer = torch.optim.AdamW(lora_params, lr=args.lr,
|
| 212 |
+
weight_decay=args.weight_decay,
|
| 213 |
+
betas=(0.9, 0.95))
|
| 214 |
+
|
| 215 |
+
# 7. AMP
|
| 216 |
+
dtype = {"bfloat16": torch.bfloat16,
|
| 217 |
+
"float16": torch.float16,
|
| 218 |
+
"float32": torch.float32}[args.dtype]
|
| 219 |
+
use_amp = args.device == "cuda" and dtype != torch.float32
|
| 220 |
+
|
| 221 |
+
# 8. Training loop
|
| 222 |
+
def collate(batch):
|
| 223 |
+
xs = torch.stack([b[0] for b in batch])
|
| 224 |
+
ys = torch.stack([b[1] for b in batch])
|
| 225 |
+
ms = torch.stack([b[2] for b in batch])
|
| 226 |
+
return xs, ys, ms
|
| 227 |
+
|
| 228 |
+
loader = DataLoader(
|
| 229 |
+
dataset, batch_size=args.batch_size, shuffle=True,
|
| 230 |
+
num_workers=args.num_workers, collate_fn=collate, pin_memory=True,
|
| 231 |
+
persistent_workers=args.num_workers > 0,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
steps_per_epoch = max(1, len(loader) // args.grad_accum)
|
| 235 |
+
total_steps = steps_per_epoch * args.epochs
|
| 236 |
+
if args.max_steps:
|
| 237 |
+
total_steps = min(total_steps, args.max_steps)
|
| 238 |
+
warmup = max(20, int(args.warmup_frac * total_steps))
|
| 239 |
+
|
| 240 |
+
print(f"\n[train] LoRA rank={args.lora_rank} alpha={args.lora_alpha} "
|
| 241 |
+
f"scale={args.lora_alpha/args.lora_rank:.1f}")
|
| 242 |
+
print(f"[train] epochs={args.epochs} steps/epoch≈{steps_per_epoch} "
|
| 243 |
+
f"total={total_steps} warmup={warmup}")
|
| 244 |
+
print(f"[train] lr={args.lr} batch={args.batch_size} accum={args.grad_accum} "
|
| 245 |
+
f"effective_batch={args.batch_size * args.grad_accum}")
|
| 246 |
+
|
| 247 |
+
model.train()
|
| 248 |
+
t_start = time.time()
|
| 249 |
+
step = 0
|
| 250 |
+
running_loss = 0.0
|
| 251 |
+
running_n = 0
|
| 252 |
+
|
| 253 |
+
for ep in range(args.epochs):
|
| 254 |
+
print(f"\n=== epoch {ep+1}/{args.epochs} (LoRA tool-SFT) ===")
|
| 255 |
+
data_iter = iter(loader)
|
| 256 |
+
|
| 257 |
+
for _ in range(steps_per_epoch):
|
| 258 |
+
if args.max_steps and step >= args.max_steps:
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
|
| 262 |
+
for g in optimizer.param_groups:
|
| 263 |
+
g["lr"] = cur_lr
|
| 264 |
+
|
| 265 |
+
optimizer.zero_grad(set_to_none=True)
|
| 266 |
+
loss_accum = 0.0
|
| 267 |
+
|
| 268 |
+
for _micro in range(args.grad_accum):
|
| 269 |
+
try:
|
| 270 |
+
xs, ys, ms = next(data_iter)
|
| 271 |
+
except StopIteration:
|
| 272 |
+
data_iter = iter(loader)
|
| 273 |
+
xs, ys, ms = next(data_iter)
|
| 274 |
+
|
| 275 |
+
xs = xs.to(args.device, non_blocking=True)
|
| 276 |
+
ys = ys.to(args.device, non_blocking=True)
|
| 277 |
+
ms = ms.to(args.device, non_blocking=True)
|
| 278 |
+
|
| 279 |
+
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
|
| 280 |
+
_, loss = model(xs, targets=ys, loss_mask=ms)
|
| 281 |
+
loss = loss / args.grad_accum
|
| 282 |
+
loss.backward()
|
| 283 |
+
loss_accum += loss.item() * args.grad_accum
|
| 284 |
+
|
| 285 |
+
gnorm = torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
| 286 |
+
optimizer.step()
|
| 287 |
+
step += 1
|
| 288 |
+
running_loss += loss_accum / args.grad_accum
|
| 289 |
+
running_n += 1
|
| 290 |
+
|
| 291 |
+
if step % args.log_every == 0:
|
| 292 |
+
elapsed = time.time() - t_start
|
| 293 |
+
avg = running_loss / running_n
|
| 294 |
+
print(f"[lora ep{ep+1} step {step:>4}/{total_steps}] "
|
| 295 |
+
f"loss={avg:.4f} lr={cur_lr:.2e} "
|
| 296 |
+
f"gnorm={gnorm:.2f} {elapsed/60:.1f}min")
|
| 297 |
+
log_jsonl(log_path, {"epoch": ep+1, "step": step, "loss": avg,
|
| 298 |
+
"lr": cur_lr, "gnorm": float(gnorm)})
|
| 299 |
+
running_loss = 0.0
|
| 300 |
+
running_n = 0
|
| 301 |
+
|
| 302 |
+
if step % args.save_every == 0:
|
| 303 |
+
save_lora_checkpoint(out_dir / "last_lora.pt", model, optimizer,
|
| 304 |
+
step, {"epoch": ep+1})
|
| 305 |
+
|
| 306 |
+
if args.max_steps and step >= args.max_steps:
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
save_lora_checkpoint(out_dir / f"epoch{ep+1}_lora.pt", model, optimizer,
|
| 310 |
+
step, {"epoch": ep+1})
|
| 311 |
+
print(f"[save] epoch{ep+1}_lora.pt")
|
| 312 |
+
|
| 313 |
+
# Guardar checkpoint final con pesos COMPLETOS (base + LoRA merged)
|
| 314 |
+
# Estrategia: construir state_dict manualmente fusionando LoRA
|
| 315 |
+
print("\n[merge] Mergeando LoRA en pesos base...")
|
| 316 |
+
|
| 317 |
+
# Primero recolectar todos los módulos LoRA con sus rutas
|
| 318 |
+
lora_modules = {}
|
| 319 |
+
for mod_name, mod in model.named_modules():
|
| 320 |
+
if isinstance(mod, LoRALinear):
|
| 321 |
+
lora_modules[mod_name] = mod
|
| 322 |
+
|
| 323 |
+
# Construir state_dict fusionado
|
| 324 |
+
merged_state = {}
|
| 325 |
+
for param_name, param in model.named_parameters():
|
| 326 |
+
# Detectar si este parámetro pertenece a un LoRALinear
|
| 327 |
+
is_lora_internal = False
|
| 328 |
+
for lora_path in lora_modules:
|
| 329 |
+
if param_name.startswith(lora_path + ".lora_"):
|
| 330 |
+
is_lora_internal = True # saltar lora_A y lora_B
|
| 331 |
+
break
|
| 332 |
+
if param_name == lora_path + ".linear.weight":
|
| 333 |
+
# Fusionar con LoRA
|
| 334 |
+
lora_mod = lora_modules[lora_path]
|
| 335 |
+
fused = param.data + (lora_mod.lora_B.data @ lora_mod.lora_A.data) * lora_mod.scale
|
| 336 |
+
# Guardar con nombre limpio (sin .linear)
|
| 337 |
+
clean = lora_path + ".weight"
|
| 338 |
+
merged_state[clean] = fused
|
| 339 |
+
is_lora_internal = True
|
| 340 |
+
break
|
| 341 |
+
if param_name == lora_path + ".linear.bias":
|
| 342 |
+
clean = lora_path + ".bias"
|
| 343 |
+
merged_state[clean] = param.data
|
| 344 |
+
is_lora_internal = True
|
| 345 |
+
break
|
| 346 |
+
if not is_lora_internal:
|
| 347 |
+
merged_state[param_name] = param.data
|
| 348 |
+
|
| 349 |
+
print(f"[merge] {len(merged_state)} keys en merged state_dict")
|
| 350 |
+
|
| 351 |
+
# Guardar solo LoRA ANTES de modificar el modelo
|
| 352 |
+
save_lora_checkpoint(out_dir / "final_lora_only.pt", model, optimizer,
|
| 353 |
+
step, {"done": True, "lora_rank": args.lora_rank,
|
| 354 |
+
"lora_alpha": args.lora_alpha})
|
| 355 |
+
|
| 356 |
+
# Guardar merged (full model) para benchmark — usar clave "model" que espera load_checkpoint
|
| 357 |
+
# strict=False en benchmark porque lm_head comparte pesos con tok_emb (tie_embeddings)
|
| 358 |
+
torch.save({"model": merged_state, "step": step,
|
| 359 |
+
"lora_rank": args.lora_rank, "lora_alpha": args.lora_alpha,
|
| 360 |
+
"merged": True, "tie_embeddings": True},
|
| 361 |
+
out_dir / "final.pt")
|
| 362 |
+
print(f"[done] final.pt (merged) → {out_dir}")
|
| 363 |
+
print(f"[done] final_lora_only.pt (adapter only) → {out_dir}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
main()
|
training/finetune_sft.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SFT fine-tuning with assistant-only loss masking and an internal mini-curriculum.
|
| 2 |
+
|
| 3 |
+
Mini-curriculum (within SFT):
|
| 4 |
+
Epoch 1-2: 60% conversational (OASST1 ES + sft_conv) + 40% CVE Q&A
|
| 5 |
+
Epoch 3: add tool-use (50% conv + 25% CVE + 25% tool_use)
|
| 6 |
+
|
| 7 |
+
This avoids drowning the chat behavior in JSON tool-call patterns the way SFT v3 did.
|
| 8 |
+
|
| 9 |
+
Run example:
|
| 10 |
+
python -m training_v2.train.finetune_sft \
|
| 11 |
+
--config training_v2/configs/nano.json \
|
| 12 |
+
--tokenizer training_v2/tokenizer/out/vectrayx_bpe.model \
|
| 13 |
+
--resume training_v2/checkpoints/phase3/last.pt \
|
| 14 |
+
--out training_v2/checkpoints/sft_v4 \
|
| 15 |
+
--batch-size 16 --grad-accum 4 --epochs 3 --lr 2e-5
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import sentencepiece as spm
|
| 26 |
+
import torch
|
| 27 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 28 |
+
|
| 29 |
+
ROOT = Path(__file__).resolve().parents[2]
|
| 30 |
+
sys.path.insert(0, str(ROOT))
|
| 31 |
+
|
| 32 |
+
from training_v2.data.sft_dataset import SFTDataset
|
| 33 |
+
from training_v2.model.transformer import VectraYXNano, ModelConfig
|
| 34 |
+
from training_v2.train.utils import (
|
| 35 |
+
cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
SFT_FILES = {
|
| 40 |
+
"conversational": [
|
| 41 |
+
"corpus/sft_conversational.jsonl",
|
| 42 |
+
"sft_v2_data/oasst1_es.jsonl",
|
| 43 |
+
],
|
| 44 |
+
"cve_qa": [
|
| 45 |
+
"corpus/sft_v2_dataset.jsonl",
|
| 46 |
+
],
|
| 47 |
+
"tool_use": [
|
| 48 |
+
"corpus/tooluse_dataset.jsonl",
|
| 49 |
+
],
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_sft_corpus_config(path):
|
| 54 |
+
global SFT_FILES
|
| 55 |
+
cfg = json.loads(Path(path).read_text())
|
| 56 |
+
SFT_FILES = {
|
| 57 |
+
"conversational": cfg.get("sft_conversational", SFT_FILES["conversational"]),
|
| 58 |
+
"cve_qa": cfg.get("sft_cve_qa", SFT_FILES["cve_qa"]),
|
| 59 |
+
"tool_use": cfg.get("sft_tool_use", SFT_FILES["tool_use"]),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def discover(paths, root):
|
| 64 |
+
found = []
|
| 65 |
+
for rel in paths:
|
| 66 |
+
full = Path(root) / rel
|
| 67 |
+
if full.exists():
|
| 68 |
+
found.append(full)
|
| 69 |
+
else:
|
| 70 |
+
print(f" [skip missing] {full}")
|
| 71 |
+
return found
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def build_dataset(args, sp, include_tools):
|
| 75 |
+
block_size = ModelConfig.from_json(args.config).max_seq_len
|
| 76 |
+
pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
|
| 77 |
+
|
| 78 |
+
conv = discover(SFT_FILES["conversational"], args.corpus_root)
|
| 79 |
+
cve = discover(SFT_FILES["cve_qa"], args.corpus_root)
|
| 80 |
+
tools = discover(SFT_FILES["tool_use"], args.corpus_root)
|
| 81 |
+
|
| 82 |
+
parts = []
|
| 83 |
+
if conv:
|
| 84 |
+
parts.append(("conv", SFTDataset(conv, sp, block_size, pad_id=pad_id, seed=args.seed)))
|
| 85 |
+
if cve:
|
| 86 |
+
parts.append(("cve", SFTDataset(cve, sp, block_size, pad_id=pad_id, seed=args.seed + 1)))
|
| 87 |
+
if include_tools and tools:
|
| 88 |
+
parts.append(("tools", SFTDataset(tools, sp, block_size, pad_id=pad_id, seed=args.seed + 2)))
|
| 89 |
+
return parts, pad_id
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_loader(parts, weights, batch_size, num_workers):
|
| 93 |
+
"""Weighted sampling across the named parts."""
|
| 94 |
+
sizes = [len(d) for _, d in parts]
|
| 95 |
+
names = [n for n, _ in parts]
|
| 96 |
+
datasets = [d for _, d in parts]
|
| 97 |
+
big = ConcatDataset(datasets)
|
| 98 |
+
|
| 99 |
+
offsets = np.cumsum([0] + sizes)
|
| 100 |
+
weight_per_idx = np.zeros(offsets[-1], dtype=np.float64)
|
| 101 |
+
for i, n in enumerate(names):
|
| 102 |
+
w = weights.get(n, 1.0) / max(1, sizes[i])
|
| 103 |
+
weight_per_idx[offsets[i]:offsets[i + 1]] = w
|
| 104 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
| 105 |
+
weights=weight_per_idx,
|
| 106 |
+
num_samples=int(sum(sizes)),
|
| 107 |
+
replacement=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def collate(batch):
|
| 111 |
+
xs = torch.stack([b[0] for b in batch], 0)
|
| 112 |
+
ys = torch.stack([b[1] for b in batch], 0)
|
| 113 |
+
ms = torch.stack([b[2] for b in batch], 0)
|
| 114 |
+
return xs, ys, ms
|
| 115 |
+
|
| 116 |
+
return DataLoader(
|
| 117 |
+
big, batch_size=batch_size, sampler=sampler,
|
| 118 |
+
num_workers=num_workers, collate_fn=collate, pin_memory=True,
|
| 119 |
+
persistent_workers=num_workers > 0,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main():
|
| 124 |
+
p = argparse.ArgumentParser()
|
| 125 |
+
p.add_argument("--config", required=True)
|
| 126 |
+
p.add_argument("--tokenizer", required=True)
|
| 127 |
+
p.add_argument("--resume", required=True, help="pre-training checkpoint to fine-tune")
|
| 128 |
+
p.add_argument("--out", required=True)
|
| 129 |
+
p.add_argument("--corpus-root", default=".")
|
| 130 |
+
p.add_argument("--corpus-config", default=None)
|
| 131 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 132 |
+
p.add_argument("--grad-accum", type=int, default=4)
|
| 133 |
+
p.add_argument("--epochs", type=int, default=3)
|
| 134 |
+
p.add_argument("--lr", type=float, default=2e-5)
|
| 135 |
+
p.add_argument("--weight-decay", type=float, default=0.0)
|
| 136 |
+
p.add_argument("--grad-clip", type=float, default=1.0)
|
| 137 |
+
p.add_argument("--warmup-frac", type=float, default=0.03)
|
| 138 |
+
p.add_argument("--num-workers", type=int, default=2)
|
| 139 |
+
p.add_argument("--log-every", type=int, default=20)
|
| 140 |
+
p.add_argument("--save-every", type=int, default=500)
|
| 141 |
+
p.add_argument("--seed", type=int, default=42)
|
| 142 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 143 |
+
p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
| 144 |
+
args = p.parse_args()
|
| 145 |
+
|
| 146 |
+
if args.corpus_config:
|
| 147 |
+
load_sft_corpus_config(args.corpus_config)
|
| 148 |
+
|
| 149 |
+
torch.manual_seed(args.seed)
|
| 150 |
+
np.random.seed(args.seed)
|
| 151 |
+
|
| 152 |
+
cfg = ModelConfig.from_json(args.config)
|
| 153 |
+
model = VectraYXNano(cfg).to(args.device)
|
| 154 |
+
print(f"[model] {model.num_params()/1e6:.2f}M params")
|
| 155 |
+
load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
|
| 156 |
+
print(f"[resume] {args.resume}")
|
| 157 |
+
|
| 158 |
+
sp = spm.SentencePieceProcessor()
|
| 159 |
+
sp.load(args.tokenizer)
|
| 160 |
+
parts, pad_id = build_dataset(args, sp, include_tools=True)
|
| 161 |
+
if not parts:
|
| 162 |
+
raise RuntimeError("no SFT files found")
|
| 163 |
+
|
| 164 |
+
out_dir = Path(args.out)
|
| 165 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
log_path = out_dir / "train_log.jsonl"
|
| 167 |
+
|
| 168 |
+
optimizer = make_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
|
| 169 |
+
|
| 170 |
+
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
|
| 171 |
+
use_amp = args.device == "cuda" and dtype != torch.float32
|
| 172 |
+
|
| 173 |
+
epoch_plans = [
|
| 174 |
+
{"conv": 1.00, "cve": 0.00, "tools": 0.0}, # epoch 1: SOLO conversacional
|
| 175 |
+
{"conv": 0.70, "cve": 0.30, "tools": 0.00}, # epoch 2: + CVE Q&A
|
| 176 |
+
{"conv": 0.55, "cve": 0.30, "tools": 0.15}, # epoch 3: + tool use
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
total_steps = 0
|
| 180 |
+
for ep in range(args.epochs):
|
| 181 |
+
weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
|
| 182 |
+
print(f"\n=== epoch {ep+1}/{args.epochs} | mix={weights} ===")
|
| 183 |
+
loader = make_loader(parts, weights, args.batch_size, args.num_workers)
|
| 184 |
+
steps_per_epoch = max(1, len(loader) // args.grad_accum)
|
| 185 |
+
total_steps += steps_per_epoch
|
| 186 |
+
warmup = max(50, int(args.warmup_frac * total_steps))
|
| 187 |
+
print(f"[sft] total_steps≈{total_steps} warmup={warmup}")
|
| 188 |
+
|
| 189 |
+
model.train()
|
| 190 |
+
t_start = time.time()
|
| 191 |
+
step = 0
|
| 192 |
+
running_loss = 0.0
|
| 193 |
+
running_n = 0
|
| 194 |
+
|
| 195 |
+
for ep in range(args.epochs):
|
| 196 |
+
weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
|
| 197 |
+
loader = make_loader(parts, weights, args.batch_size, args.num_workers)
|
| 198 |
+
data_iter = iter(loader)
|
| 199 |
+
steps_per_epoch = max(1, len(loader) // args.grad_accum)
|
| 200 |
+
|
| 201 |
+
for _ in range(steps_per_epoch):
|
| 202 |
+
cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
|
| 203 |
+
for g in optimizer.param_groups:
|
| 204 |
+
g["lr"] = cur_lr
|
| 205 |
+
|
| 206 |
+
optimizer.zero_grad(set_to_none=True)
|
| 207 |
+
loss_accum = 0.0
|
| 208 |
+
for _micro in range(args.grad_accum):
|
| 209 |
+
try:
|
| 210 |
+
xs, ys, ms = next(data_iter)
|
| 211 |
+
except StopIteration:
|
| 212 |
+
data_iter = iter(loader)
|
| 213 |
+
xs, ys, ms = next(data_iter)
|
| 214 |
+
xs = xs.to(args.device, non_blocking=True)
|
| 215 |
+
ys = ys.to(args.device, non_blocking=True)
|
| 216 |
+
ms = ms.to(args.device, non_blocking=True)
|
| 217 |
+
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
|
| 218 |
+
_, loss = model(xs, targets=ys, loss_mask=ms)
|
| 219 |
+
loss = loss / args.grad_accum
|
| 220 |
+
loss.backward()
|
| 221 |
+
loss_accum += loss.item() * args.grad_accum
|
| 222 |
+
|
| 223 |
+
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 224 |
+
optimizer.step()
|
| 225 |
+
step += 1
|
| 226 |
+
running_loss += loss_accum / args.grad_accum
|
| 227 |
+
running_n += 1
|
| 228 |
+
|
| 229 |
+
if step % args.log_every == 0:
|
| 230 |
+
elapsed = time.time() - t_start
|
| 231 |
+
avg = running_loss / running_n
|
| 232 |
+
print(f"[sft ep{ep+1} step {step:>5}/{total_steps}] loss={avg:.4f} "
|
| 233 |
+
f"lr={cur_lr:.2e} gnorm={gnorm:.2f} elapsed={elapsed/60:.1f}min")
|
| 234 |
+
log_jsonl(log_path, {"epoch": ep + 1, "step": step, "loss": avg,
|
| 235 |
+
"lr": cur_lr, "gnorm": float(gnorm)})
|
| 236 |
+
running_loss = 0.0
|
| 237 |
+
running_n = 0
|
| 238 |
+
|
| 239 |
+
if step % args.save_every == 0:
|
| 240 |
+
save_checkpoint(out_dir / "last.pt", model, optimizer,
|
| 241 |
+
{"step": step}, step,
|
| 242 |
+
extra={"epoch": ep + 1, "weights": weights})
|
| 243 |
+
|
| 244 |
+
save_checkpoint(out_dir / f"epoch{ep+1}.pt", model, optimizer,
|
| 245 |
+
{"step": step}, step,
|
| 246 |
+
extra={"epoch": ep + 1, "weights": weights})
|
| 247 |
+
print(f"[save] {out_dir}/epoch{ep+1}.pt")
|
| 248 |
+
|
| 249 |
+
save_checkpoint(out_dir / "final.pt", model, optimizer, {"step": step}, step,
|
| 250 |
+
extra={"done": True})
|
| 251 |
+
print(f"[done] SFT → {out_dir}/final.pt")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
main()
|
training/finetune_tools.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool-use focused SFT for VectraYX Nano/Base.
|
| 2 |
+
|
| 3 |
+
This is a simplified version of finetune_sft.py that trains ONLY on tool-use examples.
|
| 4 |
+
The goal is to test the hypothesis that B4=0.000 is due to diluted tool-call gradients
|
| 5 |
+
in the mixed SFT corpus, not a capacity gate.
|
| 6 |
+
|
| 7 |
+
Run example:
|
| 8 |
+
python -m training_v2.train.finetune_tools \
|
| 9 |
+
--config training_v2/configs/nano.json \
|
| 10 |
+
--tokenizer models/vectrayx_bpe.model \
|
| 11 |
+
--resume checkpoints/nano_final.pt \
|
| 12 |
+
--tool-corpus /tmp/tool_sft_v1.jsonl \
|
| 13 |
+
--out checkpoints/tool_sft_nano \
|
| 14 |
+
--batch-size 16 --grad-accum 4 --epochs 2 --lr 1e-5
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import sentencepiece as spm
|
| 25 |
+
import torch
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
|
| 28 |
+
ROOT = Path(__file__).resolve().parents[2]
|
| 29 |
+
sys.path.insert(0, str(ROOT))
|
| 30 |
+
|
| 31 |
+
from training_v2.data.sft_dataset import SFTDataset
|
| 32 |
+
from training_v2.model.transformer import VectraYXNano, ModelConfig
|
| 33 |
+
from training_v2.train.utils import (
|
| 34 |
+
cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
p = argparse.ArgumentParser()
|
| 40 |
+
p.add_argument("--config", required=True)
|
| 41 |
+
p.add_argument("--tokenizer", required=True)
|
| 42 |
+
p.add_argument("--resume", required=True, help="checkpoint to fine-tune from")
|
| 43 |
+
p.add_argument("--tool-corpus", required=True, help="tool-use JSONL corpus")
|
| 44 |
+
p.add_argument("--out", required=True)
|
| 45 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 46 |
+
p.add_argument("--grad-accum", type=int, default=4)
|
| 47 |
+
p.add_argument("--epochs", type=int, default=2)
|
| 48 |
+
p.add_argument("--lr", type=float, default=1e-5)
|
| 49 |
+
p.add_argument("--weight-decay", type=float, default=0.0)
|
| 50 |
+
p.add_argument("--grad-clip", type=float, default=1.0)
|
| 51 |
+
p.add_argument("--warmup-frac", type=float, default=0.03)
|
| 52 |
+
p.add_argument("--num-workers", type=int, default=2)
|
| 53 |
+
p.add_argument("--log-every", type=int, default=20)
|
| 54 |
+
p.add_argument("--save-every", type=int, default=500)
|
| 55 |
+
p.add_argument("--seed", type=int, default=42)
|
| 56 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
| 58 |
+
p.add_argument("--max-steps", type=int, default=None, help="for testing")
|
| 59 |
+
args = p.parse_args()
|
| 60 |
+
|
| 61 |
+
torch.manual_seed(args.seed)
|
| 62 |
+
np.random.seed(args.seed)
|
| 63 |
+
|
| 64 |
+
# Load model
|
| 65 |
+
cfg = ModelConfig.from_json(args.config)
|
| 66 |
+
model = VectraYXNano(cfg).to(args.device)
|
| 67 |
+
print(f"[model] {model.num_params()/1e6:.2f}M params")
|
| 68 |
+
load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
|
| 69 |
+
print(f"[resume] {args.resume}")
|
| 70 |
+
|
| 71 |
+
# Load tokenizer
|
| 72 |
+
sp = spm.SentencePieceProcessor()
|
| 73 |
+
sp.load(args.tokenizer)
|
| 74 |
+
pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
|
| 75 |
+
|
| 76 |
+
# Build tool-only dataset
|
| 77 |
+
block_size = cfg.max_seq_len
|
| 78 |
+
tool_corpus = Path(args.tool_corpus)
|
| 79 |
+
if not tool_corpus.exists():
|
| 80 |
+
raise FileNotFoundError(f"Tool corpus not found: {tool_corpus}")
|
| 81 |
+
|
| 82 |
+
dataset = SFTDataset([tool_corpus], sp, block_size, pad_id=pad_id, seed=args.seed)
|
| 83 |
+
print(f"[dataset] {len(dataset)} tool-use examples from {tool_corpus}")
|
| 84 |
+
|
| 85 |
+
# Setup output
|
| 86 |
+
out_dir = Path(args.out)
|
| 87 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
log_path = out_dir / "train_log.jsonl"
|
| 89 |
+
|
| 90 |
+
# Optimizer
|
| 91 |
+
optimizer = make_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
|
| 92 |
+
|
| 93 |
+
# AMP setup
|
| 94 |
+
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
|
| 95 |
+
use_amp = args.device == "cuda" and dtype != torch.float32
|
| 96 |
+
|
| 97 |
+
# Training loop
|
| 98 |
+
def collate(batch):
|
| 99 |
+
xs = torch.stack([b[0] for b in batch], 0)
|
| 100 |
+
ys = torch.stack([b[1] for b in batch], 0)
|
| 101 |
+
ms = torch.stack([b[2] for b in batch], 0)
|
| 102 |
+
return xs, ys, ms
|
| 103 |
+
|
| 104 |
+
loader = DataLoader(
|
| 105 |
+
dataset, batch_size=args.batch_size, shuffle=True,
|
| 106 |
+
num_workers=args.num_workers, collate_fn=collate, pin_memory=True,
|
| 107 |
+
persistent_workers=args.num_workers > 0,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
steps_per_epoch = max(1, len(loader) // args.grad_accum)
|
| 111 |
+
total_steps = steps_per_epoch * args.epochs
|
| 112 |
+
if args.max_steps:
|
| 113 |
+
total_steps = min(total_steps, args.max_steps)
|
| 114 |
+
warmup = max(50, int(args.warmup_frac * total_steps))
|
| 115 |
+
print(f"[train] epochs={args.epochs} steps_per_epoch≈{steps_per_epoch} total_steps={total_steps} warmup={warmup}")
|
| 116 |
+
|
| 117 |
+
model.train()
|
| 118 |
+
t_start = time.time()
|
| 119 |
+
step = 0
|
| 120 |
+
running_loss = 0.0
|
| 121 |
+
running_n = 0
|
| 122 |
+
|
| 123 |
+
for ep in range(args.epochs):
|
| 124 |
+
print(f"\n=== epoch {ep+1}/{args.epochs} (tool-only) ===")
|
| 125 |
+
data_iter = iter(loader)
|
| 126 |
+
|
| 127 |
+
for _ in range(steps_per_epoch):
|
| 128 |
+
if args.max_steps and step >= args.max_steps:
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
|
| 132 |
+
for g in optimizer.param_groups:
|
| 133 |
+
g["lr"] = cur_lr
|
| 134 |
+
|
| 135 |
+
optimizer.zero_grad(set_to_none=True)
|
| 136 |
+
loss_accum = 0.0
|
| 137 |
+
for _micro in range(args.grad_accum):
|
| 138 |
+
try:
|
| 139 |
+
xs, ys, ms = next(data_iter)
|
| 140 |
+
except StopIteration:
|
| 141 |
+
data_iter = iter(loader)
|
| 142 |
+
xs, ys, ms = next(data_iter)
|
| 143 |
+
xs = xs.to(args.device, non_blocking=True)
|
| 144 |
+
ys = ys.to(args.device, non_blocking=True)
|
| 145 |
+
ms = ms.to(args.device, non_blocking=True)
|
| 146 |
+
|
| 147 |
+
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
|
| 148 |
+
_, loss = model(xs, targets=ys, loss_mask=ms)
|
| 149 |
+
loss = loss / args.grad_accum
|
| 150 |
+
loss.backward()
|
| 151 |
+
loss_accum += loss.item() * args.grad_accum
|
| 152 |
+
|
| 153 |
+
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 154 |
+
optimizer.step()
|
| 155 |
+
step += 1
|
| 156 |
+
running_loss += loss_accum / args.grad_accum
|
| 157 |
+
running_n += 1
|
| 158 |
+
|
| 159 |
+
if step % args.log_every == 0:
|
| 160 |
+
elapsed = time.time() - t_start
|
| 161 |
+
avg = running_loss / running_n
|
| 162 |
+
print(f"[tool-sft ep{ep+1} step {step:>5}/{total_steps}] loss={avg:.4f} "
|
| 163 |
+
f"lr={cur_lr:.2e} gnorm={gnorm:.2f} elapsed={elapsed/60:.1f}min")
|
| 164 |
+
log_jsonl(log_path, {"epoch": ep + 1, "step": step, "loss": avg,
|
| 165 |
+
"lr": cur_lr, "gnorm": float(gnorm)})
|
| 166 |
+
running_loss = 0.0
|
| 167 |
+
running_n = 0
|
| 168 |
+
|
| 169 |
+
if step % args.save_every == 0:
|
| 170 |
+
save_checkpoint(out_dir / "last.pt", model, optimizer,
|
| 171 |
+
{"step": step}, step,
|
| 172 |
+
extra={"epoch": ep + 1, "tool_only": True})
|
| 173 |
+
|
| 174 |
+
if args.max_steps and step >= args.max_steps:
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
save_checkpoint(out_dir / f"epoch{ep+1}.pt", model, optimizer,
|
| 178 |
+
{"step": step}, step,
|
| 179 |
+
extra={"epoch": ep + 1, "tool_only": True})
|
| 180 |
+
print(f"[save] {out_dir}/epoch{ep+1}.pt")
|
| 181 |
+
|
| 182 |
+
save_checkpoint(out_dir / "final.pt", model, optimizer, {"step": step}, step,
|
| 183 |
+
extra={"done": True, "tool_only": True})
|
| 184 |
+
print(f"[done] {out_dir}/final.pt")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
main()
|
training/pretrain.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Curriculum pre-training driver for VectraYX-Nano v2.
|
| 2 |
+
|
| 3 |
+
Phase 1: 100% conversational (LR 3e-4 from scratch)
|
| 4 |
+
Phase 2: 75% tech + 25% conv (LR 1.5e-4, resumed from phase 1)
|
| 5 |
+
Phase 3: 70% tools + 20% tech + 10% conv (LR 8e-5, resumed from phase 2)
|
| 6 |
+
|
| 7 |
+
Run example:
|
| 8 |
+
python -m training_v2.train.pretrain \
|
| 9 |
+
--config training_v2/configs/nano.json \
|
| 10 |
+
--bins training_v2/data/bins \
|
| 11 |
+
--out training_v2/checkpoints \
|
| 12 |
+
--phase 1 --max-steps 8000 --batch-size 16 --grad-accum 8
|
| 13 |
+
|
| 14 |
+
Then:
|
| 15 |
+
--phase 2 --resume training_v2/checkpoints/phase1/last.pt
|
| 16 |
+
--phase 3 --resume training_v2/checkpoints/phase2/last.pt
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.utils.data import DataLoader
|
| 29 |
+
|
| 30 |
+
ROOT = Path(__file__).resolve().parents[2]
|
| 31 |
+
sys.path.insert(0, str(ROOT))
|
| 32 |
+
|
| 33 |
+
from training_v2.data.curriculum_dataset import (
|
| 34 |
+
MixedCurriculumDataset, make_phase_mix, load_phase_summary,
|
| 35 |
+
)
|
| 36 |
+
from training_v2.model.transformer import VectraYXNano, ModelConfig
|
| 37 |
+
from training_v2.train.utils import (
|
| 38 |
+
cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
PHASE_LR = {1: 3.0e-4, 2: 1.5e-4, 3: 8.0e-5}
|
| 43 |
+
PHASE_WARMUP_FRAC = {1: 0.05, 2: 0.02, 3: 0.02}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_dataloader(args, mix, block_size):
|
| 47 |
+
phase_dirs = {
|
| 48 |
+
"phase1_conv": Path(args.bins) / "phase1_conv",
|
| 49 |
+
"phase2_tech": Path(args.bins) / "phase2_tech",
|
| 50 |
+
"phase3_tools": Path(args.bins) / "phase3_tools",
|
| 51 |
+
}
|
| 52 |
+
ds = MixedCurriculumDataset(
|
| 53 |
+
phase_dirs={k: v for k, v in phase_dirs.items() if mix.get(k, 0) > 0},
|
| 54 |
+
weights=mix,
|
| 55 |
+
block_size=block_size,
|
| 56 |
+
dtype=np.uint16,
|
| 57 |
+
seed=args.seed,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def collate(batch):
|
| 61 |
+
xs = torch.stack([b[0] for b in batch], 0)
|
| 62 |
+
ys = torch.stack([b[1] for b in batch], 0)
|
| 63 |
+
return xs, ys
|
| 64 |
+
|
| 65 |
+
return DataLoader(
|
| 66 |
+
ds,
|
| 67 |
+
batch_size=args.batch_size,
|
| 68 |
+
num_workers=args.num_workers,
|
| 69 |
+
collate_fn=collate,
|
| 70 |
+
pin_memory=True,
|
| 71 |
+
persistent_workers=args.num_workers > 0,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def estimate_phase_tokens(phase_idx, mix, summary):
|
| 76 |
+
total = 0.0
|
| 77 |
+
for k, w in mix.items():
|
| 78 |
+
n = summary.get(k, {}).get("n_tokens", 0)
|
| 79 |
+
if w > 0 and n > 0:
|
| 80 |
+
total += n
|
| 81 |
+
return int(total)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
p = argparse.ArgumentParser()
|
| 86 |
+
p.add_argument("--config", required=True)
|
| 87 |
+
p.add_argument("--bins", required=True, help="root of binary shard dirs")
|
| 88 |
+
p.add_argument("--out", required=True, help="checkpoint output root")
|
| 89 |
+
p.add_argument("--phase", type=int, choices=[1, 2, 3], required=True)
|
| 90 |
+
p.add_argument("--resume", type=str, default=None)
|
| 91 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 92 |
+
p.add_argument("--grad-accum", type=int, default=8)
|
| 93 |
+
p.add_argument("--max-steps", type=int, default=None)
|
| 94 |
+
p.add_argument("--epochs", type=float, default=2.0,
|
| 95 |
+
help="estimate steps as epochs*phase_tokens/(batch*ga*block)")
|
| 96 |
+
p.add_argument("--lr", type=float, default=None)
|
| 97 |
+
p.add_argument("--weight-decay", type=float, default=0.1)
|
| 98 |
+
p.add_argument("--grad-clip", type=float, default=1.0)
|
| 99 |
+
p.add_argument("--num-workers", type=int, default=2)
|
| 100 |
+
p.add_argument("--log-every", type=int, default=20)
|
| 101 |
+
p.add_argument("--save-every", type=int, default=1000)
|
| 102 |
+
p.add_argument("--seed", type=int, default=42)
|
| 103 |
+
p.add_argument("--replay-conv", type=float, default=None,
|
| 104 |
+
help="override replay ratio of conversational data in phase 2/3")
|
| 105 |
+
p.add_argument("--replay-tech", type=float, default=None,
|
| 106 |
+
help="override replay ratio of technical data in phase 3")
|
| 107 |
+
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
p.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
| 109 |
+
p.add_argument("--compile", action="store_true")
|
| 110 |
+
args = p.parse_args()
|
| 111 |
+
|
| 112 |
+
torch.manual_seed(args.seed)
|
| 113 |
+
np.random.seed(args.seed)
|
| 114 |
+
|
| 115 |
+
cfg = ModelConfig.from_json(args.config)
|
| 116 |
+
model = VectraYXNano(cfg).to(args.device)
|
| 117 |
+
n_params = model.num_params()
|
| 118 |
+
print(f"[model] {n_params/1e6:.2f}M params · cfg={cfg}")
|
| 119 |
+
|
| 120 |
+
mix = make_phase_mix(args.phase, replay_conv=args.replay_conv, replay_tech=args.replay_tech)
|
| 121 |
+
summary = load_phase_summary(args.bins)
|
| 122 |
+
phase_tokens = estimate_phase_tokens(args.phase, mix, summary)
|
| 123 |
+
tokens_per_step = args.batch_size * args.grad_accum * cfg.max_seq_len
|
| 124 |
+
if args.max_steps is None:
|
| 125 |
+
args.max_steps = max(1000, int(args.epochs * phase_tokens / tokens_per_step))
|
| 126 |
+
print(f"[phase {args.phase}] mix={mix}")
|
| 127 |
+
print(f"[phase {args.phase}] phase_tokens={phase_tokens:,} tokens/step={tokens_per_step:,} steps={args.max_steps}")
|
| 128 |
+
|
| 129 |
+
lr = args.lr if args.lr is not None else PHASE_LR[args.phase]
|
| 130 |
+
warmup = max(50, int(PHASE_WARMUP_FRAC[args.phase] * args.max_steps))
|
| 131 |
+
optimizer = make_optimizer(model, lr=lr, weight_decay=args.weight_decay)
|
| 132 |
+
|
| 133 |
+
start_step = 0
|
| 134 |
+
if args.resume:
|
| 135 |
+
start_step, _ = load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
|
| 136 |
+
print(f"[resume] loaded weights from {args.resume} (step={start_step})")
|
| 137 |
+
start_step = 0 # fresh optimizer for new phase
|
| 138 |
+
|
| 139 |
+
loader = build_dataloader(args, mix, cfg.max_seq_len)
|
| 140 |
+
data_iter = iter(loader)
|
| 141 |
+
|
| 142 |
+
out_dir = Path(args.out) / f"phase{args.phase}"
|
| 143 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
log_path = out_dir / "train_log.jsonl"
|
| 145 |
+
|
| 146 |
+
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
|
| 147 |
+
use_amp = args.device == "cuda" and dtype != torch.float32
|
| 148 |
+
scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
|
| 149 |
+
|
| 150 |
+
if args.compile:
|
| 151 |
+
try:
|
| 152 |
+
model = torch.compile(model)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"[compile] skipped: {e}")
|
| 155 |
+
|
| 156 |
+
model.train()
|
| 157 |
+
t_start = time.time()
|
| 158 |
+
tokens_seen = 0
|
| 159 |
+
running_loss = 0.0
|
| 160 |
+
running_n = 0
|
| 161 |
+
|
| 162 |
+
for step in range(start_step, args.max_steps):
|
| 163 |
+
cur_lr = cosine_with_warmup(step, warmup, args.max_steps, lr)
|
| 164 |
+
for g in optimizer.param_groups:
|
| 165 |
+
g["lr"] = cur_lr
|
| 166 |
+
|
| 167 |
+
optimizer.zero_grad(set_to_none=True)
|
| 168 |
+
loss_accum = 0.0
|
| 169 |
+
for micro in range(args.grad_accum):
|
| 170 |
+
try:
|
| 171 |
+
batch = next(data_iter)
|
| 172 |
+
except StopIteration:
|
| 173 |
+
data_iter = iter(loader)
|
| 174 |
+
batch = next(data_iter)
|
| 175 |
+
xs, ys = batch[0], batch[1]
|
| 176 |
+
xs = xs.to(args.device, non_blocking=True)
|
| 177 |
+
ys = ys.to(args.device, non_blocking=True)
|
| 178 |
+
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
|
| 179 |
+
_, loss = model(xs, targets=ys)
|
| 180 |
+
loss = loss / args.grad_accum
|
| 181 |
+
if scaler.is_enabled():
|
| 182 |
+
scaler.scale(loss).backward()
|
| 183 |
+
else:
|
| 184 |
+
loss.backward()
|
| 185 |
+
loss_accum += loss.item() * args.grad_accum
|
| 186 |
+
|
| 187 |
+
if scaler.is_enabled():
|
| 188 |
+
scaler.unscale_(optimizer)
|
| 189 |
+
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 190 |
+
if scaler.is_enabled():
|
| 191 |
+
scaler.step(optimizer)
|
| 192 |
+
scaler.update()
|
| 193 |
+
else:
|
| 194 |
+
optimizer.step()
|
| 195 |
+
|
| 196 |
+
tokens_seen += tokens_per_step
|
| 197 |
+
running_loss += loss_accum / args.grad_accum
|
| 198 |
+
running_n += 1
|
| 199 |
+
|
| 200 |
+
if (step + 1) % args.log_every == 0:
|
| 201 |
+
elapsed = time.time() - t_start
|
| 202 |
+
tps = tokens_seen / max(1.0, elapsed)
|
| 203 |
+
avg_loss = running_loss / running_n
|
| 204 |
+
print(f"[p{args.phase} step {step+1:>6}/{args.max_steps}] "
|
| 205 |
+
f"loss={avg_loss:.4f} lr={cur_lr:.2e} gnorm={gnorm:.2f} "
|
| 206 |
+
f"tok/s={tps:>7,.0f} elapsed={elapsed/60:.1f}min")
|
| 207 |
+
log_jsonl(log_path, {
|
| 208 |
+
"phase": args.phase, "step": step + 1, "loss": avg_loss,
|
| 209 |
+
"lr": cur_lr, "gnorm": float(gnorm), "tok_per_s": tps,
|
| 210 |
+
"tokens_seen": tokens_seen,
|
| 211 |
+
})
|
| 212 |
+
running_loss = 0.0
|
| 213 |
+
running_n = 0
|
| 214 |
+
|
| 215 |
+
if (step + 1) % args.save_every == 0 or (step + 1) == args.max_steps:
|
| 216 |
+
ckpt_path = out_dir / "last.pt"
|
| 217 |
+
save_checkpoint(ckpt_path, model, optimizer, {"step": step + 1}, step + 1,
|
| 218 |
+
extra={"phase": args.phase, "mix": mix, "lr": lr})
|
| 219 |
+
print(f"[save] {ckpt_path}")
|
| 220 |
+
|
| 221 |
+
final = out_dir / "last.pt"
|
| 222 |
+
save_checkpoint(final, model, optimizer, {"step": args.max_steps}, args.max_steps,
|
| 223 |
+
extra={"phase": args.phase, "mix": mix, "lr": lr, "done": True})
|
| 224 |
+
print(f"[done] phase {args.phase} → {final}")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|
training/sft_dataset.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SFT dataset with proper assistant-only loss masking and safe packing.
|
| 2 |
+
|
| 3 |
+
Each example is a chat-formatted string with `<|system|> <|user|> <|assistant|> <|end|>`
|
| 4 |
+
turn delimiters. We tokenize on the fly (corpus is small, ~25M tokens) and build a
|
| 5 |
+
mask=1 only on tokens that are part of an assistant response (everything between
|
| 6 |
+
`<|assistant|>` and the next `<|end|>`).
|
| 7 |
+
|
| 8 |
+
For pre-training-style packing without cross-example contamination we group multiple
|
| 9 |
+
short examples into a fixed-length window using `cu_seqlens`-style document boundaries
|
| 10 |
+
implemented via per-document attention reset. Here we keep it simple: pad/truncate
|
| 11 |
+
each example to `block_size`. Throughput is still high (>40k tok/s on L4) for this
|
| 12 |
+
volume.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.data import Dataset
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _read_jsonl(path):
|
| 25 |
+
out = []
|
| 26 |
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
| 27 |
+
for line in f:
|
| 28 |
+
line = line.strip()
|
| 29 |
+
if not line:
|
| 30 |
+
continue
|
| 31 |
+
try:
|
| 32 |
+
obj = json.loads(line)
|
| 33 |
+
except json.JSONDecodeError:
|
| 34 |
+
continue
|
| 35 |
+
t = obj.get("text") or ""
|
| 36 |
+
if t:
|
| 37 |
+
out.append({"text": t, "source": obj.get("source", Path(path).stem)})
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_assistant_mask(token_ids, assistant_id, end_id):
|
| 42 |
+
"""mask[i] = 1 iff token_ids[i] is inside an `<|assistant|> ... <|end|>` span.
|
| 43 |
+
|
| 44 |
+
We mark from the token AFTER `<|assistant|>` up to and including `<|end|>` so the
|
| 45 |
+
model learns to emit the closing delimiter.
|
| 46 |
+
"""
|
| 47 |
+
mask = np.zeros(len(token_ids), dtype=np.int64)
|
| 48 |
+
inside = False
|
| 49 |
+
for i, t in enumerate(token_ids):
|
| 50 |
+
if t == assistant_id and not inside:
|
| 51 |
+
inside = True
|
| 52 |
+
continue # don't include the assistant tag itself
|
| 53 |
+
if inside:
|
| 54 |
+
mask[i] = 1
|
| 55 |
+
if t == end_id:
|
| 56 |
+
inside = False
|
| 57 |
+
return mask
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SFTDataset(Dataset):
|
| 61 |
+
def __init__(self, jsonl_paths, sp, block_size, assistant_token="<|assistant|>",
|
| 62 |
+
end_token="<|end|>", pad_id=0, seed=42, mix_weights=None):
|
| 63 |
+
self.sp = sp
|
| 64 |
+
self.block_size = block_size
|
| 65 |
+
self.pad_id = pad_id
|
| 66 |
+
self.assistant_id = sp.piece_to_id(assistant_token)
|
| 67 |
+
self.end_id = sp.piece_to_id(end_token)
|
| 68 |
+
if self.assistant_id < 0 or self.end_id < 0:
|
| 69 |
+
raise ValueError(f"missing special tokens in tokenizer: "
|
| 70 |
+
f"{assistant_token}={self.assistant_id} {end_token}={self.end_id}")
|
| 71 |
+
|
| 72 |
+
self.examples = []
|
| 73 |
+
rng = random.Random(seed)
|
| 74 |
+
for p in jsonl_paths:
|
| 75 |
+
recs = _read_jsonl(p)
|
| 76 |
+
w = (mix_weights or {}).get(Path(p).name, 1.0)
|
| 77 |
+
if w != 1.0:
|
| 78 |
+
k = int(len(recs) * w)
|
| 79 |
+
recs = rng.sample(recs, min(k, len(recs)))
|
| 80 |
+
self.examples.extend(recs)
|
| 81 |
+
print(f" [sft] {p}: {len(recs):,} ex (w={w})")
|
| 82 |
+
rng.shuffle(self.examples)
|
| 83 |
+
print(f"[sft] total: {len(self.examples):,} examples")
|
| 84 |
+
|
| 85 |
+
def __len__(self):
|
| 86 |
+
return len(self.examples)
|
| 87 |
+
|
| 88 |
+
def __getitem__(self, idx):
|
| 89 |
+
text = self.examples[idx]["text"]
|
| 90 |
+
ids = self.sp.encode(text, out_type=int)
|
| 91 |
+
ids = ids[: self.block_size + 1]
|
| 92 |
+
mask = build_assistant_mask(ids, self.assistant_id, self.end_id)
|
| 93 |
+
|
| 94 |
+
if len(ids) < self.block_size + 1:
|
| 95 |
+
need = self.block_size + 1 - len(ids)
|
| 96 |
+
ids = ids + [self.pad_id] * need
|
| 97 |
+
mask = np.concatenate([mask, np.zeros(need, dtype=np.int64)])
|
| 98 |
+
|
| 99 |
+
ids = np.asarray(ids, dtype=np.int64)
|
| 100 |
+
x = torch.from_numpy(ids[:-1])
|
| 101 |
+
y = torch.from_numpy(ids[1:].copy())
|
| 102 |
+
m = torch.from_numpy(mask[1:].copy()) # mask aligned with targets
|
| 103 |
+
# zero out padded targets
|
| 104 |
+
y[m == 0] = -100
|
| 105 |
+
return x, y, m
|
training/transformer.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VectraYX-Nano transformer (decoder-only, ~42M params).
|
| 2 |
+
|
| 3 |
+
Modern small-LLM stack:
|
| 4 |
+
RMSNorm (pre-norm) · SwiGLU FFN · RoPE · GQA (8q/2kv)
|
| 5 |
+
QK-Norm · no biases · tied embeddings · z-loss
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ModelConfig:
|
| 19 |
+
vocab_size: int = 16384
|
| 20 |
+
n_layers: int = 8
|
| 21 |
+
n_heads: int = 8
|
| 22 |
+
n_kv_heads: int = 2
|
| 23 |
+
d_model: int = 512
|
| 24 |
+
d_ffn: int = 2048
|
| 25 |
+
max_seq_len: int = 1024
|
| 26 |
+
rope_theta: float = 10000.0
|
| 27 |
+
rms_eps: float = 1e-6
|
| 28 |
+
init_std: float = 0.02
|
| 29 |
+
dropout: float = 0.0
|
| 30 |
+
tie_embeddings: bool = True
|
| 31 |
+
qk_norm: bool = True
|
| 32 |
+
z_loss_coef: float = 1e-4
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def from_json(cls, path):
|
| 36 |
+
cfg = json.loads(open(path).read())["model"]
|
| 37 |
+
return cls(**{k: cfg[k] for k in cfg if k in cls.__dataclass_fields__})
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RMSNorm(nn.Module):
|
| 41 |
+
def __init__(self, dim, eps=1e-6):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 44 |
+
self.eps = eps
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 48 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 49 |
+
return x.to(self.weight.dtype) * self.weight
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def precompute_rope(head_dim, max_seq_len, theta=10000.0, device=None):
|
| 53 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
| 54 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 55 |
+
freqs = torch.outer(t, inv_freq)
|
| 56 |
+
cos = freqs.cos()
|
| 57 |
+
sin = freqs.sin()
|
| 58 |
+
if device is not None:
|
| 59 |
+
cos = cos.to(device)
|
| 60 |
+
sin = sin.to(device)
|
| 61 |
+
return cos, sin
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def apply_rope(x, cos, sin):
|
| 65 |
+
# x: (B, H, T, D) with D even. cos/sin: (T, D/2)
|
| 66 |
+
T, D = x.shape[-2], x.shape[-1]
|
| 67 |
+
cos = cos[:T].view(1, 1, T, D // 2)
|
| 68 |
+
sin = sin[:T].view(1, 1, T, D // 2)
|
| 69 |
+
x1 = x[..., : D // 2]
|
| 70 |
+
x2 = x[..., D // 2:]
|
| 71 |
+
rx1 = x1 * cos - x2 * sin
|
| 72 |
+
rx2 = x1 * sin + x2 * cos
|
| 73 |
+
return torch.cat([rx1, rx2], dim=-1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class GQAttention(nn.Module):
|
| 77 |
+
def __init__(self, cfg: ModelConfig):
|
| 78 |
+
super().__init__()
|
| 79 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 80 |
+
assert cfg.n_heads % cfg.n_kv_heads == 0
|
| 81 |
+
self.n_heads = cfg.n_heads
|
| 82 |
+
self.n_kv_heads = cfg.n_kv_heads
|
| 83 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 84 |
+
self.repeat = self.n_heads // self.n_kv_heads
|
| 85 |
+
|
| 86 |
+
self.wq = nn.Linear(cfg.d_model, cfg.n_heads * self.head_dim, bias=False)
|
| 87 |
+
self.wk = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 88 |
+
self.wv = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 89 |
+
self.wo = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 90 |
+
|
| 91 |
+
self.qk_norm = cfg.qk_norm
|
| 92 |
+
if self.qk_norm:
|
| 93 |
+
self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
|
| 94 |
+
self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
|
| 95 |
+
|
| 96 |
+
self.dropout = cfg.dropout
|
| 97 |
+
|
| 98 |
+
def forward(self, x, cos, sin):
|
| 99 |
+
B, T, _ = x.shape
|
| 100 |
+
q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 101 |
+
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 102 |
+
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
if self.qk_norm:
|
| 105 |
+
q = self.q_norm(q)
|
| 106 |
+
k = self.k_norm(k)
|
| 107 |
+
|
| 108 |
+
q = apply_rope(q, cos, sin)
|
| 109 |
+
k = apply_rope(k, cos, sin)
|
| 110 |
+
|
| 111 |
+
if self.repeat > 1:
|
| 112 |
+
k = k.repeat_interleave(self.repeat, dim=1)
|
| 113 |
+
v = v.repeat_interleave(self.repeat, dim=1)
|
| 114 |
+
|
| 115 |
+
out = F.scaled_dot_product_attention(
|
| 116 |
+
q, k, v,
|
| 117 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 118 |
+
is_causal=True,
|
| 119 |
+
)
|
| 120 |
+
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
| 121 |
+
return self.wo(out)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class SwiGLU(nn.Module):
|
| 125 |
+
def __init__(self, cfg: ModelConfig):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.w_gate = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
|
| 128 |
+
self.w_up = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
|
| 129 |
+
self.w_down = nn.Linear(cfg.d_ffn, cfg.d_model, bias=False)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Block(nn.Module):
|
| 136 |
+
def __init__(self, cfg: ModelConfig):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.attn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
|
| 139 |
+
self.attn = GQAttention(cfg)
|
| 140 |
+
self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
|
| 141 |
+
self.ffn = SwiGLU(cfg)
|
| 142 |
+
|
| 143 |
+
def forward(self, x, cos, sin):
|
| 144 |
+
x = x + self.attn(self.attn_norm(x), cos, sin)
|
| 145 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class VectraYXNano(nn.Module):
|
| 150 |
+
def __init__(self, cfg: ModelConfig):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.cfg = cfg
|
| 153 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 154 |
+
self.layers = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 155 |
+
self.final_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
|
| 156 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 157 |
+
|
| 158 |
+
if cfg.tie_embeddings:
|
| 159 |
+
self.lm_head.weight = self.tok_emb.weight
|
| 160 |
+
|
| 161 |
+
head_dim = cfg.d_model // cfg.n_heads
|
| 162 |
+
cos, sin = precompute_rope(head_dim, cfg.max_seq_len, cfg.rope_theta)
|
| 163 |
+
self.register_buffer("rope_cos", cos, persistent=False)
|
| 164 |
+
self.register_buffer("rope_sin", sin, persistent=False)
|
| 165 |
+
|
| 166 |
+
self.apply(self._init_weights)
|
| 167 |
+
residual_std = cfg.init_std / math.sqrt(2 * cfg.n_layers)
|
| 168 |
+
for n, p in self.named_parameters():
|
| 169 |
+
if n.endswith("wo.weight") or n.endswith("w_down.weight"):
|
| 170 |
+
nn.init.normal_(p, mean=0.0, std=residual_std)
|
| 171 |
+
|
| 172 |
+
def _init_weights(self, m):
|
| 173 |
+
std = self.cfg.init_std
|
| 174 |
+
if isinstance(m, nn.Linear):
|
| 175 |
+
nn.init.normal_(m.weight, mean=0.0, std=std)
|
| 176 |
+
if m.bias is not None:
|
| 177 |
+
nn.init.zeros_(m.bias)
|
| 178 |
+
elif isinstance(m, nn.Embedding):
|
| 179 |
+
nn.init.normal_(m.weight, mean=0.0, std=std)
|
| 180 |
+
|
| 181 |
+
def num_params(self, exclude_embedding=False):
|
| 182 |
+
n = sum(p.numel() for p in self.parameters())
|
| 183 |
+
if exclude_embedding and self.cfg.tie_embeddings:
|
| 184 |
+
n -= self.tok_emb.weight.numel()
|
| 185 |
+
return n
|
| 186 |
+
|
| 187 |
+
def forward(self, idx, targets=None, loss_mask=None):
|
| 188 |
+
B, T = idx.shape
|
| 189 |
+
assert T <= self.cfg.max_seq_len, f"seq {T} > max {self.cfg.max_seq_len}"
|
| 190 |
+
x = self.tok_emb(idx)
|
| 191 |
+
cos = self.rope_cos
|
| 192 |
+
sin = self.rope_sin
|
| 193 |
+
for layer in self.layers:
|
| 194 |
+
x = layer(x, cos, sin)
|
| 195 |
+
x = self.final_norm(x)
|
| 196 |
+
logits = self.lm_head(x)
|
| 197 |
+
|
| 198 |
+
if targets is None:
|
| 199 |
+
return logits, None
|
| 200 |
+
|
| 201 |
+
# cross-entropy + z-loss for stability
|
| 202 |
+
flat_logits = logits.view(-1, logits.size(-1))
|
| 203 |
+
flat_tgt = targets.view(-1)
|
| 204 |
+
ce = F.cross_entropy(flat_logits, flat_tgt, reduction="none", ignore_index=-100)
|
| 205 |
+
if loss_mask is not None:
|
| 206 |
+
mask = loss_mask.view(-1).float()
|
| 207 |
+
denom = mask.sum().clamp_min(1.0)
|
| 208 |
+
ce_loss = (ce * mask).sum() / denom
|
| 209 |
+
else:
|
| 210 |
+
valid = (flat_tgt != -100).float()
|
| 211 |
+
denom = valid.sum().clamp_min(1.0)
|
| 212 |
+
ce_loss = (ce * valid).sum() / denom
|
| 213 |
+
|
| 214 |
+
if self.cfg.z_loss_coef > 0:
|
| 215 |
+
lse = torch.logsumexp(flat_logits.float(), dim=-1)
|
| 216 |
+
if loss_mask is not None:
|
| 217 |
+
z = ((lse ** 2) * loss_mask.view(-1).float()).sum() / denom
|
| 218 |
+
else:
|
| 219 |
+
z = ((lse ** 2) * (flat_tgt != -100).float()).sum() / denom
|
| 220 |
+
loss = ce_loss + self.cfg.z_loss_coef * z
|
| 221 |
+
else:
|
| 222 |
+
loss = ce_loss
|
| 223 |
+
return logits, loss
|
| 224 |
+
|
| 225 |
+
@torch.no_grad()
|
| 226 |
+
def generate(self, idx, max_new_tokens, temperature=0.7, top_k=40, top_p=0.9,
|
| 227 |
+
eos_id=None, repeat_penalty=1.0):
|
| 228 |
+
self.eval()
|
| 229 |
+
for _ in range(max_new_tokens):
|
| 230 |
+
cond = idx[:, -self.cfg.max_seq_len:]
|
| 231 |
+
logits, _ = self(cond)
|
| 232 |
+
logits = logits[:, -1, :].float()
|
| 233 |
+
|
| 234 |
+
if repeat_penalty != 1.0:
|
| 235 |
+
for token in set(idx[0].tolist()):
|
| 236 |
+
logits[0, token] = logits[0, token] / repeat_penalty if logits[0, token] > 0 else logits[0, token] * repeat_penalty
|
| 237 |
+
|
| 238 |
+
if temperature <= 0:
|
| 239 |
+
next_id = logits.argmax(-1, keepdim=True)
|
| 240 |
+
else:
|
| 241 |
+
logits = logits / temperature
|
| 242 |
+
if top_k:
|
| 243 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 244 |
+
logits[logits < v[:, [-1]]] = -float("inf")
|
| 245 |
+
if top_p and top_p < 1.0:
|
| 246 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 247 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 248 |
+
cumprobs = probs.cumsum(-1)
|
| 249 |
+
drop = cumprobs > top_p
|
| 250 |
+
drop[..., 1:] = drop[..., :-1].clone()
|
| 251 |
+
drop[..., 0] = False
|
| 252 |
+
sorted_logits[drop] = -float("inf")
|
| 253 |
+
logits = torch.full_like(logits, -float("inf")).scatter(-1, sorted_idx, sorted_logits)
|
| 254 |
+
probs = F.softmax(logits, dim=-1)
|
| 255 |
+
next_id = torch.multinomial(probs, 1)
|
| 256 |
+
idx = torch.cat([idx, next_id], dim=-1)
|
| 257 |
+
if eos_id is not None and next_id.item() == eos_id:
|
| 258 |
+
break
|
| 259 |
+
return idx
|
training/utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities: optimizer setup, LR schedule, checkpointing."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cosine_with_warmup(step, warmup, total, max_lr, min_lr_ratio=0.1):
|
| 12 |
+
if step < warmup:
|
| 13 |
+
return max_lr * (step + 1) / warmup
|
| 14 |
+
progress = (step - warmup) / max(1, total - warmup)
|
| 15 |
+
progress = min(1.0, progress)
|
| 16 |
+
return min_lr_ratio * max_lr + 0.5 * (max_lr - min_lr_ratio * max_lr) * (1 + math.cos(math.pi * progress))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def make_optimizer(model, lr, weight_decay=0.1, betas=(0.9, 0.95), fused=True):
|
| 20 |
+
"""AdamW with weight decay only on 2D weights (no decay on biases / norms / embeddings).
|
| 21 |
+
|
| 22 |
+
Per Loshchilov & Hutter; same convention as nanoGPT.
|
| 23 |
+
"""
|
| 24 |
+
decay, no_decay = [], []
|
| 25 |
+
for n, p in model.named_parameters():
|
| 26 |
+
if not p.requires_grad:
|
| 27 |
+
continue
|
| 28 |
+
if p.dim() >= 2 and "tok_emb" not in n:
|
| 29 |
+
decay.append(p)
|
| 30 |
+
else:
|
| 31 |
+
no_decay.append(p)
|
| 32 |
+
groups = [
|
| 33 |
+
{"params": decay, "weight_decay": weight_decay},
|
| 34 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 35 |
+
]
|
| 36 |
+
extra = {}
|
| 37 |
+
if fused and torch.cuda.is_available():
|
| 38 |
+
try:
|
| 39 |
+
return torch.optim.AdamW(groups, lr=lr, betas=betas, fused=True)
|
| 40 |
+
except TypeError:
|
| 41 |
+
pass
|
| 42 |
+
return torch.optim.AdamW(groups, lr=lr, betas=betas, **extra)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def save_checkpoint(path, model, optimizer, scheduler_state, step, extra=None):
|
| 46 |
+
path = Path(path)
|
| 47 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
payload = {
|
| 49 |
+
"model": model.state_dict(),
|
| 50 |
+
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
| 51 |
+
"scheduler": scheduler_state,
|
| 52 |
+
"step": step,
|
| 53 |
+
"config": {k: getattr(model.cfg, k) for k in model.cfg.__dataclass_fields__},
|
| 54 |
+
"extra": extra or {},
|
| 55 |
+
}
|
| 56 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 57 |
+
torch.save(payload, tmp)
|
| 58 |
+
os.replace(tmp, path)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
|
| 62 |
+
payload = torch.load(path, map_location=map_location, weights_only=False)
|
| 63 |
+
# Si el checkpoint tiene tie_embeddings=True, usar strict=False
|
| 64 |
+
# (lm_head comparte pesos con tok_emb y no se guarda por separado)
|
| 65 |
+
strict = not payload.get("tie_embeddings", False)
|
| 66 |
+
missing, unexpected = model.load_state_dict(payload["model"], strict=strict)
|
| 67 |
+
if missing:
|
| 68 |
+
print(f"[load_checkpoint] missing keys (expected with tie_embeddings): {missing[:3]}")
|
| 69 |
+
if optimizer is not None and payload.get("optimizer"):
|
| 70 |
+
optimizer.load_state_dict(payload["optimizer"])
|
| 71 |
+
return payload.get("step", 0), payload.get("extra", {})
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def count_tokens(loader_output_iter, n_steps, block_size, batch_size):
|
| 75 |
+
"""Approximate; effective tokens consumed per step."""
|
| 76 |
+
return n_steps * block_size * batch_size
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def log_jsonl(path, record):
|
| 80 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 81 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|