anuragredbus commited on
Commit
8d09986
·
1 Parent(s): a1be3fe

fix: robust notebook setup (no magic shell) + local CWD auto-detect

Browse files

- Use subprocess+shutil for Colab git clone; surface clone errors
- Local mode finds repo root without /content
- Cell 3 auto-chdirs from training/ if needed
- Clear errors when project files are missing

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +203 -74
training/train_grpo.ipynb CHANGED
@@ -23,20 +23,7 @@
23
  },
24
  {
25
  "cell_type": "code",
26
- "execution_count": null,
27
  "metadata": {},
28
- "outputs": [
29
- {
30
- "ename": "",
31
- "evalue": "",
32
- "output_type": "error",
33
- "traceback": [
34
- "\u001b[1;31mRunning cells with '.venv (Python 3.13.1)' requires the ipykernel package.\n",
35
- "\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
36
- "\u001b[1;31mCommand: '/Users/vaibhavkhandare/Projects/mernstack/openenv-course/viraltest/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
37
- ]
38
- }
39
- ],
40
  "source": [
41
  "# Cell 1: Install dependencies\n",
42
  "!pip install -q torch torchvision torchaudio\n",
@@ -44,41 +31,155 @@
44
  "!pip install -q matplotlib pandas\n",
45
  "!pip install -q pydantic httpx\n",
46
  "!pip install -q \"openenv-core[core]>=0.2.2\""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]
48
  },
