vaibhav12332112312 commited on
Commit
8970072
·
1 Parent(s): 1a2a407

pounteradds

Browse files
server/viraltest_environment.py CHANGED
@@ -387,6 +387,8 @@ class ViraltestEnvironment(Environment):
387
  self._hours_since_sleep = 2
388
  self._sleep_debt = 0.0
389
 
 
 
390
  def _load_competitors(self) -> List[CompetitorState]:
391
  archetypes = _COMPETITORS_DATA.get("archetypes", [])
392
  return [
@@ -1136,6 +1138,8 @@ class ViraltestEnvironment(Environment):
1136
 
1137
  self._shift_label = kwargs.get("shift_label")
1138
  self._chain_id = kwargs.get("episode_chain_id")
 
 
1139
 
1140
  if self._chain_id and self._chain_id in _BRAND_STORE:
1141
  brand = _BRAND_STORE[self._chain_id]
@@ -1439,20 +1443,29 @@ class ViraltestEnvironment(Environment):
1439
  # ----- reward -----
1440
 
1441
  def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
1442
- eng_component = min(1.0, engagement / 2.0) * 0.3
 
 
 
 
1443
 
 
1444
  prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
1445
  energy_delta = self._energy - prev_energy
1446
- energy_component = max(0.0, min(1.0, (energy_delta + 0.3) / 0.6)) * 0.15
1447
 
 
1448
  day_posts = self._posts_per_day.get(self._day, 0)
1449
  if 1 <= day_posts <= 2:
1450
- consistency = 1.0
1451
- elif day_posts == 0 or day_posts == 3:
1452
- consistency = 0.5
1453
- else:
1454
- consistency = 0.0
1455
- consistency_component = consistency * 0.15
 
 
 
1456
 
1457
  tag_component = 0.0
1458
  if sa.action_type == "post" and sa.tags:
@@ -1474,22 +1487,54 @@ class ViraltestEnvironment(Environment):
1474
  )
1475
  return max(0.0, min(1.0, raw))
1476
 
1477
- def _compute_rest_reward(self) -> float:
1478
- prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
1479
- energy_delta = self._energy - prev_energy
1480
- energy_component = max(0.0, min(1.0, (energy_delta + 0.3) / 0.6)) * 0.15
 
1481
 
1482
- day_posts = self._posts_per_day.get(self._day, 0)
1483
- if 1 <= day_posts <= 2:
1484
- consistency = 1.0
1485
- elif day_posts == 0 or day_posts == 3:
1486
- consistency = 0.5
1487
- else:
1488
- consistency = 0.0
1489
- consistency_component = consistency * 0.15
1490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1491
  burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
1492
  raw = energy_component + consistency_component - burnout_penalty
 
 
1493
  return max(0.0, min(1.0, raw))
1494
 
1495
  def _advance_time(self) -> None:
@@ -1700,6 +1745,13 @@ class ViraltestEnvironment(Environment):
1700
  return max(0.0, min(1.0, raw))
1701
 
1702
 
 
 
 
 
 
 
 
1703
  def _topic_overlap(topic_a: str, topic_b: str) -> bool:
1704
  words_a = set(topic_a.split())
1705
  words_b = set(topic_b.split())
 
387
  self._hours_since_sleep = 2
388
  self._sleep_debt = 0.0
389
 
390
+ self._reward_mode = "combined"
391
+
392
  def _load_competitors(self) -> List[CompetitorState]:
393
  archetypes = _COMPETITORS_DATA.get("archetypes", [])
394
  return [
 
1138
 
1139
  self._shift_label = kwargs.get("shift_label")
1140
  self._chain_id = kwargs.get("episode_chain_id")
1141
+ mode = kwargs.get("reward_mode", "combined")
1142
+ self._reward_mode = mode if mode in ("timing", "content", "combined") else "combined"
1143
 
1144
  if self._chain_id and self._chain_id in _BRAND_STORE:
1145
  brand = _BRAND_STORE[self._chain_id]
 
1443
  # ----- reward -----
1444
 
1445
  def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
1446
+ if self._reward_mode == "timing":
1447
+ return self._compute_timing_reward(sa, engagement)
1448
+ if self._reward_mode == "content":
1449
+ return self._compute_content_reward(sa, engagement)
1450
+ return self._compute_combined_reward(sa, engagement)
1451
 
1452
+ def _energy_component(self) -> float:
1453
  prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
1454
  energy_delta = self._energy - prev_energy
1455
+ return max(0.0, min(1.0, (energy_delta + 0.3) / 0.6))
1456
 
1457
+ def _consistency_score(self) -> float:
1458
  day_posts = self._posts_per_day.get(self._day, 0)
1459
  if 1 <= day_posts <= 2:
1460
+ return 1.0
1461
+ if day_posts == 0 or day_posts == 3:
1462
+ return 0.5
1463
+ return 0.0
1464
+
1465
+ def _compute_combined_reward(self, sa: ScheduledAction, engagement: float) -> float:
1466
+ eng_component = min(1.0, engagement / 2.0) * 0.3
1467
+ energy_component = self._energy_component() * 0.15
1468
+ consistency_component = self._consistency_score() * 0.15
1469
 
1470
  tag_component = 0.0
1471
  if sa.action_type == "post" and sa.tags:
 
1487
  )
1488
  return max(0.0, min(1.0, raw))
1489
 
1490
+ def _compute_timing_reward(self, sa: ScheduledAction, engagement: float) -> float:
1491
+ is_post = sa.action_type == "post"
1492
+ peak_hour_mult = 1.3 if is_post and self._get_hour_multiplier() >= 1.2 else 1.0
1493
+ trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
1494
+ eng_component = min(1.0, engagement / 2.0) * 0.40 * trending_topic_mult * peak_hour_mult
1495
 
1496
+ peak_bonus = min(1.0, self._get_hour_multiplier() / 1.3) if is_post else 0.0
1497
+ peak_component = peak_bonus * 0.20
1498
+
1499
+ energy_component = self._energy_component() * 0.20
1500
+ consistency_component = self._consistency_score() * 0.20
1501
+ burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
 
 
1502
 
1503
+ raw = eng_component + peak_component + energy_component + consistency_component - burnout_penalty
1504
+ return max(0.0, min(1.0, raw))
1505
+
1506
+ def _compute_content_reward(self, sa: ScheduledAction, engagement: float) -> float:
1507
+ is_post = sa.action_type == "post"
1508
+ trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
1509
+ eng_component = min(1.0, engagement / 2.0) * 0.20 * trending_topic_mult
1510
+
1511
+ tag_component = 0.0
1512
+ if is_post and sa.tags:
1513
+ trending_match = sum(1 for t in sa.tags if t.lower() in self._trending_tags) / 5.0
1514
+ tag_component = min(1.0, trending_match + 0.3) * 0.25
1515
+
1516
+ comp_component = 0.0
1517
+ if is_post:
1518
+ diff = self._calc_competitor_diff(sa.topic)
1519
+ comp_component = min(1.0, diff / 1.3) * 0.25
1520
+
1521
+ variety_component = 0.0
1522
+ intent_component = 0.0
1523
+ if is_post:
1524
+ variety_component = min(1.0, len(self._unique_content_types) / 4.0) * 0.15
1525
+ intent_component = (0.15 if sa.intent in INTENT_MULTIPLIER else 0.0)
1526
+
1527
+ burnout_penalty = 0.05 if self._energy < 0.2 else 0.0
1528
+ raw = eng_component + tag_component + comp_component + variety_component + intent_component - burnout_penalty
1529
+ return max(0.0, min(1.0, raw))
1530
+
1531
+ def _compute_rest_reward(self) -> float:
1532
+ energy_component = self._energy_component() * 0.15
1533
+ consistency_component = self._consistency_score() * 0.15
1534
  burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
1535
  raw = energy_component + consistency_component - burnout_penalty
1536
+ if self._reward_mode == "content":
1537
+ raw *= 0.5
1538
  return max(0.0, min(1.0, raw))
1539
 
1540
  def _advance_time(self) -> None:
 
1745
  return max(0.0, min(1.0, raw))
1746
 
1747
 
1748
+ def get_peak_hours(day_of_week: int, top_k: int = 2) -> List[int]:
1749
+ row = _HEATMAP_GRID.get(day_of_week % 7, [])
1750
+ if not row:
1751
+ return []
1752
+ return sorted(range(len(row)), key=lambda h: row[h], reverse=True)[:top_k]
1753
+
1754
+
1755
  def _topic_overlap(topic_a: str, topic_b: str) -> bool:
1756
  words_a = set(topic_a.split())
1757
  words_b = set(topic_b.split())
training/train_grpo.ipynb CHANGED
@@ -25,9 +25,7 @@
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": null,
29
  "metadata": {},
30
- "outputs": [],
31
  "source": [
32
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
33
  "!pip install -q torch torchvision torchaudio\n",
@@ -36,13 +34,13 @@
36
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
37
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
38
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
39
- ]
 
 
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": null,
44
  "metadata": {},
45
- "outputs": [],
46
  "source": [
47
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
48
  "import os\n",
@@ -118,13 +116,13 @@
118
  "print(f\"Branch: {REPO_BRANCH}\")\n",
119
  "print(f\"Commit: {commit}\")\n",
120
  "print(f\"Plots dir: {PLOTS_DIR}\")"
121
- ]
 
 
122
  },
123
  {
124
  "cell_type": "code",
125
- "execution_count": null,
126
  "metadata": {},
127
- "outputs": [],
128
  "source": [
129
  "# Cell 3: Imports (with runtime validation)\n",
130
  "import json, random, time, textwrap, copy, os, sys\n",
@@ -156,7 +154,7 @@
156
  "from models import ScheduledAction, ToolCall, ViraltestAction\n",
157
  "from server.viraltest_environment import (\n",
158
  " ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
159
- " TOPIC_CATEGORIES,\n",
160
  ")\n",
161
  "\n",
162
  "ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
@@ -178,7 +176,9 @@
178
  "import ast\n",
179
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
180
  "print(\"OK: ast.parse (syntax check)\")"
181
- ]
 
 
182
  },
