vaibhav12332112312 commited on
Commit
a6b8df0
·
1 Parent(s): 81cdb34

train: batched parallel rollouts on Qwen2.5-3B + parser hardening

Browse files

- Qwen2.5-3B-Instruct in bf16 + flash-attn-2 (sdpa fallback)
- New run_llm_episodes_batched: N parallel envs, one batched generate per day
(~10x faster rollouts than sequential)
- parse_model_output: per-tool-call try/except so a malformed `arguments` no
longer wipes the whole action (root cause of post-train follower collapse)
- is_well_formed_response filter on SFT data
- SFT: max_length=4096, batch=4 x accum=2, bf16
- Per-step credit assignment for SFT sample weights

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +251 -203
training/train_grpo.ipynb CHANGED
@@ -25,23 +25,22 @@
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": null,
29
  "metadata": {},
30
- "outputs": [],
31
  "source": [
32
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
33
  "!pip install -q torch torchvision torchaudio\n",
34
- "!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\" \"bitsandbytes\"\n",
35
  "!pip install -q matplotlib pandas\n",
36
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
37
- "!pip install -q \"openenv-core[core]>=0.2.2\""
38
- ]
 
 
 
39
  },
40
  {
41
  "cell_type": "code",
42
- "execution_count": null,
43
  "metadata": {},
44
- "outputs": [],
45
  "source": [
46
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
47
  "import os\n",
@@ -117,13 +116,13 @@
117
  "print(f\"Branch: {REPO_BRANCH}\")\n",
118
  "print(f\"Commit: {commit}\")\n",
119
  "print(f\"Plots dir: {PLOTS_DIR}\")"
120
- ]
 
 
121
  },
122
  {
123
  "cell_type": "code",
124
- "execution_count": null,
125
  "metadata": {},
126
- "outputs": [],
127
  "source": [
128
  "# Cell 3: Imports (with runtime validation)\n",
129
  "import json, random, time, textwrap, copy, os, sys\n",
@@ -177,7 +176,9 @@
177
  "import ast\n",
178
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
179
  "print(\"OK: ast.parse (syntax check)\")"
180
- ]
 
 
181
  },
182
  {
183
  "cell_type": "markdown",
@@ -190,9 +191,7 @@
190
  },
191
  {
192
  "cell_type": "code",
193
- "execution_count": null,
194
  "metadata": {},
195
- "outputs": [],
196
  "source": [
197
  "# Cell 4: Define heuristic agents + episode runner\n",
198
  "_rng = random.Random(42)\n",
@@ -269,13 +268,13 @@
269
  " \"rewards\": rewards, \"energies\": energies}\n",
270
  "\n",
271
  "print(\"Agents and episode runner defined.\")"
272
- ]
 
 
273
  },
274
  {
275
  "cell_type": "code",
276
- "execution_count": null,
277
  "metadata": {},
278
- "outputs": [],
279
  "source": [
280
  "# Cell 5: Run baselines (safe)\n",
281
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -310,13 +309,13 @@
310
  "for name in BASELINE_AGENTS:\n",
311
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
312
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
313
- ]
 
 
314
  },
315
  {
316
  "cell_type": "code",
317
- "execution_count": null,
318
  "metadata": {},
319
- "outputs": [],
320
  "source": [
321
  "# Cell 6: Baseline plots\n",
322
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -334,7 +333,9 @@
334
  "fig.tight_layout()\n",
335
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
336
  "plt.show()"
337
- ]
 
 
338
  },
339
  {
340
  "cell_type": "markdown",
@@ -347,80 +348,57 @@
347
  },
348
  {
349
  "cell_type": "code",
350
- "execution_count": null,
351
  "metadata": {},
352
- "outputs": [],
353
  "source": [
354
- "# Cell 7: Load model (4-bit on CUDA Colab; fp16/fp32 fallback if bitsandbytes missing)\n",
355
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
356
  "\n",
357
- "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
358
  "\n",
359
  "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
 
 
 
360
  "\n",
361
- "_use_4bit = False\n",
362
- "try:\n",
363
- " from transformers.utils import is_bitsandbytes_available\n",
364
- "except Exception: # older transformers\n",
365
- " def is_bitsandbytes_available():\n",
366
- " try:\n",
367
- " import bitsandbytes # noqa: F401\n",
368
- " return True\n",
369
- " except ImportError:\n",
370
- " return False\n",
371
- "\n",
372
- "if torch.cuda.is_available() and is_bitsandbytes_available():\n",
373
- " from transformers import BitsAndBytesConfig\n",
374
- " _use_4bit = True\n",
375
- "\n",
376
- "if _use_4bit:\n",
377
- " print(f\"Loading {MODEL_NAME} (4-bit quantized, CUDA)...\")\n",
378
- " bnb_config = BitsAndBytesConfig(\n",
379
- " load_in_4bit=True,\n",
380
- " bnb_4bit_quant_type=\"nf4\",\n",
381
- " bnb_4bit_compute_dtype=torch.float16,\n",
382
- " bnb_4bit_use_double_quant=True,\n",
383
- " )\n",
384
- " model = AutoModelForCausalLM.from_pretrained(\n",
385
- " MODEL_NAME,\n",
386
- " trust_remote_code=True,\n",
387
- " quantization_config=bnb_config,\n",
388
- " device_map=\"auto\",\n",
389
- " )\n",
390
  "else:\n",
391
- " print(\n",
392
- " f\"Loading {MODEL_NAME} without 4-bit (bitsandbytes/CUDA unavailable).\\n\"\n",
393
- " \" On Colab: run `pip install -U bitsandbytes>=0.46.1` and use a GPU runtime.\\n\"\n",
394
- " \" On Mac: use fp16 on MPS or fp32 on CPU.\"\n",
395
- " )\n",
396
- " dtype = torch.float16 if (torch.cuda.is_available() or getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available()) else torch.float32\n",
397
- " model = AutoModelForCausalLM.from_pretrained(\n",
398
- " MODEL_NAME,\n",
399
- " trust_remote_code=True,\n",
400
- " dtype=dtype,\n",
401
- " device_map=\"auto\" if torch.cuda.is_available() else None,\n",
402
- " )\n",
403
- " if not torch.cuda.is_available():\n",
404
- " if getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available():\n",
405
- " model = model.to(\"mps\")\n",
406
- " else:\n",
407
- " model = model.to(\"cpu\")\n",
408
  "\n",
409
  "model.eval()\n",
410
- "print(f\"Model loaded. dtype={next(model.parameters()).dtype}\")\n",
411
- "try:\n",
412
- " print(f\"Device: {model.device}\")\n",
413
- "except Exception:\n",
414
- " print(\"Device: (see first parameter device)\")\n",
415
  "if torch.cuda.is_available():\n",
416
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
417
- ]
 
 
418
  },
419
  {
420
  "cell_type": "code",
421
- "execution_count": null,
422
  "metadata": {},
423
- "outputs": [],
424
  "source": [
425
  "# Cell 8: LLM agent functions\n",
426
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
@@ -468,6 +446,21 @@
468
  " f\"Plan your actions (JSON only):\")\n",
469
  "\n",
470
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  "def parse_model_output(text):\n",
472
  " text = text.strip()\n",
473
  " if \"```\" in text:\n",
@@ -478,24 +471,33 @@
478
  " text = text[start:end]\n",
479
  " try:\n",
480
  " data = json.loads(text)\n",
481
- " tool_calls = [ToolCall(name=tc[\"name\"], arguments=tc.get(\"arguments\", {}))\n",
482
- " for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
483
- " scheduled = []\n",
484
- " for a in data.get(\"scheduled_actions\", []):\n",
485
- " try:\n",
486
- " scheduled.append(ScheduledAction(**a))\n",
487
- " except Exception:\n",
488
- " # Same as original bare `except:`: skip invalid scheduled_actions entries\n",
489
- " pass\n",
490
- " return ViraltestAction(\n",
491
- " tool_calls=tool_calls,\n",
492
- " scheduled_actions=scheduled,\n",
493
- " replies=data.get(\"replies\", []),\n",
494
- " notes=data.get(\"notes\"),\n",
495
- " )\n",
496
  " except Exception:\n",
497
- " # Same behavior as original bare `except:`: any parse/validation failure -> empty action\n",
498
  " return ViraltestAction(scheduled_actions=[])\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  "\n",
500
  "\n",
501
  "def _infer_model_device(m):\n",
@@ -509,53 +511,101 @@
509
  " return torch.device(\"cpu\")\n",
510
  "\n",
511
  "\n",
512
- "def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
513
- " prompt = format_obs(obs)\n",
514
- " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
515
- " messages.extend(history[-4:])\n",
516
- " messages.append({\"role\": \"user\", \"content\": prompt})\n",
517
- " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
518
- " inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
 
 
519
  " with torch.no_grad():\n",
520
- " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
521
- " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
522
- " resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
523
- " return resp, parse_model_output(resp)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  "\n",
525
  "\n",
526
  "def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
527
- " env = ViraltestEnvironment()\n",
528
- " obs = env.reset(task=task, seed=seed)\n",
529
- " rewards, energies = [], [obs.creator_energy]\n",
530
- " history, pairs = [], []\n",
531
- " for day in range(1, TASK_HORIZON + 1):\n",
532
- " if obs.done: break\n",
533
- " if obs.creator_energy <= 0.25:\n",
534
- " action = ViraltestAction(scheduled_actions=[])\n",
535
- " resp = '{\"scheduled_actions\": []}'\n",
536
- " else:\n",
537
- " resp, action = generate_action(mdl, tok, obs, history)\n",
538
- " prompt = format_obs(obs)\n",
539
- " pairs.append({\"prompt\": prompt, \"response\": resp})\n",
540
- " obs = env.step(action)\n",
541
- " r = obs.reward or 0.0\n",
542
- " rewards.append(r)\n",
543
- " energies.append(obs.creator_energy)\n",
544
- " history.extend([{\"role\": \"user\", \"content\": prompt},\n",
545
- " {\"role\": \"assistant\", \"content\": resp}])\n",
546
- " if verbose:\n",
547
- " n_p = len([s for s in action.scheduled_actions if s.action_type==\"post\"])\n",
548
- " print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
549
- " if obs.done: break\n",
550
- " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
551
- " return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
552
- " \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
553
- " \"energies\": energies, \"pairs\": pairs,\n",
554
- " \"follower_delta\": obs.follower_count - 10000,\n",
555
- " \"burned_out\": obs.creator_energy <= 0}\n",
556
  "\n",
557
- "print(\"LLM agent functions defined.\")"
558
- ]
 
 
 
559
  },
560
  {
561
  "cell_type": "markdown",
@@ -568,26 +618,23 @@
568
  },
569
  {
570
  "cell_type": "code",
571
- "execution_count": null,
572
  "metadata": {},
573
- "outputs": [],
574
  "source": [
575
- "# Cell 9: Run untrained model\n",
576
- "print(\"Running UNTRAINED base model on all tasks...\")\n",
577
  "print(\"=\" * 60)\n",
578
  "\n",
579
- "before_results = {}\n",
580
- "for task in TASKS:\n",
581
- " print(f\"\\n Task: {task}\")\n",
582
- " result = run_llm_episode(model, tokenizer, task, seed=42, verbose=True)\n",
583
- " before_results[task] = result\n",
584
- " print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
585
  "\n",
586
  "print(\"\\n\" + \"=\" * 60)\n",
587
- "print(\"BEFORE TRAINING:\")\n",
588
  "for t in TASKS:\n",
589
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
590
- ]
 
 
591
  },
592
  {
593
  "cell_type": "markdown",
@@ -606,9 +653,7 @@
606
  },
607
  {
608
  "cell_type": "code",
609
- "execution_count": null,
610
  "metadata": {},
611
- "outputs": [],
612
  "source": [
613
  "# Cell 10: Attach LoRA adapter\n",
614
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -623,13 +668,13 @@
623
  "model.enable_input_require_grads()\n",
624
  "peft_model = get_peft_model(model, lora_config)\n",
625
  "peft_model.print_trainable_parameters()"
626
- ]
 
 
627
  },
628
  {
629
  "cell_type": "code",
630
- "execution_count": null,
631
  "metadata": {},
632
- "outputs": [],
633
  "source": [
634
  "# Cell 11: Training loop\n",
635
  "from trl import SFTTrainer, SFTConfig\n",
@@ -652,35 +697,39 @@
652
  " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
653
  " print(f\"{'=' * 60}\")\n",
654
  "\n",
655
- " # Collect episodes\n",
656
  " peft_model.eval()\n",
657
- " all_pairs, episode_rewards, episode_graders = [], [], []\n",
 
 
 
658
  "\n",
659
- " for ep in range(EPISODES_PER_ROUND):\n",
660
- " task = TASKS[ep % len(TASKS)]\n",
661
- " seed = 42 + (round_idx - 1) * 100 + ep\n",
662
- " result = run_llm_episode(peft_model, tokenizer, task, seed=seed)\n",
663
  " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
664
  " episode_rewards.append(ep_reward)\n",
665
  " episode_graders.append(result[\"grader_score\"])\n",
666
- "\n",
667
  " for pr in result[\"pairs\"]:\n",
 
 
668
  " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
669
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
670
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
671
- " all_pairs.append({\"text\": text, \"reward\": ep_reward})\n",
672
- "\n",
673
- " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
674
- " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\")\n",
675
- "\n",
676
- " avg_r = np.mean(episode_rewards)\n",
677
- " avg_g = np.mean(episode_graders)\n",
678
- " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
 
 
 
679
  "\n",
680
- " # Filter to top-K\n",
681
  " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
682
  " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
683
- " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
684
  "\n",
685
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
686
  "\n",
@@ -688,14 +737,14 @@
688
  " sft_config = SFTConfig(\n",
689
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
690
  " num_train_epochs=2,\n",
691
- " per_device_train_batch_size=1,\n",
692
- " gradient_accumulation_steps=4,\n",
693
  " learning_rate=2e-5,\n",
694
  " warmup_steps=5,\n",
695
  " logging_steps=5,\n",
696
  " save_strategy=\"no\",\n",
697
- " max_length=1024,\n",
698
- " fp16=True,\n",
699
  " report_to=\"none\",\n",
700
  " )\n",
701
  "\n",
@@ -720,7 +769,9 @@
720
  "elapsed = time.time() - t_start\n",
721
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
722
  "print(pd.DataFrame(training_log).to_string(index=False))"
723
- ]
 
 
724
  },
