Spaces:
Paused
Paused
Commit ·
8970072
1
Parent(s): 1a2a407
pounteradds
Browse files- server/viraltest_environment.py +72 -20
- training/train_grpo.ipynb +208 -158
server/viraltest_environment.py
CHANGED
|
@@ -387,6 +387,8 @@ class ViraltestEnvironment(Environment):
|
|
| 387 |
self._hours_since_sleep = 2
|
| 388 |
self._sleep_debt = 0.0
|
| 389 |
|
|
|
|
|
|
|
| 390 |
def _load_competitors(self) -> List[CompetitorState]:
|
| 391 |
archetypes = _COMPETITORS_DATA.get("archetypes", [])
|
| 392 |
return [
|
|
@@ -1136,6 +1138,8 @@ class ViraltestEnvironment(Environment):
|
|
| 1136 |
|
| 1137 |
self._shift_label = kwargs.get("shift_label")
|
| 1138 |
self._chain_id = kwargs.get("episode_chain_id")
|
|
|
|
|
|
|
| 1139 |
|
| 1140 |
if self._chain_id and self._chain_id in _BRAND_STORE:
|
| 1141 |
brand = _BRAND_STORE[self._chain_id]
|
|
@@ -1439,20 +1443,29 @@ class ViraltestEnvironment(Environment):
|
|
| 1439 |
# ----- reward -----
|
| 1440 |
|
| 1441 |
def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1443 |
|
|
|
|
| 1444 |
prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
|
| 1445 |
energy_delta = self._energy - prev_energy
|
| 1446 |
-
|
| 1447 |
|
|
|
|
| 1448 |
day_posts = self._posts_per_day.get(self._day, 0)
|
| 1449 |
if 1 <= day_posts <= 2:
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
|
|
|
|
|
|
|
|
|
| 1456 |
|
| 1457 |
tag_component = 0.0
|
| 1458 |
if sa.action_type == "post" and sa.tags:
|
|
@@ -1474,22 +1487,54 @@ class ViraltestEnvironment(Environment):
|
|
| 1474 |
)
|
| 1475 |
return max(0.0, min(1.0, raw))
|
| 1476 |
|
| 1477 |
-
def
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
|
|
|
| 1481 |
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
else
|
| 1488 |
-
consistency = 0.0
|
| 1489 |
-
consistency_component = consistency * 0.15
|
| 1490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1491 |
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
| 1492 |
raw = energy_component + consistency_component - burnout_penalty
|
|
|
|
|
|
|
| 1493 |
return max(0.0, min(1.0, raw))
|
| 1494 |
|
| 1495 |
def _advance_time(self) -> None:
|
|
@@ -1700,6 +1745,13 @@ class ViraltestEnvironment(Environment):
|
|
| 1700 |
return max(0.0, min(1.0, raw))
|
| 1701 |
|
| 1702 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1703 |
def _topic_overlap(topic_a: str, topic_b: str) -> bool:
|
| 1704 |
words_a = set(topic_a.split())
|
| 1705 |
words_b = set(topic_b.split())
|
|
|
|
| 387 |
self._hours_since_sleep = 2
|
| 388 |
self._sleep_debt = 0.0
|
| 389 |
|
| 390 |
+
self._reward_mode = "combined"
|
| 391 |
+
|
| 392 |
def _load_competitors(self) -> List[CompetitorState]:
|
| 393 |
archetypes = _COMPETITORS_DATA.get("archetypes", [])
|
| 394 |
return [
|
|
|
|
| 1138 |
|
| 1139 |
self._shift_label = kwargs.get("shift_label")
|
| 1140 |
self._chain_id = kwargs.get("episode_chain_id")
|
| 1141 |
+
mode = kwargs.get("reward_mode", "combined")
|
| 1142 |
+
self._reward_mode = mode if mode in ("timing", "content", "combined") else "combined"
|
| 1143 |
|
| 1144 |
if self._chain_id and self._chain_id in _BRAND_STORE:
|
| 1145 |
brand = _BRAND_STORE[self._chain_id]
|
|
|
|
| 1443 |
# ----- reward -----
|
| 1444 |
|
| 1445 |
def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1446 |
+
if self._reward_mode == "timing":
|
| 1447 |
+
return self._compute_timing_reward(sa, engagement)
|
| 1448 |
+
if self._reward_mode == "content":
|
| 1449 |
+
return self._compute_content_reward(sa, engagement)
|
| 1450 |
+
return self._compute_combined_reward(sa, engagement)
|
| 1451 |
|
| 1452 |
+
def _energy_component(self) -> float:
|
| 1453 |
prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
|
| 1454 |
energy_delta = self._energy - prev_energy
|
| 1455 |
+
return max(0.0, min(1.0, (energy_delta + 0.3) / 0.6))
|
| 1456 |
|
| 1457 |
+
def _consistency_score(self) -> float:
|
| 1458 |
day_posts = self._posts_per_day.get(self._day, 0)
|
| 1459 |
if 1 <= day_posts <= 2:
|
| 1460 |
+
return 1.0
|
| 1461 |
+
if day_posts == 0 or day_posts == 3:
|
| 1462 |
+
return 0.5
|
| 1463 |
+
return 0.0
|
| 1464 |
+
|
| 1465 |
+
def _compute_combined_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1466 |
+
eng_component = min(1.0, engagement / 2.0) * 0.3
|
| 1467 |
+
energy_component = self._energy_component() * 0.15
|
| 1468 |
+
consistency_component = self._consistency_score() * 0.15
|
| 1469 |
|
| 1470 |
tag_component = 0.0
|
| 1471 |
if sa.action_type == "post" and sa.tags:
|
|
|
|
| 1487 |
)
|
| 1488 |
return max(0.0, min(1.0, raw))
|
| 1489 |
|
| 1490 |
+
def _compute_timing_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1491 |
+
is_post = sa.action_type == "post"
|
| 1492 |
+
peak_hour_mult = 1.3 if is_post and self._get_hour_multiplier() >= 1.2 else 1.0
|
| 1493 |
+
trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
|
| 1494 |
+
eng_component = min(1.0, engagement / 2.0) * 0.40 * trending_topic_mult * peak_hour_mult
|
| 1495 |
|
| 1496 |
+
peak_bonus = min(1.0, self._get_hour_multiplier() / 1.3) if is_post else 0.0
|
| 1497 |
+
peak_component = peak_bonus * 0.20
|
| 1498 |
+
|
| 1499 |
+
energy_component = self._energy_component() * 0.20
|
| 1500 |
+
consistency_component = self._consistency_score() * 0.20
|
| 1501 |
+
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
|
|
|
|
|
|
| 1502 |
|
| 1503 |
+
raw = eng_component + peak_component + energy_component + consistency_component - burnout_penalty
|
| 1504 |
+
return max(0.0, min(1.0, raw))
|
| 1505 |
+
|
| 1506 |
+
def _compute_content_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1507 |
+
is_post = sa.action_type == "post"
|
| 1508 |
+
trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
|
| 1509 |
+
eng_component = min(1.0, engagement / 2.0) * 0.20 * trending_topic_mult
|
| 1510 |
+
|
| 1511 |
+
tag_component = 0.0
|
| 1512 |
+
if is_post and sa.tags:
|
| 1513 |
+
trending_match = sum(1 for t in sa.tags if t.lower() in self._trending_tags) / 5.0
|
| 1514 |
+
tag_component = min(1.0, trending_match + 0.3) * 0.25
|
| 1515 |
+
|
| 1516 |
+
comp_component = 0.0
|
| 1517 |
+
if is_post:
|
| 1518 |
+
diff = self._calc_competitor_diff(sa.topic)
|
| 1519 |
+
comp_component = min(1.0, diff / 1.3) * 0.25
|
| 1520 |
+
|
| 1521 |
+
variety_component = 0.0
|
| 1522 |
+
intent_component = 0.0
|
| 1523 |
+
if is_post:
|
| 1524 |
+
variety_component = min(1.0, len(self._unique_content_types) / 4.0) * 0.15
|
| 1525 |
+
intent_component = (0.15 if sa.intent in INTENT_MULTIPLIER else 0.0)
|
| 1526 |
+
|
| 1527 |
+
burnout_penalty = 0.05 if self._energy < 0.2 else 0.0
|
| 1528 |
+
raw = eng_component + tag_component + comp_component + variety_component + intent_component - burnout_penalty
|
| 1529 |
+
return max(0.0, min(1.0, raw))
|
| 1530 |
+
|
| 1531 |
+
def _compute_rest_reward(self) -> float:
|
| 1532 |
+
energy_component = self._energy_component() * 0.15
|
| 1533 |
+
consistency_component = self._consistency_score() * 0.15
|
| 1534 |
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
| 1535 |
raw = energy_component + consistency_component - burnout_penalty
|
| 1536 |
+
if self._reward_mode == "content":
|
| 1537 |
+
raw *= 0.5
|
| 1538 |
return max(0.0, min(1.0, raw))
|
| 1539 |
|
| 1540 |
def _advance_time(self) -> None:
|
|
|
|
| 1745 |
return max(0.0, min(1.0, raw))
|
| 1746 |
|
| 1747 |
|
| 1748 |
+
def get_peak_hours(day_of_week: int, top_k: int = 2) -> List[int]:
|
| 1749 |
+
row = _HEATMAP_GRID.get(day_of_week % 7, [])
|
| 1750 |
+
if not row:
|
| 1751 |
+
return []
|
| 1752 |
+
return sorted(range(len(row)), key=lambda h: row[h], reverse=True)[:top_k]
|
| 1753 |
+
|
| 1754 |
+
|
| 1755 |
def _topic_overlap(topic_a: str, topic_b: str) -> bool:
|
| 1756 |
words_a = set(topic_a.split())
|
| 1757 |
words_b = set(topic_b.split())
|
training/train_grpo.ipynb
CHANGED
|
@@ -25,9 +25,7 @@
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
-
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 33 |
"!pip install -q torch torchvision torchaudio\n",
|
|
@@ -36,13 +34,13 @@
|
|
| 36 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 37 |
"!pip install -q \"openenv-core[core]>=0.2.2\"\n",
|
| 38 |
"!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 39 |
-
]
|
|
|
|
|
|
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
-
"execution_count": null,
|
| 44 |
"metadata": {},
|
| 45 |
-
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 48 |
"import os\n",
|
|
@@ -118,13 +116,13 @@
|
|
| 118 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 119 |
"print(f\"Commit: {commit}\")\n",
|
| 120 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 121 |
-
]
|
|
|
|
|
|
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
| 125 |
-
"execution_count": null,
|
| 126 |
"metadata": {},
|
| 127 |
-
"outputs": [],
|
| 128 |
"source": [
|
| 129 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 130 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
@@ -156,7 +154,7 @@
|
|
| 156 |
"from models import ScheduledAction, ToolCall, ViraltestAction\n",
|
| 157 |
"from server.viraltest_environment import (\n",
|
| 158 |
" ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
|
| 159 |
-
" TOPIC_CATEGORIES,\n",
|
| 160 |
")\n",
|
| 161 |
"\n",
|
| 162 |
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
|
@@ -178,7 +176,9 @@
|
|
| 178 |
"import ast\n",
|
| 179 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 180 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 181 |
-
]
|
|
|
|
|
|
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
@@ -191,9 +191,7 @@
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
-
"execution_count": null,
|
| 195 |
"metadata": {},
|
| 196 |
-
"outputs": [],
|
| 197 |
"source": [
|
| 198 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 199 |
"_rng = random.Random(42)\n",
|
|
@@ -269,13 +267,13 @@
|
|
| 269 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 270 |
"\n",
|
| 271 |
"print(\"Agents and episode runner defined.\")"
|
| 272 |
-
]
|
|
|
|
|
|
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
| 276 |
-
"execution_count": null,
|
| 277 |
"metadata": {},
|
| 278 |
-
"outputs": [],
|
| 279 |
"source": [
|
| 280 |
"# Cell 5: Run baselines (safe)\n",
|
| 281 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
@@ -310,13 +308,13 @@
|
|
| 310 |
"for name in BASELINE_AGENTS:\n",
|
| 311 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 312 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 313 |
-
]
|
|
|
|
|
|
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
| 317 |
-
"execution_count": null,
|
| 318 |
"metadata": {},
|
| 319 |
-
"outputs": [],
|
| 320 |
"source": [
|
| 321 |
"# Cell 6: Baseline plots\n",
|
| 322 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
@@ -334,7 +332,9 @@
|
|
| 334 |
"fig.tight_layout()\n",
|
| 335 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 336 |
"plt.show()"
|
| 337 |
-
]
|
|
|
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
@@ -347,9 +347,7 @@
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"cell_type": "code",
|
| 350 |
-
"execution_count": null,
|
| 351 |
"metadata": {},
|
| 352 |
-
"outputs": [],
|
| 353 |
"source": [
|
| 354 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 355 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
@@ -393,13 +391,13 @@
|
|
| 393 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 394 |
"if torch.cuda.is_available():\n",
|
| 395 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 396 |
-
]
|
|
|
|
|
|
|
| 397 |
},
|
| 398 |
{
|
| 399 |
"cell_type": "code",
|
| 400 |
-
"execution_count": null,
|
| 401 |
"metadata": {},
|
| 402 |
-
"outputs": [],
|
| 403 |
"source": [
|
| 404 |
"# Cell 8: LLM agent functions\n",
|
| 405 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
@@ -454,6 +452,16 @@
|
|
| 454 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 455 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 456 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
"\n",
|
| 458 |
"_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 459 |
"\n",
|
|
@@ -472,7 +480,7 @@
|
|
| 472 |
" return out\n",
|
| 473 |
"\n",
|
| 474 |
"\n",
|
| 475 |
-
"def format_obs(obs, history=None):\n",
|
| 476 |
" day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 477 |
" signals_str = \"\"\n",
|
| 478 |
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
|
@@ -486,12 +494,14 @@
|
|
| 486 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 487 |
" if not tool_str:\n",
|
| 488 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
|
|
|
| 489 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 490 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 491 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 492 |
" f\"{signals_str}\"\n",
|
| 493 |
" f\"{_format_history(history)}\"\n",
|
| 494 |
" f\"Tool results:\\n{tool_str}\"\n",
|
|
|
|
| 495 |
" f\"Plan today's actions (JSON only):\")\n",
|
| 496 |
"\n",
|
| 497 |
"\n",
|
|
@@ -615,12 +625,13 @@
|
|
| 615 |
" return out\n",
|
| 616 |
"\n",
|
| 617 |
"\n",
|
| 618 |
-
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None,
|
|
|
|
| 619 |
" \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
|
| 620 |
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 621 |
" n = len(tasks_seeds)\n",
|
| 622 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 623 |
-
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 624 |
" rewards = [[] for _ in range(n)]\n",
|
| 625 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 626 |
" pairs = [[] for _ in range(n)]\n",
|
|
@@ -641,7 +652,12 @@
|
|
| 641 |
"\n",
|
| 642 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 643 |
" if active:\n",
|
| 644 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
"\n",
|
| 646 |
" disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
|
| 647 |
" disc_resps, ptok = _gen(disc_prompts)\n",
|
|
@@ -716,7 +732,9 @@
|
|
| 716 |
"\n",
|
| 717 |
"\n",
|
| 718 |
"print(\"LLM agent functions defined (batched).\")"
|
| 719 |
-
]
|
|
|
|
|
|
|
| 720 |
},
|
| 721 |
{
|
| 722 |
"cell_type": "markdown",
|
|
@@ -729,9 +747,7 @@
|
|
| 729 |
},
|
| 730 |
{
|
| 731 |
"cell_type": "code",
|
| 732 |
-
"execution_count": null,
|
| 733 |
"metadata": {},
|
| 734 |
-
"outputs": [],
|
| 735 |
"source": [
|
| 736 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 737 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
@@ -745,7 +761,9 @@
|
|
| 745 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 746 |
"for t in TASKS:\n",
|
| 747 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 748 |
-
]
|
|
|
|
|
|
|
| 749 |
},
|
| 750 |
{
|
| 751 |
"cell_type": "markdown",
|
|
@@ -764,9 +782,7 @@
|
|
| 764 |
},
|
| 765 |
{
|
| 766 |
"cell_type": "code",
|
| 767 |
-
"execution_count": null,
|
| 768 |
"metadata": {},
|
| 769 |
-
"outputs": [],
|
| 770 |
"source": [
|
| 771 |
"# Cell 10: Attach LoRA adapter\n",
|
| 772 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
@@ -780,118 +796,144 @@
|
|
| 780 |
"model.enable_input_require_grads()\n",
|
| 781 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 782 |
"peft_model.print_trainable_parameters()"
|
| 783 |
-
]
|
|
|
|
|
|
|
| 784 |
},
|
| 785 |
{
|
| 786 |
"cell_type": "code",
|
| 787 |
-
"execution_count": null,
|
| 788 |
"metadata": {},
|
| 789 |
-
"outputs": [],
|
| 790 |
"source": [
|
| 791 |
-
"# Cell 11:
|
|
|
|
|
|
|
| 792 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 793 |
"from datasets import Dataset\n",
|
| 794 |
"\n",
|
| 795 |
-
"NUM_ROUNDS = 2\n",
|
| 796 |
"EPISODES_PER_ROUND = 6\n",
|
| 797 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
"\n",
|
| 799 |
"training_log = {\n",
|
| 800 |
-
" \"round\": [], \"
|
| 801 |
-
" \"
|
|
|
|
| 802 |
" \"n_training_samples\": [], \"train_loss\": [],\n",
|
| 803 |
"}\n",
|
| 804 |
"\n",
|
| 805 |
"t_start = time.time()\n",
|
| 806 |
-
"\n",
|
| 807 |
-
"
|
| 808 |
-
"
|
| 809 |
-
"
|
| 810 |
-
"
|
| 811 |
-
"\n",
|
| 812 |
-
"
|
| 813 |
-
"
|
| 814 |
-
"
|
| 815 |
-
"
|
| 816 |
-
"
|
| 817 |
-
"
|
| 818 |
-
"
|
| 819 |
-
"\n",
|
| 820 |
-
"
|
| 821 |
-
"
|
| 822 |
-
"
|
| 823 |
-
"
|
| 824 |
-
"
|
| 825 |
-
"
|
| 826 |
-
"
|
| 827 |
-
"
|
| 828 |
-
"
|
| 829 |
-
"
|
| 830 |
-
"
|
| 831 |
-
"
|
| 832 |
-
"
|
| 833 |
-
" kept
|
| 834 |
-
"
|
| 835 |
-
"
|
| 836 |
-
"\n",
|
| 837 |
-
"
|
| 838 |
-
"
|
| 839 |
-
"
|
| 840 |
-
"
|
| 841 |
-
"
|
| 842 |
-
"
|
| 843 |
-
"
|
| 844 |
-
"
|
| 845 |
-
"
|
| 846 |
-
"
|
| 847 |
-
"\n",
|
| 848 |
-
"
|
| 849 |
-
"
|
| 850 |
-
"
|
| 851 |
-
"
|
| 852 |
-
"
|
| 853 |
-
"
|
| 854 |
-
"
|
| 855 |
-
"\n",
|
| 856 |
-
"
|
| 857 |
-
"\n",
|
| 858 |
-
"
|
| 859 |
-
"
|
| 860 |
-
"
|
| 861 |
-
"
|
| 862 |
-
"
|
| 863 |
-
"
|
| 864 |
-
"
|
| 865 |
-
"
|
| 866 |
-
"
|
| 867 |
-
"
|
| 868 |
-
"
|
| 869 |
-
"
|
| 870 |
-
"
|
| 871 |
-
"
|
| 872 |
-
"\n",
|
| 873 |
-
"
|
| 874 |
-
"
|
| 875 |
-
"
|
| 876 |
-
"
|
| 877 |
-
"
|
| 878 |
-
"
|
| 879 |
-
"
|
| 880 |
-
"
|
| 881 |
-
"\n",
|
| 882 |
-
"
|
| 883 |
-
"
|
| 884 |
-
"
|
| 885 |
-
"
|
| 886 |
-
"
|
| 887 |
-
"
|
| 888 |
-
"
|
| 889 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
"\n",
|
| 891 |
"elapsed = time.time() - t_start\n",
|
| 892 |
-
"print(f\"\\
|
| 893 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 894 |
-
]
|
|
|
|
|
|
|
| 895 |
},
|
| 896 |
{
|
| 897 |
"cell_type": "markdown",
|
|
@@ -904,9 +946,7 @@
|
|
| 904 |
},
|
| 905 |
{
|
| 906 |
"cell_type": "code",
|
| 907 |
-
"execution_count": null,
|
| 908 |
"metadata": {},
|
| 909 |
-
"outputs": [],
|
| 910 |
"source": [
|
| 911 |
"# Cell 12: Run trained model (batched)\n",
|
| 912 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
@@ -921,7 +961,9 @@
|
|
| 921 |
"print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 922 |
"for t in TASKS:\n",
|
| 923 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 924 |
-
]
|
|
|
|
|
|
|
| 925 |
},
|
| 926 |
{
|
| 927 |
"cell_type": "markdown",
|
|
@@ -932,37 +974,41 @@
|
|
| 932 |
},
|
| 933 |
{
|
| 934 |
"cell_type": "code",
|
| 935 |
-
"execution_count": null,
|
| 936 |
"metadata": {},
|
| 937 |
-
"outputs": [],
|
| 938 |
"source": [
|
| 939 |
-
"# Cell 13: Training curves\n",
|
| 940 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 941 |
-
"
|
|
|
|
|
|
|
| 942 |
"\n",
|
| 943 |
-
"axes[0].plot(
|
| 944 |
-
"axes[0].fill_between(
|
| 945 |
" training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
|
| 946 |
-
"
|
| 947 |
-
"axes[0].
|
|
|
|
|
|
|
| 948 |
"axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
|
| 949 |
"\n",
|
| 950 |
-
"axes[1].plot(
|
| 951 |
-
"
|
|
|
|
|
|
|
| 952 |
"axes[1].set_title('Training Loss', fontweight='bold')\n",
|
| 953 |
"axes[1].grid(True, alpha=0.3)\n",
|
| 954 |
"\n",
|
| 955 |
-
"fig.suptitle('Viraltest v2 — LoRA Training
|
| 956 |
"fig.tight_layout()\n",
|
| 957 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 958 |
"plt.show()"
|
| 959 |
-
]
|
|
|
|
|
|
|
| 960 |
},
|
| 961 |
{
|
| 962 |
"cell_type": "code",
|
| 963 |
-
"execution_count": null,
|
| 964 |
"metadata": {},
|
| 965 |
-
"outputs": [],
|
| 966 |
"source": [
|
| 967 |
"# Cell 14: Before vs After\n",
|
| 968 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
@@ -992,13 +1038,13 @@
|
|
| 992 |
"fig.tight_layout()\n",
|
| 993 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 994 |
"plt.show()"
|
| 995 |
-
]
|
|
|
|
|
|
|
| 996 |
},
|
| 997 |
{
|
| 998 |
"cell_type": "code",
|
| 999 |
-
"execution_count": null,
|
| 1000 |
"metadata": {},
|
| 1001 |
-
"outputs": [],
|
| 1002 |
"source": [
|
| 1003 |
"# Cell 15: Trajectory comparison\n",
|
| 1004 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
@@ -1022,7 +1068,9 @@
|
|
| 1022 |
"fig.tight_layout()\n",
|
| 1023 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1024 |
"plt.show()"
|
| 1025 |
-
]
|
|
|
|
|
|
|
| 1026 |
},
|
| 1027 |
{
|
| 1028 |
"cell_type": "markdown",
|
|
@@ -1033,9 +1081,7 @@
|
|
| 1033 |
},
|
| 1034 |
{
|
| 1035 |
"cell_type": "code",
|
| 1036 |
-
"execution_count": null,
|
| 1037 |
"metadata": {},
|
| 1038 |
-
"outputs": [],
|
| 1039 |
"source": [
|
| 1040 |
"# Cell 16: Final summary\n",
|
| 1041 |
"print(\"=\" * 67)\n",
|
|
@@ -1057,8 +1103,10 @@
|
|
| 1057 |
"\n",
|
| 1058 |
"summary = {\n",
|
| 1059 |
" \"model\": MODEL_NAME,\n",
|
| 1060 |
-
" \"training\": \"LoRA SFT (
|
| 1061 |
-
" \"
|
|
|
|
|
|
|
| 1062 |
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1063 |
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1064 |
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
|
@@ -1072,13 +1120,13 @@
|
|
| 1072 |
"\n",
|
| 1073 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1074 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1075 |
-
]
|
|
|
|
|
|
|
| 1076 |
},
|
| 1077 |
{
|
| 1078 |
"cell_type": "code",
|
| 1079 |
-
"execution_count": null,
|
| 1080 |
"metadata": {},
|
| 1081 |
-
"outputs": [],
|
| 1082 |
"source": [
|
| 1083 |
"# Cell 17: Save adapter\n",
|
| 1084 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
@@ -1086,7 +1134,9 @@
|
|
| 1086 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1087 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1088 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1089 |
-
]
|
|
|
|
|
|
|
| 1090 |
}
|
| 1091 |
],
|
| 1092 |
"metadata": {
|
|
@@ -1112,4 +1162,4 @@
|
|
| 1112 |
},
|
| 1113 |
"nbformat": 4,
|
| 1114 |
"nbformat_minor": 4
|
| 1115 |
-
}
|
|
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 31 |
"!pip install -q torch torchvision torchaudio\n",
|
|
|
|
| 34 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 35 |
"!pip install -q \"openenv-core[core]>=0.2.2\"\n",
|
| 36 |
"!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 37 |
+
],
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"outputs": []
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
|
|
|
| 43 |
"metadata": {},
|
|
|
|
| 44 |
"source": [
|
| 45 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 46 |
"import os\n",
|
|
|
|
| 116 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 117 |
"print(f\"Commit: {commit}\")\n",
|
| 118 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 119 |
+
],
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"outputs": []
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
|
|
|
| 125 |
"metadata": {},
|
|
|
|
| 126 |
"source": [
|
| 127 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 128 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
|
|
| 154 |
"from models import ScheduledAction, ToolCall, ViraltestAction\n",
|
| 155 |
"from server.viraltest_environment import (\n",
|
| 156 |
" ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
|
| 157 |
+
" TOPIC_CATEGORIES, get_peak_hours,\n",
|
| 158 |
")\n",
|
| 159 |
"\n",
|
| 160 |
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
|
|
|
| 176 |
"import ast\n",
|
| 177 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 178 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 179 |
+
],
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"outputs": []
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
|
|
|
| 194 |
"metadata": {},
|
|
|
|
| 195 |
"source": [
|
| 196 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 197 |
"_rng = random.Random(42)\n",
|
|
|
|
| 267 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 268 |
"\n",
|
| 269 |
"print(\"Agents and episode runner defined.\")"
|
| 270 |
+
],
|
| 271 |
+
"execution_count": null,
|
| 272 |
+
"outputs": []
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
|
|
|
| 276 |
"metadata": {},
|
|
|
|
| 277 |
"source": [
|
| 278 |
"# Cell 5: Run baselines (safe)\n",
|
| 279 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
|
|
| 308 |
"for name in BASELINE_AGENTS:\n",
|
| 309 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 310 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 311 |
+
],
|
| 312 |
+
"execution_count": null,
|
| 313 |
+
"outputs": []
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
|
|
|
| 317 |
"metadata": {},
|
|
|
|
| 318 |
"source": [
|
| 319 |
"# Cell 6: Baseline plots\n",
|
| 320 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
|
|
| 332 |
"fig.tight_layout()\n",
|
| 333 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 334 |
"plt.show()"
|
| 335 |
+
],
|
| 336 |
+
"execution_count": null,
|
| 337 |
+
"outputs": []
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"cell_type": "code",
|
|
|
|
| 350 |
"metadata": {},
|
|
|
|
| 351 |
"source": [
|
| 352 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 353 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
|
|
| 391 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 392 |
"if torch.cuda.is_available():\n",
|
| 393 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 394 |
+
],
|
| 395 |
+
"execution_count": null,
|
| 396 |
+
"outputs": []
|
| 397 |
},
|
| 398 |
{
|
| 399 |
"cell_type": "code",
|
|
|
|
| 400 |
"metadata": {},
|
|
|
|
| 401 |
"source": [
|
| 402 |
"# Cell 8: LLM agent functions\n",
|
| 403 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
|
|
| 452 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 453 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 454 |
"\n",
|
| 455 |
+
"SYSTEM_PROMPT_TIMING = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"FOCUS: optimise WHEN to post. Identify peak hours for the audience (use query_audience / query_trends).\n",
|
| 458 |
+
"2 posts/day at peak hours beats 4 posts at random hours.\"\"\")\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"SYSTEM_PROMPT_CONTENT = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"FOCUS: optimise WHAT to post. Vary content_type and intent across the week,\n",
|
| 463 |
+
"pick differentiated topics, exploit trending tags.\"\"\")\n",
|
| 464 |
+
"\n",
|
| 465 |
"\n",
|
| 466 |
"_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 467 |
"\n",
|
|
|
|
| 480 |
" return out\n",
|
| 481 |
"\n",
|
| 482 |
"\n",
|
| 483 |
+
"def format_obs(obs, history=None, extra_hint=None):\n",
|
| 484 |
" day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 485 |
" signals_str = \"\"\n",
|
| 486 |
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
|
|
|
| 494 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 495 |
" if not tool_str:\n",
|
| 496 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
| 497 |
+
" hint_str = f\"Coach hint: today's peak hours are {extra_hint}.\\n\" if extra_hint else \"\"\n",
|
| 498 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 499 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 500 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 501 |
" f\"{signals_str}\"\n",
|
| 502 |
" f\"{_format_history(history)}\"\n",
|
| 503 |
" f\"Tool results:\\n{tool_str}\"\n",
|
| 504 |
+
" f\"{hint_str}\"\n",
|
| 505 |
" f\"Plan today's actions (JSON only):\")\n",
|
| 506 |
"\n",
|
| 507 |
"\n",
|
|
|
|
| 625 |
" return out\n",
|
| 626 |
"\n",
|
| 627 |
"\n",
|
| 628 |
+
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None,\n",
|
| 629 |
+
" log_tag=None, hint_peak_hours=False, reward_mode=\"combined\"):\n",
|
| 630 |
" \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
|
| 631 |
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 632 |
" n = len(tasks_seeds)\n",
|
| 633 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 634 |
+
" obss = [envs[i].reset(task=t, seed=s, reward_mode=reward_mode) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 635 |
" rewards = [[] for _ in range(n)]\n",
|
| 636 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 637 |
" pairs = [[] for _ in range(n)]\n",
|
|
|
|
| 652 |
"\n",
|
| 653 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 654 |
" if active:\n",
|
| 655 |
+
" def _hint_for(i):\n",
|
| 656 |
+
" if not hint_peak_hours:\n",
|
| 657 |
+
" return None\n",
|
| 658 |
+
" hrs = get_peak_hours(obss[i].day_of_week, top_k=2)\n",
|
| 659 |
+
" return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
|
| 660 |
+
" base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
|
| 661 |
"\n",
|
| 662 |
" disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
|
| 663 |
" disc_resps, ptok = _gen(disc_prompts)\n",
|
|
|
|
| 732 |
"\n",
|
| 733 |
"\n",
|
| 734 |
"print(\"LLM agent functions defined (batched).\")"
|
| 735 |
+
],
|
| 736 |
+
"execution_count": null,
|
| 737 |
+
"outputs": []
|
| 738 |
},
|
| 739 |
{
|
| 740 |
"cell_type": "markdown",
|
|
|
|
| 747 |
},
|
| 748 |
{
|
| 749 |
"cell_type": "code",
|
|
|
|
| 750 |
"metadata": {},
|
|
|
|
| 751 |
"source": [
|
| 752 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 753 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
|
|
| 761 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 762 |
"for t in TASKS:\n",
|
| 763 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 764 |
+
],
|
| 765 |
+
"execution_count": null,
|
| 766 |
+
"outputs": []
|
| 767 |
},
|
| 768 |
{
|
| 769 |
"cell_type": "markdown",
|
|
|
|
| 782 |
},
|
| 783 |
{
|
| 784 |
"cell_type": "code",
|
|
|
|
| 785 |
"metadata": {},
|
|
|
|
| 786 |
"source": [
|
| 787 |
"# Cell 10: Attach LoRA adapter\n",
|
| 788 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
|
|
| 796 |
"model.enable_input_require_grads()\n",
|
| 797 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 798 |
"peft_model.print_trainable_parameters()"
|
| 799 |
+
],
|
| 800 |
+
"execution_count": null,
|
| 801 |
+
"outputs": []
|
| 802 |
},
|
| 803 |
{
|
| 804 |
"cell_type": "code",
|
|
|
|
| 805 |
"metadata": {},
|
|
|
|
| 806 |
"source": [
|
| 807 |
+
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 808 |
+
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
| 809 |
+
"# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
|
| 810 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 811 |
"from datasets import Dataset\n",
|
| 812 |
"\n",
|
|
|
|
| 813 |
"EPISODES_PER_ROUND = 6\n",
|
| 814 |
+
"ROUNDS_PER_PHASE = 3\n",
|
| 815 |
+
"QUALITY_FLOOR = 0.0\n",
|
| 816 |
+
"\n",
|
| 817 |
+
"PHASES = [\n",
|
| 818 |
+
" {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
|
| 819 |
+
" {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
|
| 820 |
+
"]\n",
|
| 821 |
"\n",
|
| 822 |
"training_log = {\n",
|
| 823 |
+
" \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
|
| 824 |
+
" \"avg_episode_reward\": [], \"max_episode_reward\": [], \"min_episode_reward\": [],\n",
|
| 825 |
+
" \"avg_grader\": [], \"max_grader\": [],\n",
|
| 826 |
" \"n_training_samples\": [], \"train_loss\": [],\n",
|
| 827 |
"}\n",
|
| 828 |
"\n",
|
| 829 |
"t_start = time.time()\n",
|
| 830 |
+
"global_step = 0\n",
|
| 831 |
+
"\n",
|
| 832 |
+
"for phase in PHASES:\n",
|
| 833 |
+
" phase_name = phase[\"name\"]\n",
|
| 834 |
+
" sys_prompt = phase[\"system\"]\n",
|
| 835 |
+
" reward_mode = phase[\"reward_mode\"]\n",
|
| 836 |
+
" print(f\"\\n{'#' * 60}\\n# PHASE {phase_name} (reward_mode={reward_mode})\\n{'#' * 60}\")\n",
|
| 837 |
+
"\n",
|
| 838 |
+
" for round_idx in range(ROUNDS_PER_PHASE):\n",
|
| 839 |
+
" use_hint = (round_idx == 0)\n",
|
| 840 |
+
" print(f\"\\n{'=' * 60}\\n{phase_name} | ROUND {round_idx+1}/{ROUNDS_PER_PHASE} | hint={use_hint}\\n{'=' * 60}\")\n",
|
| 841 |
+
"\n",
|
| 842 |
+
" peft_model.eval()\n",
|
| 843 |
+
" tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + ep + round_idx * 10) for ep in range(EPISODES_PER_ROUND)]\n",
|
| 844 |
+
" t_roll = time.time()\n",
|
| 845 |
+
" results = run_llm_episodes_batched(\n",
|
| 846 |
+
" peft_model, tokenizer, tasks_seeds, verbose=False, eval=False,\n",
|
| 847 |
+
" system=sys_prompt, hint_peak_hours=use_hint, reward_mode=reward_mode,\n",
|
| 848 |
+
" log_tag=f\"{phase_name}_r{round_idx}\",\n",
|
| 849 |
+
" )\n",
|
| 850 |
+
" print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
|
| 851 |
+
"\n",
|
| 852 |
+
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
| 853 |
+
" for ep, result in enumerate(results):\n",
|
| 854 |
+
" ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 855 |
+
" episode_rewards.append(ep_reward)\n",
|
| 856 |
+
" episode_graders.append(result[\"grader_score\"])\n",
|
| 857 |
+
" kept = 0\n",
|
| 858 |
+
" for pr in result[\"pairs\"]:\n",
|
| 859 |
+
" if not is_well_formed_response(pr[\"response\"]):\n",
|
| 860 |
+
" continue\n",
|
| 861 |
+
" text = (f\"<|im_start|>system\\n{sys_prompt}<|im_end|>\\n\"\n",
|
| 862 |
+
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 863 |
+
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 864 |
+
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
| 865 |
+
" kept += 1\n",
|
| 866 |
+
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
|
| 867 |
+
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
|
| 868 |
+
"\n",
|
| 869 |
+
" avg_r = float(np.mean(episode_rewards))\n",
|
| 870 |
+
" avg_g = float(np.mean(episode_graders))\n",
|
| 871 |
+
" max_g = float(max(episode_graders))\n",
|
| 872 |
+
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
|
| 873 |
+
"\n",
|
| 874 |
+
" loss = float(\"nan\")\n",
|
| 875 |
+
" n_filtered = 0\n",
|
| 876 |
+
" if not all_pairs:\n",
|
| 877 |
+
" print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
|
| 878 |
+
" elif max_g < QUALITY_FLOOR:\n",
|
| 879 |
+
" print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
|
| 880 |
+
" else:\n",
|
| 881 |
+
" rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
|
| 882 |
+
" adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
|
| 883 |
+
" filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
|
| 884 |
+
" if not filtered:\n",
|
| 885 |
+
" print(\" SKIP SFT: zero positive-advantage samples\")\n",
|
| 886 |
+
" else:\n",
|
| 887 |
+
" n_filtered = len(filtered)\n",
|
| 888 |
+
" print(f\" Kept {n_filtered}/{len(all_pairs)} positive-advantage samples\")\n",
|
| 889 |
+
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 890 |
+
" sft_config = SFTConfig(\n",
|
| 891 |
+
" output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
|
| 892 |
+
" num_train_epochs=1,\n",
|
| 893 |
+
" per_device_train_batch_size=2,\n",
|
| 894 |
+
" gradient_accumulation_steps=4,\n",
|
| 895 |
+
" learning_rate=5e-6,\n",
|
| 896 |
+
" warmup_steps=5,\n",
|
| 897 |
+
" logging_steps=1,\n",
|
| 898 |
+
" save_strategy=\"no\",\n",
|
| 899 |
+
" max_length=2048,\n",
|
| 900 |
+
" bf16=True,\n",
|
| 901 |
+
" report_to=\"none\",\n",
|
| 902 |
+
" )\n",
|
| 903 |
+
" peft_model.train()\n",
|
| 904 |
+
" trainer = SFTTrainer(\n",
|
| 905 |
+
" model=peft_model, processing_class=tokenizer,\n",
|
| 906 |
+
" train_dataset=dataset, args=sft_config,\n",
|
| 907 |
+
" )\n",
|
| 908 |
+
" train_result = trainer.train()\n",
|
| 909 |
+
" loss = float(train_result.training_loss)\n",
|
| 910 |
+
" print(f\" Training loss: {loss:.4f}\")\n",
|
| 911 |
+
"\n",
|
| 912 |
+
" global_step += 1\n",
|
| 913 |
+
" training_log[\"phase\"].append(phase_name)\n",
|
| 914 |
+
" training_log[\"round\"].append(round_idx + 1)\n",
|
| 915 |
+
" training_log[\"global_step\"].append(global_step)\n",
|
| 916 |
+
" training_log[\"use_hint\"].append(use_hint)\n",
|
| 917 |
+
" training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
|
| 918 |
+
" training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
|
| 919 |
+
" training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
|
| 920 |
+
" training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
|
| 921 |
+
" training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
|
| 922 |
+
" training_log[\"n_training_samples\"].append(n_filtered)\n",
|
| 923 |
+
" training_log[\"train_loss\"].append(round(loss, 4) if loss == loss else float(\"nan\"))\n",
|
| 924 |
+
"\n",
|
| 925 |
+
" save_dir = f\"./checkpoints/{phase_name}_adapter\"\n",
|
| 926 |
+
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 927 |
+
" peft_model.save_pretrained(save_dir)\n",
|
| 928 |
+
" tokenizer.save_pretrained(save_dir)\n",
|
| 929 |
+
" print(f\"\\n Saved {phase_name} adapter -> {save_dir}\")\n",
|
| 930 |
"\n",
|
| 931 |
"elapsed = time.time() - t_start\n",
|
| 932 |
+
"print(f\"\\nTwo-phase training complete in {elapsed/60:.1f} min\")\n",
|
| 933 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 934 |
+
],
|
| 935 |
+
"execution_count": null,
|
| 936 |
+
"outputs": []
|
| 937 |
},
|
| 938 |
{
|
| 939 |
"cell_type": "markdown",
|
|
|
|
| 946 |
},
|
| 947 |
{
|
| 948 |
"cell_type": "code",
|
|
|
|
| 949 |
"metadata": {},
|
|
|
|
| 950 |
"source": [
|
| 951 |
"# Cell 12: Run trained model (batched)\n",
|
| 952 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
|
|
| 961 |
"print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 962 |
"for t in TASKS:\n",
|
| 963 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 964 |
+
],
|
| 965 |
+
"execution_count": null,
|
| 966 |
+
"outputs": []
|
| 967 |
},
|
| 968 |
{
|
| 969 |
"cell_type": "markdown",
|
|
|
|
| 974 |
},
|
| 975 |
{
|
| 976 |
"cell_type": "code",
|
|
|
|
| 977 |
"metadata": {},
|
|
|
|
| 978 |
"source": [
|
| 979 |
+
"# Cell 13: Training curves (two-phase)\n",
|
| 980 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 981 |
+
"steps = training_log[\"global_step\"]\n",
|
| 982 |
+
"phases = training_log[\"phase\"]\n",
|
| 983 |
+
"phase1_end = max([s for s, p in zip(steps, phases) if p == \"phase1_timing\"], default=0)\n",
|
| 984 |
"\n",
|
| 985 |
+
"axes[0].plot(steps, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
|
| 986 |
+
"axes[0].fill_between(steps, training_log[\"avg_grader\"],\n",
|
| 987 |
" training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
|
| 988 |
+
"if phase1_end > 0:\n",
|
| 989 |
+
" axes[0].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6, label='phase split')\n",
|
| 990 |
+
"axes[0].set_xlabel('Global step'); axes[0].set_ylabel('Grader Score')\n",
|
| 991 |
+
"axes[0].set_title('Grader Score (timing -> content)', fontweight='bold')\n",
|
| 992 |
"axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
|
| 993 |
"\n",
|
| 994 |
+
"axes[1].plot(steps, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
|
| 995 |
+
"if phase1_end > 0:\n",
|
| 996 |
+
" axes[1].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6)\n",
|
| 997 |
+
"axes[1].set_xlabel('Global step'); axes[1].set_ylabel('Loss')\n",
|
| 998 |
"axes[1].set_title('Training Loss', fontweight='bold')\n",
|
| 999 |
"axes[1].grid(True, alpha=0.3)\n",
|
| 1000 |
"\n",
|
| 1001 |
+
"fig.suptitle('Viraltest v2 — Two-Phase LoRA Training (timing -> content)', fontsize=14, fontweight='bold')\n",
|
| 1002 |
"fig.tight_layout()\n",
|
| 1003 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 1004 |
"plt.show()"
|
| 1005 |
+
],
|
| 1006 |
+
"execution_count": null,
|
| 1007 |
+
"outputs": []
|
| 1008 |
},
|
| 1009 |
{
|
| 1010 |
"cell_type": "code",
|
|
|
|
| 1011 |
"metadata": {},
|
|
|
|
| 1012 |
"source": [
|
| 1013 |
"# Cell 14: Before vs After\n",
|
| 1014 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
|
|
| 1038 |
"fig.tight_layout()\n",
|
| 1039 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 1040 |
"plt.show()"
|
| 1041 |
+
],
|
| 1042 |
+
"execution_count": null,
|
| 1043 |
+
"outputs": []
|
| 1044 |
},
|
| 1045 |
{
|
| 1046 |
"cell_type": "code",
|
|
|
|
| 1047 |
"metadata": {},
|
|
|
|
| 1048 |
"source": [
|
| 1049 |
"# Cell 15: Trajectory comparison\n",
|
| 1050 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
|
|
| 1068 |
"fig.tight_layout()\n",
|
| 1069 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1070 |
"plt.show()"
|
| 1071 |
+
],
|
| 1072 |
+
"execution_count": null,
|
| 1073 |
+
"outputs": []
|
| 1074 |
},
|
| 1075 |
{
|
| 1076 |
"cell_type": "markdown",
|
|
|
|
| 1081 |
},
|
| 1082 |
{
|
| 1083 |
"cell_type": "code",
|
|
|
|
| 1084 |
"metadata": {},
|
|
|
|
| 1085 |
"source": [
|
| 1086 |
"# Cell 16: Final summary\n",
|
| 1087 |
"print(\"=\" * 67)\n",
|
|
|
|
| 1103 |
"\n",
|
| 1104 |
"summary = {\n",
|
| 1105 |
" \"model\": MODEL_NAME,\n",
|
| 1106 |
+
" \"training\": \"Two-phase LoRA SFT (timing -> content) with hardcoded peak-hours hint on round 1 of each phase\",\n",
|
| 1107 |
+
" \"phases\": [p[\"name\"] for p in PHASES],\n",
|
| 1108 |
+
" \"rounds_per_phase\": ROUNDS_PER_PHASE,\n",
|
| 1109 |
+
" \"episodes_per_round\": EPISODES_PER_ROUND,\n",
|
| 1110 |
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1111 |
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1112 |
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
|
|
|
| 1120 |
"\n",
|
| 1121 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1122 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1123 |
+
],
|
| 1124 |
+
"execution_count": null,
|
| 1125 |
+
"outputs": []
|
| 1126 |
},
|
| 1127 |
{
|
| 1128 |
"cell_type": "code",
|
|
|
|
| 1129 |
"metadata": {},
|
|
|
|
| 1130 |
"source": [
|
| 1131 |
"# Cell 17: Save adapter\n",
|
| 1132 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
|
|
| 1134 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1135 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1136 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1137 |
+
],
|
| 1138 |
+
"execution_count": null,
|
| 1139 |
+
"outputs": []
|
| 1140 |
}
|
| 1141 |
],
|
| 1142 |
"metadata": {
|
|
|
|
| 1162 |
},
|
| 1163 |
"nbformat": 4,
|
| 1164 |
"nbformat_minor": 4
|
| 1165 |
+
}
|