183
  {
184
  "cell_type": "markdown",
@@ -191,9 +191,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": null,
195
  "metadata": {},
196
- "outputs": [],
197
  "source": [
198
  "# Cell 4: Define heuristic agents + episode runner\n",
199
  "_rng = random.Random(42)\n",
@@ -269,13 +267,13 @@
269
  " \"rewards\": rewards, \"energies\": energies}\n",
270
  "\n",
271
  "print(\"Agents and episode runner defined.\")"
272
- ]
 
 
273
  },
274
  {
275
  "cell_type": "code",
276
- "execution_count": null,
277
  "metadata": {},
278
- "outputs": [],
279
  "source": [
280
  "# Cell 5: Run baselines (safe)\n",
281
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -310,13 +308,13 @@
310
  "for name in BASELINE_AGENTS:\n",
311
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
312
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
313
- ]
 
 
314
  },
315
  {
316
  "cell_type": "code",
317
- "execution_count": null,
318
  "metadata": {},
319
- "outputs": [],
320
  "source": [
321
  "# Cell 6: Baseline plots\n",
322
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -334,7 +332,9 @@
334
  "fig.tight_layout()\n",
335
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
336
  "plt.show()"
337
- ]
 
 
338
  },
339
  {
340
  "cell_type": "markdown",
@@ -347,9 +347,7 @@
347
  },
348
  {
349
  "cell_type": "code",
350
- "execution_count": null,
351
  "metadata": {},
352
- "outputs": [],
353
  "source": [
354
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
355
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
@@ -393,13 +391,13 @@
393
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
394
  "if torch.cuda.is_available():\n",
395
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
396
- ]
 
 
397
  },
398
  {
399
  "cell_type": "code",
400
- "execution_count": null,
401
  "metadata": {},
402
- "outputs": [],
403
  "source": [
404
  "# Cell 8: LLM agent functions\n",
405
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
@@ -454,6 +452,16 @@
454
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
455
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
456
  "\n",
 
 
 
 
 
 
 
 
 
 
457
  "\n",
458
  "_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
459
  "\n",
@@ -472,7 +480,7 @@
472
  " return out\n",
473
  "\n",
474
  "\n",
475
- "def format_obs(obs, history=None):\n",
476
  " day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
477
  " signals_str = \"\"\n",
478
  " signals = getattr(obs, \"engagement_signals\", None)\n",
@@ -486,12 +494,14 @@
486
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
487
  " if not tool_str:\n",
488
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
 
489
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
490
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
491
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
492
  " f\"{signals_str}\"\n",
493
  " f\"{_format_history(history)}\"\n",
494
  " f\"Tool results:\\n{tool_str}\"\n",
 
495
  " f\"Plan today's actions (JSON only):\")\n",
496
  "\n",
497
  "\n",
@@ -615,12 +625,13 @@
615
  " return out\n",
616
  "\n",
617
  "\n",
618
- "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
 
619
  " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
620
  " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
621
  " n = len(tasks_seeds)\n",
622
  " envs = [ViraltestEnvironment() for _ in range(n)]\n",
623
- " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
624
  " rewards = [[] for _ in range(n)]\n",
625
  " energies = [[obs.creator_energy] for obs in obss]\n",
626
  " pairs = [[] for _ in range(n)]\n",
@@ -641,7 +652,12 @@
641
  "\n",
642
  " actions_by_idx = {i: rest_action for i in rest}\n",
643
  " if active:\n",
644
- " base_prompts = [format_obs(obss[i], histories[i]) for i in active]\n",
 
 
 
 
 
645
  "\n",
646
  " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
647
  " disc_resps, ptok = _gen(disc_prompts)\n",
@@ -716,7 +732,9 @@
716
  "\n",
717
  "\n",
718
  "print(\"LLM agent functions defined (batched).\")"
719
- ]
 
 
720
  },
721
  {
722
  "cell_type": "markdown",
@@ -729,9 +747,7 @@
729
  },
730
  {
731
  "cell_type": "code",
732
- "execution_count": null,
733
  "metadata": {},
734
- "outputs": [],
735
  "source": [
736
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
737
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
@@ -745,7 +761,9 @@
745
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
746
  "for t in TASKS:\n",
747
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
748
- ]
 
 
749
  },
750
  {
751
  "cell_type": "markdown",
@@ -764,9 +782,7 @@
764
  },
765
  {
766
  "cell_type": "code",
767
- "execution_count": null,
768
  "metadata": {},
769
- "outputs": [],
770
  "source": [
771
  "# Cell 10: Attach LoRA adapter\n",
772
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -780,118 +796,144 @@
780
  "model.enable_input_require_grads()\n",
781
  "peft_model = get_peft_model(model, lora_config)\n",
782
  "peft_model.print_trainable_parameters()"
783
- ]
 
 
784
  },
785
  {
786
  "cell_type": "code",
787
- "execution_count": null,
788
  "metadata": {},
789
- "outputs": [],
790
  "source": [
791
- "# Cell 11: Training loop\n",
 
 
792
  "from trl import SFTTrainer, SFTConfig\n",
793
  "from datasets import Dataset\n",
794
  "\n",
795
- "NUM_ROUNDS = 2\n",
796
  "EPISODES_PER_ROUND = 6\n",
797
- "QUALITY_FLOOR = 0.0 # 0 = always run SFT on positive-advantage samples\n",
 
 
 
 
 
 
798
  "\n",
799
  "training_log = {\n",
800
- " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
801
- " \"min_episode_reward\": [], \"avg_grader\": [], \"max_grader\": [],\n",
 
802
  " \"n_training_samples\": [], \"train_loss\": [],\n",
803
  "}\n",
804
  "\n",
805
  "t_start = time.time()\n",
806
- "\n",
807
- "for round_idx in range(1, NUM_ROUNDS + 1):\n",
808
- " print(f\"\\n{'=' * 60}\")\n",
809
- " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
810
- " print(f\"{'=' * 60}\")\n",
811
- "\n",
812
- " peft_model.eval()\n",
813
- " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
814
- " t_roll = time.time()\n",
815
- " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n",
816
- " eval=False, system=SYSTEM_PROMPT_TRAIN,\n",
817
- " log_tag=f\"train_round{round_idx}\")\n",
818
- " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
819
- "\n",
820
- " all_pairs, episode_rewards, episode_graders = [], [], []\n",
821
- " for ep, result in enumerate(results):\n",
822
- " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
823
- " episode_rewards.append(ep_reward)\n",
824
- " episode_graders.append(result[\"grader_score\"])\n",
825
- " kept = 0\n",
826
- " for pr in result[\"pairs\"]:\n",
827
- " if not is_well_formed_response(pr[\"response\"]):\n",
828
- " continue\n",
829
- " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT_TRAIN}<|im_end|>\\n\"\n",
830
- " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
831
- " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
832
- " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
833
- " kept += 1\n",
834
- " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
835
- " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
836
- "\n",
837
- " avg_r = float(np.mean(episode_rewards))\n",
838
- " avg_g = float(np.mean(episode_graders))\n",
839
- " max_g = float(max(episode_graders))\n",
840
- " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
841
- " if not all_pairs:\n",
842
- " print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
843
- " continue\n",
844
- " if max_g < QUALITY_FLOOR:\n",
845
- " print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
846
- " continue\n",
847
- "\n",
848
- " rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
849
- " adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
850
- " filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
851
- " if not filtered:\n",
852
- " print(\" SKIP SFT: zero positive-advantage samples\")\n",
853
- " continue\n",
854
- " print(f\" Kept {len(filtered)}/{len(all_pairs)} positive-advantage samples\")\n",
855
- "\n",
856
- " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
857
- "\n",
858
- " # SFT training (real gradient updates)\n",
859
- " sft_config = SFTConfig(\n",
860
- " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
861
- " num_train_epochs=1,\n",
862
- " per_device_train_batch_size=2,\n",
863
- " gradient_accumulation_steps=4,\n",
864
- " learning_rate=5e-6,\n",
865
- " warmup_steps=5,\n",
866
- " logging_steps=1,\n",
867
- " save_strategy=\"no\",\n",
868
- " max_length=2048,\n",
869
- " bf16=True,\n",
870
- " report_to=\"none\",\n",
871
- " )\n",
872
- "\n",
873
- " peft_model.train()\n",
874
- " trainer = SFTTrainer(\n",
875
- " model=peft_model, processing_class=tokenizer,\n",
876
- " train_dataset=dataset, args=sft_config,\n",
877
- " )\n",
878
- " train_result = trainer.train()\n",
879
- " loss = train_result.training_loss\n",
880
- " print(f\" Training loss: {loss:.4f}\")\n",
881
- "\n",
882
- " training_log[\"round\"].append(round_idx)\n",
883
- " training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
884
- " training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
885
- " training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
886
- " training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
887
- " training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
888
- " training_log[\"n_training_samples\"].append(len(filtered))\n",
889
- " training_log[\"train_loss\"].append(round(loss, 4))\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  "\n",
891
  "elapsed = time.time() - t_start\n",
892
- "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
893
  "print(pd.DataFrame(training_log).to_string(index=False))"
894
- ]
 
 
895
  },
896
  {
897
  "cell_type": "markdown",
@@ -904,9 +946,7 @@
904
  },
905
  {
906
  "cell_type": "code",
907
- "execution_count": null,
908
  "metadata": {},
909
- "outputs": [],
910
  "source": [
911
  "# Cell 12: Run trained model (batched)\n",
912
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
@@ -921,7 +961,9 @@
921
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
922
  "for t in TASKS:\n",
923
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
924
- ]
 
 
925
  },
926
  {
927
  "cell_type": "markdown",
@@ -932,37 +974,41 @@
932
  },
933
  {
934
  "cell_type": "code",
935
- "execution_count": null,
936
  "metadata": {},
937
- "outputs": [],
938
  "source": [
939
- "# Cell 13: Training curves\n",
940
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
941
- "rounds = training_log[\"round\"]\n",
 
 
942
  "\n",
943
- "axes[0].plot(rounds, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
944
- "axes[0].fill_between(rounds, training_log[\"avg_grader\"],\n",
945
  " training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
946
- "axes[0].set_xlabel('Round'); axes[0].set_ylabel('Grader Score')\n",
947
- "axes[0].set_title('Grader Score Over Rounds', fontweight='bold')\n",
 
 
948
  "axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
949
  "\n",
950
- "axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
951
- "axes[1].set_xlabel('Round'); axes[1].set_ylabel('Loss')\n",
 
 
952
  "axes[1].set_title('Training Loss', fontweight='bold')\n",
953
  "axes[1].grid(True, alpha=0.3)\n",
954
  "\n",
955
- "fig.suptitle('Viraltest v2 — LoRA Training Progress (Qwen 1.5B)', fontsize=14, fontweight='bold')\n",
956
  "fig.tight_layout()\n",
957
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
958
  "plt.show()"
959
- ]
 
 
960
  },
