{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PhysiX RLVR \u2014 SFT + GRPO Training (Colab)\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/)  \n", "[Trained model](https://huggingface.co/Pratyush-01/physix-3b-rl)  | \n", "[Live demo Space](https://huggingface.co/spaces/Pratyush-01/physix-live)  | \n", "[W&B runs](https://wandb.ai/pratyush01/physix-live)\n", "\n", "This notebook reproduces the **PhysiX** training pipeline end-to-end on a single Colab GPU:\n", "\n", "1. **SFT warm-start** on synthetic ground-truth equations so the model emits the right JSON grammar.\n", "2. **GRPO RLVR** on the [PhysiX OpenEnv](https://huggingface.co/spaces/Pratyush-01/physix-live) \u2014 reward comes from `scipy.odeint` + per-step R\u00b2, no LLM judge.\n", "3. Plot loss / reward curves and (optionally) push the merged model to the HF Hub.\n", "\n", "Built on **[OpenEnv](https://github.com/openenv-hackathon/openenv)** + **Unsloth** + **TRL** \u2014 same code path as our cloud `job_train_single.py`, just driven from a notebook.\n", "\n", "**Recommended runtime:** Runtime \u2192 Change runtime type \u2192 **A100** (or L4/T4 \u2014 see the *Profile* cell below; a T4 will work for `1.5b` only).\n", "\n", "**Optional secrets** (Colab \u2192 \ud83d\udd11 sidebar \u2192 Secrets, or paste interactively when prompted in cell 3):\n", "- `HF_TOKEN` \u2014 write-scoped token from . Only needed if you set `PUSH_TO_HUB=True` to push the trained model to the Hub. The source bundle is public so the fetch in cell 5 works without it.\n", "- `WANDB_API_KEY` \u2014 from . Optional: enables live W&B logging (you can also disable W&B in the *Profile* cell)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. GPU sanity check" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Environment hardening\n", "\n", "Same hardening as the HF Jobs script:\n", "- Route every cache (HF, torch inductor, triton, W&B) under `/tmp` so we don't blow Colab's disk quota.\n", "- Disable `torch.compile` / inductor at four layers \u2014 Unsloth GRPO triggers an inductor CPU-SIMD probe on the first step that hard-fails on minimal envs, and we don't need compile speedups for a small LoRA run.\n", "- Set `WANDB_PROJECT` and disable wandb model-artifact uploads (we push to HF Hub instead)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "os.environ.setdefault(\"USER\", \"physix\")\n", "os.environ.setdefault(\"LOGNAME\", \"physix\")\n", "os.environ.setdefault(\"HOME\", \"/tmp/home\")\n", "\n", "os.environ.setdefault(\"HF_HOME\", \"/tmp/hf_cache\")\n", "os.environ.setdefault(\"TORCHINDUCTOR_CACHE_DIR\", \"/tmp/torchinductor_cache\")\n", "os.environ.setdefault(\"TRITON_CACHE_DIR\", \"/tmp/triton_cache\")\n", "os.environ.setdefault(\"XDG_CACHE_HOME\", \"/tmp/xdg-cache\")\n", "os.environ.setdefault(\"WANDB_DIR\", \"/tmp/wandb\")\n", "os.environ.setdefault(\"WANDB_CACHE_DIR\", \"/tmp/wandb-cache\")\n", "os.environ.setdefault(\"WANDB_DATA_DIR\", \"/tmp/wandb-data\")\n", "os.environ.setdefault(\"WANDB_ARTIFACT_DIR\", \"/tmp/wandb-artifacts\")\n", "os.environ.setdefault(\"WANDB_CONFIG_DIR\", \"/tmp/wandb-config\")\n", "\n", "os.environ.setdefault(\"WANDB_DISABLE_ARTIFACTS\", \"true\")\n", "os.environ.setdefault(\"WANDB_LOG_MODEL\", \"false\")\n", "os.environ.setdefault(\"WANDB_PROJECT\", \"physix-live\")\n", "\n", "os.environ.setdefault(\"UNSLOTH_COMPILE_DISABLE\", \"1\")\n", "os.environ.setdefault(\"TORCH_COMPILE_DISABLE\", \"1\")\n", "os.environ.setdefault(\"TORCHINDUCTOR_DISABLE\", \"1\")\n", "os.environ.setdefault(\"TORCHDYNAMO_DISABLE\", \"1\")\n", "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True\")\n", "os.environ.setdefault(\"TOKENIZERS_PARALLELISM\", \"false\")\n", "os.environ.setdefault(\"PYTHONUNBUFFERED\", \"1\")\n", "\n", "for d in (\n", " os.environ[\"HOME\"],\n", " os.environ[\"HF_HOME\"],\n", " os.environ[\"TORCHINDUCTOR_CACHE_DIR\"],\n", " os.environ[\"TRITON_CACHE_DIR\"],\n", " os.environ[\"XDG_CACHE_HOME\"],\n", " os.environ[\"WANDB_DIR\"],\n", " os.environ[\"WANDB_CACHE_DIR\"],\n", " os.environ[\"WANDB_DATA_DIR\"],\n", " os.environ[\"WANDB_ARTIFACT_DIR\"],\n", " os.environ[\"WANDB_CONFIG_DIR\"],\n", "):\n", " Path(d).mkdir(parents=True, exist_ok=True)\n", "\n", "print(\"caches under /tmp ready\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load secrets (`HF_TOKEN`, `WANDB_API_KEY`) \u2014 both optional\n\nBoth secrets are optional \u2014 leave them blank and the notebook still trains end-to-end without push-to-Hub or W&B logging, just without push-to-Hub or live W&B logging. Resolved in this order, first hit wins, no overwrites:\n\n1. **Existing environment variable** \u2014 anything already in `os.environ` (e.g. set with `%env HF_TOKEN=...` in a prior cell, or inherited from the host shell). `HUGGINGFACE_HUB_TOKEN` and `HF_HUB_TOKEN` are accepted as aliases for `HF_TOKEN`.\n2. **Colab Secrets** (\ud83d\udd11 sidebar) \u2014 recommended; survives runtime restarts. Add `HF_TOKEN` and/or `WANDB_API_KEY` and enable notebook access for each.\n3. **Interactive prompt** \u2014 masked `getpass` input; press Enter to skip. The token is *not* echoed to the cell output and *not* persisted in the notebook file.\n\n`HF_TOKEN` is only needed if `PUSH_TO_HUB=True` in the *Profile* cell (use a write-scoped token then). `WANDB_API_KEY` is only needed for live W&B logging \u2014 loss / reward are saved to `trainer_state.json` regardless and we plot from there in cell 9." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from getpass import getpass\n", "\n", "PROMPT_FOR_MISSING_SECRETS = True\n", "\n", "# Resolution order for each secret:\n", "# 1. Already in os.environ (e.g. you ran `%env HF_TOKEN=...` earlier, or\n", "# the variable was inherited from the host shell).\n", "# 2. Colab Secrets (\ud83d\udd11 sidebar) \u2014 only available in Colab runtimes.\n", "# 3. Interactive getpass() prompt (masked input, not echoed to the cell\n", "# output and not persisted in the notebook).\n", "# The first hit wins; we never overwrite a value already present.\n", "\n", "ALIASES: dict[str, tuple[str, ...]] = {\n", " \"HF_TOKEN\": (\"HF_TOKEN\", \"HUGGINGFACE_HUB_TOKEN\", \"HF_HUB_TOKEN\"),\n", " \"WANDB_API_KEY\": (\"WANDB_API_KEY\",),\n", "}\n", "\n", "def _from_env(key: str) -> str | None:\n", " \"\"\"Return the first non-empty env var across `key` and its aliases.\"\"\"\n", " for name in ALIASES[key]:\n", " val = os.environ.get(name)\n", " if val:\n", " return val\n", " return None\n", "\n", "for key in (\"HF_TOKEN\", \"WANDB_API_KEY\"):\n", " val = _from_env(key)\n", " if val:\n", " os.environ[key] = val\n", " print(f\"{key}: loaded from environment\")\n", "\n", "try:\n", " from google.colab import userdata # type: ignore[import-not-found]\n", " for key in (\"HF_TOKEN\", \"WANDB_API_KEY\"):\n", " if os.environ.get(key):\n", " continue\n", " try:\n", " val = userdata.get(key)\n", " except Exception:\n", " val = None\n", " if val:\n", " os.environ[key] = val\n", " print(f\"{key}: loaded from Colab Secrets\")\n", "except ImportError:\n", " pass\n", "\n", "def _prompt(key: str, blurb: str, url: str) -> None:\n", " if os.environ.get(key) or not PROMPT_FOR_MISSING_SECRETS:\n", " return\n", " print(f\"\\n{key} not found in environment or Colab Secrets.\")\n", " print(f\" {blurb}\")\n", " print(f\" Get one at: {url}\")\n", " print(f\" Press Enter to skip, or paste your token (input is hidden):\")\n", " try:\n", " val = getpass(f\" {key} = \").strip()\n", " except (EOFError, KeyboardInterrupt):\n", " val = \"\"\n", " if val:\n", " os.environ[key] = val\n", " print(f\"{key}: set from interactive prompt\")\n", "\n", "_prompt(\n", " \"HF_TOKEN\",\n", " \"Optional: only needed if PUSH_TO_HUB=True in the Profile cell.\",\n", " \"https://huggingface.co/settings/tokens\",\n", ")\n", "_prompt(\n", " \"WANDB_API_KEY\",\n", " \"Optional: enables live W&B logging; metrics are still saved to disk regardless.\",\n", " \"https://wandb.ai/authorize\",\n", ")\n", "\n", "print()\n", "if os.environ.get(\"HF_TOKEN\"):\n", " os.environ[\"HUGGINGFACE_HUB_TOKEN\"] = os.environ[\"HF_TOKEN\"]\n", " print(\"HF_TOKEN: ready (push-to-Hub enabled)\")\n", "else:\n", " print(\"HF_TOKEN: NOT set (push-to-Hub will be skipped; source fetch is public)\")\n", "\n", "if os.environ.get(\"WANDB_API_KEY\"):\n", " print(\"WANDB_API_KEY: ready\")\n", "else:\n", " print(\"WANDB_API_KEY: NOT set (W&B logging will be disabled)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Install dependencies\n", "\n", "Pinned the same way as the HF Jobs container:\n", "- `trl==0.24.0` \u2014 Unsloth's `patch_trl_openenv()` does `inspect.getsource(...)` on a TRL internal that breaks on newer TRL.\n", "- `unsloth` latest \u2014 provides the GRPO patch and FastLanguageModel.\n", "- `openenv-core[core]>=0.2.2` \u2014 the OpenEnv SDK we build on.\n", "- The rest are runtime deps used by the verifier (`scipy`, `sympy`, `pydantic`).\n", "\n", "Colab pre-installs torch with CUDA, so we let pip pick a self-consistent set rather than re-pin torch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install -q --upgrade pip\n", "!pip install -q \\\n", " \"unsloth\" \\\n", " \"trl==0.24.0\" \\\n", " \"transformers\" \\\n", " \"datasets\" \\\n", " \"peft\" \\\n", " \"accelerate\" \\\n", " \"bitsandbytes\" \\\n", " \"wandb\" \\\n", " \"setuptools\" \\\n", " \"wheel\" \\\n", " \"scipy>=1.10,<2.0\" \\\n", " \"sympy>=1.12,<2.0\" \\\n", " \"pydantic>=2.5,<3.0\" \\\n", " \"numpy>=1.24,<3.0\" \\\n", " \"openenv-core[core]>=0.2.2\" \\\n", " \"huggingface_hub>=0.24,<1.0\" \\\n", " \"matplotlib>=3.7,<4.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Fetch the `physix-live` source\n\nDownloads the latest source directly from the live HF Space (`Pratyush-01/physix-live`) so the notebook always has the up-to-date code. No token required \u2014 the Space is public." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\nimport shutil\nfrom huggingface_hub import snapshot_download\n\nPHYSIX_REPO = \"Pratyush-01/physix-live\"\nPHYSIX_LOCAL = Path(\"/tmp/src/physix-live\")\n\nif PHYSIX_LOCAL.exists():\n shutil.rmtree(PHYSIX_LOCAL)\nPHYSIX_LOCAL.parent.mkdir(parents=True, exist_ok=True)\n\nsnapshot_download(\n repo_id=PHYSIX_REPO,\n repo_type=\"space\",\n local_dir=str(PHYSIX_LOCAL),\n token=os.environ.get(\"HF_TOKEN\"),\n)\n\nassert (PHYSIX_LOCAL / \"pyproject.toml\").is_file(), f\"missing pyproject.toml in {PHYSIX_LOCAL}\"\nassert (PHYSIX_LOCAL / \"physix\" / \"__init__.py\").is_file(), f\"missing physix/ package in {PHYSIX_LOCAL}\"\nprint(\"physix-live source at\", PHYSIX_LOCAL)\n!ls {PHYSIX_LOCAL}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import subprocess, sys, importlib, site\n\ncmd = [sys.executable, \"-m\", \"pip\", \"install\", \"--no-deps\", \"-e\", str(PHYSIX_LOCAL)]\nprint(\"$\", \" \".join(cmd))\nsubprocess.run(cmd, check=True)\n\n# Refresh import caches so the freshly-installed package is discoverable.\nimportlib.invalidate_caches()\nsite.main()\n\n# Verify physix is now importable before moving on.\ntry:\n import physix\n print(f\"\\nphysix imported OK from {physix.__file__}\")\nexcept ImportError as e:\n print(f\"\\nFAILED to import physix: {e}\")\n print(\"Site-packages dirs:\")\n for d in site.getsitepackages():\n print(f\" {d}\")\n raise" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import unsloth # MUST come before trl / transformers so its monkey-patches land first\n", "import torch, trl, transformers, datasets, wandb\n", "import physix\n", "\n", "device = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None\n", "print(f\"torch={torch.__version__} cuda={torch.cuda.is_available()} device={device}\")\n", "print(f\"unsloth={unsloth.__version__} trl={trl.__version__} transformers={transformers.__version__} datasets={datasets.__version__}\")\n", "print(f\"physix loaded from {physix.__file__}\")\n", "assert trl.__version__ == \"0.24.0\", f\"trl must be pinned to 0.24.0, got {trl.__version__}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Pick a profile\n\n| Profile | Base model | LoRA r | SFT epochs | GRPO steps | Recommended GPU |\n|--------|------------|:-:|:-:|:-:|---------|\n| `1.5b` | Qwen2.5-1.5B-Instruct | 32 | 3 | 300 | T4, L4, A100 |\n| `3b` | Qwen2.5-3B-Instruct | 32 | 3 | 200 | L4, A100 |\n| `7b` | Qwen2.5-7B-Instruct | 16 | 3 | 200 | A100 only |\n\nWe default to `3b` \u2014 the same configuration used for the published W&B runs. Switch to `1.5b` if you only have a T4.\n\nSet `SYSTEM_ID` to a single system (e.g. `\"damped_spring\"`) for a focused-reward run, or `None` to train across all 3 trained systems.\n\nSet `PUSH_TO_HUB = False` if you don't want to push the merged model \u2014 local-only runs need no `HF_TOKEN`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PROFILES: dict[str, dict] = {\n \"1.5b\": {\n \"base_model\": \"Qwen/Qwen2.5-1.5B-Instruct\",\n \"sft_lora_r\": \"32\",\n \"grpo_lora_r\": \"32\",\n \"sft_lr\": \"2e-5\",\n \"grpo_lr\": \"5e-6\",\n \"sft_epochs\": \"3\",\n \"num_steps\": \"300\",\n \"num_generations\": \"4\",\n \"max_completion\": \"256\",\n \"hub_final_repo\": \"Pratyush-01/physix-1.5b-rl-colab\",\n \"hub_ckpt_repo\": \"Pratyush-01/physix-1.5b-rl-colab-ckpt\",\n \"sft_run_name\": \"physix-sft-1.5b-colab\",\n \"grpo_run_name\": \"physix-grpo-1.5b-colab\",\n },\n \"3b\": {\n \"base_model\": \"Qwen/Qwen2.5-3B-Instruct\",\n \"sft_lora_r\": \"32\",\n \"grpo_lora_r\": \"32\",\n \"sft_lr\": \"1.5e-5\",\n \"grpo_lr\": \"3e-6\",\n \"sft_epochs\": \"3\",\n \"num_steps\": \"200\",\n \"num_generations\": \"4\",\n \"max_completion\": \"384\",\n \"hub_final_repo\": \"Pratyush-01/physix-3b-rl-colab\",\n \"hub_ckpt_repo\": \"Pratyush-01/physix-3b-rl-colab-ckpt\",\n \"sft_run_name\": \"physix-sft-3b-colab\",\n \"grpo_run_name\": \"physix-grpo-3b-colab\",\n },\n \"7b\": {\n \"base_model\": \"Qwen/Qwen2.5-7B-Instruct\",\n \"sft_lora_r\": \"16\",\n \"grpo_lora_r\": \"16\",\n \"sft_lr\": \"1e-5\",\n \"grpo_lr\": \"2e-6\",\n \"sft_epochs\": \"3\",\n \"num_steps\": \"200\",\n \"num_generations\": \"4\",\n \"max_completion\": \"256\",\n \"hub_final_repo\": \"Pratyush-01/physix-7b-rl-colab\",\n \"hub_ckpt_repo\": \"Pratyush-01/physix-7b-rl-colab-ckpt\",\n \"sft_run_name\": \"physix-sft-7b-colab\",\n \"grpo_run_name\": \"physix-grpo-7b-colab\",\n },\n}\n\nACTIVE_PROFILE = \"3b\"\nSYSTEM_ID: str | None = \"damped_spring\"\nINSTANCES_PER_SYSTEM = 32\nPUSH_TO_HUB = False\n\np = PROFILES[ACTIVE_PROFILE]\n\nif not os.environ.get(\"WANDB_API_KEY\"):\n os.environ[\"WANDB_MODE\"] = \"disabled\"\n print(\"WANDB_MODE=disabled (no WANDB_API_KEY). Loss/reward will still be logged to disk.\")\n\nprint(f\"Profile: {ACTIVE_PROFILE} base={p['base_model']}\")\nprint(f\"System: {SYSTEM_ID or 'all 3 trained systems'}\")\nprint(f\"Push to Hub: {PUSH_TO_HUB}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. SFT warm-start\n", "\n", "Trains LoRA adapters on synthetic `(prompt, ground_truth_equation)` pairs generated from the env. After SFT, the model emits valid JSON in the right grammar \u2014 without this step ~80 % of GRPO completions are unparseable and reward variance collapses to zero (= flat loss curve, wasted GPU credits).\n", "\n", "Calls the same `physix.training.sft` module the cloud job uses." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import subprocess, sys\n\nSFT_OUT = \"/tmp/physix-sft\"\n\nsft_cmd = [\n sys.executable, \"-m\", \"physix.training.sft\",\n \"--model\", p[\"base_model\"],\n \"--output-dir\", SFT_OUT,\n \"--epochs\", p[\"sft_epochs\"],\n \"--instances-per-system\", str(INSTANCES_PER_SYSTEM),\n \"--lora-r\", p[\"sft_lora_r\"],\n \"--learning-rate\", p[\"sft_lr\"],\n \"--wandb-run-name\", p[\"sft_run_name\"],\n \"--seed\", \"0\",\n]\nif SYSTEM_ID:\n sft_cmd += [\"--system-ids\", SYSTEM_ID]\nif PUSH_TO_HUB and os.environ.get(\"HF_TOKEN\"):\n sft_cmd += [\"--hub-checkpoint-repo-id\", p[\"hub_ckpt_repo\"]]\n\nprint(\"=\" * 78)\nprint(\" SFT is launched as a subprocess. Per-step loss WILL NOT stream to this cell.\")\nprint(\" \u2192 Live loss curve: https://wandb.ai/pratyush01/physix-live\")\nprint(f\" run name: {p['sft_run_name']}\")\nprint(\" (If WANDB_API_KEY is unset, metrics are still saved to trainer_state.json\")\nprint(\" under SFT_OUT and plotted in the final cell of this notebook.)\")\nprint(\"=\" * 78)\nprint(\"$\", \" \".join(sft_cmd))\nsubprocess.run(sft_cmd, check=True)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. GRPO RLVR\n", "\n", "Reward functions live in `physix.training.reward_fns` and run inside the OpenEnv:\n", "\n", "| Component | Weight | Formula | Role |\n", "|-----------|:-:|---------|------|\n", "| `match` | 0.50 | R\u00b2 (observed vs predicted) | primary accuracy |\n", "| `match_dense` | \u2014 | \u221aR\u00b2 | non-trivial gradient near zero R\u00b2 |\n", "| `correctness` | \u2014 | 1 if R\u00b2 \u2265 0.70 else 0 | binary bonus past plateau |\n", "| `simplicity` | 0.20 | `1 \u2212 ops/12`, gated on R\u00b2 \u2265 0.10 | prefer shorter equations |\n", "| `format` | 0.10 | parses **and** simulates | syntactic + numerical validity |\n", "\n", "We deliberately do **not** include `progress` in the GRPO reward set \u2014 every GRPO row starts with `previous_r_match=0`, so `progress` would just be a redundant copy of `match`. It's only used in multi-turn live episodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "GRPO_OUT = \"/tmp/physix-grpo\"\n\ngrpo_cmd = [\n sys.executable, \"-m\", \"physix.training.loop\",\n \"--model\", p[\"base_model\"],\n \"--output-dir\", GRPO_OUT,\n \"--num-steps\", p[\"num_steps\"],\n \"--num-generations\", p[\"num_generations\"],\n \"--max-completion-length\", p[\"max_completion\"],\n \"--learning-rate\", p[\"grpo_lr\"],\n \"--instances-per-system\", str(INSTANCES_PER_SYSTEM),\n \"--lora-r\", p[\"grpo_lora_r\"],\n \"--save-method\", \"merged_16bit\",\n \"--wandb-project\", \"physix-live\",\n \"--wandb-run-name\", p[\"grpo_run_name\"],\n \"--sft-checkpoint\", f\"{SFT_OUT}/merged\",\n \"--seed\", \"0\",\n]\nif SYSTEM_ID:\n grpo_cmd += [\"--system-ids\", SYSTEM_ID]\nif PUSH_TO_HUB and os.environ.get(\"HF_TOKEN\"):\n grpo_cmd += [\n \"--push-to-hub\",\n \"--hub-repo-id\", p[\"hub_final_repo\"],\n \"--hub-checkpoint-repo-id\", p[\"hub_ckpt_repo\"],\n ]\n\nprint(\"=\" * 78)\nprint(\" GRPO is launched as a subprocess. Per-step reward / loss WILL NOT stream here.\")\nprint(\" \u2192 Live curves (reward, reward_std, KL, per-component reward, loss):\")\nprint(\" https://wandb.ai/pratyush01/physix-live\")\nprint(f\" run name: {p['grpo_run_name']}\")\nprint(\" (If WANDB_API_KEY is unset, all metrics are still saved to\")\nprint(\" trainer_state.json under GRPO_OUT and plotted in the final cell.)\")\nprint(\"=\" * 78)\nprint(\"$\", \" \".join(grpo_cmd))\nsubprocess.run(grpo_cmd, check=True)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Plot loss & reward curves\n", "\n", "TRL's `GRPOTrainer` writes a `trainer_state.json` containing all logged metrics \u2014 even when W&B is disabled, the loss / reward / per-component breakdown is on disk. This cell reads it and produces the same plots that live under `docs/plots/` in the README." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json, glob\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "\n", "candidates = sorted(glob.glob(f\"{GRPO_OUT}/**/trainer_state.json\", recursive=True))\n", "if not candidates:\n", " candidates = [f\"{GRPO_OUT}/trainer_state.json\"]\n", "state_path = candidates[-1]\n", "print(\"Reading\", state_path)\n", "\n", "log = json.loads(Path(state_path).read_text())[\"log_history\"]\n", "\n", "def series(key: str):\n", " xs, ys = [], []\n", " for row in log:\n", " if key in row and \"step\" in row:\n", " xs.append(row[\"step\"])\n", " ys.append(row[key])\n", " return xs, ys\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", "\n", "xs, ys = series(\"loss\")\n", "axes[0].plot(xs, ys, color=\"#444\")\n", "axes[0].set_title(\"GRPO loss\")\n", "axes[0].set_xlabel(\"step\"); axes[0].set_ylabel(\"loss\"); axes[0].grid(alpha=0.3)\n", "\n", "xs, ys = series(\"reward\")\n", "axes[1].plot(xs, ys, color=\"#1f77b4\")\n", "axes[1].set_title(\"GRPO reward (aggregate)\")\n", "axes[1].set_xlabel(\"step\"); axes[1].set_ylabel(\"reward\"); axes[1].grid(alpha=0.3)\n", "\n", "plt.tight_layout(); plt.savefig(\"/tmp/colab_loss_reward.png\", dpi=120); plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "components = [\n", " (\"rewards/reward_match/mean\", \"match (R\u00b2)\"),\n", " (\"rewards/reward_match_dense/mean\", \"match_dense (\u221aR\u00b2)\"),\n", " (\"rewards/reward_correctness/mean\", \"correctness (R\u00b2\u22650.70)\"),\n", " (\"rewards/reward_simplicity/mean\", \"simplicity\"),\n", " (\"rewards/reward_format/mean\", \"format\"),\n", "]\n", "plt.figure(figsize=(10, 5))\n", "for key, label in components:\n", " xs, ys = series(key)\n", " if xs:\n", " plt.plot(xs, ys, label=label, alpha=0.85)\n", "plt.title(\"Per-component reward\")\n", "plt.xlabel(\"step\"); plt.ylabel(\"mean reward\"); plt.grid(alpha=0.3); plt.legend()\n", "plt.tight_layout(); plt.savefig(\"/tmp/colab_reward_components.png\", dpi=120); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. Try the trained model on a fresh trajectory (optional)\n", "\n", "Quick smoke test: spin up the same OpenEnv server in-process, give the trained model one new noisy trajectory, and check the reward breakdown." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from physix.server.environment import PhysiXEnvironment\n", "from physix.training.prompt import build_prompt\n", "from physix.models import PhysiXAction\n", "from unsloth import FastLanguageModel\n", "\n", "# GRPO writes to {output_dir}/{save_method}; we used merged_16bit.\n", "# SFT writes to {output_dir}/merged. Prefer GRPO if it ran to completion,\n", "# otherwise fall back to the SFT checkpoint so cell 9 still has something\n", "# to demo even if you stopped GRPO early.\n", "GRPO_MERGED = Path(GRPO_OUT) / \"merged_16bit\"\n", "SFT_MERGED = Path(SFT_OUT) / \"merged\"\n", "checkpoint = GRPO_MERGED if GRPO_MERGED.exists() else SFT_MERGED\n", "print(f\"Loading checkpoint: {checkpoint}\")\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=str(checkpoint),\n", " max_seq_length=4096, load_in_4bit=False,\n", ")\n", "FastLanguageModel.for_inference(model)\n", "\n", "env = PhysiXEnvironment()\n", "obs = env.reset(seed=99, system_id=SYSTEM_ID or \"damped_spring\")\n", "messages = build_prompt(obs) # chat-format [{\"role\": \"system\", ...}, {\"role\": \"user\", ...}]\n", "prompt_text = tokenizer.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True,\n", ")\n", "inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(\"cuda\")\n", "out = model.generate(**inputs, max_new_tokens=256, do_sample=False)\n", "completion = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)\n", "print(\"=== Model output ===\\n\", completion)\n", "\n", "try:\n", " parsed = json.loads(completion)\n", " action = PhysiXAction(**parsed)\n", " step_obs = env.step(action)\n", " print(\"\\n=== Reward breakdown ===\\n\", step_obs.reward_breakdown)\n", "except Exception as exc:\n", " print(\"Failed to parse / step:\", exc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 11. Done\n", "\n", "If `PUSH_TO_HUB=True` and `HF_TOKEN` was provided, the merged 16-bit model is now at:\n", "\n", "- `https://huggingface.co/{p['hub_final_repo']}` (final)\n", "- `https://huggingface.co/{p['hub_ckpt_repo']}` (mid-run checkpoints)\n", "\n", "Plots are saved to `/tmp/colab_loss_reward.png` and `/tmp/colab_reward_components.png` \u2014 drop them into the README under `docs/plots/`.\n", "\n", "**Next:** explore the live demo at or run `python -m physix.server.app` locally and point a fresh OpenEnv client at it." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }