Spaces:
Sleeping
Sleeping
Upload scripts/train_sft.py with huggingface_hub
Browse files- scripts/train_sft.py +1198 -0
scripts/train_sft.py
ADDED
|
@@ -0,0 +1,1198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""scripts/train_sft.py - SFT warm-up phase (master spec, sections 1-3).
|
| 2 |
+
|
| 3 |
+
Loads ``Qwen/Qwen2.5-3B-Instruct`` in 4-bit (NF4) via Unsloth, attaches a
|
| 4 |
+
LoRA adapter (rank 16, alpha 32, dropout 0.05, on q/k/v/o projections),
|
| 5 |
+
and runs a single epoch of supervised fine-tuning on
|
| 6 |
+
``data/sft_dataset.jsonl`` (3,000 examples).
|
| 7 |
+
|
| 8 |
+
Goal: take the base model from ~0% format compliance to >=95% so the GRPO
|
| 9 |
+
trainer has a non-zero probability of getting parseable rewards.
|
| 10 |
+
|
| 11 |
+
Locked hyperparameters (master spec, section 1):
|
| 12 |
+
* batch=4, grad_accum=4 -> effective batch 16
|
| 13 |
+
* lr=2e-4 with 20-step linear warmup -> constant
|
| 14 |
+
* weight_decay=0.01, optimizer=adamw_8bit, mixed precision=bf16
|
| 15 |
+
* max_seq_len=1024, epochs=1, max_steps=200
|
| 16 |
+
* checkpoint every 50, eval every 50, log every 10
|
| 17 |
+
* seed=42
|
| 18 |
+
|
| 19 |
+
Designed to run on a Colab T4 in <=30 minutes.
|
| 20 |
+
|
| 21 |
+
Usage::
|
| 22 |
+
|
| 23 |
+
pip install -r requirements-train.txt
|
| 24 |
+
python -m scripts.train_sft \
|
| 25 |
+
--dataset data/sft_dataset.jsonl \
|
| 26 |
+
--val-dataset data/sft_validation.jsonl \
|
| 27 |
+
--output checkpoints/sft_warmup \
|
| 28 |
+
--report-to wandb
|
| 29 |
+
|
| 30 |
+
W&B logging (master spec, section 2)
|
| 31 |
+
------------------------------------
|
| 32 |
+
* Every 10 steps: TRL's built-in train/loss, learning_rate, grad_norm,
|
| 33 |
+
epoch, global_step.
|
| 34 |
+
* Every 50 steps (validation pass on 100 held-out syndromes):
|
| 35 |
+
|
| 36 |
+
eval/format_compliance
|
| 37 |
+
eval/logical_correction_rate
|
| 38 |
+
eval/exact_match_pymatching
|
| 39 |
+
eval/hamming_overlap_mean
|
| 40 |
+
eval/output_length_mean
|
| 41 |
+
eval/output_diversity (10 samples of one prompt @ T=0.7)
|
| 42 |
+
eval/syndrome_consistency
|
| 43 |
+
|
| 44 |
+
* End-of-train: ``run.summary`` dump of final eval scores; LoRA adapter
|
| 45 |
+
uploaded as a W&B artifact.
|
| 46 |
+
|
| 47 |
+
Early stopping (master spec, section 3)
|
| 48 |
+
---------------------------------------
|
| 49 |
+
Training halts as soon as ANY of these is true after a validation pass:
|
| 50 |
+
|
| 51 |
+
1. format_compliance >= 0.95 AND logical_correction_rate >= 0.80
|
| 52 |
+
AND output_diversity >= 3 (success)
|
| 53 |
+
2. global_step >= 200 (hard cap)
|
| 54 |
+
3. wall-clock >= 30 minutes (hard cap)
|
| 55 |
+
4. train/loss has NaN or inf (failure)
|
| 56 |
+
"""
|
| 57 |
+
from __future__ import annotations
|
| 58 |
+
|
| 59 |
+
import argparse
|
| 60 |
+
import json
|
| 61 |
+
import os
|
| 62 |
+
import random
|
| 63 |
+
import re
|
| 64 |
+
import statistics
|
| 65 |
+
import sys
|
| 66 |
+
import time
|
| 67 |
+
from collections import defaultdict
|
| 68 |
+
from pathlib import Path
|
| 69 |
+
from typing import Iterable, Optional
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --------------------------------------------------------------------------- #
|
| 73 |
+
# Pre-flight dataset audit #
|
| 74 |
+
# --------------------------------------------------------------------------- #
|
| 75 |
+
# Runs as the FIRST step of main(), before any model/tokenizer/heavy imports.
|
| 76 |
+
# Catches dataset regressions (class collapse, format drift, parse breakage,
|
| 77 |
+
# size mismatches) in a few seconds, before burning ~30 min of GPU on a run
|
| 78 |
+
# that was doomed at row 0.
|
| 79 |
+
#
|
| 80 |
+
# 9 checks, 3 of them duplicated on the validation split. Any failure raises
|
| 81 |
+
# SystemExit(2) so the Colab/Lightning shell pipeline exits with a non-zero
|
| 82 |
+
# status and won't proceed to model loading.
|
| 83 |
+
|
| 84 |
+
_FORMAT_ANCHOR_RE = re.compile(r"X_ERRORS=\[[\d,\s]*\]\s*Z_ERRORS=\[[\d,\s]*\]\s*$")
|
| 85 |
+
_FORMAT_ONLY_RE = re.compile(r"^\s*X_ERRORS=\[[\d,\s]*\]\s*Z_ERRORS=\[[\d,\s]*\]\s*$")
|
| 86 |
+
_TAIL_RE = re.compile(r"X_ERRORS=\[([^\]]*)\]\s*Z_ERRORS=\[([^\]]*)\]\s*$")
|
| 87 |
+
_LEVEL_P_RE = re.compile(r"Physical error rate:\s*([\d.]+)")
|
| 88 |
+
_LEVEL_D_RE = re.compile(r"Code distance:\s*(\d+)")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _detect_level_from_prompt(prompt: str) -> str:
|
| 92 |
+
"""Return ``"L1"``/``"L2"``/``"L3"``/``"unknown"`` for an SFT prompt.
|
| 93 |
+
|
| 94 |
+
Used as a fallback for legacy datasets that didn't write a ``level``
|
| 95 |
+
field into each record. We read the L1/L2/L3 ``p`` and ``distance``
|
| 96 |
+
values straight from :mod:`qubit_medic.config` rather than hardcoding
|
| 97 |
+
them, so the audit keeps working when the curriculum is tuned (e.g.
|
| 98 |
+
L1's ``p`` was bumped from 0.0001 -> 0.0005, which broke the old
|
| 99 |
+
hardcoded check and made every L1 row read as ``unknown``).
|
| 100 |
+
"""
|
| 101 |
+
m_p = _LEVEL_P_RE.search(prompt)
|
| 102 |
+
m_d = _LEVEL_D_RE.search(prompt)
|
| 103 |
+
if not m_p or not m_d:
|
| 104 |
+
return "unknown"
|
| 105 |
+
p = float(m_p.group(1))
|
| 106 |
+
d = int(m_d.group(1))
|
| 107 |
+
try:
|
| 108 |
+
from qubit_medic.config import level_by_name
|
| 109 |
+
l3 = level_by_name("L3_stretch")
|
| 110 |
+
l2 = level_by_name("L2_target")
|
| 111 |
+
l1 = level_by_name("L1_warmup")
|
| 112 |
+
if d == l3.distance and abs(p - l3.p) < 1e-9:
|
| 113 |
+
return "L3"
|
| 114 |
+
if d == l2.distance and abs(p - l2.p) < 1e-9:
|
| 115 |
+
return "L2"
|
| 116 |
+
if d == l1.distance and abs(p - l1.p) < 1e-9:
|
| 117 |
+
return "L1"
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
return "unknown"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _level_label_from_record(rec: dict) -> str:
|
| 124 |
+
"""Return ``"L1"``/``"L2"``/``"L3"``/``"unknown"`` for an SFT record.
|
| 125 |
+
|
| 126 |
+
Prefers the explicit ``level`` field written by
|
| 127 |
+
``scripts/generate_sft_data.py`` (e.g. ``"L1_warmup"``). Falls back
|
| 128 |
+
to :func:`_detect_level_from_prompt` for legacy records that lack
|
| 129 |
+
that field.
|
| 130 |
+
"""
|
| 131 |
+
raw = rec.get("level")
|
| 132 |
+
if isinstance(raw, str):
|
| 133 |
+
if raw.startswith("L1"):
|
| 134 |
+
return "L1"
|
| 135 |
+
if raw.startswith("L2"):
|
| 136 |
+
return "L2"
|
| 137 |
+
if raw.startswith("L3"):
|
| 138 |
+
return "L3"
|
| 139 |
+
prompt = rec.get("prompt")
|
| 140 |
+
if isinstance(prompt, str):
|
| 141 |
+
return _detect_level_from_prompt(prompt)
|
| 142 |
+
return "unknown"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _has_nonempty_correction(completion: str) -> bool:
|
| 146 |
+
"""True iff the completion's trailing format line predicts at least one
|
| 147 |
+
error (X or Z). Robust to a leading reasoning prefix.
|
| 148 |
+
"""
|
| 149 |
+
m = _TAIL_RE.search(completion.rstrip())
|
| 150 |
+
if m is None:
|
| 151 |
+
return False
|
| 152 |
+
return bool(m.group(1).strip()) or bool(m.group(2).strip())
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _audit_file(path: Path) -> dict:
|
| 156 |
+
"""Compute raw audit metrics for one JSONL file."""
|
| 157 |
+
if not path.exists():
|
| 158 |
+
return {"error": f"missing file: {path}"}
|
| 159 |
+
rows: list[dict] = []
|
| 160 |
+
parse_failures = 0
|
| 161 |
+
with path.open() as f:
|
| 162 |
+
for line in f:
|
| 163 |
+
line = line.strip()
|
| 164 |
+
if not line:
|
| 165 |
+
continue
|
| 166 |
+
try:
|
| 167 |
+
rec = json.loads(line)
|
| 168 |
+
except json.JSONDecodeError:
|
| 169 |
+
parse_failures += 1
|
| 170 |
+
continue
|
| 171 |
+
if "prompt" not in rec or "completion" not in rec:
|
| 172 |
+
parse_failures += 1
|
| 173 |
+
continue
|
| 174 |
+
rows.append(rec)
|
| 175 |
+
n = len(rows)
|
| 176 |
+
total_lines = n + parse_failures
|
| 177 |
+
parse_rate = (n / total_lines) if total_lines else 0.0
|
| 178 |
+
nonempty = sum(_has_nonempty_correction(r["completion"]) for r in rows)
|
| 179 |
+
anchor = sum(1 for r in rows if _FORMAT_ANCHOR_RE.search(r["completion"].rstrip()))
|
| 180 |
+
levels = {"L1": 0, "L2": 0, "L3": 0, "unknown": 0}
|
| 181 |
+
for r in rows:
|
| 182 |
+
levels[_level_label_from_record(r)] += 1
|
| 183 |
+
plens = [len(r["prompt"]) for r in rows]
|
| 184 |
+
clens = [len(r["completion"]) for r in rows]
|
| 185 |
+
format_only = sum(1 for r in rows if _FORMAT_ONLY_RE.fullmatch(r["completion"].strip()))
|
| 186 |
+
return {
|
| 187 |
+
"n": n,
|
| 188 |
+
"parse_failures": parse_failures,
|
| 189 |
+
"parse_rate": parse_rate,
|
| 190 |
+
"nonempty_frac": (nonempty / n) if n else 0.0,
|
| 191 |
+
"anchor_frac": (anchor / n) if n else 0.0,
|
| 192 |
+
"level_pct": {k: ((v / n) if n else 0.0) for k, v in levels.items()},
|
| 193 |
+
"plens": plens,
|
| 194 |
+
"clens": clens,
|
| 195 |
+
"format_only_frac": (format_only / n) if n else 0.0,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def audit_sft_dataset(
|
| 200 |
+
train_path: str = "data/sft_dataset.jsonl",
|
| 201 |
+
val_path: str = "data/sft_validation.jsonl",
|
| 202 |
+
) -> None:
|
| 203 |
+
"""Pre-flight audit of the SFT dataset. Halts (SystemExit) on violation.
|
| 204 |
+
|
| 205 |
+
Runs 9 checks against ``train_path`` plus 4 parallel checks against
|
| 206 |
+
``val_path``. Designed to run in seconds on the CPU before any heavy
|
| 207 |
+
ML deps are imported, so a broken dataset never reaches the GPU.
|
| 208 |
+
|
| 209 |
+
Locked thresholds:
|
| 210 |
+
Total rows: train=3000, val=100
|
| 211 |
+
JSON parse rate: 100%
|
| 212 |
+
Non-empty correction: 65-75%
|
| 213 |
+
Format anchor: 100%
|
| 214 |
+
Curriculum L1/L2/L3: 35-45% / 45-55% / 7-15%
|
| 215 |
+
Prompt length: min>=800, median in [1100,1600], max<=2200
|
| 216 |
+
Completion length: min>=22, median in [22,80], max<=120
|
| 217 |
+
Format-only target: 100%
|
| 218 |
+
Validation parallel: same thresholds applied to val split
|
| 219 |
+
"""
|
| 220 |
+
EXPECTED_TRAIN = 3000
|
| 221 |
+
EXPECTED_VAL = 100
|
| 222 |
+
NONEMPTY_LO, NONEMPTY_HI = 0.65, 0.75
|
| 223 |
+
# Tightened to match quota-based per-level generation in
|
| 224 |
+
# scripts/generate_sft_data.py, which produces the 40/50/10 split
|
| 225 |
+
# exactly (no rejection-sampling drift).
|
| 226 |
+
L1_LO, L1_HI = 0.38, 0.42
|
| 227 |
+
L2_LO, L2_HI = 0.48, 0.52
|
| 228 |
+
L3_LO, L3_HI = 0.08, 0.12
|
| 229 |
+
PLEN_MIN, PLEN_MED_LO, PLEN_MED_HI, PLEN_MAX = 800, 1100, 1600, 2200
|
| 230 |
+
# Targets are deliberately one-line format strings. The earlier
|
| 231 |
+
# reasoning-prefix targets made the base model burn the full eval token
|
| 232 |
+
# budget on analysis and never reach the required parseable answer line.
|
| 233 |
+
CLEN_MIN, CLEN_MED_LO, CLEN_MED_HI, CLEN_MAX = 22, 22, 80, 120
|
| 234 |
+
FORMAT_ONLY_MIN = 1.0
|
| 235 |
+
|
| 236 |
+
train = _audit_file(Path(train_path))
|
| 237 |
+
if "error" in train:
|
| 238 |
+
print(f"[audit] FATAL: {train['error']}")
|
| 239 |
+
raise SystemExit(2)
|
| 240 |
+
|
| 241 |
+
# ------------------------------- train checks ------------------------- #
|
| 242 |
+
checks: list[tuple[str, str, bool]] = []
|
| 243 |
+
|
| 244 |
+
checks.append((
|
| 245 |
+
"Total rows",
|
| 246 |
+
f"{train['n']} (expected {EXPECTED_TRAIN})",
|
| 247 |
+
train["n"] == EXPECTED_TRAIN,
|
| 248 |
+
))
|
| 249 |
+
checks.append((
|
| 250 |
+
"JSON parse rate",
|
| 251 |
+
f"{train['parse_rate'] * 100:.1f}% ({train['parse_failures']} failures)",
|
| 252 |
+
abs(train["parse_rate"] - 1.0) < 1e-9,
|
| 253 |
+
))
|
| 254 |
+
checks.append((
|
| 255 |
+
"Non-empty correction",
|
| 256 |
+
f"{train['nonempty_frac'] * 100:.1f}% (target 65-75%)",
|
| 257 |
+
NONEMPTY_LO <= train["nonempty_frac"] <= NONEMPTY_HI,
|
| 258 |
+
))
|
| 259 |
+
checks.append((
|
| 260 |
+
"Format anchor",
|
| 261 |
+
f"{train['anchor_frac'] * 100:.1f}%",
|
| 262 |
+
abs(train["anchor_frac"] - 1.0) < 1e-9,
|
| 263 |
+
))
|
| 264 |
+
|
| 265 |
+
p1 = train["level_pct"]["L1"]
|
| 266 |
+
p2 = train["level_pct"]["L2"]
|
| 267 |
+
p3 = train["level_pct"]["L3"]
|
| 268 |
+
p_unknown = train["level_pct"]["unknown"]
|
| 269 |
+
checks.append((
|
| 270 |
+
"Curriculum L1/L2/L3",
|
| 271 |
+
f"{p1*100:.1f}/{p2*100:.1f}/{p3*100:.1f}% (unknown={p_unknown*100:.1f}%)",
|
| 272 |
+
(L1_LO <= p1 <= L1_HI
|
| 273 |
+
and L2_LO <= p2 <= L2_HI
|
| 274 |
+
and L3_LO <= p3 <= L3_HI),
|
| 275 |
+
))
|
| 276 |
+
|
| 277 |
+
pmin = min(train["plens"]) if train["plens"] else 0
|
| 278 |
+
pmed = int(statistics.median(train["plens"])) if train["plens"] else 0
|
| 279 |
+
pmax = max(train["plens"]) if train["plens"] else 0
|
| 280 |
+
checks.append((
|
| 281 |
+
"Prompt length",
|
| 282 |
+
f"min={pmin} median={pmed} max={pmax}",
|
| 283 |
+
(pmin >= PLEN_MIN
|
| 284 |
+
and PLEN_MED_LO <= pmed <= PLEN_MED_HI
|
| 285 |
+
and pmax <= PLEN_MAX),
|
| 286 |
+
))
|
| 287 |
+
|
| 288 |
+
cmin = min(train["clens"]) if train["clens"] else 0
|
| 289 |
+
cmed = int(statistics.median(train["clens"])) if train["clens"] else 0
|
| 290 |
+
cmax = max(train["clens"]) if train["clens"] else 0
|
| 291 |
+
checks.append((
|
| 292 |
+
"Completion length",
|
| 293 |
+
f"min={cmin} median={cmed} max={cmax}",
|
| 294 |
+
(cmin >= CLEN_MIN
|
| 295 |
+
and CLEN_MED_LO <= cmed <= CLEN_MED_HI
|
| 296 |
+
and cmax <= CLEN_MAX),
|
| 297 |
+
))
|
| 298 |
+
|
| 299 |
+
checks.append((
|
| 300 |
+
"Format-only completions",
|
| 301 |
+
f"{train['format_only_frac'] * 100:.1f}% (target 100%)",
|
| 302 |
+
abs(train["format_only_frac"] - FORMAT_ONLY_MIN) < 1e-9,
|
| 303 |
+
))
|
| 304 |
+
|
| 305 |
+
# ------------------------------- val parallel ------------------------- #
|
| 306 |
+
val = _audit_file(Path(val_path))
|
| 307 |
+
if "error" in val:
|
| 308 |
+
checks.append(("Validation parallel", val["error"], False))
|
| 309 |
+
else:
|
| 310 |
+
v1 = val["level_pct"]["L1"]
|
| 311 |
+
v2 = val["level_pct"]["L2"]
|
| 312 |
+
v3 = val["level_pct"]["L3"]
|
| 313 |
+
val_pass = (
|
| 314 |
+
val["n"] == EXPECTED_VAL
|
| 315 |
+
and abs(val["parse_rate"] - 1.0) < 1e-9
|
| 316 |
+
and NONEMPTY_LO <= val["nonempty_frac"] <= NONEMPTY_HI
|
| 317 |
+
and abs(val["anchor_frac"] - 1.0) < 1e-9
|
| 318 |
+
and abs(val["format_only_frac"] - FORMAT_ONLY_MIN) < 1e-9
|
| 319 |
+
and L1_LO <= v1 <= L1_HI
|
| 320 |
+
and L2_LO <= v2 <= L2_HI
|
| 321 |
+
and L3_LO <= v3 <= L3_HI
|
| 322 |
+
)
|
| 323 |
+
val_summary = (
|
| 324 |
+
f"rows={val['n']} parse={val['parse_rate']*100:.0f}% "
|
| 325 |
+
f"nonempty={val['nonempty_frac']*100:.1f}% "
|
| 326 |
+
f"anchor={val['anchor_frac']*100:.0f}% "
|
| 327 |
+
f"format_only={val['format_only_frac']*100:.0f}% "
|
| 328 |
+
f"L1/L2/L3={v1*100:.1f}/{v2*100:.1f}/{v3*100:.1f}%"
|
| 329 |
+
)
|
| 330 |
+
checks.append(("Validation parallel", val_summary, val_pass))
|
| 331 |
+
|
| 332 |
+
# ------------------------------- print banner ------------------------- #
|
| 333 |
+
print()
|
| 334 |
+
print("DATASET AUDIT SUMMARY")
|
| 335 |
+
print("=" * 21)
|
| 336 |
+
label_w = max(len(label) for label, _, _ in checks) + 1
|
| 337 |
+
val_w = max(len(val_str) for _, val_str, _ in checks)
|
| 338 |
+
for label, val_str, passed in checks:
|
| 339 |
+
mark = "✓" if passed else "✗" # ✓ / ✗
|
| 340 |
+
print(f"{(label + ':').ljust(label_w + 1)} {val_str.ljust(val_w)} [{mark}]")
|
| 341 |
+
|
| 342 |
+
all_passed = all(passed for _, _, passed in checks)
|
| 343 |
+
print()
|
| 344 |
+
if all_passed:
|
| 345 |
+
print("ALL CHECKS PASSED — DATASET READY FOR TRAINING")
|
| 346 |
+
print()
|
| 347 |
+
return
|
| 348 |
+
print("AUDIT FAILED — FIX DATASET BEFORE TRAINING")
|
| 349 |
+
print()
|
| 350 |
+
raise SystemExit(2)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# --------------------------------------------------------------------------- #
|
| 354 |
+
# Validation-record loading #
|
| 355 |
+
# --------------------------------------------------------------------------- #
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _load_jsonl(path: str) -> list[dict]:
|
| 359 |
+
rows: list[dict] = []
|
| 360 |
+
with open(path) as f:
|
| 361 |
+
for line in f:
|
| 362 |
+
rows.append(json.loads(line))
|
| 363 |
+
return rows
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _load_train_dataset(path: str, tokenizer):
|
| 367 |
+
"""Load the SFT JSONL into a HuggingFace Dataset.
|
| 368 |
+
|
| 369 |
+
Master spec (section 4): the chat template is applied via the
|
| 370 |
+
tokenizer (``apply_chat_template``), NOT by manually inserting
|
| 371 |
+
``<|im_start|>`` markers - that way the same template works across
|
| 372 |
+
Qwen2.5 / Qwen3 / etc. without surprises.
|
| 373 |
+
"""
|
| 374 |
+
from datasets import Dataset
|
| 375 |
+
|
| 376 |
+
rows = _load_jsonl(path)
|
| 377 |
+
out = []
|
| 378 |
+
for rec in rows:
|
| 379 |
+
messages = [
|
| 380 |
+
{"role": "user", "content": rec["prompt"]},
|
| 381 |
+
{"role": "assistant", "content": rec["completion"]},
|
| 382 |
+
]
|
| 383 |
+
try:
|
| 384 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 385 |
+
except Exception:
|
| 386 |
+
# Defensive fallback if apply_chat_template ever misbehaves.
|
| 387 |
+
text = (
|
| 388 |
+
"<|im_start|>user\n"
|
| 389 |
+
f"{rec['prompt']}\n<|im_end|>\n"
|
| 390 |
+
"<|im_start|>assistant\n"
|
| 391 |
+
f"{rec['completion']}<|im_end|>"
|
| 392 |
+
)
|
| 393 |
+
out.append({
|
| 394 |
+
"prompt": rec["prompt"],
|
| 395 |
+
"completion": rec["completion"],
|
| 396 |
+
"text": text,
|
| 397 |
+
})
|
| 398 |
+
return Dataset.from_list(out)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# --------------------------------------------------------------------------- #
|
| 402 |
+
# Per-level physics caches (used by the validation callback) #
|
| 403 |
+
# --------------------------------------------------------------------------- #
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _build_level_caches(needed_levels: set[str]) -> dict[str, dict]:
|
| 407 |
+
"""Pre-build circuit / matching / layout / supports per curriculum level."""
|
| 408 |
+
import pymatching
|
| 409 |
+
|
| 410 |
+
from qubit_medic.config import level_by_name
|
| 411 |
+
from qubit_medic.server.physics import (
|
| 412 |
+
build_circuit, build_dem, extract_layout, per_round_x_z_counts,
|
| 413 |
+
)
|
| 414 |
+
from qubit_medic.server.rewards import compute_final_detector_supports
|
| 415 |
+
|
| 416 |
+
caches: dict[str, dict] = {}
|
| 417 |
+
for name in needed_levels:
|
| 418 |
+
lvl = level_by_name(name)
|
| 419 |
+
circuit = build_circuit(lvl)
|
| 420 |
+
dem = build_dem(circuit)
|
| 421 |
+
matching = pymatching.Matching.from_detector_error_model(dem)
|
| 422 |
+
layout = extract_layout(circuit)
|
| 423 |
+
n_x, n_z = per_round_x_z_counts(layout)
|
| 424 |
+
supports = compute_final_detector_supports(layout)
|
| 425 |
+
caches[name] = {
|
| 426 |
+
"level": lvl,
|
| 427 |
+
"circuit": circuit,
|
| 428 |
+
"dem": dem,
|
| 429 |
+
"matching": matching,
|
| 430 |
+
"layout": layout,
|
| 431 |
+
"supports": supports,
|
| 432 |
+
"num_x_stab": n_x,
|
| 433 |
+
"num_z_stab": n_z,
|
| 434 |
+
}
|
| 435 |
+
return caches
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# --------------------------------------------------------------------------- #
|
| 439 |
+
# Validation callback (master spec, section 2 + section 3) #
|
| 440 |
+
# --------------------------------------------------------------------------- #
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _build_validation_callback(
|
| 444 |
+
*,
|
| 445 |
+
model,
|
| 446 |
+
tokenizer,
|
| 447 |
+
val_records: list[dict],
|
| 448 |
+
eval_every: int,
|
| 449 |
+
eval_schedule: tuple[tuple[int, int, str], ...] | None,
|
| 450 |
+
print_sample_outputs: int,
|
| 451 |
+
output_dir: str,
|
| 452 |
+
max_new_tokens: int,
|
| 453 |
+
diversity_n_samples: int,
|
| 454 |
+
diversity_temperature: float,
|
| 455 |
+
early_stop_format: float,
|
| 456 |
+
early_stop_correction: float,
|
| 457 |
+
early_stop_diversity: int,
|
| 458 |
+
max_wall_seconds: float,
|
| 459 |
+
started_wall: float,
|
| 460 |
+
diversity_floor: int = 2,
|
| 461 |
+
diversity_run_len: int = 2,
|
| 462 |
+
):
|
| 463 |
+
"""Returns a ``TrainerCallback`` that:
|
| 464 |
+
* fires at every step in ``eval_schedule`` (or every ``eval_every``
|
| 465 |
+
steps if no schedule is given) with a per-step sample size,
|
| 466 |
+
* logs the spec metrics + new diagnostic metrics to W&B,
|
| 467 |
+
* prints the first ``print_sample_outputs`` raw model outputs to
|
| 468 |
+
stdout AND to ``{output_dir}/eval_samples_step{N}.txt`` so a
|
| 469 |
+
broken parser / generation drift can be diagnosed in seconds,
|
| 470 |
+
* stops training when the success criterion or hard caps fire.
|
| 471 |
+
|
| 472 |
+
Metric semantics changed in this revision:
|
| 473 |
+
* Parse failures NO LONGER default to "predict no errors". Failed
|
| 474 |
+
rows contribute logical_correction=0, hamming=0,
|
| 475 |
+
syndrome_consistency=0 to the aggregates. This stops trivial
|
| 476 |
+
syndromes (~95% at p=0.001) from inflating logical_correction_rate
|
| 477 |
+
to 0.98 while format_compliance sits at 0.01.
|
| 478 |
+
* New ``eval/parse_failure_rate`` = 1 - format_compliance, so a
|
| 479 |
+
broken parser is impossible to miss.
|
| 480 |
+
* New ``eval/format_compliance_strict`` reports the share of
|
| 481 |
+
outputs that hit the canonical ``X_ERRORS=[...] Z_ERRORS=[...]``
|
| 482 |
+
form (Reward 4 == 1.0). The looser ``eval/format_compliance``
|
| 483 |
+
reports the share where the model's answer was extractable at all.
|
| 484 |
+
"""
|
| 485 |
+
from transformers import TrainerCallback
|
| 486 |
+
|
| 487 |
+
from qubit_medic import wandb_utils
|
| 488 |
+
from qubit_medic.prompts import parse_action
|
| 489 |
+
from qubit_medic.server.physics import SyndromeSample
|
| 490 |
+
from qubit_medic.server.rewards import compute_all_rewards
|
| 491 |
+
|
| 492 |
+
if not val_records:
|
| 493 |
+
return None
|
| 494 |
+
|
| 495 |
+
# Pre-build per-level physics for fast scoring.
|
| 496 |
+
needed = {r["level"] for r in val_records}
|
| 497 |
+
level_caches = _build_level_caches(needed)
|
| 498 |
+
|
| 499 |
+
# Pick one stable prompt for the diversity probe (always the same record
|
| 500 |
+
# so the diversity number is comparable across checkpoints).
|
| 501 |
+
diversity_record = val_records[0]
|
| 502 |
+
diversity_messages = [{"role": "user", "content": diversity_record["prompt"]}]
|
| 503 |
+
|
| 504 |
+
# Index the schedule: step -> (sample_size, mode). Sample sizes are
|
| 505 |
+
# capped at len(val_records) so a small held-out set still works.
|
| 506 |
+
if eval_schedule:
|
| 507 |
+
schedule = {
|
| 508 |
+
step: (min(size, len(val_records)), mode)
|
| 509 |
+
for step, size, mode in eval_schedule
|
| 510 |
+
}
|
| 511 |
+
else:
|
| 512 |
+
schedule = {}
|
| 513 |
+
|
| 514 |
+
sample_dir = Path(output_dir)
|
| 515 |
+
sample_dir.mkdir(parents=True, exist_ok=True)
|
| 516 |
+
|
| 517 |
+
# 2026-04 (FIX 2) diversity-collapse rolling buffer. We track the
|
| 518 |
+
# last ``diversity_run_len`` full-eval ``output_diversity`` values
|
| 519 |
+
# and stop training when every entry is below ``diversity_floor``.
|
| 520 |
+
from collections import deque as _deque
|
| 521 |
+
recent_diversity = _deque(maxlen=diversity_run_len)
|
| 522 |
+
|
| 523 |
+
class _ValidationCallback(TrainerCallback):
|
| 524 |
+
# Stamp the most recent eval here so the on_train_end hook can avoid
|
| 525 |
+
# re-running if the eval step coincided with the final step.
|
| 526 |
+
last_eval_step: int = -1
|
| 527 |
+
|
| 528 |
+
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
| 529 |
+
now = time.time() - started_wall
|
| 530 |
+
if now >= max_wall_seconds:
|
| 531 |
+
print(f"[sft] wall-clock cap {max_wall_seconds:.0f}s hit at step "
|
| 532 |
+
f"{state.global_step}; stopping.")
|
| 533 |
+
control.should_training_stop = True
|
| 534 |
+
return
|
| 535 |
+
|
| 536 |
+
step = state.global_step
|
| 537 |
+
if step == 0:
|
| 538 |
+
return
|
| 539 |
+
if schedule:
|
| 540 |
+
if step not in schedule:
|
| 541 |
+
return
|
| 542 |
+
else:
|
| 543 |
+
if step % eval_every != 0:
|
| 544 |
+
return
|
| 545 |
+
self._run_eval(state, control)
|
| 546 |
+
|
| 547 |
+
def on_train_end(self, args, state, control, **kwargs): # noqa: D401
|
| 548 |
+
if state.global_step != self.last_eval_step:
|
| 549 |
+
self._run_eval(state, control, final=True)
|
| 550 |
+
|
| 551 |
+
# ------------------------------------------------------------------ #
|
| 552 |
+
# Core evaluation #
|
| 553 |
+
# ------------------------------------------------------------------ #
|
| 554 |
+
def _generate_greedy(self, messages: list[dict]) -> tuple[str, int]:
|
| 555 |
+
text = tokenizer.apply_chat_template(
|
| 556 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 557 |
+
)
|
| 558 |
+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 559 |
+
try:
|
| 560 |
+
out = model.generate(
|
| 561 |
+
**inputs,
|
| 562 |
+
max_new_tokens=max_new_tokens,
|
| 563 |
+
do_sample=False,
|
| 564 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 565 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 566 |
+
)
|
| 567 |
+
gen_ids = out[0][inputs["input_ids"].shape[1]:]
|
| 568 |
+
completion = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| 569 |
+
return completion, int(gen_ids.shape[0])
|
| 570 |
+
except Exception as exc:
|
| 571 |
+
return f"<gen-error: {exc}>", 0
|
| 572 |
+
|
| 573 |
+
def _generate_sampled(self, messages: list[dict]) -> str:
|
| 574 |
+
text = tokenizer.apply_chat_template(
|
| 575 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 576 |
+
)
|
| 577 |
+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 578 |
+
try:
|
| 579 |
+
out = model.generate(
|
| 580 |
+
**inputs,
|
| 581 |
+
max_new_tokens=max_new_tokens,
|
| 582 |
+
do_sample=True,
|
| 583 |
+
temperature=diversity_temperature,
|
| 584 |
+
top_p=0.95,
|
| 585 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 586 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 587 |
+
)
|
| 588 |
+
return tokenizer.decode(
|
| 589 |
+
out[0][inputs["input_ids"].shape[1]:],
|
| 590 |
+
skip_special_tokens=True,
|
| 591 |
+
)
|
| 592 |
+
except Exception as exc:
|
| 593 |
+
return f"<gen-error: {exc}>"
|
| 594 |
+
|
| 595 |
+
def _run_eval(self, state, control, *, final: bool = False) -> None:
|
| 596 |
+
self.last_eval_step = state.global_step
|
| 597 |
+
try:
|
| 598 |
+
from unsloth import FastLanguageModel
|
| 599 |
+
FastLanguageModel.for_inference(model)
|
| 600 |
+
except Exception:
|
| 601 |
+
model.eval() # type: ignore[attr-defined]
|
| 602 |
+
|
| 603 |
+
step = state.global_step
|
| 604 |
+
# Resolve sample size + mode for this step.
|
| 605 |
+
if final and step in schedule:
|
| 606 |
+
sample_size, mode = schedule[step]
|
| 607 |
+
elif final:
|
| 608 |
+
sample_size, mode = len(val_records), "full"
|
| 609 |
+
elif step in schedule:
|
| 610 |
+
sample_size, mode = schedule[step]
|
| 611 |
+
else:
|
| 612 |
+
sample_size, mode = len(val_records), "full"
|
| 613 |
+
|
| 614 |
+
# Deterministic slice so the same prompts are used across checkpoints.
|
| 615 |
+
records = val_records[:sample_size]
|
| 616 |
+
n = len(records)
|
| 617 |
+
full_eval = (mode == "full")
|
| 618 |
+
|
| 619 |
+
n_format = 0 # lenient parse_success
|
| 620 |
+
n_format_strict = 0 # canonical "=" + "[]"
|
| 621 |
+
n_logical = n_exact = 0
|
| 622 |
+
sum_hamming = 0.0
|
| 623 |
+
sum_syndrome = 0.0
|
| 624 |
+
sum_length = 0
|
| 625 |
+
rows: list[dict] = []
|
| 626 |
+
sample_dump_lines: list[str] = [
|
| 627 |
+
f"=== eval samples @ step {step} (mode={mode}, n={n}) ===",
|
| 628 |
+
]
|
| 629 |
+
|
| 630 |
+
for idx, rec in enumerate(records):
|
| 631 |
+
num_data = int(rec["num_data_qubits"])
|
| 632 |
+
messages = [{"role": "user", "content": rec["prompt"]}]
|
| 633 |
+
completion, n_tokens = self._generate_greedy(messages)
|
| 634 |
+
sum_length += n_tokens
|
| 635 |
+
|
| 636 |
+
parsed = parse_action(completion, num_data_qubits=num_data)
|
| 637 |
+
fmt_ok = parsed.parse_success
|
| 638 |
+
fmt_strict_ok = bool(parsed.strict_format)
|
| 639 |
+
n_format += int(fmt_ok)
|
| 640 |
+
n_format_strict += int(fmt_strict_ok)
|
| 641 |
+
|
| 642 |
+
# Physics-heavy metrics only in "full" mode AND only when
|
| 643 |
+
# the parse actually succeeded. A failed parse means the
|
| 644 |
+
# model didn't produce a usable prediction; we score that
|
| 645 |
+
# as a miss (0) for every downstream metric instead of
|
| 646 |
+
# silently substituting an empty Pauli frame, which would
|
| 647 |
+
# accidentally score correct on the ~95% of trivial
|
| 648 |
+
# syndromes at p=0.001.
|
| 649 |
+
logical_ok = False
|
| 650 |
+
exact_ok = False
|
| 651 |
+
hamming = 0.0
|
| 652 |
+
syndrome = 0.0
|
| 653 |
+
if full_eval and fmt_ok:
|
| 654 |
+
cache = level_caches[rec["level"]]
|
| 655 |
+
layout = cache["layout"]
|
| 656 |
+
supports = cache["supports"]
|
| 657 |
+
sample = SyndromeSample(
|
| 658 |
+
syndrome_bits=list(map(int, rec["syndrome_bits"])),
|
| 659 |
+
actual_observable_flip=int(rec["actual_observable_flip"]),
|
| 660 |
+
pymatching_observable_pred=int(rec["pymatching_observable_pred"]),
|
| 661 |
+
pymatching_x_errors=list(map(int, rec["true_x_errors"])),
|
| 662 |
+
pymatching_z_errors=list(map(int, rec["true_z_errors"])),
|
| 663 |
+
)
|
| 664 |
+
breakdown = compute_all_rewards(parsed, sample, layout, supports)
|
| 665 |
+
logical_ok = breakdown.logical_correction >= 0.5
|
| 666 |
+
hamming = float(breakdown.hamming_overlap)
|
| 667 |
+
syndrome = float(breakdown.syndrome_consistency)
|
| 668 |
+
exact_ok = (
|
| 669 |
+
parsed.x_errors == sorted(set(rec["true_x_errors"]))
|
| 670 |
+
and parsed.z_errors == sorted(set(rec["true_z_errors"]))
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
n_logical += int(logical_ok)
|
| 674 |
+
n_exact += int(exact_ok)
|
| 675 |
+
sum_hamming += hamming
|
| 676 |
+
sum_syndrome += syndrome
|
| 677 |
+
|
| 678 |
+
if idx < print_sample_outputs:
|
| 679 |
+
sample_dump_lines.append(
|
| 680 |
+
f"\n--- sample {idx} (level={rec['level']}, "
|
| 681 |
+
f"true_x={rec['true_x_errors']}, true_z={rec['true_z_errors']}, "
|
| 682 |
+
f"fmt_ok={fmt_ok}, fmt_strict={fmt_strict_ok}, "
|
| 683 |
+
f"n_tokens={n_tokens}) ---\n"
|
| 684 |
+
f">>> RAW MODEL OUTPUT:\n{completion}\n"
|
| 685 |
+
f">>> PARSED: x={parsed.x_errors} z={parsed.z_errors}"
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
if idx < 4: # keep W&B table tiny
|
| 689 |
+
rows.append({
|
| 690 |
+
"step": step,
|
| 691 |
+
"prompt": rec["prompt"][:600],
|
| 692 |
+
"gold": rec["completion"],
|
| 693 |
+
"model": completion[:300],
|
| 694 |
+
"x_pred": ",".join(map(str, parsed.x_errors)),
|
| 695 |
+
"z_pred": ",".join(map(str, parsed.z_errors)),
|
| 696 |
+
"format_ok": fmt_ok,
|
| 697 |
+
"format_strict_ok": fmt_strict_ok,
|
| 698 |
+
"logical_ok": logical_ok,
|
| 699 |
+
"exact_match": exact_ok,
|
| 700 |
+
"hamming_overlap": hamming,
|
| 701 |
+
})
|
| 702 |
+
|
| 703 |
+
# ---------- print + persist raw output samples -------------- #
|
| 704 |
+
sample_blob = "\n".join(sample_dump_lines)
|
| 705 |
+
print(sample_blob)
|
| 706 |
+
try:
|
| 707 |
+
(sample_dir / f"eval_samples_step{step}.txt").write_text(sample_blob)
|
| 708 |
+
except OSError as exc:
|
| 709 |
+
print(f"[sft][eval@{step}] could not persist sample outputs: {exc}")
|
| 710 |
+
|
| 711 |
+
# ---------- diversity probe (skip in format_only mode) ------ #
|
| 712 |
+
if full_eval:
|
| 713 |
+
diverse_outputs: list[str] = []
|
| 714 |
+
for _ in range(diversity_n_samples):
|
| 715 |
+
diverse_outputs.append(self._generate_sampled(diversity_messages))
|
| 716 |
+
output_diversity = len(set(diverse_outputs))
|
| 717 |
+
else:
|
| 718 |
+
output_diversity = 0 # not measured this step
|
| 719 |
+
|
| 720 |
+
# ---------- aggregate + log to W&B ------------------------- #
|
| 721 |
+
metrics: dict[str, float | int] = {
|
| 722 |
+
"eval/format_compliance": n_format / max(1, n),
|
| 723 |
+
"eval/format_compliance_strict": n_format_strict / max(1, n),
|
| 724 |
+
"eval/parse_failure_rate": 1.0 - (n_format / max(1, n)),
|
| 725 |
+
"eval/output_length_mean": sum_length / max(1, n),
|
| 726 |
+
"eval/episodes": n,
|
| 727 |
+
"eval/mode_full": int(full_eval),
|
| 728 |
+
}
|
| 729 |
+
if full_eval:
|
| 730 |
+
metrics.update({
|
| 731 |
+
"eval/logical_correction_rate": n_logical / max(1, n),
|
| 732 |
+
"eval/exact_match_pymatching": n_exact / max(1, n),
|
| 733 |
+
"eval/hamming_overlap_mean": sum_hamming / max(1, n),
|
| 734 |
+
"eval/syndrome_consistency": sum_syndrome / max(1, n),
|
| 735 |
+
"eval/output_diversity": output_diversity,
|
| 736 |
+
})
|
| 737 |
+
print(f"[sft][eval@{step}] " + ", ".join(
|
| 738 |
+
f"{k.split('/')[-1]}={v:.3f}" if isinstance(v, float) else f"{k.split('/')[-1]}={v}"
|
| 739 |
+
for k, v in metrics.items()
|
| 740 |
+
))
|
| 741 |
+
wandb_utils.log(metrics, step=step)
|
| 742 |
+
wandb_utils.log_generation_table(
|
| 743 |
+
rows, step=step,
|
| 744 |
+
table_name=("sft/final_validation" if final else "sft/validation"),
|
| 745 |
+
columns=["step", "prompt", "gold", "model", "x_pred", "z_pred",
|
| 746 |
+
"format_ok", "format_strict_ok", "logical_ok",
|
| 747 |
+
"exact_match", "hamming_overlap"],
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# Fail fast on the known broken-SFT pattern: the model burns the
|
| 751 |
+
# whole generation budget on prose and never emits the format line.
|
| 752 |
+
# These thresholds mirror the runbook table in the issue analysis.
|
| 753 |
+
format_floor_by_step = {5: 0.10, 15: 0.30, 30: 0.60, 50: 0.80}
|
| 754 |
+
floor = format_floor_by_step.get(step)
|
| 755 |
+
if (
|
| 756 |
+
floor is not None
|
| 757 |
+
and not final
|
| 758 |
+
and metrics["eval/format_compliance"] < floor
|
| 759 |
+
):
|
| 760 |
+
print(
|
| 761 |
+
f"[sft] format guard tripped at step {step}: "
|
| 762 |
+
f"format_compliance={metrics['eval/format_compliance']:.3f} "
|
| 763 |
+
f"< {floor:.2f}. Stop and inspect raw outputs / data."
|
| 764 |
+
)
|
| 765 |
+
control.should_training_stop = True
|
| 766 |
+
wandb_utils.update_summary({
|
| 767 |
+
"sft/early_stop_reason": "format_guard",
|
| 768 |
+
"sft/format_guard_step": step,
|
| 769 |
+
"sft/format_guard_floor": floor,
|
| 770 |
+
})
|
| 771 |
+
|
| 772 |
+
# ---------- early stop checks ------------------------------ #
|
| 773 |
+
# Only meaningful on full evals: logical_correction_rate and
|
| 774 |
+
# output_diversity are not measured in format_only mode.
|
| 775 |
+
if full_eval:
|
| 776 |
+
success = (
|
| 777 |
+
metrics["eval/format_compliance"] >= early_stop_format
|
| 778 |
+
and metrics["eval/logical_correction_rate"] >= early_stop_correction
|
| 779 |
+
and metrics["eval/output_diversity"] >= early_stop_diversity
|
| 780 |
+
)
|
| 781 |
+
if success and not final:
|
| 782 |
+
print(f"[sft] success criterion hit at step {state.global_step}: "
|
| 783 |
+
f"format={metrics['eval/format_compliance']:.3f} >= {early_stop_format}, "
|
| 784 |
+
f"correction={metrics['eval/logical_correction_rate']:.3f} >= {early_stop_correction}, "
|
| 785 |
+
f"diversity={int(metrics['eval/output_diversity'])} >= {early_stop_diversity}; "
|
| 786 |
+
f"stopping.")
|
| 787 |
+
control.should_training_stop = True
|
| 788 |
+
wandb_utils.update_summary({"sft/early_stop_reason": "success_criterion"})
|
| 789 |
+
|
| 790 |
+
# 2026-04 (FIX 2) diversity-collapse early stop. Pushed
|
| 791 |
+
# AFTER the success check so a model that satisfies both
|
| 792 |
+
# criteria still wins; only sustained low diversity
|
| 793 |
+
# without convergence triggers the regression stop.
|
| 794 |
+
recent_diversity.append(int(metrics["eval/output_diversity"]))
|
| 795 |
+
if (
|
| 796 |
+
not final
|
| 797 |
+
and not control.should_training_stop
|
| 798 |
+
and len(recent_diversity) >= diversity_run_len
|
| 799 |
+
and all(d < diversity_floor for d in recent_diversity)
|
| 800 |
+
):
|
| 801 |
+
history = list(recent_diversity)
|
| 802 |
+
print(
|
| 803 |
+
f"[sft] diversity collapse early stop at step "
|
| 804 |
+
f"{state.global_step}: eval/output_diversity has "
|
| 805 |
+
f"been < {diversity_floor} for {diversity_run_len} "
|
| 806 |
+
f"consecutive full evals (history={history}). "
|
| 807 |
+
f"Stopping. Bump --lora-dropout (e.g. 0.15) or "
|
| 808 |
+
f"increase label smoothing and rerun."
|
| 809 |
+
)
|
| 810 |
+
control.should_training_stop = True
|
| 811 |
+
wandb_utils.update_summary({
|
| 812 |
+
"sft/early_stop_reason": "diversity_collapse",
|
| 813 |
+
"sft/diversity_collapse_step": state.global_step,
|
| 814 |
+
"sft/diversity_collapse_history": history,
|
| 815 |
+
})
|
| 816 |
+
|
| 817 |
+
try:
|
| 818 |
+
from unsloth import FastLanguageModel
|
| 819 |
+
FastLanguageModel.for_training(model)
|
| 820 |
+
except Exception:
|
| 821 |
+
model.train() # type: ignore[attr-defined]
|
| 822 |
+
|
| 823 |
+
return _ValidationCallback()
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
# --------------------------------------------------------------------------- #
|
| 827 |
+
# Loss-divergence guard (failure mode early stop) #
|
| 828 |
+
# --------------------------------------------------------------------------- #
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def _build_loss_guard_callback():
|
| 832 |
+
import math
|
| 833 |
+
|
| 834 |
+
from transformers import TrainerCallback
|
| 835 |
+
|
| 836 |
+
class _LossGuard(TrainerCallback):
|
| 837 |
+
def on_log(self, args, state, control, logs=None, **kwargs): # noqa: D401
|
| 838 |
+
if not logs:
|
| 839 |
+
return
|
| 840 |
+
loss = logs.get("loss")
|
| 841 |
+
if loss is None:
|
| 842 |
+
return
|
| 843 |
+
try:
|
| 844 |
+
lf = float(loss)
|
| 845 |
+
except (TypeError, ValueError):
|
| 846 |
+
return
|
| 847 |
+
if math.isnan(lf) or math.isinf(lf):
|
| 848 |
+
print(f"[sft] loss={loss} is NaN/inf at step {state.global_step}; "
|
| 849 |
+
f"stopping training.")
|
| 850 |
+
control.should_training_stop = True
|
| 851 |
+
|
| 852 |
+
return _LossGuard()
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
# --------------------------------------------------------------------------- #
|
| 856 |
+
# Main #
|
| 857 |
+
# --------------------------------------------------------------------------- #
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def main(argv: Iterable[str] = ()) -> int:
|
| 861 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 862 |
+
parser.add_argument("--dataset", type=str, default="data/sft_dataset.jsonl")
|
| 863 |
+
parser.add_argument("--val-dataset", type=str,
|
| 864 |
+
default="data/sft_validation.jsonl",
|
| 865 |
+
help="held-out validation JSONL (rich records). "
|
| 866 |
+
"If missing, validation is skipped.")
|
| 867 |
+
parser.add_argument("--output", type=str, default="checkpoints/sft_warmup")
|
| 868 |
+
parser.add_argument("--model", type=str,
|
| 869 |
+
default=os.getenv("QUBIT_MEDIC_MODEL",
|
| 870 |
+
"Qwen/Qwen2.5-3B-Instruct"))
|
| 871 |
+
parser.add_argument("--epochs", type=int, default=None)
|
| 872 |
+
parser.add_argument("--batch-size", type=int, default=None)
|
| 873 |
+
parser.add_argument("--grad-accum", type=int, default=None)
|
| 874 |
+
parser.add_argument("--lr", type=float, default=None)
|
| 875 |
+
parser.add_argument("--max-seq-len", type=int, default=None)
|
| 876 |
+
parser.add_argument("--max-steps", type=int, default=None,
|
| 877 |
+
help="hard cap on training steps (default 200)")
|
| 878 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 879 |
+
parser.add_argument("--lora-r", type=int, default=None)
|
| 880 |
+
parser.add_argument("--lora-alpha", type=int, default=None)
|
| 881 |
+
parser.add_argument("--lora-dropout", type=float, default=None)
|
| 882 |
+
parser.add_argument("--report-to", type=str, default="wandb")
|
| 883 |
+
parser.add_argument("--wandb-run-name", type=str, default=None)
|
| 884 |
+
parser.add_argument("--wandb-group", type=str, default=None)
|
| 885 |
+
parser.add_argument("--wandb-tags", type=str, nargs="*", default=("sft",))
|
| 886 |
+
parser.add_argument("--wandb-notes", type=str, default=None)
|
| 887 |
+
parser.add_argument("--eval-every", type=int, default=None,
|
| 888 |
+
help="run validation pass every N steps (legacy "
|
| 889 |
+
"fallback when --no-eval-schedule is set)")
|
| 890 |
+
parser.add_argument("--no-eval-schedule", action="store_true",
|
| 891 |
+
help="disable the variable-cadence schedule "
|
| 892 |
+
"(SFT_EVAL_SCHEDULE) and fall back to "
|
| 893 |
+
"uniform --eval-every spacing")
|
| 894 |
+
parser.add_argument("--print-sample-outputs", type=int,
|
| 895 |
+
default=None,
|
| 896 |
+
help="N raw model outputs to print + persist per eval "
|
| 897 |
+
"(defaults to SFT_PRINT_SAMPLE_OUTPUTS from config)")
|
| 898 |
+
parser.add_argument("--diversity-samples", type=int, default=10,
|
| 899 |
+
help="N samples for the output_diversity probe")
|
| 900 |
+
parser.add_argument("--diversity-temperature", type=float, default=0.7)
|
| 901 |
+
parser.add_argument("--no-artifact", action="store_true")
|
| 902 |
+
args = parser.parse_args(list(argv))
|
| 903 |
+
|
| 904 |
+
# Pre-flight dataset audit. Runs in seconds on the CPU before any heavy
|
| 905 |
+
# ML deps are imported, so a broken dataset never reaches the GPU. Halts
|
| 906 |
+
# via SystemExit(2) on any threshold violation.
|
| 907 |
+
audit_sft_dataset(args.dataset, args.val_dataset)
|
| 908 |
+
|
| 909 |
+
# Heavy imports are lazy so this module is importable without GPU deps.
|
| 910 |
+
try:
|
| 911 |
+
from unsloth import FastLanguageModel
|
| 912 |
+
except ImportError:
|
| 913 |
+
print("ERROR: unsloth not installed. Run `pip install -r requirements-train.txt`",
|
| 914 |
+
file=sys.stderr)
|
| 915 |
+
return 1
|
| 916 |
+
import torch
|
| 917 |
+
from transformers import TrainingArguments
|
| 918 |
+
from trl import SFTTrainer
|
| 919 |
+
|
| 920 |
+
from qubit_medic import wandb_utils
|
| 921 |
+
from qubit_medic.config import (
|
| 922 |
+
LORA_ALPHA, LORA_DROPOUT, LORA_R, LORA_TARGET_MODULES, MODEL_ID,
|
| 923 |
+
PRIMARY_SEED, SFT_BATCH_SIZE, SFT_DIVERSITY_COLLAPSE_RUN_LEN,
|
| 924 |
+
SFT_EARLY_STOP_CORRECTION, SFT_EARLY_STOP_DIVERSITY,
|
| 925 |
+
SFT_EARLY_STOP_FORMAT, SFT_EPOCHS, SFT_EVAL_EVERY, SFT_EVAL_SCHEDULE,
|
| 926 |
+
SFT_GRAD_ACCUM, SFT_LABEL_SMOOTHING, SFT_LOG_EVERY, SFT_LR,
|
| 927 |
+
SFT_LR_SCHEDULER, SFT_MAX_NEW_TOKENS, SFT_MAX_SEQ_LEN, SFT_MAX_STEPS,
|
| 928 |
+
SFT_MAX_WALL_SECONDS, SFT_OPTIMIZER, SFT_PREFLIGHT_DIVERSITY_FLOOR,
|
| 929 |
+
SFT_PRINT_SAMPLE_OUTPUTS, SFT_SAVE_EVERY, SFT_WARMUP_STEPS,
|
| 930 |
+
SFT_WEIGHT_DECAY,
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
epochs = args.epochs if args.epochs is not None else SFT_EPOCHS
|
| 934 |
+
batch_size = args.batch_size if args.batch_size is not None else SFT_BATCH_SIZE
|
| 935 |
+
grad_accum = args.grad_accum if args.grad_accum is not None else SFT_GRAD_ACCUM
|
| 936 |
+
lr = args.lr if args.lr is not None else SFT_LR
|
| 937 |
+
max_seq_len = args.max_seq_len if args.max_seq_len is not None else SFT_MAX_SEQ_LEN
|
| 938 |
+
max_steps = args.max_steps if args.max_steps is not None else SFT_MAX_STEPS
|
| 939 |
+
seed = args.seed if args.seed is not None else PRIMARY_SEED
|
| 940 |
+
lora_r = args.lora_r if args.lora_r is not None else LORA_R
|
| 941 |
+
lora_alpha = args.lora_alpha if args.lora_alpha is not None else LORA_ALPHA
|
| 942 |
+
lora_dropout = args.lora_dropout if args.lora_dropout is not None else LORA_DROPOUT
|
| 943 |
+
eval_every = args.eval_every if args.eval_every is not None else SFT_EVAL_EVERY
|
| 944 |
+
print_sample_outputs = (
|
| 945 |
+
args.print_sample_outputs
|
| 946 |
+
if args.print_sample_outputs is not None
|
| 947 |
+
else SFT_PRINT_SAMPLE_OUTPUTS
|
| 948 |
+
)
|
| 949 |
+
model_id = args.model if args.model else MODEL_ID
|
| 950 |
+
|
| 951 |
+
random.seed(seed)
|
| 952 |
+
torch.manual_seed(seed)
|
| 953 |
+
if torch.cuda.is_available():
|
| 954 |
+
torch.cuda.manual_seed_all(seed)
|
| 955 |
+
|
| 956 |
+
# ---- W&B init (no-op if unavailable / disabled) -------------------- #
|
| 957 |
+
report_to = wandb_utils.derive_report_to(args.report_to)
|
| 958 |
+
run_name = args.wandb_run_name or wandb_utils.make_run_name("sft")
|
| 959 |
+
wandb_utils.init_run(
|
| 960 |
+
run_name=run_name,
|
| 961 |
+
job_type="sft",
|
| 962 |
+
tags=args.wandb_tags,
|
| 963 |
+
notes=args.wandb_notes,
|
| 964 |
+
group=args.wandb_group,
|
| 965 |
+
extra_config={
|
| 966 |
+
"cli": {
|
| 967 |
+
"epochs": epochs,
|
| 968 |
+
"batch_size": batch_size,
|
| 969 |
+
"grad_accum": grad_accum,
|
| 970 |
+
"effective_batch": batch_size * grad_accum,
|
| 971 |
+
"lr": lr,
|
| 972 |
+
"lr_scheduler": SFT_LR_SCHEDULER,
|
| 973 |
+
"warmup_steps": SFT_WARMUP_STEPS,
|
| 974 |
+
"weight_decay": SFT_WEIGHT_DECAY,
|
| 975 |
+
"optimizer": SFT_OPTIMIZER,
|
| 976 |
+
"max_seq_len": max_seq_len,
|
| 977 |
+
"max_steps": max_steps,
|
| 978 |
+
"lora_r": lora_r,
|
| 979 |
+
"lora_alpha": lora_alpha,
|
| 980 |
+
"lora_dropout": lora_dropout,
|
| 981 |
+
"lora_target_modules": list(LORA_TARGET_MODULES),
|
| 982 |
+
"dataset_path": args.dataset,
|
| 983 |
+
"val_dataset_path": args.val_dataset,
|
| 984 |
+
"model": model_id,
|
| 985 |
+
"seed": seed,
|
| 986 |
+
"report_to": report_to,
|
| 987 |
+
"eval_every": eval_every,
|
| 988 |
+
"save_every": SFT_SAVE_EVERY,
|
| 989 |
+
"log_every": SFT_LOG_EVERY,
|
| 990 |
+
"early_stop_format": SFT_EARLY_STOP_FORMAT,
|
| 991 |
+
"early_stop_correction": SFT_EARLY_STOP_CORRECTION,
|
| 992 |
+
"early_stop_diversity": SFT_EARLY_STOP_DIVERSITY,
|
| 993 |
+
"max_wall_seconds": SFT_MAX_WALL_SECONDS,
|
| 994 |
+
},
|
| 995 |
+
},
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
# ---- Preflight: refuse to run with the known-bad Unsloth+TF combo #
|
| 999 |
+
# (unsloth >= 2026.4.0) + (transformers < 4.55.0) silently misparses
|
| 1000 |
+
# the Qwen2.5-3B config: it instantiates a 7B-shaped model
|
| 1001 |
+
# (hidden=4096) and crashes when the 3B checkpoint (hidden=2048)
|
| 1002 |
+
# starts loading, with:
|
| 1003 |
+
# RuntimeError: size mismatch for weight: copying a param with
|
| 1004 |
+
# shape torch.Size([151936, 2048]) from checkpoint, the shape in
|
| 1005 |
+
# current model is torch.Size([151936, 4096]).
|
| 1006 |
+
# We catch this BEFORE downloading >5GB of weights so the user does
|
| 1007 |
+
# not burn GPU minutes on a deterministic failure.
|
| 1008 |
+
import unsloth as _unsloth
|
| 1009 |
+
import transformers as _transformers
|
| 1010 |
+
|
| 1011 |
+
def _parse_ver(v: str) -> tuple[int, ...]:
|
| 1012 |
+
out: list[int] = []
|
| 1013 |
+
for part in v.split("+", 1)[0].split("."):
|
| 1014 |
+
digits = "".join(ch for ch in part if ch.isdigit())
|
| 1015 |
+
out.append(int(digits) if digits else 0)
|
| 1016 |
+
return tuple(out)
|
| 1017 |
+
|
| 1018 |
+
_u = _parse_ver(_unsloth.__version__)
|
| 1019 |
+
_t = _parse_ver(_transformers.__version__)
|
| 1020 |
+
_is_qwen25_3b = "qwen2.5-3b" in model_id.lower()
|
| 1021 |
+
_bad_combo = _u >= (2026, 4, 0) and _t < (4, 55, 0)
|
| 1022 |
+
if _is_qwen25_3b and _bad_combo:
|
| 1023 |
+
print(
|
| 1024 |
+
"[train_sft] FATAL: detected the unsloth/transformers combo that\n"
|
| 1025 |
+
f" silently misparses {model_id} into a 7B-shaped model.\n"
|
| 1026 |
+
f" Installed: unsloth=={_unsloth.__version__} "
|
| 1027 |
+
f"transformers=={_transformers.__version__}\n"
|
| 1028 |
+
" This exact pair produces the\n"
|
| 1029 |
+
" 'size mismatch ... [151936, 2048] vs [151936, 4096]'\n"
|
| 1030 |
+
" error during model load on Lightning AI / Colab.\n"
|
| 1031 |
+
" Fix: pin to a known-good combination, e.g.\n"
|
| 1032 |
+
" pip install --no-deps --force-reinstall \\\n"
|
| 1033 |
+
" unsloth==2025.11.1 unsloth_zoo==2026.4.9\n"
|
| 1034 |
+
" pip install --force-reinstall \\\n"
|
| 1035 |
+
" transformers==4.57.2 trl==0.20.0\n"
|
| 1036 |
+
" Or re-run scripts/run_lightning_pipeline.sh which\n"
|
| 1037 |
+
" pins these correctly and now hard-fails if the pins\n"
|
| 1038 |
+
" do not stick.",
|
| 1039 |
+
file=sys.stderr,
|
| 1040 |
+
)
|
| 1041 |
+
return 1
|
| 1042 |
+
|
| 1043 |
+
# ---- Load model + datasets --------------------------------------- #
|
| 1044 |
+
print(f"loading {model_id} via Unsloth (4-bit NF4)")
|
| 1045 |
+
print(f" unsloth={_unsloth.__version__} "
|
| 1046 |
+
f"transformers={_transformers.__version__}")
|
| 1047 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 1048 |
+
model_name=model_id,
|
| 1049 |
+
max_seq_length=max_seq_len,
|
| 1050 |
+
load_in_4bit=True,
|
| 1051 |
+
dtype=None, # Unsloth auto-selects bf16/fp16
|
| 1052 |
+
)
|
| 1053 |
+
model = FastLanguageModel.get_peft_model(
|
| 1054 |
+
model,
|
| 1055 |
+
r=lora_r,
|
| 1056 |
+
lora_alpha=lora_alpha,
|
| 1057 |
+
target_modules=list(LORA_TARGET_MODULES),
|
| 1058 |
+
lora_dropout=lora_dropout,
|
| 1059 |
+
bias="none",
|
| 1060 |
+
use_gradient_checkpointing="unsloth",
|
| 1061 |
+
random_state=seed,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
print(f"loading train dataset from {args.dataset}")
|
| 1065 |
+
train_dataset = _load_train_dataset(args.dataset, tokenizer)
|
| 1066 |
+
print(f" {len(train_dataset)} samples; first text len = "
|
| 1067 |
+
f"{len(train_dataset[0]['text'])}")
|
| 1068 |
+
|
| 1069 |
+
val_records: list[dict] = []
|
| 1070 |
+
val_path = Path(args.val_dataset)
|
| 1071 |
+
if val_path.exists():
|
| 1072 |
+
val_records = _load_jsonl(args.val_dataset)
|
| 1073 |
+
print(f"loaded {len(val_records)} held-out validation records "
|
| 1074 |
+
f"from {args.val_dataset}")
|
| 1075 |
+
else:
|
| 1076 |
+
print(f"WARNING: no validation file at {args.val_dataset}; "
|
| 1077 |
+
f"running without eval / early-stop.")
|
| 1078 |
+
|
| 1079 |
+
wandb_utils.log({
|
| 1080 |
+
"sft/train_dataset_size": len(train_dataset),
|
| 1081 |
+
"sft/val_dataset_size": len(val_records),
|
| 1082 |
+
"sft/first_text_len": len(train_dataset[0]["text"]),
|
| 1083 |
+
})
|
| 1084 |
+
|
| 1085 |
+
# Dataset preview to W&B (sanity check the chat-template wrapping).
|
| 1086 |
+
wandb_utils.log_generation_table(
|
| 1087 |
+
[
|
| 1088 |
+
{"split": "train", "prompt": train_dataset[i]["prompt"][:600],
|
| 1089 |
+
"completion": train_dataset[i]["completion"]}
|
| 1090 |
+
for i in range(min(8, len(train_dataset)))
|
| 1091 |
+
],
|
| 1092 |
+
step=0,
|
| 1093 |
+
table_name="sft/train_preview",
|
| 1094 |
+
columns=["split", "prompt", "completion"],
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
# ---- TrainingArguments (locked spec) ----------------------------- #
|
| 1098 |
+
Path(args.output).mkdir(parents=True, exist_ok=True)
|
| 1099 |
+
bf16_supported = (
|
| 1100 |
+
torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 1101 |
+
)
|
| 1102 |
+
training_args = TrainingArguments(
|
| 1103 |
+
output_dir=args.output,
|
| 1104 |
+
num_train_epochs=epochs,
|
| 1105 |
+
max_steps=max_steps, # hard cap; wins over epochs
|
| 1106 |
+
per_device_train_batch_size=batch_size,
|
| 1107 |
+
gradient_accumulation_steps=grad_accum,
|
| 1108 |
+
learning_rate=lr,
|
| 1109 |
+
weight_decay=SFT_WEIGHT_DECAY,
|
| 1110 |
+
# Label smoothing was added in the 2026-04 SFT regularisation
|
| 1111 |
+
# rewrite (FIX 2) to combat mode collapse: spreading the loss
|
| 1112 |
+
# across non-target tokens makes the model less sharply rewarded
|
| 1113 |
+
# for memorising one canonical completion, which is what kept
|
| 1114 |
+
# output_diversity at 1 across every prior checkpoint.
|
| 1115 |
+
label_smoothing_factor=SFT_LABEL_SMOOTHING,
|
| 1116 |
+
warmup_steps=SFT_WARMUP_STEPS,
|
| 1117 |
+
lr_scheduler_type=SFT_LR_SCHEDULER,
|
| 1118 |
+
optim=SFT_OPTIMIZER,
|
| 1119 |
+
bf16=bf16_supported,
|
| 1120 |
+
fp16=torch.cuda.is_available() and not bf16_supported,
|
| 1121 |
+
logging_steps=SFT_LOG_EVERY,
|
| 1122 |
+
save_steps=SFT_SAVE_EVERY,
|
| 1123 |
+
save_total_limit=4,
|
| 1124 |
+
seed=seed,
|
| 1125 |
+
report_to=report_to,
|
| 1126 |
+
run_name=run_name,
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
# ---- Callbacks --------------------------------------------------- #
|
| 1130 |
+
started_wall = time.time()
|
| 1131 |
+
callbacks = [_build_loss_guard_callback()]
|
| 1132 |
+
eval_schedule = None if args.no_eval_schedule else SFT_EVAL_SCHEDULE
|
| 1133 |
+
val_cb = _build_validation_callback(
|
| 1134 |
+
model=model,
|
| 1135 |
+
tokenizer=tokenizer,
|
| 1136 |
+
val_records=val_records,
|
| 1137 |
+
eval_every=eval_every,
|
| 1138 |
+
eval_schedule=eval_schedule,
|
| 1139 |
+
print_sample_outputs=print_sample_outputs,
|
| 1140 |
+
output_dir=args.output,
|
| 1141 |
+
max_new_tokens=SFT_MAX_NEW_TOKENS,
|
| 1142 |
+
diversity_n_samples=args.diversity_samples,
|
| 1143 |
+
diversity_temperature=args.diversity_temperature,
|
| 1144 |
+
early_stop_format=SFT_EARLY_STOP_FORMAT,
|
| 1145 |
+
early_stop_correction=SFT_EARLY_STOP_CORRECTION,
|
| 1146 |
+
early_stop_diversity=SFT_EARLY_STOP_DIVERSITY,
|
| 1147 |
+
max_wall_seconds=SFT_MAX_WALL_SECONDS,
|
| 1148 |
+
started_wall=started_wall,
|
| 1149 |
+
# 2026-04 (FIX 2) diversity-collapse regression early stop.
|
| 1150 |
+
diversity_floor=SFT_PREFLIGHT_DIVERSITY_FLOOR,
|
| 1151 |
+
diversity_run_len=SFT_DIVERSITY_COLLAPSE_RUN_LEN,
|
| 1152 |
+
)
|
| 1153 |
+
if val_cb is not None:
|
| 1154 |
+
callbacks.append(val_cb)
|
| 1155 |
+
|
| 1156 |
+
trainer = SFTTrainer(
|
| 1157 |
+
model=model,
|
| 1158 |
+
tokenizer=tokenizer,
|
| 1159 |
+
train_dataset=train_dataset,
|
| 1160 |
+
dataset_text_field="text",
|
| 1161 |
+
max_seq_length=max_seq_len,
|
| 1162 |
+
args=training_args,
|
| 1163 |
+
packing=False,
|
| 1164 |
+
callbacks=callbacks,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
print(f"training (max_steps={max_steps}, eval_every={eval_every}) ...")
|
| 1168 |
+
train_result = trainer.train()
|
| 1169 |
+
elapsed = time.time() - started_wall
|
| 1170 |
+
metrics = getattr(train_result, "metrics", {}) or {}
|
| 1171 |
+
wandb_utils.update_summary({
|
| 1172 |
+
"sft/wall_seconds": elapsed,
|
| 1173 |
+
**{f"sft/final/{k}": v for k, v in metrics.items()
|
| 1174 |
+
if isinstance(v, (int, float))},
|
| 1175 |
+
})
|
| 1176 |
+
print(f"training finished in {elapsed:.1f}s "
|
| 1177 |
+
f"(max_wall_seconds={SFT_MAX_WALL_SECONDS:.0f})")
|
| 1178 |
+
|
| 1179 |
+
print(f"saving adapters to {args.output}")
|
| 1180 |
+
model.save_pretrained(args.output)
|
| 1181 |
+
tokenizer.save_pretrained(args.output)
|
| 1182 |
+
|
| 1183 |
+
# ---- Upload adapter as W&B artifact ------------------------------ #
|
| 1184 |
+
if not args.no_artifact:
|
| 1185 |
+
wandb_utils.log_artifact(
|
| 1186 |
+
args.output,
|
| 1187 |
+
name=f"sft-adapter-{run_name}",
|
| 1188 |
+
artifact_type="model",
|
| 1189 |
+
description="SFT-warmed Qwen2.5-3B + LoRA adapter (Qubit-Medic).",
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
wandb_utils.finish_run()
|
| 1193 |
+
print("done")
|
| 1194 |
+
return 0
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
if __name__ == "__main__":
|
| 1198 |
+
sys.exit(main(sys.argv[1:]))
|