961
  {
962
  "cell_type": "code",
963
- "execution_count": null,
964
  "metadata": {},
965
- "outputs": [],
966
  "source": [
967
  "# Cell 14: Before vs After\n",
968
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -992,13 +1038,13 @@
992
  "fig.tight_layout()\n",
993
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
994
  "plt.show()"
995
- ]
 
 
996
  },
997
  {
998
  "cell_type": "code",
999
- "execution_count": null,
1000
  "metadata": {},
1001
- "outputs": [],
1002
  "source": [
1003
  "# Cell 15: Trajectory comparison\n",
1004
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -1022,7 +1068,9 @@
1022
  "fig.tight_layout()\n",
1023
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1024
  "plt.show()"
1025
- ]
 
 
1026
  },
1027
  {
1028
  "cell_type": "markdown",
@@ -1033,9 +1081,7 @@
1033
  },
1034
  {
1035
  "cell_type": "code",
1036
- "execution_count": null,
1037
  "metadata": {},
1038
- "outputs": [],
1039
  "source": [
1040
  "# Cell 16: Final summary\n",
1041
  "print(\"=\" * 67)\n",
@@ -1057,8 +1103,10 @@
1057
  "\n",
1058
  "summary = {\n",
1059
  " \"model\": MODEL_NAME,\n",
1060
- " \"training\": \"LoRA SFT (real weight updates)\",\n",
1061
- " \"rounds\": NUM_ROUNDS, \"episodes_per_round\": EPISODES_PER_ROUND,\n",
 
 
1062
  " \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
1063
  " \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
1064
  " \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
@@ -1072,13 +1120,13 @@
1072
  "\n",
1073
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1074
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1075
- ]
 
 
1076
  },
1077
  {
1078
  "cell_type": "code",
1079
- "execution_count": null,
1080
  "metadata": {},
1081
- "outputs": [],
1082
  "source": [
1083
  "# Cell 17: Save adapter\n",
1084
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -1086,7 +1134,9 @@
1086
  "tokenizer.save_pretrained(save_path)\n",
1087
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1088
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1089
- ]
 
 
1090
  }
1091
  ],
1092
  "metadata": {
@@ -1112,4 +1162,4 @@
1112
  },
1113
  "nbformat": 4,
1114
  "nbformat_minor": 4
1115
- }
 
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
31
  "!pip install -q torch torchvision torchaudio\n",
 
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
37
+ ],
38
+ "execution_count": null,
39
+ "outputs": []
40
  },
41
  {
42
  "cell_type": "code",
 
43
  "metadata": {},
 
44
  "source": [
45
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
46
  "import os\n",
 
116
  "print(f\"Branch: {REPO_BRANCH}\")\n",
117
  "print(f\"Commit: {commit}\")\n",
118
  "print(f\"Plots dir: {PLOTS_DIR}\")"
119
+ ],
120
+ "execution_count": null,
121
+ "outputs": []
122
  },
123
  {
124
  "cell_type": "code",
 
125
  "metadata": {},
 
126
  "source": [
127
  "# Cell 3: Imports (with runtime validation)\n",
128
  "import json, random, time, textwrap, copy, os, sys\n",
 
154
  "from models import ScheduledAction, ToolCall, ViraltestAction\n",
155
  "from server.viraltest_environment import (\n",
156
  " ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
157
+ " TOPIC_CATEGORIES, get_peak_hours,\n",
158
  ")\n",
159
  "\n",
160
  "ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
 
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
  "print(\"OK: ast.parse (syntax check)\")"
179
+ ],
180
+ "execution_count": null,
181
+ "outputs": []
182
  },
183
  {
184
  "cell_type": "markdown",
 
191
  },
192
  {
193
  "cell_type": "code",
 
194
  "metadata": {},
 
195
  "source": [
196
  "# Cell 4: Define heuristic agents + episode runner\n",
197
  "_rng = random.Random(42)\n",
 
267
  " \"rewards\": rewards, \"energies\": energies}\n",
268
  "\n",
269
  "print(\"Agents and episode runner defined.\")"
270
+ ],
271
+ "execution_count": null,
272
+ "outputs": []
273
  },
274
  {
275
  "cell_type": "code",
 
276
  "metadata": {},
 
277
  "source": [
278
  "# Cell 5: Run baselines (safe)\n",
279
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
308
  "for name in BASELINE_AGENTS:\n",
309
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
310
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
311
+ ],
312
+ "execution_count": null,
313
+ "outputs": []
314
  },
