Spaces:
Paused
Paused
Commit ·
7db31d9
1
Parent(s): 1d82571
train_grpo: add TEST_ONLY mode to skip training and run eval+plots only
Browse filesWhen TEST_ONLY=1 (env var), Cell 11 short-circuits the rollout+SFT loop
so the rest of the notebook (AFTER eval, debug, plots, summary, adapter
save) runs end-to-end on a zero-init LoRA wrapper. Lets us validate the
eval+plot pipeline in ~5 min on a small GPU instead of waiting on a
multi-hour training run.
Made-with: Cursor
- training/train_grpo.ipynb +13 -3
training/train_grpo.ipynb
CHANGED
|
@@ -188,8 +188,12 @@
|
|
| 188 |
"print(\"OK: ast.parse (syntax check)\")\n",
|
| 189 |
"\n",
|
| 190 |
"SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"HINT_ALWAYS = True\n",
|
| 192 |
-
"print(f\"SMOKE_MODE={SMOKE_MODE} | HINT_ALWAYS={HINT_ALWAYS}\")"
|
| 193 |
],
|
| 194 |
"execution_count": null,
|
| 195 |
"outputs": []
|
|
@@ -837,8 +841,9 @@
|
|
| 837 |
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 838 |
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
| 839 |
"# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
|
| 840 |
-
"
|
| 841 |
-
"from
|
|
|
|
| 842 |
"\n",
|
| 843 |
"if SMOKE_MODE:\n",
|
| 844 |
" EPISODES_PER_ROUND = 4\n",
|
|
@@ -870,6 +875,11 @@
|
|
| 870 |
"t_start = time.time()\n",
|
| 871 |
"global_step = 0\n",
|
| 872 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
"for phase in PHASES:\n",
|
| 874 |
" phase_name = phase[\"name\"]\n",
|
| 875 |
" sys_prompt = phase[\"system\"]\n",
|
|
|
|
| 188 |
"print(\"OK: ast.parse (syntax check)\")\n",
|
| 189 |
"\n",
|
| 190 |
"SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
|
| 191 |
+
"# TEST_ONLY=1 skips the training loop entirely (load model -> eval -> plots).\n",
|
| 192 |
+
"# Use when you only want to verify the eval/plot pipeline on a fast small GPU.\n",
|
| 193 |
+
"# AFTER eval will then run on a zero-init LoRA wrapper (== base model behaviour).\n",
|
| 194 |
+
"TEST_ONLY = bool(int(os.environ.get(\"TEST_ONLY\", \"0\")))\n",
|
| 195 |
"HINT_ALWAYS = True\n",
|
| 196 |
+
"print(f\"SMOKE_MODE={SMOKE_MODE} | TEST_ONLY={TEST_ONLY} | HINT_ALWAYS={HINT_ALWAYS}\")"
|
| 197 |
],
|
| 198 |
"execution_count": null,
|
| 199 |
"outputs": []
|
|
|
|
| 841 |
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 842 |
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
| 843 |
"# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
|
| 844 |
+
"if not TEST_ONLY:\n",
|
| 845 |
+
" from trl import SFTTrainer, SFTConfig\n",
|
| 846 |
+
" from datasets import Dataset\n",
|
| 847 |
"\n",
|
| 848 |
"if SMOKE_MODE:\n",
|
| 849 |
" EPISODES_PER_ROUND = 4\n",
|
|
|
|
| 875 |
"t_start = time.time()\n",
|
| 876 |
"global_step = 0\n",
|
| 877 |
"\n",
|
| 878 |
+
"if TEST_ONLY:\n",
|
| 879 |
+
" print(\"TEST_ONLY=1 -> skipping training rollouts + SFT. AFTER eval will run on \"\n",
|
| 880 |
+
" \"zero-init LoRA (== base model behaviour). All plot/summary cells still execute.\")\n",
|
| 881 |
+
" PHASES = [] # empty so the for-loop below is a no-op\n",
|
| 882 |
+
"\n",
|
| 883 |
"for phase in PHASES:\n",
|
| 884 |
" phase_name = phase[\"name\"]\n",
|
| 885 |
" sys_prompt = phase[\"system\"]\n",
|