vaibhav12332112312 commited on
Commit
271bf42
·
1 Parent(s): ad5d3b3

Match eval sampling to training, log all I/O, single round

Browse files

- Drop greedy at eval; always sample with temperature=1.0, top_p=0.95 so
eval reflects the same distribution training optimized
- Add IO logging: every prompt/response written to plots/io_log.jsonl
with phase tag (before / train_roundN / after) for inspection
- NUM_ROUNDS = 1 to iterate quickly while debugging the no-transfer issue

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +78 -63
training/train_grpo.ipynb CHANGED
@@ -25,7 +25,9 @@
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
31
  "!pip install -q torch torchvision torchaudio\n",
@@ -34,13 +36,13 @@
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
37
- ],
38
- "execution_count": null,
39
- "outputs": []
40
  },
41
  {
42
  "cell_type": "code",
 
43
  "metadata": {},
 
44
  "source": [
45
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
46
  "import os\n",
@@ -116,13 +118,13 @@
116
  "print(f\"Branch: {REPO_BRANCH}\")\n",
117
  "print(f\"Commit: {commit}\")\n",
118
  "print(f\"Plots dir: {PLOTS_DIR}\")"
119
- ],
120
- "execution_count": null,
121
- "outputs": []
122
  },
123
  {
124
  "cell_type": "code",
 
125
  "metadata": {},
 
126
  "source": [
127
  "# Cell 3: Imports (with runtime validation)\n",
128
  "import json, random, time, textwrap, copy, os, sys\n",
@@ -176,9 +178,7 @@
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
  "print(\"OK: ast.parse (syntax check)\")"
179
- ],
180
- "execution_count": null,
181
- "outputs": []
182
  },
183
  {
184
  "cell_type": "markdown",
@@ -191,7 +191,9 @@
191
  },
192
  {
193
  "cell_type": "code",
 
194
  "metadata": {},
 
195
  "source": [
196
  "# Cell 4: Define heuristic agents + episode runner\n",
197
  "_rng = random.Random(42)\n",
@@ -267,13 +269,13 @@
267
  " \"rewards\": rewards, \"energies\": energies}\n",
268
  "\n",
269
  "print(\"Agents and episode runner defined.\")"
270
- ],
271
- "execution_count": null,
272
- "outputs": []
273
  },
274
  {
275
  "cell_type": "code",
 
276
  "metadata": {},
 
277
  "source": [
278
  "# Cell 5: Run baselines (safe)\n",
279
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -308,13 +310,13 @@
308
  "for name in BASELINE_AGENTS:\n",
309
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
310
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
311
- ],
312
- "execution_count": null,
313
- "outputs": []
314
  },
315
  {
316
  "cell_type": "code",
 
317
  "metadata": {},
 
318
  "source": [
319
  "# Cell 6: Baseline plots\n",
320
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -332,9 +334,7 @@
332
  "fig.tight_layout()\n",
333
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
334
  "plt.show()"
335
- ],
336
- "execution_count": null,
337
- "outputs": []
338
  },
339
  {
340
  "cell_type": "markdown",
@@ -347,7 +347,9 @@
347
  },
348
  {
349
  "cell_type": "code",
 
350
  "metadata": {},
 
351
  "source": [
352
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
353
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
@@ -391,13 +393,13 @@
391
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
392
  "if torch.cuda.is_available():\n",
393
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
394
- ],
395
- "execution_count": null,
396
- "outputs": []
397
  },