315
  {
316
  "cell_type": "code",
 
317
  "metadata": {},
 
318
  "source": [
319
  "# Cell 6: Baseline plots\n",
320
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
332
  "fig.tight_layout()\n",
333
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
334
  "plt.show()"
335
+ ],
336
+ "execution_count": null,
337
+ "outputs": []
338
  },
339
  {
340
  "cell_type": "markdown",
 
347
  },
348
  {
349
  "cell_type": "code",
 
350
  "metadata": {},
 
351
  "source": [
352
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
353
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
 
391
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
392
  "if torch.cuda.is_available():\n",
393
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
394
+ ],
395
+ "execution_count": null,
396
+ "outputs": []
397
  },
398
  {
399
  "cell_type": "code",
 
400
  "metadata": {},
 
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
 
452
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
453
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
454
  "\n",
455
+ "SYSTEM_PROMPT_TIMING = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
456
+ "\n",
457
+ "FOCUS: optimise WHEN to post. Identify peak hours for the audience (use query_audience / query_trends).\n",
458
+ "2 posts/day at peak hours beats 4 posts at random hours.\"\"\")\n",
459
+ "\n",
460
+ "SYSTEM_PROMPT_CONTENT = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
461
+ "\n",
462
+ "FOCUS: optimise WHAT to post. Vary content_type and intent across the week,\n",
463
+ "pick differentiated topics, exploit trending tags.\"\"\")\n",
464
+ "\n",
465
  "\n",
466
  "_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
467
  "\n",
 
480
  " return out\n",
481
  "\n",
482
  "\n",
483
+ "def format_obs(obs, history=None, extra_hint=None):\n",
484
  " day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
485
  " signals_str = \"\"\n",
486
  " signals = getattr(obs, \"engagement_signals\", None)\n",
 
494
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
495
  " if not tool_str:\n",
496
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
497
+ " hint_str = f\"Coach hint: today's peak hours are {extra_hint}.\\n\" if extra_hint else \"\"\n",
498
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
499
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
500
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
501
  " f\"{signals_str}\"\n",
502
  " f\"{_format_history(history)}\"\n",
503
  " f\"Tool results:\\n{tool_str}\"\n",
504
+ " f\"{hint_str}\"\n",
505
  " f\"Plan today's actions (JSON only):\")\n",
506
  "\n",
507
  "\n",
 
625
  " return out\n",
626
  "\n",
627
  "\n",
628
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None,\n",
629
+ " log_tag=None, hint_peak_hours=False, reward_mode=\"combined\"):\n",
630
  " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
631
  " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
632
  " n = len(tasks_seeds)\n",
633
  " envs = [ViraltestEnvironment() for _ in range(n)]\n",
634
+ " obss = [envs[i].reset(task=t, seed=s, reward_mode=reward_mode) for i, (t, s) in enumerate(tasks_seeds)]\n",
635
  " rewards = [[] for _ in range(n)]\n",
636
  " energies = [[obs.creator_energy] for obs in obss]\n",
637
  " pairs = [[] for _ in range(n)]\n",
 
652
  "\n",
653
  " actions_by_idx = {i: rest_action for i in rest}\n",
654
  " if active:\n",
655
+ " def _hint_for(i):\n",
656
+ " if not hint_peak_hours:\n",
657
+ " return None\n",
658
+ " hrs = get_peak_hours(obss[i].day_of_week, top_k=2)\n",
659
+ " return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
660
+ " base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
661
  "\n",
662
  " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
663
  " disc_resps, ptok = _gen(disc_prompts)\n",
 
732
  "\n",
733
  "\n",
734
  "print(\"LLM agent functions defined (batched).\")"
735
+ ],
736
+ "execution_count": null,
737
+ "outputs": []
738
  },
739
  {
740
  "cell_type": "markdown",
 
747
  },
748
  {
749
  "cell_type": "code",
 
750
  "metadata": {},
 
751
  "source": [
752
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
753
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
 
761
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
762
  "for t in TASKS:\n",
763
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
764
+ ],
765
+ "execution_count": null,
766
+ "outputs": []
767
  },
768
  {
769
  "cell_type": "markdown",
 
782
  },
783
  {
784
  "cell_type": "code",
 
785
  "metadata": {},
 
786
  "source": [
787
  "# Cell 10: Attach LoRA adapter\n",
788
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
796
  "model.enable_input_require_grads()\n",
797
  "peft_model = get_peft_model(model, lora_config)\n",
798
  "peft_model.print_trainable_parameters()"
799
+ ],
800
+ "execution_count": null,
801
+ "outputs": []
802
  },
