anuragredbus commited on
Commit
7db31d9
·
1 Parent(s): 1d82571

train_grpo: add TEST_ONLY mode to skip training and run eval+plots only

Browse files

When TEST_ONLY=1 (env var), Cell 11 short-circuits the rollout+SFT loop
so the rest of the notebook (AFTER eval, debug, plots, summary, adapter
save) runs end-to-end on a zero-init LoRA wrapper. Lets us validate the
eval+plot pipeline in ~5 min on a small GPU instead of waiting on a
multi-hour training run.

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +13 -3
training/train_grpo.ipynb CHANGED
@@ -188,8 +188,12 @@
188
  "print(\"OK: ast.parse (syntax check)\")\n",
189
  "\n",
190
  "SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
 
 
 
 
191
  "HINT_ALWAYS = True\n",
192
- "print(f\"SMOKE_MODE={SMOKE_MODE} | HINT_ALWAYS={HINT_ALWAYS}\")"
193
  ],
194
  "execution_count": null,
195
  "outputs": []
@@ -837,8 +841,9 @@
837
  "# Cell 11: Two-phase training loop (timing -> content)\n",
838
  "# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
839
  "# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
840
- "from trl import SFTTrainer, SFTConfig\n",
841
- "from datasets import Dataset\n",
 
842
  "\n",
843
  "if SMOKE_MODE:\n",
844
  " EPISODES_PER_ROUND = 4\n",
@@ -870,6 +875,11 @@
870
  "t_start = time.time()\n",
871
  "global_step = 0\n",
872
  "\n",
 
 
 
 
 
873
  "for phase in PHASES:\n",
874
  " phase_name = phase[\"name\"]\n",
875
  " sys_prompt = phase[\"system\"]\n",
 
188
  "print(\"OK: ast.parse (syntax check)\")\n",
189
  "\n",
190
  "SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
191
+ "# TEST_ONLY=1 skips the training loop entirely (load model -> eval -> plots).\n",
192
+ "# Use when you only want to verify the eval/plot pipeline on a fast small GPU.\n",
193
+ "# AFTER eval will then run on a zero-init LoRA wrapper (== base model behaviour).\n",
194
+ "TEST_ONLY = bool(int(os.environ.get(\"TEST_ONLY\", \"0\")))\n",
195
  "HINT_ALWAYS = True\n",
196
+ "print(f\"SMOKE_MODE={SMOKE_MODE} | TEST_ONLY={TEST_ONLY} | HINT_ALWAYS={HINT_ALWAYS}\")"
197
  ],
198
  "execution_count": null,
199
  "outputs": []
 
841
  "# Cell 11: Two-phase training loop (timing -> content)\n",
842
  "# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
843
  "# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
844
+ "if not TEST_ONLY:\n",
845
+ " from trl import SFTTrainer, SFTConfig\n",
846
+ " from datasets import Dataset\n",
847
  "\n",
848
  "if SMOKE_MODE:\n",
849
  " EPISODES_PER_ROUND = 4\n",
 
875
  "t_start = time.time()\n",
876
  "global_step = 0\n",
877
  "\n",
878
+ "if TEST_ONLY:\n",
879
+ " print(\"TEST_ONLY=1 -> skipping training rollouts + SFT. AFTER eval will run on \"\n",
880
+ " \"zero-init LoRA (== base model behaviour). All plot/summary cells still execute.\")\n",
881
+ " PHASES = [] # empty so the for-loop below is a no-op\n",
882
+ "\n",
883
  "for phase in PHASES:\n",
884
  " phase_name = phase[\"name\"]\n",
885
  " sys_prompt = phase[\"system\"]\n",