vaibhav12332112312 commited on
Commit
3326716
·
1 Parent(s): e955a2d

train(grpo): unified hint prompt, no-history chat, positive-advantage filter

Browse files

- env: surface audience_active_hours + competitor_recent_post_hours in obs metadata
- prompt: single audience-hours hint, same for train + eval (clean delta = LoRA only)
- runner: drop assistant history (kills 4712-tok bloat); never append synthetic rest into training pairs; carry step idx for return back-up
- decode: greedy at eval, sampled (T=1.0, top_p=0.95) at rollout
- filter: positive group-relative advantage only; QUALITY_FLOOR=0.40 skips bad rounds
- LoRA: r=8 attn-only; lr 5e-6, 1 epoch, max_len 2048 (less drift)

Made-with: Cursor

server/viraltest_environment.py CHANGED
@@ -1097,6 +1097,19 @@ class ViraltestEnvironment(Environment):
1097
  if grader_score is not None:
1098
  meta["grader_score"] = round(grader_score, 4)
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  burnout_risk = min(1.0, self._low_energy_days / 5.0)
1101
 
1102
  return ViraltestObservation(
 
1097
  if grader_score is not None:
1098
  meta["grader_score"] = round(grader_score, 4)
1099
 
1100
+ audience_hours: set = set()
1101
+ for seg in _AUDIENCE_DATA.get("segments", []):
1102
+ audience_hours.update(seg.get("active_hours", []))
1103
+ meta["audience_active_hours"] = sorted(audience_hours)
1104
+
1105
+ comp_hours = [
1106
+ (self._hour - p["hours_ago"]) % 24
1107
+ for comp in self._competitors
1108
+ for p in comp.recent_posts
1109
+ if p["hours_ago"] < 48
1110
+ ]
1111
+ meta["competitor_recent_post_hours"] = sorted(comp_hours)
1112
+
1113
  burnout_risk = min(1.0, self._low_energy_days / 5.0)
1114
 
1115
  return ViraltestObservation(
training/train_grpo.ipynb CHANGED
@@ -400,7 +400,7 @@
400
  "metadata": {},
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
- "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
404
  "You are an Instagram content strategy agent. Each step is one day.\n",
405
  "You manage a creator account over a 15-day cycle.\n",
406
  "\n",
@@ -439,6 +439,12 @@
439
  "- topic: free-form string\n",
440
  "- empty scheduled_actions = full day rest\"\"\")\n",
441
  "\n",
 
 
 
 
 
 
442
  "\n",
443
  "def format_obs(obs):\n",
444
  " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
@@ -449,6 +455,9 @@
449
  " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
450
  " f\"sends={signals.sends_per_reach:.3f} \"\n",
451
  " f\"saves={signals.saves:.3f}\\n\")\n",
 
 
 
452
  " tool_str = \"\"\n",
453
  " for tr in getattr(obs, \"tool_results\", []):\n",
454
  " if tr.success:\n",
@@ -459,8 +468,10 @@
459
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
460
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
461
  " f\"{signals_str}\"\n",
 
 
462
  " f\"Tool results:\\n{tool_str}\"\n",
463
- " f\"Plan your actions (JSON only):\")\n",
464
  "\n",
465
  "\n",
466
  "def is_well_formed_response(text):\n",
@@ -527,35 +538,37 @@
527
  " return torch.device(\"cpu\")\n",
528
  "\n",
529
  "\n",
530
- "def _build_chat(history, prompt):\n",
531
- " msgs = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
532
- " msgs.extend(history[-14:])\n",
533
- " msgs.append({\"role\": \"user\", \"content\": prompt})\n",
534
- " return msgs\n",
535
  "\n",
536
  "\n",
537
- "def _batched_generate(mdl, tok, prompts, temperature=0.7, max_new_tokens=512):\n",
538
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
 
 
 
 
 
539
  " with torch.no_grad():\n",
540
- " out = mdl.generate(\n",
541
- " **enc, max_new_tokens=max_new_tokens, temperature=temperature,\n",
542
- " do_sample=True, top_p=0.9, pad_token_id=tok.pad_token_id,\n",
543
- " )\n",
544
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
545
  " return resps, enc[\"input_ids\"].shape[1]\n",
546
  "\n",
547
  "\n",
548
- "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True):\n",
549
  " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
 
550
  " n = len(tasks_seeds)\n",
551
  " envs = [ViraltestEnvironment() for _ in range(n)]\n",
552
  " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
553
- " histories = [[] for _ in range(n)]\n",
554
  " rewards = [[] for _ in range(n)]\n",
555
  " energies = [[obs.creator_energy] for obs in obss]\n",
556
  " pairs = [[] for _ in range(n)]\n",
557
  " done_mask = [obs.done for obs in obss]\n",
558
- " rest_resp = '{\"scheduled_actions\": []}'\n",
559
  "\n",
560
  " for day in range(1, TASK_HORIZON + 1):\n",
561
  " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
@@ -563,33 +576,26 @@
563
  " if not active and not rest:\n",
564
  " break\n",
565
  "\n",
566
- " resps_by_idx = {}\n",
567
  " if active:\n",
568
  " prompts = [format_obs(obss[i]) for i in active]\n",
569
- " chats = [_build_chat(histories[i], p) for i, p in zip(active, prompts)]\n",
570
  " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
571
- " resps, ptok = _batched_generate(mdl, tok, texts)\n",
572
  " if verbose:\n",
573
  " print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
574
  " for j, i in enumerate(active):\n",
575
- " resps_by_idx[i] = (resps[j], prompts[j])\n",
576
- " for i in rest:\n",
577
- " resps_by_idx[i] = (rest_resp, format_obs(obss[i]))\n",
578
  "\n",
579
  " for i in range(n):\n",
580
- " if done_mask[i] or i not in resps_by_idx:\n",
581
  " continue\n",
582
- " resp, prompt = resps_by_idx[i]\n",
583
- " action = parse_model_output(resp)\n",
584
- " pairs[i].append({\"prompt\": prompt, \"response\": resp})\n",
585
- " obss[i] = envs[i].step(action)\n",
586
  " r = obss[i].reward or 0.0\n",
587
  " rewards[i].append(r)\n",
588
  " energies[i].append(obss[i].creator_energy)\n",
589
- " histories[i].extend([\n",
590
- " {\"role\": \"user\", \"content\": prompt},\n",
591
- " {\"role\": \"assistant\", \"content\": resp},\n",
592
- " ])\n",
593
  " if obss[i].done:\n",
594
  " done_mask[i] = True\n",
595
  "\n",
@@ -602,8 +608,9 @@
602
  " for t in reversed(range(len(rewards[i]))):\n",
603
  " G = rewards[i][t] + GAMMA * G\n",
604
  " rets[t] = G\n",
605
- " for k, pr in enumerate(pairs[i]):\n",
606
- " pr[\"return\"] = rets[k] if k < len(rets) else 0.0\n",
 
607
  " results.append({\n",
608
  " \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
609
  " \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
@@ -641,7 +648,7 @@
641
  "print(\"=\" * 60)\n",
642
  "\n",
643
  "t0 = time.time()\n",
644
- "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
645
  "before_results = {r[\"task\"]: r for r in results}\n",
646
  "\n",
647
  "print(\"\\n\" + \"=\" * 60)\n",
@@ -675,9 +682,8 @@
675
  "from peft import LoraConfig, get_peft_model, TaskType\n",
676
  "\n",
677
  "lora_config = LoraConfig(\n",
678
- " r=16, lora_alpha=32, lora_dropout=0.05,\n",
679
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
680
- " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
681
  " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
682
  ")\n",
683
  "\n",
@@ -698,7 +704,7 @@
698
  "\n",
699
  "NUM_ROUNDS = 4\n",
700
  "EPISODES_PER_ROUND = 6\n",
701
- "TOP_K_FRACTION = 0.5\n",
702
  "\n",
703
  "training_log = {\n",
704
  " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
@@ -716,7 +722,8 @@
716
  " peft_model.eval()\n",
717
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
718
  " t_roll = time.time()\n",
719
- " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False)\n",
 
720
  " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
721
  "\n",
722
  " all_pairs, episode_rewards, episode_graders = [], [], []\n",
@@ -728,7 +735,7 @@
728
  " for pr in result[\"pairs\"]:\n",
729
  " if not is_well_formed_response(pr[\"response\"]):\n",
730
  " continue\n",
731
- " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
732
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
733
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
734
  " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
@@ -738,28 +745,36 @@
738
  "\n",
739
  " avg_r = float(np.mean(episode_rewards))\n",
740
  " avg_g = float(np.mean(episode_graders))\n",
741
- " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} | total pairs={len(all_pairs)}\")\n",
 
742
  " if not all_pairs:\n",
743
  " print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
744
  " continue\n",
 
 
 
745
  "\n",
746
- " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
747
- " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
748
- " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples (return >= {threshold:.3f})\")\n",
 
 
 
 
749
  "\n",
750
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
751
  "\n",
752
  " # SFT training (real gradient updates)\n",
753
  " sft_config = SFTConfig(\n",
754
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
755
- " num_train_epochs=2,\n",
756
- " per_device_train_batch_size=4,\n",
757
- " gradient_accumulation_steps=2,\n",
758
- " learning_rate=2e-5,\n",
759
- " warmup_ratio=0.1,\n",
760
  " logging_steps=1,\n",
761
  " save_strategy=\"no\",\n",
762
- " max_length=4096,\n",
763
  " bf16=True,\n",
764
  " report_to=\"none\",\n",
765
  " )\n",
@@ -808,7 +823,7 @@
808
  "\n",
809
  "peft_model.eval()\n",
810
  "t0 = time.time()\n",
811
- "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
812
  "after_results = {r[\"task\"]: r for r in results}\n",
813
  "\n",
814
  "print(\"\\n\" + \"=\" * 60)\n",
 
400
  "metadata": {},
401
  "source": [
402
  "# Cell 8: LLM agent functions\n",
403
+ "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
404
  "You are an Instagram content strategy agent. Each step is one day.\n",
405
  "You manage a creator account over a 15-day cycle.\n",
406
  "\n",
 
439
  "- topic: free-form string\n",
440
  "- empty scheduled_actions = full day rest\"\"\")\n",
441
  "\n",
442
+ "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
443
+ "\n",
444
+ "HINT: schedule posts during/just before the audience_active_hours window — that is when your target users are online.\"\"\")\n",
445
+ "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
446
+ "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
447
+ "\n",
448
  "\n",
449
  "def format_obs(obs):\n",
450
  " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
 
455
  " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
456
  " f\"sends={signals.sends_per_reach:.3f} \"\n",
457
  " f\"saves={signals.saves:.3f}\\n\")\n",
458
+ " meta = getattr(obs, \"metadata\", None) or {}\n",
459
+ " aud = meta.get(\"audience_active_hours\") or []\n",
460
+ " comp = meta.get(\"competitor_recent_post_hours\") or []\n",
461
  " tool_str = \"\"\n",
462
  " for tr in getattr(obs, \"tool_results\", []):\n",
463
  " if tr.success:\n",
 
468
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
469
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
470
  " f\"{signals_str}\"\n",
471
+ " f\"audience_active_hours: {aud}\\n\"\n",
472
+ " f\"competitor_recent_post_hours: {comp}\\n\"\n",
473
  " f\"Tool results:\\n{tool_str}\"\n",
474
+ " f\"Plan today's actions (JSON only):\")\n",
475
  "\n",
476
  "\n",
477
  "def is_well_formed_response(text):\n",
 
538
  " return torch.device(\"cpu\")\n",
539
  "\n",
540
  "\n",
541
+ "def _build_chat(system, prompt):\n",
542
+ " return [\n",
543
+ " {\"role\": \"system\", \"content\": system},\n",
544
+ " {\"role\": \"user\", \"content\": prompt},\n",
545
+ " ]\n",
546
  "\n",
547
  "\n",
548
+ "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
549
  " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
550
+ " gen_kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.pad_token_id)\n",
551
+ " if eval:\n",
552
+ " gen_kwargs.update(do_sample=False)\n",
553
+ " else:\n",
554
+ " gen_kwargs.update(do_sample=True, temperature=1.0, top_p=0.95)\n",
555
  " with torch.no_grad():\n",
556
+ " out = mdl.generate(**enc, **gen_kwargs)\n",
 
 
 
557
  " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
558
  " return resps, enc[\"input_ids\"].shape[1]\n",
559
  "\n",
560
  "\n",
561
+ "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None):\n",
562
  " \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
563
+ " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
564
  " n = len(tasks_seeds)\n",
565
  " envs = [ViraltestEnvironment() for _ in range(n)]\n",
566
  " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
 
567
  " rewards = [[] for _ in range(n)]\n",
568
  " energies = [[obs.creator_energy] for obs in obss]\n",
569
  " pairs = [[] for _ in range(n)]\n",
570
  " done_mask = [obs.done for obs in obss]\n",
571
+ " rest_action = ViraltestAction(scheduled_actions=[])\n",
572
  "\n",
573
  " for day in range(1, TASK_HORIZON + 1):\n",
574
  " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
 
576
  " if not active and not rest:\n",
577
  " break\n",
578
  "\n",
579
+ " actions_by_idx = {i: rest_action for i in rest}\n",
580
  " if active:\n",
581
  " prompts = [format_obs(obss[i]) for i in active]\n",
582
+ " chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
583
  " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
584
+ " resps, ptok = _batched_generate(mdl, tok, texts, eval=eval)\n",
585
  " if verbose:\n",
586
  " print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
587
  " for j, i in enumerate(active):\n",
588
+ " actions_by_idx[i] = parse_model_output(resps[j])\n",
589
+ " pairs[i].append({\"prompt\": prompts[j], \"response\": resps[j],\n",
590
+ " \"step\": len(rewards[i])})\n",
591
  "\n",
592
  " for i in range(n):\n",
593
+ " if done_mask[i] or i not in actions_by_idx:\n",
594
  " continue\n",
595
+ " obss[i] = envs[i].step(actions_by_idx[i])\n",
 
 
 
596
  " r = obss[i].reward or 0.0\n",
597
  " rewards[i].append(r)\n",
598
  " energies[i].append(obss[i].creator_energy)\n",
 
 
 
 
599
  " if obss[i].done:\n",
600
  " done_mask[i] = True\n",
601
  "\n",
 
608
  " for t in reversed(range(len(rewards[i]))):\n",
609
  " G = rewards[i][t] + GAMMA * G\n",
610
  " rets[t] = G\n",
611
+ " for pr in pairs[i]:\n",
612
+ " k = pr.get(\"step\", 0)\n",
613
+ " pr[\"return\"] = rets[k] if 0 <= k < len(rets) else 0.0\n",
614
  " results.append({\n",
615
  " \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
616
  " \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
 
648
  "print(\"=\" * 60)\n",
649
  "\n",
650
  "t0 = time.time()\n",
651
+ "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
652
  "before_results = {r[\"task\"]: r for r in results}\n",
653
  "\n",
654
  "print(\"\\n\" + \"=\" * 60)\n",
 
682
  "from peft import LoraConfig, get_peft_model, TaskType\n",
683
  "\n",
684
  "lora_config = LoraConfig(\n",
685
+ " r=8, lora_alpha=16, lora_dropout=0.05,\n",
686
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
 
687
  " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
688
  ")\n",
689
  "\n",
 
704
  "\n",
705
  "NUM_ROUNDS = 4\n",
706
  "EPISODES_PER_ROUND = 6\n",
707
+ "QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n",
708
  "\n",
709
  "training_log = {\n",
710
  " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
 
722
  " peft_model.eval()\n",
723
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
724
  " t_roll = time.time()\n",
725
+ " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n",
726
+ " eval=False, system=SYSTEM_PROMPT_TRAIN)\n",
727
  " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
728
  "\n",
729
  " all_pairs, episode_rewards, episode_graders = [], [], []\n",
 
735
  " for pr in result[\"pairs\"]:\n",
736
  " if not is_well_formed_response(pr[\"response\"]):\n",
737
  " continue\n",
738
+ " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT_TRAIN}<|im_end|>\\n\"\n",
739
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
740
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
741
  " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
 
745
  "\n",
746
  " avg_r = float(np.mean(episode_rewards))\n",
747
  " avg_g = float(np.mean(episode_graders))\n",
748
+ " max_g = float(max(episode_graders))\n",
749
+ " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
750
  " if not all_pairs:\n",
751
  " print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
752
  " continue\n",
753
+ " if max_g < QUALITY_FLOOR:\n",
754
+ " print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
755
+ " continue\n",
756
  "\n",
757
+ " rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
758
+ " adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
759
+ " filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
760
+ " if not filtered:\n",
761
+ " print(\" SKIP SFT: zero positive-advantage samples\")\n",
762
+ " continue\n",
763
+ " print(f\" Kept {len(filtered)}/{len(all_pairs)} positive-advantage samples\")\n",
764
  "\n",
765
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
766
  "\n",
767
  " # SFT training (real gradient updates)\n",
768
  " sft_config = SFTConfig(\n",
769
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
770
+ " num_train_epochs=1,\n",
771
+ " per_device_train_batch_size=2,\n",
772
+ " gradient_accumulation_steps=4,\n",
773
+ " learning_rate=5e-6,\n",
774
+ " warmup_steps=5,\n",
775
  " logging_steps=1,\n",
776
  " save_strategy=\"no\",\n",
777
+ " max_length=2048,\n",
778
  " bf16=True,\n",
779
  " report_to=\"none\",\n",
780
  " )\n",
 
823
  "\n",
824
  "peft_model.eval()\n",
825
  "t0 = time.time()\n",
826
+ "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
827
  "after_results = {r[\"task\"]: r for r in results}\n",
828
  "\n",
829
  "print(\"\\n\" + \"=\" * 60)\n",