vaibhav12332112312 commited on
Commit
30614d3
·
1 Parent(s): b1c1732

Inject peak hours + history + post-mandate, run SFT every round

Browse files

Prompt explicitly tells the model to schedule >=2 `post` actions per day at
heatmap peak hours, plus a rolling 3-day Recent summary so it can react to
its own past results. Eval runs greedy (deterministic), training stays
sampled. QUALITY_FLOOR=0 so SFT runs on positive-advantage samples even
when grader scores are still low. Bumped to 2 training rounds.

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +194 -154
training/train_grpo.ipynb CHANGED
@@ -25,9 +25,7 @@
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": null,
29
  "metadata": {},
30
- "outputs": [],
31
  "source": [
32
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
33
  "!pip install -q torch torchvision torchaudio\n",
@@ -36,13 +34,13 @@
36
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
37
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
38
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
39
- ]
 
 
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": null,
44
  "metadata": {},
45
- "outputs": [],
46
  "source": [
47
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
48
  "import os\n",
@@ -118,13 +116,13 @@
118
  "print(f\"Branch: {REPO_BRANCH}\")\n",
119
  "print(f\"Commit: {commit}\")\n",
120
  "print(f\"Plots dir: {PLOTS_DIR}\")"
121
- ]
 
 
122
  },
123
  {
124
  "cell_type": "code",
125
- "execution_count": null,
126
  "metadata": {},
127
- "outputs": [],
128
  "source": [
129
  "# Cell 3: Imports (with runtime validation)\n",
130
  "import json, random, time, textwrap, copy, os, sys\n",
@@ -178,7 +176,9 @@
178
  "import ast\n",
179
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
180
  "print(\"OK: ast.parse (syntax check)\")"
181
- ]
 
 
182
  },
183
  {
184
  "cell_type": "markdown",
@@ -191,9 +191,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": null,
195
  "metadata": {},
196
- "outputs": [],
197
  "source": [
198
  "# Cell 4: Define heuristic agents + episode runner\n",
199
  "_rng = random.Random(42)\n",
@@ -269,13 +267,13 @@
269
  " \"rewards\": rewards, \"energies\": energies}\n",
270
  "\n",
271
  "print(\"Agents and episode runner defined.\")"
272
- ]
 
 
273
  },
274
  {
275
  "cell_type": "code",
276
- "execution_count": null,
277
  "metadata": {},
278
- "outputs": [],
279
  "source": [
280
  "# Cell 5: Run baselines (safe)\n",
281
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -310,13 +308,13 @@
310
  "for name in BASELINE_AGENTS:\n",
311
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
312
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
313
- ]
 
 
314
  },
315
  {
316
  "cell_type": "code",
317
- "execution_count": null,
318
  "metadata": {},
319
- "outputs": [],
320
  "source": [
321
  "# Cell 6: Baseline plots\n",
322
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -334,7 +332,9 @@
334
  "fig.tight_layout()\n",
335
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
336
  "plt.show()"
337
- ]
 
 
338
  },
339
  {
340
  "cell_type": "markdown",
@@ -347,9 +347,7 @@
347
  },
348
  {
349
  "cell_type": "code",
350
- "execution_count": null,
351
  "metadata": {},
352
- "outputs": [],
353
  "source": [
354
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
355
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
@@ -393,13 +391,13 @@
393
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
394
  "if torch.cuda.is_available():\n",
395
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
396
- ]
 
 
397
  },
398
  {
399
  "cell_type": "code",
400
- "execution_count": null,
401
  "metadata": {},
402
- "outputs": [],
403
  "source": [
404
  "# Cell 8: LLM agent functions\n",
405
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
@@ -439,7 +437,19 @@
439
  " like_bait -> likes from existing followers\n",
440
  "- tags: up to 5 hashtags\n",
441
  "- topic: free-form string\n",
442
- "- empty scheduled_actions = full day rest\"\"\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
443
  "\n",
444
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
445
  "\n",
@@ -458,9 +468,28 @@
458
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
459
  "\n",
460
  "\n",
461
- "def format_obs(obs):\n",
462
- " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
463
- " day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  " signals_str = \"\"\n",
465
  " signals = getattr(obs, \"engagement_signals\", None)\n",
466
  " if signals:\n",
@@ -473,10 +502,11 @@
473
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
474
  " if not tool_str:\n",
475
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
476
- " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
477
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
478
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
479
  " f\"{signals_str}\"\n",
 
480
  " f\"Tool results:\\n{tool_str}\"\n",
481
  " f\"Plan today's actions (JSON only):\")\n",
482
  "\n",
@@ -554,11 +584,11 @@
554
  "\n",
555
  "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
556
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
557
- " gen_kwargs = dict(\n",
558
- " max_new_tokens=max_new_tokens,\n",
559
- " pad_token_id=tok.pad_token_id,\n",
560
- " do_sample=True, temperature=1.0, top_p=0.95,\n",
561
- " )\n",
562
  " with torch.no_grad():\n",
563
  " out = mdl.generate(**enc, **gen_kwargs)\n",
564
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
@@ -576,96 +606,104 @@
576
  " f.write(json.dumps(rec) + \"\\n\")\n",
577
  "\n",
578
  "\n",
579
- "DISCOVERY_SUFFIX = \"\\n\\nPHASE A (DISCOVERY): respond with JSON {\\\"tool_calls\\\": [...]} only.\"\n",
580
- "PLANNING_SUFFIX = \"\\n\\nPHASE B (PLANNING): respond with JSON {\\\"scheduled_actions\\\": [...], \\\"notes\\\": \\\"...\\\"} using the fresh Tool results above.\"\n",
581
- "\n",
582
- "\n",
583
- "def _parse_tool_calls_only(text):\n",
584
- " return parse_model_output(text).tool_calls\n",
585
- "\n",
586
- "\n",
587
- "def _parse_actions_only(text):\n",
588
- " a = parse_model_output(text)\n",
589
- " return ViraltestAction(tool_calls=[], scheduled_actions=a.scheduled_actions, notes=a.notes)\n",
590
- "\n",
591
- "\n",
592
- "def _format_fresh_results(fresh):\n",
593
- " if not fresh:\n",
594
- " return \"\"\n",
595
- " out = \"Fresh tool results (PHASE A):\\n\"\n",
596
- " for tr in fresh:\n",
597
- " if tr.success:\n",
598
- " out += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
599
- " else:\n",
600
- " out += f\" {tr.name}: ERROR {tr.error}\\n\"\n",
601
- " return out\n",
602
- "\n",
603
- "\n",
604
- "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
605
- " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
606
- " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
607
- " n = len(tasks_seeds)\n",
608
- " envs = [ViraltestEnvironment() for _ in range(n)]\n",
609
- " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
610
- " rewards = [[] for _ in range(n)]\n",
611
- " energies = [[obs.creator_energy] for obs in obss]\n",
612
- " pairs = [[] for _ in range(n)]\n",
613
- " done_mask = [obs.done for obs in obss]\n",
614
- " rest_action = ViraltestAction(scheduled_actions=[])\n",
615
- "\n",
616
- " def _gen(prompts):\n",
617
- " chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
618
- " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
619
- " return _batched_generate(mdl, tok, texts, eval=eval)\n",
620
- "\n",
621
- " for day in range(1, TASK_HORIZON + 1):\n",
622
- " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
623
- " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
624
- " if not active and not rest:\n",
625
- " break\n",
626
- "\n",
627
- " actions_by_idx = {i: rest_action for i in rest}\n",
628
- " if active:\n",
629
- " base_prompts = [format_obs(obss[i]) for i in active]\n",
630
- "\n",
631
- " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
632
- " disc_resps, ptok = _gen(disc_prompts)\n",
633
- " if verbose:\n",
634
- " print(f\" D{day:2d}A: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
635
- "\n",
636
- " fresh_per_active = []\n",
637
- " for j, i in enumerate(active):\n",
638
- " tcs = _parse_tool_calls_only(disc_resps[j])\n",
639
- " fresh_per_active.append([envs[i]._dispatch_tool(tc) for tc in tcs])\n",
640
- " pairs[i].append({\"prompt\": disc_prompts[j], \"response\": disc_resps[j],\n",
641
- " \"step\": len(rewards[i]), \"phase\": \"A\"})\n",
642
- " if log_tag is not None:\n",
643
- " t, s = tasks_seeds[i]\n",
644
- " _log_io(f\"{log_tag}/A\", i, day, t, s, disc_prompts[j], disc_resps[j])\n",
645
- "\n",
646
- " plan_prompts = [base_prompts[j] + \"\\n\" + _format_fresh_results(fresh_per_active[j]) + PLANNING_SUFFIX\n",
647
- " for j in range(len(active))]\n",
648
- " plan_resps, ptok2 = _gen(plan_prompts)\n",
649
- " if verbose:\n",
650
- " print(f\" D{day:2d}B: batch={len(active)} prompt_tok={ptok2}\")\n",
651
- "\n",
652
- " for j, i in enumerate(active):\n",
653
- " actions_by_idx[i] = _parse_actions_only(plan_resps[j])\n",
654
- " pairs[i].append({\"prompt\": plan_prompts[j], \"response\": plan_resps[j],\n",
655
- " \"step\": len(rewards[i]), \"phase\": \"B\"})\n",
656
- " if log_tag is not None:\n",
657
- " t, s = tasks_seeds[i]\n",
658
- " _log_io(f\"{log_tag}/B\", i, day, t, s, plan_prompts[j], plan_resps[j])\n",
659
- "\n",
660
- " for i in range(n):\n",
661
- " if done_mask[i] or i not in actions_by_idx:\n",
662
- " continue\n",
663
- " obss[i] = envs[i].step(actions_by_idx[i])\n",
664
- " r = obss[i].reward or 0.0\n",
665
- " rewards[i].append(r)\n",
666
- " energies[i].append(obss[i].creator_energy)\n",
667
- " if obss[i].done:\n",
668
- " done_mask[i] = True\n",
 
 
 
 
 
 
 
 
669
  "\n",
670
  " GAMMA, TERMINAL_W = 0.95, 5.0\n",
671
  " results = []\n",
@@ -694,7 +732,9 @@
694
  "\n",
695
  "\n",
696
  "print(\"LLM agent functions defined (batched).\")"
697
- ]
 
 
698
  },
