{ "nbformat": 4, "nbformat_minor": 4, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `train_grpo_smoke.ipynb` — syntax & environment smoke test\n", "\n", "Companion to `train_grpo.ipynb`. **Fast** (~1–2 min): checks imports, repo layout, `TASK_HORIZON`, and one short env run.\n", "\n", "Run **all cells top to bottom** in Colab or locally before starting the full training notebook." ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Cell 1: Minimal deps (quoted versions for zsh / shell safety)\n", "!pip install -q pydantic httpx\n", "!pip install -q \"openenv-core[core]>=0.2.2\"" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Cell 2: Repo path (same logic as main notebook)\n", "import os\n", "import sys\n", "import shutil\n", "import subprocess\n", "from pathlib import Path\n", "\n", "REPO_BRANCH = \"main\"\n", "REPO_URL = \"https://github.com/VaibhavKhandare/viral-posts-env.git\"\n", "COLAB_REPO = Path(\"/content/viral-posts-env\")\n", "\n", "\n", "def _is_repo_root(p: Path) -> bool:\n", " return (p / \"server\" / \"viraltest_environment.py\").is_file() and (p / \"models.py\").is_file()\n", "\n", "\n", "def _find_local_root() -> Path:\n", " here = Path.cwd().resolve()\n", " for cand in (here, here.parent, here.parent.parent):\n", " if _is_repo_root(cand):\n", " return cand\n", " raise FileNotFoundError(\n", " \"Could not find project root. cd into viral-posts-env or use Colab.\"\n", " )\n", "\n", "\n", "if Path(\"/content\").is_dir():\n", " if COLAB_REPO.exists():\n", " shutil.rmtree(COLAB_REPO, ignore_errors=True)\n", " p = subprocess.run(\n", " [\"git\", \"clone\", \"--branch\", REPO_BRANCH, \"--depth\", \"1\", REPO_URL, str(COLAB_REPO)],\n", " capture_output=True,\n", " text=True,\n", " )\n", " if p.returncode != 0:\n", " raise RuntimeError(f\"git clone failed:\\n{p.stderr}\")\n", " os.chdir(COLAB_REPO)\n", " print(\"Mode: Colab\")\n", "else:\n", " os.chdir(_find_local_root())\n", " print(\"Mode: local\")\n", "\n", "REPO_DIR = str(Path.cwd().resolve())\n", "if REPO_DIR not in sys.path:\n", " sys.path.insert(0, REPO_DIR)\n", "print(\"REPO_DIR =\", REPO_DIR)" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Cell 3: Core imports + TASK_HORIZON check\n", "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "if not Path(\"server/viraltest_environment.py\").is_file():\n", " for cand in (Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent):\n", " if (cand / \"server\" / \"viraltest_environment.py\").is_file():\n", " os.chdir(cand)\n", " s = str(cand.resolve())\n", " if s not in sys.path:\n", " sys.path.insert(0, s)\n", " print(\"Auto chdir:\", s)\n", " break\n", " else:\n", " raise RuntimeError(\"Run Cell 2 first or open from repo root.\")\n", "\n", "from models import ScheduledAction, ToolCall, ViraltestAction\n", "from server.viraltest_environment import (\n", " ViraltestEnvironment,\n", " TAG_POOL,\n", " TASK_HORIZON,\n", " TOPIC_CATEGORIES,\n", ")\n", "\n", "assert TASK_HORIZON == 30, f\"Expected TASK_HORIZON=30, got {TASK_HORIZON}\"\n", "print(\"OK: TASK_HORIZON =\", TASK_HORIZON)\n", "print(\"OK: tags =\", len(TAG_POOL), \"niches =\", len(TOPIC_CATEGORIES))" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Cell 4: One minimal episode (syntax + env wiring)\n", "import random\n", "\n", "_rng = random.Random(42)\n", "\n", "\n", "def plan_minimal(obs_dict, day):\n", " topics = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n", " topic = topics[day % len(topics)]\n", " tags = [TAG_POOL[i % len(TAG_POOL)] for i in range(day, day + 3)]\n", " return ViraltestAction(\n", " scheduled_actions=[\n", " ScheduledAction(\n", " hour=12,\n", " action_type=\"post\",\n", " content_type=\"carousel\",\n", " topic=topic,\n", " tags=tags,\n", " intent=\"save_bait\",\n", " )\n", " ]\n", " )\n", "\n", "\n", "def run_episode(task, plan_fn, seed=42):\n", " env = ViraltestEnvironment()\n", " obs = env.reset(task=task, seed=seed)\n", " obs_dict = obs.model_dump()\n", " rewards = []\n", " for day in range(1, TASK_HORIZON + 1):\n", " obs = env.step(plan_fn(obs_dict, day))\n", " obs_dict = obs.model_dump()\n", " rewards.append(obs.reward or 0.0)\n", " if obs.done:\n", " break\n", " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n", " return {\"steps\": len(rewards), \"total_reward\": sum(rewards), \"grader_score\": gs}\n", "\n", "\n", "r = run_episode(\"monthly_engage\", plan_minimal, seed=42)\n", "print(\"Episode result:\", r)\n", "assert r[\"steps\"] == TASK_HORIZON, f\"Expected {TASK_HORIZON} steps, got {r['steps']}\"\n", "print(\"OK: full monthly episode completed\")" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Cell 5: Optional ML stack (no model download)\n", "mods = [\n", " \"torch\",\n", " \"transformers\",\n", " \"peft\",\n", " \"trl\",\n", " \"datasets\",\n", " \"accelerate\",\n", "]\n", "for m in mods:\n", " try:\n", " __import__(m)\n", " print(\"OK import:\", m)\n", " except ImportError as e:\n", " print(\"MISSING (install in full notebook):\", m, \"—\", e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If all cells pass, open `train_grpo.ipynb` and run the full pipeline." ] } ] }