{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SQLEnv: Interactive SQL Exploration for RL\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)\n", "\n", "Data analysts don't write perfect queries from scratch. They explore schemas, run test queries, observe results, and refine. SQLEnv turns this process into an RL environment.\n", "\n", "| Text-to-SQL benchmarks | SQLEnv |\n", "|---|---|\n", "| Full schema given upfront | Schema hidden — agent must DESCRIBE tables |\n", "| One-shot query generation | Multi-turn exploration with a step budget |\n", "| Binary correct/wrong | Dense 3-layer reward per step |\n", "| Static evaluation | Interactive RL environment |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup\n", "\n", "Install the environment and download the Spider databases." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "import sys\n", "from pathlib import Path\n", "\n", "IN_COLAB = \"google.colab\" in sys.modules\n", "\n", "if IN_COLAB:\n", " # Colab: clone the repo, install the package, fetch Spider databases.\n", " # Requires a GITHUB_TOKEN in Colab userdata if the repo is private.\n", " from google.colab import userdata\n", "\n", " token = userdata.get(\"GITHUB_TOKEN\")\n", " BRANCH = \"main\" # @param {type:\"string\"}\n", " repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n", "\n", " if Path(\"sql-env\").exists():\n", " subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n", " else:\n", " subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n", " os.chdir(\"sql-env\")\n", "\n", " print(\"Colab detected: installing dependencies...\")\n", " subprocess.check_call(\n", " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", \"pip\"]\n", " )\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"--no-deps\",\n", " \"--force-reinstall\",\n", " \".\",\n", " ]\n", " )\n", " subprocess.check_call(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"-q\",\n", " \"openenv-core[core]>=0.2.1\",\n", " \"pydantic>=2.0.0\",\n", " \"jmespath\",\n", " ]\n", " )\n", " # Download Spider SQLite databases the notebook reads from\n", " subprocess.check_call(\n", " [sys.executable, \"scripts/download_spider_databases.py\", \"--db-id\", \"all\"]\n", " )\n", "\n", " PROJECT_ROOT = Path.cwd()\n", "else:\n", " # Local: walk up from CWD to find the project root\n", " def find_project_root() -> Path:\n", " \"\"\"Walk up from CWD until pyproject.toml is found.\"\"\"\n", " for parent in [Path.cwd(), *Path.cwd().parents]:\n", " if (parent / \"pyproject.toml\").exists():\n", " return parent\n", " raise FileNotFoundError(\n", " \"Could not locate project root (no pyproject.toml found)\"\n", " )\n", "\n", " PROJECT_ROOT = find_project_root()\n", " os.chdir(PROJECT_ROOT)\n", "\n", "if str(PROJECT_ROOT) not in sys.path:\n", " sys.path.insert(0, str(PROJECT_ROOT))\n", "\n", "print(f\"Project root: {PROJECT_ROOT}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Create the Environment\n", "\n", "No server needed: instantiate the environment directly. `MockTokenizer` avoids downloading HuggingFace models." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded 473 questions\n" ] } ], "source": [ "from sql_env.server.sql_environment import SQLEnvironment\n", "from sql_env.server.mock_tokenizer import MockTokenizer\n", "from sql_env.models import SQLAction\n", "\n", "env = SQLEnvironment(\n", " questions_path=\"data/questions/questions_train.json\",\n", " db_dir=\"data/databases\",\n", " tokenizer=MockTokenizer(),\n", " step_budget=15,\n", ")\n", "print(f\"Loaded {len(env.questions)} questions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Start an Episode\n", "\n", "`reset()` picks a random question and loads its SQLite database. The agent sees the question and table names, but column details stay hidden until it DESCRIBEs them." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Question: Find the names of employees who never won any award in the evaluation.\n", "\n", "Schema (table names only — columns hidden):\n", "Available tables:\n", "- employee\n", "- evaluation\n", "- hiring\n", "- shop\n", "\n", "Budget: 15 steps\n" ] } ], "source": [ "obs = env.reset(seed=42)\n", "\n", "print(\"Question:\", obs.question)\n", "print(\"\\nSchema (table names only — columns hidden):\")\n", "print(obs.schema_info)\n", "print(f\"\\nBudget: {obs.budget_remaining} steps\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Explore the Action Space\n", "\n", "Four actions available: **DESCRIBE**, **SAMPLE**, **QUERY**, **ANSWER**.\n", "\n", "### DESCRIBE — reveal column details" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Describing table: employee\n", "\n", "Schema after DESCRIBE:\n", "Available tables:\n", "- employee\n", "- evaluation\n", "- hiring\n", "- shop\n", "\n", "Described tables:\n", "- employee: Employee_ID INT, Name TEXT, Age INT, City TEXT\n", "\n", "Result: Table 'employee' columns:\n", "- Employee_ID: INT\n", "- Name: TEXT\n", "- Age: INT\n", "- City: TEXT\n", "Row count: 10\n", "Reward: 0.0150 | Budget: 14\n" ] } ], "source": [ "# Pick a table from the current episode's schema\n", "first_table = obs.schema_info.split(\"\\n\")[1].strip(\"- \").strip()\n", "print(f\"Describing table: {first_table}\\n\")\n", "\n", "obs = env.step(SQLAction(action_type=\"DESCRIBE\", argument=first_table))\n", "\n", "print(\"Schema after DESCRIBE:\")\n", "print(obs.schema_info)\n", "print(f\"\\nResult: {obs.result}\")\n", "print(f\"Reward: {obs.reward:.4f} | Budget: {obs.budget_remaining}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SAMPLE — preview actual data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sample rows:\n", "Sample from 'employee':\n", "1. 1 | George Chuter | 23 | Bristol\n", "2. 2 | Lee Mears | 29 | Bath\n", "3. 3 | Mark Regan | 43 | Bristol\n", "4. 4 | Jason Hobson | 30 | Bristol\n", "5. 5 | Tim Payne | 29 | Wasps\n", "\n", "Reward: 0.0150 | Budget: 13\n" ] } ], "source": [ "obs = env.step(SQLAction(action_type=\"SAMPLE\", argument=first_table))\n", "\n", "print(\"Sample rows:\")\n", "print(obs.result[:500])\n", "print(f\"\\nReward: {obs.reward:.4f} | Budget: {obs.budget_remaining}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### QUERY — execute SQL" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Query result:\n", "1. 10\n", "\n", "Reward: 0.0625 | Budget: 12\n" ] } ], "source": [ "obs = env.step(\n", " SQLAction(action_type=\"QUERY\", argument=f\"SELECT COUNT(*) FROM {first_table}\")\n", ")\n", "\n", "print(\"Query result:\")\n", "print(obs.result)\n", "print(f\"\\nReward: {obs.reward:.4f} | Budget: {obs.budget_remaining}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ANSWER — submit and get scored" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done: True\n", "Terminal reward: 0.0\n", "Action history: [\"DESCRIBE -> Table 'employee' columns:\", \"SAMPLE -> Sample from 'employee':\", 'QUERY -> 1. 10', 'ANSWER 42 -> incorrect']\n" ] } ], "source": [ "obs = env.step(SQLAction(action_type=\"ANSWER\", argument=\"42\"))\n", "\n", "print(f\"Done: {obs.done}\")\n", "print(f\"Terminal reward: {obs.reward}\")\n", "print(f\"Action history: {obs.action_history}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": "### What just happened?\n\nEach step costs 1 from the **budget** (default 15). When budget hits 0 without an ANSWER, the episode ends with no terminal reward.\n\nThe environment gives small **rewards** after every step, not just at the end — see the live values below." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from server.reward import (\n", " _EXEC_OK_REWARD,\n", " _NEW_INFO_REWARD,\n", " _REPEAT_PENALTY,\n", " _STEP_COST,\n", " _PER_STEP_FLOOR,\n", " _PER_STEP_CAP,\n", ")\n", "\n", "print(\"Reward constants (from server/reward.py):\")\n", "print(f\" +{_EXEC_OK_REWARD} successful execution (no errors)\")\n", "print(f\" +{_NEW_INFO_REWARD} new information (unique query)\")\n", "print(f\" -{_STEP_COST} step cost (every action)\")\n", "print(f\" -{_REPEAT_PENALTY} repeat penalty (duplicate SQL)\")\n", "print(f\" [{_PER_STEP_FLOOR}, +{_PER_STEP_CAP}] per-step clipping range\")\n", "print(\" +1.0 correct answer (terminal)\")\n", "print(\" +0.0 wrong answer (terminal)\")\n", "print()\n", "print(\"Terminal correctness dominates: max exploration over 15 steps\")\n", "print(f\"is ~{15 * _PER_STEP_CAP:.1f}, while a correct answer adds 1.0.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Random Policy Baseline\n", "\n", "The `RandomPolicy` explores randomly and submits an answer on its last step. This is the lower bound." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random Policy (50 episodes):\n", " Success rate: 0.0%\n", " Avg reward: 0.247\n", " Avg steps: 15.0\n" ] } ], "source": [ "from sql_env.evaluation import RandomPolicy, evaluate\n", "\n", "random_result = evaluate(env, RandomPolicy(seed=0), n_episodes=50, seed=0)\n", "\n", "print(\"Random Policy (50 episodes):\")\n", "print(f\" Success rate: {random_result.success_rate:.1%}\")\n", "print(f\" Avg reward: {random_result.avg_reward:.3f}\")\n", "print(f\" Avg steps: {random_result.avg_steps:.1f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample random episode transcript" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: Count the number of paragraphs.\n", "\n", " SAMPLE Paragraphs\n", " → reward=0.0150 result=Sample from 'Paragraphs':\n", "1. 7 | 2394 | Korea | None\n", "2. 9 | 3 | Somalia | None\n", "3...\n", " SAMPLE Documents\n", " → reward=0.0150 result=Sample from 'Documents':\n", "1. 0 | 7 | Introduction of OS | n | None\n", "2. 1 | 25 | Un...\n", " DESCRIBE Documents\n", " → reward=0.0150 result=Table 'Documents' columns:\n", "- Document_ID: INTEGER\n", "- Template_ID: INTEGER\n", "- Docum...\n", " SAMPLE Documents\n", " → reward=0.0150 result=Sample from 'Documents':\n", "1. 0 | 7 | Introduction of OS | n | None\n", "2. 1 | 25 | Un...\n", " DESCRIBE Documents\n", " → reward=0.0150 result=Table 'Documents' columns:\n", "- Document_ID: INTEGER\n", "- Template_ID: INTEGER\n", "- Docum...\n", " DESCRIBE Documents\n", " → reward=0.0150 result=Table 'Documents' columns:\n", "- Document_ID: INTEGER\n", "- Template_ID: INTEGER\n", "- Docum...\n", " DESCRIBE Templates\n", " → reward=0.0150 result=Table 'Templates' columns:\n", "- Template_ID: INTEGER\n", "- Version_Number: INTEGER\n", "- Te...\n", " SAMPLE Documents\n", " → reward=0.0150 result=Sample from 'Documents':\n", "1. 0 | 7 | Introduction of OS | n | None\n", "2. 1 | 25 | Un...\n", " DESCRIBE Documents\n", " → reward=0.0150 result=Table 'Documents' columns:\n", "- Document_ID: INTEGER\n", "- Template_ID: INTEGER\n", "- Docum...\n", " QUERY SELECT * FROM \"Templates\" LIMIT 5\n", " → reward=0.0625 result=1. 0 | 5 | PP | 2005-11-12 07:09:48 | 2008-01-05 14:19:28 | \n", "2. 1 | 9 | PP | 201...\n", " DESCRIBE Documents\n", " → reward=0.0150 result=Table 'Documents' columns:\n", "- Document_ID: INTEGER\n", "- Template_ID: INTEGER\n", "- Docum...\n", " DESCRIBE Paragraphs\n", " → reward=0.0150 result=Table 'Paragraphs' columns:\n", "- Paragraph_ID: INTEGER\n", "- Document_ID: INTEGER\n", "- Par...\n", " QUERY SELECT * FROM \"Paragraphs\" LIMIT 5\n", " → reward=0.0250 result=1. 7 | 2394 | Korea | None\n", "2. 9 | 3 | Somalia | None\n", "3. 65 | 50123 | Palestinian...\n", " QUERY SELECT * FROM \"Documents\" LIMIT 5\n", " → reward=0.0250 result=1. 0 | 7 | Introduction of OS | n | None\n", "2. 1 | 25 | Understanding DB | y | None...\n", " ANSWER 76 | 20 | Robbin CV | y | None\n", " → reward=0.0000 result=Answer submitted: incorrect.\n", "\n", "Total reward: 0.278\n" ] } ], "source": [ "# Play one episode and print the transcript\n", "obs = env.reset(seed=7)\n", "policy = RandomPolicy(seed=7)\n", "total_reward = 0.0\n", "\n", "print(f\"Q: {obs.question}\\n\")\n", "while not obs.done:\n", " action = policy.select_action(obs)\n", " obs = env.step(action)\n", " reward = obs.reward or 0.0\n", " total_reward += reward\n", " result_preview = obs.result[:80] + \"...\" if len(obs.result) > 80 else obs.result\n", " print(f\" {action.action_type} {action.argument[:40]}\")\n", " print(f\" → reward={reward:.4f} result={result_preview}\")\n", "\n", "print(f\"\\nTotal reward: {total_reward:.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Oracle Policy — Upper Bound\n", "\n", "The `OraclePolicy` knows the gold SQL and answer. It plays the optimal strategy: DESCRIBE relevant tables, run the gold query, submit the correct answer. If the environment works, the oracle should score 100%." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Oracle Policy (50 episodes):\n", " Success rate: 100.0%\n", " Avg reward: 1.168\n", " Avg steps: 3.5\n" ] } ], "source": [ "from sql_env.evaluation import OraclePolicy\n", "from sql_env.models import QuestionRecord\n", "import json\n", "\n", "# Load questions so the oracle knows the gold answers\n", "with open(\"data/questions/questions_train.json\") as f:\n", " raw_questions = json.load(f)\n", "\n", "questions = [\n", " QuestionRecord(\n", " question_id=q[\"question_id\"],\n", " question_text=q[\"question_text\"],\n", " database_name=q[\"database_name\"],\n", " gold_sql=q[\"gold_sql\"],\n", " gold_answer=str(q[\"gold_answer\"]),\n", " answer_type=q[\"answer_type\"],\n", " difficulty=q[\"difficulty\"],\n", " tables_involved=q[\"tables_involved\"],\n", " )\n", " for q in raw_questions\n", "]\n", "\n", "oracle_result = evaluate(env, OraclePolicy(questions), n_episodes=50, seed=0)\n", "\n", "print(\"Oracle Policy (50 episodes):\")\n", "print(f\" Success rate: {oracle_result.success_rate:.1%}\")\n", "print(f\" Avg reward: {oracle_result.avg_reward:.3f}\")\n", "print(f\" Avg steps: {oracle_result.avg_steps:.1f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Side-by-side comparison" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Reward Breakdown\n", "\n", "Without dense rewards, a small model (<0.5B parameters) gets no useful signal: the correct answer is too far away. SQLEnv solves this with three reward layers:\n", "\n", "| Layer | What it rewards | Clipping |\n", "|-------|----------------|----------|\n", "| **L1: Operational** | Successful execution (+0.02), new info (+0.01), penalizes repeats (-0.03) and idle steps (-0.02) | per step [-0.10, 0.15] |\n", "| **L2: Progress** | Delta from previous query: rewards getting closer, penalizes regression (cardinality, value overlap, numeric proximity) | per step [-0.10, 0.15] |\n", "| **L3: Terminal** | Correct answer: +1.0. Wrong: 0.0 | one-shot |\n", "\n", "Terminal correctness dominates: max exploration over 15 steps is ~0.3, while a correct answer adds 1.0." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "from server.reward import (\n", " _EXEC_OK_REWARD,\n", " _NEW_INFO_REWARD,\n", " _REPEAT_PENALTY,\n", " _STEP_COST,\n", ")\n", "\n", "\n", "def format_sql(sql):\n", " \"\"\"Simple SQL formatter for display.\"\"\"\n", " kw = r\"\\b(SELECT|FROM|JOIN|ON|WHERE|GROUP BY|ORDER BY|HAVING|LIMIT|AND|OR|LEFT JOIN|RIGHT JOIN|INNER JOIN)\\b\"\n", " formatted = re.sub(kw, r\"\\n \\1\", sql, flags=re.IGNORECASE).strip()\n", " return formatted\n", "\n", "\n", "def explain_reward(action_type, error, is_repeat_query, total_reward):\n", " \"\"\"Decompose a step reward into labeled components.\n", "\n", " Layer 1 components (step_cost, exec_ok, new_info, repeat_penalty) are\n", " deterministic from action type + state, so we reconstruct them exactly\n", " from the reward constants imported above. Layer 2 (progress delta on\n", " QUERY) and Layer 3 (terminal on ANSWER) are not exposed in the\n", " observation, so we recover them as 'total reward minus L1 sum' and\n", " label them accordingly. The clip range [-0.10, +0.15] may adjust the\n", " final value — any residual after layer reconstruction is labeled\n", " 'clip_adjust'.\n", " \"\"\"\n", " at = action_type.upper()\n", " parts = [(\"step_cost\", -_STEP_COST)] # always applied\n", "\n", " if error:\n", " pass # no exec_ok when the action errored\n", " elif at == \"QUERY\" and is_repeat_query:\n", " parts.append((\"repeat_penalty\", -_REPEAT_PENALTY))\n", " else:\n", " parts.append((\"exec_ok\", +_EXEC_OK_REWARD))\n", " if at == \"QUERY\":\n", " parts.append((\"new_info\", +_NEW_INFO_REWARD))\n", "\n", " l1_sum = sum(v for _, v in parts)\n", " remainder = total_reward - l1_sum\n", "\n", " if abs(remainder) > 1e-9:\n", " if at == \"ANSWER\":\n", " parts.append((\"terminal\", remainder))\n", " elif at == \"QUERY\":\n", " parts.append((\"layer2_progress\", remainder))\n", " else:\n", " parts.append((\"clip_adjust\", remainder))\n", "\n", " labels = \" + \".join(f\"{name}({v:+.3f})\" for name, v in parts)\n", " return f\"{labels} = {total_reward:+.4f}\"\n", "\n", "\n", "# Run one oracle episode and show per-step rewards with component breakdown\n", "obs = env.reset(seed=0)\n", "oracle = OraclePolicy(questions)\n", "step_rewards = []\n", "seen_queries: set[str] = set()\n", "\n", "print(f\"Q: {obs.question}\\n\")\n", "while not obs.done:\n", " action = oracle.select_action(obs)\n", " is_repeat = action.action_type.upper() == \"QUERY\" and action.argument in seen_queries\n", " if action.action_type.upper() == \"QUERY\":\n", " seen_queries.add(action.argument)\n", "\n", " obs = env.step(action)\n", " reward = obs.reward or 0.0\n", " step_rewards.append(reward)\n", "\n", " print(f\" Step {len(step_rewards)}: {action.action_type}\")\n", " if action.action_type == \"QUERY\":\n", " print(\n", " f\" SQL:\\n {format_sql(action.argument).replace(chr(10), chr(10) + ' ')}\"\n", " )\n", " else:\n", " print(f\" Action: {action.argument}\")\n", " if obs.result:\n", " print(f\" Result: {obs.result}\")\n", " if obs.error:\n", " print(f\" Error: {obs.error}\")\n", " print(f\" Reward: {explain_reward(action.action_type, obs.error, is_repeat, reward)}\")\n", " print()\n", "\n", "exploration = sum(step_rewards[:-1]) if len(step_rewards) > 1 else 0.0\n", "terminal = step_rewards[-1] if step_rewards else 0.0\n", "\n", "print(f\"Total reward: {sum(step_rewards):.3f}\")\n", "print(f\" Exploration (L1+L2): {exploration:.3f} ({len(step_rewards) - 1} steps)\")\n", "print(f\" Terminal (L3): {terminal:.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Same Environment, Over the Wire\n", "\n", "The same `SQLEnvironment` runs as a Docker container on HuggingFace Spaces:\n", "[**huggingface.co/spaces/hjerpe/sql_env**](https://huggingface.co/spaces/hjerpe/sql_env).\n", "`SQLEnvClient` connects via WebSocket and provides the same `reset()` /\n", "`step()` interface we used above — same action space, same observation shape,\n", "same reward model. The only difference is that the SQLite database and\n", "reward computation now live on a remote container instead of in this\n", "Python process.\n", "\n", "The cell below drives one full episode against the live Space." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sql_env.client import SQLEnvClient\n", "from sql_env.models import SQLAction\n", "\n", "# Live hosted Space. This is the URL anyone in the world can point a client\n", "# at — no local setup required. The first request may take ~30s if the\n", "# container is cold-starting.\n", "SPACE_URL = \"https://hjerpe-sql-env.hf.space\"\n", "\n", "print(f\"Connecting to {SPACE_URL} ...\\n\")\n", "\n", "# openenv-core's SQLEnvClient is sync-by-default in older versions but\n", "# async-by-default in newer ones (the newer API exposes .sync() as an\n", "# explicit synchronous wrapper). Detect at runtime so the cell works on\n", "# both local dev installs and Colab's pinned >=0.2.1 version.\n", "_remote_client = SQLEnvClient(base_url=SPACE_URL)\n", "_remote_ctx = _remote_client.sync() if hasattr(_remote_client, \"sync\") else _remote_client\n", "\n", "try:\n", " with _remote_ctx as remote_env:\n", " # --- reset ---\n", " result = remote_env.reset()\n", " remote_obs = result.observation\n", " print(f\"Q: {remote_obs.question}\")\n", " tables = [\n", " line.lstrip(\"- \").strip()\n", " for line in remote_obs.schema_info.splitlines()[1:]\n", " if line.strip()\n", " ]\n", " print(f\"Tables: {tables}\\n\")\n", "\n", " first_table = tables[0]\n", "\n", " # --- describe ---\n", " result = remote_env.step(\n", " SQLAction(action_type=\"DESCRIBE\", argument=first_table)\n", " )\n", " print(f\"DESCRIBE {first_table}\")\n", " print(f\" reward: {result.observation.reward:+.4f}\")\n", " # Line-based preview so truncation never cuts mid-word\n", " _lines = result.observation.result.splitlines()\n", " _preview = \"\\n \".join(_lines[:6])\n", " _more = (\n", " f\"\\n ... ({len(_lines) - 6} more lines)\"\n", " if len(_lines) > 6\n", " else \"\"\n", " )\n", " print(f\" result: {_preview}{_more}\\n\")\n", "\n", " # --- query ---\n", " query_sql = f\"SELECT COUNT(*) FROM {first_table}\"\n", " result = remote_env.step(\n", " SQLAction(action_type=\"QUERY\", argument=query_sql)\n", " )\n", " print(f\"QUERY {query_sql}\")\n", " print(f\" reward: {result.observation.reward:+.4f}\")\n", " print(f\" result: {result.observation.result}\\n\")\n", "\n", " # --- answer (intentionally wrong — we're demoing plumbing, not correctness) ---\n", " result = remote_env.step(\n", " SQLAction(action_type=\"ANSWER\", argument=\"demo\")\n", " )\n", " print(f\"ANSWER demo\")\n", " print(f\" done: {result.observation.done}\")\n", " print(f\" reward: {result.observation.reward:+.4f}\")\n", " print(\"\\nSame action space, same observation shape, same rewards — just running remotely.\")\n", "except Exception as exc: # noqa: BLE001 — demo cell should not crash the notebook\n", " print(f\"Remote call failed: {type(exc).__name__}: {exc}\")\n", " print(\n", " \"If the Space is sleeping, the first request usually wakes it. \"\n", " \"Retry in ~30s, or skip this cell to run the notebook fully offline.\"\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What you just saw\n", "\n", "The gap between random (0% success, reward 0.25) and oracle (100% success, reward 1.19) defines the learning space. A trained agent lands somewhere between, and where it lands measures how well it learned to explore.\n", "\n", "- **Partial observability**: schema hidden until the agent earns it\n", "- **Dense rewards**: signal on every step, not just at the end\n", "- **Step budget**: forces strategic allocation of exploration\n", "- **676 questions** (473 train / 203 eval) across 10 Spider databases with difficulty labels\n", "\n", "**Next steps:**\n", "- Train a model: `notebooks/train_grpo.ipynb`\n", "- Read the design: `vision/VISION.md`\n", "- Try the live Space: [HuggingFace Space URL]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 4 }