699
  {
700
  "cell_type": "markdown",
@@ -707,9 +747,7 @@
707
  },
708
  {
709
  "cell_type": "code",
710
- "execution_count": null,
711
  "metadata": {},
712
- "outputs": [],
713
  "source": [
714
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
715
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
@@ -723,7 +761,9 @@
723
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
724
  "for t in TASKS:\n",
725
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
726
- ]
 
 
727
  },
728
  {
729
  "cell_type": "markdown",
@@ -742,9 +782,7 @@
742
  },
743
  {
744
  "cell_type": "code",
745
- "execution_count": null,
746
  "metadata": {},
747
- "outputs": [],
748
  "source": [
749
  "# Cell 10: Attach LoRA adapter\n",
750
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -758,21 +796,21 @@
758
  "model.enable_input_require_grads()\n",
759
  "peft_model = get_peft_model(model, lora_config)\n",
760
  "peft_model.print_trainable_parameters()"
761
- ]
 
 
762
  },
763
  {
764
  "cell_type": "code",
765
- "execution_count": null,
766
  "metadata": {},
767
- "outputs": [],
768
  "source": [
769
  "# Cell 11: Training loop\n",
770
  "from trl import SFTTrainer, SFTConfig\n",
771
  "from datasets import Dataset\n",
772
  "\n",
773
- "NUM_ROUNDS = 1\n",
774
  "EPISODES_PER_ROUND = 6\n",
775
- "QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n",
776
  "\n",
777
  "training_log = {\n",
778
  " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
@@ -869,7 +907,9 @@
869
  "elapsed = time.time() - t_start\n",
870
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
871
  "print(pd.DataFrame(training_log).to_string(index=False))"
872
- ]
 
 
873
  },