725
  {
726
  "cell_type": "markdown",
@@ -733,27 +784,24 @@
733
  },
734
  {
735
  "cell_type": "code",
736
- "execution_count": null,
737
  "metadata": {},
738
- "outputs": [],
739
  "source": [
740
- "# Cell 12: Run trained model\n",
741
- "print(\"Running TRAINED model on all tasks...\")\n",
742
  "print(\"=\" * 60)\n",
743
  "\n",
744
  "peft_model.eval()\n",
745
- "after_results = {}\n",
746
- "for task in TASKS:\n",
747
- " print(f\"\\n Task: {task}\")\n",
748
- " result = run_llm_episode(peft_model, tokenizer, task, seed=42, verbose=True)\n",
749
- " after_results[task] = result\n",
750
- " print(f\" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}\")\n",
751
  "\n",
752
  "print(\"\\n\" + \"=\" * 60)\n",
753
- "print(\"AFTER TRAINING:\")\n",
754
  "for t in TASKS:\n",
755
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
756
- ]
 
 
757
  },
758
  {
759
  "cell_type": "markdown",
@@ -764,9 +812,7 @@
764
  },
765
  {
766
  "cell_type": "code",
767
- "execution_count": null,
768
  "metadata": {},
769
- "outputs": [],
770
  "source": [
771
  "# Cell 13: Training curves\n",
772
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -788,13 +834,13 @@
788
  "fig.tight_layout()\n",
789
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
790
  "plt.show()"
791
- ]
 
 
792
  },
793
  {
794
  "cell_type": "code",
795
- "execution_count": null,
796
  "metadata": {},
797
- "outputs": [],
798
  "source": [
799
  "# Cell 14: Before vs After\n",
800
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -824,13 +870,13 @@
824
  "fig.tight_layout()\n",
825
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
826
  "plt.show()"
827
- ]
 
 
828
  },
829
  {
830
  "cell_type": "code",
831
- "execution_count": null,
832
  "metadata": {},
833
- "outputs": [],
834
  "source": [
835
  "# Cell 15: Trajectory comparison\n",
836
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -854,7 +900,9 @@
854
  "fig.tight_layout()\n",
855
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
856
  "plt.show()"
857
- ]
 
 
858
  },
859
  {
860
  "cell_type": "markdown",
@@ -865,9 +913,7 @@
865
  },
866
  {
867
  "cell_type": "code",
868
- "execution_count": null,
869
  "metadata": {},
870
- "outputs": [],
871
  "source": [
872
  "# Cell 16: Final summary\n",
873
  "print(\"=\" * 67)\n",
@@ -904,13 +950,13 @@
904
  "\n",
905
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
906
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
907
- ]
 
 
908
  },
909
  {
910
  "cell_type": "code",
911
- "execution_count": null,
912
  "metadata": {},
913
- "outputs": [],
914
  "source": [
915
  "# Cell 17: Save adapter\n",
916
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -918,7 +964,9 @@
918
  "tokenizer.save_pretrained(save_path)\n",
919
  "print(f\"LoRA adapter saved to {save_path}\")\n",
920
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
921
- ]
 
 
922
  }
923
  ],
924
  "metadata": {
@@ -944,4 +992,4 @@
944
  },
945
  "nbformat": 4,
946
  "nbformat_minor": 4
947
- }
 
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
31
  "!pip install -q torch torchvision torchaudio\n",
32
+ "!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\"\n",
33
  "!pip install -q matplotlib pandas\n",
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
+ "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
+ "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
37
+ ],
38
+ "execution_count": null,
39
+ "outputs": []
40
  },
