vaibhav12332112312 commited on
Commit
e82b235
·
1 Parent(s): 3419724

Strip heatmap leak from prompt; let model discover peak hours via tools

Browse files

- Remove explicit Mon..Sun peak-hour table from system prompt
- Drop "today's peak hours=..." from format_obs
- Compress two-phase + posting rules to essentials
- Forces model to learn timing via query_audience/query_trends

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +60 -78
training/train_grpo.ipynb CHANGED
@@ -25,7 +25,9 @@
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
31
  "!pip install -q torch torchvision torchaudio\n",
@@ -34,13 +36,13 @@
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
37
- ],
38
- "execution_count": null,
39
- "outputs": []
40
  },
41
  {
42
  "cell_type": "code",
 
43
  "metadata": {},
 
44
  "source": [
45
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
46
  "import os\n",
@@ -116,13 +118,13 @@
116
  "print(f\"Branch: {REPO_BRANCH}\")\n",
117
  "print(f\"Commit: {commit}\")\n",
118
  "print(f\"Plots dir: {PLOTS_DIR}\")"
119
- ],
120
- "execution_count": null,
121
- "outputs": []
122
  },
123
  {
124
  "cell_type": "code",
 
125
  "metadata": {},
 
126
  "source": [
127
  "# Cell 3: Imports (with runtime validation)\n",
128
  "import json, random, time, textwrap, copy, os, sys\n",
@@ -176,9 +178,7 @@
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
  "print(\"OK: ast.parse (syntax check)\")"
179
- ],
180
- "execution_count": null,
181
- "outputs": []
182
  },
183
  {
184
  "cell_type": "markdown",
@@ -191,7 +191,9 @@
191
  },
192
  {
193
  "cell_type": "code",
 
194
  "metadata": {},
 
195
  "source": [
196
  "# Cell 4: Define heuristic agents + episode runner\n",
197
  "_rng = random.Random(42)\n",
@@ -267,13 +269,13 @@
267
  " \"rewards\": rewards, \"energies\": energies}\n",
268
  "\n",
269
  "print(\"Agents and episode runner defined.\")"
270
- ],
271
- "execution_count": null,
272
- "outputs": []
273
  },
274
  {
275
  "cell_type": "code",
 
276
  "metadata": {},
 
277
  "source": [
278
  "# Cell 5: Run baselines (safe)\n",
279
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -308,13 +310,13 @@
308
  "for name in BASELINE_AGENTS:\n",
309
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
310
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
311
- ],
312
- "execution_count": null,
313
- "outputs": []
314
  },
315
  {
316
  "cell_type": "code",
 
317
  "metadata": {},
 
318
  "source": [
319
  "# Cell 6: Baseline plots\n",
320
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -332,9 +334,7 @@
332
  "fig.tight_layout()\n",
333
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
334
  "plt.show()"
335
- ],
336
- "execution_count": null,
337
- "outputs": []
338
  },
339
  {
340
  "cell_type": "markdown",
@@ -347,7 +347,9 @@
347
  },
348
  {
349
  "cell_type": "code",
 
350
  "metadata": {},
 
351
  "source": [
352
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
353
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
@@ -391,13 +393,13 @@
391
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
392
  "if torch.cuda.is_available():\n",
393
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
394
- ],
395
- "execution_count": null,
396
- "outputs": []
397
  },
