ronitraj commited on
Commit
bd1a695
·
verified ·
1 Parent(s): 6ac3b26

Upload scripts/train_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_sft.py +1198 -0
scripts/train_sft.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/train_sft.py - SFT warm-up phase (master spec, sections 1-3).
2
+
3
+ Loads ``Qwen/Qwen2.5-3B-Instruct`` in 4-bit (NF4) via Unsloth, attaches a
4
+ LoRA adapter (rank 16, alpha 32, dropout 0.05, on q/k/v/o projections),
5
+ and runs a single epoch of supervised fine-tuning on
6
+ ``data/sft_dataset.jsonl`` (3,000 examples).
7
+
8
+ Goal: take the base model from ~0% format compliance to >=95% so the GRPO
9
+ trainer has a non-zero probability of getting parseable rewards.
10
+
11
+ Locked hyperparameters (master spec, section 1):
12
+ * batch=4, grad_accum=4 -> effective batch 16
13
+ * lr=2e-4 with 20-step linear warmup -> constant
14
+ * weight_decay=0.01, optimizer=adamw_8bit, mixed precision=bf16
15
+ * max_seq_len=1024, epochs=1, max_steps=200
16
+ * checkpoint every 50, eval every 50, log every 10
17
+ * seed=42
18
+
19
+ Designed to run on a Colab T4 in <=30 minutes.
20
+
21
+ Usage::
22
+
23
+ pip install -r requirements-train.txt
24
+ python -m scripts.train_sft \
25
+ --dataset data/sft_dataset.jsonl \
26
+ --val-dataset data/sft_validation.jsonl \
27
+ --output checkpoints/sft_warmup \
28
+ --report-to wandb
29
+
30
+ W&B logging (master spec, section 2)
31
+ ------------------------------------
32
+ * Every 10 steps: TRL's built-in train/loss, learning_rate, grad_norm,
33
+ epoch, global_step.
34
+ * Every 50 steps (validation pass on 100 held-out syndromes):
35
+
36
+ eval/format_compliance
37
+ eval/logical_correction_rate
38
+ eval/exact_match_pymatching
39
+ eval/hamming_overlap_mean
40
+ eval/output_length_mean
41
+ eval/output_diversity (10 samples of one prompt @ T=0.7)
42
+ eval/syndrome_consistency
43
+
44
+ * End-of-train: ``run.summary`` dump of final eval scores; LoRA adapter
45
+ uploaded as a W&B artifact.
46
+
47
+ Early stopping (master spec, section 3)
48
+ ---------------------------------------
49
+ Training halts as soon as ANY of these is true after a validation pass:
50
+
51
+ 1. format_compliance >= 0.95 AND logical_correction_rate >= 0.80
52
+ AND output_diversity >= 3 (success)
53
+ 2. global_step >= 200 (hard cap)
54
+ 3. wall-clock >= 30 minutes (hard cap)
55
+ 4. train/loss has NaN or inf (failure)
56
+ """
57
+ from __future__ import annotations
58
+
59
+ import argparse
60
+ import json
61
+ import os
62
+ import random
63
+ import re
64
+ import statistics
65
+ import sys
66
+ import time
67
+ from collections import defaultdict
68
+ from pathlib import Path
69
+ from typing import Iterable, Optional
70
+
71
+
72
+ # --------------------------------------------------------------------------- #
73
+ # Pre-flight dataset audit #
74
+ # --------------------------------------------------------------------------- #
75
+ # Runs as the FIRST step of main(), before any model/tokenizer/heavy imports.
76
+ # Catches dataset regressions (class collapse, format drift, parse breakage,
77
+ # size mismatches) in a few seconds, before burning ~30 min of GPU on a run
78
+ # that was doomed at row 0.
79
+ #
80
+ # 9 checks, 3 of them duplicated on the validation split. Any failure raises
81
+ # SystemExit(2) so the Colab/Lightning shell pipeline exits with a non-zero
82
+ # status and won't proceed to model loading.
83
+
84
+ _FORMAT_ANCHOR_RE = re.compile(r"X_ERRORS=\[[\d,\s]*\]\s*Z_ERRORS=\[[\d,\s]*\]\s*$")
85
+ _FORMAT_ONLY_RE = re.compile(r"^\s*X_ERRORS=\[[\d,\s]*\]\s*Z_ERRORS=\[[\d,\s]*\]\s*$")
86
+ _TAIL_RE = re.compile(r"X_ERRORS=\[([^\]]*)\]\s*Z_ERRORS=\[([^\]]*)\]\s*$")
87
+ _LEVEL_P_RE = re.compile(r"Physical error rate:\s*([\d.]+)")
88
+ _LEVEL_D_RE = re.compile(r"Code distance:\s*(\d+)")
89
+
90
+
91
+ def _detect_level_from_prompt(prompt: str) -> str:
92
+ """Return ``"L1"``/``"L2"``/``"L3"``/``"unknown"`` for an SFT prompt.
93
+
94
+ Used as a fallback for legacy datasets that didn't write a ``level``
95
+ field into each record. We read the L1/L2/L3 ``p`` and ``distance``
96
+ values straight from :mod:`qubit_medic.config` rather than hardcoding
97
+ them, so the audit keeps working when the curriculum is tuned (e.g.
98
+ L1's ``p`` was bumped from 0.0001 -> 0.0005, which broke the old
99
+ hardcoded check and made every L1 row read as ``unknown``).
100
+ """
101
+ m_p = _LEVEL_P_RE.search(prompt)
102
+ m_d = _LEVEL_D_RE.search(prompt)
103
+ if not m_p or not m_d:
104
+ return "unknown"
105
+ p = float(m_p.group(1))
106
+ d = int(m_d.group(1))
107
+ try:
108
+ from qubit_medic.config import level_by_name
109
+ l3 = level_by_name("L3_stretch")
110
+ l2 = level_by_name("L2_target")
111
+ l1 = level_by_name("L1_warmup")
112
+ if d == l3.distance and abs(p - l3.p) < 1e-9:
113
+ return "L3"
114
+ if d == l2.distance and abs(p - l2.p) < 1e-9:
115
+ return "L2"
116
+ if d == l1.distance and abs(p - l1.p) < 1e-9:
117
+ return "L1"
118
+ except Exception:
119
+ pass
120
+ return "unknown"
121
+
122
+
123
+ def _level_label_from_record(rec: dict) -> str:
124
+ """Return ``"L1"``/``"L2"``/``"L3"``/``"unknown"`` for an SFT record.
125
+
126
+ Prefers the explicit ``level`` field written by
127
+ ``scripts/generate_sft_data.py`` (e.g. ``"L1_warmup"``). Falls back
128
+ to :func:`_detect_level_from_prompt` for legacy records that lack
129
+ that field.
130
+ """
131
+ raw = rec.get("level")
132
+ if isinstance(raw, str):
133
+ if raw.startswith("L1"):
134
+ return "L1"
135
+ if raw.startswith("L2"):
136
+ return "L2"
137
+ if raw.startswith("L3"):
138
+ return "L3"
139
+ prompt = rec.get("prompt")
140
+ if isinstance(prompt, str):
141
+ return _detect_level_from_prompt(prompt)
142
+ return "unknown"
143
+
144
+
145
+ def _has_nonempty_correction(completion: str) -> bool:
146
+ """True iff the completion's trailing format line predicts at least one
147
+ error (X or Z). Robust to a leading reasoning prefix.
148
+ """
149
+ m = _TAIL_RE.search(completion.rstrip())
150
+ if m is None:
151
+ return False
152
+ return bool(m.group(1).strip()) or bool(m.group(2).strip())
153
+
154
+
155
+ def _audit_file(path: Path) -> dict:
156
+ """Compute raw audit metrics for one JSONL file."""
157
+ if not path.exists():
158
+ return {"error": f"missing file: {path}"}
159
+ rows: list[dict] = []
160
+ parse_failures = 0
161
+ with path.open() as f:
162
+ for line in f:
163
+ line = line.strip()
164
+ if not line:
165
+ continue
166
+ try:
167
+ rec = json.loads(line)
168
+ except json.JSONDecodeError:
169
+ parse_failures += 1
170
+ continue
171
+ if "prompt" not in rec or "completion" not in rec:
172
+ parse_failures += 1
173
+ continue
174
+ rows.append(rec)
175
+ n = len(rows)
176
+ total_lines = n + parse_failures
177
+ parse_rate = (n / total_lines) if total_lines else 0.0
178
+ nonempty = sum(_has_nonempty_correction(r["completion"]) for r in rows)
179
+ anchor = sum(1 for r in rows if _FORMAT_ANCHOR_RE.search(r["completion"].rstrip()))
180
+ levels = {"L1": 0, "L2": 0, "L3": 0, "unknown": 0}
181
+ for r in rows:
182
+ levels[_level_label_from_record(r)] += 1
183
+ plens = [len(r["prompt"]) for r in rows]
184
+ clens = [len(r["completion"]) for r in rows]
185
+ format_only = sum(1 for r in rows if _FORMAT_ONLY_RE.fullmatch(r["completion"].strip()))
186
+ return {
187
+ "n": n,
188
+ "parse_failures": parse_failures,
189
+ "parse_rate": parse_rate,
190
+ "nonempty_frac": (nonempty / n) if n else 0.0,
191
+ "anchor_frac": (anchor / n) if n else 0.0,
192
+ "level_pct": {k: ((v / n) if n else 0.0) for k, v in levels.items()},
193
+ "plens": plens,
194
+ "clens": clens,
195
+ "format_only_frac": (format_only / n) if n else 0.0,
196
+ }
197
+
198
+
199
+ def audit_sft_dataset(
200
+ train_path: str = "data/sft_dataset.jsonl",
201
+ val_path: str = "data/sft_validation.jsonl",
202
+ ) -> None:
203
+ """Pre-flight audit of the SFT dataset. Halts (SystemExit) on violation.
204
+
205
+ Runs 9 checks against ``train_path`` plus 4 parallel checks against
206
+ ``val_path``. Designed to run in seconds on the CPU before any heavy
207
+ ML deps are imported, so a broken dataset never reaches the GPU.
208
+
209
+ Locked thresholds:
210
+ Total rows: train=3000, val=100
211
+ JSON parse rate: 100%
212
+ Non-empty correction: 65-75%
213
+ Format anchor: 100%
214
+ Curriculum L1/L2/L3: 35-45% / 45-55% / 7-15%
215
+ Prompt length: min>=800, median in [1100,1600], max<=2200
216
+ Completion length: min>=22, median in [22,80], max<=120
217
+ Format-only target: 100%
218
+ Validation parallel: same thresholds applied to val split
219
+ """
220
+ EXPECTED_TRAIN = 3000
221
+ EXPECTED_VAL = 100
222
+ NONEMPTY_LO, NONEMPTY_HI = 0.65, 0.75
223
+ # Tightened to match quota-based per-level generation in
224
+ # scripts/generate_sft_data.py, which produces the 40/50/10 split
225
+ # exactly (no rejection-sampling drift).
226
+ L1_LO, L1_HI = 0.38, 0.42
227
+ L2_LO, L2_HI = 0.48, 0.52
228
+ L3_LO, L3_HI = 0.08, 0.12
229
+ PLEN_MIN, PLEN_MED_LO, PLEN_MED_HI, PLEN_MAX = 800, 1100, 1600, 2200
230
+ # Targets are deliberately one-line format strings. The earlier
231
+ # reasoning-prefix targets made the base model burn the full eval token
232
+ # budget on analysis and never reach the required parseable answer line.
233
+ CLEN_MIN, CLEN_MED_LO, CLEN_MED_HI, CLEN_MAX = 22, 22, 80, 120
234
+ FORMAT_ONLY_MIN = 1.0
235
+
236
+ train = _audit_file(Path(train_path))
237
+ if "error" in train:
238
+ print(f"[audit] FATAL: {train['error']}")
239
+ raise SystemExit(2)
240
+
241
+ # ------------------------------- train checks ------------------------- #
242
+ checks: list[tuple[str, str, bool]] = []
243
+
244
+ checks.append((
245
+ "Total rows",
246
+ f"{train['n']} (expected {EXPECTED_TRAIN})",
247
+ train["n"] == EXPECTED_TRAIN,
248
+ ))
249
+ checks.append((
250
+ "JSON parse rate",
251
+ f"{train['parse_rate'] * 100:.1f}% ({train['parse_failures']} failures)",
252
+ abs(train["parse_rate"] - 1.0) < 1e-9,
253
+ ))
254
+ checks.append((
255
+ "Non-empty correction",
256
+ f"{train['nonempty_frac'] * 100:.1f}% (target 65-75%)",
257
+ NONEMPTY_LO <= train["nonempty_frac"] <= NONEMPTY_HI,
258
+ ))
259
+ checks.append((
260
+ "Format anchor",
261
+ f"{train['anchor_frac'] * 100:.1f}%",
262
+ abs(train["anchor_frac"] - 1.0) < 1e-9,
263
+ ))
264
+
265
+ p1 = train["level_pct"]["L1"]
266
+ p2 = train["level_pct"]["L2"]
267
+ p3 = train["level_pct"]["L3"]
268
+ p_unknown = train["level_pct"]["unknown"]
269
+ checks.append((
270
+ "Curriculum L1/L2/L3",
271
+ f"{p1*100:.1f}/{p2*100:.1f}/{p3*100:.1f}% (unknown={p_unknown*100:.1f}%)",
272
+ (L1_LO <= p1 <= L1_HI
273
+ and L2_LO <= p2 <= L2_HI
274
+ and L3_LO <= p3 <= L3_HI),
275
+ ))
276
+
277
+ pmin = min(train["plens"]) if train["plens"] else 0
278
+ pmed = int(statistics.median(train["plens"])) if train["plens"] else 0
279
+ pmax = max(train["plens"]) if train["plens"] else 0
280
+ checks.append((
281
+ "Prompt length",
282
+ f"min={pmin} median={pmed} max={pmax}",
283
+ (pmin >= PLEN_MIN
284
+ and PLEN_MED_LO <= pmed <= PLEN_MED_HI
285
+ and pmax <= PLEN_MAX),
286
+ ))
287
+
288
+ cmin = min(train["clens"]) if train["clens"] else 0
289
+ cmed = int(statistics.median(train["clens"])) if train["clens"] else 0
290
+ cmax = max(train["clens"]) if train["clens"] else 0
291
+ checks.append((
292
+ "Completion length",
293
+ f"min={cmin} median={cmed} max={cmax}",
294
+ (cmin >= CLEN_MIN
295
+ and CLEN_MED_LO <= cmed <= CLEN_MED_HI
296
+ and cmax <= CLEN_MAX),
297
+ ))
298
+
299
+ checks.append((
300
+ "Format-only completions",
301
+ f"{train['format_only_frac'] * 100:.1f}% (target 100%)",
302
+ abs(train["format_only_frac"] - FORMAT_ONLY_MIN) < 1e-9,
303
+ ))
304
+
305
+ # ------------------------------- val parallel ------------------------- #
306
+ val = _audit_file(Path(val_path))
307
+ if "error" in val:
308
+ checks.append(("Validation parallel", val["error"], False))
309
+ else:
310
+ v1 = val["level_pct"]["L1"]
311
+ v2 = val["level_pct"]["L2"]
312
+ v3 = val["level_pct"]["L3"]
313
+ val_pass = (
314
+ val["n"] == EXPECTED_VAL
315
+ and abs(val["parse_rate"] - 1.0) < 1e-9
316
+ and NONEMPTY_LO <= val["nonempty_frac"] <= NONEMPTY_HI
317
+ and abs(val["anchor_frac"] - 1.0) < 1e-9
318
+ and abs(val["format_only_frac"] - FORMAT_ONLY_MIN) < 1e-9
319
+ and L1_LO <= v1 <= L1_HI
320
+ and L2_LO <= v2 <= L2_HI
321
+ and L3_LO <= v3 <= L3_HI
322
+ )
323
+ val_summary = (
324
+ f"rows={val['n']} parse={val['parse_rate']*100:.0f}% "
325
+ f"nonempty={val['nonempty_frac']*100:.1f}% "
326
+ f"anchor={val['anchor_frac']*100:.0f}% "
327
+ f"format_only={val['format_only_frac']*100:.0f}% "
328
+ f"L1/L2/L3={v1*100:.1f}/{v2*100:.1f}/{v3*100:.1f}%"
329
+ )
330
+ checks.append(("Validation parallel", val_summary, val_pass))
331
+
332
+ # ------------------------------- print banner ------------------------- #
333
+ print()
334
+ print("DATASET AUDIT SUMMARY")
335
+ print("=" * 21)
336
+ label_w = max(len(label) for label, _, _ in checks) + 1
337
+ val_w = max(len(val_str) for _, val_str, _ in checks)
338
+ for label, val_str, passed in checks:
339
+ mark = "✓" if passed else "✗" # ✓ / ✗
340
+ print(f"{(label + ':').ljust(label_w + 1)} {val_str.ljust(val_w)} [{mark}]")
341
+
342
+ all_passed = all(passed for _, _, passed in checks)
343
+ print()
344
+ if all_passed:
345
+ print("ALL CHECKS PASSED — DATASET READY FOR TRAINING")
346
+ print()
347
+ return
348
+ print("AUDIT FAILED — FIX DATASET BEFORE TRAINING")
349
+ print()
350
+ raise SystemExit(2)
351
+
352
+
353
+ # --------------------------------------------------------------------------- #
354
+ # Validation-record loading #
355
+ # --------------------------------------------------------------------------- #
356
+
357
+
358
+ def _load_jsonl(path: str) -> list[dict]:
359
+ rows: list[dict] = []
360
+ with open(path) as f:
361
+ for line in f:
362
+ rows.append(json.loads(line))
363
+ return rows
364
+
365
+
366
+ def _load_train_dataset(path: str, tokenizer):
367
+ """Load the SFT JSONL into a HuggingFace Dataset.
368
+
369
+ Master spec (section 4): the chat template is applied via the
370
+ tokenizer (``apply_chat_template``), NOT by manually inserting
371
+ ``<|im_start|>`` markers - that way the same template works across
372
+ Qwen2.5 / Qwen3 / etc. without surprises.
373
+ """
374
+ from datasets import Dataset
375
+
376
+ rows = _load_jsonl(path)
377
+ out = []
378
+ for rec in rows:
379
+ messages = [
380
+ {"role": "user", "content": rec["prompt"]},
381
+ {"role": "assistant", "content": rec["completion"]},
382
+ ]
383
+ try:
384
+ text = tokenizer.apply_chat_template(messages, tokenize=False)
385
+ except Exception:
386
+ # Defensive fallback if apply_chat_template ever misbehaves.
387
+ text = (
388
+ "<|im_start|>user\n"
389
+ f"{rec['prompt']}\n<|im_end|>\n"
390
+ "<|im_start|>assistant\n"
391
+ f"{rec['completion']}<|im_end|>"
392
+ )
393
+ out.append({
394
+ "prompt": rec["prompt"],
395
+ "completion": rec["completion"],
396
+ "text": text,
397
+ })
398
+ return Dataset.from_list(out)
399
+
400
+
401
+ # --------------------------------------------------------------------------- #
402
+ # Per-level physics caches (used by the validation callback) #
403
+ # --------------------------------------------------------------------------- #
404
+
405
+
406
+ def _build_level_caches(needed_levels: set[str]) -> dict[str, dict]:
407
+ """Pre-build circuit / matching / layout / supports per curriculum level."""
408
+ import pymatching
409
+
410
+ from qubit_medic.config import level_by_name
411
+ from qubit_medic.server.physics import (
412
+ build_circuit, build_dem, extract_layout, per_round_x_z_counts,
413
+ )
414
+ from qubit_medic.server.rewards import compute_final_detector_supports
415
+
416
+ caches: dict[str, dict] = {}
417
+ for name in needed_levels:
418
+ lvl = level_by_name(name)
419
+ circuit = build_circuit(lvl)
420
+ dem = build_dem(circuit)
421
+ matching = pymatching.Matching.from_detector_error_model(dem)
422
+ layout = extract_layout(circuit)
423
+ n_x, n_z = per_round_x_z_counts(layout)
424
+ supports = compute_final_detector_supports(layout)
425
+ caches[name] = {
426
+ "level": lvl,
427
+ "circuit": circuit,
428
+ "dem": dem,
429
+ "matching": matching,
430
+ "layout": layout,
431
+ "supports": supports,
432
+ "num_x_stab": n_x,
433
+ "num_z_stab": n_z,
434
+ }
435
+ return caches
436
+
437
+
438
+ # --------------------------------------------------------------------------- #
439
+ # Validation callback (master spec, section 2 + section 3) #
440
+ # --------------------------------------------------------------------------- #
441
+
442
+
443
+ def _build_validation_callback(
444
+ *,
445
+ model,
446
+ tokenizer,
447
+ val_records: list[dict],
448
+ eval_every: int,
449
+ eval_schedule: tuple[tuple[int, int, str], ...] | None,
450
+ print_sample_outputs: int,
451
+ output_dir: str,
452
+ max_new_tokens: int,
453
+ diversity_n_samples: int,
454
+ diversity_temperature: float,
455
+ early_stop_format: float,
456
+ early_stop_correction: float,
457
+ early_stop_diversity: int,
458
+ max_wall_seconds: float,
459
+ started_wall: float,
460
+ diversity_floor: int = 2,
461
+ diversity_run_len: int = 2,
462
+ ):
463
+ """Returns a ``TrainerCallback`` that:
464
+ * fires at every step in ``eval_schedule`` (or every ``eval_every``
465
+ steps if no schedule is given) with a per-step sample size,
466
+ * logs the spec metrics + new diagnostic metrics to W&B,
467
+ * prints the first ``print_sample_outputs`` raw model outputs to
468
+ stdout AND to ``{output_dir}/eval_samples_step{N}.txt`` so a
469
+ broken parser / generation drift can be diagnosed in seconds,
470
+ * stops training when the success criterion or hard caps fire.
471
+
472
+ Metric semantics changed in this revision:
473
+ * Parse failures NO LONGER default to "predict no errors". Failed
474
+ rows contribute logical_correction=0, hamming=0,
475
+ syndrome_consistency=0 to the aggregates. This stops trivial
476
+ syndromes (~95% at p=0.001) from inflating logical_correction_rate
477
+ to 0.98 while format_compliance sits at 0.01.
478
+ * New ``eval/parse_failure_rate`` = 1 - format_compliance, so a
479
+ broken parser is impossible to miss.
480
+ * New ``eval/format_compliance_strict`` reports the share of
481
+ outputs that hit the canonical ``X_ERRORS=[...] Z_ERRORS=[...]``
482
+ form (Reward 4 == 1.0). The looser ``eval/format_compliance``
483
+ reports the share where the model's answer was extractable at all.
484
+ """
485
+ from transformers import TrainerCallback
486
+
487
+ from qubit_medic import wandb_utils
488
+ from qubit_medic.prompts import parse_action
489
+ from qubit_medic.server.physics import SyndromeSample
490
+ from qubit_medic.server.rewards import compute_all_rewards
491
+
492
+ if not val_records:
493
+ return None
494
+
495
+ # Pre-build per-level physics for fast scoring.
496
+ needed = {r["level"] for r in val_records}
497
+ level_caches = _build_level_caches(needed)
498
+
499
+ # Pick one stable prompt for the diversity probe (always the same record
500
+ # so the diversity number is comparable across checkpoints).
501
+ diversity_record = val_records[0]
502
+ diversity_messages = [{"role": "user", "content": diversity_record["prompt"]}]
503
+
504
+ # Index the schedule: step -> (sample_size, mode). Sample sizes are
505
+ # capped at len(val_records) so a small held-out set still works.
506
+ if eval_schedule:
507
+ schedule = {
508
+ step: (min(size, len(val_records)), mode)
509
+ for step, size, mode in eval_schedule
510
+ }
511
+ else:
512
+ schedule = {}
513
+
514
+ sample_dir = Path(output_dir)
515
+ sample_dir.mkdir(parents=True, exist_ok=True)
516
+
517
+ # 2026-04 (FIX 2) diversity-collapse rolling buffer. We track the
518
+ # last ``diversity_run_len`` full-eval ``output_diversity`` values
519
+ # and stop training when every entry is below ``diversity_floor``.
520
+ from collections import deque as _deque
521
+ recent_diversity = _deque(maxlen=diversity_run_len)
522
+
523
+ class _ValidationCallback(TrainerCallback):
524
+ # Stamp the most recent eval here so the on_train_end hook can avoid
525
+ # re-running if the eval step coincided with the final step.
526
+ last_eval_step: int = -1
527
+
528
+ def on_step_end(self, args, state, control, **kwargs): # noqa: D401
529
+ now = time.time() - started_wall
530
+ if now >= max_wall_seconds:
531
+ print(f"[sft] wall-clock cap {max_wall_seconds:.0f}s hit at step "
532
+ f"{state.global_step}; stopping.")
533
+ control.should_training_stop = True
534
+ return
535
+
536
+ step = state.global_step
537
+ if step == 0:
538
+ return
539
+ if schedule:
540
+ if step not in schedule:
541
+ return
542
+ else:
543
+ if step % eval_every != 0:
544
+ return
545
+ self._run_eval(state, control)
546
+
547
+ def on_train_end(self, args, state, control, **kwargs): # noqa: D401
548
+ if state.global_step != self.last_eval_step:
549
+ self._run_eval(state, control, final=True)
550
+
551
+ # ------------------------------------------------------------------ #
552
+ # Core evaluation #
553
+ # ------------------------------------------------------------------ #
554
+ def _generate_greedy(self, messages: list[dict]) -> tuple[str, int]:
555
+ text = tokenizer.apply_chat_template(
556
+ messages, tokenize=False, add_generation_prompt=True,
557
+ )
558
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
559
+ try:
560
+ out = model.generate(
561
+ **inputs,
562
+ max_new_tokens=max_new_tokens,
563
+ do_sample=False,
564
+ eos_token_id=tokenizer.eos_token_id,
565
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
566
+ )
567
+ gen_ids = out[0][inputs["input_ids"].shape[1]:]
568
+ completion = tokenizer.decode(gen_ids, skip_special_tokens=True)
569
+ return completion, int(gen_ids.shape[0])
570
+ except Exception as exc:
571
+ return f"<gen-error: {exc}>", 0
572
+
573
+ def _generate_sampled(self, messages: list[dict]) -> str:
574
+ text = tokenizer.apply_chat_template(
575
+ messages, tokenize=False, add_generation_prompt=True,
576
+ )
577
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
578
+ try:
579
+ out = model.generate(
580
+ **inputs,
581
+ max_new_tokens=max_new_tokens,
582
+ do_sample=True,
583
+ temperature=diversity_temperature,
584
+ top_p=0.95,
585
+ eos_token_id=tokenizer.eos_token_id,
586
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
587
+ )
588
+ return tokenizer.decode(
589
+ out[0][inputs["input_ids"].shape[1]:],
590
+ skip_special_tokens=True,
591
+ )
592
+ except Exception as exc:
593
+ return f"<gen-error: {exc}>"
594
+
595
+ def _run_eval(self, state, control, *, final: bool = False) -> None:
596
+ self.last_eval_step = state.global_step
597
+ try:
598
+ from unsloth import FastLanguageModel
599
+ FastLanguageModel.for_inference(model)
600
+ except Exception:
601
+ model.eval() # type: ignore[attr-defined]
602
+
603
+ step = state.global_step
604
+ # Resolve sample size + mode for this step.
605
+ if final and step in schedule:
606
+ sample_size, mode = schedule[step]
607
+ elif final:
608
+ sample_size, mode = len(val_records), "full"
609
+ elif step in schedule:
610
+ sample_size, mode = schedule[step]
611
+ else:
612
+ sample_size, mode = len(val_records), "full"
613
+
614
+ # Deterministic slice so the same prompts are used across checkpoints.
615
+ records = val_records[:sample_size]
616
+ n = len(records)
617
+ full_eval = (mode == "full")
618
+
619
+ n_format = 0 # lenient parse_success
620
+ n_format_strict = 0 # canonical "=" + "[]"
621
+ n_logical = n_exact = 0
622
+ sum_hamming = 0.0
623
+ sum_syndrome = 0.0
624
+ sum_length = 0
625
+ rows: list[dict] = []
626
+ sample_dump_lines: list[str] = [
627
+ f"=== eval samples @ step {step} (mode={mode}, n={n}) ===",
628
+ ]
629
+
630
+ for idx, rec in enumerate(records):
631
+ num_data = int(rec["num_data_qubits"])
632
+ messages = [{"role": "user", "content": rec["prompt"]}]
633
+ completion, n_tokens = self._generate_greedy(messages)
634
+ sum_length += n_tokens
635
+
636
+ parsed = parse_action(completion, num_data_qubits=num_data)
637
+ fmt_ok = parsed.parse_success
638
+ fmt_strict_ok = bool(parsed.strict_format)
639
+ n_format += int(fmt_ok)
640
+ n_format_strict += int(fmt_strict_ok)
641
+
642
+ # Physics-heavy metrics only in "full" mode AND only when
643
+ # the parse actually succeeded. A failed parse means the
644
+ # model didn't produce a usable prediction; we score that
645
+ # as a miss (0) for every downstream metric instead of
646
+ # silently substituting an empty Pauli frame, which would
647
+ # accidentally score correct on the ~95% of trivial
648
+ # syndromes at p=0.001.
649
+ logical_ok = False
650
+ exact_ok = False
651
+ hamming = 0.0
652
+ syndrome = 0.0
653
+ if full_eval and fmt_ok:
654
+ cache = level_caches[rec["level"]]
655
+ layout = cache["layout"]
656
+ supports = cache["supports"]
657
+ sample = SyndromeSample(
658
+ syndrome_bits=list(map(int, rec["syndrome_bits"])),
659
+ actual_observable_flip=int(rec["actual_observable_flip"]),
660
+ pymatching_observable_pred=int(rec["pymatching_observable_pred"]),
661
+ pymatching_x_errors=list(map(int, rec["true_x_errors"])),
662
+ pymatching_z_errors=list(map(int, rec["true_z_errors"])),
663
+ )
664
+ breakdown = compute_all_rewards(parsed, sample, layout, supports)
665
+ logical_ok = breakdown.logical_correction >= 0.5
666
+ hamming = float(breakdown.hamming_overlap)
667
+ syndrome = float(breakdown.syndrome_consistency)
668
+ exact_ok = (
669
+ parsed.x_errors == sorted(set(rec["true_x_errors"]))
670
+ and parsed.z_errors == sorted(set(rec["true_z_errors"]))
671
+ )
672
+
673
+ n_logical += int(logical_ok)
674
+ n_exact += int(exact_ok)
675
+ sum_hamming += hamming
676
+ sum_syndrome += syndrome
677
+
678
+ if idx < print_sample_outputs:
679
+ sample_dump_lines.append(
680
+ f"\n--- sample {idx} (level={rec['level']}, "
681
+ f"true_x={rec['true_x_errors']}, true_z={rec['true_z_errors']}, "
682
+ f"fmt_ok={fmt_ok}, fmt_strict={fmt_strict_ok}, "
683
+ f"n_tokens={n_tokens}) ---\n"
684
+ f">>> RAW MODEL OUTPUT:\n{completion}\n"
685
+ f">>> PARSED: x={parsed.x_errors} z={parsed.z_errors}"
686
+ )
687
+
688
+ if idx < 4: # keep W&B table tiny
689
+ rows.append({
690
+ "step": step,
691
+ "prompt": rec["prompt"][:600],
692
+ "gold": rec["completion"],
693
+ "model": completion[:300],
694
+ "x_pred": ",".join(map(str, parsed.x_errors)),
695
+ "z_pred": ",".join(map(str, parsed.z_errors)),
696
+ "format_ok": fmt_ok,
697
+ "format_strict_ok": fmt_strict_ok,
698
+ "logical_ok": logical_ok,
699
+ "exact_match": exact_ok,
700
+ "hamming_overlap": hamming,
701
+ })
702
+
703
+ # ---------- print + persist raw output samples -------------- #
704
+ sample_blob = "\n".join(sample_dump_lines)
705
+ print(sample_blob)
706
+ try:
707
+ (sample_dir / f"eval_samples_step{step}.txt").write_text(sample_blob)
708
+ except OSError as exc:
709
+ print(f"[sft][eval@{step}] could not persist sample outputs: {exc}")
710
+
711
+ # ---------- diversity probe (skip in format_only mode) ------ #
712
+ if full_eval:
713
+ diverse_outputs: list[str] = []
714
+ for _ in range(diversity_n_samples):
715
+ diverse_outputs.append(self._generate_sampled(diversity_messages))
716
+ output_diversity = len(set(diverse_outputs))
717
+ else:
718
+ output_diversity = 0 # not measured this step
719
+
720
+ # ---------- aggregate + log to W&B ------------------------- #
721
+ metrics: dict[str, float | int] = {
722
+ "eval/format_compliance": n_format / max(1, n),
723
+ "eval/format_compliance_strict": n_format_strict / max(1, n),
724
+ "eval/parse_failure_rate": 1.0 - (n_format / max(1, n)),
725
+ "eval/output_length_mean": sum_length / max(1, n),
726
+ "eval/episodes": n,
727
+ "eval/mode_full": int(full_eval),
728
+ }
729
+ if full_eval:
730
+ metrics.update({
731
+ "eval/logical_correction_rate": n_logical / max(1, n),
732
+ "eval/exact_match_pymatching": n_exact / max(1, n),
733
+ "eval/hamming_overlap_mean": sum_hamming / max(1, n),
734
+ "eval/syndrome_consistency": sum_syndrome / max(1, n),
735
+ "eval/output_diversity": output_diversity,
736
+ })
737
+ print(f"[sft][eval@{step}] " + ", ".join(
738
+ f"{k.split('/')[-1]}={v:.3f}" if isinstance(v, float) else f"{k.split('/')[-1]}={v}"
739
+ for k, v in metrics.items()
740
+ ))
741
+ wandb_utils.log(metrics, step=step)
742
+ wandb_utils.log_generation_table(
743
+ rows, step=step,
744
+ table_name=("sft/final_validation" if final else "sft/validation"),
745
+ columns=["step", "prompt", "gold", "model", "x_pred", "z_pred",
746
+ "format_ok", "format_strict_ok", "logical_ok",
747
+ "exact_match", "hamming_overlap"],
748
+ )
749
+
750
+ # Fail fast on the known broken-SFT pattern: the model burns the
751
+ # whole generation budget on prose and never emits the format line.
752
+ # These thresholds mirror the runbook table in the issue analysis.
753
+ format_floor_by_step = {5: 0.10, 15: 0.30, 30: 0.60, 50: 0.80}
754
+ floor = format_floor_by_step.get(step)
755
+ if (
756
+ floor is not None
757
+ and not final
758
+ and metrics["eval/format_compliance"] < floor
759
+ ):
760
+ print(
761
+ f"[sft] format guard tripped at step {step}: "
762
+ f"format_compliance={metrics['eval/format_compliance']:.3f} "
763
+ f"< {floor:.2f}. Stop and inspect raw outputs / data."
764
+ )
765
+ control.should_training_stop = True
766
+ wandb_utils.update_summary({
767
+ "sft/early_stop_reason": "format_guard",
768
+ "sft/format_guard_step": step,
769
+ "sft/format_guard_floor": floor,
770
+ })
771
+
772
+ # ---------- early stop checks ------------------------------ #
773
+ # Only meaningful on full evals: logical_correction_rate and
774
+ # output_diversity are not measured in format_only mode.
775
+ if full_eval:
776
+ success = (
777
+ metrics["eval/format_compliance"] >= early_stop_format
778
+ and metrics["eval/logical_correction_rate"] >= early_stop_correction
779
+ and metrics["eval/output_diversity"] >= early_stop_diversity
780
+ )
781
+ if success and not final:
782
+ print(f"[sft] success criterion hit at step {state.global_step}: "
783
+ f"format={metrics['eval/format_compliance']:.3f} >= {early_stop_format}, "
784
+ f"correction={metrics['eval/logical_correction_rate']:.3f} >= {early_stop_correction}, "
785
+ f"diversity={int(metrics['eval/output_diversity'])} >= {early_stop_diversity}; "
786
+ f"stopping.")
787
+ control.should_training_stop = True
788
+ wandb_utils.update_summary({"sft/early_stop_reason": "success_criterion"})
789
+
790
+ # 2026-04 (FIX 2) diversity-collapse early stop. Pushed
791
+ # AFTER the success check so a model that satisfies both
792
+ # criteria still wins; only sustained low diversity
793
+ # without convergence triggers the regression stop.
794
+ recent_diversity.append(int(metrics["eval/output_diversity"]))
795
+ if (
796
+ not final
797
+ and not control.should_training_stop
798
+ and len(recent_diversity) >= diversity_run_len
799
+ and all(d < diversity_floor for d in recent_diversity)
800
+ ):
801
+ history = list(recent_diversity)
802
+ print(
803
+ f"[sft] diversity collapse early stop at step "
804
+ f"{state.global_step}: eval/output_diversity has "
805
+ f"been < {diversity_floor} for {diversity_run_len} "
806
+ f"consecutive full evals (history={history}). "
807
+ f"Stopping. Bump --lora-dropout (e.g. 0.15) or "
808
+ f"increase label smoothing and rerun."
809
+ )
810
+ control.should_training_stop = True
811
+ wandb_utils.update_summary({
812
+ "sft/early_stop_reason": "diversity_collapse",
813
+ "sft/diversity_collapse_step": state.global_step,
814
+ "sft/diversity_collapse_history": history,
815
+ })
816
+
817
+ try:
818
+ from unsloth import FastLanguageModel
819
+ FastLanguageModel.for_training(model)
820
+ except Exception:
821
+ model.train() # type: ignore[attr-defined]
822
+
823
+ return _ValidationCallback()
824
+
825
+
826
+ # --------------------------------------------------------------------------- #
827
+ # Loss-divergence guard (failure mode early stop) #
828
+ # --------------------------------------------------------------------------- #
829
+
830
+
831
+ def _build_loss_guard_callback():
832
+ import math
833
+
834
+ from transformers import TrainerCallback
835
+
836
+ class _LossGuard(TrainerCallback):
837
+ def on_log(self, args, state, control, logs=None, **kwargs): # noqa: D401
838
+ if not logs:
839
+ return
840
+ loss = logs.get("loss")
841
+ if loss is None:
842
+ return
843
+ try:
844
+ lf = float(loss)
845
+ except (TypeError, ValueError):
846
+ return
847
+ if math.isnan(lf) or math.isinf(lf):
848
+ print(f"[sft] loss={loss} is NaN/inf at step {state.global_step}; "
849
+ f"stopping training.")
850
+ control.should_training_stop = True
851
+
852
+ return _LossGuard()
853
+
854
+
855
+ # --------------------------------------------------------------------------- #
856
+ # Main #
857
+ # --------------------------------------------------------------------------- #
858
+
859
+
860
+ def main(argv: Iterable[str] = ()) -> int:
861
+ parser = argparse.ArgumentParser(description=__doc__)
862
+ parser.add_argument("--dataset", type=str, default="data/sft_dataset.jsonl")
863
+ parser.add_argument("--val-dataset", type=str,
864
+ default="data/sft_validation.jsonl",
865
+ help="held-out validation JSONL (rich records). "
866
+ "If missing, validation is skipped.")
867
+ parser.add_argument("--output", type=str, default="checkpoints/sft_warmup")
868
+ parser.add_argument("--model", type=str,
869
+ default=os.getenv("QUBIT_MEDIC_MODEL",
870
+ "Qwen/Qwen2.5-3B-Instruct"))
871
+ parser.add_argument("--epochs", type=int, default=None)
872
+ parser.add_argument("--batch-size", type=int, default=None)
873
+ parser.add_argument("--grad-accum", type=int, default=None)
874
+ parser.add_argument("--lr", type=float, default=None)
875
+ parser.add_argument("--max-seq-len", type=int, default=None)
876
+ parser.add_argument("--max-steps", type=int, default=None,
877
+ help="hard cap on training steps (default 200)")
878
+ parser.add_argument("--seed", type=int, default=None)
879
+ parser.add_argument("--lora-r", type=int, default=None)
880
+ parser.add_argument("--lora-alpha", type=int, default=None)
881
+ parser.add_argument("--lora-dropout", type=float, default=None)
882
+ parser.add_argument("--report-to", type=str, default="wandb")
883
+ parser.add_argument("--wandb-run-name", type=str, default=None)
884
+ parser.add_argument("--wandb-group", type=str, default=None)
885
+ parser.add_argument("--wandb-tags", type=str, nargs="*", default=("sft",))
886
+ parser.add_argument("--wandb-notes", type=str, default=None)
887
+ parser.add_argument("--eval-every", type=int, default=None,
888
+ help="run validation pass every N steps (legacy "
889
+ "fallback when --no-eval-schedule is set)")
890
+ parser.add_argument("--no-eval-schedule", action="store_true",
891
+ help="disable the variable-cadence schedule "
892
+ "(SFT_EVAL_SCHEDULE) and fall back to "
893
+ "uniform --eval-every spacing")
894
+ parser.add_argument("--print-sample-outputs", type=int,
895
+ default=None,
896
+ help="N raw model outputs to print + persist per eval "
897
+ "(defaults to SFT_PRINT_SAMPLE_OUTPUTS from config)")
898
+ parser.add_argument("--diversity-samples", type=int, default=10,
899
+ help="N samples for the output_diversity probe")
900
+ parser.add_argument("--diversity-temperature", type=float, default=0.7)
901
+ parser.add_argument("--no-artifact", action="store_true")
902
+ args = parser.parse_args(list(argv))
903
+
904
+ # Pre-flight dataset audit. Runs in seconds on the CPU before any heavy
905
+ # ML deps are imported, so a broken dataset never reaches the GPU. Halts
906
+ # via SystemExit(2) on any threshold violation.
907
+ audit_sft_dataset(args.dataset, args.val_dataset)
908
+
909
+ # Heavy imports are lazy so this module is importable without GPU deps.
910
+ try:
911
+ from unsloth import FastLanguageModel
912
+ except ImportError:
913
+ print("ERROR: unsloth not installed. Run `pip install -r requirements-train.txt`",
914
+ file=sys.stderr)
915
+ return 1
916
+ import torch
917
+ from transformers import TrainingArguments
918
+ from trl import SFTTrainer
919
+
920
+ from qubit_medic import wandb_utils
921
+ from qubit_medic.config import (
922
+ LORA_ALPHA, LORA_DROPOUT, LORA_R, LORA_TARGET_MODULES, MODEL_ID,
923
+ PRIMARY_SEED, SFT_BATCH_SIZE, SFT_DIVERSITY_COLLAPSE_RUN_LEN,
924
+ SFT_EARLY_STOP_CORRECTION, SFT_EARLY_STOP_DIVERSITY,
925
+ SFT_EARLY_STOP_FORMAT, SFT_EPOCHS, SFT_EVAL_EVERY, SFT_EVAL_SCHEDULE,
926
+ SFT_GRAD_ACCUM, SFT_LABEL_SMOOTHING, SFT_LOG_EVERY, SFT_LR,
927
+ SFT_LR_SCHEDULER, SFT_MAX_NEW_TOKENS, SFT_MAX_SEQ_LEN, SFT_MAX_STEPS,
928
+ SFT_MAX_WALL_SECONDS, SFT_OPTIMIZER, SFT_PREFLIGHT_DIVERSITY_FLOOR,
929
+ SFT_PRINT_SAMPLE_OUTPUTS, SFT_SAVE_EVERY, SFT_WARMUP_STEPS,
930
+ SFT_WEIGHT_DECAY,
931
+ )
932
+
933
+ epochs = args.epochs if args.epochs is not None else SFT_EPOCHS
934
+ batch_size = args.batch_size if args.batch_size is not None else SFT_BATCH_SIZE
935
+ grad_accum = args.grad_accum if args.grad_accum is not None else SFT_GRAD_ACCUM
936
+ lr = args.lr if args.lr is not None else SFT_LR
937
+ max_seq_len = args.max_seq_len if args.max_seq_len is not None else SFT_MAX_SEQ_LEN
938
+ max_steps = args.max_steps if args.max_steps is not None else SFT_MAX_STEPS
939
+ seed = args.seed if args.seed is not None else PRIMARY_SEED
940
+ lora_r = args.lora_r if args.lora_r is not None else LORA_R
941
+ lora_alpha = args.lora_alpha if args.lora_alpha is not None else LORA_ALPHA
942
+ lora_dropout = args.lora_dropout if args.lora_dropout is not None else LORA_DROPOUT
943
+ eval_every = args.eval_every if args.eval_every is not None else SFT_EVAL_EVERY
944
+ print_sample_outputs = (
945
+ args.print_sample_outputs
946
+ if args.print_sample_outputs is not None
947
+ else SFT_PRINT_SAMPLE_OUTPUTS
948
+ )
949
+ model_id = args.model if args.model else MODEL_ID
950
+
951
+ random.seed(seed)
952
+ torch.manual_seed(seed)
953
+ if torch.cuda.is_available():
954
+ torch.cuda.manual_seed_all(seed)
955
+
956
+ # ---- W&B init (no-op if unavailable / disabled) -------------------- #
957
+ report_to = wandb_utils.derive_report_to(args.report_to)
958
+ run_name = args.wandb_run_name or wandb_utils.make_run_name("sft")
959
+ wandb_utils.init_run(
960
+ run_name=run_name,
961
+ job_type="sft",
962
+ tags=args.wandb_tags,
963
+ notes=args.wandb_notes,
964
+ group=args.wandb_group,
965
+ extra_config={
966
+ "cli": {
967
+ "epochs": epochs,
968
+ "batch_size": batch_size,
969
+ "grad_accum": grad_accum,
970
+ "effective_batch": batch_size * grad_accum,
971
+ "lr": lr,
972
+ "lr_scheduler": SFT_LR_SCHEDULER,
973
+ "warmup_steps": SFT_WARMUP_STEPS,
974
+ "weight_decay": SFT_WEIGHT_DECAY,
975
+ "optimizer": SFT_OPTIMIZER,
976
+ "max_seq_len": max_seq_len,
977
+ "max_steps": max_steps,
978
+ "lora_r": lora_r,
979
+ "lora_alpha": lora_alpha,
980
+ "lora_dropout": lora_dropout,
981
+ "lora_target_modules": list(LORA_TARGET_MODULES),
982
+ "dataset_path": args.dataset,
983
+ "val_dataset_path": args.val_dataset,
984
+ "model": model_id,
985
+ "seed": seed,
986
+ "report_to": report_to,
987
+ "eval_every": eval_every,
988
+ "save_every": SFT_SAVE_EVERY,
989
+ "log_every": SFT_LOG_EVERY,
990
+ "early_stop_format": SFT_EARLY_STOP_FORMAT,
991
+ "early_stop_correction": SFT_EARLY_STOP_CORRECTION,
992
+ "early_stop_diversity": SFT_EARLY_STOP_DIVERSITY,
993
+ "max_wall_seconds": SFT_MAX_WALL_SECONDS,
994
+ },
995
+ },
996
+ )
997
+
998
+ # ---- Preflight: refuse to run with the known-bad Unsloth+TF combo #
999
+ # (unsloth >= 2026.4.0) + (transformers < 4.55.0) silently misparses
1000
+ # the Qwen2.5-3B config: it instantiates a 7B-shaped model
1001
+ # (hidden=4096) and crashes when the 3B checkpoint (hidden=2048)
1002
+ # starts loading, with:
1003
+ # RuntimeError: size mismatch for weight: copying a param with
1004
+ # shape torch.Size([151936, 2048]) from checkpoint, the shape in
1005
+ # current model is torch.Size([151936, 4096]).
1006
+ # We catch this BEFORE downloading >5GB of weights so the user does
1007
+ # not burn GPU minutes on a deterministic failure.
1008
+ import unsloth as _unsloth
1009
+ import transformers as _transformers
1010
+
1011
+ def _parse_ver(v: str) -> tuple[int, ...]:
1012
+ out: list[int] = []
1013
+ for part in v.split("+", 1)[0].split("."):
1014
+ digits = "".join(ch for ch in part if ch.isdigit())
1015
+ out.append(int(digits) if digits else 0)
1016
+ return tuple(out)
1017
+
1018
+ _u = _parse_ver(_unsloth.__version__)
1019
+ _t = _parse_ver(_transformers.__version__)
1020
+ _is_qwen25_3b = "qwen2.5-3b" in model_id.lower()
1021
+ _bad_combo = _u >= (2026, 4, 0) and _t < (4, 55, 0)
1022
+ if _is_qwen25_3b and _bad_combo:
1023
+ print(
1024
+ "[train_sft] FATAL: detected the unsloth/transformers combo that\n"
1025
+ f" silently misparses {model_id} into a 7B-shaped model.\n"
1026
+ f" Installed: unsloth=={_unsloth.__version__} "
1027
+ f"transformers=={_transformers.__version__}\n"
1028
+ " This exact pair produces the\n"
1029
+ " 'size mismatch ... [151936, 2048] vs [151936, 4096]'\n"
1030
+ " error during model load on Lightning AI / Colab.\n"
1031
+ " Fix: pin to a known-good combination, e.g.\n"
1032
+ " pip install --no-deps --force-reinstall \\\n"
1033
+ " unsloth==2025.11.1 unsloth_zoo==2026.4.9\n"
1034
+ " pip install --force-reinstall \\\n"
1035
+ " transformers==4.57.2 trl==0.20.0\n"
1036
+ " Or re-run scripts/run_lightning_pipeline.sh which\n"
1037
+ " pins these correctly and now hard-fails if the pins\n"
1038
+ " do not stick.",
1039
+ file=sys.stderr,
1040
+ )
1041
+ return 1
1042
+
1043
+ # ---- Load model + datasets --------------------------------------- #
1044
+ print(f"loading {model_id} via Unsloth (4-bit NF4)")
1045
+ print(f" unsloth={_unsloth.__version__} "
1046
+ f"transformers={_transformers.__version__}")
1047
+ model, tokenizer = FastLanguageModel.from_pretrained(
1048
+ model_name=model_id,
1049
+ max_seq_length=max_seq_len,
1050
+ load_in_4bit=True,
1051
+ dtype=None, # Unsloth auto-selects bf16/fp16
1052
+ )
1053
+ model = FastLanguageModel.get_peft_model(
1054
+ model,
1055
+ r=lora_r,
1056
+ lora_alpha=lora_alpha,
1057
+ target_modules=list(LORA_TARGET_MODULES),
1058
+ lora_dropout=lora_dropout,
1059
+ bias="none",
1060
+ use_gradient_checkpointing="unsloth",
1061
+ random_state=seed,
1062
+ )
1063
+
1064
+ print(f"loading train dataset from {args.dataset}")
1065
+ train_dataset = _load_train_dataset(args.dataset, tokenizer)
1066
+ print(f" {len(train_dataset)} samples; first text len = "
1067
+ f"{len(train_dataset[0]['text'])}")
1068
+
1069
+ val_records: list[dict] = []
1070
+ val_path = Path(args.val_dataset)
1071
+ if val_path.exists():
1072
+ val_records = _load_jsonl(args.val_dataset)
1073
+ print(f"loaded {len(val_records)} held-out validation records "
1074
+ f"from {args.val_dataset}")
1075
+ else:
1076
+ print(f"WARNING: no validation file at {args.val_dataset}; "
1077
+ f"running without eval / early-stop.")
1078
+
1079
+ wandb_utils.log({
1080
+ "sft/train_dataset_size": len(train_dataset),
1081
+ "sft/val_dataset_size": len(val_records),
1082
+ "sft/first_text_len": len(train_dataset[0]["text"]),
1083
+ })
1084
+
1085
+ # Dataset preview to W&B (sanity check the chat-template wrapping).
1086
+ wandb_utils.log_generation_table(
1087
+ [
1088
+ {"split": "train", "prompt": train_dataset[i]["prompt"][:600],
1089
+ "completion": train_dataset[i]["completion"]}
1090
+ for i in range(min(8, len(train_dataset)))
1091
+ ],
1092
+ step=0,
1093
+ table_name="sft/train_preview",
1094
+ columns=["split", "prompt", "completion"],
1095
+ )
1096
+
1097
+ # ---- TrainingArguments (locked spec) ----------------------------- #
1098
+ Path(args.output).mkdir(parents=True, exist_ok=True)
1099
+ bf16_supported = (
1100
+ torch.cuda.is_available() and torch.cuda.is_bf16_supported()
1101
+ )
1102
+ training_args = TrainingArguments(
1103
+ output_dir=args.output,
1104
+ num_train_epochs=epochs,
1105
+ max_steps=max_steps, # hard cap; wins over epochs
1106
+ per_device_train_batch_size=batch_size,
1107
+ gradient_accumulation_steps=grad_accum,
1108
+ learning_rate=lr,
1109
+ weight_decay=SFT_WEIGHT_DECAY,
1110
+ # Label smoothing was added in the 2026-04 SFT regularisation
1111
+ # rewrite (FIX 2) to combat mode collapse: spreading the loss
1112
+ # across non-target tokens makes the model less sharply rewarded
1113
+ # for memorising one canonical completion, which is what kept
1114
+ # output_diversity at 1 across every prior checkpoint.
1115
+ label_smoothing_factor=SFT_LABEL_SMOOTHING,
1116
+ warmup_steps=SFT_WARMUP_STEPS,
1117
+ lr_scheduler_type=SFT_LR_SCHEDULER,
1118
+ optim=SFT_OPTIMIZER,
1119
+ bf16=bf16_supported,
1120
+ fp16=torch.cuda.is_available() and not bf16_supported,
1121
+ logging_steps=SFT_LOG_EVERY,
1122
+ save_steps=SFT_SAVE_EVERY,
1123
+ save_total_limit=4,
1124
+ seed=seed,
1125
+ report_to=report_to,
1126
+ run_name=run_name,
1127
+ )
1128
+
1129
+ # ---- Callbacks --------------------------------------------------- #
1130
+ started_wall = time.time()
1131
+ callbacks = [_build_loss_guard_callback()]
1132
+ eval_schedule = None if args.no_eval_schedule else SFT_EVAL_SCHEDULE
1133
+ val_cb = _build_validation_callback(
1134
+ model=model,
1135
+ tokenizer=tokenizer,
1136
+ val_records=val_records,
1137
+ eval_every=eval_every,
1138
+ eval_schedule=eval_schedule,
1139
+ print_sample_outputs=print_sample_outputs,
1140
+ output_dir=args.output,
1141
+ max_new_tokens=SFT_MAX_NEW_TOKENS,
1142
+ diversity_n_samples=args.diversity_samples,
1143
+ diversity_temperature=args.diversity_temperature,
1144
+ early_stop_format=SFT_EARLY_STOP_FORMAT,
1145
+ early_stop_correction=SFT_EARLY_STOP_CORRECTION,
1146
+ early_stop_diversity=SFT_EARLY_STOP_DIVERSITY,
1147
+ max_wall_seconds=SFT_MAX_WALL_SECONDS,
1148
+ started_wall=started_wall,
1149
+ # 2026-04 (FIX 2) diversity-collapse regression early stop.
1150
+ diversity_floor=SFT_PREFLIGHT_DIVERSITY_FLOOR,
1151
+ diversity_run_len=SFT_DIVERSITY_COLLAPSE_RUN_LEN,
1152
+ )
1153
+ if val_cb is not None:
1154
+ callbacks.append(val_cb)
1155
+
1156
+ trainer = SFTTrainer(
1157
+ model=model,
1158
+ tokenizer=tokenizer,
1159
+ train_dataset=train_dataset,
1160
+ dataset_text_field="text",
1161
+ max_seq_length=max_seq_len,
1162
+ args=training_args,
1163
+ packing=False,
1164
+ callbacks=callbacks,
1165
+ )
1166
+
1167
+ print(f"training (max_steps={max_steps}, eval_every={eval_every}) ...")
1168
+ train_result = trainer.train()
1169
+ elapsed = time.time() - started_wall
1170
+ metrics = getattr(train_result, "metrics", {}) or {}
1171
+ wandb_utils.update_summary({
1172
+ "sft/wall_seconds": elapsed,
1173
+ **{f"sft/final/{k}": v for k, v in metrics.items()
1174
+ if isinstance(v, (int, float))},
1175
+ })
1176
+ print(f"training finished in {elapsed:.1f}s "
1177
+ f"(max_wall_seconds={SFT_MAX_WALL_SECONDS:.0f})")
1178
+
1179
+ print(f"saving adapters to {args.output}")
1180
+ model.save_pretrained(args.output)
1181
+ tokenizer.save_pretrained(args.output)
1182
+
1183
+ # ---- Upload adapter as W&B artifact ------------------------------ #
1184
+ if not args.no_artifact:
1185
+ wandb_utils.log_artifact(
1186
+ args.output,
1187
+ name=f"sft-adapter-{run_name}",
1188
+ artifact_type="model",
1189
+ description="SFT-warmed Qwen2.5-3B + LoRA adapter (Qubit-Medic).",
1190
+ )
1191
+
1192
+ wandb_utils.finish_run()
1193
+ print("done")
1194
+ return 0
1195
+
1196
+
1197
+ if __name__ == "__main__":
1198
+ sys.exit(main(sys.argv[1:]))