File size: 4,770 Bytes
df724f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""
Pre-flight training configuration checklist (SFT + GRPO).
Read-only: inspects training/sft_train.py and training/grpo_train.py; does not start training.
"""
from __future__ import annotations
import argparse
import inspect
import os
import sys
from pathlib import Path
def main() -> None:
parser = argparse.ArgumentParser(description="Pre-flight SFT/GRPO training config checklist")
parser.add_argument(
"--repo-root",
type=Path,
default=None,
help="Project root (default: parent of scripts/)",
)
args = parser.parse_args()
root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
if str(root) not in sys.path:
sys.path.insert(0, str(root))
sft_path = root / "training" / "sft_train.py"
grpo_path = root / "training" / "grpo_train.py"
if not sft_path.is_file() or not grpo_path.is_file():
print(f"Missing {sft_path} or {grpo_path}")
return
import training.sft_train as sft
import training.grpo_train as grpo
sft_text = sft_path.read_text(encoding="utf-8")
grpo_text = grpo_path.read_text(encoding="utf-8")
sft_fn = inspect.getsource(sft.train_sft)
checks: list[tuple[str, bool, str]] = []
# Base model
want_model = "Qwen/Qwen2.5-1.5B-Instruct"
ok_model = sft.DEFAULT_MODEL == want_model
checks.append(("[ ] Base model: Qwen/Qwen2.5-1.5B-Instruct", ok_model, f"found {sft.DEFAULT_MODEL!r}"))
# LoRA
ok_lora = "r=16" in sft_fn and "lora_alpha=32" in sft_fn
checks.append(("[ ] SFT LoRA r=16, alpha=32", ok_lora, "in train_sft()"))
# SFT training args
ok_epochs = "num_train_epochs=3" in sft_text
ok_b = "per_device_train_batch_size=4" in sft_text
ok_g = "gradient_accumulation_steps=4" in sft_text
eff = 4 * 4
ok_sft = ok_epochs and ok_b and ok_g
checks.append(
(
f"[ ] SFT epochs=3, batch=4, grad_accum=4 (effective ~{eff})",
ok_sft,
f"epochs={ok_epochs} batch={ok_b} grad={ok_g}",
)
)
# Output dir
want_out = "checkpoints/sft_1.5b/"
ok_out = sft.DEFAULT_OUTPUT == want_out
checks.append(("[ ] SFT output: checkpoints/sft_1.5b/", ok_out, f"default={sft.DEFAULT_OUTPUT!r}"))
# GRPO BASE_MODEL (read at import time in grpo_train)
base = os.getenv("BASE_MODEL", "")
grpo_default = grpo.BASE_MODEL
if not base:
grpo_brief = f"BASE_MODEL env not set - will use module default {grpo_default!r}"
else:
grpo_brief = f"set to {base!r}"
checks.append(
(
"[ ] GRPO reads BASE_MODEL from env",
True,
grpo_brief,
)
)
# GRPO reward weights
want_line = "reward_weights=[3.0, 1.5, 2.0, 0.5]"
in_rw = want_line in grpo_text
w_line = next((ln.strip() for ln in grpo_text.splitlines() if "reward_weights" in ln), "")
checks.append(
(
"[ ] GRPO reward weights [efficiency, tom, anti-cap, format] = [3.0, 1.5, 2.0, 0.5]",
in_rw,
w_line or "not found",
)
)
# GRPO data path
d_ok = 'default="data/episodes.jsonl"' in grpo_text
checks.append(
(
'[ ] GRPO --data default: data/episodes.jsonl',
d_ok,
"see grpo_train.main argparse" if d_ok else "check grpo_train.py",
)
)
checks.append(
(
"[ ] Estimated VRAM note (1.5B + LoRA r=16 ~6-8GB SFT; more for GRPO)",
True,
"informational (not a failure if you skip the box)",
)
)
hf = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
ok_hf = bool(hf)
checks.append(
(
"[ ] HF token for push (HF_TOKEN or HUGGING_FACE_HUB_TOKEN)",
ok_hf,
"set" if ok_hf else "not set - needed to push checkpoints",
)
)
print("Training config pre-flight (read from training/sft_train.py, training/grpo_train.py)\n")
for line, ok, note in checks:
mark = "x" if ok else " "
display = line.replace("[ ]", f"[{mark}]", 1) if line.startswith("[ ]") else line
print(display)
if note:
print(f" -> {note}")
print()
core_ok = ok_model and ok_lora and ok_sft and ok_out and in_rw and d_ok
if core_ok and ok_hf:
print("\nREADY FOR TRAINING (SFT + GRPO config strings match; HF token present for hub).")
elif core_ok:
print(
"\nMOSTLY READY: fix missing HF token if you need push_to_hub; verify BASE_MODEL for GRPO stage."
)
else:
print("\nNEEDS FIXING: see failed [ ] items above (model path, LoRA, SFT args, or output dir).")
if __name__ == "__main__":
main()
|