398
  {
399
  "cell_type": "code",
 
400
  "metadata": {},
 
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
@@ -547,18 +549,29 @@
547
  "\n",
548
  "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
549
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
550
- " gen_kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.pad_token_id)\n",
551
- " if eval:\n",
552
- " gen_kwargs.update(do_sample=False)\n",
553
- " else:\n",
554
- " gen_kwargs.update(do_sample=True, temperature=1.0, top_p=0.95)\n",
555
  " with torch.no_grad():\n",
556
  " out = mdl.generate(**enc, **gen_kwargs)\n",
557
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
558
  " return resps, enc[\"input_ids\"].shape[1]\n",
559
  "\n",
560
  "\n",
561
- "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None):\n",
 
 
 
 
 
 
 
 
 
 
 
562
  " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
563
  " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
564
  " n = len(tasks_seeds)\n",
@@ -588,6 +601,9 @@
588
  " actions_by_idx[i] = parse_model_output(resps[j])\n",
589
  " pairs[i].append({\"prompt\": prompts[j], \"response\": resps[j],\n",
590
  " \"step\": len(rewards[i])})\n",
 
 
 
591
  "\n",
592
  " for i in range(n):\n",
593
  " if done_mask[i] or i not in actions_by_idx:\n",
@@ -626,9 +642,7 @@
626
  "\n",
627
  "\n",
628
  "print(\"LLM agent functions defined (batched).\")"
629
- ],
630
- "execution_count": null,
631
- "outputs": []
632
  },
633
  {
634
  "cell_type": "markdown",
@@ -641,23 +655,23 @@
641
  },
642
  {
643
  "cell_type": "code",
 
644
  "metadata": {},
 
645
  "source": [
646
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
647
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
648
  "print(\"=\" * 60)\n",
649
  "\n",
650
  "t0 = time.time()\n",
651
- "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
652
  "before_results = {r[\"task\"]: r for r in results}\n",
653
  "\n",
654
  "print(\"\\n\" + \"=\" * 60)\n",
655
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
656
  "for t in TASKS:\n",
657
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
658
- ],
659
- "execution_count": null,
660
- "outputs": []
661
  },
662
  {
663
  "cell_type": "markdown",
@@ -676,7 +690,9 @@
676
  },
677
  {
678
  "cell_type": "code",
 
679
  "metadata": {},
 
680
  "source": [
681
  "# Cell 10: Attach LoRA adapter\n",
682
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -690,19 +706,19 @@
690
  "model.enable_input_require_grads()\n",
691
  "peft_model = get_peft_model(model, lora_config)\n",
692
  "peft_model.print_trainable_parameters()"
693
- ],
694
- "execution_count": null,
695
- "outputs": []
696
  },
697
  {
698
  "cell_type": "code",
 
699
  "metadata": {},
 
700
  "source": [
701
  "# Cell 11: Training loop\n",
702
  "from trl import SFTTrainer, SFTConfig\n",
703
  "from datasets import Dataset\n",
704
  "\n",
705
- "NUM_ROUNDS = 4\n",
706
  "EPISODES_PER_ROUND = 6\n",
707
  "QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n",
708
  "\n",
@@ -723,7 +739,8 @@
723
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
724
  " t_roll = time.time()\n",
725
  " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n",
726
- " eval=False, system=SYSTEM_PROMPT_TRAIN)\n",
 
727
  " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
728
  "\n",
729
  " all_pairs, episode_rewards, episode_graders = [], [], []\n",
@@ -800,9 +817,7 @@
800
  "elapsed = time.time() - t_start\n",
801
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
802
  "print(pd.DataFrame(training_log).to_string(index=False))"
803
- ],
804
- "execution_count": null,
805
- "outputs": []
806
  },
807
  {
808
  "cell_type": "markdown",
@@ -815,7 +830,9 @@
815
  },
816
  {
817
  "cell_type": "code",
 
818
  "metadata": {},
 
819
  "source": [
820
  "# Cell 12: Run trained model (batched)\n",
821
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
@@ -823,16 +840,14 @@
823
  "\n",
824
  "peft_model.eval()\n",
825
  "t0 = time.time()\n",
826
- "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
827
- "after_results = {r[\"task\"]: r for r in results}\n",
828
  "\n",
829
  "print(\"\\n\" + \"=\" * 60)\n",
830
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
831
  "for t in TASKS:\n",
832
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
833
- ],
834
- "execution_count": null,
835
- "outputs": []
836
  },
837
  {
838
  "cell_type": "markdown",
@@ -843,7 +858,9 @@
843
  },
