anuragredbus commited on
Commit
4a29e22
·
1 Parent(s): e2c547b

fix: rewrite training notebook for real LoRA fine-tuning on Colab

Browse files

- Add missing openenv-core dependency to install cell
- Self-contained: clones repo, installs all deps, runs end-to-end
- Real weight updates via LoRA + SFT (not prompt engineering)
- 4-bit quantization to fit free Colab T4 GPU
- Pipeline: baselines → untrained LLM → LoRA training → trained LLM → plots

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +774 -1039
training/train_grpo.ipynb CHANGED
@@ -1,1041 +1,776 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Viraltest v2 — GRPO Training on Qwen2.5-1.5B-Instruct\n",
8
- "\n",
9
- "This notebook trains an LLM to be an Instagram strategy agent using **Group Relative Policy Optimization (GRPO)**.\n",
10
- "\n",
11
- "**What we train:** The model learns to plan daily posting schedules (content type, timing, topics, tags, intent signals) that maximise engagement while managing energy/burnout.\n",
12
- "\n",
13
- "**Pipeline:**\n",
14
- "1. Run heuristic baselines (smart, spam, rest, random) to establish baseline scores\n",
15
- "2. Run the **untrained** base model and record scores\n",
16
- "3. Train with GRPO using environment rewards\n",
17
- "4. Run the **trained** model and compare\n",
18
- "5. Plot real reward curves and before/after comparisons\n",
19
- "\n",
20
- "**Requirements:** Free Colab T4 GPU, ~45 min total.\n",
21
- "\n",
22
- "**Reward:** per-step env reward (0-1) + 2× terminal `grader_score`."
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  },
25
- {
26
- "cell_type": "code",
27
- "execution_count": null,
28
- "metadata": {},
29
- "outputs": [],
30
- "source": [
31
- "!pip install -q trl>=0.12.0 transformers accelerate peft bitsandbytes datasets\n",
32
- "!pip install -q openai httpx matplotlib pandas\n",
33
- "!pip install -q openenv-core[core]>=0.2.2"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": null,
39
- "metadata": {},
40
- "outputs": [],
41
- "source": [
42
- "import json\n",
43
- "import os\n",
44
- "import time\n",
45
- "import random\n",
46
- "import copy\n",
47
- "from pathlib import Path\n",
48
- "from typing import Any, Dict, List, Optional, Tuple\n",
49
- "\n",
50
- "import matplotlib.pyplot as plt\n",
51
- "import numpy as np\n",
52
- "import pandas as pd\n",
53
- "\n",
54
- "PLOTS_DIR = Path(\"../plots\")\n",
55
- "PLOTS_DIR.mkdir(exist_ok=True)\n",
56
- "\n",
57
- "print(\"Imports OK\")"
58
- ]
59
- },
60
- {
61
- "cell_type": "markdown",
62
- "metadata": {},
63
- "source": [
64
- "## Part 1: Environment Setup — Direct In-Process Access\n",
65
- "\n",
66
- "We instantiate the environment directly (no HTTP server needed) so we can run hundreds of episodes quickly."
67
- ]
68
- },
69
- {
70
- "cell_type": "code",
71
- "execution_count": null,
72
- "metadata": {},
73
- "outputs": [],
74
- "source": [
75
- "import sys\n",
76
- "sys.path.insert(0, \"..\")\n",
77
- "\n",
78
- "from models import ScheduledAction, ViraltestAction, ToolCall\n",
79
- "from server.viraltest_environment import (\n",
80
- " ViraltestEnvironment,\n",
81
- " TAG_POOL,\n",
82
- " TOPIC_CATEGORIES,\n",
83
- " TASK_HORIZON,\n",
84
- ")\n",
85
- "\n",
86
- "ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
87
- "NICHES = list(TOPIC_CATEGORIES.keys())\n",
88
- "CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
89
- "INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
90
- "TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
91
- "\n",
92
- "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Niches: {len(NICHES)}\")\n",
93
- "print(f\"Tasks: {TASKS}\")\n",
94
- "print(f\"Horizon: {TASK_HORIZON} steps (days)\")"
95
- ]
96
- },
97
- {
98
- "cell_type": "markdown",
99
- "metadata": {},
100
- "source": [
101
- "## Part 2: Heuristic Baselines\n",
102
- "\n",
103
- "Before touching any LLM, we run scripted agents to establish a **baseline leaderboard**.\n",
104
- "This proves the environment can differentiate skill levels."
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "_rng = random.Random(42)\n",
114
- "\n",
115
- "\n",
116
- "def plan_always_rest(obs_dict: dict, day: int) -> ViraltestAction:\n",
117
- " return ViraltestAction(scheduled_actions=[], notes=\"Rest day.\")\n",
118
- "\n",
119
- "\n",
120
- "def plan_spam(obs_dict: dict, day: int) -> ViraltestAction:\n",
121
- " actions = [\n",
122
- " {\"hour\": h, \"action_type\": \"post\", \"content_type\": \"reel\",\n",
123
- " \"topic\": \"AI tools\", \"tags\": [\"ai\"], \"intent\": \"watch_bait\"}\n",
124
- " for h in range(24)\n",
125
- " ]\n",
126
- " return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
127
- "\n",
128
- "\n",
129
- "def plan_random(obs_dict: dict, day: int) -> ViraltestAction:\n",
130
- " actions = []\n",
131
- " for h in range(24):\n",
132
- " if _rng.random() < 0.1:\n",
133
- " ct = _rng.choice(CONTENT_TYPES)\n",
134
- " topic = _rng.choice(ALL_TOPICS)\n",
135
- " tags = _rng.sample(TAG_POOL[:30], min(3, len(TAG_POOL)))\n",
136
- " intent = _rng.choice(INTENTS)\n",
137
- " actions.append({\"hour\": h, \"action_type\": \"post\", \"content_type\": ct,\n",
138
- " \"topic\": topic, \"tags\": tags, \"intent\": intent})\n",
139
- " return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
140
- "\n",
141
- "\n",
142
- "def plan_minimal(obs_dict: dict, day: int) -> ViraltestAction:\n",
143
- " topic = ALL_TOPICS[day % len(ALL_TOPICS)]\n",
144
- " tags = [TAG_POOL[i % len(TAG_POOL)] for i in range(day, day + 3)]\n",
145
- " actions = [\n",
146
- " {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"carousel\",\n",
147
- " \"topic\": topic, \"tags\": tags, \"intent\": \"save_bait\"},\n",
148
- " ]\n",
149
- " return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])\n",
150
- "\n",
151
- "\n",
152
- "def plan_smart(obs_dict: dict, day: int) -> ViraltestAction:\n",
153
- " \"\"\"Best heuristic: 2 posts at peak hours, varied content types and intents, tag rotation.\"\"\"\n",
154
- " topic1 = ALL_TOPICS[(day * 2) % len(ALL_TOPICS)]\n",
155
- " topic2 = ALL_TOPICS[(day * 2 + 1) % len(ALL_TOPICS)]\n",
156
- " ct1 = CONTENT_TYPES[(day * 2) % 4]\n",
157
- " ct2 = CONTENT_TYPES[(day * 2 + 1) % 4]\n",
158
- " intent1 = INTENTS[(day * 2) % 4]\n",
159
- " intent2 = INTENTS[(day * 2 + 1) % 4]\n",
160
- " tags1 = [TAG_POOL[(day * 6 + i) % len(TAG_POOL)] for i in range(3)]\n",
161
- " tags2 = [TAG_POOL[(day * 6 + 3 + i) % len(TAG_POOL)] for i in range(3)]\n",
162
- "\n",
163
- " actions = [\n",
164
- " {\"hour\": 8, \"action_type\": \"create_content\"},\n",
165
- " {\"hour\": 12, \"action_type\": \"post\", \"content_type\": ct1,\n",
166
- " \"topic\": topic1, \"tags\": tags1, \"intent\": intent1},\n",
167
- " {\"hour\": 19, \"action_type\": \"post\", \"content_type\": ct2,\n",
168
- " \"topic\": topic2, \"tags\": tags2, \"intent\": intent2},\n",
169
- " ]\n",
170
- " replies = [{\"post_hour\": 12, \"reply_hour\": 13}]\n",
171
- " return ViraltestAction(\n",
172
- " scheduled_actions=[ScheduledAction(**a) for a in actions],\n",
173
- " replies=[{\"post_hour\": 12, \"reply_hour\": 13}],\n",
174
- " notes=f\"Day {day}: varied content at peak hours.\",\n",
175
- " )\n",
176
- "\n",
177
- "\n",
178
- "def plan_smart_with_tools(obs_dict: dict, day: int) -> ViraltestAction:\n",
179
- " \"\"\"Smart agent that also uses tools for world discovery.\"\"\"\n",
180
- " tool_calls = []\n",
181
- " if day <= 3:\n",
182
- " tool_calls.append(ToolCall(name=\"query_trends\", arguments={\"niche\": NICHES[day % len(NICHES)]}))\n",
183
- " if day % 5 == 0:\n",
184
- " tool_calls.append(ToolCall(name=\"query_competitor\", arguments={\"competitor_id\": \"niche_expert\", \"window_days\": 7}))\n",
185
- " if day % 7 == 0:\n",
186
- " tool_calls.append(ToolCall(name=\"query_audience\", arguments={\"segment_id\": \"gen_z\"}))\n",
187
- "\n",
188
- " base = plan_smart(obs_dict, day)\n",
189
- " return ViraltestAction(\n",
190
- " tool_calls=tool_calls,\n",
191
- " scheduled_actions=base.scheduled_actions,\n",
192
- " replies=base.replies,\n",
193
- " notes=f\"Day {day}: tool-assisted planning.\",\n",
194
- " )\n",
195
- "\n",
196
- "\n",
197
- "BASELINE_AGENTS = {\n",
198
- " \"always_rest\": plan_always_rest,\n",
199
- " \"spam\": plan_spam,\n",
200
- " \"random\": plan_random,\n",
201
- " \"minimal\": plan_minimal,\n",
202
- " \"smart\": plan_smart,\n",
203
- " \"smart_with_tools\": plan_smart_with_tools,\n",
204
- "}"
205
- ]
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "metadata": {},
211
- "outputs": [],
212
- "source": [
213
- "def run_episode(task: str, plan_fn, seed: int = 42) -> Dict[str, Any]:\n",
214
- " \"\"\"Run one full 30-day episode and return metrics.\"\"\"\n",
215
- " env = ViraltestEnvironment()\n",
216
- " obs = env.reset(task=task, seed=seed)\n",
217
- " obs_dict = obs.model_dump()\n",
218
- "\n",
219
- " rewards = []\n",
220
- " energies = [obs.creator_energy]\n",
221
- " followers_hist = [obs.follower_count]\n",
222
- "\n",
223
- " for day in range(1, TASK_HORIZON + 1):\n",
224
- " action = plan_fn(obs_dict, day)\n",
225
- " obs = env.step(action)\n",
226
- " obs_dict = obs.model_dump()\n",
227
- " r = obs.reward if obs.reward is not None else 0.0\n",
228
- " rewards.append(r)\n",
229
- " energies.append(obs.creator_energy)\n",
230
- " followers_hist.append(obs.follower_count)\n",
231
- " if obs.done:\n",
232
- " break\n",
233
- "\n",
234
- " grader_score = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
235
- "\n",
236
- " return {\n",
237
- " \"task\": task,\n",
238
- " \"steps\": len(rewards),\n",
239
- " \"total_reward\": sum(rewards),\n",
240
- " \"avg_reward\": sum(rewards) / len(rewards) if rewards else 0,\n",
241
- " \"grader_score\": grader_score,\n",
242
- " \"final_energy\": obs.creator_energy,\n",
243
- " \"min_energy\": min(energies),\n",
244
- " \"final_followers\": obs.follower_count,\n",
245
- " \"follower_delta\": obs.follower_count - 10000,\n",
246
- " \"burned_out\": obs.creator_energy <= 0,\n",
247
- " \"rewards\": rewards,\n",
248
- " \"energies\": energies,\n",
249
- " \"followers\": followers_hist,\n",
250
- " }\n",
251
- "\n",
252
- "\n",
253
- "print(\"Running heuristic baselines across all tasks...\")\n",
254
- "print(\"=\" * 80)\n",
255
- "\n",
256
- "baseline_results = {}\n",
257
- "for agent_name, plan_fn in BASELINE_AGENTS.items():\n",
258
- " baseline_results[agent_name] = {}\n",
259
- " for task in TASKS:\n",
260
- " _rng = random.Random(42)\n",
261
- " result = run_episode(task, plan_fn, seed=42)\n",
262
- " baseline_results[agent_name][task] = result\n",
263
- " print(f\" {agent_name:>20s} | {task:>22s} | score={result['grader_score']:.4f} | \"\n",
264
- " f\"reward={result['total_reward']:.3f} | energy={result['final_energy']:.2f} | \"\n",
265
- " f\"followers={result['follower_delta']:+d}\")\n",
266
- " print()\n",
267
- "\n",
268
- "print(\"\\n\" + \"=\" * 80)\n",
269
- "print(\"BASELINE LEADERBOARD (grader_score)\")\n",
270
- "print(\"=\" * 80)\n",
271
- "print(f\"{'Agent':<22s} {'engage':>10s} {'strategic':>12s} {'competitive':>14s} {'avg':>8s}\")\n",
272
- "print(\"-\" * 68)\n",
273
- "for agent_name in BASELINE_AGENTS:\n",
274
- " scores = [baseline_results[agent_name][t][\"grader_score\"] for t in TASKS]\n",
275
- " avg = sum(scores) / len(scores)\n",
276
- " print(f\"{agent_name:<22s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {avg:>8.4f}\")"
277
- ]
278
- },
279
- {
280
- "cell_type": "markdown",
281
- "metadata": {},
282
- "source": [
283
- "## Part 3: Baseline Visualization\n",
284
- "\n",
285
- "Plot the heuristic baseline results to show the environment differentiates skill levels."
286
- ]
287
- },
288
- {
289
- "cell_type": "code",
290
- "execution_count": null,
291
- "metadata": {},
292
- "outputs": [],
293
- "source": [
294
- "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
295
- "agent_names = list(BASELINE_AGENTS.keys())\n",
296
- "colors = ['#E53935', '#FF9800', '#9E9E9E', '#42A5F5', '#4CAF50', '#2E7D32']\n",
297
- "\n",
298
- "for i, task in enumerate(TASKS):\n",
299
- " scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
300
- " bars = axes[i].barh(agent_names, scores, color=colors)\n",
301
- " axes[i].set_title(task.replace(\"monthly_\", \"\").title(), fontsize=13, fontweight='bold')\n",
302
- " axes[i].set_xlim(0, max(max(scores) * 1.15, 0.01))\n",
303
- " for bar, score in zip(bars, scores):\n",
304
- " axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
305
- " f\"{score:.3f}\", va='center', fontsize=9)\n",
306
- "\n",
307
- "axes[0].set_ylabel(\"Agent\")\n",
308
- "fig.suptitle(\"Viraltest v2 — Heuristic Baseline Leaderboard\", fontsize=14, fontweight='bold')\n",
309
- "fig.tight_layout()\n",
310
- "fig.savefig(PLOTS_DIR / \"baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
311
- "plt.show()\n",
312
- "print(f\"Saved {PLOTS_DIR / 'baseline_leaderboard.png'}\")"
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "metadata": {},
319
- "outputs": [],
320
- "source": [
321
- "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
322
- "\n",
323
- "for i, task in enumerate(TASKS):\n",
324
- " for j, agent_name in enumerate(agent_names):\n",
325
- " result = baseline_results[agent_name][task]\n",
326
- " axes[0, i].plot(result[\"rewards\"], label=agent_name, color=colors[j], alpha=0.8)\n",
327
- " axes[1, i].plot(result[\"energies\"], label=agent_name, color=colors[j], alpha=0.8)\n",
328
- "\n",
329
- " axes[0, i].set_title(f\"{task.replace('monthly_', '').title()} — Rewards\", fontsize=11)\n",
330
- " axes[0, i].set_xlabel(\"Day\")\n",
331
- " axes[0, i].set_ylabel(\"Reward\")\n",
332
- " axes[0, i].grid(True, alpha=0.3)\n",
333
- "\n",
334
- " axes[1, i].set_title(f\"{task.replace('monthly_', '').title()} — Energy\", fontsize=11)\n",
335
- " axes[1, i].set_xlabel(\"Day\")\n",
336
- " axes[1, i].set_ylabel(\"Energy\")\n",
337
- " axes[1, i].grid(True, alpha=0.3)\n",
338
- "\n",
339
- "axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)\n",
340
- "fig.suptitle(\"Viraltest v2 — Daily Rewards & Energy by Agent\", fontsize=14, fontweight='bold', y=1.01)\n",
341
- "fig.tight_layout()\n",
342
- "fig.savefig(PLOTS_DIR / \"baseline_trajectories.png\", dpi=150, bbox_inches='tight')\n",
343
- "plt.show()\n",
344
- "print(f\"Saved {PLOTS_DIR / 'baseline_trajectories.png'}\")"
345
- ]
346
- },
347
- {
348
- "cell_type": "markdown",
349
- "metadata": {},
350
- "source": [
351
- "## Part 4: LLM Evaluation — Untrained Baseline\n",
352
- "\n",
353
- "We run the base Qwen2.5-1.5B-Instruct model (no fine-tuning) against the environment\n",
354
- "using the same prompt format as `inference.py`. This gives us the **before** scores.\n",
355
- "\n",
356
- "### Option A: Via HTTP (if you have a running env server + model API)\n",
357
- "Set `ENV_BASE_URL` and `API_BASE_URL` environment variables.\n",
358
- "\n",
359
- "### Option B: Direct in-process (no server needed)\n",
360
- "We load the model locally and run the environment directly. This is what we do below."
361
- ]
362
- },
363
- {
364
- "cell_type": "code",
365
- "execution_count": null,
366
- "metadata": {},
367
- "outputs": [],
368
- "source": [
369
- "import textwrap\n",
370
- "import torch\n",
371
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
372
- "\n",
373
- "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
374
- "\n",
375
- "print(f\"Loading {MODEL_NAME}...\")\n",
376
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
377
- "model = AutoModelForCausalLM.from_pretrained(\n",
378
- " MODEL_NAME,\n",
379
- " trust_remote_code=True,\n",
380
- " torch_dtype=torch.float16,\n",
381
- " device_map=\"auto\",\n",
382
- ")\n",
383
- "model.eval()\n",
384
- "print(f\"Model loaded on {model.device}\")"
385
- ]
386
- },
387
- {
388
- "cell_type": "code",
389
- "execution_count": null,
390
- "metadata": {},
391
- "outputs": [],
392
- "source": [
393
- "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
394
- "You are an Instagram content strategy agent. Each step is one full day (24 hours).\n",
395
- "You manage a creator account over a 30-day monthly cycle.\n",
396
- "\n",
397
- "You receive a SPARSE observation (energy, followers, last reward, notes echo).\n",
398
- "To learn about the world, you MUST use TOOLS before planning your day.\n",
399
- "\n",
400
- "AVAILABLE TOOLS (call via tool_calls before scheduling posts):\n",
401
- "- query_trends(niche): Get trending topics and tags for a niche\n",
402
- "- query_competitor(competitor_id, window_days): See competitor activity\n",
403
- "- query_tag_history(tag): Check your past performance with a tag\n",
404
- "- query_audience(segment_id): Learn audience segment preferences\n",
405
- "- predict_engagement(scheduled_actions): Simulate engagement without committing\n",
406
- "- draft_review(scheduled_actions): Get feedback on a draft plan\n",
407
- "\n",
408
- "RESPONSE FORMAT (JSON only, no markdown, no prose):\n",
409
- "{\n",
410
- " \"tool_calls\": [\n",
411
- " {\"name\": \"query_trends\", \"arguments\": {\"niche\": \"tech\"}}\n",
412
- " ],\n",
413
- " \"scheduled_actions\": [\n",
414
- " {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"reel\", \"topic\": \"AI tools\", \"tags\": [\"ai\", \"coding\"], \"intent\": \"watch_bait\"},\n",
415
- " {\"hour\": 19, \"action_type\": \"post\", \"content_type\": \"carousel\", \"topic\": \"startup life\", \"tags\": [\"startup\"], \"intent\": \"save_bait\"}\n",
416
- " ],\n",
417
- " \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
418
- " \"notes\": \"Day 3: tech niche trending up.\"\n",
419
- "}\n",
420
- "\n",
421
- "RULES:\n",
422
- "- hour: 0-23. content_type: reel|story|carousel|text_post. intent: send_bait|save_bait|watch_bait|like_bait\n",
423
- "- 1-2 posts per day is optimal. More causes audience fatigue.\n",
424
- "- Empty scheduled_actions = rest all day (recovers energy)\n",
425
- "- Use notes to track hypotheses across days\n",
426
- "- Tool calls cost API budget (starts at 100). Use wisely.\n",
427
- "- Reply within 90 minutes of a post for reach bonus\"\"\")\n",
428
- "\n",
429
- "\n",
430
- "def format_obs_for_prompt(obs) -> str:\n",
431
- " \"\"\"Format environment observation into a prompt string.\"\"\"\n",
432
- " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
433
- " day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
434
- " notes_echo = getattr(obs, \"agent_notes\", None) or \"none\"\n",
435
- " budget = getattr(obs, \"api_budget_remaining\", 100)\n",
436
- " burnout = getattr(obs, \"burnout_risk\", 0.0)\n",
437
- "\n",
438
- " tool_results_str = \"\"\n",
439
- " for tr in getattr(obs, \"tool_results\", []):\n",
440
- " if tr.success:\n",
441
- " tool_results_str += f\" {tr.name}: {json.dumps(tr.data)[:200]}\\n\"\n",
442
- " else:\n",
443
- " tool_results_str += f\" {tr.name}: ERROR - {tr.error}\\n\"\n",
444
- "\n",
445
- " coach = getattr(obs, \"coach_feedback\", None)\n",
446
- " coach_str = \"\"\n",
447
- " if coach:\n",
448
- " coach_str = f\"Coach: delta={coach.get('delta', 0):.3f}, suggestion={coach.get('suggestion', '')}\\n\"\n",
449
- "\n",
450
- " signals = getattr(obs, \"engagement_signals\", None)\n",
451
- " signals_str = \"\"\n",
452
- " if signals:\n",
453
- " signals_str = (\n",
454
- " f\"Signals: watch={signals.watch_time:.3f} sends={signals.sends_per_reach:.3f} \"\n",
455
- " f\"saves={signals.saves:.3f} likes={signals.likes_per_reach:.3f}\\n\"\n",
456
- " )\n",
457
- "\n",
458
- " return textwrap.dedent(f\"\"\"\\\n",
459
- "Day: {day_name} (day_of_week={obs.day_of_week}) | days_elapsed={obs.days_elapsed}\n",
460
- "Energy: {obs.creator_energy:.2f} | Burnout risk: {burnout:.2f} | Followers: {obs.follower_count}\n",
461
- "Engagement rate: {obs.engagement_rate:.3f} | Content queue: {obs.content_queue_size}\n",
462
- "API budget remaining: {budget}\n",
463
- "{signals_str}{coach_str}Tool results from last step:\n",
464
- "{tool_results_str if tool_results_str else ' (none)\\n'}Your notes from last step: {notes_echo}\n",
465
- "Plan your tool calls and actions for today:\"\"\")\n",
466
- "\n",
467
- "\n",
468
- "def parse_model_output(text: str) -> ViraltestAction:\n",
469
- " \"\"\"Parse model JSON output into a ViraltestAction.\"\"\"\n",
470
- " text = text.strip()\n",
471
- " if text.startswith(\"```\"):\n",
472
- " lines = text.split(\"\\n\")\n",
473
- " lines = [l for l in lines if not l.strip().startswith(\"```\")]\n",
474
- " text = \"\\n\".join(lines).strip()\n",
475
- "\n",
476
- " try:\n",
477
- " data = json.loads(text)\n",
478
- " tool_calls = []\n",
479
- " for tc in data.get(\"tool_calls\", []):\n",
480
- " if isinstance(tc, dict) and \"name\" in tc:\n",
481
- " tool_calls.append(ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {})))\n",
482
- "\n",
483
- " scheduled = []\n",
484
- " for a in data.get(\"scheduled_actions\", []):\n",
485
- " if isinstance(a, dict):\n",
486
- " try:\n",
487
- " scheduled.append(ScheduledAction(**a))\n",
488
- " except Exception:\n",
489
- " pass\n",
490
- "\n",
491
- " return ViraltestAction(\n",
492
- " tool_calls=tool_calls,\n",
493
- " scheduled_actions=scheduled,\n",
494
- " replies=data.get(\"replies\", []),\n",
495
- " notes=data.get(\"notes\"),\n",
496
- " )\n",
497
- " except (json.JSONDecodeError, Exception):\n",
498
- " return ViraltestAction(scheduled_actions=[])\n",
499
- "\n",
500
- "\n",
501
- "def generate_action(model, tokenizer, obs, history: List[dict], temperature=0.7, max_new_tokens=512) -> Tuple[str, ViraltestAction]:\n",
502
- " \"\"\"Generate an action from the model given an observation.\"\"\"\n",
503
- " user_prompt = format_obs_for_prompt(obs)\n",
504
- " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
505
- " messages.extend(history[-4:])\n",
506
- " messages.append({\"role\": \"user\", \"content\": user_prompt})\n",
507
- "\n",
508
- " text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
509
- " inputs = tokenizer(text_input, return_tensors=\"pt\").to(model.device)\n",
510
- "\n",
511
- " with torch.no_grad():\n",
512
- " output_ids = model.generate(\n",
513
- " **inputs,\n",
514
- " max_new_tokens=max_new_tokens,\n",
515
- " temperature=temperature,\n",
516
- " do_sample=True,\n",
517
- " top_p=0.9,\n",
518
- " pad_token_id=tokenizer.eos_token_id,\n",
519
- " )\n",
520
- "\n",
521
- " new_tokens = output_ids[0][inputs[\"input_ids\"].shape[1]:]\n",
522
- " response = tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
523
- " action = parse_model_output(response)\n",
524
- " return response, action\n",
525
- "\n",
526
- "print(\"LLM agent functions defined.\")"
527
- ]
528
- },
529
- {
530
- "cell_type": "code",
531
- "execution_count": null,
532
- "metadata": {},
533
- "outputs": [],
534
- "source": [
535
- "def run_llm_episode(model, tokenizer, task: str, seed: int = 42, verbose: bool = False) -> Dict[str, Any]:\n",
536
- " \"\"\"Run one full episode using the LLM agent.\"\"\"\n",
537
- " env = ViraltestEnvironment()\n",
538
- " obs = env.reset(task=task, seed=seed)\n",
539
- "\n",
540
- " rewards = []\n",
541
- " energies = [obs.creator_energy]\n",
542
- " history = []\n",
543
- " prompts_and_responses = []\n",
544
- "\n",
545
- " for day in range(1, TASK_HORIZON + 1):\n",
546
- " if obs.done:\n",
547
- " break\n",
548
- "\n",
549
- " if obs.creator_energy <= 0.25:\n",
550
- " action = ViraltestAction(scheduled_actions=[], notes=\"Low energy — forced rest.\")\n",
551
- " response_text = '{\"scheduled_actions\": [], \"notes\": \"Low energy — rest.\"}'\n",
552
- " else:\n",
553
- " response_text, action = generate_action(model, tokenizer, obs, history)\n",
554
- "\n",
555
- " prompt_text = format_obs_for_prompt(obs)\n",
556
- " prompts_and_responses.append({\n",
557
- " \"prompt\": prompt_text,\n",
558
- " \"response\": response_text,\n",
559
- " })\n",
560
- "\n",
561
- " obs = env.step(action)\n",
562
- " r = obs.reward if obs.reward is not None else 0.0\n",
563
- " rewards.append(r)\n",
564
- " energies.append(obs.creator_energy)\n",
565
- "\n",
566
- " history.append({\"role\": \"user\", \"content\": prompt_text})\n",
567
- " history.append({\"role\": \"assistant\", \"content\": response_text})\n",
568
- "\n",
569
- " if verbose:\n",
570
- " n_posts = len([sa for sa in action.scheduled_actions if sa.action_type == \"post\"])\n",
571
- " n_tools = len(action.tool_calls)\n",
572
- " print(f\" Day {day:2d}: reward={r:.4f} energy={obs.creator_energy:.2f} \"\n",
573
- " f\"posts={n_posts} tools={n_tools}\")\n",
574
- "\n",
575
- " if obs.done:\n",
576
- " break\n",
577
- "\n",
578
- " grader_score = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
579
- "\n",
580
- " return {\n",
581
- " \"task\": task,\n",
582
- " \"steps\": len(rewards),\n",
583
- " \"total_reward\": sum(rewards),\n",
584
- " \"avg_reward\": sum(rewards) / len(rewards) if rewards else 0,\n",
585
- " \"grader_score\": grader_score,\n",
586
- " \"final_energy\": obs.creator_energy,\n",
587
- " \"min_energy\": min(energies),\n",
588
- " \"final_followers\": obs.follower_count,\n",
589
- " \"follower_delta\": obs.follower_count - 10000,\n",
590
- " \"burned_out\": obs.creator_energy <= 0,\n",
591
- " \"rewards\": rewards,\n",
592
- " \"energies\": energies,\n",
593
- " \"prompts_and_responses\": prompts_and_responses,\n",
594
- " }\n",
595
- "\n",
596
- "print(\"LLM episode runner defined.\")"
597
- ]
598
- },
599
- {
600
- "cell_type": "code",
601
- "execution_count": null,
602
- "metadata": {},
603
- "outputs": [],
604
- "source": [
605
- "print(\"Running UNTRAINED base model...\")\n",
606
- "print(\"=\" * 60)\n",
607
- "\n",
608
- "before_results = {}\n",
609
- "for task in TASKS:\n",
610
- " print(f\"\\nTask: {task}\")\n",
611
- " result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
612
- " before_results[task] = result\n",
613
- " print(f\" => grader_score={result['grader_score']:.4f}, \"\n",
614
- " f\"total_reward={result['total_reward']:.3f}, \"\n",
615
- " f\"burned_out={result['burned_out']}\")\n",
616
- "\n",
617
- "print(\"\\n\" + \"=\" * 60)\n",
618
- "print(\"BEFORE TRAINING SCORES\")\n",
619
- "print(\"=\" * 60)\n",
620
- "for task in TASKS:\n",
621
- " r = before_results[task]\n",
622
- " print(f\" {task}: grader={r['grader_score']:.4f} reward={r['total_reward']:.3f} energy={r['final_energy']:.2f}\")"
623
- ]
624
- },
625
- {
626
- "cell_type": "markdown",
627
- "metadata": {},
628
- "source": [
629
- "## Part 5: GRPO Training\n",
630
- "\n",
631
- "We use TRL's GRPO trainer to optimize the model on environment rewards.\n",
632
- "\n",
633
- "**Approach:** For each training step, we collect a batch of episodes, score them with the environment reward, and use GRPO to reinforce high-reward responses relative to the group.\n",
634
- "\n",
635
- "Since full multi-step GRPO with TRL requires careful integration, we use a **reward-weighted SFT** approach that achieves similar results:\n",
636
- "1. Collect N episodes with the current model\n",
637
- "2. Weight each (prompt, response) pair by its environment reward\n",
638
- "3. Fine-tune on the reward-weighted dataset\n",
639
- "4. Repeat for multiple rounds"
640
- ]
641
- },
642
- {
643
- "cell_type": "code",
644
- "execution_count": null,
645
- "metadata": {},
646
- "outputs": [],
647
- "source": [
648
- "from peft import LoraConfig, get_peft_model, TaskType\n",
649
- "from transformers import TrainingArguments\n",
650
- "from trl import SFTTrainer, SFTConfig\n",
651
- "from datasets import Dataset\n",
652
- "\n",
653
- "lora_config = LoraConfig(\n",
654
- " r=16,\n",
655
- " lora_alpha=32,\n",
656
- " lora_dropout=0.05,\n",
657
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
658
- " task_type=TaskType.CAUSAL_LM,\n",
659
- " bias=\"none\",\n",
660
- ")\n",
661
- "\n",
662
- "model.enable_input_require_grads()\n",
663
- "peft_model = get_peft_model(model, lora_config)\n",
664
- "peft_model.print_trainable_parameters()\n",
665
- "print(\"LoRA adapter attached.\")"
666
- ]
667
- },
668
- {
669
- "cell_type": "code",
670
- "execution_count": null,
671
- "metadata": {},
672
- "outputs": [],
673
- "source": [
674
- "def collect_training_data(\n",
675
- " model, tokenizer, n_episodes: int = 8, tasks: List[str] = None\n",
676
- ") -> Tuple[List[Dict], List[float]]:\n",
677
- " \"\"\"Collect episodes and build reward-weighted training pairs.\"\"\"\n",
678
- " tasks = tasks or TASKS\n",
679
- " all_pairs = []\n",
680
- " all_episode_rewards = []\n",
681
- "\n",
682
- " for ep in range(n_episodes):\n",
683
- " task = tasks[ep % len(tasks)]\n",
684
- " seed = 42 + ep\n",
685
- " result = run_llm_episode(model, tokenizer, task, seed=seed)\n",
686
- " episode_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
687
- " all_episode_rewards.append(episode_reward)\n",
688
- "\n",
689
- " for pr in result[\"prompts_and_responses\"]:\n",
690
- " step_text = (\n",
691
- " f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
692
- " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
693
- " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\"\n",
694
- " )\n",
695
- " all_pairs.append({\n",
696
- " \"text\": step_text,\n",
697
- " \"reward\": episode_reward,\n",
698
- " })\n",
699
- "\n",
700
- " return all_pairs, all_episode_rewards\n",
701
- "\n",
702
- "print(\"Data collection function defined.\")"
703
- ]
704
- },
705
- {
706
- "cell_type": "code",
707
- "execution_count": null,
708
- "metadata": {},
709
- "outputs": [],
710
- "source": [
711
- "NUM_ROUNDS = 4\n",
712
- "EPISODES_PER_ROUND = 6\n",
713
- "TOP_K_FRACTION = 0.5\n",
714
- "\n",
715
- "training_log = {\n",
716
- " \"round\": [],\n",
717
- " \"avg_episode_reward\": [],\n",
718
- " \"max_episode_reward\": [],\n",
719
- " \"min_episode_reward\": [],\n",
720
- " \"n_training_samples\": [],\n",
721
- " \"train_loss\": [],\n",
722
- "}\n",
723
- "\n",
724
- "for round_idx in range(1, NUM_ROUNDS + 1):\n",
725
- " print(f\"\\n{'=' * 60}\")\n",
726
- " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
727
- " print(f\"{'=' * 60}\")\n",
728
- "\n",
729
- " print(f\"Collecting {EPISODES_PER_ROUND} episodes...\")\n",
730
- " peft_model.eval()\n",
731
- " pairs, episode_rewards = collect_training_data(\n",
732
- " peft_model, tokenizer, n_episodes=EPISODES_PER_ROUND\n",
733
- " )\n",
734
- " avg_reward = sum(episode_rewards) / len(episode_rewards)\n",
735
- " print(f\" Episode rewards: {[f'{r:.3f}' for r in episode_rewards]}\")\n",
736
- " print(f\" Avg: {avg_reward:.3f}, Max: {max(episode_rewards):.3f}, Min: {min(episode_rewards):.3f}\")\n",
737
- "\n",
738
- " if not pairs:\n",
739
- " print(\" No training pairs collected, skipping round.\")\n",
740
- " continue\n",
741
- "\n",
742
- " reward_threshold = np.percentile(\n",
743
- " [p[\"reward\"] for p in pairs],\n",
744
- " (1 - TOP_K_FRACTION) * 100\n",
745
- " )\n",
746
- " filtered = [p for p in pairs if p[\"reward\"] >= reward_threshold]\n",
747
- " print(f\" Filtered to {len(filtered)}/{len(pairs)} samples (reward >= {reward_threshold:.3f})\")\n",
748
- "\n",
749
- " if not filtered:\n",
750
- " print(\" No samples above threshold, using all.\")\n",
751
- " filtered = pairs\n",
752
- "\n",
753
- " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
754
- "\n",
755
- " output_dir = f\"./viraltest_checkpoints/round_{round_idx}\"\n",
756
- " sft_config = SFTConfig(\n",
757
- " output_dir=output_dir,\n",
758
- " num_train_epochs=2,\n",
759
- " per_device_train_batch_size=1,\n",
760
- " gradient_accumulation_steps=4,\n",
761
- " learning_rate=2e-5,\n",
762
- " warmup_steps=5,\n",
763
- " logging_steps=5,\n",
764
- " save_strategy=\"no\",\n",
765
- " max_seq_length=1024,\n",
766
- " fp16=True,\n",
767
- " report_to=\"none\",\n",
768
- " )\n",
769
- "\n",
770
- " print(f\" Training on {len(dataset)} samples...\")\n",
771
- " peft_model.train()\n",
772
- " trainer = SFTTrainer(\n",
773
- " model=peft_model,\n",
774
- " tokenizer=tokenizer,\n",
775
- " train_dataset=dataset,\n",
776
- " args=sft_config,\n",
777
- " )\n",
778
- " train_result = trainer.train()\n",
779
- " train_loss = train_result.training_loss\n",
780
- " print(f\" Training loss: {train_loss:.4f}\")\n",
781
- "\n",
782
- " training_log[\"round\"].append(round_idx)\n",
783
- " training_log[\"avg_episode_reward\"].append(avg_reward)\n",
784
- " training_log[\"max_episode_reward\"].append(max(episode_rewards))\n",
785
- " training_log[\"min_episode_reward\"].append(min(episode_rewards))\n",
786
- " training_log[\"n_training_samples\"].append(len(filtered))\n",
787
- " training_log[\"train_loss\"].append(train_loss)\n",
788
- "\n",
789
- "print(\"\\n\" + \"=\" * 60)\n",
790
- "print(\"TRAINING COMPLETE\")\n",
791
- "print(\"=\" * 60)\n",
792
- "\n",
793
- "train_df = pd.DataFrame(training_log)\n",
794
- "print(train_df.to_string(index=False))\n",
795
- "\n",
796
- "train_df.to_csv(PLOTS_DIR / \"training_log.csv\", index=False)\n",
797
- "print(f\"\\nSaved training log to {PLOTS_DIR / 'training_log.csv'}\")"
798
- ]
799
- },
800
- {
801
- "cell_type": "markdown",
802
- "metadata": {},
803
- "source": [
804
- "## Part 6: Post-Training Evaluation\n",
805
- "\n",
806
- "Run the trained model on all three tasks and compare with before-training scores."
807
- ]
808
- },
809
- {
810
- "cell_type": "code",
811
- "execution_count": null,
812
- "metadata": {},
813
- "outputs": [],
814
- "source": [
815
- "print(\"Running TRAINED model...\")\n",
816
- "print(\"=\" * 60)\n",
817
- "\n",
818
- "peft_model.eval()\n",
819
- "\n",
820
- "after_results = {}\n",
821
- "for task in TASKS:\n",
822
- " print(f\"\\nTask: {task}\")\n",
823
- " result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
824
- " after_results[task] = result\n",
825
- " print(f\" => grader_score={result['grader_score']:.4f}, \"\n",
826
- " f\"total_reward={result['total_reward']:.3f}, \"\n",
827
- " f\"burned_out={result['burned_out']}\")\n",
828
- "\n",
829
- "print(\"\\n\" + \"=\" * 60)\n",
830
- "print(\"AFTER TRAINING SCORES\")\n",
831
- "print(\"=\" * 60)\n",
832
- "for task in TASKS:\n",
833
- " r = after_results[task]\n",
834
- " print(f\" {task}: grader={r['grader_score']:.4f} reward={r['total_reward']:.3f} energy={r['final_energy']:.2f}\")"
835
- ]
836
- },
837
- {
838
- "cell_type": "markdown",
839
- "metadata": {},
840
- "source": [
841
- "## Part 7: Result Plots — Real Training Evidence"
842
- ]
843
- },
844
- {
845
- "cell_type": "code",
846
- "execution_count": null,
847
- "metadata": {},
848
- "outputs": [],
849
- "source": [
850
- "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
851
- "\n",
852
- "rounds = training_log[\"round\"]\n",
853
- "axes[0].plot(rounds, training_log[\"avg_episode_reward\"], 'o-', color='#2196F3', linewidth=2, label='Avg reward')\n",
854
- "axes[0].fill_between(rounds, training_log[\"min_episode_reward\"], training_log[\"max_episode_reward\"],\n",
855
- " alpha=0.2, color='#2196F3', label='Min-Max range')\n",
856
- "axes[0].set_xlabel('Training Round', fontsize=12)\n",
857
- "axes[0].set_ylabel('Episode Reward', fontsize=12)\n",
858
- "axes[0].set_title('Training Reward Over Rounds', fontsize=13, fontweight='bold')\n",
859
- "axes[0].legend()\n",
860
- "axes[0].grid(True, alpha=0.3)\n",
861
- "\n",
862
- "axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', linewidth=2)\n",
863
- "axes[1].set_xlabel('Training Round', fontsize=12)\n",
864
- "axes[1].set_ylabel('Training Loss', fontsize=12)\n",
865
- "axes[1].set_title('Training Loss Over Rounds', fontsize=13, fontweight='bold')\n",
866
- "axes[1].grid(True, alpha=0.3)\n",
867
- "\n",
868
- "fig.suptitle('Viraltest v2 — GRPO Training Progress', fontsize=14, fontweight='bold', y=1.02)\n",
869
- "fig.tight_layout()\n",
870
- "fig.savefig(PLOTS_DIR / 'reward_curve.png', dpi=150, bbox_inches='tight')\n",
871
- "plt.show()\n",
872
- "print(f\"Saved {PLOTS_DIR / 'reward_curve.png'}\")"
873
- ]
874
- },
875
- {
876
- "cell_type": "code",
877
- "execution_count": null,
878
- "metadata": {},
879
- "outputs": [],
880
- "source": [
881
- "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
882
- "before_scores = [before_results[t][\"grader_score\"] for t in TASKS]\n",
883
- "after_scores = [after_results[t][\"grader_score\"] for t in TASKS]\n",
884
- "smart_scores = [baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS]\n",
885
- "\n",
886
- "x = np.arange(len(TASKS))\n",
887
- "width = 0.25\n",
888
- "\n",
889
- "fig, ax = plt.subplots(figsize=(10, 6))\n",
890
- "bars1 = ax.bar(x - width, before_scores, width, label='Base Model (Before)', color='#FF9800')\n",
891
- "bars2 = ax.bar(x, after_scores, width, label='Trained Model (After)', color='#4CAF50')\n",
892
- "bars3 = ax.bar(x + width, smart_scores, width, label='Smart Heuristic', color='#9E9E9E', alpha=0.7)\n",
893
- "\n",
894
- "ax.set_ylabel('Grader Score', fontsize=12)\n",
895
- "ax.set_title('Before vs After Training — Grader Scores', fontsize=14, fontweight='bold')\n",
896
- "ax.set_xticks(x)\n",
897
- "ax.set_xticklabels(task_labels, fontsize=11)\n",
898
- "ax.legend(fontsize=10)\n",
899
- "ax.grid(True, alpha=0.3, axis='y')\n",
900
- "\n",
901
- "for bars in [bars1, bars2, bars3]:\n",
902
- " for bar in bars:\n",
903
- " height = bar.get_height()\n",
904
- " if height > 0:\n",
905
- " ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,\n",
906
- " f'{height:.3f}', ha='center', va='bottom', fontsize=9)\n",
907
- "\n",
908
- "fig.tight_layout()\n",
909
- "fig.savefig(PLOTS_DIR / 'before_after.png', dpi=150, bbox_inches='tight')\n",
910
- "plt.show()\n",
911
- "print(f\"Saved {PLOTS_DIR / 'before_after.png'}\")"
912
- ]
913
- },
914
- {
915
- "cell_type": "code",
916
- "execution_count": null,
917
- "metadata": {},
918
- "outputs": [],
919
- "source": [
920
- "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
921
- "\n",
922
- "labels_and_data = [\n",
923
- " (\"Base Model\", before_results, '#FF9800'),\n",
924
- " (\"Trained Model\", after_results, '#4CAF50'),\n",
925
- "]\n",
926
- "\n",
927
- "for i, task in enumerate(TASKS):\n",
928
- " for label, results, color in labels_and_data:\n",
929
- " r = results[task]\n",
930
- " axes[0, i].plot(r[\"rewards\"], label=label, color=color, linewidth=1.5, alpha=0.9)\n",
931
- " axes[1, i].plot(r[\"energies\"], label=label, color=color, linewidth=1.5, alpha=0.9)\n",
932
- "\n",
933
- " smart_r = baseline_results[\"smart\"][task]\n",
934
- " axes[0, i].plot(smart_r[\"rewards\"], label=\"Smart Heuristic\", color='#9E9E9E',\n",
935
- " linewidth=1, alpha=0.5, linestyle='--')\n",
936
- " axes[1, i].plot(smart_r[\"energies\"], label=\"Smart Heuristic\", color='#9E9E9E',\n",
937
- " linewidth=1, alpha=0.5, linestyle='--')\n",
938
- "\n",
939
- " task_title = task.replace('monthly_', '').title()\n",
940
- " axes[0, i].set_title(f\"{task_title} — Daily Rewards\", fontsize=11)\n",
941
- " axes[0, i].set_xlabel(\"Day\")\n",
942
- " axes[0, i].set_ylabel(\"Reward\")\n",
943
- " axes[0, i].grid(True, alpha=0.3)\n",
944
- "\n",
945
- " axes[1, i].set_title(f\"{task_title} — Energy\", fontsize=11)\n",
946
- " axes[1, i].set_xlabel(\"Day\")\n",
947
- " axes[1, i].set_ylabel(\"Energy\")\n",
948
- " axes[1, i].grid(True, alpha=0.3)\n",
949
- "\n",
950
- "axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)\n",
951
- "fig.suptitle('Viraltest v2 — Before vs After Training Trajectories', fontsize=14, fontweight='bold', y=1.01)\n",
952
- "fig.tight_layout()\n",
953
- "fig.savefig(PLOTS_DIR / 'training_trajectories.png', dpi=150, bbox_inches='tight')\n",
954
- "plt.show()\n",
955
- "print(f\"Saved {PLOTS_DIR / 'training_trajectories.png'}\")"
956
- ]
957
- },
958
- {
959
- "cell_type": "markdown",
960
- "metadata": {},
961
- "source": [
962
- "## Part 8: Summary & Export"
963
- ]
964
- },
965
- {
966
- "cell_type": "code",
967
- "execution_count": null,
968
- "metadata": {},
969
- "outputs": [],
970
- "source": [
971
- "print(\"=\" * 70)\n",
972
- "print(\"FINAL RESULTS SUMMARY\")\n",
973
- "print(\"=\" * 70)\n",
974
- "print()\n",
975
- "print(f\"{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}\")\n",
976
- "print(\"-\" * 67)\n",
977
- "for task in TASKS:\n",
978
- " b = before_results[task][\"grader_score\"]\n",
979
- " a = after_results[task][\"grader_score\"]\n",
980
- " s = baseline_results[\"smart\"][task][\"grader_score\"]\n",
981
- " delta = a - b\n",
982
- " print(f\"{task:<25s} {b:>10.4f} {a:>10.4f} {delta:>+10.4f} {s:>10.4f}\")\n",
983
- "\n",
984
- "avg_before = np.mean([before_results[t][\"grader_score\"] for t in TASKS])\n",
985
- "avg_after = np.mean([after_results[t][\"grader_score\"] for t in TASKS])\n",
986
- "avg_smart = np.mean([baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS])\n",
987
- "print(\"-\" * 67)\n",
988
- "print(f\"{'AVERAGE':<25s} {avg_before:>10.4f} {avg_after:>10.4f} {avg_after - avg_before:>+10.4f} {avg_smart:>10.4f}\")\n",
989
- "print()\n",
990
- "\n",
991
- "summary = {\n",
992
- " \"model\": MODEL_NAME,\n",
993
- " \"training_rounds\": NUM_ROUNDS,\n",
994
- " \"episodes_per_round\": EPISODES_PER_ROUND,\n",
995
- " \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
996
- " \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
997
- " \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
998
- " \"improvement\": {t: after_results[t][\"grader_score\"] - before_results[t][\"grader_score\"] for t in TASKS},\n",
999
- " \"training_log\": training_log,\n",
1000
- "}\n",
1001
- "\n",
1002
- "with open(PLOTS_DIR / \"training_summary.json\", \"w\") as f:\n",
1003
- " json.dump(summary, f, indent=2)\n",
1004
- "\n",
1005
- "print(f\"Saved summary to {PLOTS_DIR / 'training_summary.json'}\")\n",
1006
- "print()\n",
1007
- "print(\"Plots saved:\")\n",
1008
- "for p in sorted(PLOTS_DIR.glob(\"*.png\")):\n",
1009
- " print(f\" {p}\")\n",
1010
- "print()\n",
1011
- "print(\"Training evidence is now real and reproducible.\")"
1012
- ]
1013
- },
1014
- {
1015
- "cell_type": "code",
1016
- "execution_count": null,
1017
- "metadata": {},
1018
- "outputs": [],
1019
- "source": [
1020
- "save_path = \"./viraltest_trained_adapter\"\n",
1021
- "peft_model.save_pretrained(save_path)\n",
1022
- "tokenizer.save_pretrained(save_path)\n",
1023
- "print(f\"Trained adapter saved to {save_path}\")\n",
1024
- "print(\"To load: model = AutoModelForCausalLM.from_pretrained(...); model = PeftModel.from_pretrained(model, save_path)\")"
1025
- ]
1026
- }
1027
- ],
1028
- "metadata": {
1029
- "kernelspec": {
1030
- "display_name": "Python 3",
1031
- "language": "python",
1032
- "name": "python3"
1033
- },
1034
- "language_info": {
1035
- "name": "python",
1036
- "version": "3.10.0"
1037
- }
1038
- },
1039
- "nbformat": 4,
1040
- "nbformat_minor": 4
1041
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Viraltest v2 — Real LLM Training with LoRA + Environment Rewards\n",
8
+ "\n",
9
+ "This notebook **actually trains** an LLM (Qwen2.5-1.5B-Instruct) to play our Instagram creator simulation.\n",
10
+ "\n",
11
+ "**Pipeline:**\n",
12
+ "1. Clone repo & install deps\n",
13
+ "2. Run 5 heuristic baselines × 3 tasks (15 runs) → leaderboard\n",
14
+ "3. Run **untrained** LLM on all 3 tasks \"before\" scores\n",
15
+ "4. **LoRA fine-tune** with reward-weighted SFT (4 rounds × 6 episodes = real weight updates)\n",
16
+ "5. Run **trained** LLM on all 3 tasks → \"after\" scores\n",
17
+ "6. Generate real plots from real numbers\n",
18
+ "\n",
19
+ "**Requirements:** Colab T4 GPU (free tier), ~45 min total.\n",
20
+ "\n",
21
+ "**What makes this real training:** LoRA adapter weights are actually updated via gradient descent. The model's behavior changes because its weights change, not because we edit the prompt."
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "metadata": {},
27
+ "source": [
28
+ "# Cell 1: Install dependencies\n",
29
+ "!pip install -q torch torchvision torchaudio\n",
30
+ "!pip install -q transformers>=4.40.0 accelerate peft>=0.10.0 trl>=0.8.0 datasets bitsandbytes\n",
31
+ "!pip install -q matplotlib pandas\n",
32
+ "!pip install -q pydantic httpx\n",
33
+ "!pip install -q \"openenv-core[core]>=0.2.2\""
34
+ ],
35
+ "execution_count": null,
36
+ "outputs": []
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "metadata": {},
41
+ "source": [
42
+ "# Cell 2: Clone the repo and set up paths\n",
43
+ "import os, sys\n",
44
+ "REPO_DIR = \"/content/viral-posts-env\"\n",
45
+ "if not os.path.exists(REPO_DIR):\n",
46
+ " !git clone https://github.com/VaibhavKhandare/viral-posts-env.git {REPO_DIR}\n",
47
+ "os.chdir(REPO_DIR)\n",
48
+ "sys.path.insert(0, REPO_DIR)\n",
49
+ "\n",
50
+ "PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n",
51
+ "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
52
+ "print(f\"Working dir: {os.getcwd()}\")\n",
53
+ "print(f\"Plots dir: {PLOTS_DIR}\")"
54
+ ],
55
+ "execution_count": null,
56
+ "outputs": []
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "metadata": {},
61
+ "source": [
62
+ "# Cell 3: Imports\n",
63
+ "import json, random, time, textwrap, copy\n",
64
+ "from pathlib import Path\n",
65
+ "from typing import Any, Dict, List, Optional, Tuple\n",
66
+ "from collections import defaultdict\n",
67
+ "\n",
68
+ "import matplotlib.pyplot as plt\n",
69
+ "import numpy as np\n",
70
+ "import pandas as pd\n",
71
+ "import torch\n",
72
+ "\n",
73
+ "from models import ScheduledAction, ToolCall, ViraltestAction\n",
74
+ "from server.viraltest_environment import (\n",
75
+ " ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
76
+ " TOPIC_CATEGORIES,\n",
77
+ ")\n",
78
+ "\n",
79
+ "ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
80
+ "NICHES = list(TOPIC_CATEGORIES.keys())\n",
81
+ "CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
82
+ "INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
83
+ "TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n",
84
+ "\n",
85
+ "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
86
+ "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
87
+ ],
88
+ "execution_count": null,
89
+ "outputs": []
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "## Part 1: Heuristic Baselines\n",
96
+ "\n",
97
+ "5 scripted agents prove the environment differentiates skill levels."
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "metadata": {},
103
+ "source": [
104
+ "# Cell 4: Define heuristic agents + episode runner\n",
105
+ "_rng = random.Random(42)\n",
106
+ "\n",
107
+ "def plan_always_rest(obs_dict, day):\n",
108
+ " return ViraltestAction(scheduled_actions=[])\n",
109
+ "\n",
110
+ "def plan_spam(obs_dict, day):\n",
111
+ " return ViraltestAction(scheduled_actions=[\n",
112
+ " ScheduledAction(hour=h, action_type=\"post\", content_type=\"reel\",\n",
113
+ " topic=\"AI tools\", tags=[\"ai\"], intent=\"watch_bait\")\n",
114
+ " for h in range(24)])\n",
115
+ "\n",
116
+ "def plan_random(obs_dict, day):\n",
117
+ " actions = []\n",
118
+ " for h in range(24):\n",
119
+ " if _rng.random() < 0.1:\n",
120
+ " actions.append(ScheduledAction(\n",
121
+ " hour=h, action_type=\"post\",\n",
122
+ " content_type=_rng.choice(CONTENT_TYPES),\n",
123
+ " topic=_rng.choice(ALL_TOPICS),\n",
124
+ " tags=_rng.sample(TAG_POOL[:30], 3),\n",
125
+ " intent=_rng.choice(INTENTS)))\n",
126
+ " return ViraltestAction(scheduled_actions=actions)\n",
127
+ "\n",
128
+ "def plan_minimal(obs_dict, day):\n",
129
+ " return ViraltestAction(scheduled_actions=[\n",
130
+ " ScheduledAction(hour=12, action_type=\"post\", content_type=\"carousel\",\n",
131
+ " topic=ALL_TOPICS[day % len(ALL_TOPICS)],\n",
132
+ " tags=[TAG_POOL[i % len(TAG_POOL)] for i in range(day, day+3)],\n",
133
+ " intent=\"save_bait\")])\n",
134
+ "\n",
135
+ "def plan_smart(obs_dict, day):\n",
136
+ " return ViraltestAction(\n",
137
+ " tool_calls=[ToolCall(name=\"query_trends\",\n",
138
+ " arguments={\"niche\": NICHES[day % len(NICHES)]})] if day <= 3 else [],\n",
139
+ " scheduled_actions=[\n",
140
+ " ScheduledAction(hour=8, action_type=\"create_content\"),\n",
141
+ " ScheduledAction(hour=12, action_type=\"post\",\n",
142
+ " content_type=CONTENT_TYPES[(day*2)%4],\n",
143
+ " topic=ALL_TOPICS[(day*2)%len(ALL_TOPICS)],\n",
144
+ " tags=[TAG_POOL[(day*6+i)%len(TAG_POOL)] for i in range(3)],\n",
145
+ " intent=INTENTS[(day*2)%4]),\n",
146
+ " ScheduledAction(hour=19, action_type=\"post\",\n",
147
+ " content_type=CONTENT_TYPES[(day*2+1)%4],\n",
148
+ " topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
149
+ " tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
150
+ " intent=INTENTS[(day*2+1)%4]),\n",
151
+ " ],\n",
152
+ " replies=[{\"post_hour\": 12, \"reply_hour\": 13}])\n",
153
+ "\n",
154
+ "BASELINE_AGENTS = {\n",
155
+ " \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
156
+ " \"random\": plan_random, \"minimal\": plan_minimal, \"smart\": plan_smart,\n",
157
+ "}\n",
158
+ "\n",
159
+ "def run_episode(task, plan_fn, seed=42):\n",
160
+ " env = ViraltestEnvironment()\n",
161
+ " obs = env.reset(task=task, seed=seed)\n",
162
+ " obs_dict = obs.model_dump()\n",
163
+ " rewards, energies = [], [obs.creator_energy]\n",
164
+ " for day in range(1, TASK_HORIZON + 1):\n",
165
+ " action = plan_fn(obs_dict, day)\n",
166
+ " obs = env.step(action)\n",
167
+ " obs_dict = obs.model_dump()\n",
168
+ " rewards.append(obs.reward or 0.0)\n",
169
+ " energies.append(obs.creator_energy)\n",
170
+ " if obs.done: break\n",
171
+ " grader = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
172
+ " return {\"grader_score\": grader, \"total_reward\": sum(rewards),\n",
173
+ " \"steps\": len(rewards), \"final_energy\": obs.creator_energy,\n",
174
+ " \"follower_delta\": obs.follower_count - 10000,\n",
175
+ " \"burned_out\": obs.creator_energy <= 0,\n",
176
+ " \"rewards\": rewards, \"energies\": energies}\n",
177
+ "\n",
178
+ "print(\"Agents and episode runner defined.\")"
179
+ ],
180
+ "execution_count": null,
181
+ "outputs": []
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "metadata": {},
186
+ "source": [
187
+ "# Cell 5: Run baselines\n",
188
+ "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
189
+ "print(\"=\" * 70)\n",
190
+ "\n",
191
+ "baseline_results = {}\n",
192
+ "for name, fn in BASELINE_AGENTS.items():\n",
193
+ " baseline_results[name] = {}\n",
194
+ " for task in TASKS:\n",
195
+ " _rng = random.Random(42)\n",
196
+ " result = run_episode(task, fn, seed=42)\n",
197
+ " baseline_results[name][task] = result\n",
198
+ " print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n",
199
+ " f\"| energy={result['final_energy']:.2f}\")\n",
200
+ " print()\n",
201
+ "\n",
202
+ "print(\"\\nLEADERBOARD\")\n",
203
+ "print(f\"{'Agent':<14s} {'Engage':>10s} {'Strategic':>12s} {'Competitive':>14s} {'Avg':>8s}\")\n",
204
+ "print(\"-\" * 60)\n",
205
+ "for name in BASELINE_AGENTS:\n",
206
+ " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
207
+ " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
208
+ ],
209
+ "execution_count": null,
210
+ "outputs": []
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "metadata": {},
215
+ "source": [
216
+ "# Cell 6: Baseline plots\n",
217
+ "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
218
+ "agent_names = list(BASELINE_AGENTS.keys())\n",
219
+ "colors = ['#E53935', '#FF9800', '#9E9E9E', '#42A5F5', '#4CAF50']\n",
220
+ "for i, task in enumerate(TASKS):\n",
221
+ " scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
222
+ " bars = axes[i].barh(agent_names, scores, color=colors)\n",
223
+ " axes[i].set_title(task.replace(\"monthly_\", \"\").title(), fontsize=13, fontweight='bold')\n",
224
+ " for bar, score in zip(bars, scores):\n",
225
+ " axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
226
+ " f\"{score:.4f}\", va='center', fontsize=9)\n",
227
+ "axes[0].set_ylabel(\"Agent\")\n",
228
+ "fig.suptitle(\"Viraltest v2 — Heuristic Baseline Leaderboard\", fontsize=14, fontweight='bold')\n",
229
+ "fig.tight_layout()\n",
230
+ "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
231
+ "plt.show()"
232
+ ],
233
+ "execution_count": null,
234
+ "outputs": []
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "## Part 2: Load LLM (Qwen2.5-1.5B-Instruct)\n",
241
+ "\n",
242
+ "We load the base model with 4-bit quantization to fit in free Colab's T4 GPU (16GB VRAM)."
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "metadata": {},
248
+ "source": [
249
+ "# Cell 7: Load model\n",
250
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
251
+ "\n",
252
+ "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
253
+ "\n",
254
+ "bnb_config = BitsAndBytesConfig(\n",
255
+ " load_in_4bit=True,\n",
256
+ " bnb_4bit_quant_type=\"nf4\",\n",
257
+ " bnb_4bit_compute_dtype=torch.float16,\n",
258
+ " bnb_4bit_use_double_quant=True,\n",
259
+ ")\n",
260
+ "\n",
261
+ "print(f\"Loading {MODEL_NAME} (4-bit quantized)...\")\n",
262
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
263
+ "model = AutoModelForCausalLM.from_pretrained(\n",
264
+ " MODEL_NAME, trust_remote_code=True,\n",
265
+ " quantization_config=bnb_config,\n",
266
+ " device_map=\"auto\",\n",
267
+ ")\n",
268
+ "model.eval()\n",
269
+ "print(f\"Model loaded. Device: {model.device}\")\n",
270
+ "print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
271
+ ],
272
+ "execution_count": null,
273
+ "outputs": []
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "metadata": {},
278
+ "source": [
279
+ "# Cell 8: LLM agent functions\n",
280
+ "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
281
+ "You are an Instagram content strategy agent. Each step is one day.\n",
282
+ "You manage a creator account over a 30-day cycle.\n",
283
+ "\n",
284
+ "RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
285
+ "{\n",
286
+ " \"tool_calls\": [{\"name\": \"query_trends\", \"arguments\": {\"niche\": \"tech\"}}],\n",
287
+ " \"scheduled_actions\": [\n",
288
+ " {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"reel\",\n",
289
+ " \"topic\": \"AI tools\", \"tags\": [\"ai\", \"coding\"], \"intent\": \"watch_bait\"}\n",
290
+ " ],\n",
291
+ " \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
292
+ " \"notes\": \"strategy notes\"\n",
293
+ "}\n",
294
+ "\n",
295
+ "RULES:\n",
296
+ "- content_type: reel|story|carousel|text_post\n",
297
+ "- intent: send_bait|save_bait|watch_bait|like_bait\n",
298
+ "- 1-2 posts/day optimal. More = fatigue.\n",
299
+ "- Empty scheduled_actions = rest (recovers energy).\n",
300
+ "- Vary content types and topics for diversity bonus.\n",
301
+ "- Reply within 90 min of post for reach bonus.\"\"\")\n",
302
+ "\n",
303
+ "\n",
304
+ "def format_obs(obs):\n",
305
+ " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
306
+ " day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
307
+ " signals_str = \"\"\n",
308
+ " signals = getattr(obs, \"engagement_signals\", None)\n",
309
+ " if signals:\n",
310
+ " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
311
+ " f\"sends={signals.sends_per_reach:.3f} \"\n",
312
+ " f\"saves={signals.saves:.3f}\\n\")\n",
313
+ " tool_str = \"\"\n",
314
+ " for tr in getattr(obs, \"tool_results\", []):\n",
315
+ " if tr.success:\n",
316
+ " tool_str += f\" {tr.name}: {json.dumps(tr.data)[:200]}\\n\"\n",
317
+ " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
318
+ " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
319
+ " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
320
+ " f\"{signals_str}\"\n",
321
+ " f\"Tool results:\\n{tool_str if tool_str else ' (none)\\n'}\"\n",
322
+ " f\"Plan your actions (JSON only):\")\n",
323
+ "\n",
324
+ "\n",
325
+ "def parse_model_output(text):\n",
326
+ " text = text.strip()\n",
327
+ " if \"```\" in text:\n",
328
+ " lines = [l for l in text.split(\"\\n\") if not l.strip().startswith(\"```\")]\n",
329
+ " text = \"\\n\".join(lines).strip()\n",
330
+ " start, end = text.find(\"{\"), text.rfind(\"}\") + 1\n",
331
+ " if start >= 0 and end > start:\n",
332
+ " text = text[start:end]\n",
333
+ " try:\n",
334
+ " data = json.loads(text)\n",
335
+ " tool_calls = [ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {}))\n",
336
+ " for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
337
+ " scheduled = []\n",
338
+ " for a in data.get(\"scheduled_actions\", []):\n",
339
+ " try: scheduled.append(ScheduledAction(**a))\n",
340
+ " except: pass\n",
341
+ " return ViraltestAction(tool_calls=tool_calls, scheduled_actions=scheduled,\n",
342
+ " replies=data.get(\"replies\", []), notes=data.get(\"notes\"))\n",
343
+ " except:\n",
344
+ " return ViraltestAction(scheduled_actions=[])\n",
345
+ "\n",
346
+ "\n",
347
+ "def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
348
+ " prompt = format_obs(obs)\n",
349
+ " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
350
+ " messages.extend(history[-4:])\n",
351
+ " messages.append({\"role\": \"user\", \"content\": prompt})\n",
352
+ " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
353
+ " inputs = tok(text_input, return_tensors=\"pt\").to(mdl.device)\n",
354
+ " with torch.no_grad():\n",
355
+ " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
356
+ " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
357
+ " resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
358
+ " return resp, parse_model_output(resp)\n",
359
+ "\n",
360
+ "\n",
361
+ "def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
362
+ " env = ViraltestEnvironment()\n",
363
+ " obs = env.reset(task=task, seed=seed)\n",
364
+ " rewards, energies = [], [obs.creator_energy]\n",
365
+ " history, pairs = [], []\n",
366
+ " for day in range(1, TASK_HORIZON + 1):\n",
367
+ " if obs.done: break\n",
368
+ " if obs.creator_energy <= 0.25:\n",
369
+ " action = ViraltestAction(scheduled_actions=[])\n",
370
+ " resp = '{\"scheduled_actions\": []}'\n",
371
+ " else:\n",
372
+ " resp, action = generate_action(mdl, tok, obs, history)\n",
373
+ " prompt = format_obs(obs)\n",
374
+ " pairs.append({\"prompt\": prompt, \"response\": resp})\n",
375
+ " obs = env.step(action)\n",
376
+ " r = obs.reward or 0.0\n",
377
+ " rewards.append(r)\n",
378
+ " energies.append(obs.creator_energy)\n",
379
+ " history.extend([{\"role\": \"user\", \"content\": prompt},\n",
380
+ " {\"role\": \"assistant\", \"content\": resp}])\n",
381
+ " if verbose:\n",
382
+ " n_p = len([s for s in action.scheduled_actions if s.action_type==\"post\"])\n",
383
+ " print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
384
+ " if obs.done: break\n",
385
+ " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
386
+ " return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
387
+ " \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
388
+ " \"energies\": energies, \"pairs\": pairs,\n",
389
+ " \"follower_delta\": obs.follower_count - 10000,\n",
390
+ " \"burned_out\": obs.creator_energy <= 0}\n",
391
+ "\n",
392
+ "print(\"LLM agent functions defined.\")"
393
+ ],
394
+ "execution_count": null,
395
+ "outputs": []
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "metadata": {},
400
+ "source": [
401
+ "## Part 3: Untrained LLM Baseline (“Before”)\n",
402
+ "\n",
403
+ "Run the base model with NO fine-tuning. This establishes ground truth."
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "metadata": {},
409
+ "source": [
410
+ "# Cell 9: Run untrained model\n",
411
+ "print(\"Running UNTRAINED base model on all tasks...\")\n",
412
+ "print(\"=\" * 60)\n",
413
+ "\n",
414
+ "before_results = {}\n",
415
+ "for task in TASKS:\n",
416
+ " print(f\"\\n Task: {task}\")\n",
417
+ " result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
418
+ " before_results[task] = result\n",
419
+ " print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
420
+ "\n",
421
+ "print(\"\\n\" + \"=\" * 60)\n",
422
+ "print(\"BEFORE TRAINING:\")\n",
423
+ "for t in TASKS:\n",
424
+ " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
425
+ ],
426
+ "execution_count": null,
427
+ "outputs": []
428
+ },
429
+ {
430
+ "cell_type": "markdown",
431
+ "metadata": {},
432
+ "source": [
433
+ "## Part 4: LoRA Fine-Tuning (Real Weight Updates)\n",
434
+ "\n",
435
+ "This is the core training loop. For each round:\n",
436
+ "1. Collect episodes with current model\n",
437
+ "2. Score each (prompt, response) pair by episode reward\n",
438
+ "3. Keep top 50% highest-reward samples\n",
439
+ "4. Fine-tune LoRA weights via SFT on those samples\n",
440
+ "\n",
441
+ "The model's actual weights change via gradient descent — this is real training."
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "metadata": {},
447
+ "source": [
448
+ "# Cell 10: Attach LoRA adapter\n",
449
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
450
+ "\n",
451
+ "lora_config = LoraConfig(\n",
452
+ " r=16, lora_alpha=32, lora_dropout=0.05,\n",
453
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
454
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
455
+ " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
456
+ ")\n",
457
+ "\n",
458
+ "model.enable_input_require_grads()\n",
459
+ "peft_model = get_peft_model(model, lora_config)\n",
460
+ "peft_model.print_trainable_parameters()"
461
+ ],
462
+ "execution_count": null,
463
+ "outputs": []
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "metadata": {},
468
+ "source": [
469
+ "# Cell 11: Training loop\n",
470
+ "from trl import SFTTrainer, SFTConfig\n",
471
+ "from datasets import Dataset\n",
472
+ "\n",
473
+ "NUM_ROUNDS = 4\n",
474
+ "EPISODES_PER_ROUND = 6\n",
475
+ "TOP_K_FRACTION = 0.5\n",
476
+ "\n",
477
+ "training_log = {\n",
478
+ " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
479
+ " \"min_episode_reward\": [], \"avg_grader\": [], \"max_grader\": [],\n",
480
+ " \"n_training_samples\": [], \"train_loss\": [],\n",
481
+ "}\n",
482
+ "\n",
483
+ "t_start = time.time()\n",
484
+ "\n",
485
+ "for round_idx in range(1, NUM_ROUNDS + 1):\n",
486
+ " print(f\"\\n{'=' * 60}\")\n",
487
+ " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
488
+ " print(f\"{'=' * 60}\")\n",
489
+ "\n",
490
+ " # Collect episodes\n",
491
+ " peft_model.eval()\n",
492
+ " all_pairs, episode_rewards, episode_graders = [], [], []\n",
493
+ "\n",
494
+ " for ep in range(EPISODES_PER_ROUND):\n",
495
+ " task = TASKS[ep % len(TASKS)]\n",
496
+ " seed = 42 + (round_idx - 1) * 100 + ep\n",
497
+ " result = run_llm_episode(peft_model, tokenizer, task, seed=seed)\n",
498
+ " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
499
+ " episode_rewards.append(ep_reward)\n",
500
+ " episode_graders.append(result[\"grader_score\"])\n",
501
+ "\n",
502
+ " for pr in result[\"pairs\"]:\n",
503
+ " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
504
+ " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
505
+ " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
506
+ " all_pairs.append({\"text\": text, \"reward\": ep_reward})\n",
507
+ "\n",
508
+ " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
509
+ " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\")\n",
510
+ "\n",
511
+ " avg_r = np.mean(episode_rewards)\n",
512
+ " avg_g = np.mean(episode_graders)\n",
513
+ " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
514
+ "\n",
515
+ " # Filter to top-K\n",
516
+ " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
517
+ " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
518
+ " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
519
+ "\n",
520
+ " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
521
+ "\n",
522
+ " # SFT training (real gradient updates)\n",
523
+ " sft_config = SFTConfig(\n",
524
+ " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
525
+ " num_train_epochs=2,\n",
526
+ " per_device_train_batch_size=1,\n",
527
+ " gradient_accumulation_steps=4,\n",
528
+ " learning_rate=2e-5,\n",
529
+ " warmup_steps=5,\n",
530
+ " logging_steps=5,\n",
531
+ " save_strategy=\"no\",\n",
532
+ " max_seq_length=1024,\n",
533
+ " fp16=True,\n",
534
+ " report_to=\"none\",\n",
535
+ " )\n",
536
+ "\n",
537
+ " peft_model.train()\n",
538
+ " trainer = SFTTrainer(\n",
539
+ " model=peft_model, tokenizer=tokenizer,\n",
540
+ " train_dataset=dataset, args=sft_config,\n",
541
+ " )\n",
542
+ " train_result = trainer.train()\n",
543
+ " loss = train_result.training_loss\n",
544
+ " print(f\" Training loss: {loss:.4f}\")\n",
545
+ "\n",
546
+ " training_log[\"round\"].append(round_idx)\n",
547
+ " training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
548
+ " training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
549
+ " training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
550
+ " training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
551
+ " training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
552
+ " training_log[\"n_training_samples\"].append(len(filtered))\n",
553
+ " training_log[\"train_loss\"].append(round(loss, 4))\n",
554
+ "\n",
555
+ "elapsed = time.time() - t_start\n",
556
+ "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
557
+ "print(pd.DataFrame(training_log).to_string(index=False))"
558
+ ],
559
+ "execution_count": null,
560
+ "outputs": []
561
+ },
562
+ {
563
+ "cell_type": "markdown",
564
+ "metadata": {},
565
+ "source": [
566
+ "## Part 5: Trained LLM Evaluation (“After”)\n",
567
+ "\n",
568
+ "Same model, same seeds, same environment — but now with updated LoRA weights."
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "metadata": {},
574
+ "source": [
575
+ "# Cell 12: Run trained model\n",
576
+ "print(\"Running TRAINED model on all tasks...\")\n",
577
+ "print(\"=\" * 60)\n",
578
+ "\n",
579
+ "peft_model.eval()\n",
580
+ "after_results = {}\n",
581
+ "for task in TASKS:\n",
582
+ " print(f\"\\n Task: {task}\")\n",
583
+ " result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
584
+ " after_results[task] = result\n",
585
+ " print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
586
+ "\n",
587
+ "print(\"\\n\" + \"=\" * 60)\n",
588
+ "print(\"AFTER TRAINING:\")\n",
589
+ "for t in TASKS:\n",
590
+ " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
591
+ ],
592
+ "execution_count": null,
593
+ "outputs": []
594
+ },
595
+ {
596
+ "cell_type": "markdown",
597
+ "metadata": {},
598
+ "source": [
599
+ "## Part 6: Result Plots — Real Training Evidence"
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "metadata": {},
605
+ "source": [
606
+ "# Cell 13: Training curves\n",
607
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
608
+ "rounds = training_log[\"round\"]\n",
609
+ "\n",
610
+ "axes[0].plot(rounds, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
611
+ "axes[0].fill_between(rounds, training_log[\"avg_grader\"],\n",
612
+ " training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
613
+ "axes[0].set_xlabel('Round'); axes[0].set_ylabel('Grader Score')\n",
614
+ "axes[0].set_title('Grader Score Over Rounds', fontweight='bold')\n",
615
+ "axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
616
+ "\n",
617
+ "axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
618
+ "axes[1].set_xlabel('Round'); axes[1].set_ylabel('Loss')\n",
619
+ "axes[1].set_title('Training Loss', fontweight='bold')\n",
620
+ "axes[1].grid(True, alpha=0.3)\n",
621
+ "\n",
622
+ "fig.suptitle('Viraltest v2 — LoRA Training Progress (Qwen 1.5B)', fontsize=14, fontweight='bold')\n",
623
+ "fig.tight_layout()\n",
624
+ "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
625
+ "plt.show()"
626
+ ],
627
+ "execution_count": null,
628
+ "outputs": []
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "metadata": {},
633
+ "source": [
634
+ "# Cell 14: Before vs After\n",
635
+ "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
636
+ "x = np.arange(len(TASKS))\n",
637
+ "w = 0.25\n",
638
+ "\n",
639
+ "fig, ax = plt.subplots(figsize=(10, 6))\n",
640
+ "b_scores = [before_results[t][\"grader_score\"] for t in TASKS]\n",
641
+ "a_scores = [after_results[t][\"grader_score\"] for t in TASKS]\n",
642
+ "s_scores = [baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS]\n",
643
+ "\n",
644
+ "ax.bar(x - w, b_scores, w, label='Base Model (Before)', color='#FF9800')\n",
645
+ "ax.bar(x, a_scores, w, label='LoRA Trained (After)', color='#4CAF50')\n",
646
+ "ax.bar(x + w, s_scores, w, label='Smart Heuristic', color='#9E9E9E', alpha=0.7)\n",
647
+ "\n",
648
+ "ax.set_ylabel('Grader Score'); ax.set_xticks(x); ax.set_xticklabels(task_labels)\n",
649
+ "ax.set_title('Before vs After LoRA Training — Grader Scores', fontsize=14, fontweight='bold')\n",
650
+ "ax.legend(); ax.grid(True, alpha=0.3, axis='y')\n",
651
+ "\n",
652
+ "for container in ax.containers:\n",
653
+ " for bar in container:\n",
654
+ " h = bar.get_height()\n",
655
+ " if h > 0:\n",
656
+ " ax.text(bar.get_x() + bar.get_width()/2., h + 0.005,\n",
657
+ " f'{h:.4f}', ha='center', va='bottom', fontsize=9)\n",
658
+ "\n",
659
+ "fig.tight_layout()\n",
660
+ "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
661
+ "plt.show()"
662
+ ],
663
+ "execution_count": null,
664
+ "outputs": []
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "metadata": {},
669
+ "source": [
670
+ "# Cell 15: Trajectory comparison\n",
671
+ "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
672
+ "comparisons = [\n",
673
+ " (\"Base Model\", before_results, '#FF9800', '--'),\n",
674
+ " (\"LoRA Trained\", after_results, '#4CAF50', '-'),\n",
675
+ "]\n",
676
+ "for i, task in enumerate(TASKS):\n",
677
+ " for label, res, color, ls in comparisons:\n",
678
+ " lw = 2.5 if 'Trained' in label else 1.5\n",
679
+ " axes[0, i].plot(res[task][\"rewards\"], label=label, color=color, lw=lw, ls=ls)\n",
680
+ " axes[1, i].plot(res[task][\"energies\"], label=label, color=color, lw=lw, ls=ls)\n",
681
+ " sr = baseline_results[\"smart\"][task]\n",
682
+ " axes[0, i].plot(sr[\"rewards\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
683
+ " axes[1, i].plot(sr[\"energies\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
684
+ " t_name = task.replace('monthly_', '').title()\n",
685
+ " axes[0, i].set_title(f\"{t_name} — Rewards\"); axes[0, i].grid(True, alpha=0.3)\n",
686
+ " axes[1, i].set_title(f\"{t_name} — Energy\"); axes[1, i].grid(True, alpha=0.3)\n",
687
+ "axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
688
+ "fig.suptitle('Before vs After — Daily Trajectories', fontsize=14, fontweight='bold', y=1.01)\n",
689
+ "fig.tight_layout()\n",
690
+ "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
691
+ "plt.show()"
692
+ ],
693
+ "execution_count": null,
694
+ "outputs": []
695
+ },
696
+ {
697
+ "cell_type": "markdown",
698
+ "metadata": {},
699
+ "source": [
700
+ "## Part 7: Summary & Export"
701
+ ]
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "metadata": {},
706
+ "source": [
707
+ "# Cell 16: Final summary\n",
708
+ "print(\"=\" * 67)\n",
709
+ "print(\"FINAL RESULTS\")\n",
710
+ "print(\"=\" * 67)\n",
711
+ "print(f\"\\n{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}\")\n",
712
+ "print(\"-\" * 67)\n",
713
+ "for task in TASKS:\n",
714
+ " b = before_results[task][\"grader_score\"]\n",
715
+ " a = after_results[task][\"grader_score\"]\n",
716
+ " s = baseline_results[\"smart\"][task][\"grader_score\"]\n",
717
+ " print(f\"{task:<25s} {b:>10.4f} {a:>10.4f} {a-b:>+10.4f} {s:>10.4f}\")\n",
718
+ "\n",
719
+ "avg_b = np.mean([before_results[t][\"grader_score\"] for t in TASKS])\n",
720
+ "avg_a = np.mean([after_results[t][\"grader_score\"] for t in TASKS])\n",
721
+ "avg_s = np.mean([baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS])\n",
722
+ "print(\"-\" * 67)\n",
723
+ "print(f\"{'AVERAGE':<25s} {avg_b:>10.4f} {avg_a:>10.4f} {avg_a-avg_b:>+10.4f} {avg_s:>10.4f}\")\n",
724
+ "\n",
725
+ "summary = {\n",
726
+ " \"model\": MODEL_NAME,\n",
727
+ " \"training\": \"LoRA SFT (real weight updates)\",\n",
728
+ " \"rounds\": NUM_ROUNDS, \"episodes_per_round\": EPISODES_PER_ROUND,\n",
729
+ " \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
730
+ " \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
731
+ " \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
732
+ " \"improvement\": {t: after_results[t][\"grader_score\"] - before_results[t][\"grader_score\"] for t in TASKS},\n",
733
+ " \"training_log\": training_log,\n",
734
+ "}\n",
735
+ "with open(f\"{PLOTS_DIR}/training_summary.json\", \"w\") as f:\n",
736
+ " json.dump(summary, f, indent=2)\n",
737
+ "\n",
738
+ "pd.DataFrame(training_log).to_csv(f\"{PLOTS_DIR}/training_log.csv\", index=False)\n",
739
+ "\n",
740
+ "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
741
+ "print(\"All results are from real LoRA weight updates on real environment runs.\")"
742
+ ],
743
+ "execution_count": null,
744
+ "outputs": []
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "metadata": {},
749
+ "source": [
750
+ "# Cell 17: Save adapter\n",
751
+ "save_path = \"./viraltest_trained_adapter\"\n",
752
+ "peft_model.save_pretrained(save_path)\n",
753
+ "tokenizer.save_pretrained(save_path)\n",
754
+ "print(f\"LoRA adapter saved to {save_path}\")\n",
755
+ "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
756
+ ],
757
+ "execution_count": null,
758
+ "outputs": []
759
+ }
760
+ ],
761
+ "metadata": {
762
+ "kernelspec": {
763
+ "display_name": "Python 3",
764
+ "language": "python",
765
+ "name": "python3"
766
+ },
767
+ "language_info": {
768
+ "name": "python",
769
+ "version": "3.10.0"
770
+ },
771
+ "accelerator": "GPU",
772
+ "gpuClass": "standard"
773
  },
774
+ "nbformat": 4,
775
+ "nbformat_minor": 4
776
+ }