41
  {
42
  "cell_type": "code",
 
43
  "metadata": {},
 
44
  "source": [
45
  "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
46
  "import os\n",
 
116
  "print(f\"Branch: {REPO_BRANCH}\")\n",
117
  "print(f\"Commit: {commit}\")\n",
118
  "print(f\"Plots dir: {PLOTS_DIR}\")"
119
+ ],
120
+ "execution_count": null,
121
+ "outputs": []
122
  },
123
  {
124
  "cell_type": "code",
 
125
  "metadata": {},
 
126
  "source": [
127
  "# Cell 3: Imports (with runtime validation)\n",
128
  "import json, random, time, textwrap, copy, os, sys\n",
 
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
  "print(\"OK: ast.parse (syntax check)\")"
179
+ ],
180
+ "execution_count": null,
181
+ "outputs": []
182
  },
183
  {
184
  "cell_type": "markdown",
 
191
  },
192
  {
193
  "cell_type": "code",
 
194
  "metadata": {},
 
195
  "source": [
196
  "# Cell 4: Define heuristic agents + episode runner\n",
197
  "_rng = random.Random(42)\n",
 
268
  " \"rewards\": rewards, \"energies\": energies}\n",
269
  "\n",
270
  "print(\"Agents and episode runner defined.\")"
271
+ ],
272
+ "execution_count": null,
273
+ "outputs": []
274
  },
275
  {
276
  "cell_type": "code",
 
277
  "metadata": {},
 
278
  "source": [
279
  "# Cell 5: Run baselines (safe)\n",
280
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
309
  "for name in BASELINE_AGENTS:\n",
310
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
311
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
312
+ ],
313
+ "execution_count": null,
314
+ "outputs": []
315
  },
316
  {
317
  "cell_type": "code",
 
318
  "metadata": {},
 
319
  "source": [
320
  "# Cell 6: Baseline plots\n",
321
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
333
  "fig.tight_layout()\n",
334
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
335
  "plt.show()"
336
+ ],
337
+ "execution_count": null,
338
+ "outputs": []
339
  },
340
  {
341
  "cell_type": "markdown",
 
348
  },
349
  {
350
  "cell_type": "code",
 
351
  "metadata": {},
 
352
  "source": [
353
+ "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
354
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
355
  "\n",
356
+ "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
357
  "\n",
358
  "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
359
+ "if tokenizer.pad_token is None:\n",
360
+ " tokenizer.pad_token = tokenizer.eos_token\n",
361
+ "tokenizer.padding_side = \"left\"\n",
362
  "\n",
363
+ "\n",
364
+ "def _has_flash_attn():\n",
365
+ " try:\n",
366
+ " import flash_attn # noqa: F401\n",
367
+ " return torch.cuda.is_available()\n",
368
+ " except Exception:\n",
369
+ " return False\n",
370
+ "\n",
371
+ "\n",
372
+ "if torch.cuda.is_available():\n",
373
+ " dtype = torch.bfloat16\n",
374
+ " attn_impl = \"flash_attention_2\" if _has_flash_attn() else \"sdpa\"\n",
375
+ "elif getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available():\n",
376
+ " dtype, attn_impl = torch.float16, \"sdpa\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  "else:\n",
378
+ " dtype, attn_impl = torch.float32, \"eager\"\n",
379
+ "\n",
380
+ "print(f\"Loading {MODEL_NAME} (dtype={dtype}, attn={attn_impl})...\")\n",
381
+ "model = AutoModelForCausalLM.from_pretrained(\n",
382
+ " MODEL_NAME,\n",
383
+ " trust_remote_code=True,\n",
384
+ " dtype=dtype,\n",
385
+ " attn_implementation=attn_impl,\n",
386
+ " device_map=\"cuda:0\" if torch.cuda.is_available() else None,\n",
387
+ ")\n",
388
+ "if not torch.cuda.is_available():\n",
389
+ " model = model.to(\"mps\") if (getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available()) else model.to(\"cpu\")\n",
 
 
 
 
 
390
  "\n",
391
  "model.eval()\n",
392
+ "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
 
 
 
 
393
  "if torch.cuda.is_available():\n",
394
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
395
+ ],
396
+ "execution_count": null,
397
+ "outputs": []
398
  },