398
  {
399
  "cell_type": "code",
 
400
  "metadata": {},
 
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
@@ -439,38 +441,21 @@
439
  "- topic: free-form string\n",
440
  "- empty scheduled_actions = full day rest\n",
441
  "\n",
442
- "POSTING RULES (critical — only `post` actions earn engagement reward):\n",
443
- "- EVERY active day MUST schedule at least 2 `post` actions (max 3). `create_content`\n",
444
- " alone gives 0 reward — content stays in queue. Mix in 0-1 `create_content` only\n",
445
- " if the queue is empty.\n",
446
- "- Schedule posts at HEATMAP PEAK HOURS (Buffer/Sprout-derived):\n",
447
- " Mon peaks 14, 18, 19 Tue peaks 14, 15, 19\n",
448
- " Wed peaks 13, 14, 18 Thu peaks 12, 13, 19\n",
449
- " Fri peaks 12, 13, 22 Sat peaks 21, 22, 13\n",
450
- " Sun peaks 21, 22, 11\n",
451
- "- Vary `intent` across the day; rotate `content_type` to avoid fatigue.\n",
452
- "- Reuse strong tags from the Recent-days summary (those that earned reward).\"\"\")\n",
453
  "\n",
454
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
455
  "\n",
456
- "TWO-PHASE FLOW (each day has two turns — same observation, two responses):\n",
457
- "PHASE A — DISCOVERY: respond with {\"tool_calls\": [...]} only. Tools cost nothing,\n",
458
- " call as many query_* / predict_engagement / draft_review as useful. Their results\n",
459
- " are dispatched immediately and shown to you in PHASE B of the SAME day.\n",
460
- "PHASE B — PLANNING: respond with {\"scheduled_actions\": [...], \"notes\": \"...\"}\n",
461
- " using the freshly returned Tool results.\n",
462
- "Audience peak hours, segment affinities, trends, competitor schedules are NOT in\n",
463
- "the observation — discover them in PHASE A. Useful PHASE-A starter set:\n",
464
- " query_trends(niche), query_audience(segment_id), query_creator_pool(),\n",
465
- " query_competitor(competitor_id, window_days), and on later days also\n",
466
- " predict_engagement(scheduled_actions=[...candidate plan...]).\"\"\")\n",
467
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
468
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
469
  "\n",
470
  "\n",
471
  "_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
472
- "_PEAK_HOURS = {0:[14,18,19], 1:[14,15,19], 2:[13,14,18], 3:[12,13,19],\n",
473
- " 4:[12,13,22], 5:[21,22,13], 6:[21,22,11]}\n",
474
  "\n",
475
  "\n",
476
  "def _format_history(history, k=3):\n",
@@ -489,7 +474,6 @@
489
  "\n",
490
  "def format_obs(obs, history=None):\n",
491
  " day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
492
- " peaks = _PEAK_HOURS.get(obs.day_of_week, [12, 18, 20])\n",
493
  " signals_str = \"\"\n",
494
  " signals = getattr(obs, \"engagement_signals\", None)\n",
495
  " if signals:\n",
@@ -502,7 +486,7 @@
502
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
503
  " if not tool_str:\n",
504
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
505
- " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed} | today's peak hours={peaks}\\n\"\n",
506
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
507
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
508
  " f\"{signals_str}\"\n",
@@ -732,9 +716,7 @@
732
  "\n",
733
  "\n",
734
  "print(\"LLM agent functions defined (batched).\")"
735
- ],
736
- "execution_count": null,
737
- "outputs": []
738
  },
739
  {
740
  "cell_type": "markdown",
@@ -747,7 +729,9 @@
747
  },
748
  {
749
  "cell_type": "code",
 
750
  "metadata": {},
 
751
  "source": [
752
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
753
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
@@ -761,9 +745,7 @@
761
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
762
  "for t in TASKS:\n",
763
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
764
- ],
765
- "execution_count": null,
766
- "outputs": []
767
  },
768
  {
769
  "cell_type": "markdown",
@@ -782,7 +764,9 @@
782
  },
783
  {
784
  "cell_type": "code",
 
785
  "metadata": {},
 
786
  "source": [
787
  "# Cell 10: Attach LoRA adapter\n",
788
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -796,13 +780,13 @@
796
  "model.enable_input_require_grads()\n",
797
  "peft_model = get_peft_model(model, lora_config)\n",
798
  "peft_model.print_trainable_parameters()"
799
- ],
800
- "execution_count": null,
801
- "outputs": []
802
  },
