vaibhav12332112312 commited on
Commit
b1c1732
·
1 Parent(s): afbf541

ReAct two-pass per day so model sees current-day tool results

Browse files

Phase A (discovery): model emits tool_calls only; we dispatch them via
env._dispatch_tool (read-only) and capture results.
Phase B (planning): same observation + fresh tool results, model emits
scheduled_actions. env.step then runs those actions only (tool_calls
already executed, not double-dispatched).

Both phases logged separately to plots/io_log.jsonl with /A and /B tags
so we can verify discovery actually fires.

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +113 -68
training/train_grpo.ipynb CHANGED
@@ -443,19 +443,17 @@
443
  "\n",
444
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
445
  "\n",
446
- "TOOL POLICY (tool_calls cost nothingcall them aggressively):\n",
447
- "- The observation tells you ONLY your account stats. Audience peak hours, segment\n",
448
- " affinities, trending topics/tags and competitor schedules are NOT given. You must\n",
449
- " discover them via tool_calls and read them from `Tool results` next turn.\n",
450
- "- days_elapsed == 0 -> call EVERY discovery tool you might need, e.g.:\n",
451
- " {\"name\": \"query_trends\", \"arguments\": {\"niche\": \"<TOPIC_CATEGORIES key>\"}}\n",
452
- " {\"name\": \"query_audience\", \"arguments\": {\"segment_id\": \"young_professionals\"}}\n",
453
- " {\"name\": \"query_audience\", \"arguments\": {\"segment_id\": \"students\"}}\n",
454
- " {\"name\": \"query_creator_pool\", \"arguments\": {}}\n",
455
- " {\"name\": \"query_competitor\", \"arguments\": {\"competitor_id\": \"niche_expert\", \"window_days\": 7}}\n",
456
- "- days_elapsed >= 1 -> before scheduling posts, call:\n",
457
- " {\"name\": \"predict_engagement\", \"arguments\": {\"scheduled_actions\": [...]}}\n",
458
- " and any query_* whose result is missing from `Tool results`.\"\"\")\n",
459
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
460
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
461
  "\n",
@@ -469,18 +467,18 @@
469
  " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
470
  " f\"sends={signals.sends_per_reach:.3f} \"\n",
471
  " f\"saves={signals.saves:.3f}\\n\")\n",
472
- " tool_str = \"\"\n",
473
- " for tr in getattr(obs, \"tool_results\", []):\n",
474
- " if tr.success:\n",
475
- " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
476
- " if not tool_str:\n",
477
- " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
478
- " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
479
- " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
480
- " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
481
- " f\"{signals_str}\"\n",
482
- " f\"Tool results:\\n{tool_str}\"\n",
483
- " f\"Plan today's actions (JSON only):\")\n",
484
  "\n",
485
  "\n",
486
  "def is_well_formed_response(text):\n",
@@ -578,49 +576,96 @@
578
  " f.write(json.dumps(rec) + \"\\n\")\n",
579
  "\n",
580
  "\n",
581
- "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
582
- " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
583
- " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
584
- " n = len(tasks_seeds)\n",
585
- " envs = [ViraltestEnvironment() for _ in range(n)]\n",
586
- " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
587
- " rewards = [[] for _ in range(n)]\n",
588
- " energies = [[obs.creator_energy] for obs in obss]\n",
589
- " pairs = [[] for _ in range(n)]\n",
590
- " done_mask = [obs.done for obs in obss]\n",
591
- " rest_action = ViraltestAction(scheduled_actions=[])\n",
592
- "\n",
593
- " for day in range(1, TASK_HORIZON + 1):\n",
594
- " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
595
- " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
596
- " if not active and not rest:\n",
597
- " break\n",
598
- "\n",
599
- " actions_by_idx = {i: rest_action for i in rest}\n",
600
- " if active:\n",
601
- " prompts = [format_obs(obss[i]) for i in active]\n",
602
- " chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
603
- " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
604
- " resps, ptok = _batched_generate(mdl, tok, texts, eval=eval)\n",
605
- " if verbose:\n",
606
- " print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
607
- " for j, i in enumerate(active):\n",
608
- " actions_by_idx[i] = parse_model_output(resps[j])\n",
609
- " pairs[i].append({\"prompt\": prompts[j], \"response\": resps[j],\n",
610
- " \"step\": len(rewards[i])})\n",
611
- " if log_tag is not None:\n",
612
- " t, s = tasks_seeds[i]\n",
613
- " _log_io(log_tag, i, day, t, s, prompts[j], resps[j])\n",
614
- "\n",
615
- " for i in range(n):\n",
616
- " if done_mask[i] or i not in actions_by_idx:\n",
617
- " continue\n",
618
- " obss[i] = envs[i].step(actions_by_idx[i])\n",
619
- " r = obss[i].reward or 0.0\n",
620
- " rewards[i].append(r)\n",
621
- " energies[i].append(obss[i].creator_energy)\n",
622
- " if obss[i].done:\n",
623
- " done_mask[i] = True\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  "\n",
625
  " GAMMA, TERMINAL_W = 0.95, 5.0\n",