49
  {
50
  "cell_type": "code",
51
- "execution_count": null,
52
  "metadata": {},
53
- "outputs": [],
54
  "source": [
55
- "# Cell 2: Clone the repo and set up paths\n",
56
- "import os, sys\n",
57
- "REPO_DIR = \"/content/viral-posts-env\"\n",
 
 
 
 
58
  "REPO_BRANCH = \"hack1\"\n",
59
- "if not os.path.exists(REPO_DIR):\n",
60
- " !git clone --branch {REPO_BRANCH} --depth 1 https://github.com/VaibhavKhandare/viral-posts-env.git {REPO_DIR}\n",
61
- "os.chdir(REPO_DIR)\n",
62
- "sys.path.insert(0, REPO_DIR)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  "\n",
64
  "PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
65
  "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
 
 
 
 
 
 
 
 
 
 
66
  "print(f\"Working dir: {os.getcwd()}\")\n",
 
 
67
  "print(f\"Plots dir: {PLOTS_DIR}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ]
69
  },
70
  {
71
  "cell_type": "code",
72
- "execution_count": null,
73
  "metadata": {},
74
- "outputs": [],
75
  "source": [
76
- "# Cell 3: Imports\n",
77
- "import json, random, time, textwrap, copy\n",
78
  "from pathlib import Path\n",
79
  "from typing import Any, Dict, List, Optional, Tuple\n",
80
  "from collections import defaultdict\n",
81
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  "import matplotlib.pyplot as plt\n",
83
  "import numpy as np\n",
84
  "import pandas as pd\n",
@@ -97,8 +198,16 @@
97
  "TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
98
  "\n",
99
  "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
100
- "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
101
- ]
 
 
 
 
 
 
 
 
102
  },
103
  {
104
  "cell_type": "markdown",
@@ -111,9 +220,7 @@
111
  },
112
  {
113
  "cell_type": "code",
114
- "execution_count": null,
115
  "metadata": {},
116
- "outputs": [],
117
  "source": [
118
  "# Cell 4: Define heuristic agents + episode runner\n",
119
  "_rng = random.Random(42)\n",
@@ -190,24 +297,36 @@
190
  " \"rewards\": rewards, \"energies\": energies}\n",
191
  "\n",
192
  "print(\"Agents and episode runner defined.\")"
193
- ]
 
 
194
  },
195
  {
196
  "cell_type": "code",
197
- "execution_count": null,
198
  "metadata": {},
199
- "outputs": [],
200
  "source": [
201
- "# Cell 5: Run baselines\n",
202
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
203
  "print(\"=\" * 70)\n",
204
  "\n",
 
 
 
 
 
 
 
205
  "baseline_results = {}\n",
206
  "for name, fn in BASELINE_AGENTS.items():\n",
207
  " baseline_results[name] = {}\n",
208
  " for task in TASKS:\n",
209
  " _rng = random.Random(42)\n",
210
- " result = run_episode(task, fn, seed=42)\n",
 
 
 
 
 
211
  " baseline_results[name][task] = result\n",
212
  " print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
213
  " f\"| energy={result['final_energy']:.2f}\")\n",
@@ -219,13 +338,13 @@
219
  "for name in BASELINE_AGENTS:\n",
220
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
221
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
222
- ]
 
 
223
  },
224
  {
225
  "cell_type": "code",
226
- "execution_count": null,
227
  "metadata": {},
228
- "outputs": [],
229
  "source": [
230
  "# Cell 6: Baseline plots\n",
231
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -243,7 +362,9 @@
243
  "fig.tight_layout()\n",
244
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
245
  "plt.show()"
246
- ]
 
 
247
  },
248
  {
249
  "cell_type": "markdown",
@@ -256,9 +377,7 @@
256
  },
257
  {
258
  "cell_type": "code",
259
- "execution_count": null,
260
  "metadata": {},
261
- "outputs": [],
262
  "source": [
263
  "# Cell 7: Load model\n",
264
  "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
@@ -282,13 +401,13 @@
282
  "model.eval()\n",
283
  "print(f\"Model loaded. Device: {model.device}\")\n",
284
  "print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
285
- ]
 
 
286
  },
287
  {
288
  "cell_type": "code",
289
- "execution_count": null,
290
  "metadata": {},
291
- "outputs": [],
292
  "source": [
293
  "# Cell 8: LLM agent functions\n",
294
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
@@ -404,7 +523,9 @@
404
  " \"burned_out\": obs.creator_energy <= 0}\n",
405
  "\n",
406
  "print(\"LLM agent functions defined.\")"
407
- ]
 
 
408
  },
409
  {
410
  "cell_type": "markdown",
@@ -417,9 +538,7 @@
417
  },
418
  {
419
  "cell_type": "code",
420
- "execution_count": null,
421
  "metadata": {},
422
- "outputs": [],
423
  "source": [
424
  "# Cell 9: Run untrained model\n",
425
  "print(\"Running UNTRAINED base model on all tasks...\")\n",
@@ -436,7 +555,9 @@
436
  "print(\"BEFORE TRAINING:\")\n",
437
  "for t in TASKS:\n",
438
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
439
- ]
 
 
440
  },
441
  {
442
  "cell_type": "markdown",
@@ -455,9 +576,7 @@
455
  },
456
  {
457
  "cell_type": "code",
458
- "execution_count": null,
459
  "metadata": {},
460
- "outputs": [],
461
  "source": [
462
  "# Cell 10: Attach LoRA adapter\n",
463
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -472,13 +591,13 @@
472
  "model.enable_input_require_grads()\n",
473
  "peft_model = get_peft_model(model, lora_config)\n",
474
  "peft_model.print_trainable_parameters()"
475
- ]
 
 
476
  },
477
  {
478
  "cell_type": "code",
479
- "execution_count": null,
480
  "metadata": {},
481
- "outputs": [],
482
  "source": [
483
  "# Cell 11: Training loop\n",
484
  "from trl import SFTTrainer, SFTConfig\n",
@@ -569,7 +688,9 @@
569
  "elapsed = time.time() - t_start\n",
570
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
571
  "print(pd.DataFrame(training_log).to_string(index=False))"
572
- ]
 
 
573
  },
574
  {
575
  "cell_type": "markdown",
@@ -582,9 +703,7 @@
582
  },
583
  {
584
  "cell_type": "code",
585
- "execution_count": null,
586
  "metadata": {},
587
- "outputs": [],
588
  "source": [
589
  "# Cell 12: Run trained model\n",
590
  "print(\"Running TRAINED model on all tasks...\")\n",
@@ -602,7 +721,9 @@
602
  "print(\"AFTER TRAINING:\")\n",
603
  "for t in TASKS:\n",
604
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
605
- ]
 
 
606
  },
607
  {
608
  "cell_type": "markdown",
@@ -613,9 +734,7 @@
613
  },
614
  {
615
  "cell_type": "code",
616
- "execution_count": null,
617
  "metadata": {},
618
- "outputs": [],
619
  "source": [
620
  "# Cell 13: Training curves\n",
621
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -637,13 +756,13 @@
637
  "fig.tight_layout()\n",
638
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
639
  "plt.show()"
640
- ]
 
 
641
  },
642
  {
643
  "cell_type": "code",
644
- "execution_count": null,
645
  "metadata": {},
646
- "outputs": [],
647
  "source": [
648
  "# Cell 14: Before vs After\n",
649
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -673,13 +792,13 @@
673
  "fig.tight_layout()\n",
674
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
675
  "plt.show()"
676
- ]
 
 
677
  },
678
  {
679
  "cell_type": "code",
680
- "execution_count": null,
681
  "metadata": {},
682
- "outputs": [],
683
  "source": [
684
  "# Cell 15: Trajectory comparison\n",
685
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -703,7 +822,9 @@
703
  "fig.tight_layout()\n",
704
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
705
  "plt.show()"
706
- ]
 
 
707
  },
708
  {
709
  "cell_type": "markdown",
@@ -714,9 +835,7 @@
714
  },
715
  {
716
  "cell_type": "code",
717
- "execution_count": null,
718
  "metadata": {},
719
- "outputs": [],
720
  "source": [
721
  "# Cell 16: Final summary\n",
722
  "print(\"=\" * 67)\n",
@@ -753,13 +872,13 @@
753
  "\n",
754
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
755
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
756
- ]
 
 
757
  },
758
  {
759
  "cell_type": "code",
760
- "execution_count": null,
761
  "metadata": {},
762
- "outputs": [],
763
  "source": [
764
  "# Cell 17: Save adapter\n",
765
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -767,7 +886,9 @@
767
  "tokenizer.save_pretrained(save_path)\n",
768
  "print(f\"LoRA adapter saved to {save_path}\")\n",
769
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
770
- ]
 
 
771
  }
772
  ],
773
  "metadata": {
@@ -779,10 +900,18 @@
779
  "name": "python3"
780
  },
781
  "language_info": {
 
 
 
 
 
 
782
  "name": "python",
783
- "version": "3.13.1"
 
 
784
  }
785
  },
786
  "nbformat": 4,
787
  "nbformat_minor": 4
788
- }
 
23
  },
24
  {
25
  "cell_type": "code",
 
26
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
27
  "source": [
28
  "# Cell 1: Install dependencies\n",
29
  "!pip install -q torch torchvision torchaudio\n",
 
31
  "!pip install -q matplotlib pandas\n",
32
  "!pip install -q pydantic httpx\n",
33
  "!pip install -q \"openenv-core[core]>=0.2.2\""
34
+ ],
35
+ "execution_count": 5,
36
+ "outputs": [
37
+ {
38
+ "output_type": "stream",
39
+ "text": [
40
+ "\n",
41
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
42
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
43
+ "zsh:1: 4.45.0 not found\n",
44
+ "\n",
45
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
46
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
47
+ "\n",
48
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
49
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
50
+ "\n",
51
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
52
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
53
+ ]
54
+ }
55
  ]
56
  },
57
  {
58
  "cell_type": "code",
 
59
  "metadata": {},
 
60
  "source": [
61
+ "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
62
+ "import os\n",
63
+ "import sys\n",
64
+ "import shutil\n",
65
+ "import subprocess\n",
66
+ "from pathlib import Path\n",
67
+ "\n",
68
  "REPO_BRANCH = \"hack1\"\n",
69
+ "REPO_URL = \"https://github.com/VaibhavKhandare/viral-posts-env.git\"\n",
70
+ "COLAB_REPO = Path(\"/content/viral-posts-env\")\n",
71
+ "\n",
72
+ "\n",
73
+ "def _is_repo_root(p: Path) -> bool:\n",
74
+ " return (p / \"server\" / \"viraltest_environment.py\").is_file() and (p / \"models.py\").is_file()\n",
75
+ "\n",
76
+ "\n",
77
+ "def _find_local_root() -> Path:\n",
78
+ " here = Path.cwd().resolve()\n",
79
+ " for cand in (here, here.parent, here.parent.parent):\n",
80
+ " if _is_repo_root(cand):\n",
81
+ " return cand\n",
82
+ " raise FileNotFoundError(\n",
83
+ " \"Could not find project root. cd into viral-posts-env or run this notebook in Google Colab.\"\n",
84
+ " )\n",
85
+ "\n",
86
+ "\n",
87
+ "# --- Colab: always clone a clean copy (avoids stale 7-day code) ---\n",
88
+ "if Path(\"/content\").is_dir():\n",
89
+ " if COLAB_REPO.exists():\n",
90
+ " shutil.rmtree(COLAB_REPO, ignore_errors=True)\n",
91
+ " p = subprocess.run(\n",
92
+ " [\n",
93
+ " \"git\", \"clone\", \"--branch\", REPO_BRANCH, \"--depth\", \"1\",\n",
94
+ " REPO_URL, str(COLAB_REPO),\n",
95
+ " ],\n",
96
+ " capture_output=True,\n",
97
+ " text=True,\n",
98
+ " )\n",
99
+ " if p.returncode != 0:\n",
100
+ " raise RuntimeError(\n",
101
+ " \"git clone failed. Check network and branch name.\\n\"\n",
102
+ " f\"stdout:\\n{p.stdout}\\nstderr:\\n{p.stderr}\"\n",
103
+ " )\n",
104
+ " if not COLAB_REPO.is_dir():\n",
105
+ " raise FileNotFoundError(f\"Clone did not create {COLAB_REPO}\")\n",
106
+ " os.chdir(COLAB_REPO)\n",
107
+ " print(\"Mode: Colab (fresh clone)\")\n",
108
+ "else:\n",
109
+ " # --- Local machine: do not use /content ---\n",
110
+ " root = _find_local_root()\n",
111
+ " os.chdir(root)\n",
112
+ " print(\"Mode: local\")\n",
113
+ " print(f\"Repo root: {root}\")\n",
114
+ "\n",
115
+ "REPO_DIR = str(Path.cwd().resolve())\n",
116
+ "if REPO_DIR not in sys.path:\n",
117
+ " sys.path.insert(0, REPO_DIR)\n",
118
  "\n",
119
  "PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
120
  "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
121
+ "\n",
122
+ "try:\n",
123
+ " commit = subprocess.check_output(\n",
124
+ " [\"git\", \"rev-parse\", \"--short\", \"HEAD\"],\n",
125
+ " stderr=subprocess.DEVNULL,\n",
126
+ " text=True,\n",
127
+ " ).strip()\n",
128
+ "except Exception:\n",
129
+ " commit = \"n/a\"\n",
130
+ "\n",
131
  "print(f\"Working dir: {os.getcwd()}\")\n",
132
+ "print(f\"Branch: {REPO_BRANCH}\")\n",
133
+ "print(f\"Commit: {commit}\")\n",
134
  "print(f\"Plots dir: {PLOTS_DIR}\")"
135
+ ],
136
+ "execution_count": 6,
137
+ "outputs": [
138
+ {
139
+ "output_type": "stream",
140
+ "text": [
141
+ "fatal: could not create leading directories of '/content/viral-posts-env': Read-only file system\n"
142
+ ]
143
+ },
144
+ {
145
+ "output_type": "error",
146
+ "ename": "FileNotFoundError",
147
+ "evalue": "[Errno 2] No such file or directory: '/content/viral-posts-env'",
148
+ "traceback": [
149
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
150
+ "\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
151
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 13\u001b[39m\n\u001b[32m 9\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m os.path.exists(REPO_DIR):\n\u001b[32m 10\u001b[39m get_ipython().system(\u001b[33m'rm -rf {REPO_DIR}'\u001b[39m)\n\u001b[32m 11\u001b[39m \n\u001b[32m 12\u001b[39m get_ipython().system(\u001b[33m'git clone --branch {REPO_BRANCH} --depth 1 {REPO_URL} {REPO_DIR}'\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m13\u001b[39m os.chdir(REPO_DIR)\n\u001b[32m 14\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m REPO_DIR \u001b[38;5;28;01mnot\u001b[39;00m \u001b[38;5;28;01min\u001b[39;00m sys.path:\n\u001b[32m 15\u001b[39m sys.path.insert(\u001b[32m0\u001b[39m, REPO_DIR)\n\u001b[32m 16\u001b[39m \n",
152
+ "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/content/viral-posts-env'"
153
+ ]
154
+ }
155
  ]
156
  },
157
  {
158
  "cell_type": "code",
 
159
  "metadata": {},
 
160
  "source": [
161
+ "# Cell 3: Imports (with runtime validation)\n",
162
+ "import json, random, time, textwrap, copy, os, sys\n",
163
  "from pathlib import Path\n",
164
  "from typing import Any, Dict, List, Optional, Tuple\n",
165
  "from collections import defaultdict\n",
166
  "\n",
167
+ "# Find repo root if notebook was opened from training/ and Cell 2 was skipped\n",
168
+ "if not Path(\"server/viraltest_environment.py\").is_file():\n",
169
+ " for cand in (Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent):\n",
170
+ " if (cand / \"server\" / \"viraltest_environment.py\").is_file():\n",
171
+ " os.chdir(cand)\n",
172
+ " s = str(cand.resolve())\n",
173
+ " if s not in sys.path:\n",
174
+ " sys.path.insert(0, s)\n",
175
+ " print(\"Auto chdir to repo root:\", s)\n",
176
+ " break\n",
177
+ " else:\n",
178
+ " raise RuntimeError(\n",
179
+ " \"Project files not found. Run **Cell 2** first (Colab), or run from repo root.\\n\"\n",
180
+ " f\" cwd = {os.getcwd()!r}\\n\"\n",
181
+ " )\n",
182
+ "\n",
183
  "import matplotlib.pyplot as plt\n",
184
  "import numpy as np\n",
185
  "import pandas as pd\n",
 
198
  "TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
199
  "\n",
200
  "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
201
+ "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")\n",
202
+ "\n",
203
+ "# Hard stop if stale repo/code is loaded\n",
204
+ "assert TASK_HORIZON == 30, (\n",
205
+ " f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}. \"\n",
206
+ " \"Restart runtime and run from Cell 1 again (clean clone on hack1).\"\n",
207
+ ")"
208
+ ],
209
+ "execution_count": null,
210
+ "outputs": []
211
  },
212
  {
213
  "cell_type": "markdown",
 
220
  },
221
  {
222
  "cell_type": "code",
 
223
  "metadata": {},
 
224
  "source": [
225
  "# Cell 4: Define heuristic agents + episode runner\n",
226
  "_rng = random.Random(42)\n",
 
297
  " \"rewards\": rewards, \"energies\": energies}\n",
298
  "\n",
299
  "print(\"Agents and episode runner defined.\")"
300
+ ],
301
+ "execution_count": null,
302
+ "outputs": []
303
  },
304
  {
305
  "cell_type": "code",
 
306
  "metadata": {},
 
307
  "source": [
308
+ "# Cell 5: Run baselines (safe)\n",
309
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
310
  "print(\"=\" * 70)\n",
311
  "\n",
312
+ "required = [\"BASELINE_AGENTS\", \"run_episode\", \"TASKS\", \"random\"]\n",
313
+ "missing = [k for k in required if k not in globals()]\n",
314
+ "if missing:\n",
315
+ " raise RuntimeError(\n",
316
+ " f\"Missing prerequisites: {missing}. Run notebook from top (Cell 1 -> Cell 5).\"\n",
317
+ " )\n",
318
+ "\n",
319
  "baseline_results = {}\n",
320
  "for name, fn in BASELINE_AGENTS.items():\n",
321
  " baseline_results[name] = {}\n",
322
  " for task in TASKS:\n",
323
  " _rng = random.Random(42)\n",
324
+ " try:\n",
325
+ " result = run_episode(task, fn, seed=42)\n",
326
+ " except Exception as e:\n",
327
+ " raise RuntimeError(\n",
328
+ " f\"Baseline failed for agent={name}, task={task}: {type(e).__name__}: {e}\"\n",
329
+ " ) from e\n",
330
  " baseline_results[name][task] = result\n",
331
  " print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
332
  " f\"| energy={result['final_energy']:.2f}\")\n",
 
338
  "for name in BASELINE_AGENTS:\n",
339
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
340
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
341
+ ],
342
+ "execution_count": null,
343
+ "outputs": []
344
  },
345
  {
346
  "cell_type": "code",
 
347
  "metadata": {},
 
348
  "source": [
349
  "# Cell 6: Baseline plots\n",
350
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
362
  "fig.tight_layout()\n",
363
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
364
  "plt.show()"
365
+ ],
366
+ "execution_count": null,
367
+ "outputs": []
368
  },
369
  {
370
  "cell_type": "markdown",
 
377
  },
378
  {
379
  "cell_type": "code",
 
380
  "metadata": {},
 
381
  "source": [
382
  "# Cell 7: Load model\n",
383
  "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
 
401
  "model.eval()\n",
402
  "print(f\"Model loaded. Device: {model.device}\")\n",
403
  "print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
404
+ ],
405
+ "execution_count": null,
406
+ "outputs": []
407
  },
408
  {
409
  "cell_type": "code",
 
410
  "metadata": {},
 
411
  "source": [
412
  "# Cell 8: LLM agent functions\n",
413
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
 
523
  " \"burned_out\": obs.creator_energy <= 0}\n",
524
  "\n",
525
  "print(\"LLM agent functions defined.\")"
526
+ ],
527
+ "execution_count": null,
528
+ "outputs": []
529
  },
530
  {
531
  "cell_type": "markdown",
 
538
  },
539
  {
540
  "cell_type": "code",
 
541
  "metadata": {},
 
542
  "source": [
543
  "# Cell 9: Run untrained model\n",
544
  "print(\"Running UNTRAINED base model on all tasks...\")\n",
 
555
  "print(\"BEFORE TRAINING:\")\n",
556
  "for t in TASKS:\n",
557
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
558
+ ],
559
+ "execution_count": null,
560
+ "outputs": []
561
  },
562
  {
563
  "cell_type": "markdown",
 
576
  },
577
  {
578
  "cell_type": "code",
 
579
  "metadata": {},
 
580
  "source": [
581
  "# Cell 10: Attach LoRA adapter\n",
582
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
591
  "model.enable_input_require_grads()\n",
592
  "peft_model = get_peft_model(model, lora_config)\n",
593
  "peft_model.print_trainable_parameters()"
594
+ ],
595
+ "execution_count": null,
596
+ "outputs": []
597
  },
598
  {
599
  "cell_type": "code",
 
600
  "metadata": {},
 
601
  "source": [
602
  "# Cell 11: Training loop\n",
603
  "from trl import SFTTrainer, SFTConfig\n",
 
688
  "elapsed = time.time() - t_start\n",
689
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
690
  "print(pd.DataFrame(training_log).to_string(index=False))"
691
+ ],
692
+ "execution_count": null,
693
+ "outputs": []
694
  },
695
  {
696
  "cell_type": "markdown",
 
703
  },
704
  {
705
  "cell_type": "code",
 
706
  "metadata": {},
 
707
  "source": [
708
  "# Cell 12: Run trained model\n",
709
  "print(\"Running TRAINED model on all tasks...\")\n",
 
721
  "print(\"AFTER TRAINING:\")\n",
722
  "for t in TASKS:\n",
723
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
724
+ ],
725
+ "execution_count": null,
726
+ "outputs": []
727
  },
728
  {
729
  "cell_type": "markdown",
 
734
  },
735
  {
736
  "cell_type": "code",
 
737
  "metadata": {},
 
738
  "source": [
739
  "# Cell 13: Training curves\n",
740
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
756
  "fig.tight_layout()\n",
757
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
758
  "plt.show()"
759
+ ],
760
+ "execution_count": null,
761
+ "outputs": []
762
  },
763
  {
764
  "cell_type": "code",
 
765
  "metadata": {},
 
766
  "source": [
767
  "# Cell 14: Before vs After\n",
768
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
792
  "fig.tight_layout()\n",
793
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
794
  "plt.show()"
795
+ ],
796
+ "execution_count": null,
797
+ "outputs": []
798
  },
799
  {
800
  "cell_type": "code",
 
801
  "metadata": {},
 
802
  "source": [
803
  "# Cell 15: Trajectory comparison\n",
804
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
822
  "fig.tight_layout()\n",
823
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
824
  "plt.show()"
825
+ ],
826
+ "execution_count": null,
827
+ "outputs": []
828
  },
829
  {
830
  "cell_type": "markdown",
 
835
  },
836
  {
837
  "cell_type": "code",
 
838
  "metadata": {},
 
839
  "source": [
840
  "# Cell 16: Final summary\n",
841
  "print(\"=\" * 67)\n",
 
872
  "\n",
873
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
874
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
875
+ ],
876
+ "execution_count": null,
877
+ "outputs": []
878
  },
879
  {
880
  "cell_type": "code",
 
881
  "metadata": {},
 
882
  "source": [
883
  "# Cell 17: Save adapter\n",
884
  "save_path = \"./viraltest_trained_adapter\"\n",
 
886
  "tokenizer.save_pretrained(save_path)\n",
887
  "print(f\"LoRA adapter saved to {save_path}\")\n",
888
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
889
+ ],
890
+ "execution_count": null,
891
+ "outputs": []
892
  }
893
  ],
894
  "metadata": {
 
900
  "name": "python3"
901
  },
902
  "language_info": {
903
+ "codemirror_mode": {
904
+ "name": "ipython",
905
+ "version": 3
906
+ },
907
+ "file_extension": ".py",
908
+ "mimetype": "text/x-python",
909
  "name": "python",
910
+ "nbconvert_exporter": "python",
911
+ "pygments_lexer": "ipython3",
912
+ "version": "3.14.2"
913
  }
914
  },
915
  "nbformat": 4,
916
  "nbformat_minor": 4
917
+ }