803
  {
804
  "cell_type": "code",
 
805
  "metadata": {},
 
806
  "source": [
807
  "# Cell 11: Training loop\n",
808
  "from trl import SFTTrainer, SFTConfig\n",
@@ -907,9 +891,7 @@
907
  "elapsed = time.time() - t_start\n",
908
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
909
  "print(pd.DataFrame(training_log).to_string(index=False))"
910
- ],
911
- "execution_count": null,
912
- "outputs": []
913
  },
914
  {
915
  "cell_type": "markdown",
@@ -922,7 +904,9 @@
922
  },
923
  {
924
  "cell_type": "code",
 
925
  "metadata": {},
 
926
  "source": [
927
  "# Cell 12: Run trained model (batched)\n",
928
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
@@ -937,9 +921,7 @@
937
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
938
  "for t in TASKS:\n",
939
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
940
- ],
941
- "execution_count": null,
942
- "outputs": []
943
  },
944
  {
945
  "cell_type": "markdown",
@@ -950,7 +932,9 @@
950
  },
951
  {
952
  "cell_type": "code",
 
953
  "metadata": {},
 
954
  "source": [
955
  "# Cell 13: Training curves\n",
956
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -972,13 +956,13 @@
972
  "fig.tight_layout()\n",
973
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
974
  "plt.show()"
975
- ],
976
- "execution_count": null,
977
- "outputs": []
978
  },
979
  {
980
  "cell_type": "code",
 
981
  "metadata": {},
 
982
  "source": [
983
  "# Cell 14: Before vs After\n",
984
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -1008,13 +992,13 @@
1008
  "fig.tight_layout()\n",
1009
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
1010
  "plt.show()"
1011
- ],
1012
- "execution_count": null,
1013
- "outputs": []
1014
  },
1015
  {
1016
  "cell_type": "code",
 
1017
  "metadata": {},
 
1018
  "source": [
1019
  "# Cell 15: Trajectory comparison\n",
1020
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -1038,9 +1022,7 @@
1038
  "fig.tight_layout()\n",
1039
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1040
  "plt.show()"
1041
- ],
1042
- "execution_count": null,
1043
- "outputs": []
1044
  },
1045
  {
1046
  "cell_type": "markdown",
@@ -1051,7 +1033,9 @@
1051
  },
1052
  {
1053
  "cell_type": "code",
 
1054
  "metadata": {},
 
1055
  "source": [
1056
  "# Cell 16: Final summary\n",
1057
  "print(\"=\" * 67)\n",
@@ -1088,13 +1072,13 @@
1088
  "\n",
1089
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1090
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1091
- ],
1092
- "execution_count": null,
1093
- "outputs": []
1094
  },
1095
  {
1096
  "cell_type": "code",
 
1097
  "metadata": {},
 
1098
  "source": [
1099
  "# Cell 17: Save adapter\n",
1100
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -1102,9 +1086,7 @@
1102
  "tokenizer.save_pretrained(save_path)\n",
1103
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1104
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1105
- ],
1106
- "execution_count": null,
1107
- "outputs": []
1108
  }
1109
  ],
1110
  "metadata": {
@@ -1130,4 +1112,4 @@
1130
  },
1131
  "nbformat": 4,
1132
  "nbformat_minor": 4
1133
- }
 
25
  },
26
  {
27
  "cell_type": "code",
28
+ "execution_count": null,
29
  "metadata": {},
30
+ "outputs": [],
31
  "source": [
32
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
33
  "!pip install -q torch torchvision torchaudio\n",
 
36
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
37
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
38
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
39
+ ]
 
 
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": null,
44
  "metadata": {},
45
+ "outputs": [],
46
  "source": [
47
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
48
  "import os\n",
 
118
  "print(f\"Branch: {REPO_BRANCH}\")\n",
119
  "print(f\"Commit: {commit}\")\n",
120
  "print(f\"Plots dir: {PLOTS_DIR}\")"
121
+ ]
 
 
122
  },
