File size: 4,770 Bytes
df724f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""
Pre-flight training configuration checklist (SFT + GRPO).
Read-only: inspects training/sft_train.py and training/grpo_train.py; does not start training.
"""
from __future__ import annotations

import argparse
import inspect
import os
import sys
from pathlib import Path


def main() -> None:
    parser = argparse.ArgumentParser(description="Pre-flight SFT/GRPO training config checklist")
    parser.add_argument(
        "--repo-root",
        type=Path,
        default=None,
        help="Project root (default: parent of scripts/)",
    )
    args = parser.parse_args()

    root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
    if str(root) not in sys.path:
        sys.path.insert(0, str(root))

    sft_path = root / "training" / "sft_train.py"
    grpo_path = root / "training" / "grpo_train.py"
    if not sft_path.is_file() or not grpo_path.is_file():
        print(f"Missing {sft_path} or {grpo_path}")
        return

    import training.sft_train as sft
    import training.grpo_train as grpo

    sft_text = sft_path.read_text(encoding="utf-8")
    grpo_text = grpo_path.read_text(encoding="utf-8")
    sft_fn = inspect.getsource(sft.train_sft)

    checks: list[tuple[str, bool, str]] = []

    # Base model
    want_model = "Qwen/Qwen2.5-1.5B-Instruct"
    ok_model = sft.DEFAULT_MODEL == want_model
    checks.append(("[ ] Base model: Qwen/Qwen2.5-1.5B-Instruct", ok_model, f"found {sft.DEFAULT_MODEL!r}"))

    # LoRA
    ok_lora = "r=16" in sft_fn and "lora_alpha=32" in sft_fn
    checks.append(("[ ] SFT LoRA r=16, alpha=32", ok_lora, "in train_sft()"))

    # SFT training args
    ok_epochs = "num_train_epochs=3" in sft_text
    ok_b = "per_device_train_batch_size=4" in sft_text
    ok_g = "gradient_accumulation_steps=4" in sft_text
    eff = 4 * 4
    ok_sft = ok_epochs and ok_b and ok_g
    checks.append(
        (
            f"[ ] SFT epochs=3, batch=4, grad_accum=4 (effective ~{eff})",
            ok_sft,
            f"epochs={ok_epochs} batch={ok_b} grad={ok_g}",
        )
    )

    # Output dir
    want_out = "checkpoints/sft_1.5b/"
    ok_out = sft.DEFAULT_OUTPUT == want_out
    checks.append(("[ ] SFT output: checkpoints/sft_1.5b/", ok_out, f"default={sft.DEFAULT_OUTPUT!r}"))

    # GRPO BASE_MODEL (read at import time in grpo_train)
    base = os.getenv("BASE_MODEL", "")
    grpo_default = grpo.BASE_MODEL
    if not base:
        grpo_brief = f"BASE_MODEL env not set - will use module default {grpo_default!r}"
    else:
        grpo_brief = f"set to {base!r}"
    checks.append(
        (
            "[ ] GRPO reads BASE_MODEL from env",
            True,
            grpo_brief,
        )
    )

    # GRPO reward weights
    want_line = "reward_weights=[3.0, 1.5, 2.0, 0.5]"
    in_rw = want_line in grpo_text
    w_line = next((ln.strip() for ln in grpo_text.splitlines() if "reward_weights" in ln), "")
    checks.append(
        (
            "[ ] GRPO reward weights [efficiency, tom, anti-cap, format] = [3.0, 1.5, 2.0, 0.5]",
            in_rw,
            w_line or "not found",
        )
    )

    # GRPO data path
    d_ok = 'default="data/episodes.jsonl"' in grpo_text
    checks.append(
        (
            '[ ] GRPO --data default: data/episodes.jsonl',
            d_ok,
            "see grpo_train.main argparse" if d_ok else "check grpo_train.py",
        )
    )

    checks.append(
        (
            "[ ] Estimated VRAM note (1.5B + LoRA r=16 ~6-8GB SFT; more for GRPO)",
            True,
            "informational (not a failure if you skip the box)",
        )
    )

    hf = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
    ok_hf = bool(hf)
    checks.append(
        (
            "[ ] HF token for push (HF_TOKEN or HUGGING_FACE_HUB_TOKEN)",
            ok_hf,
            "set" if ok_hf else "not set - needed to push checkpoints",
        )
    )

    print("Training config pre-flight (read from training/sft_train.py, training/grpo_train.py)\n")
    for line, ok, note in checks:
        mark = "x" if ok else " "
        display = line.replace("[ ]", f"[{mark}]", 1) if line.startswith("[ ]") else line
        print(display)
        if note:
            print(f"  -> {note}")
    print()

    core_ok = ok_model and ok_lora and ok_sft and ok_out and in_rw and d_ok
    if core_ok and ok_hf:
        print("\nREADY FOR TRAINING (SFT + GRPO config strings match; HF token present for hub).")
    elif core_ok:
        print(
            "\nMOSTLY READY: fix missing HF token if you need push_to_hub; verify BASE_MODEL for GRPO stage."
        )
    else:
        print("\nNEEDS FIXING: see failed [ ] items above (model path, LoRA, SFT args, or output dir).")


if __name__ == "__main__":
    main()