803
  {
804
  "cell_type": "code",
 
805
  "metadata": {},
 
806
  "source": [
807
+ "# Cell 11: Two-phase training loop (timing -> content)\n",
808
+ "# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
809
+ "# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
810
  "from trl import SFTTrainer, SFTConfig\n",
811
  "from datasets import Dataset\n",
812
  "\n",
 
813
  "EPISODES_PER_ROUND = 6\n",
814
+ "ROUNDS_PER_PHASE = 3\n",
815
+ "QUALITY_FLOOR = 0.0\n",
816
+ "\n",
817
+ "PHASES = [\n",
818
+ " {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
819
+ " {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
820
+ "]\n",
821
  "\n",
822
  "training_log = {\n",
823
+ " \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
824
+ " \"avg_episode_reward\": [], \"max_episode_reward\": [], \"min_episode_reward\": [],\n",
825
+ " \"avg_grader\": [], \"max_grader\": [],\n",
826
  " \"n_training_samples\": [], \"train_loss\": [],\n",
827
  "}\n",
828
  "\n",
829
  "t_start = time.time()\n",
830
+ "global_step = 0\n",
831
+ "\n",
832
+ "for phase in PHASES:\n",
833
+ " phase_name = phase[\"name\"]\n",
834
+ " sys_prompt = phase[\"system\"]\n",
835
+ " reward_mode = phase[\"reward_mode\"]\n",
836
+ " print(f\"\\n{'#' * 60}\\n# PHASE {phase_name} (reward_mode={reward_mode})\\n{'#' * 60}\")\n",
837
+ "\n",
838
+ " for round_idx in range(ROUNDS_PER_PHASE):\n",
839
+ " use_hint = (round_idx == 0)\n",
840
+ " print(f\"\\n{'=' * 60}\\n{phase_name} | ROUND {round_idx+1}/{ROUNDS_PER_PHASE} | hint={use_hint}\\n{'=' * 60}\")\n",
841
+ "\n",
842
+ " peft_model.eval()\n",
843
+ " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + ep + round_idx * 10) for ep in range(EPISODES_PER_ROUND)]\n",
844
+ " t_roll = time.time()\n",
845
+ " results = run_llm_episodes_batched(\n",
846
+ " peft_model, tokenizer, tasks_seeds, verbose=False, eval=False,\n",
847
+ " system=sys_prompt, hint_peak_hours=use_hint, reward_mode=reward_mode,\n",
848
+ " log_tag=f\"{phase_name}_r{round_idx}\",\n",
849
+ " )\n",
850
+ " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
851
+ "\n",
852
+ " all_pairs, episode_rewards, episode_graders = [], [], []\n",
853
+ " for ep, result in enumerate(results):\n",
854
+ " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
855
+ " episode_rewards.append(ep_reward)\n",
856
+ " episode_graders.append(result[\"grader_score\"])\n",
857
+ " kept = 0\n",
858
+ " for pr in result[\"pairs\"]:\n",
859
+ " if not is_well_formed_response(pr[\"response\"]):\n",
860
+ " continue\n",
861
+ " text = (f\"<|im_start|>system\\n{sys_prompt}<|im_end|>\\n\"\n",
862
+ " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
863
+ " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
864
+ " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
865
+ " kept += 1\n",
866
+ " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
867
+ " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
868
+ "\n",
869
+ " avg_r = float(np.mean(episode_rewards))\n",
870
+ " avg_g = float(np.mean(episode_graders))\n",
871
+ " max_g = float(max(episode_graders))\n",
872
+ " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
873
+ "\n",
874
+ " loss = float(\"nan\")\n",
875
+ " n_filtered = 0\n",
876
+ " if not all_pairs:\n",
877
+ " print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
878
+ " elif max_g < QUALITY_FLOOR:\n",
879
+ " print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
880
+ " else:\n",
881
+ " rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
882
+ " adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
883
+ " filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
884
+ " if not filtered:\n",
885
+ " print(\" SKIP SFT: zero positive-advantage samples\")\n",
886
+ " else:\n",
887
+ " n_filtered = len(filtered)\n",
888
+ " print(f\" Kept {n_filtered}/{len(all_pairs)} positive-advantage samples\")\n",
889
+ " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
890
+ " sft_config = SFTConfig(\n",
891
+ " output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
892
+ " num_train_epochs=1,\n",
893
+ " per_device_train_batch_size=2,\n",
894
+ " gradient_accumulation_steps=4,\n",
895
+ " learning_rate=5e-6,\n",
896
+ " warmup_steps=5,\n",
897
+ " logging_steps=1,\n",
898
+ " save_strategy=\"no\",\n",
899
+ " max_length=2048,\n",
900
+ " bf16=True,\n",
901
+ " report_to=\"none\",\n",
902
+ " )\n",
903
+ " peft_model.train()\n",
904
+ " trainer = SFTTrainer(\n",
905
+ " model=peft_model, processing_class=tokenizer,\n",
906
+ " train_dataset=dataset, args=sft_config,\n",
907
+ " )\n",
908
+ " train_result = trainer.train()\n",
909
+ " loss = float(train_result.training_loss)\n",
910
+ " print(f\" Training loss: {loss:.4f}\")\n",
911
+ "\n",
912
+ " global_step += 1\n",
913
+ " training_log[\"phase\"].append(phase_name)\n",
914
+ " training_log[\"round\"].append(round_idx + 1)\n",
915
+ " training_log[\"global_step\"].append(global_step)\n",
916
+ " training_log[\"use_hint\"].append(use_hint)\n",
917
+ " training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
918
+ " training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
919
+ " training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
920
+ " training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
921
+ " training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
922
+ " training_log[\"n_training_samples\"].append(n_filtered)\n",
923
+ " training_log[\"train_loss\"].append(round(loss, 4) if loss == loss else float(\"nan\"))\n",
924
+ "\n",
925
+ " save_dir = f\"./checkpoints/{phase_name}_adapter\"\n",
926
+ " os.makedirs(save_dir, exist_ok=True)\n",
927
+ " peft_model.save_pretrained(save_dir)\n",
928
+ " tokenizer.save_pretrained(save_dir)\n",
929
+ " print(f\"\\n Saved {phase_name} adapter -> {save_dir}\")\n",
930
  "\n",
931
  "elapsed = time.time() - t_start\n",
932
+ "print(f\"\\nTwo-phase training complete in {elapsed/60:.1f} min\")\n",
933
  "print(pd.DataFrame(training_log).to_string(index=False))"
934
+ ],
935
+ "execution_count": null,
936
+ "outputs": []
937
  },