874
  {
875
  "cell_type": "markdown",
@@ -882,9 +922,7 @@
882
  },
883
  {
884
  "cell_type": "code",
885
- "execution_count": null,
886
  "metadata": {},
887
- "outputs": [],
888
  "source": [
889
  "# Cell 12: Run trained model (batched)\n",
890
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
@@ -899,7 +937,9 @@
899
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
900
  "for t in TASKS:\n",
901
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
902
- ]
 
 
903
  },
904
  {
905
  "cell_type": "markdown",
@@ -910,9 +950,7 @@
910
  },
911
  {
912
  "cell_type": "code",
913
- "execution_count": null,
914
  "metadata": {},
915
- "outputs": [],
916
  "source": [
917
  "# Cell 13: Training curves\n",
918
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -934,13 +972,13 @@
934
  "fig.tight_layout()\n",
935
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
936
  "plt.show()"
937
- ]
 
 
938
  },
939
  {
940
  "cell_type": "code",
941
- "execution_count": null,
942
  "metadata": {},
943
- "outputs": [],
944
  "source": [
945
  "# Cell 14: Before vs After\n",
946
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -970,13 +1008,13 @@
970
  "fig.tight_layout()\n",
971
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
972
  "plt.show()"
973
- ]
 
 
974
  },
975
  {
976
  "cell_type": "code",
977
- "execution_count": null,
978
  "metadata": {},
979
- "outputs": [],
980
  "source": [
981
  "# Cell 15: Trajectory comparison\n",
982
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -1000,7 +1038,9 @@
1000
  "fig.tight_layout()\n",
1001
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1002
  "plt.show()"
1003
- ]
 
 
1004
  },
1005
  {
1006
  "cell_type": "markdown",
@@ -1011,9 +1051,7 @@
1011
  },
1012
  {
1013
  "cell_type": "code",
1014
- "execution_count": null,
1015
  "metadata": {},
1016
- "outputs": [],
1017
  "source": [
1018
  "# Cell 16: Final summary\n",
1019
  "print(\"=\" * 67)\n",
@@ -1050,13 +1088,13 @@
1050
  "\n",
1051
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1052
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1053
- ]
 
 
1054
  },
1055
  {
1056
  "cell_type": "code",
1057
- "execution_count": null,
1058
  "metadata": {},
1059
- "outputs": [],
1060
  "source": [
1061
  "# Cell 17: Save adapter\n",
1062
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -1064,7 +1102,9 @@
1064
  "tokenizer.save_pretrained(save_path)\n",
1065
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1066
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1067
- ]
 
 
1068
  }
1069
  ],
1070
  "metadata": {
@@ -1090,4 +1130,4 @@
1090
  },
1091
  "nbformat": 4,
1092
  "nbformat_minor": 4
1093
- }
 
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
31
  "!pip install -q torch torchvision torchaudio\n",
 
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
  "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
37
+ ],
38
+ "execution_count": null,
39
+ "outputs": []
40
  },
41
  {
42
  "cell_type": "code",
 
43
  "metadata": {},
 
44
  "source": [
45
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
46
  "import os\n",
 
116
  "print(f\"Branch: {REPO_BRANCH}\")\n",
117
  "print(f\"Commit: {commit}\")\n",
118
  "print(f\"Plots dir: {PLOTS_DIR}\")"
119
+ ],
120
+ "execution_count": null,
121
+ "outputs": []
122
  },
123
  {
124
  "cell_type": "code",
 
125
  "metadata": {},
 
126
  "source": [
127
  "# Cell 3: Imports (with runtime validation)\n",
128
  "import json, random, time, textwrap, copy, os, sys\n",
 
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
  "print(\"OK: ast.parse (syntax check)\")"
179
+ ],
180
+ "execution_count": null,
181
+ "outputs": []
182
  },
183
  {
184
  "cell_type": "markdown",
 
191
  },
192
  {
193
  "cell_type": "code",
 
194
  "metadata": {},
 
195
  "source": [
196
  "# Cell 4: Define heuristic agents + episode runner\n",
197
  "_rng = random.Random(42)\n",
 
267
  " \"rewards\": rewards, \"energies\": energies}\n",
268
  "\n",
269
  "print(\"Agents and episode runner defined.\")"
270
+ ],
271
+ "execution_count": null,
272
+ "outputs": []
273
  },
274
  {
275
  "cell_type": "code",
 
276
  "metadata": {},
 
277
  "source": [
278
  "# Cell 5: Run baselines (safe)\n",
279
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
308
  "for name in BASELINE_AGENTS:\n",
309
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
310
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
311
+ ],
312
+ "execution_count": null,
313
+ "outputs": []
314
  },
315
  {
316
  "cell_type": "code",
 
317
  "metadata": {},
 
318
  "source": [
319
  "# Cell 6: Baseline plots\n",
320
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
332
  "fig.tight_layout()\n",
333
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
334
  "plt.show()"
335
+ ],
336
+ "execution_count": null,
337
+ "outputs": []
338
  },
339
  {
340
  "cell_type": "markdown",
 
347
  },
348
  {
349
  "cell_type": "code",
 
350
  "metadata": {},
 
351
  "source": [
352
  "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
353
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
 
391
  "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
392
  "if torch.cuda.is_available():\n",
393
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
394
+ ],
395
+ "execution_count": null,
396
+ "outputs": []
397
  },
