Spaces:
Paused
Paused
Commit ·
a6b8df0
1
Parent(s): 81cdb34
train: batched parallel rollouts on Qwen2.5-3B + parser hardening
Browse files- Qwen2.5-3B-Instruct in bf16 + flash-attn-2 (sdpa fallback)
- New run_llm_episodes_batched: N parallel envs, one batched generate per day
(~10x faster rollouts than sequential)
- parse_model_output: per-tool-call try/except so a malformed `arguments` no
longer wipes the whole action (root cause of post-train follower collapse)
- is_well_formed_response filter on SFT data
- SFT: max_length=4096, batch=4 x accum=2, bf16
- Per-step credit assignment for SFT sample weights
Made-with: Cursor
- training/train_grpo.ipynb +251 -203
training/train_grpo.ipynb
CHANGED
|
@@ -25,23 +25,22 @@
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
-
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 33 |
"!pip install -q torch torchvision torchaudio\n",
|
| 34 |
-
"!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\"
|
| 35 |
"!pip install -q matplotlib pandas\n",
|
| 36 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 37 |
-
"!pip install -q \"openenv-core[core]>=0.2.2\""
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
},
|
| 40 |
{
|
| 41 |
"cell_type": "code",
|
| 42 |
-
"execution_count": null,
|
| 43 |
"metadata": {},
|
| 44 |
-
"outputs": [],
|
| 45 |
"source": [
|
| 46 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 47 |
"import os\n",
|
|
@@ -117,13 +116,13 @@
|
|
| 117 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 118 |
"print(f\"Commit: {commit}\")\n",
|
| 119 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 120 |
-
]
|
|
|
|
|
|
|
| 121 |
},
|
| 122 |
{
|
| 123 |
"cell_type": "code",
|
| 124 |
-
"execution_count": null,
|
| 125 |
"metadata": {},
|
| 126 |
-
"outputs": [],
|
| 127 |
"source": [
|
| 128 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 129 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
@@ -177,7 +176,9 @@
|
|
| 177 |
"import ast\n",
|
| 178 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 179 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 180 |
-
]
|
|
|
|
|
|
|
| 181 |
},
|
| 182 |
{
|
| 183 |
"cell_type": "markdown",
|
|
@@ -190,9 +191,7 @@
|
|
| 190 |
},
|
| 191 |
{
|
| 192 |
"cell_type": "code",
|
| 193 |
-
"execution_count": null,
|
| 194 |
"metadata": {},
|
| 195 |
-
"outputs": [],
|
| 196 |
"source": [
|
| 197 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 198 |
"_rng = random.Random(42)\n",
|
|
@@ -269,13 +268,13 @@
|
|
| 269 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 270 |
"\n",
|
| 271 |
"print(\"Agents and episode runner defined.\")"
|
| 272 |
-
]
|
|
|
|
|
|
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
| 276 |
-
"execution_count": null,
|
| 277 |
"metadata": {},
|
| 278 |
-
"outputs": [],
|
| 279 |
"source": [
|
| 280 |
"# Cell 5: Run baselines (safe)\n",
|
| 281 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
@@ -310,13 +309,13 @@
|
|
| 310 |
"for name in BASELINE_AGENTS:\n",
|
| 311 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 312 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 313 |
-
]
|
|
|
|
|
|
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
| 317 |
-
"execution_count": null,
|
| 318 |
"metadata": {},
|
| 319 |
-
"outputs": [],
|
| 320 |
"source": [
|
| 321 |
"# Cell 6: Baseline plots\n",
|
| 322 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
@@ -334,7 +333,9 @@
|
|
| 334 |
"fig.tight_layout()\n",
|
| 335 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 336 |
"plt.show()"
|
| 337 |
-
]
|
|
|
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
@@ -347,80 +348,57 @@
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"cell_type": "code",
|
| 350 |
-
"execution_count": null,
|
| 351 |
"metadata": {},
|
| 352 |
-
"outputs": [],
|
| 353 |
"source": [
|
| 354 |
-
"# Cell 7: Load model (
|
| 355 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 356 |
"\n",
|
| 357 |
-
"MODEL_NAME = \"Qwen/Qwen2.5-
|
| 358 |
"\n",
|
| 359 |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
|
|
|
|
|
|
|
|
|
| 360 |
"\n",
|
| 361 |
-
"
|
| 362 |
-
"
|
| 363 |
-
"
|
| 364 |
-
"
|
| 365 |
-
"
|
| 366 |
-
"
|
| 367 |
-
"
|
| 368 |
-
"
|
| 369 |
-
"
|
| 370 |
-
"
|
| 371 |
-
"\n",
|
| 372 |
-
"if
|
| 373 |
-
"
|
| 374 |
-
"
|
| 375 |
-
"\n",
|
| 376 |
-
"if _use_4bit:\n",
|
| 377 |
-
" print(f\"Loading {MODEL_NAME} (4-bit quantized, CUDA)...\")\n",
|
| 378 |
-
" bnb_config = BitsAndBytesConfig(\n",
|
| 379 |
-
" load_in_4bit=True,\n",
|
| 380 |
-
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 381 |
-
" bnb_4bit_compute_dtype=torch.float16,\n",
|
| 382 |
-
" bnb_4bit_use_double_quant=True,\n",
|
| 383 |
-
" )\n",
|
| 384 |
-
" model = AutoModelForCausalLM.from_pretrained(\n",
|
| 385 |
-
" MODEL_NAME,\n",
|
| 386 |
-
" trust_remote_code=True,\n",
|
| 387 |
-
" quantization_config=bnb_config,\n",
|
| 388 |
-
" device_map=\"auto\",\n",
|
| 389 |
-
" )\n",
|
| 390 |
"else:\n",
|
| 391 |
-
"
|
| 392 |
-
"
|
| 393 |
-
"
|
| 394 |
-
"
|
| 395 |
-
"
|
| 396 |
-
"
|
| 397 |
-
"
|
| 398 |
-
"
|
| 399 |
-
"
|
| 400 |
-
"
|
| 401 |
-
"
|
| 402 |
-
" )\n",
|
| 403 |
-
" if not torch.cuda.is_available():\n",
|
| 404 |
-
" if getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available():\n",
|
| 405 |
-
" model = model.to(\"mps\")\n",
|
| 406 |
-
" else:\n",
|
| 407 |
-
" model = model.to(\"cpu\")\n",
|
| 408 |
"\n",
|
| 409 |
"model.eval()\n",
|
| 410 |
-
"print(f\"Model loaded. dtype={next(model.parameters()).dtype}\")\n",
|
| 411 |
-
"try:\n",
|
| 412 |
-
" print(f\"Device: {model.device}\")\n",
|
| 413 |
-
"except Exception:\n",
|
| 414 |
-
" print(\"Device: (see first parameter device)\")\n",
|
| 415 |
"if torch.cuda.is_available():\n",
|
| 416 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 417 |
-
]
|
|
|
|
|
|
|
| 418 |
},
|
| 419 |
{
|
| 420 |
"cell_type": "code",
|
| 421 |
-
"execution_count": null,
|
| 422 |
"metadata": {},
|
| 423 |
-
"outputs": [],
|
| 424 |
"source": [
|
| 425 |
"# Cell 8: LLM agent functions\n",
|
| 426 |
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
|
@@ -468,6 +446,21 @@
|
|
| 468 |
" f\"Plan your actions (JSON only):\")\n",
|
| 469 |
"\n",
|
| 470 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
"def parse_model_output(text):\n",
|
| 472 |
" text = text.strip()\n",
|
| 473 |
" if \"```\" in text:\n",
|
|
@@ -478,24 +471,33 @@
|
|
| 478 |
" text = text[start:end]\n",
|
| 479 |
" try:\n",
|
| 480 |
" data = json.loads(text)\n",
|
| 481 |
-
" tool_calls = [ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {}))\n",
|
| 482 |
-
" for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
|
| 483 |
-
" scheduled = []\n",
|
| 484 |
-
" for a in data.get(\"scheduled_actions\", []):\n",
|
| 485 |
-
" try:\n",
|
| 486 |
-
" scheduled.append(ScheduledAction(**a))\n",
|
| 487 |
-
" except Exception:\n",
|
| 488 |
-
" # Same as original bare `except:`: skip invalid scheduled_actions entries\n",
|
| 489 |
-
" pass\n",
|
| 490 |
-
" return ViraltestAction(\n",
|
| 491 |
-
" tool_calls=tool_calls,\n",
|
| 492 |
-
" scheduled_actions=scheduled,\n",
|
| 493 |
-
" replies=data.get(\"replies\", []),\n",
|
| 494 |
-
" notes=data.get(\"notes\"),\n",
|
| 495 |
-
" )\n",
|
| 496 |
" except Exception:\n",
|
| 497 |
-
" # Same behavior as original bare `except:`: any parse/validation failure -> empty action\n",
|
| 498 |
" return ViraltestAction(scheduled_actions=[])\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
"\n",
|
| 500 |
"\n",
|
| 501 |
"def _infer_model_device(m):\n",
|
|
@@ -509,53 +511,101 @@
|
|
| 509 |
" return torch.device(\"cpu\")\n",
|
| 510 |
"\n",
|
| 511 |
"\n",
|
| 512 |
-
"def
|
| 513 |
-
"
|
| 514 |
-
"
|
| 515 |
-
"
|
| 516 |
-
"
|
| 517 |
-
"
|
| 518 |
-
"
|
|
|
|
|
|
|
| 519 |
" with torch.no_grad():\n",
|
| 520 |
-
" out = mdl.generate(
|
| 521 |
-
"
|
| 522 |
-
"
|
| 523 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
"\n",
|
| 525 |
"\n",
|
| 526 |
"def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
|
| 527 |
-
"
|
| 528 |
-
" obs = env.reset(task=task, seed=seed)\n",
|
| 529 |
-
" rewards, energies = [], [obs.creator_energy]\n",
|
| 530 |
-
" history, pairs = [], []\n",
|
| 531 |
-
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 532 |
-
" if obs.done: break\n",
|
| 533 |
-
" if obs.creator_energy <= 0.25:\n",
|
| 534 |
-
" action = ViraltestAction(scheduled_actions=[])\n",
|
| 535 |
-
" resp = '{\"scheduled_actions\": []}'\n",
|
| 536 |
-
" else:\n",
|
| 537 |
-
" resp, action = generate_action(mdl, tok, obs, history)\n",
|
| 538 |
-
" prompt = format_obs(obs)\n",
|
| 539 |
-
" pairs.append({\"prompt\": prompt, \"response\": resp})\n",
|
| 540 |
-
" obs = env.step(action)\n",
|
| 541 |
-
" r = obs.reward or 0.0\n",
|
| 542 |
-
" rewards.append(r)\n",
|
| 543 |
-
" energies.append(obs.creator_energy)\n",
|
| 544 |
-
" history.extend([{\"role\": \"user\", \"content\": prompt},\n",
|
| 545 |
-
" {\"role\": \"assistant\", \"content\": resp}])\n",
|
| 546 |
-
" if verbose:\n",
|
| 547 |
-
" n_p = len([s for s in action.scheduled_actions if s.action_type==\"post\"])\n",
|
| 548 |
-
" print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
|
| 549 |
-
" if obs.done: break\n",
|
| 550 |
-
" gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 551 |
-
" return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
|
| 552 |
-
" \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
|
| 553 |
-
" \"energies\": energies, \"pairs\": pairs,\n",
|
| 554 |
-
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 555 |
-
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 556 |
"\n",
|
| 557 |
-
"
|
| 558 |
-
|
|
|
|
|
|
|
|
|
|
| 559 |
},
|
| 560 |
{
|
| 561 |
"cell_type": "markdown",
|
|
@@ -568,26 +618,23 @@
|
|
| 568 |
},
|
| 569 |
{
|
| 570 |
"cell_type": "code",
|
| 571 |
-
"execution_count": null,
|
| 572 |
"metadata": {},
|
| 573 |
-
"outputs": [],
|
| 574 |
"source": [
|
| 575 |
-
"# Cell 9: Run untrained model\n",
|
| 576 |
-
"print(\"Running UNTRAINED base model on all tasks...\")\n",
|
| 577 |
"print(\"=\" * 60)\n",
|
| 578 |
"\n",
|
| 579 |
-
"
|
| 580 |
-
"for
|
| 581 |
-
"
|
| 582 |
-
" result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
|
| 583 |
-
" before_results[task] = result\n",
|
| 584 |
-
" print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
|
| 585 |
"\n",
|
| 586 |
"print(\"\\n\" + \"=\" * 60)\n",
|
| 587 |
-
"print(\"BEFORE TRAINING:\")\n",
|
| 588 |
"for t in TASKS:\n",
|
| 589 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 590 |
-
]
|
|
|
|
|
|
|
| 591 |
},
|
| 592 |
{
|
| 593 |
"cell_type": "markdown",
|
|
@@ -606,9 +653,7 @@
|
|
| 606 |
},
|
| 607 |
{
|
| 608 |
"cell_type": "code",
|
| 609 |
-
"execution_count": null,
|
| 610 |
"metadata": {},
|
| 611 |
-
"outputs": [],
|
| 612 |
"source": [
|
| 613 |
"# Cell 10: Attach LoRA adapter\n",
|
| 614 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
@@ -623,13 +668,13 @@
|
|
| 623 |
"model.enable_input_require_grads()\n",
|
| 624 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 625 |
"peft_model.print_trainable_parameters()"
|
| 626 |
-
]
|
|
|
|
|
|
|
| 627 |
},
|
| 628 |
{
|
| 629 |
"cell_type": "code",
|
| 630 |
-
"execution_count": null,
|
| 631 |
"metadata": {},
|
| 632 |
-
"outputs": [],
|
| 633 |
"source": [
|
| 634 |
"# Cell 11: Training loop\n",
|
| 635 |
"from trl import SFTTrainer, SFTConfig\n",
|
|
@@ -652,35 +697,39 @@
|
|
| 652 |
" print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
|
| 653 |
" print(f\"{'=' * 60}\")\n",
|
| 654 |
"\n",
|
| 655 |
-
" # Collect episodes\n",
|
| 656 |
" peft_model.eval()\n",
|
| 657 |
-
"
|
|
|
|
|
|
|
|
|
|
| 658 |
"\n",
|
| 659 |
-
"
|
| 660 |
-
"
|
| 661 |
-
" seed = 42 + (round_idx - 1) * 100 + ep\n",
|
| 662 |
-
" result = run_llm_episode(peft_model, tokenizer, task, seed=seed)\n",
|
| 663 |
" ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 664 |
" episode_rewards.append(ep_reward)\n",
|
| 665 |
" episode_graders.append(result[\"grader_score\"])\n",
|
| 666 |
-
"\n",
|
| 667 |
" for pr in result[\"pairs\"]:\n",
|
|
|
|
|
|
|
| 668 |
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 669 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 670 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 671 |
-
" all_pairs.append({\"text\": text, \"reward\":
|
| 672 |
-
"\n",
|
| 673 |
-
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
|
| 674 |
-
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\")\n",
|
| 675 |
-
"\n",
|
| 676 |
-
" avg_r = np.mean(episode_rewards)\n",
|
| 677 |
-
" avg_g = np.mean(episode_graders)\n",
|
| 678 |
-
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
|
|
|
|
|
|
|
|
|
|
| 679 |
"\n",
|
| 680 |
-
" # Filter to top-K\n",
|
| 681 |
" threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
|
| 682 |
" filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
|
| 683 |
-
" print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
|
| 684 |
"\n",
|
| 685 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 686 |
"\n",
|
|
@@ -688,14 +737,14 @@
|
|
| 688 |
" sft_config = SFTConfig(\n",
|
| 689 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 690 |
" num_train_epochs=2,\n",
|
| 691 |
-
" per_device_train_batch_size=
|
| 692 |
-
" gradient_accumulation_steps=
|
| 693 |
" learning_rate=2e-5,\n",
|
| 694 |
" warmup_steps=5,\n",
|
| 695 |
" logging_steps=5,\n",
|
| 696 |
" save_strategy=\"no\",\n",
|
| 697 |
-
" max_length=
|
| 698 |
-
"
|
| 699 |
" report_to=\"none\",\n",
|
| 700 |
" )\n",
|
| 701 |
"\n",
|
|
@@ -720,7 +769,9 @@
|
|
| 720 |
"elapsed = time.time() - t_start\n",
|
| 721 |
"print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
|
| 722 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 723 |
-
]
|
|
|
|
|
|
|
| 724 |
},
|
| 725 |
{
|
| 726 |
"cell_type": "markdown",
|
|
@@ -733,27 +784,24 @@
|
|
| 733 |
},
|
| 734 |
{
|
| 735 |
"cell_type": "code",
|
| 736 |
-
"execution_count": null,
|
| 737 |
"metadata": {},
|
| 738 |
-
"outputs": [],
|
| 739 |
"source": [
|
| 740 |
-
"# Cell 12: Run trained model\n",
|
| 741 |
-
"print(\"Running TRAINED model on all tasks...\")\n",
|
| 742 |
"print(\"=\" * 60)\n",
|
| 743 |
"\n",
|
| 744 |
"peft_model.eval()\n",
|
| 745 |
-
"
|
| 746 |
-
"for
|
| 747 |
-
"
|
| 748 |
-
" result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
|
| 749 |
-
" after_results[task] = result\n",
|
| 750 |
-
" print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
|
| 751 |
"\n",
|
| 752 |
"print(\"\\n\" + \"=\" * 60)\n",
|
| 753 |
-
"print(\"AFTER TRAINING:\")\n",
|
| 754 |
"for t in TASKS:\n",
|
| 755 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 756 |
-
]
|
|
|
|
|
|
|
| 757 |
},
|
| 758 |
{
|
| 759 |
"cell_type": "markdown",
|
|
@@ -764,9 +812,7 @@
|
|
| 764 |
},
|
| 765 |
{
|
| 766 |
"cell_type": "code",
|
| 767 |
-
"execution_count": null,
|
| 768 |
"metadata": {},
|
| 769 |
-
"outputs": [],
|
| 770 |
"source": [
|
| 771 |
"# Cell 13: Training curves\n",
|
| 772 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
@@ -788,13 +834,13 @@
|
|
| 788 |
"fig.tight_layout()\n",
|
| 789 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 790 |
"plt.show()"
|
| 791 |
-
]
|
|
|
|
|
|
|
| 792 |
},
|
| 793 |
{
|
| 794 |
"cell_type": "code",
|
| 795 |
-
"execution_count": null,
|
| 796 |
"metadata": {},
|
| 797 |
-
"outputs": [],
|
| 798 |
"source": [
|
| 799 |
"# Cell 14: Before vs After\n",
|
| 800 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
@@ -824,13 +870,13 @@
|
|
| 824 |
"fig.tight_layout()\n",
|
| 825 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 826 |
"plt.show()"
|
| 827 |
-
]
|
|
|
|
|
|
|
| 828 |
},
|
| 829 |
{
|
| 830 |
"cell_type": "code",
|
| 831 |
-
"execution_count": null,
|
| 832 |
"metadata": {},
|
| 833 |
-
"outputs": [],
|
| 834 |
"source": [
|
| 835 |
"# Cell 15: Trajectory comparison\n",
|
| 836 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
@@ -854,7 +900,9 @@
|
|
| 854 |
"fig.tight_layout()\n",
|
| 855 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 856 |
"plt.show()"
|
| 857 |
-
]
|
|
|
|
|
|
|
| 858 |
},
|
| 859 |
{
|
| 860 |
"cell_type": "markdown",
|
|
@@ -865,9 +913,7 @@
|
|
| 865 |
},
|
| 866 |
{
|
| 867 |
"cell_type": "code",
|
| 868 |
-
"execution_count": null,
|
| 869 |
"metadata": {},
|
| 870 |
-
"outputs": [],
|
| 871 |
"source": [
|
| 872 |
"# Cell 16: Final summary\n",
|
| 873 |
"print(\"=\" * 67)\n",
|
|
@@ -904,13 +950,13 @@
|
|
| 904 |
"\n",
|
| 905 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 906 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 907 |
-
]
|
|
|
|
|
|
|
| 908 |
},
|
| 909 |
{
|
| 910 |
"cell_type": "code",
|
| 911 |
-
"execution_count": null,
|
| 912 |
"metadata": {},
|
| 913 |
-
"outputs": [],
|
| 914 |
"source": [
|
| 915 |
"# Cell 17: Save adapter\n",
|
| 916 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
@@ -918,7 +964,9 @@
|
|
| 918 |
"tokenizer.save_pretrained(save_path)\n",
|
| 919 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 920 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 921 |
-
]
|
|
|
|
|
|
|
| 922 |
}
|
| 923 |
],
|
| 924 |
"metadata": {
|
|
@@ -944,4 +992,4 @@
|
|
| 944 |
},
|
| 945 |
"nbformat": 4,
|
| 946 |
"nbformat_minor": 4
|
| 947 |
-
}
|
|
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 31 |
"!pip install -q torch torchvision torchaudio\n",
|
| 32 |
+
"!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\"\n",
|
| 33 |
"!pip install -q matplotlib pandas\n",
|
| 34 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 35 |
+
"!pip install -q \"openenv-core[core]>=0.2.2\"\n",
|
| 36 |
+
"!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 37 |
+
],
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"outputs": []
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
|
|
|
| 43 |
"metadata": {},
|
|
|
|
| 44 |
"source": [
|
| 45 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 46 |
"import os\n",
|
|
|
|
| 116 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 117 |
"print(f\"Commit: {commit}\")\n",
|
| 118 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 119 |
+
],
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"outputs": []
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
|
|
|
| 125 |
"metadata": {},
|
|
|
|
| 126 |
"source": [
|
| 127 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 128 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
|
|
| 176 |
"import ast\n",
|
| 177 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 178 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 179 |
+
],
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"outputs": []
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
|
|
|
| 194 |
"metadata": {},
|
|
|
|
| 195 |
"source": [
|
| 196 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 197 |
"_rng = random.Random(42)\n",
|
|
|
|
| 268 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 269 |
"\n",
|
| 270 |
"print(\"Agents and episode runner defined.\")"
|
| 271 |
+
],
|
| 272 |
+
"execution_count": null,
|
| 273 |
+
"outputs": []
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "code",
|
|
|
|
| 277 |
"metadata": {},
|
|
|
|
| 278 |
"source": [
|
| 279 |
"# Cell 5: Run baselines (safe)\n",
|
| 280 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
|
|
| 309 |
"for name in BASELINE_AGENTS:\n",
|
| 310 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 311 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 312 |
+
],
|
| 313 |
+
"execution_count": null,
|
| 314 |
+
"outputs": []
|
| 315 |
},
|
| 316 |
{
|
| 317 |
"cell_type": "code",
|
|
|
|
| 318 |
"metadata": {},
|
|
|
|
| 319 |
"source": [
|
| 320 |
"# Cell 6: Baseline plots\n",
|
| 321 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
|
|
| 333 |
"fig.tight_layout()\n",
|
| 334 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 335 |
"plt.show()"
|
| 336 |
+
],
|
| 337 |
+
"execution_count": null,
|
| 338 |
+
"outputs": []
|
| 339 |
},
|
| 340 |
{
|
| 341 |
"cell_type": "markdown",
|
|
|
|
| 348 |
},
|
| 349 |
{
|
| 350 |
"cell_type": "code",
|
|
|
|
| 351 |
"metadata": {},
|
|
|
|
| 352 |
"source": [
|
| 353 |
+
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 354 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 355 |
"\n",
|
| 356 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
| 357 |
"\n",
|
| 358 |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 359 |
+
"if tokenizer.pad_token is None:\n",
|
| 360 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 361 |
+
"tokenizer.padding_side = \"left\"\n",
|
| 362 |
"\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"def _has_flash_attn():\n",
|
| 365 |
+
" try:\n",
|
| 366 |
+
" import flash_attn # noqa: F401\n",
|
| 367 |
+
" return torch.cuda.is_available()\n",
|
| 368 |
+
" except Exception:\n",
|
| 369 |
+
" return False\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"if torch.cuda.is_available():\n",
|
| 373 |
+
" dtype = torch.bfloat16\n",
|
| 374 |
+
" attn_impl = \"flash_attention_2\" if _has_flash_attn() else \"sdpa\"\n",
|
| 375 |
+
"elif getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available():\n",
|
| 376 |
+
" dtype, attn_impl = torch.float16, \"sdpa\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
"else:\n",
|
| 378 |
+
" dtype, attn_impl = torch.float32, \"eager\"\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"print(f\"Loading {MODEL_NAME} (dtype={dtype}, attn={attn_impl})...\")\n",
|
| 381 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 382 |
+
" MODEL_NAME,\n",
|
| 383 |
+
" trust_remote_code=True,\n",
|
| 384 |
+
" dtype=dtype,\n",
|
| 385 |
+
" attn_implementation=attn_impl,\n",
|
| 386 |
+
" device_map=\"cuda:0\" if torch.cuda.is_available() else None,\n",
|
| 387 |
+
")\n",
|
| 388 |
+
"if not torch.cuda.is_available():\n",
|
| 389 |
+
" model = model.to(\"mps\") if (getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available()) else model.to(\"cpu\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
"\n",
|
| 391 |
"model.eval()\n",
|
| 392 |
+
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
"if torch.cuda.is_available():\n",
|
| 394 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 395 |
+
],
|
| 396 |
+
"execution_count": null,
|
| 397 |
+
"outputs": []
|
| 398 |
},
|
| 399 |
{
|
| 400 |
"cell_type": "code",
|
|
|
|
| 401 |
"metadata": {},
|
|
|
|
| 402 |
"source": [
|
| 403 |
"# Cell 8: LLM agent functions\n",
|
| 404 |
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
|
|
|
|
| 446 |
" f\"Plan your actions (JSON only):\")\n",
|
| 447 |
"\n",
|
| 448 |
"\n",
|
| 449 |
+
"def is_well_formed_response(text):\n",
|
| 450 |
+
" try:\n",
|
| 451 |
+
" t = text.strip()\n",
|
| 452 |
+
" if \"```\" in t:\n",
|
| 453 |
+
" t = \"\\n\".join(l for l in t.split(\"\\n\") if not l.strip().startswith(\"```\")).strip()\n",
|
| 454 |
+
" s, e = t.find(\"{\"), t.rfind(\"}\") + 1\n",
|
| 455 |
+
" d = json.loads(t[s:e])\n",
|
| 456 |
+
" for tc in d.get(\"tool_calls\", []):\n",
|
| 457 |
+
" if not isinstance(tc, dict) or not isinstance(tc.get(\"arguments\", {}), dict):\n",
|
| 458 |
+
" return False\n",
|
| 459 |
+
" return True\n",
|
| 460 |
+
" except Exception:\n",
|
| 461 |
+
" return False\n",
|
| 462 |
+
"\n",
|
| 463 |
+
"\n",
|
| 464 |
"def parse_model_output(text):\n",
|
| 465 |
" text = text.strip()\n",
|
| 466 |
" if \"```\" in text:\n",
|
|
|
|
| 471 |
" text = text[start:end]\n",
|
| 472 |
" try:\n",
|
| 473 |
" data = json.loads(text)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
" except Exception:\n",
|
|
|
|
| 475 |
" return ViraltestAction(scheduled_actions=[])\n",
|
| 476 |
+
" tool_calls = []\n",
|
| 477 |
+
" for tc in data.get(\"tool_calls\", []):\n",
|
| 478 |
+
" if not isinstance(tc, dict) or \"name\" not in tc:\n",
|
| 479 |
+
" continue\n",
|
| 480 |
+
" args = tc.get(\"arguments\", {})\n",
|
| 481 |
+
" if isinstance(args, list) and args and isinstance(args[0], dict):\n",
|
| 482 |
+
" args = args[0]\n",
|
| 483 |
+
" if not isinstance(args, dict):\n",
|
| 484 |
+
" continue\n",
|
| 485 |
+
" try:\n",
|
| 486 |
+
" tool_calls.append(ToolCall(name=tc[\"name\"], arguments=args))\n",
|
| 487 |
+
" except Exception:\n",
|
| 488 |
+
" pass\n",
|
| 489 |
+
" scheduled = []\n",
|
| 490 |
+
" for a in data.get(\"scheduled_actions\", []):\n",
|
| 491 |
+
" try:\n",
|
| 492 |
+
" scheduled.append(ScheduledAction(**a))\n",
|
| 493 |
+
" except Exception:\n",
|
| 494 |
+
" pass\n",
|
| 495 |
+
" return ViraltestAction(\n",
|
| 496 |
+
" tool_calls=tool_calls,\n",
|
| 497 |
+
" scheduled_actions=scheduled,\n",
|
| 498 |
+
" replies=data.get(\"replies\", []),\n",
|
| 499 |
+
" notes=data.get(\"notes\"),\n",
|
| 500 |
+
" )\n",
|
| 501 |
"\n",
|
| 502 |
"\n",
|
| 503 |
"def _infer_model_device(m):\n",
|
|
|
|
| 511 |
" return torch.device(\"cpu\")\n",
|
| 512 |
"\n",
|
| 513 |
"\n",
|
| 514 |
+
"def _build_chat(history, prompt):\n",
|
| 515 |
+
" msgs = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
|
| 516 |
+
" msgs.extend(history[-14:])\n",
|
| 517 |
+
" msgs.append({\"role\": \"user\", \"content\": prompt})\n",
|
| 518 |
+
" return msgs\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"def _batched_generate(mdl, tok, prompts, temperature=0.7, max_new_tokens=512):\n",
|
| 522 |
+
" enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
|
| 523 |
" with torch.no_grad():\n",
|
| 524 |
+
" out = mdl.generate(\n",
|
| 525 |
+
" **enc, max_new_tokens=max_new_tokens, temperature=temperature,\n",
|
| 526 |
+
" do_sample=True, top_p=0.9, pad_token_id=tok.pad_token_id,\n",
|
| 527 |
+
" )\n",
|
| 528 |
+
" resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 529 |
+
" return resps, enc[\"input_ids\"].shape[1]\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True):\n",
|
| 533 |
+
" \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
|
| 534 |
+
" n = len(tasks_seeds)\n",
|
| 535 |
+
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 536 |
+
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 537 |
+
" histories = [[] for _ in range(n)]\n",
|
| 538 |
+
" rewards = [[] for _ in range(n)]\n",
|
| 539 |
+
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 540 |
+
" pairs = [[] for _ in range(n)]\n",
|
| 541 |
+
" done_mask = [obs.done for obs in obss]\n",
|
| 542 |
+
" rest_resp = '{\"scheduled_actions\": []}'\n",
|
| 543 |
+
"\n",
|
| 544 |
+
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 545 |
+
" active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
|
| 546 |
+
" rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
|
| 547 |
+
" if not active and not rest:\n",
|
| 548 |
+
" break\n",
|
| 549 |
+
"\n",
|
| 550 |
+
" resps_by_idx = {}\n",
|
| 551 |
+
" if active:\n",
|
| 552 |
+
" prompts = [format_obs(obss[i]) for i in active]\n",
|
| 553 |
+
" chats = [_build_chat(histories[i], p) for i, p in zip(active, prompts)]\n",
|
| 554 |
+
" texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
|
| 555 |
+
" resps, ptok = _batched_generate(mdl, tok, texts)\n",
|
| 556 |
+
" if verbose:\n",
|
| 557 |
+
" print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
|
| 558 |
+
" for j, i in enumerate(active):\n",
|
| 559 |
+
" resps_by_idx[i] = (resps[j], prompts[j])\n",
|
| 560 |
+
" for i in rest:\n",
|
| 561 |
+
" resps_by_idx[i] = (rest_resp, format_obs(obss[i]))\n",
|
| 562 |
+
"\n",
|
| 563 |
+
" for i in range(n):\n",
|
| 564 |
+
" if done_mask[i] or i not in resps_by_idx:\n",
|
| 565 |
+
" continue\n",
|
| 566 |
+
" resp, prompt = resps_by_idx[i]\n",
|
| 567 |
+
" action = parse_model_output(resp)\n",
|
| 568 |
+
" pairs[i].append({\"prompt\": prompt, \"response\": resp})\n",
|
| 569 |
+
" obss[i] = envs[i].step(action)\n",
|
| 570 |
+
" r = obss[i].reward or 0.0\n",
|
| 571 |
+
" rewards[i].append(r)\n",
|
| 572 |
+
" energies[i].append(obss[i].creator_energy)\n",
|
| 573 |
+
" histories[i].extend([\n",
|
| 574 |
+
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 575 |
+
" {\"role\": \"assistant\", \"content\": resp},\n",
|
| 576 |
+
" ])\n",
|
| 577 |
+
" if obss[i].done:\n",
|
| 578 |
+
" done_mask[i] = True\n",
|
| 579 |
+
"\n",
|
| 580 |
+
" GAMMA, TERMINAL_W = 0.95, 5.0\n",
|
| 581 |
+
" results = []\n",
|
| 582 |
+
" for i, (task, seed) in enumerate(tasks_seeds):\n",
|
| 583 |
+
" gs = (obss[i].metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 584 |
+
" rets = [0.0] * len(rewards[i])\n",
|
| 585 |
+
" G = gs * TERMINAL_W\n",
|
| 586 |
+
" for t in reversed(range(len(rewards[i]))):\n",
|
| 587 |
+
" G = rewards[i][t] + GAMMA * G\n",
|
| 588 |
+
" rets[t] = G\n",
|
| 589 |
+
" for k, pr in enumerate(pairs[i]):\n",
|
| 590 |
+
" pr[\"return\"] = rets[k] if k < len(rets) else 0.0\n",
|
| 591 |
+
" results.append({\n",
|
| 592 |
+
" \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
|
| 593 |
+
" \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
|
| 594 |
+
" \"rewards\": rewards[i], \"returns\": rets, \"energies\": energies[i],\n",
|
| 595 |
+
" \"pairs\": pairs[i], \"follower_delta\": obss[i].follower_count - 10000,\n",
|
| 596 |
+
" \"burned_out\": obss[i].creator_energy <= 0,\n",
|
| 597 |
+
" })\n",
|
| 598 |
+
" return results\n",
|
| 599 |
"\n",
|
| 600 |
"\n",
|
| 601 |
"def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
|
| 602 |
+
" return run_llm_episodes_batched(mdl, tok, [(task, seed)], verbose=verbose)[0]\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
"\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"print(\"LLM agent functions defined (batched).\")"
|
| 606 |
+
],
|
| 607 |
+
"execution_count": null,
|
| 608 |
+
"outputs": []
|
| 609 |
},
|
| 610 |
{
|
| 611 |
"cell_type": "markdown",
|
|
|
|
| 618 |
},
|
| 619 |
{
|
| 620 |
"cell_type": "code",
|
|
|
|
| 621 |
"metadata": {},
|
|
|
|
| 622 |
"source": [
|
| 623 |
+
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 624 |
+
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
| 625 |
"print(\"=\" * 60)\n",
|
| 626 |
"\n",
|
| 627 |
+
"t0 = time.time()\n",
|
| 628 |
+
"results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
|
| 629 |
+
"before_results = {r[\"task\"]: r for r in results}\n",
|
|
|
|
|
|
|
|
|
|
| 630 |
"\n",
|
| 631 |
"print(\"\\n\" + \"=\" * 60)\n",
|
| 632 |
+
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 633 |
"for t in TASKS:\n",
|
| 634 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 635 |
+
],
|
| 636 |
+
"execution_count": null,
|
| 637 |
+
"outputs": []
|
| 638 |
},
|
| 639 |
{
|
| 640 |
"cell_type": "markdown",
|
|
|
|
| 653 |
},
|
| 654 |
{
|
| 655 |
"cell_type": "code",
|
|
|
|
| 656 |
"metadata": {},
|
|
|
|
| 657 |
"source": [
|
| 658 |
"# Cell 10: Attach LoRA adapter\n",
|
| 659 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
|
|
| 668 |
"model.enable_input_require_grads()\n",
|
| 669 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 670 |
"peft_model.print_trainable_parameters()"
|
| 671 |
+
],
|
| 672 |
+
"execution_count": null,
|
| 673 |
+
"outputs": []
|
| 674 |
},
|
| 675 |
{
|
| 676 |
"cell_type": "code",
|
|
|
|
| 677 |
"metadata": {},
|
|
|
|
| 678 |
"source": [
|
| 679 |
"# Cell 11: Training loop\n",
|
| 680 |
"from trl import SFTTrainer, SFTConfig\n",
|
|
|
|
| 697 |
" print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
|
| 698 |
" print(f\"{'=' * 60}\")\n",
|
| 699 |
"\n",
|
|
|
|
| 700 |
" peft_model.eval()\n",
|
| 701 |
+
" tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
|
| 702 |
+
" t_roll = time.time()\n",
|
| 703 |
+
" results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False)\n",
|
| 704 |
+
" print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
|
| 705 |
"\n",
|
| 706 |
+
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
| 707 |
+
" for ep, result in enumerate(results):\n",
|
|
|
|
|
|
|
| 708 |
" ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 709 |
" episode_rewards.append(ep_reward)\n",
|
| 710 |
" episode_graders.append(result[\"grader_score\"])\n",
|
| 711 |
+
" kept = 0\n",
|
| 712 |
" for pr in result[\"pairs\"]:\n",
|
| 713 |
+
" if not is_well_formed_response(pr[\"response\"]):\n",
|
| 714 |
+
" continue\n",
|
| 715 |
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 716 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 717 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 718 |
+
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
| 719 |
+
" kept += 1\n",
|
| 720 |
+
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
|
| 721 |
+
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
|
| 722 |
+
"\n",
|
| 723 |
+
" avg_r = float(np.mean(episode_rewards))\n",
|
| 724 |
+
" avg_g = float(np.mean(episode_graders))\n",
|
| 725 |
+
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} | total pairs={len(all_pairs)}\")\n",
|
| 726 |
+
" if not all_pairs:\n",
|
| 727 |
+
" print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
|
| 728 |
+
" continue\n",
|
| 729 |
"\n",
|
|
|
|
| 730 |
" threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
|
| 731 |
" filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
|
| 732 |
+
" print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples (return >= {threshold:.3f})\")\n",
|
| 733 |
"\n",
|
| 734 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 735 |
"\n",
|
|
|
|
| 737 |
" sft_config = SFTConfig(\n",
|
| 738 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 739 |
" num_train_epochs=2,\n",
|
| 740 |
+
" per_device_train_batch_size=4,\n",
|
| 741 |
+
" gradient_accumulation_steps=2,\n",
|
| 742 |
" learning_rate=2e-5,\n",
|
| 743 |
" warmup_steps=5,\n",
|
| 744 |
" logging_steps=5,\n",
|
| 745 |
" save_strategy=\"no\",\n",
|
| 746 |
+
" max_length=4096,\n",
|
| 747 |
+
" bf16=True,\n",
|
| 748 |
" report_to=\"none\",\n",
|
| 749 |
" )\n",
|
| 750 |
"\n",
|
|
|
|
| 769 |
"elapsed = time.time() - t_start\n",
|
| 770 |
"print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
|
| 771 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 772 |
+
],
|
| 773 |
+
"execution_count": null,
|
| 774 |
+
"outputs": []
|
| 775 |
},
|
| 776 |
{
|
| 777 |
"cell_type": "markdown",
|
|
|
|
| 784 |
},
|
| 785 |
{
|
| 786 |
"cell_type": "code",
|
|
|
|
| 787 |
"metadata": {},
|
|
|
|
| 788 |
"source": [
|
| 789 |
+
"# Cell 12: Run trained model (batched)\n",
|
| 790 |
+
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
| 791 |
"print(\"=\" * 60)\n",
|
| 792 |
"\n",
|
| 793 |
"peft_model.eval()\n",
|
| 794 |
+
"t0 = time.time()\n",
|
| 795 |
+
"results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
|
| 796 |
+
"after_results = {r[\"task\"]: r for r in results}\n",
|
|
|
|
|
|
|
|
|
|
| 797 |
"\n",
|
| 798 |
"print(\"\\n\" + \"=\" * 60)\n",
|
| 799 |
+
"print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 800 |
"for t in TASKS:\n",
|
| 801 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 802 |
+
],
|
| 803 |
+
"execution_count": null,
|
| 804 |
+
"outputs": []
|
| 805 |
},
|
| 806 |
{
|
| 807 |
"cell_type": "markdown",
|
|
|
|
| 812 |
},
|
| 813 |
{
|
| 814 |
"cell_type": "code",
|
|
|
|
| 815 |
"metadata": {},
|
|
|
|
| 816 |
"source": [
|
| 817 |
"# Cell 13: Training curves\n",
|
| 818 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
|
|
|
| 834 |
"fig.tight_layout()\n",
|
| 835 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 836 |
"plt.show()"
|
| 837 |
+
],
|
| 838 |
+
"execution_count": null,
|
| 839 |
+
"outputs": []
|
| 840 |
},
|
| 841 |
{
|
| 842 |
"cell_type": "code",
|
|
|
|
| 843 |
"metadata": {},
|
|
|
|
| 844 |
"source": [
|
| 845 |
"# Cell 14: Before vs After\n",
|
| 846 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
|
|
| 870 |
"fig.tight_layout()\n",
|
| 871 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 872 |
"plt.show()"
|
| 873 |
+
],
|
| 874 |
+
"execution_count": null,
|
| 875 |
+
"outputs": []
|
| 876 |
},
|
| 877 |
{
|
| 878 |
"cell_type": "code",
|
|
|
|
| 879 |
"metadata": {},
|
|
|
|
| 880 |
"source": [
|
| 881 |
"# Cell 15: Trajectory comparison\n",
|
| 882 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
|
|
| 900 |
"fig.tight_layout()\n",
|
| 901 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 902 |
"plt.show()"
|
| 903 |
+
],
|
| 904 |
+
"execution_count": null,
|
| 905 |
+
"outputs": []
|
| 906 |
},
|
| 907 |
{
|
| 908 |
"cell_type": "markdown",
|
|
|
|
| 913 |
},
|
| 914 |
{
|
| 915 |
"cell_type": "code",
|
|
|
|
| 916 |
"metadata": {},
|
|
|
|
| 917 |
"source": [
|
| 918 |
"# Cell 16: Final summary\n",
|
| 919 |
"print(\"=\" * 67)\n",
|
|
|
|
| 950 |
"\n",
|
| 951 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 952 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 953 |
+
],
|
| 954 |
+
"execution_count": null,
|
| 955 |
+
"outputs": []
|
| 956 |
},
|
| 957 |
{
|
| 958 |
"cell_type": "code",
|
|
|
|
| 959 |
"metadata": {},
|
|
|
|
| 960 |
"source": [
|
| 961 |
"# Cell 17: Save adapter\n",
|
| 962 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
|
|
| 964 |
"tokenizer.save_pretrained(save_path)\n",
|
| 965 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 966 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 967 |
+
],
|
| 968 |
+
"execution_count": null,
|
| 969 |
+
"outputs": []
|
| 970 |
}
|
| 971 |
],
|
| 972 |
"metadata": {
|
|
|
|
| 992 |
},
|
| 993 |
"nbformat": 4,
|
| 994 |
"nbformat_minor": 4
|
| 995 |
+
}
|