938
  {
939
  "cell_type": "markdown",
 
946
  },
947
  {
948
  "cell_type": "code",
 
949
  "metadata": {},
 
950
  "source": [
951
  "# Cell 12: Run trained model (batched)\n",
952
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
 
961
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
962
  "for t in TASKS:\n",
963
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
964
+ ],
965
+ "execution_count": null,
966
+ "outputs": []
967
  },
968
  {
969
  "cell_type": "markdown",
 
974
  },
975
  {
976
  "cell_type": "code",
 
977
  "metadata": {},
 
978
  "source": [
979
+ "# Cell 13: Training curves (two-phase)\n",
980
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
981
+ "steps = training_log[\"global_step\"]\n",
982
+ "phases = training_log[\"phase\"]\n",
983
+ "phase1_end = max([s for s, p in zip(steps, phases) if p == \"phase1_timing\"], default=0)\n",
984
  "\n",
985
+ "axes[0].plot(steps, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
986
+ "axes[0].fill_between(steps, training_log[\"avg_grader\"],\n",
987
  " training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
988
+ "if phase1_end > 0:\n",
989
+ " axes[0].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6, label='phase split')\n",
990
+ "axes[0].set_xlabel('Global step'); axes[0].set_ylabel('Grader Score')\n",
991
+ "axes[0].set_title('Grader Score (timing -> content)', fontweight='bold')\n",
992
  "axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
993
  "\n",
994
+ "axes[1].plot(steps, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
995
+ "if phase1_end > 0:\n",
996
+ " axes[1].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6)\n",
997
+ "axes[1].set_xlabel('Global step'); axes[1].set_ylabel('Loss')\n",
998
  "axes[1].set_title('Training Loss', fontweight='bold')\n",
999
  "axes[1].grid(True, alpha=0.3)\n",
1000
  "\n",
1001
+ "fig.suptitle('Viraltest v2 — Two-Phase LoRA Training (timing -> content)', fontsize=14, fontweight='bold')\n",
1002
  "fig.tight_layout()\n",
1003
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
1004
  "plt.show()"
1005
+ ],
1006
+ "execution_count": null,
1007
+ "outputs": []
1008
  },
1009
  {
1010
  "cell_type": "code",
 
1011
  "metadata": {},
 
1012
  "source": [
1013
  "# Cell 14: Before vs After\n",
1014
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
1038
  "fig.tight_layout()\n",
1039
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
1040
  "plt.show()"
1041
+ ],
1042
+ "execution_count": null,
1043
+ "outputs": []
1044
  },
1045
  {
1046
  "cell_type": "code",
 
1047
  "metadata": {},
 
1048
  "source": [
1049
  "# Cell 15: Trajectory comparison\n",
1050
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
1068
  "fig.tight_layout()\n",
1069
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1070
  "plt.show()"
1071
+ ],
1072
+ "execution_count": null,
1073
+ "outputs": []
1074
  },
1075
  {
1076
  "cell_type": "markdown",
 
1081
  },
1082
  {
1083
  "cell_type": "code",
 
1084
  "metadata": {},
 
1085
  "source": [
1086
  "# Cell 16: Final summary\n",
1087
  "print(\"=\" * 67)\n",
 
1103
  "\n",
1104
  "summary = {\n",
1105
  " \"model\": MODEL_NAME,\n",
1106
+ " \"training\": \"Two-phase LoRA SFT (timing -> content) with hardcoded peak-hours hint on round 1 of each phase\",\n",
1107
+ " \"phases\": [p[\"name\"] for p in PHASES],\n",
1108
+ " \"rounds_per_phase\": ROUNDS_PER_PHASE,\n",
1109
+ " \"episodes_per_round\": EPISODES_PER_ROUND,\n",
1110
  " \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
1111
  " \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
1112
  " \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
 
1120
  "\n",
1121
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1122
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1123
+ ],
1124
+ "execution_count": null,
1125
+ "outputs": []
1126
  },
1127
  {
1128
  "cell_type": "code",
 
1129
  "metadata": {},
 
1130
  "source": [
1131
  "# Cell 17: Save adapter\n",
1132
  "save_path = \"./viraltest_trained_adapter\"\n",
 
1134
  "tokenizer.save_pretrained(save_path)\n",
1135
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1136
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1137
+ ],
1138
+ "execution_count": null,
1139
+ "outputs": []
1140
  }
1141
  ],
1142
  "metadata": {
 
1162
  },
1163
  "nbformat": 4,
1164
  "nbformat_minor": 4
1165
+ }