anuragredbus commited on
Commit
b55c1ff
·
1 Parent(s): eb1d764

add train_grpo_smoke notebook; quote pip versions in train_grpo

Browse files

- Smoke notebook: repo setup, imports, TASK_HORIZON=30, one episode, optional ML imports
- Fix zsh redirect bug from unquoted transformers>= in pip cell

Made-with: Cursor

training/train_grpo.ipynb CHANGED
@@ -25,9 +25,9 @@
25
  "cell_type": "code",
26
  "metadata": {},
27
  "source": [
28
- "# Cell 1: Install dependencies\n",
29
  "!pip install -q torch torchvision torchaudio\n",
30
- "!pip install -q transformers>=4.45.0 accelerate peft>=0.10.0 trl>=0.20.0 datasets bitsandbytes\n",
31
  "!pip install -q matplotlib pandas\n",
32
  "!pip install -q pydantic httpx\n",
33
  "!pip install -q \"openenv-core[core]>=0.2.2\""
@@ -142,7 +142,7 @@
142
  "Repo root: /Users/anurag.c/viral-posts-env\n",
143
  "Working dir: /Users/anurag.c/viral-posts-env\n",
144
  "Branch: hack1\n",
145
- "Commit: b5ad200\n",
146
  "Plots dir: /Users/anurag.c/viral-posts-env/plots\n"
147
  ]
148
  }
@@ -506,27 +506,14 @@
506
  "if torch.cuda.is_available():\n",
507
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
508
  ],
509
- "execution_count": 7,
510
  "outputs": [
511
  {
512
  "output_type": "stream",
513
  "text": [
514
- "Loading Qwen/Qwen2.5-1.5B-Instruct (4-bit quantized)...\n"
515
- ]
516
- },
517
- {
518
- "output_type": "error",
519
- "ename": "ImportError",
520
- "evalue": "Using `bitsandbytes` 4-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>=0.46.1`",
521
- "traceback": [
522
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
523
- "\u001b[31mImportError\u001b[39m Traceback (most recent call last)",
524
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[32m 11\u001b[39m )\n\u001b[32m 12\u001b[39m \n\u001b[32m 13\u001b[39m print(f\"Loading {MODEL_NAME} (4-bit quantized)...\")\n\u001b[32m 14\u001b[39m tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m model = AutoModelForCausalLM.from_pretrained(\n\u001b[32m 16\u001b[39m MODEL_NAME, trust_remote_code=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 17\u001b[39m quantization_config=bnb_config,\n\u001b[32m 18\u001b[39m device_map=\u001b[33m\"auto\"\u001b[39m,\n",
525
- "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/models/auto/auto_factory.py:394\u001b[39m, in \u001b[36m_BaseAutoModelClass.from_pretrained\u001b[39m\u001b[34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[39m\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(parent_config, \u001b[33m\"\u001b[39m\u001b[33mquantization_config\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 393\u001b[39m config.quantization_config = parent_config.quantization_config\n\u001b[32m--> \u001b[39m\u001b[32m394\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mmodel_class\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mfrom_pretrained\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 395\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpretrained_model_name_or_path\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mmodel_args\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mhub_kwargs\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\n\u001b[32m 396\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 397\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 398\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig.\u001b[34m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m, \u001b[39m\u001b[33m'\u001b[39m.join(c.\u001b[34m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m._model_mapping)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 400\u001b[39m )\n",
526
- "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/modeling_utils.py:4095\u001b[39m, in \u001b[36mPreTrainedModel.from_pretrained\u001b[39m\u001b[34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, fusion_config, disable_mmap, *model_args, **kwargs)\u001b[39m\n\u001b[32m 4092\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mexperts_implementation\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m kwargs:\n\u001b[32m 4093\u001b[39m config._experts_implementation = kwargs.pop(\u001b[33m\"\u001b[39m\u001b[33mexperts_implementation\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m4095\u001b[39m hf_quantizer, config, device_map = \u001b[30;43mget_hf_quantizer\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 4096\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mconfig\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mquantization_config\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mdevice_map\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mweights_only\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43muser_agent\u001b[39;49m\n\u001b[32m 4097\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 4099\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m gguf_file:\n\u001b[32m 4100\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m hf_quantizer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
527
- "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/quantizers/auto.py:342\u001b[39m, in \u001b[36mget_hf_quantizer\u001b[39m\u001b[34m(config, quantization_config, device_map, weights_only, user_agent)\u001b[39m\n\u001b[32m 339\u001b[39m hf_quantizer = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 341\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m hf_quantizer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m342\u001b[39m \u001b[30;43mhf_quantizer\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mvalidate_environment\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 343\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdevice_map\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdevice_map\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 344\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mweights_only\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mweights_only\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 345\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 346\u001b[39m device_map = hf_quantizer.update_device_map(device_map)\n\u001b[32m 347\u001b[39m config = hf_quantizer.update_tp_plan(config)\n",
528
- "\u001b[36mFile \u001b[39m\u001b[32m~/viral-posts-env/.venv/lib/python3.14/site-packages/transformers/quantizers/quantizer_bnb_4bit.py:62\u001b[39m, in \u001b[36mBnb4BitHfQuantizer.validate_environment\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[32m 59\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mUsing `bitsandbytes` 4-bit quantization requires accelerate: `pip install \u001b[39m\u001b[33m'\u001b[39m\u001b[33maccelerate>=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mACCELERATE_MIN_VERSION\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m`\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 60\u001b[39m )\n\u001b[32m 61\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_bitsandbytes_available():\n\u001b[32m---> \u001b[39m\u001b[32m62\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[32m 63\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mUsing `bitsandbytes` 4-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mBITSANDBYTES_MIN_VERSION\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m`\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 64\u001b[39m )\n\u001b[32m 66\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mintegrations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m validate_bnb_backend_availability\n\u001b[32m 68\u001b[39m validate_bnb_backend_availability(raise_exception=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
529
- "\u001b[31mImportError\u001b[39m: Using `bitsandbytes` 4-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>=0.46.1`"
530
  ]
531
  }
532
  ]
 
25
  "cell_type": "code",
26
  "metadata": {},
27
  "source": [
28
+ "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
29
  "!pip install -q torch torchvision torchaudio\n",
30
+ "!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\" \"bitsandbytes\"\n",
31
  "!pip install -q matplotlib pandas\n",
32
  "!pip install -q pydantic httpx\n",
33
  "!pip install -q \"openenv-core[core]>=0.2.2\""
 
142
  "Repo root: /Users/anurag.c/viral-posts-env\n",
143
  "Working dir: /Users/anurag.c/viral-posts-env\n",
144
  "Branch: hack1\n",
145
+ "Commit: b2fc6b6\n",
146
  "Plots dir: /Users/anurag.c/viral-posts-env/plots\n"
147
  ]
148
  }
 
506
  "if torch.cuda.is_available():\n",
507
  " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
508
  ],
509
+ "execution_count": null,
510
  "outputs": [
511
  {
512
  "output_type": "stream",
513
  "text": [
514
+ "Loading Qwen/Qwen2.5-1.5B-Instruct without 4-bit (bitsandbytes/CUDA unavailable).\n",
515
+ " On Colab: run `pip install -U bitsandbytes>=0.46.1` and use a GPU runtime.\n",
516
+ " On Mac: use fp16 on MPS or fp32 on CPU.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  ]
518
  }
519
  ]
training/train_grpo_smoke.ipynb ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 4,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10.0"
13
+ }
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "# `train_grpo_smoke.ipynb` — syntax & environment smoke test\n",
21
+ "\n",
22
+ "Companion to `train_grpo.ipynb`. **Fast** (~1–2 min): checks imports, repo layout, `TASK_HORIZON`, and one short env run.\n",
23
+ "\n",
24
+ "Run **all cells top to bottom** in Colab or locally before starting the full training notebook."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "metadata": {},
30
+ "execution_count": null,
31
+ "outputs": [],
32
+ "source": [
33
+ "# Cell 1: Minimal deps (quoted versions for zsh / shell safety)\n",
34
+ "!pip install -q pydantic httpx\n",
35
+ "!pip install -q \"openenv-core[core]>=0.2.2\""
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "metadata": {},
41
+ "execution_count": null,
42
+ "outputs": [],
43
+ "source": [
44
+ "# Cell 2: Repo path (same logic as main notebook)\n",
45
+ "import os\n",
46
+ "import sys\n",
47
+ "import shutil\n",
48
+ "import subprocess\n",
49
+ "from pathlib import Path\n",
50
+ "\n",
51
+ "REPO_BRANCH = \"hack1\"\n",
52
+ "REPO_URL = \"https://github.com/VaibhavKhandare/viral-posts-env.git\"\n",
53
+ "COLAB_REPO = Path(\"/content/viral-posts-env\")\n",
54
+ "\n",
55
+ "\n",
56
+ "def _is_repo_root(p: Path) -> bool:\n",
57
+ " return (p / \"server\" / \"viraltest_environment.py\").is_file() and (p / \"models.py\").is_file()\n",
58
+ "\n",
59
+ "\n",
60
+ "def _find_local_root() -> Path:\n",
61
+ " here = Path.cwd().resolve()\n",
62
+ " for cand in (here, here.parent, here.parent.parent):\n",
63
+ " if _is_repo_root(cand):\n",
64
+ " return cand\n",
65
+ " raise FileNotFoundError(\n",
66
+ " \"Could not find project root. cd into viral-posts-env or use Colab.\"\n",
67
+ " )\n",
68
+ "\n",
69
+ "\n",
70
+ "if Path(\"/content\").is_dir():\n",
71
+ " if COLAB_REPO.exists():\n",
72
+ " shutil.rmtree(COLAB_REPO, ignore_errors=True)\n",
73
+ " p = subprocess.run(\n",
74
+ " [\"git\", \"clone\", \"--branch\", REPO_BRANCH, \"--depth\", \"1\", REPO_URL, str(COLAB_REPO)],\n",
75
+ " capture_output=True,\n",
76
+ " text=True,\n",
77
+ " )\n",
78
+ " if p.returncode != 0:\n",
79
+ " raise RuntimeError(f\"git clone failed:\\n{p.stderr}\")\n",
80
+ " os.chdir(COLAB_REPO)\n",
81
+ " print(\"Mode: Colab\")\n",
82
+ "else:\n",
83
+ " os.chdir(_find_local_root())\n",
84
+ " print(\"Mode: local\")\n",
85
+ "\n",
86
+ "REPO_DIR = str(Path.cwd().resolve())\n",
87
+ "if REPO_DIR not in sys.path:\n",
88
+ " sys.path.insert(0, REPO_DIR)\n",
89
+ "print(\"REPO_DIR =\", REPO_DIR)"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "metadata": {},
95
+ "execution_count": null,
96
+ "outputs": [],
97
+ "source": [
98
+ "# Cell 3: Core imports + TASK_HORIZON check\n",
99
+ "import os\n",
100
+ "import sys\n",
101
+ "from pathlib import Path\n",
102
+ "\n",
103
+ "if not Path(\"server/viraltest_environment.py\").is_file():\n",
104
+ " for cand in (Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent):\n",
105
+ " if (cand / \"server\" / \"viraltest_environment.py\").is_file():\n",
106
+ " os.chdir(cand)\n",
107
+ " s = str(cand.resolve())\n",
108
+ " if s not in sys.path:\n",
109
+ " sys.path.insert(0, s)\n",
110
+ " print(\"Auto chdir:\", s)\n",
111
+ " break\n",
112
+ " else:\n",
113
+ " raise RuntimeError(\"Run Cell 2 first or open from repo root.\")\n",
114
+ "\n",
115
+ "from models import ScheduledAction, ToolCall, ViraltestAction\n",
116
+ "from server.viraltest_environment import (\n",
117
+ " ViraltestEnvironment,\n",
118
+ " TAG_POOL,\n",
119
+ " TASK_HORIZON,\n",
120
+ " TOPIC_CATEGORIES,\n",
121
+ ")\n",
122
+ "\n",
123
+ "assert TASK_HORIZON == 30, f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}\"\n",
124
+ "print(\"OK: TASK_HORIZON =\", TASK_HORIZON)\n",
125
+ "print(\"OK: tags =\", len(TAG_POOL), \"niches =\", len(TOPIC_CATEGORIES))"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "metadata": {},
131
+ "execution_count": null,
132
+ "outputs": [],
133
+ "source": [
134
+ "# Cell 4: One minimal episode (syntax + env wiring)\n",
135
+ "import random\n",
136
+ "\n",
137
+ "_rng = random.Random(42)\n",
138
+ "\n",
139
+ "\n",
140
+ "def plan_minimal(obs_dict, day):\n",
141
+ " topics = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
142
+ " topic = topics[day % len(topics)]\n",
143
+ " tags = [TAG_POOL[i % len(TAG_POOL)] for i in range(day, day + 3)]\n",
144
+ " return ViraltestAction(\n",
145
+ " scheduled_actions=[\n",
146
+ " ScheduledAction(\n",
147
+ " hour=12,\n",
148
+ " action_type=\"post\",\n",
149
+ " content_type=\"carousel\",\n",
150
+ " topic=topic,\n",
151
+ " tags=tags,\n",
152
+ " intent=\"save_bait\",\n",
153
+ " )\n",
154
+ " ]\n",
155
+ " )\n",
156
+ "\n",
157
+ "\n",
158
+ "def run_episode(task, plan_fn, seed=42):\n",
159
+ " env = ViraltestEnvironment()\n",
160
+ " obs = env.reset(task=task, seed=seed)\n",
161
+ " obs_dict = obs.model_dump()\n",
162
+ " rewards = []\n",
163
+ " for day in range(1, TASK_HORIZON + 1):\n",
164
+ " obs = env.step(plan_fn(obs_dict, day))\n",
165
+ " obs_dict = obs.model_dump()\n",
166
+ " rewards.append(obs.reward or 0.0)\n",
167
+ " if obs.done:\n",
168
+ " break\n",
169
+ " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
170
+ " return {\"steps\": len(rewards), \"total_reward\": sum(rewards), \"grader_score\": gs}\n",
171
+ "\n",
172
+ "\n",
173
+ "r = run_episode(\"monthly_engage\", plan_minimal, seed=42)\n",
174
+ "print(\"Episode result:\", r)\n",
175
+ "assert r[\"steps\"] == TASK_HORIZON, f\"Expected {TASK_HORIZON} steps, got {r['steps']}\"\n",
176
+ "print(\"OK: full monthly episode completed\")"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "metadata": {},
182
+ "execution_count": null,
183
+ "outputs": [],
184
+ "source": [
185
+ "# Cell 5: Optional ML stack (no model download)\n",
186
+ "mods = [\n",
187
+ " \"torch\",\n",
188
+ " \"transformers\",\n",
189
+ " \"peft\",\n",
190
+ " \"trl\",\n",
191
+ " \"datasets\",\n",
192
+ " \"accelerate\",\n",
193
+ "]\n",
194
+ "for m in mods:\n",
195
+ " try:\n",
196
+ " __import__(m)\n",
197
+ " print(\"OK import:\", m)\n",
198
+ " except ImportError as e:\n",
199
+ " print(\"MISSING (install in full notebook):\", m, \"—\", e)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "markdown",
204
+ "metadata": {},
205
+ "source": [
206
+ "If all cells pass, open `train_grpo.ipynb` and run the full pipeline."
207
+ ]
208
+ }
209
+ ]
210
+ }