844
  {
845
  "cell_type": "code",
 
846
  "metadata": {},
 
847
  "source": [
848
  "# Cell 13: Training curves\n",
849
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -865,13 +882,13 @@
865
  "fig.tight_layout()\n",
866
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
867
  "plt.show()"
868
- ],
869
- "execution_count": null,
870
- "outputs": []
871
  },
872
  {
873
  "cell_type": "code",
 
874
  "metadata": {},
 
875
  "source": [
876
  "# Cell 14: Before vs After\n",
877
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -901,13 +918,13 @@
901
  "fig.tight_layout()\n",
902
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
903
  "plt.show()"
904
- ],
905
- "execution_count": null,
906
- "outputs": []
907
  },
908
  {
909
  "cell_type": "code",
 
910
  "metadata": {},
 
911
  "source": [
912
  "# Cell 15: Trajectory comparison\n",
913
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -931,9 +948,7 @@
931
  "fig.tight_layout()\n",
932
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
933
  "plt.show()"
934
- ],
935
- "execution_count": null,
936
- "outputs": []
937
  },
938
  {
939
  "cell_type": "markdown",
@@ -944,7 +959,9 @@
944
  },
945
  {
946
  "cell_type": "code",
 
947
  "metadata": {},
 
948
  "source": [
949
  "# Cell 16: Final summary\n",
950
  "print(\"=\" * 67)\n",
@@ -981,13 +998,13 @@
981
  "\n",
982
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
983
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
984
- ],
985
- "execution_count": null,
986
- "outputs": []
987
  },
988
  {
989
  "cell_type": "code",
 
990
  "metadata": {},
 
991
  "source": [
992
  "# Cell 17: Save adapter\n",
993
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -995,9 +1012,7 @@
995
  "tokenizer.save_pretrained(save_path)\n",
996
  "print(f\"LoRA adapter saved to {save_path}\")\n",
997
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
998
- ],
999
- "execution_count": null,
1000
- "outputs": []
1001
  }
1002
  ],
1003
  "metadata": {
@@ -1023,4 +1038,4 @@
1023
  },
1024
  "nbformat": 4,
1025
  "nbformat_minor": 4
1026
- }
 
25
  },
26
  {
27
  "cell_type": "code",
28
+ "execution_count": null,
29
  "metadata": {},
30
+ "outputs": [],
31
  "source": [
32
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
33
  "!pip install -q torch torchvision torchaudio\n",
 
36
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
37
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
38
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
39
+ ]
 
 
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": null,
44
  "metadata": {},
45
+ "outputs": [],
46
  "source": [
47
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
48
  "import os\n",
 
118
  "print(f\"Branch: {REPO_BRANCH}\")\n",
119
  "print(f\"Commit: {commit}\")\n",
120
  "print(f\"Plots dir: {PLOTS_DIR}\")"
121
+ ]
 
 
122
  },
123
  {
124
  "cell_type": "code",
125
+ "execution_count": null,
126
  "metadata": {},
127
+ "outputs": [],
128
  "source": [
129
  "# Cell 3: Imports (with runtime validation)\n",
130
  "import json, random, time, textwrap, copy, os, sys\n",
 
178
  "import ast\n",
179
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
180
  "print(\"OK: ast.parse (syntax check)\")"
181
+ ]
 
 
182
  },
183
  {
184
  "cell_type": "markdown",
 
191
  },
192
  {
193
  "cell_type": "code",
194
+ "execution_count": null,
195
  "metadata": {},
196
+ "outputs": [],
197
  "source": [
198
  "# Cell 4: Define heuristic agents + episode runner\n",
199
  "_rng = random.Random(42)\n",
 
269
  " \"rewards\": rewards, \"energies\": energies}\n",
270
  "\n",
271
  "print(\"Agents and episode runner defined.\")"
272
+ ]
 
 
273
  },
274
  {
275
  "cell_type": "code",
276
+ "execution_count": null,
277
  "metadata": {},
278
+ "outputs": [],
279
  "source": [
280
  "# Cell 5: Run baselines (safe)\n",
281
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
310
  "for name in BASELINE_AGENTS:\n",
311
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
312
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
313
+ ]
 
 
314
  },
315
  {
316
  "cell_type": "code",
317
+ "execution_count": null,
318
  "metadata": {},
319
+ "outputs": [],
320
  "source": [
321
  "# Cell 6: Baseline plots\n",
322
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
334
  "fig.tight_layout()\n",
335
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
336
  "plt.show()"
337
+ ]
 
 
338
  },
339
  {
340
  "cell_type": "markdown",
 
347
  },
348
  {
349
  "cell_type": "code",
350
+ "execution_count": null,
351
  "metadata": {},
352
+ "outputs": [],
353
  "source": [
354
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
355
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
 
393
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
394
  "if torch.cuda.is_available():\n",
395
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
396
+ ]
 
 
397
  },
398
  {
399
  "cell_type": "code",
400
+ "execution_count": null,
401
  "metadata": {},
402
+ "outputs": [],
403
  "source": [
404
  "# Cell 8: LLM agent functions\n",
405
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
 
549
  "\n",
550
  "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
551
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
552
+ " gen_kwargs = dict(\n",
553
+ " max_new_tokens=max_new_tokens,\n",
554
+ " pad_token_id=tok.pad_token_id,\n",
555
+ " do_sample=True, temperature=1.0, top_p=0.95,\n",
556
+ " )\n",
557
  " with torch.no_grad():\n",
558
  " out = mdl.generate(**enc, **gen_kwargs)\n",
559
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
560
  " return resps, enc[\"input_ids\"].shape[1]\n",
561
  "\n",
562
  "\n",
563
+ "IO_LOG_PATH = os.path.join(PLOTS_DIR, \"io_log.jsonl\")\n",
564
+ "open(IO_LOG_PATH, \"w\").close() # truncate\n",
565
+ "\n",
566
+ "\n",
567
+ "def _log_io(tag, ep_idx, day, task, seed, prompt, response):\n",
568
+ " rec = {\"tag\": tag, \"ep\": ep_idx, \"day\": day, \"task\": task, \"seed\": seed,\n",
569
+ " \"prompt\": prompt, \"response\": response}\n",
570
+ " with open(IO_LOG_PATH, \"a\") as f:\n",
571
+ " f.write(json.dumps(rec) + \"\\n\")\n",
572
+ "\n",
573
+ "\n",
574
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
575
  " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
576
  " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
577
  " n = len(tasks_seeds)\n",
 
601
  " actions_by_idx[i] = parse_model_output(resps[j])\n",
602
  " pairs[i].append({\"prompt\": prompts[j], \"response\": resps[j],\n",
603
  " \"step\": len(rewards[i])})\n",
604
+ " if log_tag is not None:\n",
605
+ " t, s = tasks_seeds[i]\n",
606
+ " _log_io(log_tag, i, day, t, s, prompts[j], resps[j])\n",
607
  "\n",
608
  " for i in range(n):\n",
609
  " if done_mask[i] or i not in actions_by_idx:\n",
 
642
  "\n",
643
  "\n",
644
  "print(\"LLM agent functions defined (batched).\")"
645
+ ]
 
 
646
  },
647
  {
648
  "cell_type": "markdown",
 
655
  },
656
  {
657
  "cell_type": "code",
658
+ "execution_count": null,
659
  "metadata": {},
660
+ "outputs": [],
661
  "source": [
662
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
663
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
664
  "print(\"=\" * 60)\n",
665
  "\n",
666
  "t0 = time.time()\n",
667
+ "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"before\")\n",
668
  "before_results = {r[\"task\"]: r for r in results}\n",
669
  "\n",
670
  "print(\"\\n\" + \"=\" * 60)\n",
671
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
672
  "for t in TASKS:\n",
673
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
674
+ ]
 
 
675
  },
676
  {
677
  "cell_type": "markdown",
 
690
  },
691
  {
692
  "cell_type": "code",
693
+ "execution_count": null,
694
  "metadata": {},
695
+ "outputs": [],
696
  "source": [
697
  "# Cell 10: Attach LoRA adapter\n",
698
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
706
  "model.enable_input_require_grads()\n",
707
  "peft_model = get_peft_model(model, lora_config)\n",
708
  "peft_model.print_trainable_parameters()"
709
+ ]
 
 
710
  },
711
  {
712
  "cell_type": "code",
713
+ "execution_count": null,
714
  "metadata": {},
715
+ "outputs": [],
716
  "source": [
717
  "# Cell 11: Training loop\n",
718
  "from trl import SFTTrainer, SFTConfig\n",
719
  "from datasets import Dataset\n",
720
  "\n",
721
+ "NUM_ROUNDS = 1\n",
722
  "EPISODES_PER_ROUND = 6\n",
723
  "QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n",
724
  "\n",
 
739
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
740
  " t_roll = time.time()\n",
741
  " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n",
742
+ " eval=False, system=SYSTEM_PROMPT_TRAIN,\n",
743
+ " log_tag=f\"train_round{round_idx}\")\n",
744
  " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
745
  "\n",
746
  " all_pairs, episode_rewards, episode_graders = [], [], []\n",
 
817
  "elapsed = time.time() - t_start\n",
818
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
819
  "print(pd.DataFrame(training_log).to_string(index=False))"
820
+ ]
 
 
821
  },
822
  {
823
  "cell_type": "markdown",
 
830
  },
831
  {
832
  "cell_type": "code",
833
+ "execution_count": null,
834
  "metadata": {},
835
+ "outputs": [],
836
  "source": [
837
  "# Cell 12: Run trained model (batched)\n",
838
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
 
840
  "\n",
841
  "peft_model.eval()\n",
842
  "t0 = time.time()\n",
843
+ "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"after\")\n",
844
+ "after_results = {r[\"task\"]: r for r in results}\n",
845
  "\n",
846
  "print(\"\\n\" + \"=\" * 60)\n",
847
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
848
  "for t in TASKS:\n",
849
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
850
+ ]
 
 
851
  },
852
  {
853
  "cell_type": "markdown",
 
858
  },
859
  {
860
  "cell_type": "code",
861
+ "execution_count": null,
862
  "metadata": {},
863
+ "outputs": [],
864
  "source": [
865
  "# Cell 13: Training curves\n",
866
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
882
  "fig.tight_layout()\n",
883
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
884
  "plt.show()"
885
+ ]
 
 
886
  },
887
  {
888
  "cell_type": "code",
889
+ "execution_count": null,
890
  "metadata": {},
891
+ "outputs": [],
892
  "source": [
893
  "# Cell 14: Before vs After\n",
894
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
918
  "fig.tight_layout()\n",
919
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
920
  "plt.show()"
921
+ ]
 
 
922
  },
923
  {
924
  "cell_type": "code",
925
+ "execution_count": null,
926
  "metadata": {},
927
+ "outputs": [],
928
  "source": [
929
  "# Cell 15: Trajectory comparison\n",
930
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
948
  "fig.tight_layout()\n",
949
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
950
  "plt.show()"
951
+ ]
 
 
952
  },
953
  {
954
  "cell_type": "markdown",
 
959
  },
960
  {
961
  "cell_type": "code",
962
+ "execution_count": null,
963
  "metadata": {},
964
+ "outputs": [],
965
  "source": [
966
  "# Cell 16: Final summary\n",
967
  "print(\"=\" * 67)\n",
 
998
  "\n",
999
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1000
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1001
+ ]
 
 
1002
  },
1003
  {
1004
  "cell_type": "code",
1005
+ "execution_count": null,
1006
  "metadata": {},
1007
+ "outputs": [],
1008
  "source": [
1009
  "# Cell 17: Save adapter\n",
1010
  "save_path = \"./viraltest_trained_adapter\"\n",
 
1012
  "tokenizer.save_pretrained(save_path)\n",
1013
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1014
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1015
+ ]
 
 
1016
  }
1017
  ],
1018
  "metadata": {
 
1038
  },
1039
  "nbformat": 4,
1040
  "nbformat_minor": 4
1041
+ }