Spaces:
Paused
Paused
Commit ·
b1c1732
1
Parent(s): afbf541
ReAct two-pass per day so model sees current-day tool results
Browse filesPhase A (discovery): model emits tool_calls only; we dispatch them via
env._dispatch_tool (read-only) and capture results.
Phase B (planning): same observation + fresh tool results, model emits
scheduled_actions. env.step then runs those actions only (tool_calls
already executed, not double-dispatched).
Both phases logged separately to plots/io_log.jsonl with /A and /B tags
so we can verify discovery actually fires.
Made-with: Cursor
- training/train_grpo.ipynb +113 -68
training/train_grpo.ipynb
CHANGED
|
@@ -443,19 +443,17 @@
|
|
| 443 |
"\n",
|
| 444 |
"SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
|
| 445 |
"\n",
|
| 446 |
-
"
|
| 447 |
-
"
|
| 448 |
-
"
|
| 449 |
-
"
|
| 450 |
-
"
|
| 451 |
-
"
|
| 452 |
-
"
|
| 453 |
-
"
|
| 454 |
-
"
|
| 455 |
-
"
|
| 456 |
-
"
|
| 457 |
-
" {\"name\": \"predict_engagement\", \"arguments\": {\"scheduled_actions\": [...]}}\n",
|
| 458 |
-
" and any query_* whose result is missing from `Tool results`.\"\"\")\n",
|
| 459 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 460 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 461 |
"\n",
|
|
@@ -469,18 +467,18 @@
|
|
| 469 |
" signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
|
| 470 |
" f\"sends={signals.sends_per_reach:.3f} \"\n",
|
| 471 |
" f\"saves={signals.saves:.3f}\\n\")\n",
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
"\n",
|
| 485 |
"\n",
|
| 486 |
"def is_well_formed_response(text):\n",
|
|
@@ -578,49 +576,96 @@
|
|
| 578 |
" f.write(json.dumps(rec) + \"\\n\")\n",
|
| 579 |
"\n",
|
| 580 |
"\n",
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
"\n",
|
| 625 |
" GAMMA, TERMINAL_W = 0.95, 5.0\n",
|
| 626 |
" results = []\n",
|
|
|
|
| 443 |
"\n",
|
| 444 |
"SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
|
| 445 |
"\n",
|
| 446 |
+
"TWO-PHASE FLOW (each day has two turns — same observation, two responses):\n",
|
| 447 |
+
"PHASE A — DISCOVERY: respond with {\"tool_calls\": [...]} only. Tools cost nothing,\n",
|
| 448 |
+
" call as many query_* / predict_engagement / draft_review as useful. Their results\n",
|
| 449 |
+
" are dispatched immediately and shown to you in PHASE B of the SAME day.\n",
|
| 450 |
+
"PHASE B — PLANNING: respond with {\"scheduled_actions\": [...], \"notes\": \"...\"}\n",
|
| 451 |
+
" using the freshly returned Tool results.\n",
|
| 452 |
+
"Audience peak hours, segment affinities, trends, competitor schedules are NOT in\n",
|
| 453 |
+
"the observation — discover them in PHASE A. Useful PHASE-A starter set:\n",
|
| 454 |
+
" query_trends(niche), query_audience(segment_id), query_creator_pool(),\n",
|
| 455 |
+
" query_competitor(competitor_id, window_days), and on later days also\n",
|
| 456 |
+
" predict_engagement(scheduled_actions=[...candidate plan...]).\"\"\")\n",
|
|
|
|
|
|
|
| 457 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 458 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 459 |
"\n",
|
|
|
|
| 467 |
" signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
|
| 468 |
" f\"sends={signals.sends_per_reach:.3f} \"\n",
|
| 469 |
" f\"saves={signals.saves:.3f}\\n\")\n",
|
| 470 |
+
" tool_str = \"\"\n",
|
| 471 |
+
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 472 |
+
" if tr.success:\n",
|
| 473 |
+
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 474 |
+
" if not tool_str:\n",
|
| 475 |
+
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
| 476 |
+
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 477 |
+
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 478 |
+
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 479 |
+
" f\"{signals_str}\"\n",
|
| 480 |
+
" f\"Tool results:\\n{tool_str}\"\n",
|
| 481 |
+
" f\"Plan today's actions (JSON only):\")\n",
|
| 482 |
"\n",
|
| 483 |
"\n",
|
| 484 |
"def is_well_formed_response(text):\n",
|
|
|
|
| 576 |
" f.write(json.dumps(rec) + \"\\n\")\n",
|
| 577 |
"\n",
|
| 578 |
"\n",
|
| 579 |
+
"DISCOVERY_SUFFIX = \"\\n\\nPHASE A (DISCOVERY): respond with JSON {\\\"tool_calls\\\": [...]} only.\"\n",
|
| 580 |
+
"PLANNING_SUFFIX = \"\\n\\nPHASE B (PLANNING): respond with JSON {\\\"scheduled_actions\\\": [...], \\\"notes\\\": \\\"...\\\"} using the fresh Tool results above.\"\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"def _parse_tool_calls_only(text):\n",
|
| 584 |
+
" return parse_model_output(text).tool_calls\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"def _parse_actions_only(text):\n",
|
| 588 |
+
" a = parse_model_output(text)\n",
|
| 589 |
+
" return ViraltestAction(tool_calls=[], scheduled_actions=a.scheduled_actions, notes=a.notes)\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"def _format_fresh_results(fresh):\n",
|
| 593 |
+
" if not fresh:\n",
|
| 594 |
+
" return \"\"\n",
|
| 595 |
+
" out = \"Fresh tool results (PHASE A):\\n\"\n",
|
| 596 |
+
" for tr in fresh:\n",
|
| 597 |
+
" if tr.success:\n",
|
| 598 |
+
" out += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 599 |
+
" else:\n",
|
| 600 |
+
" out += f\" {tr.name}: ERROR {tr.error}\\n\"\n",
|
| 601 |
+
" return out\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n",
|
| 605 |
+
" \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
|
| 606 |
+
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 607 |
+
" n = len(tasks_seeds)\n",
|
| 608 |
+
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 609 |
+
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 610 |
+
" rewards = [[] for _ in range(n)]\n",
|
| 611 |
+
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 612 |
+
" pairs = [[] for _ in range(n)]\n",
|
| 613 |
+
" done_mask = [obs.done for obs in obss]\n",
|
| 614 |
+
" rest_action = ViraltestAction(scheduled_actions=[])\n",
|
| 615 |
+
"\n",
|
| 616 |
+
" def _gen(prompts):\n",
|
| 617 |
+
" chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
|
| 618 |
+
" texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
|
| 619 |
+
" return _batched_generate(mdl, tok, texts, eval=eval)\n",
|
| 620 |
+
"\n",
|
| 621 |
+
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 622 |
+
" active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
|
| 623 |
+
" rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
|
| 624 |
+
" if not active and not rest:\n",
|
| 625 |
+
" break\n",
|
| 626 |
+
"\n",
|
| 627 |
+
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 628 |
+
" if active:\n",
|
| 629 |
+
" base_prompts = [format_obs(obss[i]) for i in active]\n",
|
| 630 |
+
"\n",
|
| 631 |
+
" disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
|
| 632 |
+
" disc_resps, ptok = _gen(disc_prompts)\n",
|
| 633 |
+
" if verbose:\n",
|
| 634 |
+
" print(f\" D{day:2d}A: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
|
| 635 |
+
"\n",
|
| 636 |
+
" fresh_per_active = []\n",
|
| 637 |
+
" for j, i in enumerate(active):\n",
|
| 638 |
+
" tcs = _parse_tool_calls_only(disc_resps[j])\n",
|
| 639 |
+
" fresh_per_active.append([envs[i]._dispatch_tool(tc) for tc in tcs])\n",
|
| 640 |
+
" pairs[i].append({\"prompt\": disc_prompts[j], \"response\": disc_resps[j],\n",
|
| 641 |
+
" \"step\": len(rewards[i]), \"phase\": \"A\"})\n",
|
| 642 |
+
" if log_tag is not None:\n",
|
| 643 |
+
" t, s = tasks_seeds[i]\n",
|
| 644 |
+
" _log_io(f\"{log_tag}/A\", i, day, t, s, disc_prompts[j], disc_resps[j])\n",
|
| 645 |
+
"\n",
|
| 646 |
+
" plan_prompts = [base_prompts[j] + \"\\n\" + _format_fresh_results(fresh_per_active[j]) + PLANNING_SUFFIX\n",
|
| 647 |
+
" for j in range(len(active))]\n",
|
| 648 |
+
" plan_resps, ptok2 = _gen(plan_prompts)\n",
|
| 649 |
+
" if verbose:\n",
|
| 650 |
+
" print(f\" D{day:2d}B: batch={len(active)} prompt_tok={ptok2}\")\n",
|
| 651 |
+
"\n",
|
| 652 |
+
" for j, i in enumerate(active):\n",
|
| 653 |
+
" actions_by_idx[i] = _parse_actions_only(plan_resps[j])\n",
|
| 654 |
+
" pairs[i].append({\"prompt\": plan_prompts[j], \"response\": plan_resps[j],\n",
|
| 655 |
+
" \"step\": len(rewards[i]), \"phase\": \"B\"})\n",
|
| 656 |
+
" if log_tag is not None:\n",
|
| 657 |
+
" t, s = tasks_seeds[i]\n",
|
| 658 |
+
" _log_io(f\"{log_tag}/B\", i, day, t, s, plan_prompts[j], plan_resps[j])\n",
|
| 659 |
+
"\n",
|
| 660 |
+
" for i in range(n):\n",
|
| 661 |
+
" if done_mask[i] or i not in actions_by_idx:\n",
|
| 662 |
+
" continue\n",
|
| 663 |
+
" obss[i] = envs[i].step(actions_by_idx[i])\n",
|
| 664 |
+
" r = obss[i].reward or 0.0\n",
|
| 665 |
+
" rewards[i].append(r)\n",
|
| 666 |
+
" energies[i].append(obss[i].creator_energy)\n",
|
| 667 |
+
" if obss[i].done:\n",
|
| 668 |
+
" done_mask[i] = True\n",
|
| 669 |
"\n",
|
| 670 |
" GAMMA, TERMINAL_W = 0.95, 5.0\n",
|
| 671 |
" results = []\n",
|