Spaces:
Paused
Paused
Commit ·
4a29e22
1
Parent(s): e2c547b
fix: rewrite training notebook for real LoRA fine-tuning on Colab
Browse files- Add missing openenv-core dependency to install cell
- Self-contained: clones repo, installs all deps, runs end-to-end
- Real weight updates via LoRA + SFT (not prompt engineering)
- 4-bit quantization to fit free Colab T4 GPU
- Pipeline: baselines → untrained LLM → LoRA training → trained LLM → plots
Made-with: Cursor
- training/train_grpo.ipynb +774 -1039
training/train_grpo.ipynb
CHANGED
|
@@ -1,1041 +1,776 @@
|
|
| 1 |
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
},
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
"metadata": {},
|
| 29 |
-
"outputs": [],
|
| 30 |
-
"source": [
|
| 31 |
-
"!pip install -q trl>=0.12.0 transformers accelerate peft bitsandbytes datasets\n",
|
| 32 |
-
"!pip install -q openai httpx matplotlib pandas\n",
|
| 33 |
-
"!pip install -q openenv-core[core]>=0.2.2"
|
| 34 |
-
]
|
| 35 |
-
},
|
| 36 |
-
{
|
| 37 |
-
"cell_type": "code",
|
| 38 |
-
"execution_count": null,
|
| 39 |
-
"metadata": {},
|
| 40 |
-
"outputs": [],
|
| 41 |
-
"source": [
|
| 42 |
-
"import json\n",
|
| 43 |
-
"import os\n",
|
| 44 |
-
"import time\n",
|
| 45 |
-
"import random\n",
|
| 46 |
-
"import copy\n",
|
| 47 |
-
"from pathlib import Path\n",
|
| 48 |
-
"from typing import Any, Dict, List, Optional, Tuple\n",
|
| 49 |
-
"\n",
|
| 50 |
-
"import matplotlib.pyplot as plt\n",
|
| 51 |
-
"import numpy as np\n",
|
| 52 |
-
"import pandas as pd\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"PLOTS_DIR = Path(\"../plots\")\n",
|
| 55 |
-
"PLOTS_DIR.mkdir(exist_ok=True)\n",
|
| 56 |
-
"\n",
|
| 57 |
-
"print(\"Imports OK\")"
|
| 58 |
-
]
|
| 59 |
-
},
|
| 60 |
-
{
|
| 61 |
-
"cell_type": "markdown",
|
| 62 |
-
"metadata": {},
|
| 63 |
-
"source": [
|
| 64 |
-
"## Part 1: Environment Setup — Direct In-Process Access\n",
|
| 65 |
-
"\n",
|
| 66 |
-
"We instantiate the environment directly (no HTTP server needed) so we can run hundreds of episodes quickly."
|
| 67 |
-
]
|
| 68 |
-
},
|
| 69 |
-
{
|
| 70 |
-
"cell_type": "code",
|
| 71 |
-
"execution_count": null,
|
| 72 |
-
"metadata": {},
|
| 73 |
-
"outputs": [],
|
| 74 |
-
"source": [
|
| 75 |
-
"import sys\n",
|
| 76 |
-
"sys.path.insert(0, \"..\")\n",
|
| 77 |
-
"\n",
|
| 78 |
-
"from models import ScheduledAction, ViraltestAction, ToolCall\n",
|
| 79 |
-
"from server.viraltest_environment import (\n",
|
| 80 |
-
" ViraltestEnvironment,\n",
|
| 81 |
-
" TAG_POOL,\n",
|
| 82 |
-
" TOPIC_CATEGORIES,\n",
|
| 83 |
-
" TASK_HORIZON,\n",
|
| 84 |
-
")\n",
|
| 85 |
-
"\n",
|
| 86 |
-
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
| 87 |
-
"NICHES = list(TOPIC_CATEGORIES.keys())\n",
|
| 88 |
-
"CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
|
| 89 |
-
"INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
|
| 90 |
-
"TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Niches: {len(NICHES)}\")\n",
|
| 93 |
-
"print(f\"Tasks: {TASKS}\")\n",
|
| 94 |
-
"print(f\"Horizon: {TASK_HORIZON} steps (days)\")"
|
| 95 |
-
]
|
| 96 |
-
},
|
| 97 |
-
{
|
| 98 |
-
"cell_type": "markdown",
|
| 99 |
-
"metadata": {},
|
| 100 |
-
"source": [
|
| 101 |
-
"## Part 2: Heuristic Baselines\n",
|
| 102 |
-
"\n",
|
| 103 |
-
"Before touching any LLM, we run scripted agents to establish a **baseline leaderboard**.\n",
|
| 104 |
-
"This proves the environment can differentiate skill levels."
|
| 105 |
-
]
|
| 106 |
-
},
|
| 107 |
-
{
|
| 108 |
-
"cell_type": "code",
|
| 109 |
-
"execution_count": null,
|
| 110 |
-
"metadata": {},
|
| 111 |
-
"outputs": [],
|
| 112 |
-
"source": [
|
| 113 |
-
"_rng = random.Random(42)\n",
|
| 114 |
-
"\n",
|
| 115 |
-
"\n",
|
| 116 |
-
"def plan_always_rest(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 117 |
-
" return ViraltestAction(scheduled_actions=[], notes=\"Rest day.\")\n",
|
| 118 |
-
"\n",
|
| 119 |
-
"\n",
|
| 120 |
-
"def plan_spam(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 121 |
-
" actions = [\n",
|
| 122 |
-
" {\"hour\": h, \"action_type\": \"post\", \"content_type\": \"reel\",\n",
|
| 123 |
-
" \"topic\": \"AI tools\", \"tags\": [\"ai\"], \"intent\": \"watch_bait\"}\n",
|
| 124 |
-
" for h in range(24)\n",
|
| 125 |
-
" ]\n",
|
| 126 |
-
" return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
|
| 127 |
-
"\n",
|
| 128 |
-
"\n",
|
| 129 |
-
"def plan_random(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 130 |
-
" actions = []\n",
|
| 131 |
-
" for h in range(24):\n",
|
| 132 |
-
" if _rng.random() < 0.1:\n",
|
| 133 |
-
" ct = _rng.choice(CONTENT_TYPES)\n",
|
| 134 |
-
" topic = _rng.choice(ALL_TOPICS)\n",
|
| 135 |
-
" tags = _rng.sample(TAG_POOL[:30], min(3, len(TAG_POOL)))\n",
|
| 136 |
-
" intent = _rng.choice(INTENTS)\n",
|
| 137 |
-
" actions.append({\"hour\": h, \"action_type\": \"post\", \"content_type\": ct,\n",
|
| 138 |
-
" \"topic\": topic, \"tags\": tags, \"intent\": intent})\n",
|
| 139 |
-
" return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
|
| 140 |
-
"\n",
|
| 141 |
-
"\n",
|
| 142 |
-
"def plan_minimal(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 143 |
-
" topic = ALL_TOPICS[day % len(ALL_TOPICS)]\n",
|
| 144 |
-
" tags = [TAG_POOL[i % len(TAG_POOL)] for i in range(day, day + 3)]\n",
|
| 145 |
-
" actions = [\n",
|
| 146 |
-
" {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"carousel\",\n",
|
| 147 |
-
" \"topic\": topic, \"tags\": tags, \"intent\": \"save_bait\"},\n",
|
| 148 |
-
" ]\n",
|
| 149 |
-
" return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
|
| 150 |
-
"\n",
|
| 151 |
-
"\n",
|
| 152 |
-
"def plan_smart(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 153 |
-
" \"\"\"Best heuristic: 2 posts at peak hours, varied content types and intents, tag rotation.\"\"\"\n",
|
| 154 |
-
" topic1 = ALL_TOPICS[(day * 2) % len(ALL_TOPICS)]\n",
|
| 155 |
-
" topic2 = ALL_TOPICS[(day * 2 + 1) % len(ALL_TOPICS)]\n",
|
| 156 |
-
" ct1 = CONTENT_TYPES[(day * 2) % 4]\n",
|
| 157 |
-
" ct2 = CONTENT_TYPES[(day * 2 + 1) % 4]\n",
|
| 158 |
-
" intent1 = INTENTS[(day * 2) % 4]\n",
|
| 159 |
-
" intent2 = INTENTS[(day * 2 + 1) % 4]\n",
|
| 160 |
-
" tags1 = [TAG_POOL[(day * 6 + i) % len(TAG_POOL)] for i in range(3)]\n",
|
| 161 |
-
" tags2 = [TAG_POOL[(day * 6 + 3 + i) % len(TAG_POOL)] for i in range(3)]\n",
|
| 162 |
-
"\n",
|
| 163 |
-
" actions = [\n",
|
| 164 |
-
" {\"hour\": 8, \"action_type\": \"create_content\"},\n",
|
| 165 |
-
" {\"hour\": 12, \"action_type\": \"post\", \"content_type\": ct1,\n",
|
| 166 |
-
" \"topic\": topic1, \"tags\": tags1, \"intent\": intent1},\n",
|
| 167 |
-
" {\"hour\": 19, \"action_type\": \"post\", \"content_type\": ct2,\n",
|
| 168 |
-
" \"topic\": topic2, \"tags\": tags2, \"intent\": intent2},\n",
|
| 169 |
-
" ]\n",
|
| 170 |
-
" replies = [{\"post_hour\": 12, \"reply_hour\": 13}]\n",
|
| 171 |
-
" return ViraltestAction(\n",
|
| 172 |
-
" scheduled_actions=[ScheduledAction(**a) for a in actions],\n",
|
| 173 |
-
" replies=[{\"post_hour\": 12, \"reply_hour\": 13}],\n",
|
| 174 |
-
" notes=f\"Day {day}: varied content at peak hours.\",\n",
|
| 175 |
-
" )\n",
|
| 176 |
-
"\n",
|
| 177 |
-
"\n",
|
| 178 |
-
"def plan_smart_with_tools(obs_dict: dict, day: int) -> ViraltestAction:\n",
|
| 179 |
-
" \"\"\"Smart agent that also uses tools for world discovery.\"\"\"\n",
|
| 180 |
-
" tool_calls = []\n",
|
| 181 |
-
" if day <= 3:\n",
|
| 182 |
-
" tool_calls.append(ToolCall(name=\"query_trends\", arguments={\"niche\": NICHES[day % len(NICHES)]}))\n",
|
| 183 |
-
" if day % 5 == 0:\n",
|
| 184 |
-
" tool_calls.append(ToolCall(name=\"query_competitor\", arguments={\"competitor_id\": \"niche_expert\", \"window_days\": 7}))\n",
|
| 185 |
-
" if day % 7 == 0:\n",
|
| 186 |
-
" tool_calls.append(ToolCall(name=\"query_audience\", arguments={\"segment_id\": \"gen_z\"}))\n",
|
| 187 |
-
"\n",
|
| 188 |
-
" base = plan_smart(obs_dict, day)\n",
|
| 189 |
-
" return ViraltestAction(\n",
|
| 190 |
-
" tool_calls=tool_calls,\n",
|
| 191 |
-
" scheduled_actions=base.scheduled_actions,\n",
|
| 192 |
-
" replies=base.replies,\n",
|
| 193 |
-
" notes=f\"Day {day}: tool-assisted planning.\",\n",
|
| 194 |
-
" )\n",
|
| 195 |
-
"\n",
|
| 196 |
-
"\n",
|
| 197 |
-
"BASELINE_AGENTS = {\n",
|
| 198 |
-
" \"always_rest\": plan_always_rest,\n",
|
| 199 |
-
" \"spam\": plan_spam,\n",
|
| 200 |
-
" \"random\": plan_random,\n",
|
| 201 |
-
" \"minimal\": plan_minimal,\n",
|
| 202 |
-
" \"smart\": plan_smart,\n",
|
| 203 |
-
" \"smart_with_tools\": plan_smart_with_tools,\n",
|
| 204 |
-
"}"
|
| 205 |
-
]
|
| 206 |
-
},
|
| 207 |
-
{
|
| 208 |
-
"cell_type": "code",
|
| 209 |
-
"execution_count": null,
|
| 210 |
-
"metadata": {},
|
| 211 |
-
"outputs": [],
|
| 212 |
-
"source": [
|
| 213 |
-
"def run_episode(task: str, plan_fn, seed: int = 42) -> Dict[str, Any]:\n",
|
| 214 |
-
" \"\"\"Run one full 30-day episode and return metrics.\"\"\"\n",
|
| 215 |
-
" env = ViraltestEnvironment()\n",
|
| 216 |
-
" obs = env.reset(task=task, seed=seed)\n",
|
| 217 |
-
" obs_dict = obs.model_dump()\n",
|
| 218 |
-
"\n",
|
| 219 |
-
" rewards = []\n",
|
| 220 |
-
" energies = [obs.creator_energy]\n",
|
| 221 |
-
" followers_hist = [obs.follower_count]\n",
|
| 222 |
-
"\n",
|
| 223 |
-
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 224 |
-
" action = plan_fn(obs_dict, day)\n",
|
| 225 |
-
" obs = env.step(action)\n",
|
| 226 |
-
" obs_dict = obs.model_dump()\n",
|
| 227 |
-
" r = obs.reward if obs.reward is not None else 0.0\n",
|
| 228 |
-
" rewards.append(r)\n",
|
| 229 |
-
" energies.append(obs.creator_energy)\n",
|
| 230 |
-
" followers_hist.append(obs.follower_count)\n",
|
| 231 |
-
" if obs.done:\n",
|
| 232 |
-
" break\n",
|
| 233 |
-
"\n",
|
| 234 |
-
" grader_score = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 235 |
-
"\n",
|
| 236 |
-
" return {\n",
|
| 237 |
-
" \"task\": task,\n",
|
| 238 |
-
" \"steps\": len(rewards),\n",
|
| 239 |
-
" \"total_reward\": sum(rewards),\n",
|
| 240 |
-
" \"avg_reward\": sum(rewards) / len(rewards) if rewards else 0,\n",
|
| 241 |
-
" \"grader_score\": grader_score,\n",
|
| 242 |
-
" \"final_energy\": obs.creator_energy,\n",
|
| 243 |
-
" \"min_energy\": min(energies),\n",
|
| 244 |
-
" \"final_followers\": obs.follower_count,\n",
|
| 245 |
-
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 246 |
-
" \"burned_out\": obs.creator_energy <= 0,\n",
|
| 247 |
-
" \"rewards\": rewards,\n",
|
| 248 |
-
" \"energies\": energies,\n",
|
| 249 |
-
" \"followers\": followers_hist,\n",
|
| 250 |
-
" }\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"\n",
|
| 253 |
-
"print(\"Running heuristic baselines across all tasks...\")\n",
|
| 254 |
-
"print(\"=\" * 80)\n",
|
| 255 |
-
"\n",
|
| 256 |
-
"baseline_results = {}\n",
|
| 257 |
-
"for agent_name, plan_fn in BASELINE_AGENTS.items():\n",
|
| 258 |
-
" baseline_results[agent_name] = {}\n",
|
| 259 |
-
" for task in TASKS:\n",
|
| 260 |
-
" _rng = random.Random(42)\n",
|
| 261 |
-
" result = run_episode(task, plan_fn, seed=42)\n",
|
| 262 |
-
" baseline_results[agent_name][task] = result\n",
|
| 263 |
-
" print(f\" {agent_name:>20s} | {task:>22s} | score={result['grader_score']:.4f} | \"\n",
|
| 264 |
-
" f\"reward={result['total_reward']:.3f} | energy={result['final_energy']:.2f} | \"\n",
|
| 265 |
-
" f\"followers={result['follower_delta']:+d}\")\n",
|
| 266 |
-
" print()\n",
|
| 267 |
-
"\n",
|
| 268 |
-
"print(\"\\n\" + \"=\" * 80)\n",
|
| 269 |
-
"print(\"BASELINE LEADERBOARD (grader_score)\")\n",
|
| 270 |
-
"print(\"=\" * 80)\n",
|
| 271 |
-
"print(f\"{'Agent':<22s} {'engage':>10s} {'strategic':>12s} {'competitive':>14s} {'avg':>8s}\")\n",
|
| 272 |
-
"print(\"-\" * 68)\n",
|
| 273 |
-
"for agent_name in BASELINE_AGENTS:\n",
|
| 274 |
-
" scores = [baseline_results[agent_name][t][\"grader_score\"] for t in TASKS]\n",
|
| 275 |
-
" avg = sum(scores) / len(scores)\n",
|
| 276 |
-
" print(f\"{agent_name:<22s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {avg:>8.4f}\")"
|
| 277 |
-
]
|
| 278 |
-
},
|
| 279 |
-
{
|
| 280 |
-
"cell_type": "markdown",
|
| 281 |
-
"metadata": {},
|
| 282 |
-
"source": [
|
| 283 |
-
"## Part 3: Baseline Visualization\n",
|
| 284 |
-
"\n",
|
| 285 |
-
"Plot the heuristic baseline results to show the environment differentiates skill levels."
|
| 286 |
-
]
|
| 287 |
-
},
|
| 288 |
-
{
|
| 289 |
-
"cell_type": "code",
|
| 290 |
-
"execution_count": null,
|
| 291 |
-
"metadata": {},
|
| 292 |
-
"outputs": [],
|
| 293 |
-
"source": [
|
| 294 |
-
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
| 295 |
-
"agent_names = list(BASELINE_AGENTS.keys())\n",
|
| 296 |
-
"colors = ['#E53935', '#FF9800', '#9E9E9E', '#42A5F5', '#4CAF50', '#2E7D32']\n",
|
| 297 |
-
"\n",
|
| 298 |
-
"for i, task in enumerate(TASKS):\n",
|
| 299 |
-
" scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
|
| 300 |
-
" bars = axes[i].barh(agent_names, scores, color=colors)\n",
|
| 301 |
-
" axes[i].set_title(task.replace(\"monthly_\", \"\").title(), fontsize=13, fontweight='bold')\n",
|
| 302 |
-
" axes[i].set_xlim(0, max(max(scores) * 1.15, 0.01))\n",
|
| 303 |
-
" for bar, score in zip(bars, scores):\n",
|
| 304 |
-
" axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
|
| 305 |
-
" f\"{score:.3f}\", va='center', fontsize=9)\n",
|
| 306 |
-
"\n",
|
| 307 |
-
"axes[0].set_ylabel(\"Agent\")\n",
|
| 308 |
-
"fig.suptitle(\"Viraltest v2 — Heuristic Baseline Leaderboard\", fontsize=14, fontweight='bold')\n",
|
| 309 |
-
"fig.tight_layout()\n",
|
| 310 |
-
"fig.savefig(PLOTS_DIR / \"baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 311 |
-
"plt.show()\n",
|
| 312 |
-
"print(f\"Saved {PLOTS_DIR / 'baseline_leaderboard.png'}\")"
|
| 313 |
-
]
|
| 314 |
-
},
|
| 315 |
-
{
|
| 316 |
-
"cell_type": "code",
|
| 317 |
-
"execution_count": null,
|
| 318 |
-
"metadata": {},
|
| 319 |
-
"outputs": [],
|
| 320 |
-
"source": [
|
| 321 |
-
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
| 322 |
-
"\n",
|
| 323 |
-
"for i, task in enumerate(TASKS):\n",
|
| 324 |
-
" for j, agent_name in enumerate(agent_names):\n",
|
| 325 |
-
" result = baseline_results[agent_name][task]\n",
|
| 326 |
-
" axes[0, i].plot(result[\"rewards\"], label=agent_name, color=colors[j], alpha=0.8)\n",
|
| 327 |
-
" axes[1, i].plot(result[\"energies\"], label=agent_name, color=colors[j], alpha=0.8)\n",
|
| 328 |
-
"\n",
|
| 329 |
-
" axes[0, i].set_title(f\"{task.replace('monthly_', '').title()} — Rewards\", fontsize=11)\n",
|
| 330 |
-
" axes[0, i].set_xlabel(\"Day\")\n",
|
| 331 |
-
" axes[0, i].set_ylabel(\"Reward\")\n",
|
| 332 |
-
" axes[0, i].grid(True, alpha=0.3)\n",
|
| 333 |
-
"\n",
|
| 334 |
-
" axes[1, i].set_title(f\"{task.replace('monthly_', '').title()} — Energy\", fontsize=11)\n",
|
| 335 |
-
" axes[1, i].set_xlabel(\"Day\")\n",
|
| 336 |
-
" axes[1, i].set_ylabel(\"Energy\")\n",
|
| 337 |
-
" axes[1, i].grid(True, alpha=0.3)\n",
|
| 338 |
-
"\n",
|
| 339 |
-
"axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)\n",
|
| 340 |
-
"fig.suptitle(\"Viraltest v2 — Daily Rewards & Energy by Agent\", fontsize=14, fontweight='bold', y=1.01)\n",
|
| 341 |
-
"fig.tight_layout()\n",
|
| 342 |
-
"fig.savefig(PLOTS_DIR / \"baseline_trajectories.png\", dpi=150, bbox_inches='tight')\n",
|
| 343 |
-
"plt.show()\n",
|
| 344 |
-
"print(f\"Saved {PLOTS_DIR / 'baseline_trajectories.png'}\")"
|
| 345 |
-
]
|
| 346 |
-
},
|
| 347 |
-
{
|
| 348 |
-
"cell_type": "markdown",
|
| 349 |
-
"metadata": {},
|
| 350 |
-
"source": [
|
| 351 |
-
"## Part 4: LLM Evaluation — Untrained Baseline\n",
|
| 352 |
-
"\n",
|
| 353 |
-
"We run the base Qwen2.5-1.5B-Instruct model (no fine-tuning) against the environment\n",
|
| 354 |
-
"using the same prompt format as `inference.py`. This gives us the **before** scores.\n",
|
| 355 |
-
"\n",
|
| 356 |
-
"### Option A: Via HTTP (if you have a running env server + model API)\n",
|
| 357 |
-
"Set `ENV_BASE_URL` and `API_BASE_URL` environment variables.\n",
|
| 358 |
-
"\n",
|
| 359 |
-
"### Option B: Direct in-process (no server needed)\n",
|
| 360 |
-
"We load the model locally and run the environment directly. This is what we do below."
|
| 361 |
-
]
|
| 362 |
-
},
|
| 363 |
-
{
|
| 364 |
-
"cell_type": "code",
|
| 365 |
-
"execution_count": null,
|
| 366 |
-
"metadata": {},
|
| 367 |
-
"outputs": [],
|
| 368 |
-
"source": [
|
| 369 |
-
"import textwrap\n",
|
| 370 |
-
"import torch\n",
|
| 371 |
-
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 372 |
-
"\n",
|
| 373 |
-
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 374 |
-
"\n",
|
| 375 |
-
"print(f\"Loading {MODEL_NAME}...\")\n",
|
| 376 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 377 |
-
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 378 |
-
" MODEL_NAME,\n",
|
| 379 |
-
" trust_remote_code=True,\n",
|
| 380 |
-
" torch_dtype=torch.float16,\n",
|
| 381 |
-
" device_map=\"auto\",\n",
|
| 382 |
-
")\n",
|
| 383 |
-
"model.eval()\n",
|
| 384 |
-
"print(f\"Model loaded on {model.device}\")"
|
| 385 |
-
]
|
| 386 |
-
},
|
| 387 |
-
{
|
| 388 |
-
"cell_type": "code",
|
| 389 |
-
"execution_count": null,
|
| 390 |
-
"metadata": {},
|
| 391 |
-
"outputs": [],
|
| 392 |
-
"source": [
|
| 393 |
-
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
| 394 |
-
"You are an Instagram content strategy agent. Each step is one full day (24 hours).\n",
|
| 395 |
-
"You manage a creator account over a 30-day monthly cycle.\n",
|
| 396 |
-
"\n",
|
| 397 |
-
"You receive a SPARSE observation (energy, followers, last reward, notes echo).\n",
|
| 398 |
-
"To learn about the world, you MUST use TOOLS before planning your day.\n",
|
| 399 |
-
"\n",
|
| 400 |
-
"AVAILABLE TOOLS (call via tool_calls before scheduling posts):\n",
|
| 401 |
-
"- query_trends(niche): Get trending topics and tags for a niche\n",
|
| 402 |
-
"- query_competitor(competitor_id, window_days): See competitor activity\n",
|
| 403 |
-
"- query_tag_history(tag): Check your past performance with a tag\n",
|
| 404 |
-
"- query_audience(segment_id): Learn audience segment preferences\n",
|
| 405 |
-
"- predict_engagement(scheduled_actions): Simulate engagement without committing\n",
|
| 406 |
-
"- draft_review(scheduled_actions): Get feedback on a draft plan\n",
|
| 407 |
-
"\n",
|
| 408 |
-
"RESPONSE FORMAT (JSON only, no markdown, no prose):\n",
|
| 409 |
-
"{\n",
|
| 410 |
-
" \"tool_calls\": [\n",
|
| 411 |
-
" {\"name\": \"query_trends\", \"arguments\": {\"niche\": \"tech\"}}\n",
|
| 412 |
-
" ],\n",
|
| 413 |
-
" \"scheduled_actions\": [\n",
|
| 414 |
-
" {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"reel\", \"topic\": \"AI tools\", \"tags\": [\"ai\", \"coding\"], \"intent\": \"watch_bait\"},\n",
|
| 415 |
-
" {\"hour\": 19, \"action_type\": \"post\", \"content_type\": \"carousel\", \"topic\": \"startup life\", \"tags\": [\"startup\"], \"intent\": \"save_bait\"}\n",
|
| 416 |
-
" ],\n",
|
| 417 |
-
" \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
|
| 418 |
-
" \"notes\": \"Day 3: tech niche trending up.\"\n",
|
| 419 |
-
"}\n",
|
| 420 |
-
"\n",
|
| 421 |
-
"RULES:\n",
|
| 422 |
-
"- hour: 0-23. content_type: reel|story|carousel|text_post. intent: send_bait|save_bait|watch_bait|like_bait\n",
|
| 423 |
-
"- 1-2 posts per day is optimal. More causes audience fatigue.\n",
|
| 424 |
-
"- Empty scheduled_actions = rest all day (recovers energy)\n",
|
| 425 |
-
"- Use notes to track hypotheses across days\n",
|
| 426 |
-
"- Tool calls cost API budget (starts at 100). Use wisely.\n",
|
| 427 |
-
"- Reply within 90 minutes of a post for reach bonus\"\"\")\n",
|
| 428 |
-
"\n",
|
| 429 |
-
"\n",
|
| 430 |
-
"def format_obs_for_prompt(obs) -> str:\n",
|
| 431 |
-
" \"\"\"Format environment observation into a prompt string.\"\"\"\n",
|
| 432 |
-
" days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 433 |
-
" day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 434 |
-
" notes_echo = getattr(obs, \"agent_notes\", None) or \"none\"\n",
|
| 435 |
-
" budget = getattr(obs, \"api_budget_remaining\", 100)\n",
|
| 436 |
-
" burnout = getattr(obs, \"burnout_risk\", 0.0)\n",
|
| 437 |
-
"\n",
|
| 438 |
-
" tool_results_str = \"\"\n",
|
| 439 |
-
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 440 |
-
" if tr.success:\n",
|
| 441 |
-
" tool_results_str += f\" {tr.name}: {json.dumps(tr.data)[:200]}\\n\"\n",
|
| 442 |
-
" else:\n",
|
| 443 |
-
" tool_results_str += f\" {tr.name}: ERROR - {tr.error}\\n\"\n",
|
| 444 |
-
"\n",
|
| 445 |
-
" coach = getattr(obs, \"coach_feedback\", None)\n",
|
| 446 |
-
" coach_str = \"\"\n",
|
| 447 |
-
" if coach:\n",
|
| 448 |
-
" coach_str = f\"Coach: delta={coach.get('delta', 0):.3f}, suggestion={coach.get('suggestion', '')}\\n\"\n",
|
| 449 |
-
"\n",
|
| 450 |
-
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
| 451 |
-
" signals_str = \"\"\n",
|
| 452 |
-
" if signals:\n",
|
| 453 |
-
" signals_str = (\n",
|
| 454 |
-
" f\"Signals: watch={signals.watch_time:.3f} sends={signals.sends_per_reach:.3f} \"\n",
|
| 455 |
-
" f\"saves={signals.saves:.3f} likes={signals.likes_per_reach:.3f}\\n\"\n",
|
| 456 |
-
" )\n",
|
| 457 |
-
"\n",
|
| 458 |
-
" return textwrap.dedent(f\"\"\"\\\n",
|
| 459 |
-
"Day: {day_name} (day_of_week={obs.day_of_week}) | days_elapsed={obs.days_elapsed}\n",
|
| 460 |
-
"Energy: {obs.creator_energy:.2f} | Burnout risk: {burnout:.2f} | Followers: {obs.follower_count}\n",
|
| 461 |
-
"Engagement rate: {obs.engagement_rate:.3f} | Content queue: {obs.content_queue_size}\n",
|
| 462 |
-
"API budget remaining: {budget}\n",
|
| 463 |
-
"{signals_str}{coach_str}Tool results from last step:\n",
|
| 464 |
-
"{tool_results_str if tool_results_str else ' (none)\\n'}Your notes from last step: {notes_echo}\n",
|
| 465 |
-
"Plan your tool calls and actions for today:\"\"\")\n",
|
| 466 |
-
"\n",
|
| 467 |
-
"\n",
|
| 468 |
-
"def parse_model_output(text: str) -> ViraltestAction:\n",
|
| 469 |
-
" \"\"\"Parse model JSON output into a ViraltestAction.\"\"\"\n",
|
| 470 |
-
" text = text.strip()\n",
|
| 471 |
-
" if text.startswith(\"```\"):\n",
|
| 472 |
-
" lines = text.split(\"\\n\")\n",
|
| 473 |
-
" lines = [l for l in lines if not l.strip().startswith(\"```\")]\n",
|
| 474 |
-
" text = \"\\n\".join(lines).strip()\n",
|
| 475 |
-
"\n",
|
| 476 |
-
" try:\n",
|
| 477 |
-
" data = json.loads(text)\n",
|
| 478 |
-
" tool_calls = []\n",
|
| 479 |
-
" for tc in data.get(\"tool_calls\", []):\n",
|
| 480 |
-
" if isinstance(tc, dict) and \"name\" in tc:\n",
|
| 481 |
-
" tool_calls.append(ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {})))\n",
|
| 482 |
-
"\n",
|
| 483 |
-
" scheduled = []\n",
|
| 484 |
-
" for a in data.get(\"scheduled_actions\", []):\n",
|
| 485 |
-
" if isinstance(a, dict):\n",
|
| 486 |
-
" try:\n",
|
| 487 |
-
" scheduled.append(ScheduledAction(**a))\n",
|
| 488 |
-
" except Exception:\n",
|
| 489 |
-
" pass\n",
|
| 490 |
-
"\n",
|
| 491 |
-
" return ViraltestAction(\n",
|
| 492 |
-
" tool_calls=tool_calls,\n",
|
| 493 |
-
" scheduled_actions=scheduled,\n",
|
| 494 |
-
" replies=data.get(\"replies\", []),\n",
|
| 495 |
-
" notes=data.get(\"notes\"),\n",
|
| 496 |
-
" )\n",
|
| 497 |
-
" except (json.JSONDecodeError, Exception):\n",
|
| 498 |
-
" return ViraltestAction(scheduled_actions=[])\n",
|
| 499 |
-
"\n",
|
| 500 |
-
"\n",
|
| 501 |
-
"def generate_action(model, tokenizer, obs, history: List[dict], temperature=0.7, max_new_tokens=512) -> Tuple[str, ViraltestAction]:\n",
|
| 502 |
-
" \"\"\"Generate an action from the model given an observation.\"\"\"\n",
|
| 503 |
-
" user_prompt = format_obs_for_prompt(obs)\n",
|
| 504 |
-
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
|
| 505 |
-
" messages.extend(history[-4:])\n",
|
| 506 |
-
" messages.append({\"role\": \"user\", \"content\": user_prompt})\n",
|
| 507 |
-
"\n",
|
| 508 |
-
" text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 509 |
-
" inputs = tokenizer(text_input, return_tensors=\"pt\").to(model.device)\n",
|
| 510 |
-
"\n",
|
| 511 |
-
" with torch.no_grad():\n",
|
| 512 |
-
" output_ids = model.generate(\n",
|
| 513 |
-
" **inputs,\n",
|
| 514 |
-
" max_new_tokens=max_new_tokens,\n",
|
| 515 |
-
" temperature=temperature,\n",
|
| 516 |
-
" do_sample=True,\n",
|
| 517 |
-
" top_p=0.9,\n",
|
| 518 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 519 |
-
" )\n",
|
| 520 |
-
"\n",
|
| 521 |
-
" new_tokens = output_ids[0][inputs[\"input_ids\"].shape[1]:]\n",
|
| 522 |
-
" response = tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
|
| 523 |
-
" action = parse_model_output(response)\n",
|
| 524 |
-
" return response, action\n",
|
| 525 |
-
"\n",
|
| 526 |
-
"print(\"LLM agent functions defined.\")"
|
| 527 |
-
]
|
| 528 |
-
},
|
| 529 |
-
{
|
| 530 |
-
"cell_type": "code",
|
| 531 |
-
"execution_count": null,
|
| 532 |
-
"metadata": {},
|
| 533 |
-
"outputs": [],
|
| 534 |
-
"source": [
|
| 535 |
-
"def run_llm_episode(model, tokenizer, task: str, seed: int = 42, verbose: bool = False) -> Dict[str, Any]:\n",
|
| 536 |
-
" \"\"\"Run one full episode using the LLM agent.\"\"\"\n",
|
| 537 |
-
" env = ViraltestEnvironment()\n",
|
| 538 |
-
" obs = env.reset(task=task, seed=seed)\n",
|
| 539 |
-
"\n",
|
| 540 |
-
" rewards = []\n",
|
| 541 |
-
" energies = [obs.creator_energy]\n",
|
| 542 |
-
" history = []\n",
|
| 543 |
-
" prompts_and_responses = []\n",
|
| 544 |
-
"\n",
|
| 545 |
-
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 546 |
-
" if obs.done:\n",
|
| 547 |
-
" break\n",
|
| 548 |
-
"\n",
|
| 549 |
-
" if obs.creator_energy <= 0.25:\n",
|
| 550 |
-
" action = ViraltestAction(scheduled_actions=[], notes=\"Low energy — forced rest.\")\n",
|
| 551 |
-
" response_text = '{\"scheduled_actions\": [], \"notes\": \"Low energy — rest.\"}'\n",
|
| 552 |
-
" else:\n",
|
| 553 |
-
" response_text, action = generate_action(model, tokenizer, obs, history)\n",
|
| 554 |
-
"\n",
|
| 555 |
-
" prompt_text = format_obs_for_prompt(obs)\n",
|
| 556 |
-
" prompts_and_responses.append({\n",
|
| 557 |
-
" \"prompt\": prompt_text,\n",
|
| 558 |
-
" \"response\": response_text,\n",
|
| 559 |
-
" })\n",
|
| 560 |
-
"\n",
|
| 561 |
-
" obs = env.step(action)\n",
|
| 562 |
-
" r = obs.reward if obs.reward is not None else 0.0\n",
|
| 563 |
-
" rewards.append(r)\n",
|
| 564 |
-
" energies.append(obs.creator_energy)\n",
|
| 565 |
-
"\n",
|
| 566 |
-
" history.append({\"role\": \"user\", \"content\": prompt_text})\n",
|
| 567 |
-
" history.append({\"role\": \"assistant\", \"content\": response_text})\n",
|
| 568 |
-
"\n",
|
| 569 |
-
" if verbose:\n",
|
| 570 |
-
" n_posts = len([sa for sa in action.scheduled_actions if sa.action_type == \"post\"])\n",
|
| 571 |
-
" n_tools = len(action.tool_calls)\n",
|
| 572 |
-
" print(f\" Day {day:2d}: reward={r:.4f} energy={obs.creator_energy:.2f} \"\n",
|
| 573 |
-
" f\"posts={n_posts} tools={n_tools}\")\n",
|
| 574 |
-
"\n",
|
| 575 |
-
" if obs.done:\n",
|
| 576 |
-
" break\n",
|
| 577 |
-
"\n",
|
| 578 |
-
" grader_score = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 579 |
-
"\n",
|
| 580 |
-
" return {\n",
|
| 581 |
-
" \"task\": task,\n",
|
| 582 |
-
" \"steps\": len(rewards),\n",
|
| 583 |
-
" \"total_reward\": sum(rewards),\n",
|
| 584 |
-
" \"avg_reward\": sum(rewards) / len(rewards) if rewards else 0,\n",
|
| 585 |
-
" \"grader_score\": grader_score,\n",
|
| 586 |
-
" \"final_energy\": obs.creator_energy,\n",
|
| 587 |
-
" \"min_energy\": min(energies),\n",
|
| 588 |
-
" \"final_followers\": obs.follower_count,\n",
|
| 589 |
-
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 590 |
-
" \"burned_out\": obs.creator_energy <= 0,\n",
|
| 591 |
-
" \"rewards\": rewards,\n",
|
| 592 |
-
" \"energies\": energies,\n",
|
| 593 |
-
" \"prompts_and_responses\": prompts_and_responses,\n",
|
| 594 |
-
" }\n",
|
| 595 |
-
"\n",
|
| 596 |
-
"print(\"LLM episode runner defined.\")"
|
| 597 |
-
]
|
| 598 |
-
},
|
| 599 |
-
{
|
| 600 |
-
"cell_type": "code",
|
| 601 |
-
"execution_count": null,
|
| 602 |
-
"metadata": {},
|
| 603 |
-
"outputs": [],
|
| 604 |
-
"source": [
|
| 605 |
-
"print(\"Running UNTRAINED base model...\")\n",
|
| 606 |
-
"print(\"=\" * 60)\n",
|
| 607 |
-
"\n",
|
| 608 |
-
"before_results = {}\n",
|
| 609 |
-
"for task in TASKS:\n",
|
| 610 |
-
" print(f\"\\nTask: {task}\")\n",
|
| 611 |
-
" result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
|
| 612 |
-
" before_results[task] = result\n",
|
| 613 |
-
" print(f\" => grader_score={result['grader_score']:.4f}, \"\n",
|
| 614 |
-
" f\"total_reward={result['total_reward']:.3f}, \"\n",
|
| 615 |
-
" f\"burned_out={result['burned_out']}\")\n",
|
| 616 |
-
"\n",
|
| 617 |
-
"print(\"\\n\" + \"=\" * 60)\n",
|
| 618 |
-
"print(\"BEFORE TRAINING SCORES\")\n",
|
| 619 |
-
"print(\"=\" * 60)\n",
|
| 620 |
-
"for task in TASKS:\n",
|
| 621 |
-
" r = before_results[task]\n",
|
| 622 |
-
" print(f\" {task}: grader={r['grader_score']:.4f} reward={r['total_reward']:.3f} energy={r['final_energy']:.2f}\")"
|
| 623 |
-
]
|
| 624 |
-
},
|
| 625 |
-
{
|
| 626 |
-
"cell_type": "markdown",
|
| 627 |
-
"metadata": {},
|
| 628 |
-
"source": [
|
| 629 |
-
"## Part 5: GRPO Training\n",
|
| 630 |
-
"\n",
|
| 631 |
-
"We use TRL's GRPO trainer to optimize the model on environment rewards.\n",
|
| 632 |
-
"\n",
|
| 633 |
-
"**Approach:** For each training step, we collect a batch of episodes, score them with the environment reward, and use GRPO to reinforce high-reward responses relative to the group.\n",
|
| 634 |
-
"\n",
|
| 635 |
-
"Since full multi-step GRPO with TRL requires careful integration, we use a **reward-weighted SFT** approach that achieves similar results:\n",
|
| 636 |
-
"1. Collect N episodes with the current model\n",
|
| 637 |
-
"2. Weight each (prompt, response) pair by its environment reward\n",
|
| 638 |
-
"3. Fine-tune on the reward-weighted dataset\n",
|
| 639 |
-
"4. Repeat for multiple rounds"
|
| 640 |
-
]
|
| 641 |
-
},
|
| 642 |
-
{
|
| 643 |
-
"cell_type": "code",
|
| 644 |
-
"execution_count": null,
|
| 645 |
-
"metadata": {},
|
| 646 |
-
"outputs": [],
|
| 647 |
-
"source": [
|
| 648 |
-
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 649 |
-
"from transformers import TrainingArguments\n",
|
| 650 |
-
"from trl import SFTTrainer, SFTConfig\n",
|
| 651 |
-
"from datasets import Dataset\n",
|
| 652 |
-
"\n",
|
| 653 |
-
"lora_config = LoraConfig(\n",
|
| 654 |
-
" r=16,\n",
|
| 655 |
-
" lora_alpha=32,\n",
|
| 656 |
-
" lora_dropout=0.05,\n",
|
| 657 |
-
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 658 |
-
" task_type=TaskType.CAUSAL_LM,\n",
|
| 659 |
-
" bias=\"none\",\n",
|
| 660 |
-
")\n",
|
| 661 |
-
"\n",
|
| 662 |
-
"model.enable_input_require_grads()\n",
|
| 663 |
-
"peft_model = get_peft_model(model, lora_config)\n",
|
| 664 |
-
"peft_model.print_trainable_parameters()\n",
|
| 665 |
-
"print(\"LoRA adapter attached.\")"
|
| 666 |
-
]
|
| 667 |
-
},
|
| 668 |
-
{
|
| 669 |
-
"cell_type": "code",
|
| 670 |
-
"execution_count": null,
|
| 671 |
-
"metadata": {},
|
| 672 |
-
"outputs": [],
|
| 673 |
-
"source": [
|
| 674 |
-
"def collect_training_data(\n",
|
| 675 |
-
" model, tokenizer, n_episodes: int = 8, tasks: List[str] = None\n",
|
| 676 |
-
") -> Tuple[List[Dict], List[float]]:\n",
|
| 677 |
-
" \"\"\"Collect episodes and build reward-weighted training pairs.\"\"\"\n",
|
| 678 |
-
" tasks = tasks or TASKS\n",
|
| 679 |
-
" all_pairs = []\n",
|
| 680 |
-
" all_episode_rewards = []\n",
|
| 681 |
-
"\n",
|
| 682 |
-
" for ep in range(n_episodes):\n",
|
| 683 |
-
" task = tasks[ep % len(tasks)]\n",
|
| 684 |
-
" seed = 42 + ep\n",
|
| 685 |
-
" result = run_llm_episode(model, tokenizer, task, seed=seed)\n",
|
| 686 |
-
" episode_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 687 |
-
" all_episode_rewards.append(episode_reward)\n",
|
| 688 |
-
"\n",
|
| 689 |
-
" for pr in result[\"prompts_and_responses\"]:\n",
|
| 690 |
-
" step_text = (\n",
|
| 691 |
-
" f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 692 |
-
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 693 |
-
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\"\n",
|
| 694 |
-
" )\n",
|
| 695 |
-
" all_pairs.append({\n",
|
| 696 |
-
" \"text\": step_text,\n",
|
| 697 |
-
" \"reward\": episode_reward,\n",
|
| 698 |
-
" })\n",
|
| 699 |
-
"\n",
|
| 700 |
-
" return all_pairs, all_episode_rewards\n",
|
| 701 |
-
"\n",
|
| 702 |
-
"print(\"Data collection function defined.\")"
|
| 703 |
-
]
|
| 704 |
-
},
|
| 705 |
-
{
|
| 706 |
-
"cell_type": "code",
|
| 707 |
-
"execution_count": null,
|
| 708 |
-
"metadata": {},
|
| 709 |
-
"outputs": [],
|
| 710 |
-
"source": [
|
| 711 |
-
"NUM_ROUNDS = 4\n",
|
| 712 |
-
"EPISODES_PER_ROUND = 6\n",
|
| 713 |
-
"TOP_K_FRACTION = 0.5\n",
|
| 714 |
-
"\n",
|
| 715 |
-
"training_log = {\n",
|
| 716 |
-
" \"round\": [],\n",
|
| 717 |
-
" \"avg_episode_reward\": [],\n",
|
| 718 |
-
" \"max_episode_reward\": [],\n",
|
| 719 |
-
" \"min_episode_reward\": [],\n",
|
| 720 |
-
" \"n_training_samples\": [],\n",
|
| 721 |
-
" \"train_loss\": [],\n",
|
| 722 |
-
"}\n",
|
| 723 |
-
"\n",
|
| 724 |
-
"for round_idx in range(1, NUM_ROUNDS + 1):\n",
|
| 725 |
-
" print(f\"\\n{'=' * 60}\")\n",
|
| 726 |
-
" print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
|
| 727 |
-
" print(f\"{'=' * 60}\")\n",
|
| 728 |
-
"\n",
|
| 729 |
-
" print(f\"Collecting {EPISODES_PER_ROUND} episodes...\")\n",
|
| 730 |
-
" peft_model.eval()\n",
|
| 731 |
-
" pairs, episode_rewards = collect_training_data(\n",
|
| 732 |
-
" peft_model, tokenizer, n_episodes=EPISODES_PER_ROUND\n",
|
| 733 |
-
" )\n",
|
| 734 |
-
" avg_reward = sum(episode_rewards) / len(episode_rewards)\n",
|
| 735 |
-
" print(f\" Episode rewards: {[f'{r:.3f}' for r in episode_rewards]}\")\n",
|
| 736 |
-
" print(f\" Avg: {avg_reward:.3f}, Max: {max(episode_rewards):.3f}, Min: {min(episode_rewards):.3f}\")\n",
|
| 737 |
-
"\n",
|
| 738 |
-
" if not pairs:\n",
|
| 739 |
-
" print(\" No training pairs collected, skipping round.\")\n",
|
| 740 |
-
" continue\n",
|
| 741 |
-
"\n",
|
| 742 |
-
" reward_threshold = np.percentile(\n",
|
| 743 |
-
" [p[\"reward\"] for p in pairs],\n",
|
| 744 |
-
" (1 - TOP_K_FRACTION) * 100\n",
|
| 745 |
-
" )\n",
|
| 746 |
-
" filtered = [p for p in pairs if p[\"reward\"] >= reward_threshold]\n",
|
| 747 |
-
" print(f\" Filtered to {len(filtered)}/{len(pairs)} samples (reward >= {reward_threshold:.3f})\")\n",
|
| 748 |
-
"\n",
|
| 749 |
-
" if not filtered:\n",
|
| 750 |
-
" print(\" No samples above threshold, using all.\")\n",
|
| 751 |
-
" filtered = pairs\n",
|
| 752 |
-
"\n",
|
| 753 |
-
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 754 |
-
"\n",
|
| 755 |
-
" output_dir = f\"./viraltest_checkpoints/round_{round_idx}\"\n",
|
| 756 |
-
" sft_config = SFTConfig(\n",
|
| 757 |
-
" output_dir=output_dir,\n",
|
| 758 |
-
" num_train_epochs=2,\n",
|
| 759 |
-
" per_device_train_batch_size=1,\n",
|
| 760 |
-
" gradient_accumulation_steps=4,\n",
|
| 761 |
-
" learning_rate=2e-5,\n",
|
| 762 |
-
" warmup_steps=5,\n",
|
| 763 |
-
" logging_steps=5,\n",
|
| 764 |
-
" save_strategy=\"no\",\n",
|
| 765 |
-
" max_seq_length=1024,\n",
|
| 766 |
-
" fp16=True,\n",
|
| 767 |
-
" report_to=\"none\",\n",
|
| 768 |
-
" )\n",
|
| 769 |
-
"\n",
|
| 770 |
-
" print(f\" Training on {len(dataset)} samples...\")\n",
|
| 771 |
-
" peft_model.train()\n",
|
| 772 |
-
" trainer = SFTTrainer(\n",
|
| 773 |
-
" model=peft_model,\n",
|
| 774 |
-
" tokenizer=tokenizer,\n",
|
| 775 |
-
" train_dataset=dataset,\n",
|
| 776 |
-
" args=sft_config,\n",
|
| 777 |
-
" )\n",
|
| 778 |
-
" train_result = trainer.train()\n",
|
| 779 |
-
" train_loss = train_result.training_loss\n",
|
| 780 |
-
" print(f\" Training loss: {train_loss:.4f}\")\n",
|
| 781 |
-
"\n",
|
| 782 |
-
" training_log[\"round\"].append(round_idx)\n",
|
| 783 |
-
" training_log[\"avg_episode_reward\"].append(avg_reward)\n",
|
| 784 |
-
" training_log[\"max_episode_reward\"].append(max(episode_rewards))\n",
|
| 785 |
-
" training_log[\"min_episode_reward\"].append(min(episode_rewards))\n",
|
| 786 |
-
" training_log[\"n_training_samples\"].append(len(filtered))\n",
|
| 787 |
-
" training_log[\"train_loss\"].append(train_loss)\n",
|
| 788 |
-
"\n",
|
| 789 |
-
"print(\"\\n\" + \"=\" * 60)\n",
|
| 790 |
-
"print(\"TRAINING COMPLETE\")\n",
|
| 791 |
-
"print(\"=\" * 60)\n",
|
| 792 |
-
"\n",
|
| 793 |
-
"train_df = pd.DataFrame(training_log)\n",
|
| 794 |
-
"print(train_df.to_string(index=False))\n",
|
| 795 |
-
"\n",
|
| 796 |
-
"train_df.to_csv(PLOTS_DIR / \"training_log.csv\", index=False)\n",
|
| 797 |
-
"print(f\"\\nSaved training log to {PLOTS_DIR / 'training_log.csv'}\")"
|
| 798 |
-
]
|
| 799 |
-
},
|
| 800 |
-
{
|
| 801 |
-
"cell_type": "markdown",
|
| 802 |
-
"metadata": {},
|
| 803 |
-
"source": [
|
| 804 |
-
"## Part 6: Post-Training Evaluation\n",
|
| 805 |
-
"\n",
|
| 806 |
-
"Run the trained model on all three tasks and compare with before-training scores."
|
| 807 |
-
]
|
| 808 |
-
},
|
| 809 |
-
{
|
| 810 |
-
"cell_type": "code",
|
| 811 |
-
"execution_count": null,
|
| 812 |
-
"metadata": {},
|
| 813 |
-
"outputs": [],
|
| 814 |
-
"source": [
|
| 815 |
-
"print(\"Running TRAINED model...\")\n",
|
| 816 |
-
"print(\"=\" * 60)\n",
|
| 817 |
-
"\n",
|
| 818 |
-
"peft_model.eval()\n",
|
| 819 |
-
"\n",
|
| 820 |
-
"after_results = {}\n",
|
| 821 |
-
"for task in TASKS:\n",
|
| 822 |
-
" print(f\"\\nTask: {task}\")\n",
|
| 823 |
-
" result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
|
| 824 |
-
" after_results[task] = result\n",
|
| 825 |
-
" print(f\" => grader_score={result['grader_score']:.4f}, \"\n",
|
| 826 |
-
" f\"total_reward={result['total_reward']:.3f}, \"\n",
|
| 827 |
-
" f\"burned_out={result['burned_out']}\")\n",
|
| 828 |
-
"\n",
|
| 829 |
-
"print(\"\\n\" + \"=\" * 60)\n",
|
| 830 |
-
"print(\"AFTER TRAINING SCORES\")\n",
|
| 831 |
-
"print(\"=\" * 60)\n",
|
| 832 |
-
"for task in TASKS:\n",
|
| 833 |
-
" r = after_results[task]\n",
|
| 834 |
-
" print(f\" {task}: grader={r['grader_score']:.4f} reward={r['total_reward']:.3f} energy={r['final_energy']:.2f}\")"
|
| 835 |
-
]
|
| 836 |
-
},
|
| 837 |
-
{
|
| 838 |
-
"cell_type": "markdown",
|
| 839 |
-
"metadata": {},
|
| 840 |
-
"source": [
|
| 841 |
-
"## Part 7: Result Plots — Real Training Evidence"
|
| 842 |
-
]
|
| 843 |
-
},
|
| 844 |
-
{
|
| 845 |
-
"cell_type": "code",
|
| 846 |
-
"execution_count": null,
|
| 847 |
-
"metadata": {},
|
| 848 |
-
"outputs": [],
|
| 849 |
-
"source": [
|
| 850 |
-
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 851 |
-
"\n",
|
| 852 |
-
"rounds = training_log[\"round\"]\n",
|
| 853 |
-
"axes[0].plot(rounds, training_log[\"avg_episode_reward\"], 'o-', color='#2196F3', linewidth=2, label='Avg reward')\n",
|
| 854 |
-
"axes[0].fill_between(rounds, training_log[\"min_episode_reward\"], training_log[\"max_episode_reward\"],\n",
|
| 855 |
-
" alpha=0.2, color='#2196F3', label='Min-Max range')\n",
|
| 856 |
-
"axes[0].set_xlabel('Training Round', fontsize=12)\n",
|
| 857 |
-
"axes[0].set_ylabel('Episode Reward', fontsize=12)\n",
|
| 858 |
-
"axes[0].set_title('Training Reward Over Rounds', fontsize=13, fontweight='bold')\n",
|
| 859 |
-
"axes[0].legend()\n",
|
| 860 |
-
"axes[0].grid(True, alpha=0.3)\n",
|
| 861 |
-
"\n",
|
| 862 |
-
"axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', linewidth=2)\n",
|
| 863 |
-
"axes[1].set_xlabel('Training Round', fontsize=12)\n",
|
| 864 |
-
"axes[1].set_ylabel('Training Loss', fontsize=12)\n",
|
| 865 |
-
"axes[1].set_title('Training Loss Over Rounds', fontsize=13, fontweight='bold')\n",
|
| 866 |
-
"axes[1].grid(True, alpha=0.3)\n",
|
| 867 |
-
"\n",
|
| 868 |
-
"fig.suptitle('Viraltest v2 — GRPO Training Progress', fontsize=14, fontweight='bold', y=1.02)\n",
|
| 869 |
-
"fig.tight_layout()\n",
|
| 870 |
-
"fig.savefig(PLOTS_DIR / 'reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 871 |
-
"plt.show()\n",
|
| 872 |
-
"print(f\"Saved {PLOTS_DIR / 'reward_curve.png'}\")"
|
| 873 |
-
]
|
| 874 |
-
},
|
| 875 |
-
{
|
| 876 |
-
"cell_type": "code",
|
| 877 |
-
"execution_count": null,
|
| 878 |
-
"metadata": {},
|
| 879 |
-
"outputs": [],
|
| 880 |
-
"source": [
|
| 881 |
-
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
| 882 |
-
"before_scores = [before_results[t][\"grader_score\"] for t in TASKS]\n",
|
| 883 |
-
"after_scores = [after_results[t][\"grader_score\"] for t in TASKS]\n",
|
| 884 |
-
"smart_scores = [baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS]\n",
|
| 885 |
-
"\n",
|
| 886 |
-
"x = np.arange(len(TASKS))\n",
|
| 887 |
-
"width = 0.25\n",
|
| 888 |
-
"\n",
|
| 889 |
-
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
| 890 |
-
"bars1 = ax.bar(x - width, before_scores, width, label='Base Model (Before)', color='#FF9800')\n",
|
| 891 |
-
"bars2 = ax.bar(x, after_scores, width, label='Trained Model (After)', color='#4CAF50')\n",
|
| 892 |
-
"bars3 = ax.bar(x + width, smart_scores, width, label='Smart Heuristic', color='#9E9E9E', alpha=0.7)\n",
|
| 893 |
-
"\n",
|
| 894 |
-
"ax.set_ylabel('Grader Score', fontsize=12)\n",
|
| 895 |
-
"ax.set_title('Before vs After Training — Grader Scores', fontsize=14, fontweight='bold')\n",
|
| 896 |
-
"ax.set_xticks(x)\n",
|
| 897 |
-
"ax.set_xticklabels(task_labels, fontsize=11)\n",
|
| 898 |
-
"ax.legend(fontsize=10)\n",
|
| 899 |
-
"ax.grid(True, alpha=0.3, axis='y')\n",
|
| 900 |
-
"\n",
|
| 901 |
-
"for bars in [bars1, bars2, bars3]:\n",
|
| 902 |
-
" for bar in bars:\n",
|
| 903 |
-
" height = bar.get_height()\n",
|
| 904 |
-
" if height > 0:\n",
|
| 905 |
-
" ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,\n",
|
| 906 |
-
" f'{height:.3f}', ha='center', va='bottom', fontsize=9)\n",
|
| 907 |
-
"\n",
|
| 908 |
-
"fig.tight_layout()\n",
|
| 909 |
-
"fig.savefig(PLOTS_DIR / 'before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 910 |
-
"plt.show()\n",
|
| 911 |
-
"print(f\"Saved {PLOTS_DIR / 'before_after.png'}\")"
|
| 912 |
-
]
|
| 913 |
-
},
|
| 914 |
-
{
|
| 915 |
-
"cell_type": "code",
|
| 916 |
-
"execution_count": null,
|
| 917 |
-
"metadata": {},
|
| 918 |
-
"outputs": [],
|
| 919 |
-
"source": [
|
| 920 |
-
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
| 921 |
-
"\n",
|
| 922 |
-
"labels_and_data = [\n",
|
| 923 |
-
" (\"Base Model\", before_results, '#FF9800'),\n",
|
| 924 |
-
" (\"Trained Model\", after_results, '#4CAF50'),\n",
|
| 925 |
-
"]\n",
|
| 926 |
-
"\n",
|
| 927 |
-
"for i, task in enumerate(TASKS):\n",
|
| 928 |
-
" for label, results, color in labels_and_data:\n",
|
| 929 |
-
" r = results[task]\n",
|
| 930 |
-
" axes[0, i].plot(r[\"rewards\"], label=label, color=color, linewidth=1.5, alpha=0.9)\n",
|
| 931 |
-
" axes[1, i].plot(r[\"energies\"], label=label, color=color, linewidth=1.5, alpha=0.9)\n",
|
| 932 |
-
"\n",
|
| 933 |
-
" smart_r = baseline_results[\"smart\"][task]\n",
|
| 934 |
-
" axes[0, i].plot(smart_r[\"rewards\"], label=\"Smart Heuristic\", color='#9E9E9E',\n",
|
| 935 |
-
" linewidth=1, alpha=0.5, linestyle='--')\n",
|
| 936 |
-
" axes[1, i].plot(smart_r[\"energies\"], label=\"Smart Heuristic\", color='#9E9E9E',\n",
|
| 937 |
-
" linewidth=1, alpha=0.5, linestyle='--')\n",
|
| 938 |
-
"\n",
|
| 939 |
-
" task_title = task.replace('monthly_', '').title()\n",
|
| 940 |
-
" axes[0, i].set_title(f\"{task_title} — Daily Rewards\", fontsize=11)\n",
|
| 941 |
-
" axes[0, i].set_xlabel(\"Day\")\n",
|
| 942 |
-
" axes[0, i].set_ylabel(\"Reward\")\n",
|
| 943 |
-
" axes[0, i].grid(True, alpha=0.3)\n",
|
| 944 |
-
"\n",
|
| 945 |
-
" axes[1, i].set_title(f\"{task_title} — Energy\", fontsize=11)\n",
|
| 946 |
-
" axes[1, i].set_xlabel(\"Day\")\n",
|
| 947 |
-
" axes[1, i].set_ylabel(\"Energy\")\n",
|
| 948 |
-
" axes[1, i].grid(True, alpha=0.3)\n",
|
| 949 |
-
"\n",
|
| 950 |
-
"axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)\n",
|
| 951 |
-
"fig.suptitle('Viraltest v2 — Before vs After Training Trajectories', fontsize=14, fontweight='bold', y=1.01)\n",
|
| 952 |
-
"fig.tight_layout()\n",
|
| 953 |
-
"fig.savefig(PLOTS_DIR / 'training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 954 |
-
"plt.show()\n",
|
| 955 |
-
"print(f\"Saved {PLOTS_DIR / 'training_trajectories.png'}\")"
|
| 956 |
-
]
|
| 957 |
-
},
|
| 958 |
-
{
|
| 959 |
-
"cell_type": "markdown",
|
| 960 |
-
"metadata": {},
|
| 961 |
-
"source": [
|
| 962 |
-
"## Part 8: Summary & Export"
|
| 963 |
-
]
|
| 964 |
-
},
|
| 965 |
-
{
|
| 966 |
-
"cell_type": "code",
|
| 967 |
-
"execution_count": null,
|
| 968 |
-
"metadata": {},
|
| 969 |
-
"outputs": [],
|
| 970 |
-
"source": [
|
| 971 |
-
"print(\"=\" * 70)\n",
|
| 972 |
-
"print(\"FINAL RESULTS SUMMARY\")\n",
|
| 973 |
-
"print(\"=\" * 70)\n",
|
| 974 |
-
"print()\n",
|
| 975 |
-
"print(f\"{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}\")\n",
|
| 976 |
-
"print(\"-\" * 67)\n",
|
| 977 |
-
"for task in TASKS:\n",
|
| 978 |
-
" b = before_results[task][\"grader_score\"]\n",
|
| 979 |
-
" a = after_results[task][\"grader_score\"]\n",
|
| 980 |
-
" s = baseline_results[\"smart\"][task][\"grader_score\"]\n",
|
| 981 |
-
" delta = a - b\n",
|
| 982 |
-
" print(f\"{task:<25s} {b:>10.4f} {a:>10.4f} {delta:>+10.4f} {s:>10.4f}\")\n",
|
| 983 |
-
"\n",
|
| 984 |
-
"avg_before = np.mean([before_results[t][\"grader_score\"] for t in TASKS])\n",
|
| 985 |
-
"avg_after = np.mean([after_results[t][\"grader_score\"] for t in TASKS])\n",
|
| 986 |
-
"avg_smart = np.mean([baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS])\n",
|
| 987 |
-
"print(\"-\" * 67)\n",
|
| 988 |
-
"print(f\"{'AVERAGE':<25s} {avg_before:>10.4f} {avg_after:>10.4f} {avg_after - avg_before:>+10.4f} {avg_smart:>10.4f}\")\n",
|
| 989 |
-
"print()\n",
|
| 990 |
-
"\n",
|
| 991 |
-
"summary = {\n",
|
| 992 |
-
" \"model\": MODEL_NAME,\n",
|
| 993 |
-
" \"training_rounds\": NUM_ROUNDS,\n",
|
| 994 |
-
" \"episodes_per_round\": EPISODES_PER_ROUND,\n",
|
| 995 |
-
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 996 |
-
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 997 |
-
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
| 998 |
-
" \"improvement\": {t: after_results[t][\"grader_score\"] - before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 999 |
-
" \"training_log\": training_log,\n",
|
| 1000 |
-
"}\n",
|
| 1001 |
-
"\n",
|
| 1002 |
-
"with open(PLOTS_DIR / \"training_summary.json\", \"w\") as f:\n",
|
| 1003 |
-
" json.dump(summary, f, indent=2)\n",
|
| 1004 |
-
"\n",
|
| 1005 |
-
"print(f\"Saved summary to {PLOTS_DIR / 'training_summary.json'}\")\n",
|
| 1006 |
-
"print()\n",
|
| 1007 |
-
"print(\"Plots saved:\")\n",
|
| 1008 |
-
"for p in sorted(PLOTS_DIR.glob(\"*.png\")):\n",
|
| 1009 |
-
" print(f\" {p}\")\n",
|
| 1010 |
-
"print()\n",
|
| 1011 |
-
"print(\"Training evidence is now real and reproducible.\")"
|
| 1012 |
-
]
|
| 1013 |
-
},
|
| 1014 |
-
{
|
| 1015 |
-
"cell_type": "code",
|
| 1016 |
-
"execution_count": null,
|
| 1017 |
-
"metadata": {},
|
| 1018 |
-
"outputs": [],
|
| 1019 |
-
"source": [
|
| 1020 |
-
"save_path = \"./viraltest_trained_adapter\"\n",
|
| 1021 |
-
"peft_model.save_pretrained(save_path)\n",
|
| 1022 |
-
"tokenizer.save_pretrained(save_path)\n",
|
| 1023 |
-
"print(f\"Trained adapter saved to {save_path}\")\n",
|
| 1024 |
-
"print(\"To load: model = AutoModelForCausalLM.from_pretrained(...); model = PeftModel.from_pretrained(model, save_path)\")"
|
| 1025 |
-
]
|
| 1026 |
-
}
|
| 1027 |
-
],
|
| 1028 |
-
"metadata": {
|
| 1029 |
-
"kernelspec": {
|
| 1030 |
-
"display_name": "Python 3",
|
| 1031 |
-
"language": "python",
|
| 1032 |
-
"name": "python3"
|
| 1033 |
-
},
|
| 1034 |
-
"language_info": {
|
| 1035 |
-
"name": "python",
|
| 1036 |
-
"version": "3.10.0"
|
| 1037 |
-
}
|
| 1038 |
-
},
|
| 1039 |
-
"nbformat": 4,
|
| 1040 |
-
"nbformat_minor": 4
|
| 1041 |
-
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Viraltest v2 — Real LLM Training with LoRA + Environment Rewards\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook **actually trains** an LLM (Qwen2.5-1.5B-Instruct) to play our Instagram creator simulation.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Pipeline:**\n",
|
| 12 |
+
"1. Clone repo & install deps\n",
|
| 13 |
+
"2. Run 5 heuristic baselines × 3 tasks (15 runs) → leaderboard\n",
|
| 14 |
+
"3. Run **untrained** LLM on all 3 tasks → \"before\" scores\n",
|
| 15 |
+
"4. **LoRA fine-tune** with reward-weighted SFT (4 rounds × 6 episodes = real weight updates)\n",
|
| 16 |
+
"5. Run **trained** LLM on all 3 tasks → \"after\" scores\n",
|
| 17 |
+
"6. Generate real plots from real numbers\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"**Requirements:** Colab T4 GPU (free tier), ~45 min total.\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"**What makes this real training:** LoRA adapter weights are actually updated via gradient descent. The model's behavior changes because its weights change, not because we edit the prompt."
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"source": [
|
| 28 |
+
"# Cell 1: Install dependencies\n",
|
| 29 |
+
"!pip install -q torch torchvision torchaudio\n",
|
| 30 |
+
"!pip install -q transformers>=4.40.0 accelerate peft>=0.10.0 trl>=0.8.0 datasets bitsandbytes\n",
|
| 31 |
+
"!pip install -q matplotlib pandas\n",
|
| 32 |
+
"!pip install -q pydantic httpx\n",
|
| 33 |
+
"!pip install -q \"openenv-core[core]>=0.2.2\""
|
| 34 |
+
],
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"outputs": []
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"source": [
|
| 42 |
+
"# Cell 2: Clone the repo and set up paths\n",
|
| 43 |
+
"import os, sys\n",
|
| 44 |
+
"REPO_DIR = \"/content/viral-posts-env\"\n",
|
| 45 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 46 |
+
" !git clone https://github.com/VaibhavKhandare/viral-posts-env.git {REPO_DIR}\n",
|
| 47 |
+
"os.chdir(REPO_DIR)\n",
|
| 48 |
+
"sys.path.insert(0, REPO_DIR)\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
|
| 51 |
+
"os.makedirs(PLOTS_DIR, exist_ok=True)\n",
|
| 52 |
+
"print(f\"Working dir: {os.getcwd()}\")\n",
|
| 53 |
+
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 54 |
+
],
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"outputs": []
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"source": [
|
| 62 |
+
"# Cell 3: Imports\n",
|
| 63 |
+
"import json, random, time, textwrap, copy\n",
|
| 64 |
+
"from pathlib import Path\n",
|
| 65 |
+
"from typing import Any, Dict, List, Optional, Tuple\n",
|
| 66 |
+
"from collections import defaultdict\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"import matplotlib.pyplot as plt\n",
|
| 69 |
+
"import numpy as np\n",
|
| 70 |
+
"import pandas as pd\n",
|
| 71 |
+
"import torch\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"from models import ScheduledAction, ToolCall, ViraltestAction\n",
|
| 74 |
+
"from server.viraltest_environment import (\n",
|
| 75 |
+
" ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
|
| 76 |
+
" TOPIC_CATEGORIES,\n",
|
| 77 |
+
")\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
| 80 |
+
"NICHES = list(TOPIC_CATEGORIES.keys())\n",
|
| 81 |
+
"CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
|
| 82 |
+
"INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
|
| 83 |
+
"TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 86 |
+
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
|
| 87 |
+
],
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"outputs": []
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"source": [
|
| 95 |
+
"## Part 1: Heuristic Baselines\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"5 scripted agents prove the environment differentiates skill levels."
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"source": [
|
| 104 |
+
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 105 |
+
"_rng = random.Random(42)\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"def plan_always_rest(obs_dict, day):\n",
|
| 108 |
+
" return ViraltestAction(scheduled_actions=[])\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"def plan_spam(obs_dict, day):\n",
|
| 111 |
+
" return ViraltestAction(scheduled_actions=[\n",
|
| 112 |
+
" ScheduledAction(hour=h, action_type=\"post\", content_type=\"reel\",\n",
|
| 113 |
+
" topic=\"AI tools\", tags=[\"ai\"], intent=\"watch_bait\")\n",
|
| 114 |
+
" for h in range(24)])\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"def plan_random(obs_dict, day):\n",
|
| 117 |
+
" actions = []\n",
|
| 118 |
+
" for h in range(24):\n",
|
| 119 |
+
" if _rng.random() < 0.1:\n",
|
| 120 |
+
" actions.append(ScheduledAction(\n",
|
| 121 |
+
" hour=h, action_type=\"post\",\n",
|
| 122 |
+
" content_type=_rng.choice(CONTENT_TYPES),\n",
|
| 123 |
+
" topic=_rng.choice(ALL_TOPICS),\n",
|
| 124 |
+
" tags=_rng.sample(TAG_POOL[:30], 3),\n",
|
| 125 |
+
" intent=_rng.choice(INTENTS)))\n",
|
| 126 |
+
" return ViraltestAction(scheduled_actions=actions)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"def plan_minimal(obs_dict, day):\n",
|
| 129 |
+
" return ViraltestAction(scheduled_actions=[\n",
|
| 130 |
+
" ScheduledAction(hour=12, action_type=\"post\", content_type=\"carousel\",\n",
|
| 131 |
+
" topic=ALL_TOPICS[day % len(ALL_TOPICS)],\n",
|
| 132 |
+
" tags=[TAG_POOL[i % len(TAG_POOL)] for i in range(day, day+3)],\n",
|
| 133 |
+
" intent=\"save_bait\")])\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"def plan_smart(obs_dict, day):\n",
|
| 136 |
+
" return ViraltestAction(\n",
|
| 137 |
+
" tool_calls=[ToolCall(name=\"query_trends\",\n",
|
| 138 |
+
" arguments={\"niche\": NICHES[day % len(NICHES)]})] if day <= 3 else [],\n",
|
| 139 |
+
" scheduled_actions=[\n",
|
| 140 |
+
" ScheduledAction(hour=8, action_type=\"create_content\"),\n",
|
| 141 |
+
" ScheduledAction(hour=12, action_type=\"post\",\n",
|
| 142 |
+
" content_type=CONTENT_TYPES[(day*2)%4],\n",
|
| 143 |
+
" topic=ALL_TOPICS[(day*2)%len(ALL_TOPICS)],\n",
|
| 144 |
+
" tags=[TAG_POOL[(day*6+i)%len(TAG_POOL)] for i in range(3)],\n",
|
| 145 |
+
" intent=INTENTS[(day*2)%4]),\n",
|
| 146 |
+
" ScheduledAction(hour=19, action_type=\"post\",\n",
|
| 147 |
+
" content_type=CONTENT_TYPES[(day*2+1)%4],\n",
|
| 148 |
+
" topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
|
| 149 |
+
" tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
|
| 150 |
+
" intent=INTENTS[(day*2+1)%4]),\n",
|
| 151 |
+
" ],\n",
|
| 152 |
+
" replies=[{\"post_hour\": 12, \"reply_hour\": 13}])\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"BASELINE_AGENTS = {\n",
|
| 155 |
+
" \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
|
| 156 |
+
" \"random\": plan_random, \"minimal\": plan_minimal, \"smart\": plan_smart,\n",
|
| 157 |
+
"}\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"def run_episode(task, plan_fn, seed=42):\n",
|
| 160 |
+
" env = ViraltestEnvironment()\n",
|
| 161 |
+
" obs = env.reset(task=task, seed=seed)\n",
|
| 162 |
+
" obs_dict = obs.model_dump()\n",
|
| 163 |
+
" rewards, energies = [], [obs.creator_energy]\n",
|
| 164 |
+
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 165 |
+
" action = plan_fn(obs_dict, day)\n",
|
| 166 |
+
" obs = env.step(action)\n",
|
| 167 |
+
" obs_dict = obs.model_dump()\n",
|
| 168 |
+
" rewards.append(obs.reward or 0.0)\n",
|
| 169 |
+
" energies.append(obs.creator_energy)\n",
|
| 170 |
+
" if obs.done: break\n",
|
| 171 |
+
" grader = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 172 |
+
" return {\"grader_score\": grader, \"total_reward\": sum(rewards),\n",
|
| 173 |
+
" \"steps\": len(rewards), \"final_energy\": obs.creator_energy,\n",
|
| 174 |
+
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 175 |
+
" \"burned_out\": obs.creator_energy <= 0,\n",
|
| 176 |
+
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"print(\"Agents and episode runner defined.\")"
|
| 179 |
+
],
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"outputs": []
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "code",
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"source": [
|
| 187 |
+
"# Cell 5: Run baselines\n",
|
| 188 |
+
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
| 189 |
+
"print(\"=\" * 70)\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"baseline_results = {}\n",
|
| 192 |
+
"for name, fn in BASELINE_AGENTS.items():\n",
|
| 193 |
+
" baseline_results[name] = {}\n",
|
| 194 |
+
" for task in TASKS:\n",
|
| 195 |
+
" _rng = random.Random(42)\n",
|
| 196 |
+
" result = run_episode(task, fn, seed=42)\n",
|
| 197 |
+
" baseline_results[name][task] = result\n",
|
| 198 |
+
" print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
|
| 199 |
+
" f\"| energy={result['final_energy']:.2f}\")\n",
|
| 200 |
+
" print()\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"print(\"\\nLEADERBOARD\")\n",
|
| 203 |
+
"print(f\"{'Agent':<14s} {'Engage':>10s} {'Strategic':>12s} {'Competitive':>14s} {'Avg':>8s}\")\n",
|
| 204 |
+
"print(\"-\" * 60)\n",
|
| 205 |
+
"for name in BASELINE_AGENTS:\n",
|
| 206 |
+
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 207 |
+
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 208 |
+
],
|
| 209 |
+
"execution_count": null,
|
| 210 |
+
"outputs": []
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"cell_type": "code",
|
| 214 |
+
"metadata": {},
|
| 215 |
+
"source": [
|
| 216 |
+
"# Cell 6: Baseline plots\n",
|
| 217 |
+
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
| 218 |
+
"agent_names = list(BASELINE_AGENTS.keys())\n",
|
| 219 |
+
"colors = ['#E53935', '#FF9800', '#9E9E9E', '#42A5F5', '#4CAF50']\n",
|
| 220 |
+
"for i, task in enumerate(TASKS):\n",
|
| 221 |
+
" scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
|
| 222 |
+
" bars = axes[i].barh(agent_names, scores, color=colors)\n",
|
| 223 |
+
" axes[i].set_title(task.replace(\"monthly_\", \"\").title(), fontsize=13, fontweight='bold')\n",
|
| 224 |
+
" for bar, score in zip(bars, scores):\n",
|
| 225 |
+
" axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
|
| 226 |
+
" f\"{score:.4f}\", va='center', fontsize=9)\n",
|
| 227 |
+
"axes[0].set_ylabel(\"Agent\")\n",
|
| 228 |
+
"fig.suptitle(\"Viraltest v2 — Heuristic Baseline Leaderboard\", fontsize=14, fontweight='bold')\n",
|
| 229 |
+
"fig.tight_layout()\n",
|
| 230 |
+
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 231 |
+
"plt.show()"
|
| 232 |
+
],
|
| 233 |
+
"execution_count": null,
|
| 234 |
+
"outputs": []
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "markdown",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"source": [
|
| 240 |
+
"## Part 2: Load LLM (Qwen2.5-1.5B-Instruct)\n",
|
| 241 |
+
"\n",
|
| 242 |
+
"We load the base model with 4-bit quantization to fit in free Colab's T4 GPU (16GB VRAM)."
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"source": [
|
| 249 |
+
"# Cell 7: Load model\n",
|
| 250 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 255 |
+
" load_in_4bit=True,\n",
|
| 256 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 257 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
| 258 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 259 |
+
")\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"print(f\"Loading {MODEL_NAME} (4-bit quantized)...\")\n",
|
| 262 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 263 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 264 |
+
" MODEL_NAME, trust_remote_code=True,\n",
|
| 265 |
+
" quantization_config=bnb_config,\n",
|
| 266 |
+
" device_map=\"auto\",\n",
|
| 267 |
+
")\n",
|
| 268 |
+
"model.eval()\n",
|
| 269 |
+
"print(f\"Model loaded. Device: {model.device}\")\n",
|
| 270 |
+
"print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
|
| 271 |
+
],
|
| 272 |
+
"execution_count": null,
|
| 273 |
+
"outputs": []
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"source": [
|
| 279 |
+
"# Cell 8: LLM agent functions\n",
|
| 280 |
+
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
| 281 |
+
"You are an Instagram content strategy agent. Each step is one day.\n",
|
| 282 |
+
"You manage a creator account over a 30-day cycle.\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
|
| 285 |
+
"{\n",
|
| 286 |
+
" \"tool_calls\": [{\"name\": \"query_trends\", \"arguments\": {\"niche\": \"tech\"}}],\n",
|
| 287 |
+
" \"scheduled_actions\": [\n",
|
| 288 |
+
" {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"reel\",\n",
|
| 289 |
+
" \"topic\": \"AI tools\", \"tags\": [\"ai\", \"coding\"], \"intent\": \"watch_bait\"}\n",
|
| 290 |
+
" ],\n",
|
| 291 |
+
" \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
|
| 292 |
+
" \"notes\": \"strategy notes\"\n",
|
| 293 |
+
"}\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"RULES:\n",
|
| 296 |
+
"- content_type: reel|story|carousel|text_post\n",
|
| 297 |
+
"- intent: send_bait|save_bait|watch_bait|like_bait\n",
|
| 298 |
+
"- 1-2 posts/day optimal. More = fatigue.\n",
|
| 299 |
+
"- Empty scheduled_actions = rest (recovers energy).\n",
|
| 300 |
+
"- Vary content types and topics for diversity bonus.\n",
|
| 301 |
+
"- Reply within 90 min of post for reach bonus.\"\"\")\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"def format_obs(obs):\n",
|
| 305 |
+
" days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 306 |
+
" day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 307 |
+
" signals_str = \"\"\n",
|
| 308 |
+
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
| 309 |
+
" if signals:\n",
|
| 310 |
+
" signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
|
| 311 |
+
" f\"sends={signals.sends_per_reach:.3f} \"\n",
|
| 312 |
+
" f\"saves={signals.saves:.3f}\\n\")\n",
|
| 313 |
+
" tool_str = \"\"\n",
|
| 314 |
+
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 315 |
+
" if tr.success:\n",
|
| 316 |
+
" tool_str += f\" {tr.name}: {json.dumps(tr.data)[:200]}\\n\"\n",
|
| 317 |
+
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 318 |
+
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 319 |
+
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 320 |
+
" f\"{signals_str}\"\n",
|
| 321 |
+
" f\"Tool results:\\n{tool_str if tool_str else ' (none)\\n'}\"\n",
|
| 322 |
+
" f\"Plan your actions (JSON only):\")\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"def parse_model_output(text):\n",
|
| 326 |
+
" text = text.strip()\n",
|
| 327 |
+
" if \"```\" in text:\n",
|
| 328 |
+
" lines = [l for l in text.split(\"\\n\") if not l.strip().startswith(\"```\")]\n",
|
| 329 |
+
" text = \"\\n\".join(lines).strip()\n",
|
| 330 |
+
" start, end = text.find(\"{\"), text.rfind(\"}\") + 1\n",
|
| 331 |
+
" if start >= 0 and end > start:\n",
|
| 332 |
+
" text = text[start:end]\n",
|
| 333 |
+
" try:\n",
|
| 334 |
+
" data = json.loads(text)\n",
|
| 335 |
+
" tool_calls = [ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {}))\n",
|
| 336 |
+
" for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
|
| 337 |
+
" scheduled = []\n",
|
| 338 |
+
" for a in data.get(\"scheduled_actions\", []):\n",
|
| 339 |
+
" try: scheduled.append(ScheduledAction(**a))\n",
|
| 340 |
+
" except: pass\n",
|
| 341 |
+
" return ViraltestAction(tool_calls=tool_calls, scheduled_actions=scheduled,\n",
|
| 342 |
+
" replies=data.get(\"replies\", []), notes=data.get(\"notes\"))\n",
|
| 343 |
+
" except:\n",
|
| 344 |
+
" return ViraltestAction(scheduled_actions=[])\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
|
| 348 |
+
" prompt = format_obs(obs)\n",
|
| 349 |
+
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
|
| 350 |
+
" messages.extend(history[-4:])\n",
|
| 351 |
+
" messages.append({\"role\": \"user\", \"content\": prompt})\n",
|
| 352 |
+
" text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 353 |
+
" inputs = tok(text_input, return_tensors=\"pt\").to(mdl.device)\n",
|
| 354 |
+
" with torch.no_grad():\n",
|
| 355 |
+
" out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
|
| 356 |
+
" do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
|
| 357 |
+
" resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 358 |
+
" return resp, parse_model_output(resp)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
|
| 362 |
+
" env = ViraltestEnvironment()\n",
|
| 363 |
+
" obs = env.reset(task=task, seed=seed)\n",
|
| 364 |
+
" rewards, energies = [], [obs.creator_energy]\n",
|
| 365 |
+
" history, pairs = [], []\n",
|
| 366 |
+
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 367 |
+
" if obs.done: break\n",
|
| 368 |
+
" if obs.creator_energy <= 0.25:\n",
|
| 369 |
+
" action = ViraltestAction(scheduled_actions=[])\n",
|
| 370 |
+
" resp = '{\"scheduled_actions\": []}'\n",
|
| 371 |
+
" else:\n",
|
| 372 |
+
" resp, action = generate_action(mdl, tok, obs, history)\n",
|
| 373 |
+
" prompt = format_obs(obs)\n",
|
| 374 |
+
" pairs.append({\"prompt\": prompt, \"response\": resp})\n",
|
| 375 |
+
" obs = env.step(action)\n",
|
| 376 |
+
" r = obs.reward or 0.0\n",
|
| 377 |
+
" rewards.append(r)\n",
|
| 378 |
+
" energies.append(obs.creator_energy)\n",
|
| 379 |
+
" history.extend([{\"role\": \"user\", \"content\": prompt},\n",
|
| 380 |
+
" {\"role\": \"assistant\", \"content\": resp}])\n",
|
| 381 |
+
" if verbose:\n",
|
| 382 |
+
" n_p = len([s for s in action.scheduled_actions if s.action_type==\"post\"])\n",
|
| 383 |
+
" print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
|
| 384 |
+
" if obs.done: break\n",
|
| 385 |
+
" gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 386 |
+
" return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
|
| 387 |
+
" \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
|
| 388 |
+
" \"energies\": energies, \"pairs\": pairs,\n",
|
| 389 |
+
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 390 |
+
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"print(\"LLM agent functions defined.\")"
|
| 393 |
+
],
|
| 394 |
+
"execution_count": null,
|
| 395 |
+
"outputs": []
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "markdown",
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"source": [
|
| 401 |
+
"## Part 3: Untrained LLM Baseline (“Before”)\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"Run the base model with NO fine-tuning. This establishes ground truth."
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"cell_type": "code",
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"source": [
|
| 410 |
+
"# Cell 9: Run untrained model\n",
|
| 411 |
+
"print(\"Running UNTRAINED base model on all tasks...\")\n",
|
| 412 |
+
"print(\"=\" * 60)\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"before_results = {}\n",
|
| 415 |
+
"for task in TASKS:\n",
|
| 416 |
+
" print(f\"\\n Task: {task}\")\n",
|
| 417 |
+
" result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
|
| 418 |
+
" before_results[task] = result\n",
|
| 419 |
+
" print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"print(\"\\n\" + \"=\" * 60)\n",
|
| 422 |
+
"print(\"BEFORE TRAINING:\")\n",
|
| 423 |
+
"for t in TASKS:\n",
|
| 424 |
+
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 425 |
+
],
|
| 426 |
+
"execution_count": null,
|
| 427 |
+
"outputs": []
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"cell_type": "markdown",
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"source": [
|
| 433 |
+
"## Part 4: LoRA Fine-Tuning (Real Weight Updates)\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"This is the core training loop. For each round:\n",
|
| 436 |
+
"1. Collect episodes with current model\n",
|
| 437 |
+
"2. Score each (prompt, response) pair by episode reward\n",
|
| 438 |
+
"3. Keep top 50% highest-reward samples\n",
|
| 439 |
+
"4. Fine-tune LoRA weights via SFT on those samples\n",
|
| 440 |
+
"\n",
|
| 441 |
+
"The model's actual weights change via gradient descent — this is real training."
|
| 442 |
+
]
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"cell_type": "code",
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"source": [
|
| 448 |
+
"# Cell 10: Attach LoRA adapter\n",
|
| 449 |
+
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"lora_config = LoraConfig(\n",
|
| 452 |
+
" r=16, lora_alpha=32, lora_dropout=0.05,\n",
|
| 453 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 454 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 455 |
+
" task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
|
| 456 |
+
")\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"model.enable_input_require_grads()\n",
|
| 459 |
+
"peft_model = get_peft_model(model, lora_config)\n",
|
| 460 |
+
"peft_model.print_trainable_parameters()"
|
| 461 |
+
],
|
| 462 |
+
"execution_count": null,
|
| 463 |
+
"outputs": []
|
| 464 |
+
},
|
| 465 |
+
{
|
| 466 |
+
"cell_type": "code",
|
| 467 |
+
"metadata": {},
|
| 468 |
+
"source": [
|
| 469 |
+
"# Cell 11: Training loop\n",
|
| 470 |
+
"from trl import SFTTrainer, SFTConfig\n",
|
| 471 |
+
"from datasets import Dataset\n",
|
| 472 |
+
"\n",
|
| 473 |
+
"NUM_ROUNDS = 4\n",
|
| 474 |
+
"EPISODES_PER_ROUND = 6\n",
|
| 475 |
+
"TOP_K_FRACTION = 0.5\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"training_log = {\n",
|
| 478 |
+
" \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
|
| 479 |
+
" \"min_episode_reward\": [], \"avg_grader\": [], \"max_grader\": [],\n",
|
| 480 |
+
" \"n_training_samples\": [], \"train_loss\": [],\n",
|
| 481 |
+
"}\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"t_start = time.time()\n",
|
| 484 |
+
"\n",
|
| 485 |
+
"for round_idx in range(1, NUM_ROUNDS + 1):\n",
|
| 486 |
+
" print(f\"\\n{'=' * 60}\")\n",
|
| 487 |
+
" print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
|
| 488 |
+
" print(f\"{'=' * 60}\")\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" # Collect episodes\n",
|
| 491 |
+
" peft_model.eval()\n",
|
| 492 |
+
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
| 493 |
+
"\n",
|
| 494 |
+
" for ep in range(EPISODES_PER_ROUND):\n",
|
| 495 |
+
" task = TASKS[ep % len(TASKS)]\n",
|
| 496 |
+
" seed = 42 + (round_idx - 1) * 100 + ep\n",
|
| 497 |
+
" result = run_llm_episode(peft_model, tokenizer, task, seed=seed)\n",
|
| 498 |
+
" ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 499 |
+
" episode_rewards.append(ep_reward)\n",
|
| 500 |
+
" episode_graders.append(result[\"grader_score\"])\n",
|
| 501 |
+
"\n",
|
| 502 |
+
" for pr in result[\"pairs\"]:\n",
|
| 503 |
+
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 504 |
+
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 505 |
+
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 506 |
+
" all_pairs.append({\"text\": text, \"reward\": ep_reward})\n",
|
| 507 |
+
"\n",
|
| 508 |
+
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
|
| 509 |
+
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\")\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" avg_r = np.mean(episode_rewards)\n",
|
| 512 |
+
" avg_g = np.mean(episode_graders)\n",
|
| 513 |
+
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
|
| 514 |
+
"\n",
|
| 515 |
+
" # Filter to top-K\n",
|
| 516 |
+
" threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
|
| 517 |
+
" filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
|
| 518 |
+
" print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 521 |
+
"\n",
|
| 522 |
+
" # SFT training (real gradient updates)\n",
|
| 523 |
+
" sft_config = SFTConfig(\n",
|
| 524 |
+
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 525 |
+
" num_train_epochs=2,\n",
|
| 526 |
+
" per_device_train_batch_size=1,\n",
|
| 527 |
+
" gradient_accumulation_steps=4,\n",
|
| 528 |
+
" learning_rate=2e-5,\n",
|
| 529 |
+
" warmup_steps=5,\n",
|
| 530 |
+
" logging_steps=5,\n",
|
| 531 |
+
" save_strategy=\"no\",\n",
|
| 532 |
+
" max_seq_length=1024,\n",
|
| 533 |
+
" fp16=True,\n",
|
| 534 |
+
" report_to=\"none\",\n",
|
| 535 |
+
" )\n",
|
| 536 |
+
"\n",
|
| 537 |
+
" peft_model.train()\n",
|
| 538 |
+
" trainer = SFTTrainer(\n",
|
| 539 |
+
" model=peft_model, tokenizer=tokenizer,\n",
|
| 540 |
+
" train_dataset=dataset, args=sft_config,\n",
|
| 541 |
+
" )\n",
|
| 542 |
+
" train_result = trainer.train()\n",
|
| 543 |
+
" loss = train_result.training_loss\n",
|
| 544 |
+
" print(f\" Training loss: {loss:.4f}\")\n",
|
| 545 |
+
"\n",
|
| 546 |
+
" training_log[\"round\"].append(round_idx)\n",
|
| 547 |
+
" training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
|
| 548 |
+
" training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
|
| 549 |
+
" training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
|
| 550 |
+
" training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
|
| 551 |
+
" training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
|
| 552 |
+
" training_log[\"n_training_samples\"].append(len(filtered))\n",
|
| 553 |
+
" training_log[\"train_loss\"].append(round(loss, 4))\n",
|
| 554 |
+
"\n",
|
| 555 |
+
"elapsed = time.time() - t_start\n",
|
| 556 |
+
"print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
|
| 557 |
+
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 558 |
+
],
|
| 559 |
+
"execution_count": null,
|
| 560 |
+
"outputs": []
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"cell_type": "markdown",
|
| 564 |
+
"metadata": {},
|
| 565 |
+
"source": [
|
| 566 |
+
"## Part 5: Trained LLM Evaluation (“After”)\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"Same model, same seeds, same environment — but now with updated LoRA weights."
|
| 569 |
+
]
|
| 570 |
+
},
|
| 571 |
+
{
|
| 572 |
+
"cell_type": "code",
|
| 573 |
+
"metadata": {},
|
| 574 |
+
"source": [
|
| 575 |
+
"# Cell 12: Run trained model\n",
|
| 576 |
+
"print(\"Running TRAINED model on all tasks...\")\n",
|
| 577 |
+
"print(\"=\" * 60)\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"peft_model.eval()\n",
|
| 580 |
+
"after_results = {}\n",
|
| 581 |
+
"for task in TASKS:\n",
|
| 582 |
+
" print(f\"\\n Task: {task}\")\n",
|
| 583 |
+
" result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
|
| 584 |
+
" after_results[task] = result\n",
|
| 585 |
+
" print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"print(\"\\n\" + \"=\" * 60)\n",
|
| 588 |
+
"print(\"AFTER TRAINING:\")\n",
|
| 589 |
+
"for t in TASKS:\n",
|
| 590 |
+
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 591 |
+
],
|
| 592 |
+
"execution_count": null,
|
| 593 |
+
"outputs": []
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "markdown",
|
| 597 |
+
"metadata": {},
|
| 598 |
+
"source": [
|
| 599 |
+
"## Part 6: Result Plots — Real Training Evidence"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"cell_type": "code",
|
| 604 |
+
"metadata": {},
|
| 605 |
+
"source": [
|
| 606 |
+
"# Cell 13: Training curves\n",
|
| 607 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 608 |
+
"rounds = training_log[\"round\"]\n",
|
| 609 |
+
"\n",
|
| 610 |
+
"axes[0].plot(rounds, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
|
| 611 |
+
"axes[0].fill_between(rounds, training_log[\"avg_grader\"],\n",
|
| 612 |
+
" training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
|
| 613 |
+
"axes[0].set_xlabel('Round'); axes[0].set_ylabel('Grader Score')\n",
|
| 614 |
+
"axes[0].set_title('Grader Score Over Rounds', fontweight='bold')\n",
|
| 615 |
+
"axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
|
| 618 |
+
"axes[1].set_xlabel('Round'); axes[1].set_ylabel('Loss')\n",
|
| 619 |
+
"axes[1].set_title('Training Loss', fontweight='bold')\n",
|
| 620 |
+
"axes[1].grid(True, alpha=0.3)\n",
|
| 621 |
+
"\n",
|
| 622 |
+
"fig.suptitle('Viraltest v2 — LoRA Training Progress (Qwen 1.5B)', fontsize=14, fontweight='bold')\n",
|
| 623 |
+
"fig.tight_layout()\n",
|
| 624 |
+
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 625 |
+
"plt.show()"
|
| 626 |
+
],
|
| 627 |
+
"execution_count": null,
|
| 628 |
+
"outputs": []
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "code",
|
| 632 |
+
"metadata": {},
|
| 633 |
+
"source": [
|
| 634 |
+
"# Cell 14: Before vs After\n",
|
| 635 |
+
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
| 636 |
+
"x = np.arange(len(TASKS))\n",
|
| 637 |
+
"w = 0.25\n",
|
| 638 |
+
"\n",
|
| 639 |
+
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
| 640 |
+
"b_scores = [before_results[t][\"grader_score\"] for t in TASKS]\n",
|
| 641 |
+
"a_scores = [after_results[t][\"grader_score\"] for t in TASKS]\n",
|
| 642 |
+
"s_scores = [baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS]\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"ax.bar(x - w, b_scores, w, label='Base Model (Before)', color='#FF9800')\n",
|
| 645 |
+
"ax.bar(x, a_scores, w, label='LoRA Trained (After)', color='#4CAF50')\n",
|
| 646 |
+
"ax.bar(x + w, s_scores, w, label='Smart Heuristic', color='#9E9E9E', alpha=0.7)\n",
|
| 647 |
+
"\n",
|
| 648 |
+
"ax.set_ylabel('Grader Score'); ax.set_xticks(x); ax.set_xticklabels(task_labels)\n",
|
| 649 |
+
"ax.set_title('Before vs After LoRA Training — Grader Scores', fontsize=14, fontweight='bold')\n",
|
| 650 |
+
"ax.legend(); ax.grid(True, alpha=0.3, axis='y')\n",
|
| 651 |
+
"\n",
|
| 652 |
+
"for container in ax.containers:\n",
|
| 653 |
+
" for bar in container:\n",
|
| 654 |
+
" h = bar.get_height()\n",
|
| 655 |
+
" if h > 0:\n",
|
| 656 |
+
" ax.text(bar.get_x() + bar.get_width()/2., h + 0.005,\n",
|
| 657 |
+
" f'{h:.4f}', ha='center', va='bottom', fontsize=9)\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"fig.tight_layout()\n",
|
| 660 |
+
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 661 |
+
"plt.show()"
|
| 662 |
+
],
|
| 663 |
+
"execution_count": null,
|
| 664 |
+
"outputs": []
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"cell_type": "code",
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"source": [
|
| 670 |
+
"# Cell 15: Trajectory comparison\n",
|
| 671 |
+
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
| 672 |
+
"comparisons = [\n",
|
| 673 |
+
" (\"Base Model\", before_results, '#FF9800', '--'),\n",
|
| 674 |
+
" (\"LoRA Trained\", after_results, '#4CAF50', '-'),\n",
|
| 675 |
+
"]\n",
|
| 676 |
+
"for i, task in enumerate(TASKS):\n",
|
| 677 |
+
" for label, res, color, ls in comparisons:\n",
|
| 678 |
+
" lw = 2.5 if 'Trained' in label else 1.5\n",
|
| 679 |
+
" axes[0, i].plot(res[task][\"rewards\"], label=label, color=color, lw=lw, ls=ls)\n",
|
| 680 |
+
" axes[1, i].plot(res[task][\"energies\"], label=label, color=color, lw=lw, ls=ls)\n",
|
| 681 |
+
" sr = baseline_results[\"smart\"][task]\n",
|
| 682 |
+
" axes[0, i].plot(sr[\"rewards\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 683 |
+
" axes[1, i].plot(sr[\"energies\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 684 |
+
" t_name = task.replace('monthly_', '').title()\n",
|
| 685 |
+
" axes[0, i].set_title(f\"{t_name} — Rewards\"); axes[0, i].grid(True, alpha=0.3)\n",
|
| 686 |
+
" axes[1, i].set_title(f\"{t_name} — Energy\"); axes[1, i].grid(True, alpha=0.3)\n",
|
| 687 |
+
"axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
|
| 688 |
+
"fig.suptitle('Before vs After — Daily Trajectories', fontsize=14, fontweight='bold', y=1.01)\n",
|
| 689 |
+
"fig.tight_layout()\n",
|
| 690 |
+
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 691 |
+
"plt.show()"
|
| 692 |
+
],
|
| 693 |
+
"execution_count": null,
|
| 694 |
+
"outputs": []
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"cell_type": "markdown",
|
| 698 |
+
"metadata": {},
|
| 699 |
+
"source": [
|
| 700 |
+
"## Part 7: Summary & Export"
|
| 701 |
+
]
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"cell_type": "code",
|
| 705 |
+
"metadata": {},
|
| 706 |
+
"source": [
|
| 707 |
+
"# Cell 16: Final summary\n",
|
| 708 |
+
"print(\"=\" * 67)\n",
|
| 709 |
+
"print(\"FINAL RESULTS\")\n",
|
| 710 |
+
"print(\"=\" * 67)\n",
|
| 711 |
+
"print(f\"\\n{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}\")\n",
|
| 712 |
+
"print(\"-\" * 67)\n",
|
| 713 |
+
"for task in TASKS:\n",
|
| 714 |
+
" b = before_results[task][\"grader_score\"]\n",
|
| 715 |
+
" a = after_results[task][\"grader_score\"]\n",
|
| 716 |
+
" s = baseline_results[\"smart\"][task][\"grader_score\"]\n",
|
| 717 |
+
" print(f\"{task:<25s} {b:>10.4f} {a:>10.4f} {a-b:>+10.4f} {s:>10.4f}\")\n",
|
| 718 |
+
"\n",
|
| 719 |
+
"avg_b = np.mean([before_results[t][\"grader_score\"] for t in TASKS])\n",
|
| 720 |
+
"avg_a = np.mean([after_results[t][\"grader_score\"] for t in TASKS])\n",
|
| 721 |
+
"avg_s = np.mean([baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS])\n",
|
| 722 |
+
"print(\"-\" * 67)\n",
|
| 723 |
+
"print(f\"{'AVERAGE':<25s} {avg_b:>10.4f} {avg_a:>10.4f} {avg_a-avg_b:>+10.4f} {avg_s:>10.4f}\")\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"summary = {\n",
|
| 726 |
+
" \"model\": MODEL_NAME,\n",
|
| 727 |
+
" \"training\": \"LoRA SFT (real weight updates)\",\n",
|
| 728 |
+
" \"rounds\": NUM_ROUNDS, \"episodes_per_round\": EPISODES_PER_ROUND,\n",
|
| 729 |
+
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 730 |
+
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 731 |
+
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
| 732 |
+
" \"improvement\": {t: after_results[t][\"grader_score\"] - before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 733 |
+
" \"training_log\": training_log,\n",
|
| 734 |
+
"}\n",
|
| 735 |
+
"with open(f\"{PLOTS_DIR}/training_summary.json\", \"w\") as f:\n",
|
| 736 |
+
" json.dump(summary, f, indent=2)\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"pd.DataFrame(training_log).to_csv(f\"{PLOTS_DIR}/training_log.csv\", index=False)\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 741 |
+
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 742 |
+
],
|
| 743 |
+
"execution_count": null,
|
| 744 |
+
"outputs": []
|
| 745 |
+
},
|
| 746 |
+
{
|
| 747 |
+
"cell_type": "code",
|
| 748 |
+
"metadata": {},
|
| 749 |
+
"source": [
|
| 750 |
+
"# Cell 17: Save adapter\n",
|
| 751 |
+
"save_path = \"./viraltest_trained_adapter\"\n",
|
| 752 |
+
"peft_model.save_pretrained(save_path)\n",
|
| 753 |
+
"tokenizer.save_pretrained(save_path)\n",
|
| 754 |
+
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 755 |
+
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 756 |
+
],
|
| 757 |
+
"execution_count": null,
|
| 758 |
+
"outputs": []
|
| 759 |
+
}
|
| 760 |
+
],
|
| 761 |
+
"metadata": {
|
| 762 |
+
"kernelspec": {
|
| 763 |
+
"display_name": "Python 3",
|
| 764 |
+
"language": "python",
|
| 765 |
+
"name": "python3"
|
| 766 |
+
},
|
| 767 |
+
"language_info": {
|
| 768 |
+
"name": "python",
|
| 769 |
+
"version": "3.10.0"
|
| 770 |
+
},
|
| 771 |
+
"accelerator": "GPU",
|
| 772 |
+
"gpuClass": "standard"
|
| 773 |
},
|
| 774 |
+
"nbformat": 4,
|
| 775 |
+
"nbformat_minor": 4
|
| 776 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|