399
  {
400
  "cell_type": "code",
 
401
  "metadata": {},
 
402
  "source": [
403
  "# Cell 8: LLM agent functions\n",
404
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
 
446
  " f\"Plan your actions (JSON only):\")\n",
447
  "\n",
448
  "\n",
449
+ "def is_well_formed_response(text):\n",
450
+ " try:\n",
451
+ " t = text.strip()\n",
452
+ " if \"```\" in t:\n",
453
+ " t = \"\\n\".join(l for l in t.split(\"\\n\") if not l.strip().startswith(\"```\")).strip()\n",
454
+ " s, e = t.find(\"{\"), t.rfind(\"}\") + 1\n",
455
+ " d = json.loads(t[s:e])\n",
456
+ " for tc in d.get(\"tool_calls\", []):\n",
457
+ " if not isinstance(tc, dict) or not isinstance(tc.get(\"arguments\", {}), dict):\n",
458
+ " return False\n",
459
+ " return True\n",
460
+ " except Exception:\n",
461
+ " return False\n",
462
+ "\n",
463
+ "\n",
464
  "def parse_model_output(text):\n",
465
  " text = text.strip()\n",
466
  " if \"```\" in text:\n",
 
471
  " text = text[start:end]\n",
472
  " try:\n",
473
  " data = json.loads(text)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  " except Exception:\n",
 
475
  " return ViraltestAction(scheduled_actions=[])\n",
476
+ " tool_calls = []\n",
477
+ " for tc in data.get(\"tool_calls\", []):\n",
478
+ " if not isinstance(tc, dict) or \"name\" not in tc:\n",
479
+ " continue\n",
480
+ " args = tc.get(\"arguments\", {})\n",
481
+ " if isinstance(args, list) and args and isinstance(args[0], dict):\n",
482
+ " args = args[0]\n",
483
+ " if not isinstance(args, dict):\n",
484
+ " continue\n",
485
+ " try:\n",
486
+ " tool_calls.append(ToolCall(name=tc[\"name\"], arguments=args))\n",
487
+ " except Exception:\n",
488
+ " pass\n",
489
+ " scheduled = []\n",
490
+ " for a in data.get(\"scheduled_actions\", []):\n",
491
+ " try:\n",
492
+ " scheduled.append(ScheduledAction(**a))\n",
493
+ " except Exception:\n",
494
+ " pass\n",
495
+ " return ViraltestAction(\n",
496
+ " tool_calls=tool_calls,\n",
497
+ " scheduled_actions=scheduled,\n",
498
+ " replies=data.get(\"replies\", []),\n",
499
+ " notes=data.get(\"notes\"),\n",
500
+ " )\n",
501
  "\n",
502
  "\n",
503
  "def _infer_model_device(m):\n",
 
511
  " return torch.device(\"cpu\")\n",
512
  "\n",
513
  "\n",
514
+ "def _build_chat(history, prompt):\n",
515
+ " msgs = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
516
+ " msgs.extend(history[-14:])\n",
517
+ " msgs.append({\"role\": \"user\", \"content\": prompt})\n",
518
+ " return msgs\n",
519
+ "\n",
520
+ "\n",
521
+ "def _batched_generate(mdl, tok, prompts, temperature=0.7, max_new_tokens=512):\n",
522
+ " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
523
  " with torch.no_grad():\n",
524
+ " out = mdl.generate(\n",
525
+ " **enc, max_new_tokens=max_new_tokens, temperature=temperature,\n",
526
+ " do_sample=True, top_p=0.9, pad_token_id=tok.pad_token_id,\n",
527
+ " )\n",
528
+ " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
529
+ " return resps, enc[\"input_ids\"].shape[1]\n",
530
+ "\n",
531
+ "\n",
532
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True):\n",
533
+ " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
534
+ " n = len(tasks_seeds)\n",
535
+ " envs = [ViraltestEnvironment() for _ in range(n)]\n",
536
+ " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
537
+ " histories = [[] for _ in range(n)]\n",
538
+ " rewards = [[] for _ in range(n)]\n",
539
+ " energies = [[obs.creator_energy] for obs in obss]\n",
540
+ " pairs = [[] for _ in range(n)]\n",
541
+ " done_mask = [obs.done for obs in obss]\n",
542
+ " rest_resp = '{\"scheduled_actions\": []}'\n",
543
+ "\n",
544
+ " for day in range(1, TASK_HORIZON + 1):\n",
545
+ " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
546
+ " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n",
547
+ " if not active and not rest:\n",
548
+ " break\n",
549
+ "\n",
550
+ " resps_by_idx = {}\n",
551
+ " if active:\n",
552
+ " prompts = [format_obs(obss[i]) for i in active]\n",
553
+ " chats = [_build_chat(histories[i], p) for i, p in zip(active, prompts)]\n",
554
+ " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
555
+ " resps, ptok = _batched_generate(mdl, tok, texts)\n",
556
+ " if verbose:\n",
557
+ " print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
558
+ " for j, i in enumerate(active):\n",
559
+ " resps_by_idx[i] = (resps[j], prompts[j])\n",
560
+ " for i in rest:\n",
561
+ " resps_by_idx[i] = (rest_resp, format_obs(obss[i]))\n",
562
+ "\n",
563
+ " for i in range(n):\n",
564
+ " if done_mask[i] or i not in resps_by_idx:\n",
565
+ " continue\n",
566
+ " resp, prompt = resps_by_idx[i]\n",
567
+ " action = parse_model_output(resp)\n",
568
+ " pairs[i].append({\"prompt\": prompt, \"response\": resp})\n",
569
+ " obss[i] = envs[i].step(action)\n",
570
+ " r = obss[i].reward or 0.0\n",
571
+ " rewards[i].append(r)\n",
572
+ " energies[i].append(obss[i].creator_energy)\n",
573
+ " histories[i].extend([\n",
574
+ " {\"role\": \"user\", \"content\": prompt},\n",
575
+ " {\"role\": \"assistant\", \"content\": resp},\n",
576
+ " ])\n",
577
+ " if obss[i].done:\n",
578
+ " done_mask[i] = True\n",
579
+ "\n",
580
+ " GAMMA, TERMINAL_W = 0.95, 5.0\n",
581
+ " results = []\n",
582
+ " for i, (task, seed) in enumerate(tasks_seeds):\n",
583
+ " gs = (obss[i].metadata or {}).get(\"grader_score\", 0.0)\n",
584
+ " rets = [0.0] * len(rewards[i])\n",
585
+ " G = gs * TERMINAL_W\n",
586
+ " for t in reversed(range(len(rewards[i]))):\n",
587
+ " G = rewards[i][t] + GAMMA * G\n",
588
+ " rets[t] = G\n",
589
+ " for k, pr in enumerate(pairs[i]):\n",
590
+ " pr[\"return\"] = rets[k] if k < len(rets) else 0.0\n",
591
+ " results.append({\n",
592
+ " \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
593
+ " \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
594
+ " \"rewards\": rewards[i], \"returns\": rets, \"energies\": energies[i],\n",
595
+ " \"pairs\": pairs[i], \"follower_delta\": obss[i].follower_count - 10000,\n",
596
+ " \"burned_out\": obss[i].creator_energy <= 0,\n",
597
+ " })\n",
598
+ " return results\n",
599
  "\n",
600
  "\n",
601
  "def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
602
+ " return run_llm_episodes_batched(mdl, tok, [(task, seed)], verbose=verbose)[0]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  "\n",
604
+ "\n",
605
+ "print(\"LLM agent functions defined (batched).\")"
606
+ ],
607
+ "execution_count": null,
608
+ "outputs": []
609
  },
