{ "cells": [ { "cell_type": "markdown", "id": "91ad4325", "metadata": {}, "source": "# Training a SQL Agent with GRPO + SQLEnv\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hjerpe/sql-env/blob/main/notebooks/train_grpo.ipynb)\n\nThis notebook trains a small language model to explore SQL databases using GRPO. It runs end-to-end on a Colab GPU runtime — no external server needed." }, { "cell_type": "markdown", "id": "fc80922f", "metadata": {}, "source": [ "## 1) Setup\n", "Install dependencies, clone the repo, and download the Spider databases. On Colab this happens automatically; locally, run `uv sync` instead." ] }, { "cell_type": "code", "execution_count": null, "id": "bd3d42c7", "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "import sys\n", "from pathlib import Path\n", "\n", "IN_COLAB = \"google.colab\" in sys.modules\n", "\n", "if IN_COLAB:\n", " from google.colab import userdata\n", "\n", " token = userdata.get(\"GITHUB_TOKEN\")\n", "\n", " BRANCH = \"main\" # @param {type:\"string\"}\n", " repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n", "\n", " # Clone or update repo\n", " if Path(\"sql-env\").exists():\n", " subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n", " else:\n", " subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n", " os.chdir(\"sql-env\")\n", " # Install sql-env (always reinstall to pick up latest changes)\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"--no-deps\",\n", " \"--force-reinstall\",\n", " \".\",\n", " ]\n", " )\n", " # TRL 0.29+ and transformers from main (>=5.2.0 for environment_factory)\n", " # Uninstall vllm first to avoid version conflict with transformers 5.x\n", " subprocess.call([sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", \"vllm\"])\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"trl>=0.29.0\",\n", " \"accelerate>=0.34.0\",\n", " \"openenv-core[core]>=0.2.1\",\n", " \"pydantic>=2.0.0\",\n", " \"jmespath\",\n", " \"git+https://github.com/huggingface/transformers.git@main\",\n", " ]\n", " )\n", " # Download Spider databases + generate SFT data\n", " subprocess.check_call([sys.executable, \"scripts/download_spider_databases.py\"])\n", " # SFT data generation — pass --enable-thinking if ENABLE_THINKING is set\n", " # (ENABLE_THINKING is defined in the Configuration cell; default False)\n", " _sft_cmd = [sys.executable, \"scripts/generate_sft_data.py\"]\n", " if globals().get(\"ENABLE_THINKING\", False):\n", " _sft_cmd.append(\"--enable-thinking\")\n", " subprocess.check_call(_sft_cmd)\n", "\n", " # ── Pre-authenticate Drive + HF Hub up front ─────────────────────\n", " # Both require interactive prompts on first use. Doing them now\n", " # means a long training run won't get blocked at the end waiting\n", " # for human input — and a runtime crash mid-push won't lose work.\n", " print(\"\\nPre-authenticating Drive and HuggingFace Hub...\")\n", " try:\n", " from google.colab import drive\n", "\n", " drive.mount(\"/content/drive\", force_remount=False)\n", " print(\" Drive mounted\")\n", " except Exception as exc:\n", " print(f\" Drive mount skipped: {exc}\")\n", "\n", " try:\n", " from huggingface_hub import login\n", "\n", " hf_token = userdata.get(\"HF_TOKEN\")\n", " if hf_token:\n", " login(token=hf_token, add_to_git_credential=False)\n", " print(\" HuggingFace Hub authenticated\")\n", " else:\n", " print(\" HF_TOKEN secret not set — push to Hub will fail later\")\n", " except Exception as exc:\n", " print(f\" HF login skipped: {exc}\")\n", "\n", "from sql_env.training.config import find_project_root\n", "\n", "PROJECT_ROOT = find_project_root()\n", "os.chdir(PROJECT_ROOT)\n", "if str(PROJECT_ROOT) not in sys.path:\n", " sys.path.insert(0, str(PROJECT_ROOT))\n", "\n", "print(f\"Project root: {PROJECT_ROOT}\")\n", "print(f\"Running on: {'Colab' if IN_COLAB else 'local'}\")" ] }, { "cell_type": "markdown", "id": "dc177126", "metadata": {}, "source": [ "## 2) Configuration\n", "Set the model and training hyperparameters." ] }, { "cell_type": "code", "execution_count": null, "id": "978bc98f", "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "import os\n", "\n", "# Reduce CUDA fragmentation (helps after SFT → GRPO handoff)\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from sql_env.training.config import GRPOConfig\n", "from sql_env.training.data_loading import (\n", " load_model_and_tokenizer,\n", " load_question_prompts,\n", ")\n", "from sql_env.training.notebook_pipeline import (\n", " build_trainer,\n", " run_training_with_metrics,\n", " sample_random_baseline,\n", ")\n", "from sql_env.training.trl_adapter import SQLEnvTRL, sql_env_reward_func\n", "\n", "try:\n", " from trl import GRPOConfig as TRLGRPOConfig\n", " from trl import GRPOTrainer\n", " import trl\n", " import transformers\n", "\n", " print(f\"TRL: {trl.__version__}, Transformers: {transformers.__version__}\")\n", "except Exception as exc:\n", " raise RuntimeError(\n", " \"TRL is required. Install dependencies in the Setup cell first.\"\n", " ) from exc\n", "\n", "# ── Pick your model ────────────────────────────────────────────────\n", "# Qwen3-0.6B → fits T4 (16 GB) comfortably\n", "# Qwen3-1.7B → fits L4 (24 GB) with gradient checkpointing\n", "MODEL_NAME = \"Qwen/Qwen3-1.7B\" # @param [\"Qwen/Qwen3-0.6B\", \"Qwen/Qwen3-1.7B\"]\n", "\n", "# ── Thinking mode ──────────────────────────────────────────────────\n", "# When True, the model can generate ... blocks before\n", "# tool calls, enabling reasoning about SQL errors and query strategy.\n", "# Requires more tokens — max_new_tokens is auto-adjusted below.\n", "ENABLE_THINKING = False # @param {type:\"boolean\"}\n", "\n", "# ── SFT warmup ─────────────────────────────────────────────────────\n", "# Teaches one-tool-per-turn pattern before GRPO. Recommended for\n", "# first training run; can skip if resuming from an SFT checkpoint.\n", "RUN_SFT_WARMUP = True # @param {type:\"boolean\"}\n", "\n", "# ── Auto-tune batch settings per model size ────────────────────────\n", "# Note: beta>0 loads a reference model, doubling memory.\n", "# max_new_tokens=512 and num_generations=4 to fit L4 24GB.\n", "_MODEL_PRESETS = {\n", " \"Qwen/Qwen3-0.6B\": dict(\n", " per_device_train_batch_size=6,\n", " gradient_accumulation_steps=1,\n", " num_generations=6,\n", " gradient_checkpointing=False,\n", " ),\n", " \"Qwen/Qwen3-1.7B\": dict(\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " num_generations=4,\n", " gradient_checkpointing=True,\n", " ),\n", "}\n", "\n", "\n", "# Match on model size so HF checkpoint names (e.g. hjerpe/sqlenv-qwen3-1.7b-grpo)\n", "# pick the right preset instead of silently falling back to 0.6B defaults.\n", "def _get_preset(name):\n", " if \"1.7b\" in name.lower():\n", " return _MODEL_PRESETS[\"Qwen/Qwen3-1.7B\"]\n", " return _MODEL_PRESETS[\"Qwen/Qwen3-0.6B\"]\n", "\n", "\n", "_preset = _get_preset(MODEL_NAME)\n", "\n", "# Thinking mode needs more tokens for blocks + tool calls.\n", "# Phase 1 base: 512 (no-think) → 768 (thinking)\n", "_phase1_tokens = 768 if ENABLE_THINKING else 512\n", "\n", "config = GRPOConfig(\n", " questions_path=\"data/questions/questions_train.json\",\n", " db_dir=\"data/databases\",\n", " output_dir=\"outputs/grpo_run\",\n", " model_name=MODEL_NAME,\n", " num_train_epochs=1,\n", " max_new_tokens=_phase1_tokens,\n", " step_budget=10,\n", " precision=\"bf16\",\n", " beta=0.04,\n", " enable_thinking=ENABLE_THINKING,\n", " **_preset,\n", ")\n", "print(f\"Model: {config.model_name}\")\n", "print(f\"Thinking mode: {'ON' if ENABLE_THINKING else 'OFF'}\")\n", "print(\n", " f\"Batch: {config.per_device_train_batch_size} × {config.gradient_accumulation_steps} accum\"\n", ")\n", "print(\n", " f\"Generations: {config.num_generations}, Gradient ckpt: {config.gradient_checkpointing}\"\n", ")\n", "print(f\"Max tokens: {config.max_new_tokens}, Beta: {config.beta}\")\n", "print(f\"Epochs: {config.num_train_epochs}\")\n", "print(f\"SFT warmup: {'ON' if RUN_SFT_WARMUP else 'OFF'}\")" ] }, { "cell_type": "markdown", "id": "f8711d45", "metadata": {}, "source": [ "## 3) Smoke Test\n", "Verify the environment works in-process before starting the training loop." ] }, { "cell_type": "code", "execution_count": null, "id": "559dbd33", "metadata": {}, "outputs": [], "source": [ "from sql_env.server.sql_environment import SQLEnvironment\n", "from sql_env.server.mock_tokenizer import MockTokenizer\n", "from sql_env.models import SQLAction\n", "from sql_env.training.data_loading import validate_no_data_leak\n", "\n", "env = SQLEnvironment(\n", " questions_path=config.questions_path,\n", " db_dir=config.db_dir,\n", " tokenizer=MockTokenizer(),\n", " step_budget=config.step_budget,\n", ")\n", "\n", "# Guard against train/eval data leakage\n", "eval_path = \"data/questions/questions_eval.json\"\n", "if Path(eval_path).exists():\n", " validate_no_data_leak(config.questions_path, eval_path)\n", " print(\"Data leak check: PASSED (0 question overlap)\")\n", "\n", "obs = env.reset(seed=42)\n", "print(f\"Loaded {len(env.questions)} questions\")\n", "print(f\"Question: {obs.question}\")\n", "obs = env.step(\n", " SQLAction(\n", " action_type=\"DESCRIBE\",\n", " argument=obs.schema_info.split(chr(10))[1].strip(\"- \").strip(),\n", " )\n", ")\n", "print(f\"Schema: {obs.schema_info[:200]}\")\n", "print(\"Smoke test passed.\")" ] }, { "cell_type": "markdown", "id": "7fe676db", "metadata": {}, "source": "## 4) Load Model\nDownload the model and tokenizer, load training prompts." }, { "cell_type": "code", "execution_count": null, "id": "979208ee", "metadata": {}, "outputs": [], "source": [ "model, tokenizer = load_model_and_tokenizer(config.model_name)\n", "raw_prompts = load_question_prompts(config.questions_path, config.difficulty_filter)\n", "\n", "# System prompt: uses get_system_prompt() which prepends /no_think\n", "# when thinking mode is disabled, or omits it when enabled.\n", "from scripts.generate_sft_data import get_system_prompt\n", "\n", "SYSTEM_PROMPT = get_system_prompt(enable_thinking=config.enable_thinking)\n", "\n", "# question_text is passed to reset() so the environment loads the correct DB\n", "prompts = [\n", " {\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": item[\"prompt\"]},\n", " ],\n", " \"question_text\": item[\"prompt\"],\n", " }\n", " for item in raw_prompts\n", "]\n", "\n", "# Verify SFT/GRPO prompt alignment\n", "from sql_env.training.trl_adapter import get_tool_definitions\n", "\n", "_tool_defs = get_tool_definitions()\n", "_rendered = tokenizer.apply_chat_template(\n", " [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}, {\"role\": \"user\", \"content\": \"test\"}],\n", " tools=_tool_defs,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "assert \"\" in _rendered and '\"name\": \"query\"' in _rendered\n", "print(f\"Model: {config.model_name}\")\n", "print(f\"Thinking mode: {'ON' if config.enable_thinking else 'OFF'}\")\n", "print(f\"Prompts: {len(prompts)} questions\")\n", "print(f\"SFT/GRPO alignment: OK ({len(_tool_defs)} tools in prompt)\")\n", "print(f\"System prompt: {len(SYSTEM_PROMPT)} chars\")" ] }, { "cell_type": "markdown", "id": "sxvb8e0qufq", "metadata": {}, "source": "## 5) Live Visualization\nThe reward plot updates in place every `logging_steps`. TRL prints completion tables showing actual model output and tool calls." }, { "cell_type": "code", "execution_count": null, "id": "d8c31628", "metadata": {}, "outputs": [], "source": [ "from sql_env.training.visualization import LiveVisualizationCallback, SFTMonitorCallback\n", "from sql_env.training.trl_adapter import get_tool_definitions\n", "\n", "viz = LiveVisualizationCallback()\n", "\n", "# Pick diverse sample prompts (every Nth to span different databases)\n", "_sft_tools = get_tool_definitions()\n", "_step = max(1, len(prompts) // 3)\n", "_sft_monitor_prompts = [\n", " prompts[i * _step][\"prompt\"] for i in range(3) if i * _step < len(prompts)\n", "]\n", "sft_monitor = SFTMonitorCallback(\n", " tokenizer,\n", " _sft_monitor_prompts,\n", " tools=_sft_tools,\n", " eval_every_steps=40,\n", ")\n", "print(f\"Visualization + SFT monitor ({len(_sft_tools)} tools) ready.\")" ] }, { "cell_type": "markdown", "id": "b5jsw6qm1ml", "metadata": {}, "source": "## 6) SFT Warmup\nShort supervised fine-tuning on ideal trajectories to teach the model basic tool-calling patterns: one tool per turn, read responses, submit real values. GRPO then refines the reasoning." }, { "cell_type": "code", "execution_count": null, "id": "gc6hjtd3eje", "metadata": {}, "outputs": [], "source": [ "if RUN_SFT_WARMUP:\n", " import json as _json\n", " from datasets import Dataset\n", " from sql_env.training.trl_adapter import get_tool_definitions\n", "\n", " TOOL_DEFS = get_tool_definitions()\n", " print(f\"Tools extracted: {[t['function']['name'] for t in TOOL_DEFS]}\")\n", "\n", " # ── Patch Qwen3 chat template for assistant_only_loss ──────────\n", " # Qwen3's template lacks {% generation %} tags that TRL needs to\n", " # build assistant_masks. We patch it for SFT only, then restore\n", " # the original before GRPO (which does its own template checks).\n", " # See: QwenLM/Qwen3#1522, huggingface/trl#4879\n", " _original_chat_template = tokenizer.chat_template\n", " _use_assistant_only = False\n", "\n", " if \"{% generation %}\" not in _original_chat_template:\n", " _ASST_START = '{%- elif message.role == \"assistant\" %}'\n", " _ASST_END = \"{{- '<|im_end|>\\\\n' }}\\n {%- elif message.role == \\\"tool\\\" %}\"\n", "\n", " _patched = _original_chat_template.replace(\n", " _ASST_START,\n", " _ASST_START + \"\\n {% generation %}\",\n", " ).replace(\n", " _ASST_END,\n", " \"{% endgeneration %}\" + _ASST_END,\n", " )\n", " if \"{% generation %}\" in _patched and \"{% endgeneration %}\" in _patched:\n", " tokenizer.chat_template = _patched\n", " print(\"Qwen3 template patched with {% generation %} tags\")\n", "\n", " # Verify the patch produces assistant masks\n", " _test_out = tokenizer.apply_chat_template(\n", " [\n", " {\"role\": \"system\", \"content\": \"test\"},\n", " {\"role\": \"user\", \"content\": \"hello\"},\n", " {\"role\": \"assistant\", \"content\": \"world\"},\n", " ],\n", " tokenize=True,\n", " return_dict=True,\n", " return_assistant_tokens_mask=True,\n", " )\n", " _has_mask = (\n", " \"assistant_masks\" in _test_out and 1 in _test_out[\"assistant_masks\"]\n", " )\n", " _use_assistant_only = _has_mask\n", " else:\n", " print(\"WARNING: Could not patch template\")\n", "\n", " print(f\"Assistant-only loss: {'ENABLED' if _use_assistant_only else 'DISABLED'}\")\n", "\n", " # Load SFT trajectories (generated by scripts/generate_sft_data.py)\n", " sft_path = PROJECT_ROOT / \"data\" / \"sft\" / \"sft_trajectories.json\"\n", " if not sft_path.exists():\n", " print(\"Generating SFT trajectories...\")\n", " subprocess.check_call(\n", " [sys.executable, str(PROJECT_ROOT / \"scripts\" / \"generate_sft_data.py\")]\n", " )\n", "\n", " with open(sft_path) as f:\n", " sft_data = _json.load(f)\n", "\n", " # Ensure every trajectory has tool definitions\n", " for row in sft_data:\n", " if \"tools\" not in row:\n", " row[\"tools\"] = TOOL_DEFS\n", "\n", " sft_dataset = Dataset.from_list(sft_data)\n", " print(f\"SFT dataset: {len(sft_dataset)} multi-turn trajectories\")\n", "\n", " # ── Render training data preview ───────────────────────────────\n", " # Uses the exact same apply_chat_template call as TRL's SFTTrainer,\n", " # so the output file matches training input with zero divergence.\n", " _render_path = PROJECT_ROOT / \"data\" / \"sft\" / \"sft_rendered.txt\"\n", " _render_n = min(5, len(sft_data))\n", " _render_parts = []\n", " _total_tok, _total_asst = 0, 0\n", "\n", " for _ri in range(_render_n):\n", " _ex = sft_data[_ri]\n", " _msgs = _ex[\"messages\"]\n", " _tools = _ex.get(\"tools\", TOOL_DEFS)\n", "\n", " _text = tokenizer.apply_chat_template(\n", " _msgs,\n", " tools=_tools,\n", " tokenize=False,\n", " )\n", " _tok_out = tokenizer.apply_chat_template(\n", " _msgs,\n", " tools=_tools,\n", " tokenize=True,\n", " return_dict=True,\n", " return_assistant_tokens_mask=True,\n", " )\n", " _n_tok = len(_tok_out[\"input_ids\"])\n", " _mask = _tok_out.get(\"assistant_masks\", [])\n", " _n_asst = sum(_mask) if _mask else 0\n", " _total_tok += _n_tok\n", " _total_asst += _n_asst\n", "\n", " _header = (\n", " f\"{'=' * 70}\\n\"\n", " f\"Example {_ri} | {_n_tok} tokens | \"\n", " f\"{_n_asst} assistant tokens \"\n", " f\"({_n_asst / _n_tok:.0%} of sequence)\\n\"\n", " f\"{'=' * 70}\"\n", " )\n", " _render_parts.append(f\"{_header}\\n{_text}\")\n", "\n", " _summary = (\n", " f\"SFT Training Data Preview \"\n", " f\"(rendered by training tokenizer)\\n\"\n", " f\"Model: {config.model_name}\\n\"\n", " f\"Template patched: \"\n", " f\"{'yes' if _use_assistant_only else 'no'}\\n\"\n", " f\"Examples shown: {_render_n} / {len(sft_data)}\\n\"\n", " f\"Avg tokens/example: {_total_tok / _render_n:.0f} | \"\n", " f\"Avg assistant tokens: {_total_asst / _render_n:.0f} \"\n", " f\"({_total_asst / _total_tok:.0%})\\n\"\n", " )\n", " _full = _summary + \"\\n\" + \"\\n\\n\".join(_render_parts) + \"\\n\"\n", " _render_path.parent.mkdir(parents=True, exist_ok=True)\n", " _render_path.write_text(_full)\n", " print(f\"\\nRendered {_render_n} examples -> {_render_path}\")\n", " print(\n", " f\"Avg: {_total_tok // _render_n} tok/example, \"\n", " f\"{_total_asst // _render_n} assistant tok \"\n", " f\"({_total_asst / _total_tok:.0%})\"\n", " )\n", "\n", " # Show first example inline\n", " print(f\"\\n{'=' * 70}\")\n", " print(\"Example 0 preview (first 1500 chars):\")\n", " print(f\"{'=' * 70}\")\n", " print(_render_parts[0][:1500])\n", " if len(_render_parts[0]) > 1500:\n", " print(f\"... (see {_render_path} for full output)\")\n", "\n", " # ── Train ──────────────────────────────────────────────────────\n", " sft_monitor.train_dataset = sft_dataset\n", "\n", " from trl import SFTConfig as TRLSFTConfig, SFTTrainer\n", "\n", " sft_output_dir = str(PROJECT_ROOT / \"outputs\" / \"sft_warmup\")\n", "\n", " _sft_batch = 2 if \"1.7B\" in config.model_name else 4\n", "\n", " _sft_kwargs = dict(\n", " output_dir=sft_output_dir,\n", " num_train_epochs=2,\n", " per_device_train_batch_size=_sft_batch,\n", " gradient_accumulation_steps=1,\n", " learning_rate=2e-5,\n", " lr_scheduler_type=\"cosine\",\n", " warmup_steps=10,\n", " logging_steps=10,\n", " save_strategy=\"no\",\n", " bf16=True,\n", " gradient_checkpointing=config.gradient_checkpointing,\n", " assistant_only_loss=_use_assistant_only,\n", " )\n", " import inspect\n", "\n", " _sft_params = inspect.signature(TRLSFTConfig).parameters\n", " if \"max_seq_length\" in _sft_params:\n", " _sft_kwargs[\"max_seq_length\"] = 1024\n", " elif \"max_length\" in _sft_params:\n", " _sft_kwargs[\"max_length\"] = 1024\n", "\n", " sft_config = TRLSFTConfig(**_sft_kwargs)\n", "\n", " sft_trainer = SFTTrainer(\n", " model=model,\n", " args=sft_config,\n", " train_dataset=sft_dataset,\n", " processing_class=tokenizer,\n", " callbacks=[viz, sft_monitor],\n", " )\n", "\n", " _total_steps = len(sft_dataset) // _sft_batch * _sft_kwargs[\"num_train_epochs\"]\n", " print(\n", " f\"\\nSFT: {_sft_batch} batch x {_total_steps} steps, \"\n", " f\"grad_ckpt: {config.gradient_checkpointing}\"\n", " )\n", " print(f\"Assistant-only loss: {_use_assistant_only}\")\n", " print(\"Starting SFT warmup (2 epochs, LR=2e-5)...\")\n", " sft_trainer.train()\n", " print(\"SFT warmup done.\")\n", "\n", " # ── Restore original template for GRPO ─────────────────────────\n", " # GRPO's get_training_chat_template() does exact-match detection\n", " # and will fail on our patched template. Restore the original so\n", " # GRPO can apply its own prefix-preserving template.\n", " tokenizer.chat_template = _original_chat_template\n", " print(\"Template restored to original for GRPO\")\n", "\n", " # ── Aggressive cleanup for GRPO ────────────────────────────────\n", " import gc\n", " import torch\n", "\n", " model = sft_trainer.model\n", " if hasattr(model, \"module\"):\n", " model = model.module\n", " if hasattr(model, \"gradient_checkpointing_disable\"):\n", " model.gradient_checkpointing_disable()\n", "\n", " del sft_trainer, sft_dataset, sft_data, sft_config\n", " sft_monitor.train_dataset = None\n", " gc.collect()\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " _free = torch.cuda.mem_get_info()[0] / 1e9\n", " print(f\"GPU memory freed: {_free:.1f} GB available\")\n", "\n", " viz = LiveVisualizationCallback()\n", " print(\"Model handed off to GRPO.\")\n", "else:\n", " print(\"Skipping SFT warmup (RUN_SFT_WARMUP = False)\")" ] }, { "cell_type": "markdown", "id": "ro5feofkx6q", "metadata": {}, "source": "## 6b) Post-SFT Format Check\nQuick generation to verify the model learned tool-call format before starting GRPO." }, { "cell_type": "code", "execution_count": null, "id": "m6nk5n7zwg", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Quick format check: generate one completion to verify tool-call quality.\n", "# Full eval via play_episode is too expensive before GRPO (causes OOM).\n", "from sql_env.training.trl_adapter import get_tool_definitions\n", "\n", "_eval_tools = get_tool_definitions()\n", "_eval_msgs = [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": prompts[0][\"question_text\"]},\n", "]\n", "_eval_rendered = tokenizer.apply_chat_template(\n", " _eval_msgs,\n", " tools=_eval_tools,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "_eval_inputs = tokenizer(_eval_rendered, return_tensors=\"pt\")\n", "_eval_inputs = {k: v.to(model.device) for k, v in _eval_inputs.items()}\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " _eval_out = model.generate(**_eval_inputs, max_new_tokens=200, do_sample=False)\n", " _eval_new = _eval_out[0][_eval_inputs[\"input_ids\"].shape[1] :]\n", " _eval_text = tokenizer.decode(_eval_new, skip_special_tokens=True)\n", "model.train()\n", "\n", "print(\"Post-SFT format check:\")\n", "print(f\" Q: {prompts[0]['question_text'][:80]}...\")\n", "print(f\" → {_eval_text[:300]}\")\n", "has_tool_call = \"\" in tokenizer.decode(_eval_new, skip_special_tokens=False)\n", "print(f\" Tool-call format: {'YES' if has_tool_call else 'NO (plain text)'}\")\n", "\n", "# Free memory before GRPO\n", "del _eval_out, _eval_inputs, _eval_new\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "9zjonqe6dbg", "metadata": {}, "source": "## 7) Train with GRPO\nStarting from the SFT-warmed model, GRPO refines SQL reasoning via reward signal from the environment." }, { "cell_type": "code", "execution_count": null, "id": "xpwe5woqqcb", "metadata": {}, "outputs": [], "source": [ "# Phase 1: easy only — stabilize format with KL penalty\n", "easy_prompts = [\n", " {\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": item[\"prompt\"]},\n", " ],\n", " \"question_text\": item[\"prompt\"],\n", " }\n", " for item in load_question_prompts(config.questions_path, [\"easy\"])\n", "]\n", "\n", "print(\n", " f\"Phase 1: {len(easy_prompts)} easy questions (beta={config.beta}, max_tokens={config.max_new_tokens})\"\n", ")\n", "if config.enable_thinking:\n", " print(\" Thinking mode ON — model can reason in blocks\")\n", "\n", "trainer = build_trainer(\n", " trl_grpo_config_cls=TRLGRPOConfig,\n", " grpo_trainer_cls=GRPOTrainer,\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompts=easy_prompts,\n", " config=config,\n", " reward_funcs=[sql_env_reward_func],\n", " environment_factory=SQLEnvTRL,\n", " callbacks=[viz],\n", ")\n", "\n", "before_rollouts = sample_random_baseline(\n", " [item[\"question_text\"] for item in easy_prompts],\n", " step_budget=config.step_budget,\n", " seed=config.seed,\n", ")\n", "print(f\"Random baseline episodes: {len(before_rollouts)}\")\n", "\n", "train_output, steps, rewards = run_training_with_metrics(trainer)\n", "print(f\"Phase 1 done: {len(steps)} logged steps\")\n", "\n", "# Free reference model before phase 2\n", "import gc\n", "import torch\n", "\n", "del trainer\n", "gc.collect()\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " _free = torch.cuda.mem_get_info()[0] / 1e9\n", " print(f\"GPU memory freed: {_free:.1f} GB available\")\n", "\n", "# Phase 2: easy + medium — no KL penalty, longer tokens\n", "# Thinking mode gets extra headroom: 1024 → 1280\n", "from dataclasses import replace\n", "\n", "_phase2_tokens = 1280 if config.enable_thinking else 1024\n", "config2 = replace(config, beta=0.0, max_new_tokens=_phase2_tokens)\n", "\n", "print(\n", " f\"\\nPhase 2: {len(prompts)} easy+medium questions (beta={config2.beta}, max_tokens={config2.max_new_tokens})\"\n", ")\n", "\n", "viz = LiveVisualizationCallback()\n", "\n", "trainer2 = build_trainer(\n", " trl_grpo_config_cls=TRLGRPOConfig,\n", " grpo_trainer_cls=GRPOTrainer,\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompts=prompts,\n", " config=config2,\n", " reward_funcs=[sql_env_reward_func],\n", " environment_factory=SQLEnvTRL,\n", " callbacks=[viz],\n", ")\n", "\n", "train_output2, steps2, rewards2 = run_training_with_metrics(trainer2)\n", "print(f\"Phase 2 done: {len(steps2)} logged steps\")\n", "\n", "# Combine metrics\n", "all_steps = steps + [s + (steps[-1] if steps else 0) for s in steps2]\n", "all_rewards = rewards + rewards2\n", "print(\n", " f\"Total: {len(all_steps)} steps, final reward: {all_rewards[-1] if all_rewards else 'N/A'}\"\n", ")" ] }, { "cell_type": "markdown", "id": "c68fc96d", "metadata": {}, "source": "## 8) Final Summary\nStatic plot of the full training run." }, { "cell_type": "code", "execution_count": null, "id": "5e691b70", "metadata": {}, "outputs": [], "source": [ "if all_steps and all_rewards:\n", " plt.figure(figsize=(8, 4))\n", " plt.plot(all_steps, all_rewards, marker=\"o\", linewidth=1.5)\n", " # Mark phase boundary\n", " if steps:\n", " plt.axvline(\n", " x=steps[-1], color=\"red\", linestyle=\"--\", alpha=0.5, label=\"Phase 1→2\"\n", " )\n", " plt.legend()\n", " plt.title(\"GRPO Reward Trend (Easy → Easy+Medium)\")\n", " plt.xlabel(\"Training Step\")\n", " plt.ylabel(\"Reward\")\n", " plt.grid(alpha=0.3)\n", " plt.show()\n", "else:\n", " print(\"No reward points available yet.\")" ] }, { "cell_type": "markdown", "id": "98ce823b", "metadata": {}, "source": "## 9) Save and Push to Hub\nSave the trained model locally and optionally push to HuggingFace Hub." }, { "cell_type": "code", "execution_count": null, "id": "26maekrxzz1", "metadata": {}, "outputs": [], "source": [ "# Save locally — use model directly (trainer was deleted during phase handoff)\n", "model.save_pretrained(config.output_dir)\n", "tokenizer.save_pretrained(config.output_dir)\n", "print(f\"Model saved to {config.output_dir}\")\n", "\n", "# ── Run identifier (used for both Drive backup and HF repo) ────────\n", "HF_SUFFIX = \"\" # @param {type:\"string\"}\n", "\n", "# Extract base model name, stripping any existing sqlenv-/grpo- prefixes\n", "# so resuming from \"hjerpe/sqlenv-qwen3-0.6b-grpo\" doesn't double them.\n", "import re\n", "\n", "_raw_short = config.model_name.split(\"/\")[-1].lower()\n", "_model_short = re.sub(r\"^sqlenv-\", \"\", _raw_short)\n", "_model_short = re.sub(r\"-grpo.*$\", \"\", _model_short)\n", "_suffix = f\"-{HF_SUFFIX}\" if HF_SUFFIX and not HF_SUFFIX.startswith(\"-\") else HF_SUFFIX\n", "_run_name = f\"sqlenv-{_model_short}-grpo{_suffix}\"\n", "HF_REPO = f\"hjerpe/{_run_name}\"\n", "print(f\"Run name: {_run_name} → {HF_REPO}\")\n", "\n", "# ── Backup to Google Drive (Drive was mounted in Setup cell) ───────\n", "BACKUP_TO_DRIVE = True # @param {type:\"boolean\"}\n", "\n", "if IN_COLAB and BACKUP_TO_DRIVE:\n", " _drive_backup_dir = f\"/content/drive/MyDrive/sqlenv-checkpoints/{_run_name}\"\n", " if not Path(\"/content/drive/MyDrive\").exists():\n", " # Drive wasn't pre-mounted in Setup — mount now (interactive)\n", " from google.colab import drive\n", "\n", " drive.mount(\"/content/drive\", force_remount=False)\n", " subprocess.check_call([\"mkdir\", \"-p\", _drive_backup_dir])\n", " subprocess.check_call([\"cp\", \"-r\", f\"{config.output_dir}/.\", _drive_backup_dir])\n", " print(f\"Backup saved to Drive: {_drive_backup_dir}\")\n", "\n", "# ── Push to HuggingFace Hub (HF Hub was authenticated in Setup) ────\n", "PUSH_TO_HUB = True # @param {type:\"boolean\"}\n", "\n", "if PUSH_TO_HUB:\n", " # Free GPU memory before push to avoid OOM during serialization\n", " import gc\n", "\n", " if torch.cuda.is_available():\n", " model = model.to(\"cpu\")\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " print(\"Moved model to CPU for safer push\")\n", "\n", " # Retry push on transient HF Hub errors (504, etc.)\n", " _max_retries = 3\n", " for _attempt in range(1, _max_retries + 1):\n", " try:\n", " model.push_to_hub(HF_REPO, commit_message=\"GRPO trained on SQLEnv\")\n", " tokenizer.push_to_hub(HF_REPO)\n", " print(f\"Pushed to https://huggingface.co/{HF_REPO}\")\n", " break\n", " except Exception as exc:\n", " print(f\"Push attempt {_attempt}/{_max_retries} failed: {exc}\")\n", " if _attempt == _max_retries:\n", " print(f\"Push failed after {_max_retries} attempts.\")\n", " if BACKUP_TO_DRIVE and IN_COLAB:\n", " print(f\"Recover from Drive backup: {_drive_backup_dir}\")\n", " else:\n", " print(f\"Local checkpoint still at: {config.output_dir}\")\n", " raise\n", " import time\n", "\n", " time.sleep(10 * _attempt)\n", "else:\n", " print(f\"Set PUSH_TO_HUB = True to push to {HF_REPO}\")" ] } ], "metadata": { "colab": { "name": "train_grpo.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }