anuragredbus commited on
Commit
0587f05
·
1 Parent(s): 0e50d91

chore: align train_grpo.ipynb with smoke/syntax patterns for Colab

Browse files

- ast.parse sanity after TASK_HORIZON assert (like syntax_only)
- Replace bare except in parse_model_output
- Use _infer_model_device for token inputs (Peft/4-bit safe)
- Doc note: run syntax_only + smoke first; quoted pip already

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +59 -11
training/train_grpo.ipynb CHANGED
@@ -18,7 +18,9 @@
18
  "\n",
19
  "**Requirements:** Colab T4 GPU (free tier), ~45 min total.\n",
20
  "\n",
21
- "**What makes this real training:** LoRA adapter weights are actually updated via gradient descent. The model's behavior changes because its weights change, not because we edit the prompt."
 
 
22
  ]
23
  },
24
  {
@@ -40,7 +42,9 @@
40
  "\n",
41
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
42
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
43
- "zsh:1: 4.45.0 not found\n",
 
 
44
  "\n",
45
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
46
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
@@ -142,7 +146,7 @@
142
  "Repo root: /Users/anurag.c/viral-posts-env\n",
143
  "Working dir: /Users/anurag.c/viral-posts-env\n",
144
  "Branch: hack1\n",
145
- "Commit: b2fc6b6\n",
146
  "Plots dir: /Users/anurag.c/viral-posts-env/plots\n"
147
  ]
148
  }
@@ -198,7 +202,12 @@
198
  "assert TASK_HORIZON == 30, (\n",
199
  " f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}. \"\n",
200
  " \"Restart runtime and run from Cell 1 again (clean clone on hack1).\"\n",
201
- ")"
 
 
 
 
 
202
  ],
203
  "execution_count": 3,
204
  "outputs": [
@@ -506,7 +515,7 @@
506
  "if torch.cuda.is_available():\n",
507
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
508
  ],
509
- "execution_count": null,
510
  "outputs": [
511
  {
512
  "output_type": "stream",
@@ -515,6 +524,28 @@
515
  " On Colab: run `pip install -U bitsandbytes>=0.46.1` and use a GPU runtime.\n",
516
  " On Mac: use fp16 on MPS or fp32 on CPU.\n"
517
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  }
519
  ]
520
  },
@@ -582,21 +613,38 @@
582
  " for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
583
  " scheduled = []\n",
584
  " for a in data.get(\"scheduled_actions\", []):\n",
585
- " try: scheduled.append(ScheduledAction(**a))\n",
586
- " except: pass\n",
587
- " return ViraltestAction(tool_calls=tool_calls, scheduled_actions=scheduled,\n",
588
- " replies=data.get(\"replies\", []), notes=data.get(\"notes\"))\n",
589
- " except:\n",
 
 
 
 
 
 
590
  " return ViraltestAction(scheduled_actions=[])\n",
591
  "\n",
592
  "\n",
 
 
 
 
 
 
 
 
 
 
 
593
  "def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
594
  " prompt = format_obs(obs)\n",
595
  " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
596
  " messages.extend(history[-4:])\n",
597
  " messages.append({\"role\": \"user\", \"content\": prompt})\n",
598
  " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
599
- " inputs = tok(text_input, return_tensors=\"pt\").to(mdl.device)\n",
600
  " with torch.no_grad():\n",
601
  " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
602
  " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
 
18
  "\n",
19
  "**Requirements:** Colab T4 GPU (free tier), ~45 min total.\n",
20
  "\n",
21
+ "**What makes this real training:** LoRA adapter weights are actually updated via gradient descent. The model's behavior changes because its weights change, not because we edit the prompt.\n",
22
+ "\n",
23
+ "**Before this notebook:** run `training/syntax_only.ipynb` (kernel + syntax only) and `training/train_grpo_smoke.ipynb` (repo + env). Pip lines use quoted package specs so Colab/zsh does not break on `>=`."
24
  ]
25
  },
26
  {
 
42
  "\n",
43
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
44
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
45
+ "\n",
46
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
47
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
48
  "\n",
49
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
50
  "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
 
146
  "Repo root: /Users/anurag.c/viral-posts-env\n",
147
  "Working dir: /Users/anurag.c/viral-posts-env\n",
148
  "Branch: hack1\n",
149
+ "Commit: aedc9c7\n",
150
  "Plots dir: /Users/anurag.c/viral-posts-env/plots\n"
151
  ]
152
  }
 
202
  "assert TASK_HORIZON == 30, (\n",
203
  " f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}. \"\n",
204
  " \"Restart runtime and run from Cell 1 again (clean clone on hack1).\"\n",
205
+ ")\n",
206
+ "\n",
207
+ "# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n",
208
+ "import ast\n",
209
+ "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
210
+ "print(\"OK: ast.parse (syntax check)\")"
211
  ],
212
  "execution_count": 3,
213
  "outputs": [
 
515
  "if torch.cuda.is_available():\n",
516
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
517
  ],
518
+ "execution_count": 7,
519
  "outputs": [
520
  {
521
  "output_type": "stream",
 
524
  " On Colab: run `pip install -U bitsandbytes>=0.46.1` and use a GPU runtime.\n",
525
  " On Mac: use fp16 on MPS or fp32 on CPU.\n"
526
  ]
527
+ },
528
+ {
529
+ "output_type": "error",
530
+ "ename": "KeyboardInterrupt",
531
+ "evalue": "",
532
+ "traceback": [
533
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
534
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
535
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 44\u001b[39m\n\u001b[32m 40\u001b[39m \u001b[33m\" On Colab: run `pip install -U bitsandbytes>=0.46.1` and use a GPU runtime.\\n\"\u001b[39m\n\u001b[32m 41\u001b[39m \u001b[33m\" On Mac: use fp16 on MPS or fp32 on CPU.\"\u001b[39m\n\u001b[32m 42\u001b[39m )\n\u001b[32m 43\u001b[39m dtype = torch.float16 \u001b[38;5;28;01mif\u001b[39;00m (torch.cuda.is_available() \u001b[38;5;28;01mor\u001b[39;00m getattr(torch.backends, \u001b[33m\"mps\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mand\u001b[39;00m torch.backends.mps.is_available()) \u001b[38;5;28;01melse\u001b[39;00m torch.float32\n\u001b[32m---> \u001b[39m\u001b[32m44\u001b[39m model = AutoModelForCausalLM.from_pretrained(\n\u001b[32m 45\u001b[39m MODEL_NAME,\n\u001b[32m 46\u001b[39m trust_remote_code=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 47\u001b[39m dtype=dtype,\n",
536
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/models/auto/auto_factory.py:394\u001b[39m, in \u001b[36m_BaseAutoModelClass.from_pretrained\u001b[39m\u001b[34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[39m\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(parent_config, \u001b[33m\"\u001b[39m\u001b[33mquantization_config\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 393\u001b[39m config.quantization_config = parent_config.quantization_config\n\u001b[32m--> \u001b[39m\u001b[32m394\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mmodel_class\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mfrom_pretrained\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 395\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpretrained_model_name_or_path\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mmodel_args\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mhub_kwargs\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\n\u001b[32m 396\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 397\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 398\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig.\u001b[34m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m, \u001b[39m\u001b[33m'\u001b[39m.join(c.\u001b[34m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m._model_mapping)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 400\u001b[39m )\n",
537
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/modeling_utils.py:4118\u001b[39m, in \u001b[36mPreTrainedModel.from_pretrained\u001b[39m\u001b[34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, fusion_config, disable_mmap, *model_args, **kwargs)\u001b[39m\n\u001b[32m 4113\u001b[39m logger.warning_once(\n\u001b[32m 4114\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mA kernel_config was provided but use_kernels is False; setting use_kernels=True automatically. To suppress this warning, explicitly set use_kernels to True.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 4115\u001b[39m )\n\u001b[32m 4116\u001b[39m use_kernels = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m4118\u001b[39m checkpoint_files, sharded_metadata = \u001b[30;43m_get_resolved_checkpoint_files\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 4119\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpretrained_model_name_or_path\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpretrained_model_name_or_path\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4120\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mvariant\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mvariant\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4121\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mgguf_file\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mgguf_file\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4122\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43muse_safetensors\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43muse_safetensors\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4123\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdownload_kwargs\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdownload_kwargs_with_commit\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4124\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43muser_agent\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43muser_agent\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4125\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mis_remote_code\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mcls\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mis_remote_code\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4126\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtransformers_explicit_filename\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mgetattr\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43mtransformers_weights\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43;01mNone\u001b[39;49;00m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4127\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 4128\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 4130\u001b[39m is_quantized = hf_quantizer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 4132\u001b[39m \u001b[38;5;66;03m# Find the correct dtype based on current state\u001b[39;00m\n",
538
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/modeling_utils.py:660\u001b[39m, in \u001b[36m_get_resolved_checkpoint_files\u001b[39m\u001b[34m(pretrained_model_name_or_path, variant, gguf_file, use_safetensors, user_agent, is_remote_code, transformers_explicit_filename, download_kwargs, tqdm_class)\u001b[39m\n\u001b[32m 648\u001b[39m can_auto_convert = (\n\u001b[32m 649\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m is_offline_mode() \u001b[38;5;66;03m# for obvious reasons\u001b[39;00m\n\u001b[32m 650\u001b[39m \u001b[38;5;66;03m# If we are in a CI environment or in a pytest run, we prevent the conversion\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 653\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m subfolder == \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# converter bot does not work on subfolders\u001b[39;00m\n\u001b[32m 654\u001b[39m )\n\u001b[32m 656\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 657\u001b[39m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[32m 658\u001b[39m \u001b[38;5;66;03m# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None\u001b[39;00m\n\u001b[32m 659\u001b[39m \u001b[38;5;66;03m# result when internet is up, the repo and revision exist, but the file does not.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m660\u001b[39m resolved_archive_file = \u001b[30;43mcached_file\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mpretrained_model_name_or_path\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mfilename\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mcached_file_kwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 662\u001b[39m \u001b[38;5;66;03m# Try safetensors files first if not already found\u001b[39;00m\n\u001b[32m 663\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m resolved_archive_file \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m filename == _add_variant(SAFE_WEIGHTS_NAME, variant):\n\u001b[32m 664\u001b[39m \u001b[38;5;66;03m# Maybe the checkpoint is sharded, we try to grab the index name in this case.\u001b[39;00m\n",
539
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/utils/hub.py:278\u001b[39m, in \u001b[36mcached_file\u001b[39m\u001b[34m(path_or_repo_id, filename, **kwargs)\u001b[39m\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcached_file\u001b[39m(\n\u001b[32m 224\u001b[39m path_or_repo_id: \u001b[38;5;28mstr\u001b[39m | os.PathLike,\n\u001b[32m 225\u001b[39m filename: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m 226\u001b[39m **kwargs,\n\u001b[32m 227\u001b[39m ) -> \u001b[38;5;28mstr\u001b[39m | \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 228\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 229\u001b[39m \u001b[33;03m Tries to locate a file in a local folder and repo, downloads and cache it if necessary.\u001b[39;00m\n\u001b[32m 230\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 276\u001b[39m \u001b[33;03m ```\u001b[39;00m\n\u001b[32m 277\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m278\u001b[39m file = \u001b[30;43mcached_files\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mpath_or_repo_id\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpath_or_repo_id\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mfilenames\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43m[\u001b[39;49m\u001b[30;43mfilename\u001b[39;49m\u001b[30;43m]\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 279\u001b[39m file = file[\u001b[32m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m file\n\u001b[32m 280\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m file\n",
540
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/utils/hub.py:422\u001b[39m, in \u001b[36mcached_files\u001b[39m\u001b[34m(path_or_repo_id, filenames, cache_dir, force_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, tqdm_class, **deprecated_kwargs)\u001b[39m\n\u001b[32m 419\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 420\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(full_filenames) == \u001b[32m1\u001b[39m:\n\u001b[32m 421\u001b[39m \u001b[38;5;66;03m# This is slightly better for only 1 file\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m422\u001b[39m \u001b[30;43mhf_hub_download\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 423\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpath_or_repo_id\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 424\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mfilenames\u001b[39;49m\u001b[30;43m[\u001b[39;49m\u001b[30;43m0\u001b[39;49m\u001b[30;43m]\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 425\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43msubfolder\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mNone\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43;01mif\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43mlen\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43msubfolder\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m==\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m0\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43;01melse\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43msubfolder\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 426\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mrepo_type\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrepo_type\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 427\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mrevision\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrevision\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 428\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mcache_dir\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mcache_dir\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 429\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43muser_agent\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43muser_agent\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 430\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mforce_download\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mforce_download\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 431\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mproxies\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mproxies\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 432\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtoken\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtoken\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 433\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mlocal_files_only\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mlocal_files_only\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 434\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 435\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 436\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 437\u001b[39m snapshot_download(\n\u001b[32m 438\u001b[39m path_or_repo_id,\n\u001b[32m 439\u001b[39m allow_patterns=full_filenames,\n\u001b[32m (...)\u001b[39m\u001b[32m 448\u001b[39m tqdm_class=tqdm_class,\n\u001b[32m 449\u001b[39m )\n",
541
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/huggingface_hub/utils/_validators.py:88\u001b[39m, in \u001b[36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 84\u001b[39m validate_repo_id(arg_value)\n\u001b[32m 86\u001b[39m kwargs = smoothly_deprecate_legacy_arguments(fn_name=fn.\u001b[34m__name__\u001b[39m, kwargs=kwargs)\n\u001b[32m---> \u001b[39m\u001b[32m88\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mfn\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43margs\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n",
542
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/huggingface_hub/file_download.py:995\u001b[39m, in \u001b[36mhf_hub_download\u001b[39m\u001b[34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, etag_timeout, token, local_files_only, headers, endpoint, tqdm_class, dry_run)\u001b[39m\n\u001b[32m 974\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _hf_hub_download_to_local_dir(\n\u001b[32m 975\u001b[39m \u001b[38;5;66;03m# Destination\u001b[39;00m\n\u001b[32m 976\u001b[39m local_dir=local_dir,\n\u001b[32m (...)\u001b[39m\u001b[32m 992\u001b[39m dry_run=dry_run,\n\u001b[32m 993\u001b[39m )\n\u001b[32m 994\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m995\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43m_hf_hub_download_to_cache_dir\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 996\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43;03m# Destination\u001b[39;49;00m\n\u001b[32m 997\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mcache_dir\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mcache_dir\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 998\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43;03m# File info\u001b[39;49;00m\n\u001b[32m 999\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mrepo_id\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrepo_id\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1000\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mfilename\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mfilename\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1001\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mrepo_type\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrepo_type\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1002\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mrevision\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrevision\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1003\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43;03m# HTTP info\u001b[39;49;00m\n\u001b[32m 1004\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mendpoint\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mendpoint\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1005\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43metag_timeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43metag_timeout\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1006\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mhf_headers\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1007\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtoken\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtoken\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1008\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43;03m# Additional options\u001b[39;49;00m\n\u001b[32m 1009\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mlocal_files_only\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mlocal_files_only\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1010\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mforce_download\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mforce_download\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1011\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1012\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdry_run\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdry_run\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1013\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n",
543
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/huggingface_hub/file_download.py:1213\u001b[39m, in \u001b[36m_hf_hub_download_to_cache_dir\u001b[39m\u001b[34m(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, token, local_files_only, force_download, tqdm_class, dry_run)\u001b[39m\n\u001b[32m 1209\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m pointer_path\n\u001b[32m 1211\u001b[39m \u001b[38;5;66;03m# Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache)\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1213\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43;01mwith\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43mWeakFileLock\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mlock_path\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m:\u001b[39;49m\n\u001b[32m 1214\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m_download_to_tmp_and_move\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 1215\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mincomplete_path\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mPath\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mblob_path\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m+\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43m.incomplete\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1216\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdestination_path\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mPath\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mblob_path\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m (...)\u001b[39m\u001b[32m 1224\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtqdm_class\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 1225\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 1226\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43;01mif\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43;01mnot\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43mos\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mpath\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mexists\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mpointer_path\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m:\u001b[39;49m\n",
544
+ "\u001b[36mFile \u001b[39m\u001b[32m/opt/homebrew/Cellar/python@3.14/3.14.2_1/Frameworks/Python.framework/Versions/3.14/lib/python3.14/contextlib.py:141\u001b[39m, in \u001b[36m_GeneratorContextManager.__enter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 139\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m.args, \u001b[38;5;28mself\u001b[39m.kwds, \u001b[38;5;28mself\u001b[39m.func\n\u001b[32m 140\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m141\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mnext\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mgen\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 142\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[32m 143\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mgenerator didn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt yield\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n",
545
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/huggingface_hub/utils/_fixes.py:99\u001b[39m, in \u001b[36mWeakFileLock\u001b[39m\u001b[34m(lock_file, timeout)\u001b[39m\n\u001b[32m 96\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m Timeout(\u001b[38;5;28mstr\u001b[39m(lock_file))\n\u001b[32m 98\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m99\u001b[39m \u001b[30;43mlock\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43macquire\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmin\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mlog_interval\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m-\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43melapsed_time\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43;01mif\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43;01melse\u001b[39;49;00m\u001b[30;43m \u001b[39;49m\u001b[30;43mlog_interval\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 100\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Timeout:\n\u001b[32m 101\u001b[39m logger.info(\n\u001b[32m 102\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mStill waiting to acquire lock on \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlock_file\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m (elapsed: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtime.time()\u001b[38;5;250m \u001b[39m-\u001b[38;5;250m \u001b[39mstart_time\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m seconds)\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 103\u001b[39m )\n",
546
+ "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/filelock/_api.py:513\u001b[39m, in \u001b[36mBaseFileLock.acquire\u001b[39m\u001b[34m(self, timeout, poll_interval, poll_intervall, blocking, cancel_check)\u001b[39m\n\u001b[32m 511\u001b[39m msg = \u001b[33m\"\u001b[39m\u001b[33mLock \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m not acquired on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m, waiting \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m seconds ...\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 512\u001b[39m _LOGGER.debug(msg, lock_id, lock_filename, poll_interval)\n\u001b[32m--> \u001b[39m\u001b[32m513\u001b[39m \u001b[30;43mtime\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msleep\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mpoll_interval\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 514\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m:\n\u001b[32m 515\u001b[39m \u001b[38;5;28mself\u001b[39m._context.lock_counter = \u001b[38;5;28mmax\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28mself\u001b[39m._context.lock_counter - \u001b[32m1\u001b[39m)\n",
547
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
548
+ ]
549
  }
550
  ]
551
  },
 
613
  " for tc in data.get(\"tool_calls\", []) if isinstance(tc, dict) and \"name\" in tc]\n",
614
  " scheduled = []\n",
615
  " for a in data.get(\"scheduled_actions\", []):\n",
616
+ " try:\n",
617
+ " scheduled.append(ScheduledAction(**a))\n",
618
+ " except (TypeError, ValueError, KeyError):\n",
619
+ " pass\n",
620
+ " return ViraltestAction(\n",
621
+ " tool_calls=tool_calls,\n",
622
+ " scheduled_actions=scheduled,\n",
623
+ " replies=data.get(\"replies\", []),\n",
624
+ " notes=data.get(\"notes\"),\n",
625
+ " )\n",
626
+ " except (json.JSONDecodeError, TypeError, ValueError, KeyError):\n",
627
  " return ViraltestAction(scheduled_actions=[])\n",
628
  "\n",
629
  "\n",
630
+ "def _infer_model_device(m):\n",
631
+ " \"\"\"Works for single/multi-device models (Peft, 4-bit) where m.device may be missing.\"\"\"\n",
632
+ " p = next(m.parameters(), None)\n",
633
+ " if p is not None:\n",
634
+ " return p.device\n",
635
+ " d = getattr(m, \"device\", None)\n",
636
+ " if d is not None:\n",
637
+ " return d\n",
638
+ " return torch.device(\"cpu\")\n",
639
+ "\n",
640
+ "\n",
641
  "def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
642
  " prompt = format_obs(obs)\n",
643
  " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
644
  " messages.extend(history[-4:])\n",
645
  " messages.append({\"role\": \"user\", \"content\": prompt})\n",
646
  " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
647
+ " inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
648
  " with torch.no_grad():\n",
649
  " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
650
  " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",