Spaces:
Paused
Paused
Commit ·
8d09986
1
Parent(s): a1be3fe
fix: robust notebook setup (no magic shell) + local CWD auto-detect
Browse files- Use subprocess+shutil for Colab git clone; surface clone errors
- Local mode finds repo root without /content
- Cell 3 auto-chdirs from training/ if needed
- Clear errors when project files are missing
Made-with: Cursor
- training/train_grpo.ipynb +203 -74
training/train_grpo.ipynb
CHANGED
|
@@ -23,20 +23,7 @@
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "code",
|
| 26 |
-
"execution_count": null,
|
| 27 |
"metadata": {},
|
| 28 |
-
"outputs": [
|
| 29 |
-
{
|
| 30 |
-
"ename": "",
|
| 31 |
-
"evalue": "",
|
| 32 |
-
"output_type": "error",
|
| 33 |
-
"traceback": [
|
| 34 |
-
"\u001b[1;31mRunning cells with '.venv (Python 3.13.1)' requires the ipykernel package.\n",
|
| 35 |
-
"\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
|
| 36 |
-
"\u001b[1;31mCommand: '/Users/vaibhavkhandare/Projects/mernstack/openenv-course/viraltest/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
|
| 37 |
-
]
|
| 38 |
-
}
|
| 39 |
-
],
|
| 40 |
"source": [
|
| 41 |
"# Cell 1: Install dependencies\n",
|
| 42 |
"!pip install -q torch torchvision torchaudio\n",
|
|
@@ -44,41 +31,155 @@
|
|
| 44 |
"!pip install -q matplotlib pandas\n",
|
| 45 |
"!pip install -q pydantic httpx\n",
|
| 46 |
"!pip install -q \"openenv-core[core]>=0.2.2\""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
]
|
| 48 |
},
|
| 49 |
{
|
| 50 |
"cell_type": "code",
|
| 51 |
-
"execution_count": null,
|
| 52 |
"metadata": {},
|
| 53 |
-
"outputs": [],
|
| 54 |
"source": [
|
| 55 |
-
"# Cell 2:
|
| 56 |
-
"import os
|
| 57 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"REPO_BRANCH = \"hack1\"\n",
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"\n",
|
| 64 |
"PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
|
| 65 |
"os.makedirs(PLOTS_DIR, exist_ok=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"print(f\"Working dir: {os.getcwd()}\")\n",
|
|
|
|
|
|
|
| 67 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
]
|
| 69 |
},
|
| 70 |
{
|
| 71 |
"cell_type": "code",
|
| 72 |
-
"execution_count": null,
|
| 73 |
"metadata": {},
|
| 74 |
-
"outputs": [],
|
| 75 |
"source": [
|
| 76 |
-
"# Cell 3: Imports\n",
|
| 77 |
-
"import json, random, time, textwrap, copy\n",
|
| 78 |
"from pathlib import Path\n",
|
| 79 |
"from typing import Any, Dict, List, Optional, Tuple\n",
|
| 80 |
"from collections import defaultdict\n",
|
| 81 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"import matplotlib.pyplot as plt\n",
|
| 83 |
"import numpy as np\n",
|
| 84 |
"import pandas as pd\n",
|
|
@@ -97,8 +198,16 @@
|
|
| 97 |
"TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
|
| 98 |
"\n",
|
| 99 |
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 100 |
-
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
},
|
| 103 |
{
|
| 104 |
"cell_type": "markdown",
|
|
@@ -111,9 +220,7 @@
|
|
| 111 |
},
|
| 112 |
{
|
| 113 |
"cell_type": "code",
|
| 114 |
-
"execution_count": null,
|
| 115 |
"metadata": {},
|
| 116 |
-
"outputs": [],
|
| 117 |
"source": [
|
| 118 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 119 |
"_rng = random.Random(42)\n",
|
|
@@ -190,24 +297,36 @@
|
|
| 190 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 191 |
"\n",
|
| 192 |
"print(\"Agents and episode runner defined.\")"
|
| 193 |
-
]
|
|
|
|
|
|
|
| 194 |
},
|
| 195 |
{
|
| 196 |
"cell_type": "code",
|
| 197 |
-
"execution_count": null,
|
| 198 |
"metadata": {},
|
| 199 |
-
"outputs": [],
|
| 200 |
"source": [
|
| 201 |
-
"# Cell 5: Run baselines\n",
|
| 202 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
| 203 |
"print(\"=\" * 70)\n",
|
| 204 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
"baseline_results = {}\n",
|
| 206 |
"for name, fn in BASELINE_AGENTS.items():\n",
|
| 207 |
" baseline_results[name] = {}\n",
|
| 208 |
" for task in TASKS:\n",
|
| 209 |
" _rng = random.Random(42)\n",
|
| 210 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
" baseline_results[name][task] = result\n",
|
| 212 |
" print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
|
| 213 |
" f\"| energy={result['final_energy']:.2f}\")\n",
|
|
@@ -219,13 +338,13 @@
|
|
| 219 |
"for name in BASELINE_AGENTS:\n",
|
| 220 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 221 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 222 |
-
]
|
|
|
|
|
|
|
| 223 |
},
|
| 224 |
{
|
| 225 |
"cell_type": "code",
|
| 226 |
-
"execution_count": null,
|
| 227 |
"metadata": {},
|
| 228 |
-
"outputs": [],
|
| 229 |
"source": [
|
| 230 |
"# Cell 6: Baseline plots\n",
|
| 231 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
@@ -243,7 +362,9 @@
|
|
| 243 |
"fig.tight_layout()\n",
|
| 244 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 245 |
"plt.show()"
|
| 246 |
-
]
|
|
|
|
|
|
|
| 247 |
},
|
| 248 |
{
|
| 249 |
"cell_type": "markdown",
|
|
@@ -256,9 +377,7 @@
|
|
| 256 |
},
|
| 257 |
{
|
| 258 |
"cell_type": "code",
|
| 259 |
-
"execution_count": null,
|
| 260 |
"metadata": {},
|
| 261 |
-
"outputs": [],
|
| 262 |
"source": [
|
| 263 |
"# Cell 7: Load model\n",
|
| 264 |
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
|
@@ -282,13 +401,13 @@
|
|
| 282 |
"model.eval()\n",
|
| 283 |
"print(f\"Model loaded. Device: {model.device}\")\n",
|
| 284 |
"print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
|
| 285 |
-
]
|
|
|
|
|
|
|
| 286 |
},
|
| 287 |
{
|
| 288 |
"cell_type": "code",
|
| 289 |
-
"execution_count": null,
|
| 290 |
"metadata": {},
|
| 291 |
-
"outputs": [],
|
| 292 |
"source": [
|
| 293 |
"# Cell 8: LLM agent functions\n",
|
| 294 |
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
|
@@ -404,7 +523,9 @@
|
|
| 404 |
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 405 |
"\n",
|
| 406 |
"print(\"LLM agent functions defined.\")"
|
| 407 |
-
]
|
|
|
|
|
|
|
| 408 |
},
|
| 409 |
{
|
| 410 |
"cell_type": "markdown",
|
|
@@ -417,9 +538,7 @@
|
|
| 417 |
},
|
| 418 |
{
|
| 419 |
"cell_type": "code",
|
| 420 |
-
"execution_count": null,
|
| 421 |
"metadata": {},
|
| 422 |
-
"outputs": [],
|
| 423 |
"source": [
|
| 424 |
"# Cell 9: Run untrained model\n",
|
| 425 |
"print(\"Running UNTRAINED base model on all tasks...\")\n",
|
|
@@ -436,7 +555,9 @@
|
|
| 436 |
"print(\"BEFORE TRAINING:\")\n",
|
| 437 |
"for t in TASKS:\n",
|
| 438 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 439 |
-
]
|
|
|
|
|
|
|
| 440 |
},
|
| 441 |
{
|
| 442 |
"cell_type": "markdown",
|
|
@@ -455,9 +576,7 @@
|
|
| 455 |
},
|
| 456 |
{
|
| 457 |
"cell_type": "code",
|
| 458 |
-
"execution_count": null,
|
| 459 |
"metadata": {},
|
| 460 |
-
"outputs": [],
|
| 461 |
"source": [
|
| 462 |
"# Cell 10: Attach LoRA adapter\n",
|
| 463 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
@@ -472,13 +591,13 @@
|
|
| 472 |
"model.enable_input_require_grads()\n",
|
| 473 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 474 |
"peft_model.print_trainable_parameters()"
|
| 475 |
-
]
|
|
|
|
|
|
|
| 476 |
},
|
| 477 |
{
|
| 478 |
"cell_type": "code",
|
| 479 |
-
"execution_count": null,
|
| 480 |
"metadata": {},
|
| 481 |
-
"outputs": [],
|
| 482 |
"source": [
|
| 483 |
"# Cell 11: Training loop\n",
|
| 484 |
"from trl import SFTTrainer, SFTConfig\n",
|
|
@@ -569,7 +688,9 @@
|
|
| 569 |
"elapsed = time.time() - t_start\n",
|
| 570 |
"print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
|
| 571 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 572 |
-
]
|
|
|
|
|
|
|
| 573 |
},
|
| 574 |
{
|
| 575 |
"cell_type": "markdown",
|
|
@@ -582,9 +703,7 @@
|
|
| 582 |
},
|
| 583 |
{
|
| 584 |
"cell_type": "code",
|
| 585 |
-
"execution_count": null,
|
| 586 |
"metadata": {},
|
| 587 |
-
"outputs": [],
|
| 588 |
"source": [
|
| 589 |
"# Cell 12: Run trained model\n",
|
| 590 |
"print(\"Running TRAINED model on all tasks...\")\n",
|
|
@@ -602,7 +721,9 @@
|
|
| 602 |
"print(\"AFTER TRAINING:\")\n",
|
| 603 |
"for t in TASKS:\n",
|
| 604 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 605 |
-
]
|
|
|
|
|
|
|
| 606 |
},
|
| 607 |
{
|
| 608 |
"cell_type": "markdown",
|
|
@@ -613,9 +734,7 @@
|
|
| 613 |
},
|
| 614 |
{
|
| 615 |
"cell_type": "code",
|
| 616 |
-
"execution_count": null,
|
| 617 |
"metadata": {},
|
| 618 |
-
"outputs": [],
|
| 619 |
"source": [
|
| 620 |
"# Cell 13: Training curves\n",
|
| 621 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
@@ -637,13 +756,13 @@
|
|
| 637 |
"fig.tight_layout()\n",
|
| 638 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 639 |
"plt.show()"
|
| 640 |
-
]
|
|
|
|
|
|
|
| 641 |
},
|
| 642 |
{
|
| 643 |
"cell_type": "code",
|
| 644 |
-
"execution_count": null,
|
| 645 |
"metadata": {},
|
| 646 |
-
"outputs": [],
|
| 647 |
"source": [
|
| 648 |
"# Cell 14: Before vs After\n",
|
| 649 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
@@ -673,13 +792,13 @@
|
|
| 673 |
"fig.tight_layout()\n",
|
| 674 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 675 |
"plt.show()"
|
| 676 |
-
]
|
|
|
|
|
|
|
| 677 |
},
|
| 678 |
{
|
| 679 |
"cell_type": "code",
|
| 680 |
-
"execution_count": null,
|
| 681 |
"metadata": {},
|
| 682 |
-
"outputs": [],
|
| 683 |
"source": [
|
| 684 |
"# Cell 15: Trajectory comparison\n",
|
| 685 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
@@ -703,7 +822,9 @@
|
|
| 703 |
"fig.tight_layout()\n",
|
| 704 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 705 |
"plt.show()"
|
| 706 |
-
]
|
|
|
|
|
|
|
| 707 |
},
|
| 708 |
{
|
| 709 |
"cell_type": "markdown",
|
|
@@ -714,9 +835,7 @@
|
|
| 714 |
},
|
| 715 |
{
|
| 716 |
"cell_type": "code",
|
| 717 |
-
"execution_count": null,
|
| 718 |
"metadata": {},
|
| 719 |
-
"outputs": [],
|
| 720 |
"source": [
|
| 721 |
"# Cell 16: Final summary\n",
|
| 722 |
"print(\"=\" * 67)\n",
|
|
@@ -753,13 +872,13 @@
|
|
| 753 |
"\n",
|
| 754 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 755 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 756 |
-
]
|
|
|
|
|
|
|
| 757 |
},
|
| 758 |
{
|
| 759 |
"cell_type": "code",
|
| 760 |
-
"execution_count": null,
|
| 761 |
"metadata": {},
|
| 762 |
-
"outputs": [],
|
| 763 |
"source": [
|
| 764 |
"# Cell 17: Save adapter\n",
|
| 765 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
@@ -767,7 +886,9 @@
|
|
| 767 |
"tokenizer.save_pretrained(save_path)\n",
|
| 768 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 769 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 770 |
-
]
|
|
|
|
|
|
|
| 771 |
}
|
| 772 |
],
|
| 773 |
"metadata": {
|
|
@@ -779,10 +900,18 @@
|
|
| 779 |
"name": "python3"
|
| 780 |
},
|
| 781 |
"language_info": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
"name": "python",
|
| 783 |
-
"
|
|
|
|
|
|
|
| 784 |
}
|
| 785 |
},
|
| 786 |
"nbformat": 4,
|
| 787 |
"nbformat_minor": 4
|
| 788 |
-
}
|
|
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "code",
|
|
|
|
| 26 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"source": [
|
| 28 |
"# Cell 1: Install dependencies\n",
|
| 29 |
"!pip install -q torch torchvision torchaudio\n",
|
|
|
|
| 31 |
"!pip install -q matplotlib pandas\n",
|
| 32 |
"!pip install -q pydantic httpx\n",
|
| 33 |
"!pip install -q \"openenv-core[core]>=0.2.2\""
|
| 34 |
+
],
|
| 35 |
+
"execution_count": 5,
|
| 36 |
+
"outputs": [
|
| 37 |
+
{
|
| 38 |
+
"output_type": "stream",
|
| 39 |
+
"text": [
|
| 40 |
+
"\n",
|
| 41 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
|
| 42 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
| 43 |
+
"zsh:1: 4.45.0 not found\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
|
| 46 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
|
| 49 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
|
| 52 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
| 53 |
+
]
|
| 54 |
+
}
|
| 55 |
]
|
| 56 |
},
|
| 57 |
{
|
| 58 |
"cell_type": "code",
|
|
|
|
| 59 |
"metadata": {},
|
|
|
|
| 60 |
"source": [
|
| 61 |
+
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 62 |
+
"import os\n",
|
| 63 |
+
"import sys\n",
|
| 64 |
+
"import shutil\n",
|
| 65 |
+
"import subprocess\n",
|
| 66 |
+
"from pathlib import Path\n",
|
| 67 |
+
"\n",
|
| 68 |
"REPO_BRANCH = \"hack1\"\n",
|
| 69 |
+
"REPO_URL = \"https://github.com/VaibhavKhandare/viral-posts-env.git\"\n",
|
| 70 |
+
"COLAB_REPO = Path(\"/content/viral-posts-env\")\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"def _is_repo_root(p: Path) -> bool:\n",
|
| 74 |
+
" return (p / \"server\" / \"viraltest_environment.py\").is_file() and (p / \"models.py\").is_file()\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"def _find_local_root() -> Path:\n",
|
| 78 |
+
" here = Path.cwd().resolve()\n",
|
| 79 |
+
" for cand in (here, here.parent, here.parent.parent):\n",
|
| 80 |
+
" if _is_repo_root(cand):\n",
|
| 81 |
+
" return cand\n",
|
| 82 |
+
" raise FileNotFoundError(\n",
|
| 83 |
+
" \"Could not find project root. cd into viral-posts-env or run this notebook in Google Colab.\"\n",
|
| 84 |
+
" )\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"# --- Colab: always clone a clean copy (avoids stale 7-day code) ---\n",
|
| 88 |
+
"if Path(\"/content\").is_dir():\n",
|
| 89 |
+
" if COLAB_REPO.exists():\n",
|
| 90 |
+
" shutil.rmtree(COLAB_REPO, ignore_errors=True)\n",
|
| 91 |
+
" p = subprocess.run(\n",
|
| 92 |
+
" [\n",
|
| 93 |
+
" \"git\", \"clone\", \"--branch\", REPO_BRANCH, \"--depth\", \"1\",\n",
|
| 94 |
+
" REPO_URL, str(COLAB_REPO),\n",
|
| 95 |
+
" ],\n",
|
| 96 |
+
" capture_output=True,\n",
|
| 97 |
+
" text=True,\n",
|
| 98 |
+
" )\n",
|
| 99 |
+
" if p.returncode != 0:\n",
|
| 100 |
+
" raise RuntimeError(\n",
|
| 101 |
+
" \"git clone failed. Check network and branch name.\\n\"\n",
|
| 102 |
+
" f\"stdout:\\n{p.stdout}\\nstderr:\\n{p.stderr}\"\n",
|
| 103 |
+
" )\n",
|
| 104 |
+
" if not COLAB_REPO.is_dir():\n",
|
| 105 |
+
" raise FileNotFoundError(f\"Clone did not create {COLAB_REPO}\")\n",
|
| 106 |
+
" os.chdir(COLAB_REPO)\n",
|
| 107 |
+
" print(\"Mode: Colab (fresh clone)\")\n",
|
| 108 |
+
"else:\n",
|
| 109 |
+
" # --- Local machine: do not use /content ---\n",
|
| 110 |
+
" root = _find_local_root()\n",
|
| 111 |
+
" os.chdir(root)\n",
|
| 112 |
+
" print(\"Mode: local\")\n",
|
| 113 |
+
" print(f\"Repo root: {root}\")\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"REPO_DIR = str(Path.cwd().resolve())\n",
|
| 116 |
+
"if REPO_DIR not in sys.path:\n",
|
| 117 |
+
" sys.path.insert(0, REPO_DIR)\n",
|
| 118 |
"\n",
|
| 119 |
"PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
|
| 120 |
"os.makedirs(PLOTS_DIR, exist_ok=True)\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"try:\n",
|
| 123 |
+
" commit = subprocess.check_output(\n",
|
| 124 |
+
" [\"git\", \"rev-parse\", \"--short\", \"HEAD\"],\n",
|
| 125 |
+
" stderr=subprocess.DEVNULL,\n",
|
| 126 |
+
" text=True,\n",
|
| 127 |
+
" ).strip()\n",
|
| 128 |
+
"except Exception:\n",
|
| 129 |
+
" commit = \"n/a\"\n",
|
| 130 |
+
"\n",
|
| 131 |
"print(f\"Working dir: {os.getcwd()}\")\n",
|
| 132 |
+
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 133 |
+
"print(f\"Commit: {commit}\")\n",
|
| 134 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 135 |
+
],
|
| 136 |
+
"execution_count": 6,
|
| 137 |
+
"outputs": [
|
| 138 |
+
{
|
| 139 |
+
"output_type": "stream",
|
| 140 |
+
"text": [
|
| 141 |
+
"fatal: could not create leading directories of '/content/viral-posts-env': Read-only file system\n"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"output_type": "error",
|
| 146 |
+
"ename": "FileNotFoundError",
|
| 147 |
+
"evalue": "[Errno 2] No such file or directory: '/content/viral-posts-env'",
|
| 148 |
+
"traceback": [
|
| 149 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 150 |
+
"\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
|
| 151 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 13\u001b[39m\n\u001b[32m 9\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m os.path.exists(REPO_DIR):\n\u001b[32m 10\u001b[39m get_ipython().system(\u001b[33m'rm -rf {REPO_DIR}'\u001b[39m)\n\u001b[32m 11\u001b[39m \n\u001b[32m 12\u001b[39m get_ipython().system(\u001b[33m'git clone --branch {REPO_BRANCH} --depth 1 {REPO_URL} {REPO_DIR}'\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m13\u001b[39m os.chdir(REPO_DIR)\n\u001b[32m 14\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m REPO_DIR \u001b[38;5;28;01mnot\u001b[39;00m \u001b[38;5;28;01min\u001b[39;00m sys.path:\n\u001b[32m 15\u001b[39m sys.path.insert(\u001b[32m0\u001b[39m, REPO_DIR)\n\u001b[32m 16\u001b[39m \n",
|
| 152 |
+
"\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/content/viral-posts-env'"
|
| 153 |
+
]
|
| 154 |
+
}
|
| 155 |
]
|
| 156 |
},
|
| 157 |
{
|
| 158 |
"cell_type": "code",
|
|
|
|
| 159 |
"metadata": {},
|
|
|
|
| 160 |
"source": [
|
| 161 |
+
"# Cell 3: Imports (with runtime validation)\n",
|
| 162 |
+
"import json, random, time, textwrap, copy, os, sys\n",
|
| 163 |
"from pathlib import Path\n",
|
| 164 |
"from typing import Any, Dict, List, Optional, Tuple\n",
|
| 165 |
"from collections import defaultdict\n",
|
| 166 |
"\n",
|
| 167 |
+
"# Find repo root if notebook was opened from training/ and Cell 2 was skipped\n",
|
| 168 |
+
"if not Path(\"server/viraltest_environment.py\").is_file():\n",
|
| 169 |
+
" for cand in (Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent):\n",
|
| 170 |
+
" if (cand / \"server\" / \"viraltest_environment.py\").is_file():\n",
|
| 171 |
+
" os.chdir(cand)\n",
|
| 172 |
+
" s = str(cand.resolve())\n",
|
| 173 |
+
" if s not in sys.path:\n",
|
| 174 |
+
" sys.path.insert(0, s)\n",
|
| 175 |
+
" print(\"Auto chdir to repo root:\", s)\n",
|
| 176 |
+
" break\n",
|
| 177 |
+
" else:\n",
|
| 178 |
+
" raise RuntimeError(\n",
|
| 179 |
+
" \"Project files not found. Run **Cell 2** first (Colab), or run from repo root.\\n\"\n",
|
| 180 |
+
" f\" cwd = {os.getcwd()!r}\\n\"\n",
|
| 181 |
+
" )\n",
|
| 182 |
+
"\n",
|
| 183 |
"import matplotlib.pyplot as plt\n",
|
| 184 |
"import numpy as np\n",
|
| 185 |
"import pandas as pd\n",
|
|
|
|
| 198 |
"TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
|
| 199 |
"\n",
|
| 200 |
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 201 |
+
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# Hard stop if stale repo/code is loaded\n",
|
| 204 |
+
"assert TASK_HORIZON == 30, (\n",
|
| 205 |
+
" f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}. \"\n",
|
| 206 |
+
" \"Restart runtime and run from Cell 1 again (clean clone on hack1).\"\n",
|
| 207 |
+
")"
|
| 208 |
+
],
|
| 209 |
+
"execution_count": null,
|
| 210 |
+
"outputs": []
|
| 211 |
},
|
| 212 |
{
|
| 213 |
"cell_type": "markdown",
|
|
|
|
| 220 |
},
|
| 221 |
{
|
| 222 |
"cell_type": "code",
|
|
|
|
| 223 |
"metadata": {},
|
|
|
|
| 224 |
"source": [
|
| 225 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 226 |
"_rng = random.Random(42)\n",
|
|
|
|
| 297 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 298 |
"\n",
|
| 299 |
"print(\"Agents and episode runner defined.\")"
|
| 300 |
+
],
|
| 301 |
+
"execution_count": null,
|
| 302 |
+
"outputs": []
|
| 303 |
},
|
| 304 |
{
|
| 305 |
"cell_type": "code",
|
|
|
|
| 306 |
"metadata": {},
|
|
|
|
| 307 |
"source": [
|
| 308 |
+
"# Cell 5: Run baselines (safe)\n",
|
| 309 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
| 310 |
"print(\"=\" * 70)\n",
|
| 311 |
"\n",
|
| 312 |
+
"required = [\"BASELINE_AGENTS\", \"run_episode\", \"TASKS\", \"random\"]\n",
|
| 313 |
+
"missing = [k for k in required if k not in globals()]\n",
|
| 314 |
+
"if missing:\n",
|
| 315 |
+
" raise RuntimeError(\n",
|
| 316 |
+
" f\"Missing prerequisites: {missing}. Run notebook from top (Cell 1 -> Cell 5).\"\n",
|
| 317 |
+
" )\n",
|
| 318 |
+
"\n",
|
| 319 |
"baseline_results = {}\n",
|
| 320 |
"for name, fn in BASELINE_AGENTS.items():\n",
|
| 321 |
" baseline_results[name] = {}\n",
|
| 322 |
" for task in TASKS:\n",
|
| 323 |
" _rng = random.Random(42)\n",
|
| 324 |
+
" try:\n",
|
| 325 |
+
" result = run_episode(task, fn, seed=42)\n",
|
| 326 |
+
" except Exception as e:\n",
|
| 327 |
+
" raise RuntimeError(\n",
|
| 328 |
+
" f\"Baseline failed for agent={name}, task={task}: {type(e).__name__}: {e}\"\n",
|
| 329 |
+
" ) from e\n",
|
| 330 |
" baseline_results[name][task] = result\n",
|
| 331 |
" print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
|
| 332 |
" f\"| energy={result['final_energy']:.2f}\")\n",
|
|
|
|
| 338 |
"for name in BASELINE_AGENTS:\n",
|
| 339 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 340 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 341 |
+
],
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"outputs": []
|
| 344 |
},
|
| 345 |
{
|
| 346 |
"cell_type": "code",
|
|
|
|
| 347 |
"metadata": {},
|
|
|
|
| 348 |
"source": [
|
| 349 |
"# Cell 6: Baseline plots\n",
|
| 350 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
|
|
| 362 |
"fig.tight_layout()\n",
|
| 363 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 364 |
"plt.show()"
|
| 365 |
+
],
|
| 366 |
+
"execution_count": null,
|
| 367 |
+
"outputs": []
|
| 368 |
},
|
| 369 |
{
|
| 370 |
"cell_type": "markdown",
|
|
|
|
| 377 |
},
|
| 378 |
{
|
| 379 |
"cell_type": "code",
|
|
|
|
| 380 |
"metadata": {},
|
|
|
|
| 381 |
"source": [
|
| 382 |
"# Cell 7: Load model\n",
|
| 383 |
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
|
|
|
| 401 |
"model.eval()\n",
|
| 402 |
"print(f\"Model loaded. Device: {model.device}\")\n",
|
| 403 |
"print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
|
| 404 |
+
],
|
| 405 |
+
"execution_count": null,
|
| 406 |
+
"outputs": []
|
| 407 |
},
|
| 408 |
{
|
| 409 |
"cell_type": "code",
|
|
|
|
| 410 |
"metadata": {},
|
|
|
|
| 411 |
"source": [
|
| 412 |
"# Cell 8: LLM agent functions\n",
|
| 413 |
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
|
|
|
| 523 |
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 524 |
"\n",
|
| 525 |
"print(\"LLM agent functions defined.\")"
|
| 526 |
+
],
|
| 527 |
+
"execution_count": null,
|
| 528 |
+
"outputs": []
|
| 529 |
},
|
| 530 |
{
|
| 531 |
"cell_type": "markdown",
|
|
|
|
| 538 |
},
|
| 539 |
{
|
| 540 |
"cell_type": "code",
|
|
|
|
| 541 |
"metadata": {},
|
|
|
|
| 542 |
"source": [
|
| 543 |
"# Cell 9: Run untrained model\n",
|
| 544 |
"print(\"Running UNTRAINED base model on all tasks...\")\n",
|
|
|
|
| 555 |
"print(\"BEFORE TRAINING:\")\n",
|
| 556 |
"for t in TASKS:\n",
|
| 557 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 558 |
+
],
|
| 559 |
+
"execution_count": null,
|
| 560 |
+
"outputs": []
|
| 561 |
},
|
| 562 |
{
|
| 563 |
"cell_type": "markdown",
|
|
|
|
| 576 |
},
|
| 577 |
{
|
| 578 |
"cell_type": "code",
|
|
|
|
| 579 |
"metadata": {},
|
|
|
|
| 580 |
"source": [
|
| 581 |
"# Cell 10: Attach LoRA adapter\n",
|
| 582 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
|
|
| 591 |
"model.enable_input_require_grads()\n",
|
| 592 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 593 |
"peft_model.print_trainable_parameters()"
|
| 594 |
+
],
|
| 595 |
+
"execution_count": null,
|
| 596 |
+
"outputs": []
|
| 597 |
},
|
| 598 |
{
|
| 599 |
"cell_type": "code",
|
|
|
|
| 600 |
"metadata": {},
|
|
|
|
| 601 |
"source": [
|
| 602 |
"# Cell 11: Training loop\n",
|
| 603 |
"from trl import SFTTrainer, SFTConfig\n",
|
|
|
|
| 688 |
"elapsed = time.time() - t_start\n",
|
| 689 |
"print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
|
| 690 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 691 |
+
],
|
| 692 |
+
"execution_count": null,
|
| 693 |
+
"outputs": []
|
| 694 |
},
|
| 695 |
{
|
| 696 |
"cell_type": "markdown",
|
|
|
|
| 703 |
},
|
| 704 |
{
|
| 705 |
"cell_type": "code",
|
|
|
|
| 706 |
"metadata": {},
|
|
|
|
| 707 |
"source": [
|
| 708 |
"# Cell 12: Run trained model\n",
|
| 709 |
"print(\"Running TRAINED model on all tasks...\")\n",
|
|
|
|
| 721 |
"print(\"AFTER TRAINING:\")\n",
|
| 722 |
"for t in TASKS:\n",
|
| 723 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 724 |
+
],
|
| 725 |
+
"execution_count": null,
|
| 726 |
+
"outputs": []
|
| 727 |
},
|
| 728 |
{
|
| 729 |
"cell_type": "markdown",
|
|
|
|
| 734 |
},
|
| 735 |
{
|
| 736 |
"cell_type": "code",
|
|
|
|
| 737 |
"metadata": {},
|
|
|
|
| 738 |
"source": [
|
| 739 |
"# Cell 13: Training curves\n",
|
| 740 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
|
|
| 756 |
"fig.tight_layout()\n",
|
| 757 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 758 |
"plt.show()"
|
| 759 |
+
],
|
| 760 |
+
"execution_count": null,
|
| 761 |
+
"outputs": []
|
| 762 |
},
|
| 763 |
{
|
| 764 |
"cell_type": "code",
|
|
|
|
| 765 |
"metadata": {},
|
|
|
|
| 766 |
"source": [
|
| 767 |
"# Cell 14: Before vs After\n",
|
| 768 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
|
|
| 792 |
"fig.tight_layout()\n",
|
| 793 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 794 |
"plt.show()"
|
| 795 |
+
],
|
| 796 |
+
"execution_count": null,
|
| 797 |
+
"outputs": []
|
| 798 |
},
|
| 799 |
{
|
| 800 |
"cell_type": "code",
|
|
|
|
| 801 |
"metadata": {},
|
|
|
|
| 802 |
"source": [
|
| 803 |
"# Cell 15: Trajectory comparison\n",
|
| 804 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
|
|
| 822 |
"fig.tight_layout()\n",
|
| 823 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 824 |
"plt.show()"
|
| 825 |
+
],
|
| 826 |
+
"execution_count": null,
|
| 827 |
+
"outputs": []
|
| 828 |
},
|
| 829 |
{
|
| 830 |
"cell_type": "markdown",
|
|
|
|
| 835 |
},
|
| 836 |
{
|
| 837 |
"cell_type": "code",
|
|
|
|
| 838 |
"metadata": {},
|
|
|
|
| 839 |
"source": [
|
| 840 |
"# Cell 16: Final summary\n",
|
| 841 |
"print(\"=\" * 67)\n",
|
|
|
|
| 872 |
"\n",
|
| 873 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 874 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 875 |
+
],
|
| 876 |
+
"execution_count": null,
|
| 877 |
+
"outputs": []
|
| 878 |
},
|
| 879 |
{
|
| 880 |
"cell_type": "code",
|
|
|
|
| 881 |
"metadata": {},
|
|
|
|
| 882 |
"source": [
|
| 883 |
"# Cell 17: Save adapter\n",
|
| 884 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
|
|
| 886 |
"tokenizer.save_pretrained(save_path)\n",
|
| 887 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 888 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 889 |
+
],
|
| 890 |
+
"execution_count": null,
|
| 891 |
+
"outputs": []
|
| 892 |
}
|
| 893 |
],
|
| 894 |
"metadata": {
|
|
|
|
| 900 |
"name": "python3"
|
| 901 |
},
|
| 902 |
"language_info": {
|
| 903 |
+
"codemirror_mode": {
|
| 904 |
+
"name": "ipython",
|
| 905 |
+
"version": 3
|
| 906 |
+
},
|
| 907 |
+
"file_extension": ".py",
|
| 908 |
+
"mimetype": "text/x-python",
|
| 909 |
"name": "python",
|
| 910 |
+
"nbconvert_exporter": "python",
|
| 911 |
+
"pygments_lexer": "ipython3",
|
| 912 |
+
"version": "3.14.2"
|
| 913 |
}
|
| 914 |
},
|
| 915 |
"nbformat": 4,
|
| 916 |
"nbformat_minor": 4
|
| 917 |
+
}
|