610
  {
611
  "cell_type": "markdown",
 
618
  },
619
  {
620
  "cell_type": "code",
 
621
  "metadata": {},
 
622
  "source": [
623
+ "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
624
+ "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
625
  "print(\"=\" * 60)\n",
626
  "\n",
627
+ "t0 = time.time()\n",
628
+ "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
629
+ "before_results = {r[\"task\"]: r for r in results}\n",
 
 
 
630
  "\n",
631
  "print(\"\\n\" + \"=\" * 60)\n",
632
+ "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
633
  "for t in TASKS:\n",
634
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
635
+ ],
636
+ "execution_count": null,
637
+ "outputs": []
638
  },
639
  {
640
  "cell_type": "markdown",
 
653
  },
654
  {
655
  "cell_type": "code",
 
656
  "metadata": {},
 
657
  "source": [
658
  "# Cell 10: Attach LoRA adapter\n",
659
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
668
  "model.enable_input_require_grads()\n",
669
  "peft_model = get_peft_model(model, lora_config)\n",
670
  "peft_model.print_trainable_parameters()"
671
+ ],
672
+ "execution_count": null,
673
+ "outputs": []
674
  },
675
  {
676
  "cell_type": "code",
 
677
  "metadata": {},
 
678
  "source": [
679
  "# Cell 11: Training loop\n",
680
  "from trl import SFTTrainer, SFTConfig\n",
 
697
  " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n",
698
  " print(f\"{'=' * 60}\")\n",
699
  "\n",
 
700
  " peft_model.eval()\n",
701
+ " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
702
+ " t_roll = time.time()\n",
703
+ " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False)\n",
704
+ " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
705
  "\n",
706
+ " all_pairs, episode_rewards, episode_graders = [], [], []\n",
707
+ " for ep, result in enumerate(results):\n",
 
 
708
  " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
709
  " episode_rewards.append(ep_reward)\n",
710
  " episode_graders.append(result[\"grader_score\"])\n",
711
+ " kept = 0\n",
712
  " for pr in result[\"pairs\"]:\n",
713
+ " if not is_well_formed_response(pr[\"response\"]):\n",
714
+ " continue\n",
715
  " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
716
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
717
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
718
+ " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
719
+ " kept += 1\n",
720
+ " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
721
+ " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
722
+ "\n",
723
+ " avg_r = float(np.mean(episode_rewards))\n",
724
+ " avg_g = float(np.mean(episode_graders))\n",
725
+ " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} | total pairs={len(all_pairs)}\")\n",
726
+ " if not all_pairs:\n",
727
+ " print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
728
+ " continue\n",
729
  "\n",
 
730
  " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
731
  " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
732
+ " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples (return >= {threshold:.3f})\")\n",
733
  "\n",
734
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
735
  "\n",
 
737
  " sft_config = SFTConfig(\n",
738
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
739
  " num_train_epochs=2,\n",
740
+ " per_device_train_batch_size=4,\n",
741
+ " gradient_accumulation_steps=2,\n",
742
  " learning_rate=2e-5,\n",
743
  " warmup_steps=5,\n",
744
  " logging_steps=5,\n",
745
  " save_strategy=\"no\",\n",
746
+ " max_length=4096,\n",
747
+ " bf16=True,\n",
748
  " report_to=\"none\",\n",
749
  " )\n",
750
  "\n",
 
769
  "elapsed = time.time() - t_start\n",
770
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
771
  "print(pd.DataFrame(training_log).to_string(index=False))"
772
+ ],
773
+ "execution_count": null,
774
+ "outputs": []
775
  },