398
  {
399
  "cell_type": "code",
 
400
  "metadata": {},
 
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
  "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
 
437
  " like_bait -> likes from existing followers\n",
438
  "- tags: up to 5 hashtags\n",
439
  "- topic: free-form string\n",
440
+ "- empty scheduled_actions = full day rest\n",
441
+ "\n",
442
+ "POSTING RULES (critical — only `post` actions earn engagement reward):\n",
443
+ "- EVERY active day MUST schedule at least 2 `post` actions (max 3). `create_content`\n",
444
+ " alone gives 0 reward — content stays in queue. Mix in 0-1 `create_content` only\n",
445
+ " if the queue is empty.\n",
446
+ "- Schedule posts at HEATMAP PEAK HOURS (Buffer/Sprout-derived):\n",
447
+ " Mon peaks 14, 18, 19 Tue peaks 14, 15, 19\n",
448
+ " Wed peaks 13, 14, 18 Thu peaks 12, 13, 19\n",
449
+ " Fri peaks 12, 13, 22 Sat peaks 21, 22, 13\n",
450
+ " Sun peaks 21, 22, 11\n",
451
+ "- Vary `intent` across the day; rotate `content_type` to avoid fatigue.\n",
452
+ "- Reuse strong tags from the Recent-days summary (those that earned reward).\"\"\")\n",
453
  "\n",
454
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
455
  "\n",
 
468
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
469
  "\n",
470
  "\n",
471
+ "_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
472
+ "_PEAK_HOURS = {0:[14,18,19], 1:[14,15,19], 2:[13,14,18], 3:[12,13,19],\n",
473
+ " 4:[12,13,22], 5:[21,22,13], 6:[21,22,11]}\n",
474
+ "\n",
475
+ "\n",
476
+ "def _format_history(history, k=3):\n",
477
+ " if not history:\n",
478
+ " return \"Recent (last 3 days): (none — day 1)\\n\"\n",
479
+ " out = \"Recent (last 3 days):\\n\"\n",
480
+ " for h in history[-k:]:\n",
481
+ " posts = h.get(\"posts\", [])\n",
482
+ " if not posts:\n",
483
+ " out += f\" D-{h['ago']}: rest reward={h['reward']:.2f}\\n\"\n",
484
+ " else:\n",
485
+ " ph = \",\".join(f\"{p['hour']}h/{p['content_type'][:4]}/{p['intent'][:4]}\" for p in posts)\n",
486
+ " out += f\" D-{h['ago']}: posts=[{ph}] reward={h['reward']:.2f}\\n\"\n",
487
+ " return out\n",
488
+ "\n",
489
+ "\n",
490
+ "def format_obs(obs, history=None):\n",
491
+ " day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
492
+ " peaks = _PEAK_HOURS.get(obs.day_of_week, [12, 18, 20])\n",
493
  " signals_str = \"\"\n",
494
  " signals = getattr(obs, \"engagement_signals\", None)\n",
495
  " if signals:\n",
 
502
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
503
  " if not tool_str:\n",
504
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
505
+ " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed} | today's peak hours={peaks}\\n\"\n",
506
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
507
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
508
  " f\"{signals_str}\"\n",
509
+ " f\"{_format_history(history)}\"\n",
510
  " f\"Tool results:\\n{tool_str}\"\n",
511
  " f\"Plan today's actions (JSON only):\")\n",
512
  "\n",
 
584
  "\n",
585
  "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
586
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
587
+ " if eval:\n",
588
+ " gen_kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.pad_token_id, do_sample=False)\n",
589
+ " else:\n",
590
+ " gen_kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.pad_token_id,\n",
591
+ " do_sample=True, temperature=0.9, top_p=0.95)\n",
592
  " with torch.no_grad():\n",
593
  " out = mdl.generate(**enc, **gen_kwargs)\n",
594
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
 
606
  " f.write(json.dumps(rec) + \"\\n\")\n",
607
  "\n",
608
  "\n",
609
+ "DISCOVERY_SUFFIX = \"\\n\\nPHASE A (DISCOVERY): respond with JSON {\\\"tool_calls\\\": [...]} only.\"\n",
610
+ "PLANNING_SUFFIX = \"\\n\\nPHASE B (PLANNING): respond with JSON {\\\"scheduled_actions\\\": [...], \\\"notes\\\": \\\"...\\\"} using the fresh Tool results above.\"\n",
611
+ "\n",
612
+ "\n",
613
+ "def _parse_tool_calls_only(text):\n",
614
+ " return parse_model_output(text).tool_calls\n",
615
+ "\n",
616
+ "\n",
617
+ "def _parse_actions_only(text):\n",
618
+ " a = parse_model_output(text)\n",
619
+ " return ViraltestAction(tool_calls=[], scheduled_actions=a.scheduled_actions, notes=a.notes)\n",
620
+ "\n",
621
+ "\n",
622
+ "def _format_fresh_results(fresh):\n",
623
+ " if not fresh:\n",
624
+ " return \"\"\n",
625
+ " out = \"Fresh tool results (PHASE A):\\n\"\n",
626
+ " for tr in fresh:\n",
627
+ " if tr.success:\n",
628
+ " out += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
629
+ " else:\n",
630
+ " out += f\" {tr.name}: ERROR {tr.error}\\n\"\n",
631
+ " return out\n",
632
+ "\n",
633
+ "\n",
634
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
635
+ " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
636
+ " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
637
+ " n = len(tasks_seeds)\n",
638
+ " envs = [ViraltestEnvironment() for _ in range(n)]\n",
639
+ " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
640
+ " rewards = [[] for _ in range(n)]\n",
641
+ " energies = [[obs.creator_energy] for obs in obss]\n",
642
+ " pairs = [[] for _ in range(n)]\n",
643
+ " histories = [[] for _ in range(n)]\n",
644
+ " done_mask = [obs.done for obs in obss]\n",
645
+ " rest_action = ViraltestAction(scheduled_actions=[])\n",
646
+ "\n",
647
+ " def _gen(prompts):\n",
648
+ " chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
649
+ " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
650
+ " return _batched_generate(mdl, tok, texts, eval=eval)\n",
651
+ "\n",
652
+ " for day in range(1, TASK_HORIZON + 1):\n",
653
+ " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
654
+ " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
655
+ " if not active and not rest:\n",
656
+ " break\n",
657
+ "\n",
658
+ " actions_by_idx = {i: rest_action for i in rest}\n",
659
+ " if active:\n",
660
+ " base_prompts = [format_obs(obss[i], histories[i]) for i in active]\n",
661
+ "\n",
662
+ " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
663
+ " disc_resps, ptok = _gen(disc_prompts)\n",
664
+ " if verbose:\n",
665
+ " print(f\" D{day:2d}A: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
666
+ "\n",
667
+ " fresh_per_active = []\n",
668
+ " for j, i in enumerate(active):\n",
669
+ " tcs = _parse_tool_calls_only(disc_resps[j])\n",
670
+ " fresh_per_active.append([envs[i]._dispatch_tool(tc) for tc in tcs])\n",
671
+ " pairs[i].append({\"prompt\": disc_prompts[j], \"response\": disc_resps[j],\n",
672
+ " \"step\": len(rewards[i]), \"phase\": \"A\"})\n",
673
+ " if log_tag is not None:\n",
674
+ " t, s = tasks_seeds[i]\n",
675
+ " _log_io(f\"{log_tag}/A\", i, day, t, s, disc_prompts[j], disc_resps[j])\n",
676
+ "\n",
677
+ " plan_prompts = [base_prompts[j] + \"\\n\" + _format_fresh_results(fresh_per_active[j]) + PLANNING_SUFFIX\n",
678
+ " for j in range(len(active))]\n",
679
+ " plan_resps, ptok2 = _gen(plan_prompts)\n",
680
+ " if verbose:\n",
681
+ " print(f\" D{day:2d}B: batch={len(active)} prompt_tok={ptok2}\")\n",
682
+ "\n",
683
+ " for j, i in enumerate(active):\n",
684
+ " actions_by_idx[i] = _parse_actions_only(plan_resps[j])\n",
685
+ " pairs[i].append({\"prompt\": plan_prompts[j], \"response\": plan_resps[j],\n",
686
+ " \"step\": len(rewards[i]), \"phase\": \"B\"})\n",
687
+ " if log_tag is not None:\n",
688
+ " t, s = tasks_seeds[i]\n",
689
+ " _log_io(f\"{log_tag}/B\", i, day, t, s, plan_prompts[j], plan_resps[j])\n",
690
+ "\n",
691
+ " for i in range(n):\n",
692
+ " if done_mask[i] or i not in actions_by_idx:\n",
693
+ " continue\n",
694
+ " act = actions_by_idx[i]\n",
695
+ " obss[i] = envs[i].step(act)\n",
696
+ " r = obss[i].reward or 0.0\n",
697
+ " rewards[i].append(r)\n",
698
+ " energies[i].append(obss[i].creator_energy)\n",
699
+ " posts = [{\"hour\": s.hour, \"content_type\": s.content_type or \"?\", \"intent\": s.intent or \"?\"}\n",
700
+ " for s in (act.scheduled_actions or []) if s.action_type == \"post\"]\n",
701
+ " for h in histories[i]:\n",
702
+ " h[\"ago\"] += 1\n",
703
+ " histories[i].append({\"ago\": 1, \"posts\": posts, \"reward\": r})\n",
704
+ " histories[i] = histories[i][-3:]\n",
705
+ " if obss[i].done:\n",
706
+ " done_mask[i] = True\n",
707
  "\n",
708
  " GAMMA, TERMINAL_W = 0.95, 5.0\n",
709
  " results = []\n",
 
732
  "\n",
733
  "\n",
734
  "print(\"LLM agent functions defined (batched).\")"
735
+ ],
736
+ "execution_count": null,
737
+ "outputs": []
738
  },
739
  {
740
  "cell_type": "markdown",
 
747
  },
748
  {
749
  "cell_type": "code",
 
750
  "metadata": {},
 
751
  "source": [
752
  "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
753
  "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
 
761
  "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
762
  "for t in TASKS:\n",
763
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
764
+ ],
765
+ "execution_count": null,
766
+ "outputs": []
767
  },
768
  {
769
  "cell_type": "markdown",
 
782
  },
783
  {
784
  "cell_type": "code",
 
785
  "metadata": {},
 
786
  "source": [
787
  "# Cell 10: Attach LoRA adapter\n",
788
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
796
  "model.enable_input_require_grads()\n",
797
  "peft_model = get_peft_model(model, lora_config)\n",
798
  "peft_model.print_trainable_parameters()"
799
+ ],
800
+ "execution_count": null,
801
+ "outputs": []
802
  },
803
  {
804
  "cell_type": "code",
 
805
  "metadata": {},
 
806
  "source": [
807
  "# Cell 11: Training loop\n",
808
  "from trl import SFTTrainer, SFTConfig\n",
809
  "from datasets import Dataset\n",
810
  "\n",
811
+ "NUM_ROUNDS = 2\n",
812
  "EPISODES_PER_ROUND = 6\n",
813
+ "QUALITY_FLOOR = 0.0 # 0 = always run SFT on positive-advantage samples\n",
814
  "\n",
815
  "training_log = {\n",
816
  " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
 
907
  "elapsed = time.time() - t_start\n",
908
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
909
  "print(pd.DataFrame(training_log).to_string(index=False))"
910
+ ],
911
+ "execution_count": null,
912
+ "outputs": []
913
  },
914
  {
915
  "cell_type": "markdown",
 
922
  },
923
  {
924
  "cell_type": "code",
 
925
  "metadata": {},
 
926
  "source": [
927
  "# Cell 12: Run trained model (batched)\n",
928
  "print(\"Running TRAINED model on all tasks (batched)...\")\n",
 
937
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
938
  "for t in TASKS:\n",
939
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
940
+ ],
941
+ "execution_count": null,
942
+ "outputs": []
943
  },
944
  {
945
  "cell_type": "markdown",
 
950
  },
951
  {
952
  "cell_type": "code",
 
953
  "metadata": {},
 
954
  "source": [
955
  "# Cell 13: Training curves\n",
956
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
972
  "fig.tight_layout()\n",
973
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
974
  "plt.show()"
975
+ ],
976
+ "execution_count": null,
977
+ "outputs": []
978
  },
979
  {
980
  "cell_type": "code",
 
981
  "metadata": {},
 
982
  "source": [
983
  "# Cell 14: Before vs After\n",
984
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
1008
  "fig.tight_layout()\n",
1009
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
1010
  "plt.show()"
1011
+ ],
1012
+ "execution_count": null,
1013
+ "outputs": []
1014
  },
1015
  {
1016
  "cell_type": "code",
 
1017
  "metadata": {},
 
1018
  "source": [
1019
  "# Cell 15: Trajectory comparison\n",
1020
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
1038
  "fig.tight_layout()\n",
1039
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
1040
  "plt.show()"
1041
+ ],
1042
+ "execution_count": null,
1043
+ "outputs": []
1044
  },
1045
  {
1046
  "cell_type": "markdown",
 
1051
  },
1052
  {
1053
  "cell_type": "code",
 
1054
  "metadata": {},
 
1055
  "source": [
1056
  "# Cell 16: Final summary\n",
1057
  "print(\"=\" * 67)\n",
 
1088
  "\n",
1089
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
1090
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
1091
+ ],
1092
+ "execution_count": null,
1093
+ "outputs": []
1094
  },
1095
  {
1096
  "cell_type": "code",
 
1097
  "metadata": {},
 
1098
  "source": [
1099
  "# Cell 17: Save adapter\n",
1100
  "save_path = \"./viraltest_trained_adapter\"\n",
 
1102
  "tokenizer.save_pretrained(save_path)\n",
1103
  "print(f\"LoRA adapter saved to {save_path}\")\n",
1104
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
1105
+ ],
1106
+ "execution_count": null,
1107
+ "outputs": []
1108
  }
1109
  ],
1110
  "metadata": {
 
1130
  },
1131
  "nbformat": 4,
1132
  "nbformat_minor": 4
1133
+ }