123
  {
124
  "cell_type": "code",
125
+ "execution_count": null,
126
  "metadata": {},
127
+ "outputs": [],
128
  "source": [
129
  "# Cell 3: Imports (with runtime validation)\n",
130
  "import json, random, time, textwrap, copy, os, sys\n",
 
178
  "import ast\n",
179
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
180
  "print(\"OK: ast.parse (syntax check)\")"
181
+ ]
 
 
182
  },
183
  {
184
  "cell_type": "markdown",
 
191
  },
192
  {
193
  "cell_type": "code",
194
+ "execution_count": null,
195
  "metadata": {},
196
+ "outputs": [],
197
  "source": [
198
  "# Cell 4: Define heuristic agents + episode runner\n",
199
  "_rng = random.Random(42)\n",
 
269
  " \"rewards\": rewards, \"energies\": energies}\n",
270
  "\n",
271
  "print(\"Agents and episode runner defined.\")"
272
+ ]
 
 
273
  },
274
  {
275
  "cell_type": "code",
276
+ "execution_count": null,
277
  "metadata": {},
278
+ "outputs": [],
279
  "source": [
280
  "# Cell 5: Run baselines (safe)\n",
281
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
310
  "for name in BASELINE_AGENTS:\n",
311
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
312
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
313
+ ]
 
 
314
  },
315
  {
316
  "cell_type": "code",
317
+ "execution_count": null,
318
  "metadata": {},
319
+ "outputs": [],
320
  "source": [
321
  "# Cell 6: Baseline plots\n",
322
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
334
  "fig.tight_layout()\n",
335
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
336
  "plt.show()"
337
+ ]
 
 
338
  },
339
  {
340
  "cell_type": "markdown",
 
347
  },
348
  {
349
  "cell_type": "code",
350
+ "execution_count": null,
351
  "metadata": {},
352
+ "outputs": [],
353
  "source": [
354
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
355
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
 
393
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
394
  "if torch.cuda.is_available():\n",
395
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
396
+ ]
 
 
397
  },
398
  {
399
  "cell_type": "code",
400
+ "execution_count": null,
401
  "metadata": {},
402
+ "outputs": [],
403
  "source": [
404
  "# Cell 8: LLM agent functions\n",
405
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
 
441
  "- topic: free-form string\n",
442
  "- empty scheduled_actions = full day rest\n",
443
  "\n",
444
+ "POSTING RULES:\n",
445
+ "- Each active day: 2-3 `post` actions at the audience's peak hours.\n",
446
+ "- `create_content` alone earns 0 reward.\n",
447
+ "- Vary `intent` and `content_type`.\"\"\")\n",
 
 
 
 
 
 
 
448
  "\n",
449
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
450
  "\n",
451
+ "TWO-PHASE FLOW per day (same observation, two responses):\n",
452
+ "PHASE A: respond with {\"tool_calls\": [...]} only.\n",
453
+ "PHASE B: respond with {\"scheduled_actions\": [...], \"notes\": \"...\"} using the tool results.\"\"\")\n",
 
 
 
 
 
 
 
 
454
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
455
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
456
  "\n",
457
  "\n",
458
  "_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
 
 
459
  "\n",
460
  "\n",
461
  "def _format_history(history, k=3):\n",
 
474
  "\n",
475
  "def format_obs(obs, history=None):\n",
476
  " day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
 
477
  " signals_str = \"\"\n",
478
  " signals = getattr(obs, \"engagement_signals\", None)\n",
479
  " if signals:\n",
 
486
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
487
  " if not tool_str:\n",
488
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
489
+ " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
490
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
491
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
492
  " f\"{signals_str}\"\n",
 
716
  "\n",
717
  "\n",
718
  "print(\"LLM agent functions defined (batched).\")"
719
+ ]
 
 
720
  },
721
  {
722
  "cell_type": "markdown",
 
729
  },
730
  {
731
  "cell_type": "code",
732
+ "execution_count": null,
733
  "metadata": {},
734
+ "outputs": [],
735
  "source": [
736
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
737
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
 
745
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
746
  "for t in TASKS:\n",
747
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
748
+ ]
 
 
749
  },
750
  {
751
  "cell_type": "markdown",
 
764
  },
765
  {
766
  "cell_type": "code",
767
+ "execution_count": null,
768
  "metadata": {},
769
+ "outputs": [],
770
  "source": [
771
  "# Cell 10: Attach LoRA adapter\n",
772
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
780
  "model.enable_input_require_grads()\n",
781
  "peft_model = get_peft_model(model, lora_config)\n",
782
  "peft_model.print_trainable_parameters()"
783
+ ]
 
 
784
  },
785
  {
786
  "cell_type": "code",
787
+ "execution_count": null,
788
  "metadata": {},
789
+ "outputs": [],
790
  "source": [
791
  "# Cell 11: Training loop\n",
792
  "from trl import SFTTrainer, SFTConfig\n",
 
891
  "elapsed = time.time() - t_start\n",
892
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
893
  "print(pd.DataFrame(training_log).to_string(index=False))"
894
+ ]
 
 
895
  },
896
  {
897
  "cell_type": "markdown",
 
904
  },
905
  {
906
  "cell_type": "code",
907
+ "execution_count": null,
908
  "metadata": {},
909
+ "outputs": [],
910
  "source": [
911
  "# Cell 12: Run trained model (batched)\n",
912
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
 
921
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
922
  "for t in TASKS:\n",
923
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
924
+ ]
 
 
925
  },
926
  {
927
  "cell_type": "markdown",
 
932
  },
933
  {
934
  "cell_type": "code",
935
+ "execution_count": null,
936
  "metadata": {},
937
+ "outputs": [],
938
  "source": [
939
  "# Cell 13: Training curves\n",
940
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
956
  "fig.tight_layout()\n",
957
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
958
  "plt.show()"
959
+ ]
 
 
960
  },
961
  {
962
  "cell_type": "code",
963
+ "execution_count": null,
964
  "metadata": {},
965
+ "outputs": [],
966
  "source": [
967
  "# Cell 14: Before vs After\n",
968
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
992
  "fig.tight_layout()\n",
993
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
994
  "plt.show()"
995
+ ]
 
 
996
  },
997
  {
998
  "cell_type": "code",
999
+ "execution_count": null,
1000
  "metadata": {},
1001
+ "outputs": [],
1002
  "source": [
1003
  "# Cell 15: Trajectory comparison\n",
1004
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
1022
  "fig.tight_layout()\n",
1023
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1024
  "plt.show()"
1025
+ ]
 
 
1026
  },
1027
  {
1028
  "cell_type": "markdown",
 
1033
  },
1034
  {
1035
  "cell_type": "code",
1036
+ "execution_count": null,
1037
  "metadata": {},
1038
+ "outputs": [],
1039
  "source": [
1040
  "# Cell 16: Final summary\n",
1041
  "print(\"=\" * 67)\n",
 
1072
  "\n",
1073
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1074
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1075
+ ]
 
 
1076
  },
1077
  {
1078
  "cell_type": "code",
1079
+ "execution_count": null,
1080
  "metadata": {},
1081
+ "outputs": [],
1082
  "source": [
1083
  "# Cell 17: Save adapter\n",
1084
  "save_path = \"./viraltest_trained_adapter\"\n",
 
1086
  "tokenizer.save_pretrained(save_path)\n",
1087
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1088
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1089
+ ]
 
 
1090
  }
1091
  ],
1092
  "metadata": {
 
1112
  },
1113
  "nbformat": 4,
1114
  "nbformat_minor": 4
1115
+ }