{ "cells": [ { "cell_type": "markdown", "id": "c9e1656a", "metadata": {}, "source": [ "# Prompting Baseline vs GRPO Comparison\\n\\nThis notebook compares five SQL agent methods on the same evaluation set: zero-shot, 1-shot, 3-shot, GRPO no-think, and GRPO thinking." ] }, { "cell_type": "markdown", "id": "5d215e06", "metadata": {}, "source": [ "## 1) Setup\\nDetect Colab, install dependencies when needed, and ensure Spider databases are available." ] }, { "cell_type": "code", "execution_count": null, "id": "0ff66f07", "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "import sys\n", "from pathlib import Path\n", "\n", "IN_COLAB = \"google.colab\" in sys.modules\n", "\n", "if IN_COLAB:\n", " from google.colab import userdata\n", "\n", " token = userdata.get(\"GITHUB_TOKEN\")\n", "\n", " BRANCH = \"main\" # @param {type:\"string\"}\n", " repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n", "\n", " # Clone or update repo\n", " if Path(\"sql-env\").exists():\n", " subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n", " else:\n", " subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n", " os.chdir(\"sql-env\")\n", "\n", " print(\"Colab detected: installing dependencies...\")\n", " subprocess.check_call(\n", " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", \"pip\"]\n", " )\n", " # Uninstall vllm first to avoid version conflict with transformers 5.x\n", " subprocess.call([sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", \"vllm\"])\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"--no-deps\",\n", " \"--force-reinstall\",\n", " \".\",\n", " ]\n", " )\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"openenv-core[core]>=0.2.1\",\n", " \"torch>=2.2.0\",\n", " \"pandas>=2.0.0\",\n", " \"matplotlib>=3.7.0\",\n", " \"huggingface_hub>=0.37\",\n", " \"jmespath\",\n", " \"git+https://github.com/huggingface/transformers.git@main\",\n", " ]\n", " )\n", " subprocess.check_call([sys.executable, \"scripts/download_spider_databases.py\"])\n", " # Generate SFT data for few-shot examples\n", " subprocess.check_call([sys.executable, \"scripts/generate_sft_data.py\"])\n", "\n", "project_root = Path.cwd().resolve()\n", "if (\n", " not (project_root / \"pyproject.toml\").exists()\n", " and (project_root / \"sql-env\").exists()\n", "):\n", " project_root = (project_root / \"sql-env\").resolve()\n", " os.chdir(project_root)\n", "\n", "if str(project_root) not in sys.path:\n", " sys.path.insert(0, str(project_root))\n", "\n", "print(f\"Project root: {project_root}\")\n", "print(f\"Running in Colab: {IN_COLAB}\")" ] }, { "cell_type": "markdown", "id": "83a8bfbb", "metadata": {}, "source": [ "## 2) Configuration" ] }, { "cell_type": "code", "execution_count": null, "id": "65decdf7", "metadata": {}, "outputs": [], "source": [ "N_EVAL_EPISODES = 50 # @param {type:\"integer\"}\n", "STEP_BUDGET = 15\n", "SEED = 42\n", "\n", "QUESTIONS_PATH = \"data/questions/questions_eval.json\"\n", "DB_DIR = \"data/databases\"\n", "\n", "# ── Pick your base model ───────────────────────────────────────────\n", "# Qwen3-0.6B → fits T4 (16 GB) comfortably\n", "# Qwen3-1.7B → fits L4 (24 GB) with gradient checkpointing\n", "BASE_MODEL_NAME = \"Qwen/Qwen3-0.6B\" # @param [\"Qwen/Qwen3-0.6B\", \"Qwen/Qwen3-1.7B\"]\n", "\n", "# ── Pick your trained checkpoints ─────────────────────────────────\n", "# Should match the base model size you trained with.\n", "# Set to \"none\" to skip that condition.\n", "# Multiple GRPO checkpoints can be compared (e.g. v1 vs v2).\n", "GRPO_MODEL_REPO = \"hjerpe/sqlenv-qwen3-0.6b-grpo\" # @param [\"hjerpe/sqlenv-qwen3-0.6b-grpo\", \"hjerpe/sqlenv-qwen3-1.7b-grpo\", \"none\"]\n", "GRPO_V2_MODEL_REPO = \"none\" # @param [\"hjerpe/sqlenv-qwen3-0.6b-grpo-v2\", \"none\"]\n", "GRPO_THINKING_MODEL_REPO = \"none\" # @param [\"hjerpe/sqlenv-qwen3-0.6b-grpo-think\", \"hjerpe/sqlenv-qwen3-1.7b-grpo-think\", \"none\"]\n", "\n", "print(f\"Base model: {BASE_MODEL_NAME}\")\n", "print(f\"GRPO checkpoint (v1): {GRPO_MODEL_REPO}\")\n", "print(f\"GRPO checkpoint (v2): {GRPO_V2_MODEL_REPO}\")\n", "print(f\"GRPO thinking checkpoint: {GRPO_THINKING_MODEL_REPO}\")\n", "print(f\"Eval episodes: {N_EVAL_EPISODES}, Step budget: {STEP_BUDGET}\")" ] }, { "cell_type": "markdown", "id": "28bf9de3", "metadata": {}, "source": [ "## 3) Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "118aafdc", "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "import gc\n", "import json\n", "import random\n", "import re\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch\n", "\n", "from sql_env import SQLAction, SQLObservation\n", "from sql_env.evaluation.policies import EvaluationResult, evaluate\n", "from sql_env.server.sql_environment import SQLEnvironment\n", "from sql_env.server.mock_tokenizer import MockTokenizer\n", "from sql_env.training.data_loading import load_model_and_tokenizer\n", "from sql_env.training.trl_adapter import get_tool_definitions\n", "from scripts.generate_sft_data import get_system_prompt" ] }, { "cell_type": "markdown", "id": "b6ed47e5", "metadata": {}, "source": [ "## 4) Environment and Eval Data\\nCreate an environment instance and load the fixed evaluation set used by all comparison methods." ] }, { "cell_type": "code", "execution_count": null, "id": "00579e5b", "metadata": {}, "outputs": [], "source": [ "from sql_env.training.data_loading import validate_no_data_leak\n", "\n", "questions_path = Path(QUESTIONS_PATH)\n", "db_dir = Path(DB_DIR)\n", "\n", "if not questions_path.exists():\n", " raise FileNotFoundError(f\"Questions file not found: {questions_path}\")\n", "\n", "if not db_dir.exists():\n", " print(\"Database directory missing, downloading Spider databases...\")\n", " subprocess.check_call([sys.executable, \"scripts/download_spider_databases.py\"])\n", "\n", "# Guard against train/eval data leakage\n", "train_path = Path(\"data/questions/questions_train.json\")\n", "if train_path.exists():\n", " validate_no_data_leak(str(train_path), str(questions_path))\n", " print(\"Data leak check: PASSED (0 question overlap)\")\n", "\n", "env = SQLEnvironment(\n", " questions_path=str(questions_path),\n", " db_dir=str(db_dir),\n", " tokenizer=MockTokenizer(),\n", " step_budget=STEP_BUDGET,\n", ")\n", "\n", "with questions_path.open(\"r\", encoding=\"utf-8\") as handle:\n", " eval_questions = json.load(handle)\n", "\n", "print(f\"Loaded {len(eval_questions)} eval questions from {questions_path}\")\n", "print(f\"Environment ready with step budget = {STEP_BUDGET}\")" ] }, { "cell_type": "markdown", "id": "84d2e123", "metadata": {}, "source": [ "## 6) LLMToolCallingPolicy\n", "Drive model inference using tool-calling chat templates with episode-aware history and parse-error fallback." ] }, { "cell_type": "code", "execution_count": null, "id": "814a69dc", "metadata": {}, "outputs": [], "source": [ "class LLMToolCallingPolicy:\n", " \"\"\"Policy that mirrors TRL environment_factory rollouts exactly.\n", "\n", " TRL's rollout (see trl/trainer/grpo_trainer.py _tool_call_loop):\n", " 1. Generate until EOS / max_new_tokens (no stop at ).\n", " 2. Parse ALL blocks from the completion.\n", " 3. Append ONE assistant message with structured tool_calls list.\n", " 4. Execute each call in order, appending one role:\"tool\" message each.\n", " 5. Regenerate.\n", "\n", " The model never saw intermediate tool results while emitting a multi-call\n", " turn during training — it committed to all N calls up-front. We buffer\n", " parsed calls and drain them across N select_action invocations so the\n", " per-step evaluation loop (evaluate() in policies.py) sees exactly the\n", " same history the training rollout produced.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " model,\n", " tokenizer,\n", " tool_definitions: list[dict],\n", " system_prompt: str,\n", " few_shot_messages: list[dict] | None = None,\n", " enable_thinking: bool = False,\n", " max_new_tokens: int = 512,\n", " verbose: bool = False,\n", " ) -> None:\n", " self.model = model\n", " self.tokenizer = tokenizer\n", " self.tool_definitions = tool_definitions\n", " self.system_prompt = system_prompt\n", " self.few_shot_messages = list(few_shot_messages or [])\n", " self.enable_thinking = enable_thinking\n", " self.max_new_tokens = max_new_tokens\n", " self.verbose = verbose\n", " self._messages: list[dict] = []\n", " self._last_question: str | None = None\n", " # Pending actions parsed from a single model generation, drained\n", " # one-per-select_action. `pending_count` is the total number of\n", " # tool calls that the model emitted in the current turn (so we\n", " # can distinguish \"this is the last one\" from \"more to come\").\n", " self._pending_actions: list[dict] = []\n", " self._pending_count: int = 0\n", " # True on the very first select_action of an episode, so we know\n", " # NOT to append a tool message for a nonexistent previous action.\n", " self._expect_tool_result: bool = False\n", " # Logging counters\n", " self.stats = {\n", " \"total_calls\": 0, # select_action invocations\n", " \"generations\": 0, # model.generate calls\n", " \"parse_ok\": 0, # individual tool calls parsed from generations\n", " \"parse_fail\": 0, # generations that produced zero parseable calls\n", " \"multi_call_turns\": 0, # generations that produced >1 tool call\n", " \"budget_exhaust\": 0,\n", " \"parse_retries\": 0,\n", " }\n", " self.parse_errors: list[str] = []\n", " self._current_question: str | None = None\n", " self.failed_answers: list[dict] = []\n", " self._episode_count = 0\n", " self.reset()\n", "\n", " def reset(self) -> None:\n", " self._messages = [{\"role\": \"system\", \"content\": self.system_prompt}]\n", " self._messages.extend(self.few_shot_messages)\n", " self._last_question = None\n", " self._pending_actions = []\n", " self._pending_count = 0\n", " self._expect_tool_result = False\n", "\n", " def _to_device(self, value, device):\n", " return value.to(device) if hasattr(value, \"to\") else value\n", "\n", " def _render_and_tokenize(self):\n", " try:\n", " rendered = self.tokenizer.apply_chat_template(\n", " self._messages,\n", " tools=self.tool_definitions,\n", " add_generation_prompt=True,\n", " tokenize=False,\n", " enable_thinking=self.enable_thinking,\n", " )\n", " except TypeError:\n", " rendered = self.tokenizer.apply_chat_template(\n", " self._messages,\n", " tools=self.tool_definitions,\n", " add_generation_prompt=True,\n", " tokenize=False,\n", " )\n", " encoded = self.tokenizer(rendered, return_tensors=\"pt\")\n", " return encoded[\"input_ids\"], encoded.get(\"attention_mask\")\n", "\n", " def _parse_all_tool_calls(self, text: str) -> list[dict]:\n", " \"\"\"Extract every JSON block from a completion.\n", "\n", " Returns a list of {\"name\": str, \"arguments\": dict} dicts. Invalid\n", " blocks are silently skipped (they'd be dropped by TRL's parser too).\n", " \"\"\"\n", " if not text:\n", " return []\n", " pattern = re.compile(r\"\\s*(\\{.*?\\})\\s*\", re.DOTALL)\n", " out: list[dict] = []\n", " for raw_json in pattern.findall(text):\n", " try:\n", " obj = json.loads(raw_json)\n", " except json.JSONDecodeError:\n", " continue\n", " if not isinstance(obj, dict):\n", " continue\n", " name = obj.get(\"name\")\n", " args = obj.get(\"arguments\")\n", " if isinstance(name, str) and isinstance(args, dict):\n", " out.append({\"name\": name, \"arguments\": args})\n", " return out\n", "\n", " def select_action(self, observation: SQLObservation) -> SQLAction:\n", " self.stats[\"total_calls\"] += 1\n", "\n", " if observation.budget_remaining <= 1:\n", " self.stats[\"budget_exhaust\"] += 1\n", " # Flush any pending work; the episode is about to end anyway.\n", " self._pending_actions = []\n", " return SQLAction(action_type=\"ANSWER\", argument=\"budget_exhausted\")\n", "\n", " # New episode — reset message history and post the first user turn.\n", " if observation.question != self._last_question:\n", " self.reset()\n", " self._last_question = observation.question\n", " self._current_question = observation.question\n", " self._episode_count += 1\n", "\n", " tables = []\n", " for line in (observation.schema_info or \"\").split(\"\\n\"):\n", " stripped = line.strip().lstrip(\"- \").strip()\n", " if stripped and stripped.lower() != \"available tables:\":\n", " tables.append(stripped)\n", " # Matches TRL: reset() return string is appended to the last\n", " # user message. See SQLEnvTRL.reset().\n", " table_hint = (\n", " f\"Tables: {', '.join(tables)}. \"\n", " \"Use describe, sample, query, and answer tools.\"\n", " )\n", " self._messages.append(\n", " {\n", " \"role\": \"user\",\n", " \"content\": f\"{observation.question}{table_hint}\",\n", " }\n", " )\n", " self._expect_tool_result = False\n", " elif self._expect_tool_result:\n", " # The previous select_action returned an action; env.step just\n", " # executed it and passed us its observation. Append the result\n", " # as a role:\"tool\" message, matching TRL's per-call append.\n", " result_text = observation.result or observation.error or \"\"\n", " self._messages.append({\"role\": \"tool\", \"content\": result_text})\n", " self._expect_tool_result = False\n", "\n", " # If we still have tool calls buffered from a multi-call generation,\n", " # return the next one WITHOUT regenerating. This preserves training\n", " # semantics: the model committed to all calls before seeing any\n", " # results, so regenerating mid-batch would be a protocol violation.\n", " if self._pending_actions:\n", " tool_call = self._pending_actions.pop(0)\n", " action = _tool_call_to_action(tool_call)\n", " self._expect_tool_result = True\n", " if self.verbose:\n", " remaining = len(self._pending_actions)\n", " tag = f\"[OK/buf+{remaining}]\" if remaining else \"[OK/buf]\"\n", " print(f\" {tag} {action.action_type}: {str(action.argument)[:80]}\")\n", " return action\n", "\n", " # Otherwise generate a fresh turn.\n", " input_ids, attention_mask = self._render_and_tokenize()\n", "\n", " model_device = getattr(self.model, \"device\", None)\n", " if model_device is None:\n", " model_device = next(self.model.parameters()).device\n", " input_ids = self._to_device(input_ids, model_device)\n", " if attention_mask is not None:\n", " attention_mask = self._to_device(attention_mask, model_device)\n", "\n", " if self.verbose and self.stats[\"generations\"] < 3:\n", " print(\n", " f\" [ctx] {input_ids.shape[-1]} input tokens, max_new={self.max_new_tokens}\"\n", " )\n", "\n", " generate_kwargs = {\n", " \"input_ids\": input_ids,\n", " \"max_new_tokens\": self.max_new_tokens,\n", " }\n", " if attention_mask is not None:\n", " generate_kwargs[\"attention_mask\"] = attention_mask\n", "\n", " with torch.no_grad():\n", " output_ids = self.model.generate(**generate_kwargs)\n", " self.stats[\"generations\"] += 1\n", "\n", " generated_ids = output_ids[0, input_ids.shape[-1] :]\n", " generated_text = self.tokenizer.decode(\n", " generated_ids, skip_special_tokens=True\n", " ).strip()\n", " generated_text_full = self.tokenizer.decode(\n", " generated_ids, skip_special_tokens=False\n", " ).strip()\n", "\n", " # Parse ALL tool calls from this generation (matches TRL). Try the\n", " # skip-special-tokens=False variant as a fallback because sometimes\n", " # the tool_call text sits next to a special token boundary.\n", " parsed: list[dict] = []\n", " for text_variant in (generated_text, generated_text_full):\n", " parsed = self._parse_all_tool_calls(text_variant)\n", " if parsed:\n", " break\n", "\n", " if not parsed:\n", " # Parse failure — model didn't emit a valid tool call.\n", " self.stats[\"parse_fail\"] += 1\n", " if len(self.parse_errors) < 5:\n", " self.parse_errors.append(\n", " f\"--- Parse error #{len(self.parse_errors) + 1} ---\\n\"\n", " f\"text={generated_text[:300]}\"\n", " )\n", " if self.verbose:\n", " print(f\" [PARSE FAIL] raw: {generated_text[:200]}\")\n", "\n", " # Don't end the episode on parse failure — append the\n", " # failed output as an assistant message and let the\n", " # episode continue. The evaluate() loop will call\n", " # select_action again with the next observation.\n", " # No extra coaching prompt — keeps format identical\n", " # to what trained models see.\n", " self._messages.append({\"role\": \"assistant\", \"content\": generated_text})\n", " self.failed_answers.append(\n", " {\n", " \"question\": self._current_question,\n", " \"raw_text\": generated_text[:500],\n", " \"episode\": self._episode_count,\n", " }\n", " )\n", " self._expect_tool_result = False\n", " # Return a no-op DESCRIBE on the first available table\n", " # so the episode doesn't end. The model wasted a turn\n", " # but gets to keep exploring.\n", " tables = []\n", " for line in (observation.schema_info or \"\").split(\"\\n\"):\n", " stripped = line.strip().lstrip(\"- \").strip()\n", " if (\n", " stripped\n", " and stripped.lower() != \"available tables:\"\n", " and \":\" not in stripped\n", " ):\n", " tables.append(stripped)\n", " if tables:\n", " return SQLAction(action_type=\"DESCRIBE\", argument=tables[0])\n", " return SQLAction(action_type=\"ANSWER\", argument=\"parse_error\")\n", "\n", " # Append ONE assistant message containing the full tool_calls list\n", " # (matches what TRL appends after a multi-call generation).\n", " self._messages.append(\n", " {\n", " \"role\": \"assistant\",\n", " \"tool_calls\": [\n", " {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": tc[\"name\"],\n", " \"arguments\": json.dumps(tc[\"arguments\"]),\n", " },\n", " }\n", " for tc in parsed\n", " ],\n", " }\n", " )\n", " self.stats[\"parse_ok\"] += len(parsed)\n", " if len(parsed) > 1:\n", " self.stats[\"multi_call_turns\"] += 1\n", "\n", " # First call goes out now; the rest are buffered.\n", " self._pending_actions = parsed[1:]\n", " self._pending_count = len(parsed)\n", " first = parsed[0]\n", " action = _tool_call_to_action(first)\n", " self._expect_tool_result = True\n", " if self.verbose:\n", " tag = f\"[OK/{len(parsed)}]\" if len(parsed) > 1 else \"[OK]\"\n", " print(f\" {tag} {action.action_type}: {str(action.argument)[:80]}\")\n", " return action\n", "\n", " def print_stats(self, label: str = \"\") -> None:\n", " \"\"\"Print action statistics for debugging.\"\"\"\n", " s = self.stats\n", " gens = s[\"generations\"] or 1\n", " print(f\"\\n{'=' * 60}\")\n", " print(f\"Policy stats{f' ({label})' if label else ''}:\")\n", " print(f\" select_action calls: {s['total_calls']}\")\n", " print(f\" model.generate calls: {s['generations']}\")\n", " print(f\" Tool calls parsed: {s['parse_ok']}\")\n", " print(\n", " f\" Multi-call turns: {s['multi_call_turns']} \"\n", " f\"({s['multi_call_turns'] / gens:.0%} of generations)\"\n", " )\n", " print(\n", " f\" Parse failures: {s['parse_fail']} \"\n", " f\"({s['parse_fail'] / gens:.0%} of generations)\"\n", " )\n", " print(f\" Budget exhaust: {s['budget_exhaust']}\")\n", " print(f\" Parse retries: {s['parse_retries']}\")\n", " print(f\" Failed answer attempts logged: {len(self.failed_answers)}\")\n", " if self.parse_errors:\n", " print(f\"\\nFirst {len(self.parse_errors)} parse failure samples:\")\n", " for sample in self.parse_errors:\n", " print(sample)\n", " print(f\"{'=' * 60}\")\n", "\n", "\n", "def _tool_call_to_action(tool_call: dict) -> SQLAction:\n", " \"\"\"Convert a parsed {name, arguments} dict into an SQLAction.\"\"\"\n", " name = str(tool_call[\"name\"]).strip().lower()\n", " arguments = tool_call[\"arguments\"]\n", " if not isinstance(arguments, dict):\n", " raise ValueError(\"Tool call arguments must be a dictionary\")\n", "\n", " if name == \"describe\":\n", " argument = arguments.get(\"table_name\", arguments.get(\"table\"))\n", " action_type = \"DESCRIBE\"\n", " elif name == \"sample\":\n", " argument = arguments.get(\"table_name\", arguments.get(\"table\"))\n", " action_type = \"SAMPLE\"\n", " elif name == \"query\":\n", " argument = arguments.get(\"sql\")\n", " action_type = \"QUERY\"\n", " elif name == \"answer\":\n", " argument = arguments.get(\"value\", arguments.get(\"answer\"))\n", " action_type = \"ANSWER\"\n", " else:\n", " raise ValueError(f\"Unsupported tool name: {name}\")\n", "\n", " if argument is None:\n", " raise ValueError(f\"Missing required argument for tool: {name}\")\n", " return SQLAction(action_type=action_type, argument=str(argument))" ] }, { "cell_type": "markdown", "id": "e04e13a4", "metadata": {}, "source": "## 7) Few-Shot Example Builder\nBuild few-shot examples that demonstrate the complete tool-calling loop:\nquestion → describe → result → query → result → answer.\n\nExamples use the same message format as the evaluation policy (user/assistant roles,\n`` tags) so the model learns the exact pattern it needs to follow." }, { "cell_type": "code", "execution_count": null, "id": "8bc81b8d", "metadata": {}, "outputs": [], "source": [ "def build_few_shot_messages(\n", " sft_path: str,\n", " n_examples: int,\n", " seed: int = 42,\n", ") -> list[dict]:\n", " \"\"\"Build few-shot messages from SFT trajectories.\n", "\n", " SFT trajectories already use the exact training format\n", " (user → assistant-with-tool_calls → tool → ...), so we just pass the\n", " messages through verbatim. This guarantees zero drift between what the\n", " model saw during SFT warmup / GRPO and what it sees in evaluation.\n", " \"\"\"\n", " if n_examples <= 0:\n", " return []\n", "\n", " sft_file = Path(sft_path)\n", " if not sft_file.exists():\n", " raise FileNotFoundError(f\"SFT trajectories not found: {sft_file}\")\n", "\n", " with sft_file.open(\"r\", encoding=\"utf-8\") as f:\n", " trajectories = json.load(f)\n", "\n", " # Prefer trajectories explicitly marked correct if the flag is present.\n", " has_correct = any(\"correct\" in t for t in trajectories if isinstance(t, dict))\n", " if has_correct:\n", " candidates = [\n", " t\n", " for t in trajectories\n", " if isinstance(t, dict)\n", " and t.get(\"correct\") is True\n", " and isinstance(t.get(\"messages\"), list)\n", " ]\n", " else:\n", " candidates = [\n", " t\n", " for t in trajectories\n", " if isinstance(t, dict) and isinstance(t.get(\"messages\"), list)\n", " ]\n", "\n", " if not candidates:\n", " print(\"Warning: no valid SFT trajectories found for few-shot examples\")\n", " return []\n", "\n", " chosen = random.Random(seed).sample(candidates, min(n_examples, len(candidates)))\n", "\n", " few_shot: list[dict] = []\n", " for traj in chosen:\n", " for msg in traj[\"messages\"]:\n", " role = msg.get(\"role\")\n", " # Skip the system prompt — the policy adds its own.\n", " if role == \"system\":\n", " continue\n", " if role in (\"user\", \"assistant\", \"tool\"):\n", " few_shot.append(dict(msg))\n", "\n", " n_msgs = len(few_shot)\n", " n_asst = sum(1 for m in few_shot if m.get(\"role\") == \"assistant\")\n", " print(\n", " f\"Few-shot: {len(chosen)} trajectories -> {n_msgs} messages ({n_asst} assistant turns)\"\n", " )\n", " return few_shot" ] }, { "cell_type": "markdown", "id": "9aaeab16", "metadata": {}, "source": [ "## 8) Base Model 3-Condition Evaluation\n", "Load the base Qwen model once, build zero-shot/1-shot/3-shot policies, and run a fair comparison on the same eval split." ] }, { "cell_type": "code", "execution_count": null, "id": "5070d45a", "metadata": {}, "outputs": [], "source": [ "# ── Diagnostic: verify template + generation before full eval ──\n", "# This cell runs ONE episode step to confirm the pipeline works end-to-end.\n", "\n", "_diag_model, _diag_tokenizer = load_model_and_tokenizer(BASE_MODEL_NAME)\n", "if torch.cuda.is_available():\n", " _diag_model = _diag_model.to(torch.device(\"cuda\"))\n", "\n", "_diag_prompt = get_system_prompt(enable_thinking=False)\n", "_diag_tools = get_tool_definitions()\n", "\n", "# 1) Test template rendering\n", "_diag_messages = [\n", " {\"role\": \"system\", \"content\": _diag_prompt},\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"How many students are there?\\n\\nTables: students. Use describe, sample, query, and answer tools.\",\n", " },\n", "]\n", "try:\n", " _rendered = _diag_tokenizer.apply_chat_template(\n", " _diag_messages,\n", " tools=_diag_tools,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", " enable_thinking=False,\n", " )\n", " print(f\"Template OK ({len(_rendered)} chars)\")\n", " print(f\"First 500 chars:\\n{_rendered[:500]}\")\n", " print(f\"...\\nLast 200 chars:\\n{_rendered[-200:]}\")\n", "except Exception as e:\n", " print(f\"Template FAILED: {e}\")\n", " _rendered = None\n", "\n", "# 2) Test generation\n", "if _rendered:\n", " _inputs = _diag_tokenizer(_rendered, return_tensors=\"pt\")\n", " _inputs = {k: v.to(_diag_model.device) for k, v in _inputs.items()}\n", " print(f\"\\nInput tokens: {_inputs['input_ids'].shape[-1]}\")\n", "\n", " with torch.no_grad():\n", " _out = _diag_model.generate(**_inputs, max_new_tokens=200)\n", " _new_ids = _out[0][_inputs[\"input_ids\"].shape[-1] :]\n", " _text = _diag_tokenizer.decode(_new_ids, skip_special_tokens=True)\n", " _text_full = _diag_tokenizer.decode(_new_ids, skip_special_tokens=False)\n", " print(f\"\\nGenerated (skip_special=True):\\n{_text[:500]}\")\n", " print(f\"\\nGenerated (skip_special=False):\\n{_text_full[:500]}\")\n", " print(f\"\\nHas : {'' in _text_full}\")\n", "\n", "# 3) Test one full episode\n", "print(\"\\n--- One episode test ---\")\n", "_obs = env.reset(seed=42)\n", "print(f\"Question: {_obs.question[:80]}\")\n", "print(f\"Schema: {_obs.schema_info[:100]}\")\n", "print(f\"Budget: {_obs.budget_remaining}, Done: {_obs.done}\")\n", "\n", "_policy = LLMToolCallingPolicy(\n", " model=_diag_model,\n", " tokenizer=_diag_tokenizer,\n", " tool_definitions=_diag_tools,\n", " system_prompt=_diag_prompt,\n", " verbose=True,\n", ")\n", "try:\n", " _action = _policy.select_action(_obs)\n", " print(f\"Action: {_action.action_type} = {_action.argument[:100]}\")\n", "except Exception as e:\n", " print(f\"select_action FAILED: {type(e).__name__}: {e}\")\n", "\n", "_policy.print_stats(\"diagnostic\")\n", "\n", "# Cleanup\n", "del _diag_model, _diag_tokenizer, _policy\n", "gc.collect()\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", "print(\"\\nDiagnostic complete.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "faacf08b", "metadata": {}, "outputs": [], "source": [ "SFT_TRAJECTORIES_PATH = \"data/sft/sft_trajectories.json\"\n", "\n", "tool_definitions = get_tool_definitions()\n", "system_prompt_nothink = get_system_prompt(enable_thinking=False)\n", "\n", "few_shot_1 = build_few_shot_messages(SFT_TRAJECTORIES_PATH, n_examples=1, seed=SEED)\n", "few_shot_3 = build_few_shot_messages(SFT_TRAJECTORIES_PATH, n_examples=3, seed=SEED)\n", "\n", "print(f\"Loaded few-shot messages: 1-shot={len(few_shot_1)}, 3-shot={len(few_shot_3)}\")\n", "\n", "base_model, base_tokenizer = load_model_and_tokenizer(BASE_MODEL_NAME)\n", "if torch.cuda.is_available():\n", " base_model = base_model.to(torch.device(\"cuda\"))\n", "print(f\"Loaded base model: {BASE_MODEL_NAME}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "337bff7b", "metadata": {}, "outputs": [], "source": [ "def _progress(name: str):\n", " def _callback(done: int, total: int) -> None:\n", " if done == 1 or done == total or done % max(1, total // 5) == 0:\n", " print(f\"[{name}] {done}/{total} episodes\")\n", "\n", " return _callback\n", "\n", "\n", "# verbose=True on zero-shot so we can see raw model output for first episode\n", "base_conditions = [\n", " {\n", " \"name\": \"zero-shot\",\n", " \"policy\": LLMToolCallingPolicy(\n", " model=base_model,\n", " tokenizer=base_tokenizer,\n", " tool_definitions=tool_definitions,\n", " system_prompt=system_prompt_nothink,\n", " few_shot_messages=None,\n", " enable_thinking=False,\n", " verbose=True,\n", " ),\n", " },\n", " {\n", " \"name\": \"1-shot\",\n", " \"policy\": LLMToolCallingPolicy(\n", " model=base_model,\n", " tokenizer=base_tokenizer,\n", " tool_definitions=tool_definitions,\n", " system_prompt=system_prompt_nothink,\n", " few_shot_messages=few_shot_1,\n", " enable_thinking=False,\n", " ),\n", " },\n", " {\n", " \"name\": \"3-shot\",\n", " \"policy\": LLMToolCallingPolicy(\n", " model=base_model,\n", " tokenizer=base_tokenizer,\n", " tool_definitions=tool_definitions,\n", " system_prompt=system_prompt_nothink,\n", " few_shot_messages=few_shot_3,\n", " enable_thinking=False,\n", " ),\n", " },\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "f4f9c222", "metadata": {}, "outputs": [], "source": [ "base_results: dict[str, EvaluationResult] = {}\n", "base_policies: dict[str, LLMToolCallingPolicy] = {}\n", "\n", "for condition in base_conditions:\n", " name = condition[\"name\"]\n", " policy = condition[\"policy\"]\n", " print(f\"\\nRunning condition: {name}\")\n", " result = evaluate(\n", " env,\n", " policy,\n", " n_episodes=N_EVAL_EPISODES,\n", " seed=SEED,\n", " progress_callback=_progress(name),\n", " )\n", " base_results[name] = result\n", " base_policies[name] = policy\n", " print(\n", " f\"[{name}] accuracy={result.success_rate:.3f} avg_reward={result.avg_reward:.3f} \"\n", " f\"avg_steps={result.avg_steps:.2f} completed={result.n_completed}/{result.n_episodes}\"\n", " )\n", " policy.print_stats(label=name)\n", "\n", " error_eps = [ep for ep in result.episodes if ep.error]\n", " if error_eps:\n", " print(f\" Episodes with errors: {len(error_eps)}\")\n", " for ep in error_eps[:3]:\n", " print(f\" ep#{ep.episode_index}: {ep.error[:120]}\")\n", "\n", " # Turn off verbose after first condition\n", " policy.verbose = False\n", "\n", "print(\"\\nBase-model comparison complete.\")\n", "print(f\"Collected results for: {', '.join(base_results.keys())}\")" ] }, { "cell_type": "markdown", "id": "7988b07b", "metadata": {}, "source": [ "## 9) GRPO Checkpoint Evaluation\n", "Load GRPO checkpoints with graceful fallback so unavailable models do not block comparison runs." ] }, { "cell_type": "code", "execution_count": null, "id": "5e1ea7b1", "metadata": {}, "outputs": [], "source": [ "SYSTEM_PROMPT_THINK = get_system_prompt(enable_thinking=True)\n", "\n", "for var_name in (\"base_conditions\", \"base_model\", \"base_tokenizer\"):\n", " if var_name in globals():\n", " del globals()[var_name]\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", "gc.collect()\n", "\n", "grpo_conditions = [\n", " {\n", " \"name\": \"grpo-v1\",\n", " \"repo_id\": GRPO_MODEL_REPO,\n", " \"enable_thinking\": False,\n", " \"system_prompt\": system_prompt_nothink,\n", " },\n", " {\n", " \"name\": \"grpo-v2\",\n", " \"repo_id\": GRPO_V2_MODEL_REPO,\n", " \"enable_thinking\": False,\n", " \"system_prompt\": system_prompt_nothink,\n", " },\n", " {\n", " \"name\": \"grpo-thinking\",\n", " \"repo_id\": GRPO_THINKING_MODEL_REPO,\n", " \"enable_thinking\": True,\n", " \"system_prompt\": SYSTEM_PROMPT_THINK,\n", " },\n", "]\n", "\n", "grpo_results: dict[str, EvaluationResult] = {}\n", "grpo_policies: dict[str, LLMToolCallingPolicy] = {}\n", "\n", "for cfg in grpo_conditions:\n", " if cfg[\"repo_id\"] == \"none\":\n", " print(f\"\\nSkipping {cfg['name']} (set to 'none')\")\n", " continue\n", "\n", " model = None\n", " tokenizer = None\n", " policy = None\n", " print(f\"\\nLoading checkpoint: {cfg['repo_id']}\")\n", " try:\n", " model, tokenizer = load_model_and_tokenizer(cfg[\"repo_id\"])\n", " except RuntimeError as exc:\n", " print(f\"Warning: Could not load {cfg['repo_id']}. Skipping condition. ({exc})\")\n", " continue\n", "\n", " try:\n", " if torch.cuda.is_available():\n", " model = model.to(torch.device(\"cuda\"))\n", "\n", " policy = LLMToolCallingPolicy(\n", " model=model,\n", " tokenizer=tokenizer,\n", " tool_definitions=tool_definitions,\n", " system_prompt=cfg[\"system_prompt\"],\n", " few_shot_messages=None,\n", " enable_thinking=cfg[\"enable_thinking\"],\n", " verbose=True,\n", " )\n", "\n", " result = evaluate(\n", " env,\n", " policy,\n", " n_episodes=N_EVAL_EPISODES,\n", " seed=SEED,\n", " progress_callback=_progress(cfg[\"name\"]),\n", " )\n", " grpo_results[cfg[\"name\"]] = result\n", " grpo_policies[cfg[\"name\"]] = policy\n", " print(\n", " f\"[{cfg['name']}] accuracy={result.success_rate:.3f} avg_reward={result.avg_reward:.3f} \"\n", " f\"avg_steps={result.avg_steps:.2f} completed={result.n_completed}/{result.n_episodes}\"\n", " )\n", " policy.print_stats(label=cfg[\"name\"])\n", "\n", " error_eps = [ep for ep in result.episodes if ep.error]\n", " if error_eps:\n", " print(f\" Episodes with errors: {len(error_eps)}\")\n", " for ep in error_eps[:3]:\n", " print(f\" ep#{ep.episode_index}: {ep.error[:120]}\")\n", " finally:\n", " # Don't delete policy — we need it for analysis\n", " if model is not None:\n", " del model\n", " if tokenizer is not None:\n", " del tokenizer\n", " gc.collect()\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", "\n", "all_results: dict[str, EvaluationResult] = {**base_results, **grpo_results}\n", "print(\"\\nGRPO checkpoint evaluation complete.\")\n", "checkpoint_names = \", \".join(grpo_results.keys()) if grpo_results else \"none\"\n", "print(f\"Checkpoint results collected for: {checkpoint_names}\")\n", "print(f\"Total methods available for comparison: {', '.join(all_results.keys())}\")" ] }, { "cell_type": "markdown", "id": "0428cd13", "metadata": {}, "source": [ "## 10) Comparison Results\n\nThis section compares all available methods on the same evaluation subset:\n- **zero-shot**: base model, tool calling, no examples\n- **1-shot**: base model with one successful trajectory example\n- **3-shot**: base model with three successful trajectory examples\n- **grpo-no-think**: GRPO checkpoint without thinking mode\n- **grpo-thinking**: GRPO checkpoint with thinking mode\n\nIf a GRPO checkpoint is unavailable, that row is omitted automatically." ] }, { "cell_type": "code", "execution_count": null, "id": "edacc7b6", "metadata": {}, "outputs": [], "source": [ "def results_to_dataframe(\n", " results: dict[str, EvaluationResult],\n", " policies: dict[str, LLMToolCallingPolicy] | None = None,\n", ") -> pd.DataFrame:\n", " \"\"\"Convert evaluation results + policy stats to a comparison DataFrame.\"\"\"\n", " if not results:\n", " return pd.DataFrame()\n", "\n", " ordered_names = [\n", " \"zero-shot\",\n", " \"1-shot\",\n", " \"3-shot\",\n", " \"grpo-v1\",\n", " \"grpo-v2\",\n", " \"grpo-thinking\",\n", " ]\n", "\n", " rows = []\n", " for name in list(ordered_names) + [n for n in results if n not in ordered_names]:\n", " if name not in results:\n", " continue\n", " item = results[name]\n", "\n", " row = {\n", " \"Method\": name,\n", " \"Accuracy (%)\": round(item.success_rate * 100.0, 1),\n", " \"Avg Reward\": round(item.avg_reward, 3),\n", " \"Avg Steps\": round(item.avg_steps, 1),\n", " }\n", "\n", " # Add policy stats if available\n", " policy = (policies or {}).get(name)\n", " if policy:\n", " total = policy.stats[\"total_calls\"] or 1\n", " row[\"Parse Rate (%)\"] = round(policy.stats[\"parse_ok\"] / total * 100, 1)\n", " row[\"Parse Fails\"] = policy.stats[\"parse_fail\"]\n", " row[\"Budget Exhaust\"] = policy.stats[\"budget_exhaust\"]\n", "\n", " row[\"Completed\"] = f\"{item.n_completed}/{item.n_episodes}\"\n", " rows.append(row)\n", "\n", " return pd.DataFrame(rows)\n", "\n", "\n", "def plot_comparison(df: pd.DataFrame) -> None:\n", " \"\"\"Display grouped bar chart comparing accuracy and parse rate.\"\"\"\n", " if df.empty or \"Accuracy (%)\" not in df.columns:\n", " print(\"No results to plot.\")\n", " return\n", "\n", " has_parse = \"Parse Rate (%)\" in df.columns\n", " plot_df = df.sort_values(\"Accuracy (%)\", ascending=True).reset_index(drop=True)\n", " colors_acc = [\n", " \"#2b6cb0\" if not m.startswith(\"grpo\") else \"#2f855a\" for m in plot_df[\"Method\"]\n", " ]\n", "\n", " if has_parse:\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n", " else:\n", " fig, ax1 = plt.subplots(figsize=(10, 5))\n", "\n", " # Accuracy chart\n", " bars = ax1.barh(plot_df[\"Method\"], plot_df[\"Accuracy (%)\"], color=colors_acc)\n", " ax1.set_xlim(0, 100)\n", " ax1.set_xlabel(\"Accuracy (%)\")\n", " ax1.set_title(\"Answer Accuracy\")\n", " ax1.grid(axis=\"x\", alpha=0.25)\n", " for bar, value in zip(bars, plot_df[\"Accuracy (%)\"]):\n", " ax1.text(\n", " min(value + 1, 95),\n", " bar.get_y() + bar.get_height() / 2,\n", " f\"{value:.1f}%\",\n", " va=\"center\",\n", " )\n", "\n", " # Parse rate chart\n", " if has_parse:\n", " colors_parse = [\n", " \"#805ad5\" if not m.startswith(\"grpo\") else \"#38a169\"\n", " for m in plot_df[\"Method\"]\n", " ]\n", " bars2 = ax2.barh(\n", " plot_df[\"Method\"], plot_df[\"Parse Rate (%)\"], color=colors_parse\n", " )\n", " ax2.set_xlim(0, 100)\n", " ax2.set_xlabel(\"Parse Rate (%)\")\n", " ax2.set_title(\"Tool-Call Format Compliance\")\n", " ax2.grid(axis=\"x\", alpha=0.25)\n", " for bar, value in zip(bars2, plot_df[\"Parse Rate (%)\"]):\n", " ax2.text(\n", " min(value + 1, 95),\n", " bar.get_y() + bar.get_height() / 2,\n", " f\"{value:.1f}%\",\n", " va=\"center\",\n", " )\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", "# Merge all policies for the table\n", "all_policies = dict(base_policies)\n", "if \"grpo_policies\" in dir() and grpo_policies:\n", " all_policies.update(grpo_policies)\n", "\n", "comparison_df = results_to_dataframe(all_results, all_policies)\n", "if comparison_df.empty:\n", " print(\"No comparison results available. Run evaluation cells first.\")\n", "else:\n", " display(comparison_df)\n", " plot_comparison(comparison_df)" ] }, { "cell_type": "markdown", "id": "e0c352d8", "metadata": {}, "source": [ "### Result Interpretation (Template)\n", "\n", "- Compare **zero-shot -> 1-shot -> 3-shot** to measure pure prompt-engineering gains.\n", "- Compare **best prompting method vs GRPO checkpoints** to quantify training value.\n", "- If only prompting rows are present, verify checkpoint availability and rerun GRPO cells.\n", "- Keep `N_EVAL_EPISODES` and `SEED` fixed when comparing future runs for fairness." ] }, { "cell_type": "markdown", "id": "6f374466", "metadata": {}, "source": "### Failed Answer Analysis\nWhen the model fails to produce `` tags, it often still outputs an answer in natural language or bare JSON.\nThis analysis checks: **was the reasoning correct even when the format was wrong?**\n\nThis separates two distinct failure modes:\n- **Format failure only** — model knew the answer but didn't use `` tags\n- **Both format and reasoning failure** — model didn't know the answer either" }, { "cell_type": "code", "execution_count": null, "id": "2e1bee19", "metadata": {}, "outputs": [], "source": [ "from sql_env.server.verifier import verify_answer\n", "\n", "# Load gold answers for post-hoc checking\n", "with questions_path.open(\"r\", encoding=\"utf-8\") as f:\n", " _gold_data = json.load(f)\n", "_gold_lookup = {q[\"question_text\"]: q for q in _gold_data}\n", "\n", "\n", "def _extract_candidate_answer(raw_text: str) -> str | None:\n", " \"\"\"Try to extract an answer value from unstructured model output.\"\"\"\n", " text = raw_text.strip()\n", "\n", " # Try bare JSON: {\"name\": \"answer\", \"arguments\": {\"value\": \"...\"}}\n", " try:\n", " obj = json.loads(text)\n", " if isinstance(obj, dict):\n", " if \"value\" in obj:\n", " return str(obj[\"value\"])\n", " args = obj.get(\"arguments\", {})\n", " if isinstance(args, dict) and \"value\" in args:\n", " return str(args[\"value\"])\n", " except json.JSONDecodeError:\n", " pass\n", "\n", " # Try to find a short numeric or simple answer in the text\n", " # Skip long explanatory text — only extract if it looks like a direct value\n", " lines = [ln.strip() for ln in text.split(\"\\n\") if ln.strip()]\n", " if len(lines) == 1 and len(lines[0]) < 50:\n", " return lines[0]\n", "\n", " return None\n", "\n", "\n", "def analyze_failed_answers(policy: LLMToolCallingPolicy, label: str) -> dict:\n", " \"\"\"Check failed answer attempts against gold answers.\"\"\"\n", " results = {\n", " \"total_failures\": len(policy.failed_answers),\n", " \"extractable\": 0,\n", " \"correct_answer_wrong_format\": 0,\n", " \"wrong_answer_wrong_format\": 0,\n", " \"not_extractable\": 0,\n", " \"examples_correct\": [],\n", " \"examples_wrong\": [],\n", " }\n", "\n", " for entry in policy.failed_answers:\n", " question = entry[\"question\"]\n", " raw_text = entry[\"raw_text\"]\n", " gold = _gold_lookup.get(question)\n", "\n", " if gold is None:\n", " continue\n", "\n", " candidate = _extract_candidate_answer(raw_text)\n", " if candidate is None:\n", " results[\"not_extractable\"] += 1\n", " continue\n", "\n", " results[\"extractable\"] += 1\n", " gold_answer = gold[\"gold_answer\"]\n", " answer_type = gold.get(\"answer_type\", \"string\")\n", "\n", " try:\n", " is_correct = verify_answer(candidate, gold_answer, answer_type)\n", " except Exception:\n", " is_correct = False\n", "\n", " if is_correct:\n", " results[\"correct_answer_wrong_format\"] += 1\n", " if len(results[\"examples_correct\"]) < 3:\n", " results[\"examples_correct\"].append(\n", " {\n", " \"question\": question[:80],\n", " \"predicted\": candidate[:60],\n", " \"gold\": str(gold_answer)[:60],\n", " }\n", " )\n", " else:\n", " results[\"wrong_answer_wrong_format\"] += 1\n", " if len(results[\"examples_wrong\"]) < 3:\n", " results[\"examples_wrong\"].append(\n", " {\n", " \"question\": question[:80],\n", " \"predicted\": candidate[:60],\n", " \"gold\": str(gold_answer)[:60],\n", " }\n", " )\n", "\n", " return results\n", "\n", "\n", "# Analyze all conditions\n", "all_policies = dict(base_policies)\n", "# Add GRPO policies if they were saved\n", "if \"grpo_policies\" in dir():\n", " all_policies.update(grpo_policies)\n", "\n", "analysis_rows = []\n", "for name, policy in all_policies.items():\n", " result = all_results.get(name)\n", " analysis = analyze_failed_answers(policy, name)\n", "\n", " formal_correct = int(result.success_rate * result.n_completed) if result else 0\n", " total_episodes = result.n_completed if result else 0\n", "\n", " print(f\"\\n{'=' * 60}\")\n", " print(f\"Failed answer analysis: {name}\")\n", " print(f\" Episodes: {total_episodes}\")\n", " print(f\" Formally correct (format + answer): {formal_correct}\")\n", " print(f\" Parse failures total: {analysis['total_failures']}\")\n", " print(f\" - Extractable answer: {analysis['extractable']}\")\n", " print(\n", " f\" - Correct answer, wrong format: {analysis['correct_answer_wrong_format']}\"\n", " )\n", " print(\n", " f\" - Wrong answer, wrong format: {analysis['wrong_answer_wrong_format']}\"\n", " )\n", " print(f\" - Not extractable: {analysis['not_extractable']}\")\n", "\n", " if analysis[\"examples_correct\"]:\n", " print(\"\\n Examples — correct answer but wrong format:\")\n", " for ex in analysis[\"examples_correct\"]:\n", " print(f\" Q: {ex['question']}\")\n", " print(f\" Predicted: {ex['predicted']} | Gold: {ex['gold']}\")\n", "\n", " # Count episodes where model answered in correct format but got wrong answer.\n", " # Use episodes_with_failures (capped at total) not raw failure count.\n", " episodes_with_parse_fail = (\n", " min(\n", " len([fa for fa in policy.failed_answers if fa.get(\"episode\", 0) > 0]),\n", " total_episodes,\n", " )\n", " if hasattr(policy, \"failed_answers\")\n", " else analysis[\"total_failures\"]\n", " )\n", " # Deduplicate by episode number\n", " _fail_episodes = (\n", " set(fa.get(\"episode\", 0) for fa in policy.failed_answers)\n", " if hasattr(policy, \"failed_answers\")\n", " else set()\n", " )\n", " episodes_with_parse_fail = len(_fail_episodes)\n", " format_ok_answer_wrong = total_episodes - formal_correct - episodes_with_parse_fail\n", " print(f\" Format OK, answer wrong: {format_ok_answer_wrong}\")\n", " analysis_rows.append(\n", " {\n", " \"Method\": name,\n", " \"Format+Answer OK\": formal_correct,\n", " \"Format OK, Answer wrong\": format_ok_answer_wrong,\n", " \"Format fail, Answer OK\": analysis[\"correct_answer_wrong_format\"],\n", " \"Format fail, Answer wrong\": analysis[\"wrong_answer_wrong_format\"],\n", " \"Not extractable\": analysis[\"not_extractable\"],\n", " \"Total episodes\": total_episodes,\n", " }\n", " )\n", "\n", "print(f\"\\n{'=' * 60}\")\n", "analysis_df = pd.DataFrame(analysis_rows)\n", "if not analysis_df.empty:\n", " display(analysis_df)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }