Spaces:
Paused
Paused
Commit ·
d8bb03f
1
Parent(s): 17149c8
train_grpo: rename monthly_* tasks to weekly_* (with env alias)
Browse files- training/train_grpo.ipynb: rename TASKS list and matching plot label
prefix-strips from monthly_* to weekly_*.
- server/viraltest_environment.py: add _TASK_ALIASES so the renamed
weekly_* identifiers route to the existing monthly_* graders /
baselines without breaking external callers. Also picks up the
TASK_HORIZON env-var override that was already staged locally.
Why: the configurable horizon (default 15 days) is closer to a weekly
than a monthly cycle, so the task names should reflect that. The alias
keeps run-output JSON, validate-submission scripts, and the dashboard
working unchanged.
Made-with: Cursor
- server/viraltest_environment.py +15 -1
- training/train_grpo.ipynb +59 -59
server/viraltest_environment.py
CHANGED
|
@@ -13,6 +13,7 @@ Multi-day creator optimization with:
|
|
| 13 |
|
| 14 |
import json
|
| 15 |
import math
|
|
|
|
| 16 |
import random
|
| 17 |
from collections import defaultdict
|
| 18 |
from dataclasses import dataclass, field
|
|
@@ -102,7 +103,8 @@ _FOLLOWERS_BY_ARCHETYPE: Dict[str, int] = {
|
|
| 102 |
# ---------------------------------------------------------------------------
|
| 103 |
|
| 104 |
# Episode length in daily env steps. Graders and UI should stay consistent with this value.
|
| 105 |
-
TASK_HORIZON =
|
|
|
|
| 106 |
|
| 107 |
# Distinct positive tags for full tag_discovery score in strategic/competitive graders.
|
| 108 |
# Caps at 30 (original month-scale bar); scales down only for very short horizons.
|
|
@@ -149,6 +151,16 @@ INTENT_MULTIPLIER = {
|
|
| 149 |
|
| 150 |
VALID_TASKS = ("monthly_engage", "monthly_strategic", "monthly_competitive")
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
INITIAL_FOLLOWERS = 10000
|
| 153 |
REST_RECOVERY = 0.12
|
| 154 |
CREATE_CONTENT_COST = 0.05
|
|
@@ -1182,6 +1194,8 @@ class ViraltestEnvironment(Environment):
|
|
| 1182 |
|
| 1183 |
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> ViraltestObservation:
|
| 1184 |
self._task = kwargs.get("task", "monthly_engage")
|
|
|
|
|
|
|
| 1185 |
if self._task not in VALID_TASKS:
|
| 1186 |
self._task = "monthly_engage"
|
| 1187 |
|
|
|
|
| 13 |
|
| 14 |
import json
|
| 15 |
import math
|
| 16 |
+
import os
|
| 17 |
import random
|
| 18 |
from collections import defaultdict
|
| 19 |
from dataclasses import dataclass, field
|
|
|
|
| 103 |
# ---------------------------------------------------------------------------
|
| 104 |
|
| 105 |
# Episode length in daily env steps. Graders and UI should stay consistent with this value.
|
| 106 |
+
# Override via env var TASK_HORIZON (e.g. TASK_HORIZON=1 for ultra-fast local debug runs).
|
| 107 |
+
TASK_HORIZON = int(os.environ.get("TASK_HORIZON", "15"))
|
| 108 |
|
| 109 |
# Distinct positive tags for full tag_discovery score in strategic/competitive graders.
|
| 110 |
# Caps at 30 (original month-scale bar); scales down only for very short horizons.
|
|
|
|
| 151 |
|
| 152 |
VALID_TASKS = ("monthly_engage", "monthly_strategic", "monthly_competitive")
|
| 153 |
|
| 154 |
+
# Backward-compatible aliases. The training notebook now uses `weekly_*` task names
|
| 155 |
+
# (the configurable TASK_HORIZON defaults to 15 days, which is closer to a weekly
|
| 156 |
+
# horizon than a monthly one). They route to the same graders / baselines as the
|
| 157 |
+
# canonical `monthly_*` names, so external callers using either spelling work.
|
| 158 |
+
_TASK_ALIASES: Dict[str, str] = {
|
| 159 |
+
"weekly_engage": "monthly_engage",
|
| 160 |
+
"weekly_strategic": "monthly_strategic",
|
| 161 |
+
"weekly_competitive": "monthly_competitive",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
INITIAL_FOLLOWERS = 10000
|
| 165 |
REST_RECOVERY = 0.12
|
| 166 |
CREATE_CONTENT_COST = 0.05
|
|
|
|
| 1194 |
|
| 1195 |
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> ViraltestObservation:
|
| 1196 |
self._task = kwargs.get("task", "monthly_engage")
|
| 1197 |
+
# Accept the renamed `weekly_*` task identifiers used by the training notebook.
|
| 1198 |
+
self._task = _TASK_ALIASES.get(self._task, self._task)
|
| 1199 |
if self._task not in VALID_TASKS:
|
| 1200 |
self._task = "monthly_engage"
|
| 1201 |
|
training/train_grpo.ipynb
CHANGED
|
@@ -25,7 +25,9 @@
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 31 |
"!pip install -q torch torchvision torchaudio\n",
|
|
@@ -37,13 +39,13 @@
|
|
| 37 |
"# This avoids the from-source build that fails when the container has no nvcc / CUDA_HOME.\n",
|
| 38 |
"# Falls back to sdpa if the wheel install fails (e.g. on a different env).\n",
|
| 39 |
"!pip install -q \"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl\" || pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 40 |
-
]
|
| 41 |
-
"execution_count": null,
|
| 42 |
-
"outputs": []
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"cell_type": "code",
|
|
|
|
| 46 |
"metadata": {},
|
|
|
|
| 47 |
"source": [
|
| 48 |
"# Cell 2: Resolve repo path (Colab / Kaggle: fresh clone. Local: auto-detect project root)\n",
|
| 49 |
"import os\n",
|
|
@@ -126,13 +128,13 @@
|
|
| 126 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 127 |
"print(f\"Commit: {commit}\")\n",
|
| 128 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 129 |
-
]
|
| 130 |
-
"execution_count": null,
|
| 131 |
-
"outputs": []
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "code",
|
|
|
|
| 135 |
"metadata": {},
|
|
|
|
| 136 |
"source": [
|
| 137 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 138 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
@@ -171,7 +173,7 @@
|
|
| 171 |
"NICHES = list(TOPIC_CATEGORIES.keys())\n",
|
| 172 |
"CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
|
| 173 |
"INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
|
| 174 |
-
"TASKS = [\"
|
| 175 |
"\n",
|
| 176 |
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 177 |
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")\n",
|
|
@@ -198,9 +200,7 @@
|
|
| 198 |
"# hint stays on for both (current behaviour preserved).\n",
|
| 199 |
"HINT_ALWAYS = not TEST_ONLY\n",
|
| 200 |
"print(f\"SMOKE_MODE={SMOKE_MODE} | TEST_ONLY={TEST_ONLY} | HINT_ALWAYS={HINT_ALWAYS}\")"
|
| 201 |
-
]
|
| 202 |
-
"execution_count": null,
|
| 203 |
-
"outputs": []
|
| 204 |
},
|
| 205 |
{
|
| 206 |
"cell_type": "markdown",
|
|
@@ -213,7 +213,9 @@
|
|
| 213 |
},
|
| 214 |
{
|
| 215 |
"cell_type": "code",
|
|
|
|
| 216 |
"metadata": {},
|
|
|
|
| 217 |
"source": [
|
| 218 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 219 |
"_rng = random.Random(42)\n",
|
|
@@ -289,13 +291,13 @@
|
|
| 289 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 290 |
"\n",
|
| 291 |
"print(\"Agents and episode runner defined.\")"
|
| 292 |
-
]
|
| 293 |
-
"execution_count": null,
|
| 294 |
-
"outputs": []
|
| 295 |
},
|
| 296 |
{
|
| 297 |
"cell_type": "code",
|
|
|
|
| 298 |
"metadata": {},
|
|
|
|
| 299 |
"source": [
|
| 300 |
"# Cell 5: Run baselines (safe)\n",
|
| 301 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
@@ -330,13 +332,13 @@
|
|
| 330 |
"for name in BASELINE_AGENTS:\n",
|
| 331 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 332 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 333 |
-
]
|
| 334 |
-
"execution_count": null,
|
| 335 |
-
"outputs": []
|
| 336 |
},
|
| 337 |
{
|
| 338 |
"cell_type": "code",
|
|
|
|
| 339 |
"metadata": {},
|
|
|
|
| 340 |
"source": [
|
| 341 |
"# Cell 6: Baseline plots\n",
|
| 342 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
@@ -345,7 +347,7 @@
|
|
| 345 |
"for i, task in enumerate(TASKS):\n",
|
| 346 |
" scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
|
| 347 |
" bars = axes[i].barh(agent_names, scores, color=colors)\n",
|
| 348 |
-
" axes[i].set_title(task.replace(\"
|
| 349 |
" for bar, score in zip(bars, scores):\n",
|
| 350 |
" axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
|
| 351 |
" f\"{score:.4f}\", va='center', fontsize=9)\n",
|
|
@@ -354,9 +356,7 @@
|
|
| 354 |
"fig.tight_layout()\n",
|
| 355 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 356 |
"plt.show()"
|
| 357 |
-
]
|
| 358 |
-
"execution_count": null,
|
| 359 |
-
"outputs": []
|
| 360 |
},
|
| 361 |
{
|
| 362 |
"cell_type": "markdown",
|
|
@@ -369,7 +369,9 @@
|
|
| 369 |
},
|
| 370 |
{
|
| 371 |
"cell_type": "code",
|
|
|
|
| 372 |
"metadata": {},
|
|
|
|
| 373 |
"source": [
|
| 374 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 375 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
@@ -413,13 +415,13 @@
|
|
| 413 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 414 |
"if torch.cuda.is_available():\n",
|
| 415 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 416 |
-
]
|
| 417 |
-
"execution_count": null,
|
| 418 |
-
"outputs": []
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"cell_type": "code",
|
|
|
|
| 422 |
"metadata": {},
|
|
|
|
| 423 |
"source": [
|
| 424 |
"# Cell 8: LLM agent functions\n",
|
| 425 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
@@ -762,9 +764,7 @@
|
|
| 762 |
"\n",
|
| 763 |
"\n",
|
| 764 |
"print(\"LLM agent functions defined (batched).\")"
|
| 765 |
-
]
|
| 766 |
-
"execution_count": null,
|
| 767 |
-
"outputs": []
|
| 768 |
},
|
| 769 |
{
|
| 770 |
"cell_type": "markdown",
|
|
@@ -777,7 +777,9 @@
|
|
| 777 |
},
|
| 778 |
{
|
| 779 |
"cell_type": "code",
|
|
|
|
| 780 |
"metadata": {},
|
|
|
|
| 781 |
"source": [
|
| 782 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 783 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
@@ -791,9 +793,7 @@
|
|
| 791 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 792 |
"for t in TASKS:\n",
|
| 793 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 794 |
-
]
|
| 795 |
-
"execution_count": null,
|
| 796 |
-
"outputs": []
|
| 797 |
},
|
| 798 |
{
|
| 799 |
"cell_type": "markdown",
|
|
@@ -812,7 +812,9 @@
|
|
| 812 |
},
|
| 813 |
{
|
| 814 |
"cell_type": "code",
|
|
|
|
| 815 |
"metadata": {},
|
|
|
|
| 816 |
"source": [
|
| 817 |
"# Cell 10: Attach LoRA adapter\n",
|
| 818 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
@@ -834,13 +836,13 @@
|
|
| 834 |
"model.enable_input_require_grads()\n",
|
| 835 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 836 |
"peft_model.print_trainable_parameters()"
|
| 837 |
-
]
|
| 838 |
-
"execution_count": null,
|
| 839 |
-
"outputs": []
|
| 840 |
},
|
| 841 |
{
|
| 842 |
"cell_type": "code",
|
|
|
|
| 843 |
"metadata": {},
|
|
|
|
| 844 |
"source": [
|
| 845 |
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 846 |
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
|
@@ -986,9 +988,7 @@
|
|
| 986 |
"elapsed = time.time() - t_start\n",
|
| 987 |
"print(f\"\\nTwo-phase training complete in {elapsed/60:.1f} min\")\n",
|
| 988 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 989 |
-
]
|
| 990 |
-
"execution_count": null,
|
| 991 |
-
"outputs": []
|
| 992 |
},
|
| 993 |
{
|
| 994 |
"cell_type": "markdown",
|
|
@@ -1001,7 +1001,9 @@
|
|
| 1001 |
},
|
| 1002 |
{
|
| 1003 |
"cell_type": "code",
|
|
|
|
| 1004 |
"metadata": {},
|
|
|
|
| 1005 |
"source": [
|
| 1006 |
"# Cell 12: Run trained model (batched)\n",
|
| 1007 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
@@ -1045,13 +1047,13 @@
|
|
| 1045 |
" print(f\" {t}: {a:.4f} -> {new_a:.4f} (was delta={a-b:+.4f}, now {new_a-b:+.4f})\")\n",
|
| 1046 |
" else:\n",
|
| 1047 |
" print(f\" {t}: {a:.4f} (organic delta {a-b:+.4f}, no boost needed)\")"
|
| 1048 |
-
]
|
| 1049 |
-
"execution_count": null,
|
| 1050 |
-
"outputs": []
|
| 1051 |
},
|
| 1052 |
{
|
| 1053 |
"cell_type": "code",
|
|
|
|
| 1054 |
"metadata": {},
|
|
|
|
| 1055 |
"source": [
|
| 1056 |
"# Cell 12.5: Debug — analyse io_log.jsonl (before vs after, tool error rate, hint usage)\n",
|
| 1057 |
"import re\n",
|
|
@@ -1112,9 +1114,7 @@
|
|
| 1112 |
" if bk and ak:\n",
|
| 1113 |
" print(\"BEFORE response head:\", bk[\"response\"][:300].replace(\"\\n\", \" \"))\n",
|
| 1114 |
" print(\"AFTER response head:\", ak[\"response\"][:300].replace(\"\\n\", \" \"))"
|
| 1115 |
-
]
|
| 1116 |
-
"execution_count": null,
|
| 1117 |
-
"outputs": []
|
| 1118 |
},
|
| 1119 |
{
|
| 1120 |
"cell_type": "markdown",
|
|
@@ -1125,7 +1125,9 @@
|
|
| 1125 |
},
|
| 1126 |
{
|
| 1127 |
"cell_type": "code",
|
|
|
|
| 1128 |
"metadata": {},
|
|
|
|
| 1129 |
"source": [
|
| 1130 |
"# Cell 13: Training curves (two-phase)\n",
|
| 1131 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
@@ -1153,16 +1155,16 @@
|
|
| 1153 |
"fig.tight_layout()\n",
|
| 1154 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 1155 |
"plt.show()"
|
| 1156 |
-
]
|
| 1157 |
-
"execution_count": null,
|
| 1158 |
-
"outputs": []
|
| 1159 |
},
|
| 1160 |
{
|
| 1161 |
"cell_type": "code",
|
|
|
|
| 1162 |
"metadata": {},
|
|
|
|
| 1163 |
"source": [
|
| 1164 |
"# Cell 14: Before vs After\n",
|
| 1165 |
-
"task_labels = [t.replace('
|
| 1166 |
"x = np.arange(len(TASKS))\n",
|
| 1167 |
"w = 0.25\n",
|
| 1168 |
"\n",
|
|
@@ -1189,13 +1191,13 @@
|
|
| 1189 |
"fig.tight_layout()\n",
|
| 1190 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 1191 |
"plt.show()"
|
| 1192 |
-
]
|
| 1193 |
-
"execution_count": null,
|
| 1194 |
-
"outputs": []
|
| 1195 |
},
|
| 1196 |
{
|
| 1197 |
"cell_type": "code",
|
|
|
|
| 1198 |
"metadata": {},
|
|
|
|
| 1199 |
"source": [
|
| 1200 |
"# Cell 15: Trajectory comparison\n",
|
| 1201 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
@@ -1211,7 +1213,7 @@
|
|
| 1211 |
" sr = baseline_results[\"smart\"][task]\n",
|
| 1212 |
" axes[0, i].plot(sr[\"rewards\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 1213 |
" axes[1, i].plot(sr[\"energies\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 1214 |
-
" t_name = task.replace('
|
| 1215 |
" axes[0, i].set_title(f\"{t_name} — Rewards\"); axes[0, i].grid(True, alpha=0.3)\n",
|
| 1216 |
" axes[1, i].set_title(f\"{t_name} — Energy\"); axes[1, i].grid(True, alpha=0.3)\n",
|
| 1217 |
"axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
|
|
@@ -1219,9 +1221,7 @@
|
|
| 1219 |
"fig.tight_layout()\n",
|
| 1220 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1221 |
"plt.show()"
|
| 1222 |
-
]
|
| 1223 |
-
"execution_count": null,
|
| 1224 |
-
"outputs": []
|
| 1225 |
},
|
| 1226 |
{
|
| 1227 |
"cell_type": "markdown",
|
|
@@ -1232,7 +1232,9 @@
|
|
| 1232 |
},
|
| 1233 |
{
|
| 1234 |
"cell_type": "code",
|
|
|
|
| 1235 |
"metadata": {},
|
|
|
|
| 1236 |
"source": [
|
| 1237 |
"# Cell 16: Final summary\n",
|
| 1238 |
"print(\"=\" * 67)\n",
|
|
@@ -1271,13 +1273,13 @@
|
|
| 1271 |
"\n",
|
| 1272 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1273 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1274 |
-
]
|
| 1275 |
-
"execution_count": null,
|
| 1276 |
-
"outputs": []
|
| 1277 |
},
|
| 1278 |
{
|
| 1279 |
"cell_type": "code",
|
|
|
|
| 1280 |
"metadata": {},
|
|
|
|
| 1281 |
"source": [
|
| 1282 |
"# Cell 17: Save adapter\n",
|
| 1283 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
@@ -1285,9 +1287,7 @@
|
|
| 1285 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1286 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1287 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1288 |
-
]
|
| 1289 |
-
"execution_count": null,
|
| 1290 |
-
"outputs": []
|
| 1291 |
}
|
| 1292 |
],
|
| 1293 |
"metadata": {
|
|
@@ -1313,4 +1313,4 @@
|
|
| 1313 |
},
|
| 1314 |
"nbformat": 4,
|
| 1315 |
"nbformat_minor": 4
|
| 1316 |
-
}
|
|
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
+
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 33 |
"!pip install -q torch torchvision torchaudio\n",
|
|
|
|
| 39 |
"# This avoids the from-source build that fails when the container has no nvcc / CUDA_HOME.\n",
|
| 40 |
"# Falls back to sdpa if the wheel install fails (e.g. on a different env).\n",
|
| 41 |
"!pip install -q \"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl\" || pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 42 |
+
]
|
|
|
|
|
|
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
"source": [
|
| 50 |
"# Cell 2: Resolve repo path (Colab / Kaggle: fresh clone. Local: auto-detect project root)\n",
|
| 51 |
"import os\n",
|
|
|
|
| 128 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 129 |
"print(f\"Commit: {commit}\")\n",
|
| 130 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 131 |
+
]
|
|
|
|
|
|
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "code",
|
| 135 |
+
"execution_count": null,
|
| 136 |
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
"source": [
|
| 139 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 140 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
|
|
| 173 |
"NICHES = list(TOPIC_CATEGORIES.keys())\n",
|
| 174 |
"CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n",
|
| 175 |
"INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n",
|
| 176 |
+
"TASKS = [\"weekly_engage\", \"weekly_strategic\", \"weekly_competitive\"]\n",
|
| 177 |
"\n",
|
| 178 |
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 179 |
"print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")\n",
|
|
|
|
| 200 |
"# hint stays on for both (current behaviour preserved).\n",
|
| 201 |
"HINT_ALWAYS = not TEST_ONLY\n",
|
| 202 |
"print(f\"SMOKE_MODE={SMOKE_MODE} | TEST_ONLY={TEST_ONLY} | HINT_ALWAYS={HINT_ALWAYS}\")"
|
| 203 |
+
]
|
|
|
|
|
|
|
| 204 |
},
|
| 205 |
{
|
| 206 |
"cell_type": "markdown",
|
|
|
|
| 213 |
},
|
| 214 |
{
|
| 215 |
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
"source": [
|
| 220 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 221 |
"_rng = random.Random(42)\n",
|
|
|
|
| 291 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 292 |
"\n",
|
| 293 |
"print(\"Agents and episode runner defined.\")"
|
| 294 |
+
]
|
|
|
|
|
|
|
| 295 |
},
|
| 296 |
{
|
| 297 |
"cell_type": "code",
|
| 298 |
+
"execution_count": null,
|
| 299 |
"metadata": {},
|
| 300 |
+
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
"# Cell 5: Run baselines (safe)\n",
|
| 303 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
|
|
| 332 |
"for name in BASELINE_AGENTS:\n",
|
| 333 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 334 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 335 |
+
]
|
|
|
|
|
|
|
| 336 |
},
|
| 337 |
{
|
| 338 |
"cell_type": "code",
|
| 339 |
+
"execution_count": null,
|
| 340 |
"metadata": {},
|
| 341 |
+
"outputs": [],
|
| 342 |
"source": [
|
| 343 |
"# Cell 6: Baseline plots\n",
|
| 344 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
|
|
| 347 |
"for i, task in enumerate(TASKS):\n",
|
| 348 |
" scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n",
|
| 349 |
" bars = axes[i].barh(agent_names, scores, color=colors)\n",
|
| 350 |
+
" axes[i].set_title(task.replace(\"weekly_\", \"\").title(), fontsize=13, fontweight='bold')\n",
|
| 351 |
" for bar, score in zip(bars, scores):\n",
|
| 352 |
" axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n",
|
| 353 |
" f\"{score:.4f}\", va='center', fontsize=9)\n",
|
|
|
|
| 356 |
"fig.tight_layout()\n",
|
| 357 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 358 |
"plt.show()"
|
| 359 |
+
]
|
|
|
|
|
|
|
| 360 |
},
|
| 361 |
{
|
| 362 |
"cell_type": "markdown",
|
|
|
|
| 369 |
},
|
| 370 |
{
|
| 371 |
"cell_type": "code",
|
| 372 |
+
"execution_count": null,
|
| 373 |
"metadata": {},
|
| 374 |
+
"outputs": [],
|
| 375 |
"source": [
|
| 376 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 377 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
|
|
| 415 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 416 |
"if torch.cuda.is_available():\n",
|
| 417 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 418 |
+
]
|
|
|
|
|
|
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"cell_type": "code",
|
| 422 |
+
"execution_count": null,
|
| 423 |
"metadata": {},
|
| 424 |
+
"outputs": [],
|
| 425 |
"source": [
|
| 426 |
"# Cell 8: LLM agent functions\n",
|
| 427 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
|
|
| 764 |
"\n",
|
| 765 |
"\n",
|
| 766 |
"print(\"LLM agent functions defined (batched).\")"
|
| 767 |
+
]
|
|
|
|
|
|
|
| 768 |
},
|
| 769 |
{
|
| 770 |
"cell_type": "markdown",
|
|
|
|
| 777 |
},
|
| 778 |
{
|
| 779 |
"cell_type": "code",
|
| 780 |
+
"execution_count": null,
|
| 781 |
"metadata": {},
|
| 782 |
+
"outputs": [],
|
| 783 |
"source": [
|
| 784 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 785 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
|
|
| 793 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 794 |
"for t in TASKS:\n",
|
| 795 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 796 |
+
]
|
|
|
|
|
|
|
| 797 |
},
|
| 798 |
{
|
| 799 |
"cell_type": "markdown",
|
|
|
|
| 812 |
},
|
| 813 |
{
|
| 814 |
"cell_type": "code",
|
| 815 |
+
"execution_count": null,
|
| 816 |
"metadata": {},
|
| 817 |
+
"outputs": [],
|
| 818 |
"source": [
|
| 819 |
"# Cell 10: Attach LoRA adapter\n",
|
| 820 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
|
|
| 836 |
"model.enable_input_require_grads()\n",
|
| 837 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 838 |
"peft_model.print_trainable_parameters()"
|
| 839 |
+
]
|
|
|
|
|
|
|
| 840 |
},
|
| 841 |
{
|
| 842 |
"cell_type": "code",
|
| 843 |
+
"execution_count": null,
|
| 844 |
"metadata": {},
|
| 845 |
+
"outputs": [],
|
| 846 |
"source": [
|
| 847 |
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 848 |
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
|
|
|
| 988 |
"elapsed = time.time() - t_start\n",
|
| 989 |
"print(f\"\\nTwo-phase training complete in {elapsed/60:.1f} min\")\n",
|
| 990 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 991 |
+
]
|
|
|
|
|
|
|
| 992 |
},
|
| 993 |
{
|
| 994 |
"cell_type": "markdown",
|
|
|
|
| 1001 |
},
|
| 1002 |
{
|
| 1003 |
"cell_type": "code",
|
| 1004 |
+
"execution_count": null,
|
| 1005 |
"metadata": {},
|
| 1006 |
+
"outputs": [],
|
| 1007 |
"source": [
|
| 1008 |
"# Cell 12: Run trained model (batched)\n",
|
| 1009 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
|
|
| 1047 |
" print(f\" {t}: {a:.4f} -> {new_a:.4f} (was delta={a-b:+.4f}, now {new_a-b:+.4f})\")\n",
|
| 1048 |
" else:\n",
|
| 1049 |
" print(f\" {t}: {a:.4f} (organic delta {a-b:+.4f}, no boost needed)\")"
|
| 1050 |
+
]
|
|
|
|
|
|
|
| 1051 |
},
|
| 1052 |
{
|
| 1053 |
"cell_type": "code",
|
| 1054 |
+
"execution_count": null,
|
| 1055 |
"metadata": {},
|
| 1056 |
+
"outputs": [],
|
| 1057 |
"source": [
|
| 1058 |
"# Cell 12.5: Debug — analyse io_log.jsonl (before vs after, tool error rate, hint usage)\n",
|
| 1059 |
"import re\n",
|
|
|
|
| 1114 |
" if bk and ak:\n",
|
| 1115 |
" print(\"BEFORE response head:\", bk[\"response\"][:300].replace(\"\\n\", \" \"))\n",
|
| 1116 |
" print(\"AFTER response head:\", ak[\"response\"][:300].replace(\"\\n\", \" \"))"
|
| 1117 |
+
]
|
|
|
|
|
|
|
| 1118 |
},
|
| 1119 |
{
|
| 1120 |
"cell_type": "markdown",
|
|
|
|
| 1125 |
},
|
| 1126 |
{
|
| 1127 |
"cell_type": "code",
|
| 1128 |
+
"execution_count": null,
|
| 1129 |
"metadata": {},
|
| 1130 |
+
"outputs": [],
|
| 1131 |
"source": [
|
| 1132 |
"# Cell 13: Training curves (two-phase)\n",
|
| 1133 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
|
|
| 1155 |
"fig.tight_layout()\n",
|
| 1156 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 1157 |
"plt.show()"
|
| 1158 |
+
]
|
|
|
|
|
|
|
| 1159 |
},
|
| 1160 |
{
|
| 1161 |
"cell_type": "code",
|
| 1162 |
+
"execution_count": null,
|
| 1163 |
"metadata": {},
|
| 1164 |
+
"outputs": [],
|
| 1165 |
"source": [
|
| 1166 |
"# Cell 14: Before vs After\n",
|
| 1167 |
+
"task_labels = [t.replace('weekly_', '').title() for t in TASKS]\n",
|
| 1168 |
"x = np.arange(len(TASKS))\n",
|
| 1169 |
"w = 0.25\n",
|
| 1170 |
"\n",
|
|
|
|
| 1191 |
"fig.tight_layout()\n",
|
| 1192 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 1193 |
"plt.show()"
|
| 1194 |
+
]
|
|
|
|
|
|
|
| 1195 |
},
|
| 1196 |
{
|
| 1197 |
"cell_type": "code",
|
| 1198 |
+
"execution_count": null,
|
| 1199 |
"metadata": {},
|
| 1200 |
+
"outputs": [],
|
| 1201 |
"source": [
|
| 1202 |
"# Cell 15: Trajectory comparison\n",
|
| 1203 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
|
|
| 1213 |
" sr = baseline_results[\"smart\"][task]\n",
|
| 1214 |
" axes[0, i].plot(sr[\"rewards\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 1215 |
" axes[1, i].plot(sr[\"energies\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n",
|
| 1216 |
+
" t_name = task.replace('weekly_', '').title()\n",
|
| 1217 |
" axes[0, i].set_title(f\"{t_name} — Rewards\"); axes[0, i].grid(True, alpha=0.3)\n",
|
| 1218 |
" axes[1, i].set_title(f\"{t_name} — Energy\"); axes[1, i].grid(True, alpha=0.3)\n",
|
| 1219 |
"axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
|
|
|
|
| 1221 |
"fig.tight_layout()\n",
|
| 1222 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1223 |
"plt.show()"
|
| 1224 |
+
]
|
|
|
|
|
|
|
| 1225 |
},
|
| 1226 |
{
|
| 1227 |
"cell_type": "markdown",
|
|
|
|
| 1232 |
},
|
| 1233 |
{
|
| 1234 |
"cell_type": "code",
|
| 1235 |
+
"execution_count": null,
|
| 1236 |
"metadata": {},
|
| 1237 |
+
"outputs": [],
|
| 1238 |
"source": [
|
| 1239 |
"# Cell 16: Final summary\n",
|
| 1240 |
"print(\"=\" * 67)\n",
|
|
|
|
| 1273 |
"\n",
|
| 1274 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1275 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1276 |
+
]
|
|
|
|
|
|
|
| 1277 |
},
|
| 1278 |
{
|
| 1279 |
"cell_type": "code",
|
| 1280 |
+
"execution_count": null,
|
| 1281 |
"metadata": {},
|
| 1282 |
+
"outputs": [],
|
| 1283 |
"source": [
|
| 1284 |
"# Cell 17: Save adapter\n",
|
| 1285 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
|
|
| 1287 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1288 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1289 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1290 |
+
]
|
|
|
|
|
|
|
| 1291 |
}
|
| 1292 |
],
|
| 1293 |
"metadata": {
|
|
|
|
| 1313 |
},
|
| 1314 |
"nbformat": 4,
|
| 1315 |
"nbformat_minor": 4
|
| 1316 |
+
}
|