{
"cells": [
{
"cell_type": "markdown",
"id": "91ad4325",
"metadata": {},
"source": "# Training a SQL Agent with GRPO + SQLEnv\n\n[](https://colab.research.google.com/github/hjerpe/sql-env/blob/main/notebooks/train_grpo.ipynb)\n\nThis notebook trains a small language model to explore SQL databases using GRPO. It runs end-to-end on a Colab GPU runtime — no external server needed."
},
{
"cell_type": "markdown",
"id": "fc80922f",
"metadata": {},
"source": [
"## 1) Setup\n",
"Install dependencies, clone the repo, and download the Spider databases. On Colab this happens automatically; locally, run `uv sync` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd3d42c7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import subprocess\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"IN_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"if IN_COLAB:\n",
" from google.colab import userdata\n",
"\n",
" token = userdata.get(\"GITHUB_TOKEN\")\n",
"\n",
" BRANCH = \"main\" # @param {type:\"string\"}\n",
" repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n",
"\n",
" # Clone or update repo\n",
" if Path(\"sql-env\").exists():\n",
" subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n",
" else:\n",
" subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n",
" os.chdir(\"sql-env\")\n",
" # Install sql-env (always reinstall to pick up latest changes)\n",
" subprocess.check_call(\n",
" [\n",
" sys.executable,\n",
" \"-m\",\n",
" \"pip\",\n",
" \"install\",\n",
" \"-q\",\n",
" \"--no-deps\",\n",
" \"--force-reinstall\",\n",
" \".\",\n",
" ]\n",
" )\n",
" # TRL 0.29+ and transformers from main (>=5.2.0 for environment_factory)\n",
" # Uninstall vllm first to avoid version conflict with transformers 5.x\n",
" subprocess.call([sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", \"vllm\"])\n",
" subprocess.check_call(\n",
" [\n",
" sys.executable,\n",
" \"-m\",\n",
" \"pip\",\n",
" \"install\",\n",
" \"-q\",\n",
" \"trl>=0.29.0\",\n",
" \"accelerate>=0.34.0\",\n",
" \"openenv-core[core]>=0.2.1\",\n",
" \"pydantic>=2.0.0\",\n",
" \"jmespath\",\n",
" \"git+https://github.com/huggingface/transformers.git@main\",\n",
" ]\n",
" )\n",
" # Download Spider databases + generate SFT data\n",
" subprocess.check_call([sys.executable, \"scripts/download_spider_databases.py\"])\n",
" # SFT data generation — pass --enable-thinking if ENABLE_THINKING is set\n",
" # (ENABLE_THINKING is defined in the Configuration cell; default False)\n",
" _sft_cmd = [sys.executable, \"scripts/generate_sft_data.py\"]\n",
" if globals().get(\"ENABLE_THINKING\", False):\n",
" _sft_cmd.append(\"--enable-thinking\")\n",
" subprocess.check_call(_sft_cmd)\n",
"\n",
" # ── Pre-authenticate Drive + HF Hub up front ─────────────────────\n",
" # Both require interactive prompts on first use. Doing them now\n",
" # means a long training run won't get blocked at the end waiting\n",
" # for human input — and a runtime crash mid-push won't lose work.\n",
" print(\"\\nPre-authenticating Drive and HuggingFace Hub...\")\n",
" try:\n",
" from google.colab import drive\n",
"\n",
" drive.mount(\"/content/drive\", force_remount=False)\n",
" print(\" Drive mounted\")\n",
" except Exception as exc:\n",
" print(f\" Drive mount skipped: {exc}\")\n",
"\n",
" try:\n",
" from huggingface_hub import login\n",
"\n",
" hf_token = userdata.get(\"HF_TOKEN\")\n",
" if hf_token:\n",
" login(token=hf_token, add_to_git_credential=False)\n",
" print(\" HuggingFace Hub authenticated\")\n",
" else:\n",
" print(\" HF_TOKEN secret not set — push to Hub will fail later\")\n",
" except Exception as exc:\n",
" print(f\" HF login skipped: {exc}\")\n",
"\n",
"from sql_env.training.config import find_project_root\n",
"\n",
"PROJECT_ROOT = find_project_root()\n",
"os.chdir(PROJECT_ROOT)\n",
"if str(PROJECT_ROOT) not in sys.path:\n",
" sys.path.insert(0, str(PROJECT_ROOT))\n",
"\n",
"print(f\"Project root: {PROJECT_ROOT}\")\n",
"print(f\"Running on: {'Colab' if IN_COLAB else 'local'}\")"
]
},
{
"cell_type": "markdown",
"id": "dc177126",
"metadata": {},
"source": [
"## 2) Configuration\n",
"Set the model and training hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "978bc98f",
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"import os\n",
"\n",
"# Reduce CUDA fragmentation (helps after SFT → GRPO handoff)\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sql_env.training.config import GRPOConfig\n",
"from sql_env.training.data_loading import (\n",
" load_model_and_tokenizer,\n",
" load_question_prompts,\n",
")\n",
"from sql_env.training.notebook_pipeline import (\n",
" build_trainer,\n",
" run_training_with_metrics,\n",
" sample_random_baseline,\n",
")\n",
"from sql_env.training.trl_adapter import SQLEnvTRL, sql_env_reward_func\n",
"\n",
"try:\n",
" from trl import GRPOConfig as TRLGRPOConfig\n",
" from trl import GRPOTrainer\n",
" import trl\n",
" import transformers\n",
"\n",
" print(f\"TRL: {trl.__version__}, Transformers: {transformers.__version__}\")\n",
"except Exception as exc:\n",
" raise RuntimeError(\n",
" \"TRL is required. Install dependencies in the Setup cell first.\"\n",
" ) from exc\n",
"\n",
"# ── Pick your model ────────────────────────────────────────────────\n",
"# Qwen3-0.6B → fits T4 (16 GB) comfortably\n",
"# Qwen3-1.7B → fits L4 (24 GB) with gradient checkpointing\n",
"MODEL_NAME = \"Qwen/Qwen3-1.7B\" # @param [\"Qwen/Qwen3-0.6B\", \"Qwen/Qwen3-1.7B\"]\n",
"\n",
"# ── Thinking mode ──────────────────────────────────────────────────\n",
"# When True, the model can generate ... blocks before\n",
"# tool calls, enabling reasoning about SQL errors and query strategy.\n",
"# Requires more tokens — max_new_tokens is auto-adjusted below.\n",
"ENABLE_THINKING = False # @param {type:\"boolean\"}\n",
"\n",
"# ── SFT warmup ─────────────────────────────────────────────────────\n",
"# Teaches one-tool-per-turn pattern before GRPO. Recommended for\n",
"# first training run; can skip if resuming from an SFT checkpoint.\n",
"RUN_SFT_WARMUP = True # @param {type:\"boolean\"}\n",
"\n",
"# ── Auto-tune batch settings per model size ────────────────────────\n",
"# Note: beta>0 loads a reference model, doubling memory.\n",
"# max_new_tokens=512 and num_generations=4 to fit L4 24GB.\n",
"_MODEL_PRESETS = {\n",
" \"Qwen/Qwen3-0.6B\": dict(\n",
" per_device_train_batch_size=6,\n",
" gradient_accumulation_steps=1,\n",
" num_generations=6,\n",
" gradient_checkpointing=False,\n",
" ),\n",
" \"Qwen/Qwen3-1.7B\": dict(\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" num_generations=4,\n",
" gradient_checkpointing=True,\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Match on model size so HF checkpoint names (e.g. hjerpe/sqlenv-qwen3-1.7b-grpo)\n",
"# pick the right preset instead of silently falling back to 0.6B defaults.\n",
"def _get_preset(name):\n",
" if \"1.7b\" in name.lower():\n",
" return _MODEL_PRESETS[\"Qwen/Qwen3-1.7B\"]\n",
" return _MODEL_PRESETS[\"Qwen/Qwen3-0.6B\"]\n",
"\n",
"\n",
"_preset = _get_preset(MODEL_NAME)\n",
"\n",
"# Thinking mode needs more tokens for blocks + tool calls.\n",
"# Phase 1 base: 512 (no-think) → 768 (thinking)\n",
"_phase1_tokens = 768 if ENABLE_THINKING else 512\n",
"\n",
"config = GRPOConfig(\n",
" questions_path=\"data/questions/questions_train.json\",\n",
" db_dir=\"data/databases\",\n",
" output_dir=\"outputs/grpo_run\",\n",
" model_name=MODEL_NAME,\n",
" num_train_epochs=1,\n",
" max_new_tokens=_phase1_tokens,\n",
" step_budget=10,\n",
" precision=\"bf16\",\n",
" beta=0.04,\n",
" enable_thinking=ENABLE_THINKING,\n",
" **_preset,\n",
")\n",
"print(f\"Model: {config.model_name}\")\n",
"print(f\"Thinking mode: {'ON' if ENABLE_THINKING else 'OFF'}\")\n",
"print(\n",
" f\"Batch: {config.per_device_train_batch_size} × {config.gradient_accumulation_steps} accum\"\n",
")\n",
"print(\n",
" f\"Generations: {config.num_generations}, Gradient ckpt: {config.gradient_checkpointing}\"\n",
")\n",
"print(f\"Max tokens: {config.max_new_tokens}, Beta: {config.beta}\")\n",
"print(f\"Epochs: {config.num_train_epochs}\")\n",
"print(f\"SFT warmup: {'ON' if RUN_SFT_WARMUP else 'OFF'}\")"
]
},
{
"cell_type": "markdown",
"id": "f8711d45",
"metadata": {},
"source": [
"## 3) Smoke Test\n",
"Verify the environment works in-process before starting the training loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "559dbd33",
"metadata": {},
"outputs": [],
"source": [
"from sql_env.server.sql_environment import SQLEnvironment\n",
"from sql_env.server.mock_tokenizer import MockTokenizer\n",
"from sql_env.models import SQLAction\n",
"from sql_env.training.data_loading import validate_no_data_leak\n",
"\n",
"env = SQLEnvironment(\n",
" questions_path=config.questions_path,\n",
" db_dir=config.db_dir,\n",
" tokenizer=MockTokenizer(),\n",
" step_budget=config.step_budget,\n",
")\n",
"\n",
"# Guard against train/eval data leakage\n",
"eval_path = \"data/questions/questions_eval.json\"\n",
"if Path(eval_path).exists():\n",
" validate_no_data_leak(config.questions_path, eval_path)\n",
" print(\"Data leak check: PASSED (0 question overlap)\")\n",
"\n",
"obs = env.reset(seed=42)\n",
"print(f\"Loaded {len(env.questions)} questions\")\n",
"print(f\"Question: {obs.question}\")\n",
"obs = env.step(\n",
" SQLAction(\n",
" action_type=\"DESCRIBE\",\n",
" argument=obs.schema_info.split(chr(10))[1].strip(\"- \").strip(),\n",
" )\n",
")\n",
"print(f\"Schema: {obs.schema_info[:200]}\")\n",
"print(\"Smoke test passed.\")"
]
},
{
"cell_type": "markdown",
"id": "7fe676db",
"metadata": {},
"source": "## 4) Load Model\nDownload the model and tokenizer, load training prompts."
},
{
"cell_type": "code",
"execution_count": null,
"id": "979208ee",
"metadata": {},
"outputs": [],
"source": [
"model, tokenizer = load_model_and_tokenizer(config.model_name)\n",
"raw_prompts = load_question_prompts(config.questions_path, config.difficulty_filter)\n",
"\n",
"# System prompt: uses get_system_prompt() which prepends /no_think\n",
"# when thinking mode is disabled, or omits it when enabled.\n",
"from scripts.generate_sft_data import get_system_prompt\n",
"\n",
"SYSTEM_PROMPT = get_system_prompt(enable_thinking=config.enable_thinking)\n",
"\n",
"# question_text is passed to reset() so the environment loads the correct DB\n",
"prompts = [\n",
" {\n",
" \"prompt\": [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": item[\"prompt\"]},\n",
" ],\n",
" \"question_text\": item[\"prompt\"],\n",
" }\n",
" for item in raw_prompts\n",
"]\n",
"\n",
"# Verify SFT/GRPO prompt alignment\n",
"from sql_env.training.trl_adapter import get_tool_definitions\n",
"\n",
"_tool_defs = get_tool_definitions()\n",
"_rendered = tokenizer.apply_chat_template(\n",
" [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}, {\"role\": \"user\", \"content\": \"test\"}],\n",
" tools=_tool_defs,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"assert \"\" in _rendered and '\"name\": \"query\"' in _rendered\n",
"print(f\"Model: {config.model_name}\")\n",
"print(f\"Thinking mode: {'ON' if config.enable_thinking else 'OFF'}\")\n",
"print(f\"Prompts: {len(prompts)} questions\")\n",
"print(f\"SFT/GRPO alignment: OK ({len(_tool_defs)} tools in prompt)\")\n",
"print(f\"System prompt: {len(SYSTEM_PROMPT)} chars\")"
]
},
{
"cell_type": "markdown",
"id": "sxvb8e0qufq",
"metadata": {},
"source": "## 5) Live Visualization\nThe reward plot updates in place every `logging_steps`. TRL prints completion tables showing actual model output and tool calls."
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8c31628",
"metadata": {},
"outputs": [],
"source": [
"from sql_env.training.visualization import LiveVisualizationCallback, SFTMonitorCallback\n",
"from sql_env.training.trl_adapter import get_tool_definitions\n",
"\n",
"viz = LiveVisualizationCallback()\n",
"\n",
"# Pick diverse sample prompts (every Nth to span different databases)\n",
"_sft_tools = get_tool_definitions()\n",
"_step = max(1, len(prompts) // 3)\n",
"_sft_monitor_prompts = [\n",
" prompts[i * _step][\"prompt\"] for i in range(3) if i * _step < len(prompts)\n",
"]\n",
"sft_monitor = SFTMonitorCallback(\n",
" tokenizer,\n",
" _sft_monitor_prompts,\n",
" tools=_sft_tools,\n",
" eval_every_steps=40,\n",
")\n",
"print(f\"Visualization + SFT monitor ({len(_sft_tools)} tools) ready.\")"
]
},
{
"cell_type": "markdown",
"id": "b5jsw6qm1ml",
"metadata": {},
"source": "## 6) SFT Warmup\nShort supervised fine-tuning on ideal trajectories to teach the model basic tool-calling patterns: one tool per turn, read responses, submit real values. GRPO then refines the reasoning."
},
{
"cell_type": "code",
"execution_count": null,
"id": "gc6hjtd3eje",
"metadata": {},
"outputs": [],
"source": [
"if RUN_SFT_WARMUP:\n",
" import json as _json\n",
" from datasets import Dataset\n",
" from sql_env.training.trl_adapter import get_tool_definitions\n",
"\n",
" TOOL_DEFS = get_tool_definitions()\n",
" print(f\"Tools extracted: {[t['function']['name'] for t in TOOL_DEFS]}\")\n",
"\n",
" # ── Patch Qwen3 chat template for assistant_only_loss ──────────\n",
" # Qwen3's template lacks {% generation %} tags that TRL needs to\n",
" # build assistant_masks. We patch it for SFT only, then restore\n",
" # the original before GRPO (which does its own template checks).\n",
" # See: QwenLM/Qwen3#1522, huggingface/trl#4879\n",
" _original_chat_template = tokenizer.chat_template\n",
" _use_assistant_only = False\n",
"\n",
" if \"{% generation %}\" not in _original_chat_template:\n",
" _ASST_START = '{%- elif message.role == \"assistant\" %}'\n",
" _ASST_END = \"{{- '<|im_end|>\\\\n' }}\\n {%- elif message.role == \\\"tool\\\" %}\"\n",
"\n",
" _patched = _original_chat_template.replace(\n",
" _ASST_START,\n",
" _ASST_START + \"\\n {% generation %}\",\n",
" ).replace(\n",
" _ASST_END,\n",
" \"{% endgeneration %}\" + _ASST_END,\n",
" )\n",
" if \"{% generation %}\" in _patched and \"{% endgeneration %}\" in _patched:\n",
" tokenizer.chat_template = _patched\n",
" print(\"Qwen3 template patched with {% generation %} tags\")\n",
"\n",
" # Verify the patch produces assistant masks\n",
" _test_out = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": \"test\"},\n",
" {\"role\": \"user\", \"content\": \"hello\"},\n",
" {\"role\": \"assistant\", \"content\": \"world\"},\n",
" ],\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_assistant_tokens_mask=True,\n",
" )\n",
" _has_mask = (\n",
" \"assistant_masks\" in _test_out and 1 in _test_out[\"assistant_masks\"]\n",
" )\n",
" _use_assistant_only = _has_mask\n",
" else:\n",
" print(\"WARNING: Could not patch template\")\n",
"\n",
" print(f\"Assistant-only loss: {'ENABLED' if _use_assistant_only else 'DISABLED'}\")\n",
"\n",
" # Load SFT trajectories (generated by scripts/generate_sft_data.py)\n",
" sft_path = PROJECT_ROOT / \"data\" / \"sft\" / \"sft_trajectories.json\"\n",
" if not sft_path.exists():\n",
" print(\"Generating SFT trajectories...\")\n",
" subprocess.check_call(\n",
" [sys.executable, str(PROJECT_ROOT / \"scripts\" / \"generate_sft_data.py\")]\n",
" )\n",
"\n",
" with open(sft_path) as f:\n",
" sft_data = _json.load(f)\n",
"\n",
" # Ensure every trajectory has tool definitions\n",
" for row in sft_data:\n",
" if \"tools\" not in row:\n",
" row[\"tools\"] = TOOL_DEFS\n",
"\n",
" sft_dataset = Dataset.from_list(sft_data)\n",
" print(f\"SFT dataset: {len(sft_dataset)} multi-turn trajectories\")\n",
"\n",
" # ── Render training data preview ───────────────────────────────\n",
" # Uses the exact same apply_chat_template call as TRL's SFTTrainer,\n",
" # so the output file matches training input with zero divergence.\n",
" _render_path = PROJECT_ROOT / \"data\" / \"sft\" / \"sft_rendered.txt\"\n",
" _render_n = min(5, len(sft_data))\n",
" _render_parts = []\n",
" _total_tok, _total_asst = 0, 0\n",
"\n",
" for _ri in range(_render_n):\n",
" _ex = sft_data[_ri]\n",
" _msgs = _ex[\"messages\"]\n",
" _tools = _ex.get(\"tools\", TOOL_DEFS)\n",
"\n",
" _text = tokenizer.apply_chat_template(\n",
" _msgs,\n",
" tools=_tools,\n",
" tokenize=False,\n",
" )\n",
" _tok_out = tokenizer.apply_chat_template(\n",
" _msgs,\n",
" tools=_tools,\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_assistant_tokens_mask=True,\n",
" )\n",
" _n_tok = len(_tok_out[\"input_ids\"])\n",
" _mask = _tok_out.get(\"assistant_masks\", [])\n",
" _n_asst = sum(_mask) if _mask else 0\n",
" _total_tok += _n_tok\n",
" _total_asst += _n_asst\n",
"\n",
" _header = (\n",
" f\"{'=' * 70}\\n\"\n",
" f\"Example {_ri} | {_n_tok} tokens | \"\n",
" f\"{_n_asst} assistant tokens \"\n",
" f\"({_n_asst / _n_tok:.0%} of sequence)\\n\"\n",
" f\"{'=' * 70}\"\n",
" )\n",
" _render_parts.append(f\"{_header}\\n{_text}\")\n",
"\n",
" _summary = (\n",
" f\"SFT Training Data Preview \"\n",
" f\"(rendered by training tokenizer)\\n\"\n",
" f\"Model: {config.model_name}\\n\"\n",
" f\"Template patched: \"\n",
" f\"{'yes' if _use_assistant_only else 'no'}\\n\"\n",
" f\"Examples shown: {_render_n} / {len(sft_data)}\\n\"\n",
" f\"Avg tokens/example: {_total_tok / _render_n:.0f} | \"\n",
" f\"Avg assistant tokens: {_total_asst / _render_n:.0f} \"\n",
" f\"({_total_asst / _total_tok:.0%})\\n\"\n",
" )\n",
" _full = _summary + \"\\n\" + \"\\n\\n\".join(_render_parts) + \"\\n\"\n",
" _render_path.parent.mkdir(parents=True, exist_ok=True)\n",
" _render_path.write_text(_full)\n",
" print(f\"\\nRendered {_render_n} examples -> {_render_path}\")\n",
" print(\n",
" f\"Avg: {_total_tok // _render_n} tok/example, \"\n",
" f\"{_total_asst // _render_n} assistant tok \"\n",
" f\"({_total_asst / _total_tok:.0%})\"\n",
" )\n",
"\n",
" # Show first example inline\n",
" print(f\"\\n{'=' * 70}\")\n",
" print(\"Example 0 preview (first 1500 chars):\")\n",
" print(f\"{'=' * 70}\")\n",
" print(_render_parts[0][:1500])\n",
" if len(_render_parts[0]) > 1500:\n",
" print(f\"... (see {_render_path} for full output)\")\n",
"\n",
" # ── Train ──────────────────────────────────────────────────────\n",
" sft_monitor.train_dataset = sft_dataset\n",
"\n",
" from trl import SFTConfig as TRLSFTConfig, SFTTrainer\n",
"\n",
" sft_output_dir = str(PROJECT_ROOT / \"outputs\" / \"sft_warmup\")\n",
"\n",
" _sft_batch = 2 if \"1.7B\" in config.model_name else 4\n",
"\n",
" _sft_kwargs = dict(\n",
" output_dir=sft_output_dir,\n",
" num_train_epochs=2,\n",
" per_device_train_batch_size=_sft_batch,\n",
" gradient_accumulation_steps=1,\n",
" learning_rate=2e-5,\n",
" lr_scheduler_type=\"cosine\",\n",
" warmup_steps=10,\n",
" logging_steps=10,\n",
" save_strategy=\"no\",\n",
" bf16=True,\n",
" gradient_checkpointing=config.gradient_checkpointing,\n",
" assistant_only_loss=_use_assistant_only,\n",
" )\n",
" import inspect\n",
"\n",
" _sft_params = inspect.signature(TRLSFTConfig).parameters\n",
" if \"max_seq_length\" in _sft_params:\n",
" _sft_kwargs[\"max_seq_length\"] = 1024\n",
" elif \"max_length\" in _sft_params:\n",
" _sft_kwargs[\"max_length\"] = 1024\n",
"\n",
" sft_config = TRLSFTConfig(**_sft_kwargs)\n",
"\n",
" sft_trainer = SFTTrainer(\n",
" model=model,\n",
" args=sft_config,\n",
" train_dataset=sft_dataset,\n",
" processing_class=tokenizer,\n",
" callbacks=[viz, sft_monitor],\n",
" )\n",
"\n",
" _total_steps = len(sft_dataset) // _sft_batch * _sft_kwargs[\"num_train_epochs\"]\n",
" print(\n",
" f\"\\nSFT: {_sft_batch} batch x {_total_steps} steps, \"\n",
" f\"grad_ckpt: {config.gradient_checkpointing}\"\n",
" )\n",
" print(f\"Assistant-only loss: {_use_assistant_only}\")\n",
" print(\"Starting SFT warmup (2 epochs, LR=2e-5)...\")\n",
" sft_trainer.train()\n",
" print(\"SFT warmup done.\")\n",
"\n",
" # ── Restore original template for GRPO ─────────────────────────\n",
" # GRPO's get_training_chat_template() does exact-match detection\n",
" # and will fail on our patched template. Restore the original so\n",
" # GRPO can apply its own prefix-preserving template.\n",
" tokenizer.chat_template = _original_chat_template\n",
" print(\"Template restored to original for GRPO\")\n",
"\n",
" # ── Aggressive cleanup for GRPO ────────────────────────────────\n",
" import gc\n",
" import torch\n",
"\n",
" model = sft_trainer.model\n",
" if hasattr(model, \"module\"):\n",
" model = model.module\n",
" if hasattr(model, \"gradient_checkpointing_disable\"):\n",
" model.gradient_checkpointing_disable()\n",
"\n",
" del sft_trainer, sft_dataset, sft_data, sft_config\n",
" sft_monitor.train_dataset = None\n",
" gc.collect()\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" _free = torch.cuda.mem_get_info()[0] / 1e9\n",
" print(f\"GPU memory freed: {_free:.1f} GB available\")\n",
"\n",
" viz = LiveVisualizationCallback()\n",
" print(\"Model handed off to GRPO.\")\n",
"else:\n",
" print(\"Skipping SFT warmup (RUN_SFT_WARMUP = False)\")"
]
},
{
"cell_type": "markdown",
"id": "ro5feofkx6q",
"metadata": {},
"source": "## 6b) Post-SFT Format Check\nQuick generation to verify the model learned tool-call format before starting GRPO."
},
{
"cell_type": "code",
"execution_count": null,
"id": "m6nk5n7zwg",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Quick format check: generate one completion to verify tool-call quality.\n",
"# Full eval via play_episode is too expensive before GRPO (causes OOM).\n",
"from sql_env.training.trl_adapter import get_tool_definitions\n",
"\n",
"_eval_tools = get_tool_definitions()\n",
"_eval_msgs = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": prompts[0][\"question_text\"]},\n",
"]\n",
"_eval_rendered = tokenizer.apply_chat_template(\n",
" _eval_msgs,\n",
" tools=_eval_tools,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"_eval_inputs = tokenizer(_eval_rendered, return_tensors=\"pt\")\n",
"_eval_inputs = {k: v.to(model.device) for k, v in _eval_inputs.items()}\n",
"\n",
"model.eval()\n",
"with torch.no_grad():\n",
" _eval_out = model.generate(**_eval_inputs, max_new_tokens=200, do_sample=False)\n",
" _eval_new = _eval_out[0][_eval_inputs[\"input_ids\"].shape[1] :]\n",
" _eval_text = tokenizer.decode(_eval_new, skip_special_tokens=True)\n",
"model.train()\n",
"\n",
"print(\"Post-SFT format check:\")\n",
"print(f\" Q: {prompts[0]['question_text'][:80]}...\")\n",
"print(f\" → {_eval_text[:300]}\")\n",
"has_tool_call = \"\" in tokenizer.decode(_eval_new, skip_special_tokens=False)\n",
"print(f\" Tool-call format: {'YES' if has_tool_call else 'NO (plain text)'}\")\n",
"\n",
"# Free memory before GRPO\n",
"del _eval_out, _eval_inputs, _eval_new\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"id": "9zjonqe6dbg",
"metadata": {},
"source": "## 7) Train with GRPO\nStarting from the SFT-warmed model, GRPO refines SQL reasoning via reward signal from the environment."
},
{
"cell_type": "code",
"execution_count": null,
"id": "xpwe5woqqcb",
"metadata": {},
"outputs": [],
"source": [
"# Phase 1: easy only — stabilize format with KL penalty\n",
"easy_prompts = [\n",
" {\n",
" \"prompt\": [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": item[\"prompt\"]},\n",
" ],\n",
" \"question_text\": item[\"prompt\"],\n",
" }\n",
" for item in load_question_prompts(config.questions_path, [\"easy\"])\n",
"]\n",
"\n",
"print(\n",
" f\"Phase 1: {len(easy_prompts)} easy questions (beta={config.beta}, max_tokens={config.max_new_tokens})\"\n",
")\n",
"if config.enable_thinking:\n",
" print(\" Thinking mode ON — model can reason in blocks\")\n",
"\n",
"trainer = build_trainer(\n",
" trl_grpo_config_cls=TRLGRPOConfig,\n",
" grpo_trainer_cls=GRPOTrainer,\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" prompts=easy_prompts,\n",
" config=config,\n",
" reward_funcs=[sql_env_reward_func],\n",
" environment_factory=SQLEnvTRL,\n",
" callbacks=[viz],\n",
")\n",
"\n",
"before_rollouts = sample_random_baseline(\n",
" [item[\"question_text\"] for item in easy_prompts],\n",
" step_budget=config.step_budget,\n",
" seed=config.seed,\n",
")\n",
"print(f\"Random baseline episodes: {len(before_rollouts)}\")\n",
"\n",
"train_output, steps, rewards = run_training_with_metrics(trainer)\n",
"print(f\"Phase 1 done: {len(steps)} logged steps\")\n",
"\n",
"# Free reference model before phase 2\n",
"import gc\n",
"import torch\n",
"\n",
"del trainer\n",
"gc.collect()\n",
"if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" _free = torch.cuda.mem_get_info()[0] / 1e9\n",
" print(f\"GPU memory freed: {_free:.1f} GB available\")\n",
"\n",
"# Phase 2: easy + medium — no KL penalty, longer tokens\n",
"# Thinking mode gets extra headroom: 1024 → 1280\n",
"from dataclasses import replace\n",
"\n",
"_phase2_tokens = 1280 if config.enable_thinking else 1024\n",
"config2 = replace(config, beta=0.0, max_new_tokens=_phase2_tokens)\n",
"\n",
"print(\n",
" f\"\\nPhase 2: {len(prompts)} easy+medium questions (beta={config2.beta}, max_tokens={config2.max_new_tokens})\"\n",
")\n",
"\n",
"viz = LiveVisualizationCallback()\n",
"\n",
"trainer2 = build_trainer(\n",
" trl_grpo_config_cls=TRLGRPOConfig,\n",
" grpo_trainer_cls=GRPOTrainer,\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" prompts=prompts,\n",
" config=config2,\n",
" reward_funcs=[sql_env_reward_func],\n",
" environment_factory=SQLEnvTRL,\n",
" callbacks=[viz],\n",
")\n",
"\n",
"train_output2, steps2, rewards2 = run_training_with_metrics(trainer2)\n",
"print(f\"Phase 2 done: {len(steps2)} logged steps\")\n",
"\n",
"# Combine metrics\n",
"all_steps = steps + [s + (steps[-1] if steps else 0) for s in steps2]\n",
"all_rewards = rewards + rewards2\n",
"print(\n",
" f\"Total: {len(all_steps)} steps, final reward: {all_rewards[-1] if all_rewards else 'N/A'}\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c68fc96d",
"metadata": {},
"source": "## 8) Final Summary\nStatic plot of the full training run."
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e691b70",
"metadata": {},
"outputs": [],
"source": [
"if all_steps and all_rewards:\n",
" plt.figure(figsize=(8, 4))\n",
" plt.plot(all_steps, all_rewards, marker=\"o\", linewidth=1.5)\n",
" # Mark phase boundary\n",
" if steps:\n",
" plt.axvline(\n",
" x=steps[-1], color=\"red\", linestyle=\"--\", alpha=0.5, label=\"Phase 1→2\"\n",
" )\n",
" plt.legend()\n",
" plt.title(\"GRPO Reward Trend (Easy → Easy+Medium)\")\n",
" plt.xlabel(\"Training Step\")\n",
" plt.ylabel(\"Reward\")\n",
" plt.grid(alpha=0.3)\n",
" plt.show()\n",
"else:\n",
" print(\"No reward points available yet.\")"
]
},
{
"cell_type": "markdown",
"id": "98ce823b",
"metadata": {},
"source": "## 9) Save and Push to Hub\nSave the trained model locally and optionally push to HuggingFace Hub."
},
{
"cell_type": "code",
"execution_count": null,
"id": "26maekrxzz1",
"metadata": {},
"outputs": [],
"source": [
"# Save locally — use model directly (trainer was deleted during phase handoff)\n",
"model.save_pretrained(config.output_dir)\n",
"tokenizer.save_pretrained(config.output_dir)\n",
"print(f\"Model saved to {config.output_dir}\")\n",
"\n",
"# ── Run identifier (used for both Drive backup and HF repo) ────────\n",
"HF_SUFFIX = \"\" # @param {type:\"string\"}\n",
"\n",
"# Extract base model name, stripping any existing sqlenv-/grpo- prefixes\n",
"# so resuming from \"hjerpe/sqlenv-qwen3-0.6b-grpo\" doesn't double them.\n",
"import re\n",
"\n",
"_raw_short = config.model_name.split(\"/\")[-1].lower()\n",
"_model_short = re.sub(r\"^sqlenv-\", \"\", _raw_short)\n",
"_model_short = re.sub(r\"-grpo.*$\", \"\", _model_short)\n",
"_suffix = f\"-{HF_SUFFIX}\" if HF_SUFFIX and not HF_SUFFIX.startswith(\"-\") else HF_SUFFIX\n",
"_run_name = f\"sqlenv-{_model_short}-grpo{_suffix}\"\n",
"HF_REPO = f\"hjerpe/{_run_name}\"\n",
"print(f\"Run name: {_run_name} → {HF_REPO}\")\n",
"\n",
"# ── Backup to Google Drive (Drive was mounted in Setup cell) ───────\n",
"BACKUP_TO_DRIVE = True # @param {type:\"boolean\"}\n",
"\n",
"if IN_COLAB and BACKUP_TO_DRIVE:\n",
" _drive_backup_dir = f\"/content/drive/MyDrive/sqlenv-checkpoints/{_run_name}\"\n",
" if not Path(\"/content/drive/MyDrive\").exists():\n",
" # Drive wasn't pre-mounted in Setup — mount now (interactive)\n",
" from google.colab import drive\n",
"\n",
" drive.mount(\"/content/drive\", force_remount=False)\n",
" subprocess.check_call([\"mkdir\", \"-p\", _drive_backup_dir])\n",
" subprocess.check_call([\"cp\", \"-r\", f\"{config.output_dir}/.\", _drive_backup_dir])\n",
" print(f\"Backup saved to Drive: {_drive_backup_dir}\")\n",
"\n",
"# ── Push to HuggingFace Hub (HF Hub was authenticated in Setup) ────\n",
"PUSH_TO_HUB = True # @param {type:\"boolean\"}\n",
"\n",
"if PUSH_TO_HUB:\n",
" # Free GPU memory before push to avoid OOM during serialization\n",
" import gc\n",
"\n",
" if torch.cuda.is_available():\n",
" model = model.to(\"cpu\")\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
" print(\"Moved model to CPU for safer push\")\n",
"\n",
" # Retry push on transient HF Hub errors (504, etc.)\n",
" _max_retries = 3\n",
" for _attempt in range(1, _max_retries + 1):\n",
" try:\n",
" model.push_to_hub(HF_REPO, commit_message=\"GRPO trained on SQLEnv\")\n",
" tokenizer.push_to_hub(HF_REPO)\n",
" print(f\"Pushed to https://huggingface.co/{HF_REPO}\")\n",
" break\n",
" except Exception as exc:\n",
" print(f\"Push attempt {_attempt}/{_max_retries} failed: {exc}\")\n",
" if _attempt == _max_retries:\n",
" print(f\"Push failed after {_max_retries} attempts.\")\n",
" if BACKUP_TO_DRIVE and IN_COLAB:\n",
" print(f\"Recover from Drive backup: {_drive_backup_dir}\")\n",
" else:\n",
" print(f\"Local checkpoint still at: {config.output_dir}\")\n",
" raise\n",
" import time\n",
"\n",
" time.sleep(10 * _attempt)\n",
"else:\n",
" print(f\"Set PUSH_TO_HUB = True to push to {HF_REPO}\")"
]
}
],
"metadata": {
"colab": {
"name": "train_grpo.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}