626
  " results = []\n",
 
443
  "\n",
444
  "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
445
  "\n",
446
+ "TWO-PHASE FLOW (each day has two turns same observation, two responses):\n",
447
+ "PHASE A DISCOVERY: respond with {\"tool_calls\": [...]} only. Tools cost nothing,\n",
448
+ " call as many query_* / predict_engagement / draft_review as useful. Their results\n",
449
+ " are dispatched immediately and shown to you in PHASE B of the SAME day.\n",
450
+ "PHASE B PLANNING: respond with {\"scheduled_actions\": [...], \"notes\": \"...\"}\n",
451
+ " using the freshly returned Tool results.\n",
452
+ "Audience peak hours, segment affinities, trends, competitor schedules are NOT in\n",
453
+ "the observation discover them in PHASE A. Useful PHASE-A starter set:\n",
454
+ " query_trends(niche), query_audience(segment_id), query_creator_pool(),\n",
455
+ " query_competitor(competitor_id, window_days), and on later days also\n",
456
+ " predict_engagement(scheduled_actions=[...candidate plan...]).\"\"\")\n",
 
 
457
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
458
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
459
  "\n",
 
467
  " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
468
  " f\"sends={signals.sends_per_reach:.3f} \"\n",
469
  " f\"saves={signals.saves:.3f}\\n\")\n",
470
+ " tool_str = \"\"\n",
471
+ " for tr in getattr(obs, \"tool_results\", []):\n",
472
+ " if tr.success:\n",
473
+ " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
474
+ " if not tool_str:\n",
475
+ " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
476
+ " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
477
+ " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
478
+ " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
479
+ " f\"{signals_str}\"\n",
480
+ " f\"Tool results:\\n{tool_str}\"\n",
481
+ " f\"Plan today's actions (JSON only):\")\n",
482
  "\n",
483
  "\n",
484
  "def is_well_formed_response(text):\n",
 
576
  " f.write(json.dumps(rec) + \"\\n\")\n",
577
  "\n",
578
  "\n",
579
+ "DISCOVERY_SUFFIX = \"\\n\\nPHASE A (DISCOVERY): respond with JSON {\\\"tool_calls\\\": [...]} only.\"\n",
580
+ "PLANNING_SUFFIX = \"\\n\\nPHASE B (PLANNING): respond with JSON {\\\"scheduled_actions\\\": [...], \\\"notes\\\": \\\"...\\\"} using the fresh Tool results above.\"\n",
581
+ "\n",
582
+ "\n",
583
+ "def _parse_tool_calls_only(text):\n",
584
+ " return parse_model_output(text).tool_calls\n",
585
+ "\n",
586
+ "\n",
587
+ "def _parse_actions_only(text):\n",
588
+ " a = parse_model_output(text)\n",
589
+ " return ViraltestAction(tool_calls=[], scheduled_actions=a.scheduled_actions, notes=a.notes)\n",
590
+ "\n",
591
+ "\n",
592
+ "def _format_fresh_results(fresh):\n",
593
+ " if not fresh:\n",
594
+ " return \"\"\n",
595
+ " out = \"Fresh tool results (PHASE A):\\n\"\n",
596
+ " for tr in fresh:\n",
597
+ " if tr.success:\n",
598
+ " out += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
599
+ " else:\n",
600
+ " out += f\" {tr.name}: ERROR {tr.error}\\n\"\n",
601
+ " return out\n",
602
+ "\n",
603
+ "\n",
604
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
605
+ " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
606
+ " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
607
+ " n = len(tasks_seeds)\n",
608
+ " envs = [ViraltestEnvironment() for _ in range(n)]\n",
609
+ " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
610
+ " rewards = [[] for _ in range(n)]\n",
611
+ " energies = [[obs.creator_energy] for obs in obss]\n",
612
+ " pairs = [[] for _ in range(n)]\n",
613
+ " done_mask = [obs.done for obs in obss]\n",
614
+ " rest_action = ViraltestAction(scheduled_actions=[])\n",
615
+ "\n",
616
+ " def _gen(prompts):\n",
617
+ " chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
618
+ " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
619
+ " return _batched_generate(mdl, tok, texts, eval=eval)\n",
620
+ "\n",
621
+ " for day in range(1, TASK_HORIZON + 1):\n",
622
+ " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
623
+ " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
624
+ " if not active and not rest:\n",
625
+ " break\n",
626
+ "\n",
627
+ " actions_by_idx = {i: rest_action for i in rest}\n",
628
+ " if active:\n",
629
+ " base_prompts = [format_obs(obss[i]) for i in active]\n",
630
+ "\n",
631
+ " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
632
+ " disc_resps, ptok = _gen(disc_prompts)\n",
633
+ " if verbose:\n",
634
+ " print(f\" D{day:2d}A: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
635
+ "\n",
636
+ " fresh_per_active = []\n",
637
+ " for j, i in enumerate(active):\n",
638
+ " tcs = _parse_tool_calls_only(disc_resps[j])\n",
639
+ " fresh_per_active.append([envs[i]._dispatch_tool(tc) for tc in tcs])\n",
640
+ " pairs[i].append({\"prompt\": disc_prompts[j], \"response\": disc_resps[j],\n",
641
+ " \"step\": len(rewards[i]), \"phase\": \"A\"})\n",
642
+ " if log_tag is not None:\n",
643
+ " t, s = tasks_seeds[i]\n",
644
+ " _log_io(f\"{log_tag}/A\", i, day, t, s, disc_prompts[j], disc_resps[j])\n",
645
+ "\n",
646
+ " plan_prompts = [base_prompts[j] + \"\\n\" + _format_fresh_results(fresh_per_active[j]) + PLANNING_SUFFIX\n",
647
+ " for j in range(len(active))]\n",
648
+ " plan_resps, ptok2 = _gen(plan_prompts)\n",
649
+ " if verbose:\n",
650
+ " print(f\" D{day:2d}B: batch={len(active)} prompt_tok={ptok2}\")\n",
651
+ "\n",
652
+ " for j, i in enumerate(active):\n",
653
+ " actions_by_idx[i] = _parse_actions_only(plan_resps[j])\n",
654
+ " pairs[i].append({\"prompt\": plan_prompts[j], \"response\": plan_resps[j],\n",
655
+ " \"step\": len(rewards[i]), \"phase\": \"B\"})\n",
656
+ " if log_tag is not None:\n",
657
+ " t, s = tasks_seeds[i]\n",
658
+ " _log_io(f\"{log_tag}/B\", i, day, t, s, plan_prompts[j], plan_resps[j])\n",
659
+ "\n",
660
+ " for i in range(n):\n",
661
+ " if done_mask[i] or i not in actions_by_idx:\n",
662
+ " continue\n",
663
+ " obss[i] = envs[i].step(actions_by_idx[i])\n",
664
+ " r = obss[i].reward or 0.0\n",
665
+ " rewards[i].append(r)\n",
666
+ " energies[i].append(obss[i].creator_energy)\n",
667
+ " if obss[i].done:\n",
668
+ " done_mask[i] = True\n",
669
  "\n",
670
  " GAMMA, TERMINAL_W = 0.95, 5.0\n",
671
  " results = []\n",