776
  {
777
  "cell_type": "markdown",
 
784
  },
785
  {
786
  "cell_type": "code",
 
787
  "metadata": {},
 
788
  "source": [
789
+ "# Cell 12: Run trained model (batched)\n",
790
+ "print(\"Running TRAINED model on all tasks (batched)...\")\n",
791
  "print(\"=\" * 60)\n",
792
  "\n",
793
  "peft_model.eval()\n",
794
+ "t0 = time.time()\n",
795
+ "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
796
+ "after_results = {r[\"task\"]: r for r in results}\n",
 
 
 
797
  "\n",
798
  "print(\"\\n\" + \"=\" * 60)\n",
799
+ "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
800
  "for t in TASKS:\n",
801
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
802
+ ],
803
+ "execution_count": null,
804
+ "outputs": []
805
  },
806
  {
807
  "cell_type": "markdown",
 
812
  },
813
  {
814
  "cell_type": "code",
 
815
  "metadata": {},
 
816
  "source": [
817
  "# Cell 13: Training curves\n",
818
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
834
  "fig.tight_layout()\n",
835
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
836
  "plt.show()"
837
+ ],
838
+ "execution_count": null,
839
+ "outputs": []
840
  },
841
  {
842
  "cell_type": "code",
 
843
  "metadata": {},
 
844
  "source": [
845
  "# Cell 14: Before vs After\n",
846
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
870
  "fig.tight_layout()\n",
871
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
872
  "plt.show()"
873
+ ],
874
+ "execution_count": null,
875
+ "outputs": []
876
  },
877
  {
878
  "cell_type": "code",
 
879
  "metadata": {},
 
880
  "source": [
881
  "# Cell 15: Trajectory comparison\n",
882
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
900
  "fig.tight_layout()\n",
901
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
902
  "plt.show()"
903
+ ],
904
+ "execution_count": null,
905
+ "outputs": []
906
  },
907
  {
908
  "cell_type": "markdown",
 
913
  },
914
  {
915
  "cell_type": "code",
 
916
  "metadata": {},
 
917
  "source": [
918
  "# Cell 16: Final summary\n",
919
  "print(\"=\" * 67)\n",
 
950
  "\n",
951
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
952
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
953
+ ],
954
+ "execution_count": null,
955
+ "outputs": []
956
  },
957
  {
958
  "cell_type": "code",
 
959
  "metadata": {},
 
960
  "source": [
961
  "# Cell 17: Save adapter\n",
962
  "save_path = \"./viraltest_trained_adapter\"\n",
 
964
  "tokenizer.save_pretrained(save_path)\n",
965
  "print(f\"LoRA adapter saved to {save_path}\")\n",
966
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
967
+ ],
968
+ "execution_count": null,
969
+ "outputs": []
970
  }
971
  ],
972
  "metadata": {
 
992
  },
993
  "nbformat": 4,
994
  "nbformat_minor": 4
995
+ }