Upload folder using huggingface_hub
Browse files- .pytest_cache/.gitignore +2 -0
- .pytest_cache/CACHEDIR.TAG +4 -0
- .pytest_cache/README.md +8 -0
- .pytest_cache/v/cache/nodeids +14 -0
- Dockerfile +19 -0
- README.md +29 -10
- inference.py +225 -0
- openenv.yaml +104 -0
- pyproject.toml +31 -0
- requirements.txt +7 -0
- src/__init__.py +0 -0
- src/sql_arena/__init__.py +24 -0
- src/sql_arena/database.py +156 -0
- src/sql_arena/environment.py +200 -0
- src/sql_arena/graders.py +220 -0
- src/sql_arena/models.py +104 -0
- src/sql_arena/server.py +265 -0
- src/sql_arena/tasks.py +593 -0
- tests/__init__.py +0 -0
- tests/test_env.py +125 -0
.pytest_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by pytest automatically.
|
| 2 |
+
*
|
.pytest_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
| 2 |
+
# This file is a cache directory tag created by pytest.
|
| 3 |
+
# For information about cache directory tags, see:
|
| 4 |
+
# https://bford.info/cachedir/spec.html
|
.pytest_cache/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytest cache directory #
|
| 2 |
+
|
| 3 |
+
This directory contains data from the pytest's cache plugin,
|
| 4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
| 5 |
+
|
| 6 |
+
**Do not** commit this to version control.
|
| 7 |
+
|
| 8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
.pytest_cache/v/cache/nodeids
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"tests/test_env.py::TestAllDifficulties::test_easy",
|
| 3 |
+
"tests/test_env.py::TestAllDifficulties::test_hard",
|
| 4 |
+
"tests/test_env.py::TestAllDifficulties::test_medium",
|
| 5 |
+
"tests/test_env.py::TestEnvironmentBasics::test_episode_terminates",
|
| 6 |
+
"tests/test_env.py::TestEnvironmentBasics::test_reset_returns_observation",
|
| 7 |
+
"tests/test_env.py::TestEnvironmentBasics::test_state_tracking",
|
| 8 |
+
"tests/test_env.py::TestEnvironmentBasics::test_step_with_correct_query",
|
| 9 |
+
"tests/test_env.py::TestEnvironmentBasics::test_step_with_invalid_query",
|
| 10 |
+
"tests/test_env.py::TestGrading::test_scores_in_range",
|
| 11 |
+
"tests/test_env.py::TestGrading::test_varying_scores",
|
| 12 |
+
"tests/test_env.py::TestTaskRegistry::test_list_tasks",
|
| 13 |
+
"tests/test_env.py::TestTaskRegistry::test_minimum_3_tasks"
|
| 14 |
+
]
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
EXPOSE 7860
|
| 15 |
+
|
| 16 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s \
|
| 17 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
| 18 |
+
|
| 19 |
+
CMD ["uvicorn", "src.sql_arena.server:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
CHANGED
|
@@ -1,10 +1,29 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
--
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Arena - OpenEnv Environment
|
| 2 |
+
|
| 3 |
+
An interactive SQL query challenge environment where AI agents learn to write SQL
|
| 4 |
+
by iteratively querying databases and receiving execution feedback with partial credit scoring.
|
| 5 |
+
|
| 6 |
+
## Real-World Utility
|
| 7 |
+
|
| 8 |
+
Text-to-SQL is one of the most valuable capabilities for AI agents:
|
| 9 |
+
- Used by data analysts, business users, and developers daily
|
| 10 |
+
- Evaluates reasoning, schema understanding, and query composition
|
| 11 |
+
- Directly applicable to production AI assistants and copilots
|
| 12 |
+
- SQL Arena provides interactive iterative feedback (not just static benchmarks)
|
| 13 |
+
|
| 14 |
+
## Tasks
|
| 15 |
+
|
| 16 |
+
| Task | Difficulty | Description | Max Steps |
|
| 17 |
+
|------|-----------|-------------|-----------|
|
| 18 |
+
| basic_select | Easy | SELECT, WHERE, ORDER BY on single table | 5 |
|
| 19 |
+
| join_aggregate | Medium | Multi-table JOINs, GROUP BY, HAVING | 7 |
|
| 20 |
+
| complex_analysis | Hard | CTEs, window functions, subqueries | 10 |
|
| 21 |
+
|
| 22 |
+
Each difficulty has 3+ unique problems with deterministic, reproducible grading.
|
| 23 |
+
|
| 24 |
+
## Action Space
|
| 25 |
+
|
| 26 |
+
```json
|
| 27 |
+
{
|
| 28 |
+
"sql_query": "SELECT name, salary FROM employees WHERE salary > 80000 ORDER BY salary DESC"
|
| 29 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script - SQL Arena OpenEnv Environment
|
| 3 |
+
Baseline agent that uses an LLM to solve SQL challenges.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import textwrap
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
from src.sql_arena.environment import SQLArenaEnvironment
|
| 15 |
+
from src.sql_arena.models import SQLArenaAction
|
| 16 |
+
|
| 17 |
+
# =====================================================
|
| 18 |
+
# Configuration
|
| 19 |
+
# =====================================================
|
| 20 |
+
|
| 21 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 22 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 23 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 24 |
+
|
| 25 |
+
BENCHMARK = "sql_arena"
|
| 26 |
+
TEMPERATURE = 0.3
|
| 27 |
+
MAX_TOKENS = 500
|
| 28 |
+
|
| 29 |
+
TASKS = [
|
| 30 |
+
{"difficulty": "basic_select", "task_id": "easy_001", "name": "basic_select", "max_steps": 5},
|
| 31 |
+
{"difficulty": "join_aggregate", "task_id": "medium_001", "name": "join_aggregate", "max_steps": 7},
|
| 32 |
+
{"difficulty": "complex_analysis", "task_id": "hard_001", "name": "complex_analysis", "max_steps": 10},
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# =====================================================
|
| 37 |
+
# Logging (MANDATORY format)
|
| 38 |
+
# =====================================================
|
| 39 |
+
|
| 40 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 41 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 45 |
+
error_val = error if error else "null"
|
| 46 |
+
done_val = str(done).lower()
|
| 47 |
+
action_short = action.replace('\n', ' ').strip()[:100]
|
| 48 |
+
print(
|
| 49 |
+
f"[STEP] step={step} action={action_short} reward={reward:.2f} done={done_val} error={error_val}",
|
| 50 |
+
flush=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 55 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 56 |
+
print(
|
| 57 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
|
| 58 |
+
flush=True,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# =====================================================
|
| 63 |
+
# LLM Agent
|
| 64 |
+
# =====================================================
|
| 65 |
+
|
| 66 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 67 |
+
You are an expert SQL query writer. You are interacting with a SQL challenge environment.
|
| 68 |
+
|
| 69 |
+
Each turn you receive: database schema, a question, previous query results, and feedback.
|
| 70 |
+
Your goal: Write a SQL query that correctly answers the question.
|
| 71 |
+
|
| 72 |
+
Rules:
|
| 73 |
+
- Output ONLY the SQL query, nothing else
|
| 74 |
+
- No explanations, no markdown, no code fences
|
| 75 |
+
- Use standard SQLite syntax
|
| 76 |
+
- Be precise with column names and table names
|
| 77 |
+
- If your previous query had errors, fix them based on the feedback
|
| 78 |
+
""").strip()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_user_prompt(observation: dict, step: int, history: List[str]) -> str:
|
| 82 |
+
parts = []
|
| 83 |
+
parts.append(f"=== SQL Challenge (Step {step}) ===")
|
| 84 |
+
parts.append(f"\nDifficulty: {observation.get('difficulty', 'unknown')}")
|
| 85 |
+
parts.append(f"\n--- Database Schema ---\n{observation.get('schema_description', '')}")
|
| 86 |
+
parts.append(f"\n--- Question ---\n{observation.get('question', '')}")
|
| 87 |
+
|
| 88 |
+
if observation.get('expected_columns'):
|
| 89 |
+
parts.append(f"\n--- Expected Columns ---\n{observation['expected_columns']}")
|
| 90 |
+
if observation.get('query_result'):
|
| 91 |
+
parts.append(f"\n--- Previous Query Result ---\n{observation['query_result']}")
|
| 92 |
+
if observation.get('error_message'):
|
| 93 |
+
parts.append(f"\n--- Error ---\n{observation['error_message']}")
|
| 94 |
+
if observation.get('feedback'):
|
| 95 |
+
parts.append(f"\n--- Feedback ---\n{observation['feedback']}")
|
| 96 |
+
|
| 97 |
+
parts.append(f"\nAttempts remaining: {observation.get('attempts_remaining', 0)}")
|
| 98 |
+
|
| 99 |
+
if history:
|
| 100 |
+
parts.append("\n--- Previous Attempts ---")
|
| 101 |
+
for h in history[-3:]:
|
| 102 |
+
parts.append(h)
|
| 103 |
+
|
| 104 |
+
parts.append("\nWrite your SQL query now:")
|
| 105 |
+
return "\n".join(parts)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_sql_from_llm(client: OpenAI, observation: dict, step: int, history: List[str]) -> str:
|
| 109 |
+
user_prompt = build_user_prompt(observation, step, history)
|
| 110 |
+
try:
|
| 111 |
+
completion = client.chat.completions.create(
|
| 112 |
+
model=MODEL_NAME,
|
| 113 |
+
messages=[
|
| 114 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 115 |
+
{"role": "user", "content": user_prompt},
|
| 116 |
+
],
|
| 117 |
+
temperature=TEMPERATURE,
|
| 118 |
+
max_tokens=MAX_TOKENS,
|
| 119 |
+
stream=False,
|
| 120 |
+
)
|
| 121 |
+
raw = (completion.choices[0].message.content or "").strip()
|
| 122 |
+
sql = raw
|
| 123 |
+
if sql.startswith("```sql"):
|
| 124 |
+
sql = sql[6:]
|
| 125 |
+
if sql.startswith("```"):
|
| 126 |
+
sql = sql[3:]
|
| 127 |
+
if sql.endswith("```"):
|
| 128 |
+
sql = sql[:-3]
|
| 129 |
+
sql = sql.strip()
|
| 130 |
+
return sql if sql else "SELECT 1"
|
| 131 |
+
except Exception as exc:
|
| 132 |
+
print(f"[DEBUG] LLM request failed: {exc}", flush=True)
|
| 133 |
+
return "SELECT 1"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# =====================================================
|
| 137 |
+
# Main Inference Loop
|
| 138 |
+
# =====================================================
|
| 139 |
+
|
| 140 |
+
def run_task(client: OpenAI, env: SQLArenaEnvironment, task_config: dict) -> float:
|
| 141 |
+
difficulty = task_config["difficulty"]
|
| 142 |
+
task_id = task_config["task_id"]
|
| 143 |
+
task_name = task_config["name"]
|
| 144 |
+
max_steps = task_config["max_steps"]
|
| 145 |
+
|
| 146 |
+
history: List[str] = []
|
| 147 |
+
rewards: List[float] = []
|
| 148 |
+
steps_taken = 0
|
| 149 |
+
best_score = 0.0
|
| 150 |
+
|
| 151 |
+
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
result = env.reset(difficulty=difficulty, task_id=task_id)
|
| 155 |
+
obs_dict = result.observation.model_dump()
|
| 156 |
+
|
| 157 |
+
for step in range(1, max_steps + 1):
|
| 158 |
+
if result.done:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
sql_query = get_sql_from_llm(client, obs_dict, step, history)
|
| 162 |
+
|
| 163 |
+
action = SQLArenaAction(sql_query=sql_query)
|
| 164 |
+
result = env.step(action)
|
| 165 |
+
|
| 166 |
+
obs_dict = result.observation.model_dump()
|
| 167 |
+
reward = result.reward
|
| 168 |
+
done = result.done
|
| 169 |
+
error = obs_dict.get("error_message")
|
| 170 |
+
|
| 171 |
+
rewards.append(reward)
|
| 172 |
+
steps_taken = step
|
| 173 |
+
best_score = max(best_score, result.info.get("score", 0.0))
|
| 174 |
+
|
| 175 |
+
log_step(step=step, action=sql_query, reward=reward, done=done, error=error)
|
| 176 |
+
|
| 177 |
+
history.append(
|
| 178 |
+
f"Step {step}: {sql_query[:80]}... -> reward={reward:.2f}"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if done:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
final_score = min(max(best_score, 0.0), 1.0)
|
| 185 |
+
success = final_score >= 0.5
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"[DEBUG] Task {task_name} error: {e}", flush=True)
|
| 189 |
+
final_score = 0.0
|
| 190 |
+
success = False
|
| 191 |
+
|
| 192 |
+
finally:
|
| 193 |
+
log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
|
| 194 |
+
|
| 195 |
+
return final_score
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def main() -> None:
|
| 199 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 200 |
+
env = SQLArenaEnvironment()
|
| 201 |
+
|
| 202 |
+
all_scores = []
|
| 203 |
+
|
| 204 |
+
for task_config in TASKS:
|
| 205 |
+
print(f"\n{'='*60}", flush=True)
|
| 206 |
+
print(f"Running task: {task_config['name']} ({task_config['difficulty']})", flush=True)
|
| 207 |
+
print(f"{'='*60}", flush=True)
|
| 208 |
+
|
| 209 |
+
score = run_task(client, env, task_config)
|
| 210 |
+
all_scores.append(score)
|
| 211 |
+
print(f"\nTask {task_config['name']} final score: {score:.2f}\n", flush=True)
|
| 212 |
+
|
| 213 |
+
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
| 214 |
+
print(f"\n{'='*60}", flush=True)
|
| 215 |
+
print("SUMMARY", flush=True)
|
| 216 |
+
print(f"{'='*60}", flush=True)
|
| 217 |
+
for tc, sc in zip(TASKS, all_scores):
|
| 218 |
+
print(f" {tc['name']:20s}: {sc:.2f}", flush=True)
|
| 219 |
+
print(f" {'Average':20s}: {avg_score:.2f}", flush=True)
|
| 220 |
+
|
| 221 |
+
env.close()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql_arena
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
Interactive SQL query challenge environment where AI agents learn to write SQL
|
| 5 |
+
by iteratively querying databases and receiving execution feedback with partial credit.
|
| 6 |
+
|
| 7 |
+
author: "Vudumula Naga Sai Rahul"
|
| 8 |
+
license: "MIT"
|
| 9 |
+
|
| 10 |
+
interface:
|
| 11 |
+
action:
|
| 12 |
+
type: object
|
| 13 |
+
model: sql_arena.models.SQLArenaAction
|
| 14 |
+
properties:
|
| 15 |
+
sql_query:
|
| 16 |
+
type: string
|
| 17 |
+
description: "SQL query to execute against the database"
|
| 18 |
+
|
| 19 |
+
observation:
|
| 20 |
+
type: object
|
| 21 |
+
model: sql_arena.models.SQLArenaObservation
|
| 22 |
+
properties:
|
| 23 |
+
schema_description:
|
| 24 |
+
type: string
|
| 25 |
+
question:
|
| 26 |
+
type: string
|
| 27 |
+
query_result:
|
| 28 |
+
type: string
|
| 29 |
+
nullable: true
|
| 30 |
+
error_message:
|
| 31 |
+
type: string
|
| 32 |
+
nullable: true
|
| 33 |
+
feedback:
|
| 34 |
+
type: string
|
| 35 |
+
nullable: true
|
| 36 |
+
expected_columns:
|
| 37 |
+
type: array
|
| 38 |
+
nullable: true
|
| 39 |
+
attempts_remaining:
|
| 40 |
+
type: integer
|
| 41 |
+
difficulty:
|
| 42 |
+
type: string
|
| 43 |
+
task_id:
|
| 44 |
+
type: string
|
| 45 |
+
|
| 46 |
+
state:
|
| 47 |
+
type: object
|
| 48 |
+
model: sql_arena.models.SQLArenaState
|
| 49 |
+
|
| 50 |
+
tasks:
|
| 51 |
+
- id: basic_select
|
| 52 |
+
name: "Basic SELECT Queries"
|
| 53 |
+
description: "Simple SELECT, WHERE, ORDER BY queries"
|
| 54 |
+
difficulty: easy
|
| 55 |
+
max_steps: 5
|
| 56 |
+
subtasks:
|
| 57 |
+
- easy_001
|
| 58 |
+
- easy_002
|
| 59 |
+
- easy_003
|
| 60 |
+
|
| 61 |
+
- id: join_aggregate
|
| 62 |
+
name: "JOIN and Aggregate Queries"
|
| 63 |
+
description: "Multi-table JOINs with GROUP BY, HAVING"
|
| 64 |
+
difficulty: medium
|
| 65 |
+
max_steps: 7
|
| 66 |
+
subtasks:
|
| 67 |
+
- medium_001
|
| 68 |
+
- medium_002
|
| 69 |
+
- medium_003
|
| 70 |
+
|
| 71 |
+
- id: complex_analysis
|
| 72 |
+
name: "Complex Analysis Queries"
|
| 73 |
+
description: "CTEs, window functions, subqueries"
|
| 74 |
+
difficulty: hard
|
| 75 |
+
max_steps: 10
|
| 76 |
+
subtasks:
|
| 77 |
+
- hard_001
|
| 78 |
+
- hard_002
|
| 79 |
+
- hard_003
|
| 80 |
+
|
| 81 |
+
grading:
|
| 82 |
+
score_range: [0.0, 1.0]
|
| 83 |
+
components:
|
| 84 |
+
- name: execution
|
| 85 |
+
weight: 0.10
|
| 86 |
+
description: "Query executes without errors"
|
| 87 |
+
- name: columns
|
| 88 |
+
weight: 0.20
|
| 89 |
+
description: "Correct column names"
|
| 90 |
+
- name: row_count
|
| 91 |
+
weight: 0.20
|
| 92 |
+
description: "Correct number of rows"
|
| 93 |
+
- name: values
|
| 94 |
+
weight: 0.50
|
| 95 |
+
description: "Correct data values"
|
| 96 |
+
|
| 97 |
+
server:
|
| 98 |
+
framework: fastapi
|
| 99 |
+
entrypoint: src.sql_arena.server:app
|
| 100 |
+
port: 7860
|
| 101 |
+
|
| 102 |
+
deployment:
|
| 103 |
+
platform: huggingface-spaces
|
| 104 |
+
docker: true
|
pyproject.toml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-arena"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Interactive SQL query challenge OpenEnv environment"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Vudumula Naga Sai Rahul", email = "nagasairahulvudumula@gmail.com"}
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"fastapi>=0.104.0",
|
| 17 |
+
"uvicorn[standard]>=0.24.0",
|
| 18 |
+
"pydantic>=2.5.0",
|
| 19 |
+
"websockets>=12.0",
|
| 20 |
+
"openai>=1.0.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.optional-dependencies]
|
| 24 |
+
dev = [
|
| 25 |
+
"pytest>=7.0",
|
| 26 |
+
"httpx>=0.25.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[tool.setuptools.packages.find]
|
| 30 |
+
where = ["."]
|
| 31 |
+
include = ["src*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.104.0
|
| 2 |
+
uvicorn[standard]>=0.24.0
|
| 3 |
+
pydantic>=2.5.0
|
| 4 |
+
websockets>=12.0
|
| 5 |
+
openai>=1.0.0
|
| 6 |
+
pytest>=7.0
|
| 7 |
+
httpx>=0.25.0
|
src/__init__.py
ADDED
|
File without changes
|
src/sql_arena/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL Arena - Interactive SQL Query Challenge Environment for OpenEnv.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
|
| 6 |
+
from .environment import SQLArenaEnvironment, StepResult
|
| 7 |
+
from .tasks import get_task, list_tasks, SQLTask, ALL_TASKS, TASK_BY_ID
|
| 8 |
+
from .graders import grade_result
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"SQLArenaAction",
|
| 12 |
+
"SQLArenaObservation",
|
| 13 |
+
"SQLArenaState",
|
| 14 |
+
"SQLArenaEnvironment",
|
| 15 |
+
"StepResult",
|
| 16 |
+
"get_task",
|
| 17 |
+
"list_tasks",
|
| 18 |
+
"SQLTask",
|
| 19 |
+
"ALL_TASKS",
|
| 20 |
+
"TASK_BY_ID",
|
| 21 |
+
"grade_result",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
__version__ = "1.0.0"
|
src/sql_arena/database.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite Database Manager for SQL Arena.
|
| 3 |
+
|
| 4 |
+
Creates in-memory SQLite databases for each task.
|
| 5 |
+
Executes agent queries safely and formats results.
|
| 6 |
+
|
| 7 |
+
Key design decisions:
|
| 8 |
+
- In-memory databases (fast, no disk I/O, no cleanup needed)
|
| 9 |
+
- Each reset() creates a fresh database
|
| 10 |
+
- Query execution is sandboxed (read-only would be ideal but SQLite
|
| 11 |
+
in-memory is ephemeral anyway)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sqlite3
|
| 15 |
+
from typing import Tuple, Optional, Any, Dict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DatabaseManager:
|
| 19 |
+
"""
|
| 20 |
+
Manages SQLite in-memory databases for SQL challenges.
|
| 21 |
+
|
| 22 |
+
Each task gets its own fresh database with schema and sample data.
|
| 23 |
+
The agent's queries are executed against this database.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 28 |
+
|
| 29 |
+
def create_database(self, setup_sql: str) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Create a new in-memory database with the given schema and data.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
setup_sql: SQL string containing CREATE TABLE and INSERT statements
|
| 35 |
+
"""
|
| 36 |
+
# Close any existing connection
|
| 37 |
+
self.close()
|
| 38 |
+
|
| 39 |
+
# Create fresh in-memory database
|
| 40 |
+
self.conn = sqlite3.connect(":memory:")
|
| 41 |
+
|
| 42 |
+
# Enable foreign keys
|
| 43 |
+
self.conn.execute("PRAGMA foreign_keys = ON")
|
| 44 |
+
|
| 45 |
+
# Run the setup SQL (creates tables and inserts data)
|
| 46 |
+
self.conn.executescript(setup_sql)
|
| 47 |
+
self.conn.commit()
|
| 48 |
+
|
| 49 |
+
def execute_query(self, sql: str) -> Tuple[bool, Optional[Dict], Optional[str]]:
|
| 50 |
+
"""
|
| 51 |
+
Execute a SQL query and return results.
|
| 52 |
+
|
| 53 |
+
This is the main method called when the agent submits a query.
|
| 54 |
+
It catches all exceptions to prevent crashes.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
sql: The SQL query string to execute
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tuple of (success, result_dict, error_message):
|
| 61 |
+
- success: True if query executed without error
|
| 62 |
+
- result_dict: {"columns": [...], "rows": [...]} if successful
|
| 63 |
+
- error_message: Error string if failed, None if success
|
| 64 |
+
"""
|
| 65 |
+
if not self.conn:
|
| 66 |
+
return False, None, "No database connection. Call create_database() first."
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
cursor = self.conn.execute(sql)
|
| 70 |
+
|
| 71 |
+
# Get column names from cursor description
|
| 72 |
+
if cursor.description:
|
| 73 |
+
columns = [desc[0] for desc in cursor.description]
|
| 74 |
+
else:
|
| 75 |
+
columns = []
|
| 76 |
+
|
| 77 |
+
# Fetch all rows
|
| 78 |
+
rows = cursor.fetchall()
|
| 79 |
+
|
| 80 |
+
result = {
|
| 81 |
+
"columns": columns,
|
| 82 |
+
"rows": rows,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return True, result, None
|
| 86 |
+
|
| 87 |
+
except sqlite3.Error as e:
|
| 88 |
+
return False, None, f"SQL Error: {str(e)}"
|
| 89 |
+
except Exception as e:
|
| 90 |
+
return False, None, f"Execution Error: {str(e)}"
|
| 91 |
+
|
| 92 |
+
def format_result(self, result: Dict, max_rows: int = 20) -> str:
|
| 93 |
+
"""
|
| 94 |
+
Format query result as a human-readable table string.
|
| 95 |
+
|
| 96 |
+
This formatted string is shown to the agent in the observation
|
| 97 |
+
so it can see what its query returned.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
result: Dict with "columns" and "rows" keys
|
| 101 |
+
max_rows: Maximum number of rows to display
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Formatted table string
|
| 105 |
+
"""
|
| 106 |
+
if not result or not result.get("columns"):
|
| 107 |
+
return "(empty result set)"
|
| 108 |
+
|
| 109 |
+
columns = result["columns"]
|
| 110 |
+
rows = result["rows"]
|
| 111 |
+
|
| 112 |
+
if not rows:
|
| 113 |
+
return f"Columns: {', '.join(columns)}\n(0 rows returned)"
|
| 114 |
+
|
| 115 |
+
# Calculate column widths (at least as wide as header)
|
| 116 |
+
col_widths = [len(str(c)) for c in columns]
|
| 117 |
+
for row in rows[:max_rows]:
|
| 118 |
+
for i, val in enumerate(row):
|
| 119 |
+
if i < len(col_widths):
|
| 120 |
+
col_widths[i] = max(col_widths[i], len(str(val)))
|
| 121 |
+
|
| 122 |
+
# Build formatted table
|
| 123 |
+
# Header
|
| 124 |
+
header = " | ".join(
|
| 125 |
+
str(c).ljust(w) for c, w in zip(columns, col_widths)
|
| 126 |
+
)
|
| 127 |
+
separator = "-+-".join("-" * w for w in col_widths)
|
| 128 |
+
|
| 129 |
+
# Data rows
|
| 130 |
+
formatted_rows = []
|
| 131 |
+
for row in rows[:max_rows]:
|
| 132 |
+
formatted_row = " | ".join(
|
| 133 |
+
str(v).ljust(w) for v, w in zip(row, col_widths)
|
| 134 |
+
)
|
| 135 |
+
formatted_rows.append(formatted_row)
|
| 136 |
+
|
| 137 |
+
# Assemble
|
| 138 |
+
table_str = f"{header}\n{separator}\n" + "\n".join(formatted_rows)
|
| 139 |
+
|
| 140 |
+
# Truncation notice
|
| 141 |
+
if len(rows) > max_rows:
|
| 142 |
+
table_str += f"\n... ({len(rows) - max_rows} more rows not shown)"
|
| 143 |
+
|
| 144 |
+
# Row count
|
| 145 |
+
table_str += f"\n\n({len(rows)} row{'s' if len(rows) != 1 else ''} returned)"
|
| 146 |
+
|
| 147 |
+
return table_str
|
| 148 |
+
|
| 149 |
+
def close(self) -> None:
|
| 150 |
+
"""Close the database connection and free resources."""
|
| 151 |
+
if self.conn:
|
| 152 |
+
try:
|
| 153 |
+
self.conn.close()
|
| 154 |
+
except Exception:
|
| 155 |
+
pass # Ignore errors on close
|
| 156 |
+
self.conn = None
|
src/sql_arena/environment.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core SQL Arena Environment.
|
| 3 |
+
Implements the OpenEnv step()/reset()/state() interface.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Dict, Any, List
|
| 7 |
+
from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
|
| 8 |
+
from .database import DatabaseManager
|
| 9 |
+
from .tasks import SQLTask, get_task, list_tasks, TASK_BY_ID
|
| 10 |
+
from .graders import grade_result, generate_hint
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class StepResult:
|
| 14 |
+
"""Result of a single environment step."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
observation: SQLArenaObservation,
|
| 19 |
+
reward: float,
|
| 20 |
+
done: bool,
|
| 21 |
+
info: Optional[Dict[str, Any]] = None,
|
| 22 |
+
):
|
| 23 |
+
self.observation = observation
|
| 24 |
+
self.reward = reward
|
| 25 |
+
self.done = done
|
| 26 |
+
self.info = info or {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SQLArenaEnvironment:
|
| 30 |
+
"""
|
| 31 |
+
SQL Arena: An interactive SQL query challenge environment.
|
| 32 |
+
|
| 33 |
+
The agent receives a database schema and a natural language question,
|
| 34 |
+
then iteratively writes SQL queries. The environment provides
|
| 35 |
+
execution results, feedback, and partial credit scoring.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.db = DatabaseManager()
|
| 40 |
+
self.current_task: Optional[SQLTask] = None
|
| 41 |
+
self._state: Optional[SQLArenaState] = None
|
| 42 |
+
self._last_observation: Optional[SQLArenaObservation] = None
|
| 43 |
+
|
| 44 |
+
def reset(
|
| 45 |
+
self,
|
| 46 |
+
difficulty: str = "basic_select",
|
| 47 |
+
task_id: Optional[str] = None,
|
| 48 |
+
) -> StepResult:
|
| 49 |
+
"""
|
| 50 |
+
Reset the environment with a new task.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
difficulty: 'basic_select', 'join_aggregate', or 'complex_analysis'
|
| 54 |
+
task_id: Optional specific task ID
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
StepResult with initial observation
|
| 58 |
+
"""
|
| 59 |
+
# Get the task
|
| 60 |
+
self.current_task = get_task(difficulty, task_id)
|
| 61 |
+
task = self.current_task
|
| 62 |
+
|
| 63 |
+
# Setup database
|
| 64 |
+
self.db.create_database(task.setup_sql)
|
| 65 |
+
|
| 66 |
+
# Initialize state
|
| 67 |
+
self._state = SQLArenaState(
|
| 68 |
+
task_id=task.task_id,
|
| 69 |
+
difficulty=task.difficulty,
|
| 70 |
+
current_step=0,
|
| 71 |
+
max_steps=task.max_steps,
|
| 72 |
+
best_score=0.0,
|
| 73 |
+
total_reward=0.0,
|
| 74 |
+
rewards_history=[],
|
| 75 |
+
done=False,
|
| 76 |
+
last_action_error=None,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Create initial observation
|
| 80 |
+
self._last_observation = SQLArenaObservation(
|
| 81 |
+
schema_description=task.schema_description,
|
| 82 |
+
question=task.question,
|
| 83 |
+
query_result=None,
|
| 84 |
+
error_message=None,
|
| 85 |
+
feedback="Welcome to SQL Arena! Write a SQL query to answer the question above.",
|
| 86 |
+
expected_columns=task.expected_columns,
|
| 87 |
+
attempts_remaining=task.max_steps,
|
| 88 |
+
difficulty=task.difficulty,
|
| 89 |
+
task_id=task.task_id,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return StepResult(
|
| 93 |
+
observation=self._last_observation,
|
| 94 |
+
reward=0.0,
|
| 95 |
+
done=False,
|
| 96 |
+
info={"task_title": task.title},
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def step(self, action: SQLArenaAction) -> StepResult:
|
| 100 |
+
"""
|
| 101 |
+
Execute the agent's SQL query and return feedback.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
action: SQLArenaAction containing the SQL query
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
StepResult with observation, reward, and done flag
|
| 108 |
+
"""
|
| 109 |
+
if self._state is None or self.current_task is None:
|
| 110 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 111 |
+
|
| 112 |
+
if self._state.done:
|
| 113 |
+
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 114 |
+
|
| 115 |
+
task = self.current_task
|
| 116 |
+
state = self._state
|
| 117 |
+
|
| 118 |
+
# Increment step counter
|
| 119 |
+
state.current_step += 1
|
| 120 |
+
|
| 121 |
+
# Execute the query
|
| 122 |
+
success, result, error = self.db.execute_query(action.sql_query)
|
| 123 |
+
|
| 124 |
+
# Grade the result
|
| 125 |
+
score, feedback = grade_result(task, success, result, error)
|
| 126 |
+
|
| 127 |
+
# Track best score
|
| 128 |
+
state.best_score = max(state.best_score, score)
|
| 129 |
+
|
| 130 |
+
# Calculate step reward
|
| 131 |
+
if len(state.rewards_history) == 0:
|
| 132 |
+
reward = score
|
| 133 |
+
else:
|
| 134 |
+
prev_best = max(state.rewards_history) if state.rewards_history else 0.0
|
| 135 |
+
improvement = max(0, score - prev_best)
|
| 136 |
+
reward = score * 0.5 + improvement * 0.5
|
| 137 |
+
|
| 138 |
+
reward = round(min(max(reward, 0.0), 1.0), 4)
|
| 139 |
+
state.rewards_history.append(reward)
|
| 140 |
+
state.total_reward += reward
|
| 141 |
+
|
| 142 |
+
# Add progressive hints
|
| 143 |
+
hint = generate_hint(task, state.current_step, score)
|
| 144 |
+
if hint and score < 1.0:
|
| 145 |
+
feedback += f"\n\n{hint}"
|
| 146 |
+
|
| 147 |
+
# Check if done
|
| 148 |
+
attempts_remaining = task.max_steps - state.current_step
|
| 149 |
+
is_perfect = score >= 1.0
|
| 150 |
+
is_out_of_steps = attempts_remaining <= 0
|
| 151 |
+
|
| 152 |
+
state.done = is_perfect or is_out_of_steps
|
| 153 |
+
state.last_action_error = error
|
| 154 |
+
|
| 155 |
+
# Format query result for observation
|
| 156 |
+
query_result_str = None
|
| 157 |
+
if success and result:
|
| 158 |
+
query_result_str = self.db.format_result(result)
|
| 159 |
+
|
| 160 |
+
# Build observation
|
| 161 |
+
self._last_observation = SQLArenaObservation(
|
| 162 |
+
schema_description=task.schema_description,
|
| 163 |
+
question=task.question,
|
| 164 |
+
query_result=query_result_str,
|
| 165 |
+
error_message=error,
|
| 166 |
+
feedback=feedback,
|
| 167 |
+
expected_columns=task.expected_columns,
|
| 168 |
+
attempts_remaining=attempts_remaining,
|
| 169 |
+
difficulty=task.difficulty,
|
| 170 |
+
task_id=task.task_id,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return StepResult(
|
| 174 |
+
observation=self._last_observation,
|
| 175 |
+
reward=reward,
|
| 176 |
+
done=state.done,
|
| 177 |
+
info={
|
| 178 |
+
"score": score,
|
| 179 |
+
"best_score": state.best_score,
|
| 180 |
+
"step": state.current_step,
|
| 181 |
+
"is_perfect": is_perfect,
|
| 182 |
+
},
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def state(self) -> SQLArenaState:
|
| 186 |
+
"""Return the current environment state."""
|
| 187 |
+
if self._state is None:
|
| 188 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 189 |
+
return self._state
|
| 190 |
+
|
| 191 |
+
def close(self) -> None:
|
| 192 |
+
"""Clean up resources."""
|
| 193 |
+
self.db.close()
|
| 194 |
+
self.current_task = None
|
| 195 |
+
self._state = None
|
| 196 |
+
self._last_observation = None
|
| 197 |
+
|
| 198 |
+
def get_available_tasks(self) -> Dict:
|
| 199 |
+
"""Return all available tasks."""
|
| 200 |
+
return list_tasks()
|
src/sql_arena/graders.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grading logic for SQL Arena.
|
| 3 |
+
Provides partial credit scoring (0.0 to 1.0) based on:
|
| 4 |
+
- Query execution success (0.10)
|
| 5 |
+
- Column correctness (0.20)
|
| 6 |
+
- Row count correctness (0.20)
|
| 7 |
+
- Value correctness (0.50)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 11 |
+
from .tasks import SQLTask
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def normalize_value(val: Any) -> Any:
|
| 15 |
+
"""Normalize values for comparison."""
|
| 16 |
+
if val is None:
|
| 17 |
+
return None
|
| 18 |
+
if isinstance(val, float):
|
| 19 |
+
return round(val, 2)
|
| 20 |
+
if isinstance(val, str):
|
| 21 |
+
return val.strip().lower()
|
| 22 |
+
return val
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def normalize_row(row: tuple) -> tuple:
|
| 26 |
+
"""Normalize all values in a row."""
|
| 27 |
+
return tuple(normalize_value(v) for v in row)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def grade_result(
|
| 31 |
+
task: SQLTask,
|
| 32 |
+
success: bool,
|
| 33 |
+
result: Optional[Dict],
|
| 34 |
+
error: Optional[str],
|
| 35 |
+
) -> Tuple[float, str]:
|
| 36 |
+
"""
|
| 37 |
+
Grade a SQL query result against expected output.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
(score, feedback) where score is in [0.0, 1.0]
|
| 41 |
+
|
| 42 |
+
Scoring breakdown:
|
| 43 |
+
- 0.10: Query executes without error
|
| 44 |
+
- 0.20: Correct column names
|
| 45 |
+
- 0.20: Correct number of rows
|
| 46 |
+
- 0.50: Correct values (proportional to matching rows)
|
| 47 |
+
"""
|
| 48 |
+
feedback_parts = []
|
| 49 |
+
score = 0.0
|
| 50 |
+
|
| 51 |
+
# ---- Component 1: Execution Success (0.10) ----
|
| 52 |
+
if not success:
|
| 53 |
+
feedback_parts.append(f"X Query failed: {error}")
|
| 54 |
+
feedback_parts.append("Hint: Fix the syntax error and try again.")
|
| 55 |
+
return 0.0, "\n".join(feedback_parts)
|
| 56 |
+
|
| 57 |
+
score += 0.10
|
| 58 |
+
feedback_parts.append("OK: Query executed successfully (+0.10)")
|
| 59 |
+
|
| 60 |
+
# ---- Component 2: Column Correctness (0.20) ----
|
| 61 |
+
actual_columns = [c.lower().strip() for c in result.get("columns", [])]
|
| 62 |
+
expected_columns = [c.lower().strip() for c in task.expected_columns]
|
| 63 |
+
|
| 64 |
+
if actual_columns == expected_columns:
|
| 65 |
+
score += 0.20
|
| 66 |
+
feedback_parts.append(f"OK: Correct columns: {actual_columns} (+0.20)")
|
| 67 |
+
else:
|
| 68 |
+
# Partial credit for overlapping columns
|
| 69 |
+
matching_cols = set(actual_columns) & set(expected_columns)
|
| 70 |
+
if matching_cols:
|
| 71 |
+
partial = 0.20 * (len(matching_cols) / len(expected_columns))
|
| 72 |
+
score += partial
|
| 73 |
+
feedback_parts.append(
|
| 74 |
+
f"PARTIAL: Column match: got {actual_columns}, "
|
| 75 |
+
f"expected {expected_columns} (+{partial:.2f})"
|
| 76 |
+
)
|
| 77 |
+
missing = set(expected_columns) - set(actual_columns)
|
| 78 |
+
if missing:
|
| 79 |
+
feedback_parts.append(f"Hint: Missing columns: {missing}")
|
| 80 |
+
else:
|
| 81 |
+
feedback_parts.append(
|
| 82 |
+
f"WRONG: Columns: got {actual_columns}, expected {expected_columns}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# ---- Component 3: Row Count (0.20) ----
|
| 86 |
+
actual_rows = result.get("rows", [])
|
| 87 |
+
expected_row_count = task.expected_row_count
|
| 88 |
+
|
| 89 |
+
if len(actual_rows) == expected_row_count:
|
| 90 |
+
score += 0.20
|
| 91 |
+
feedback_parts.append(f"OK: Correct row count: {len(actual_rows)} (+0.20)")
|
| 92 |
+
else:
|
| 93 |
+
# Partial credit: closer counts get more credit
|
| 94 |
+
if expected_row_count > 0:
|
| 95 |
+
ratio = 1.0 - abs(len(actual_rows) - expected_row_count) / max(
|
| 96 |
+
expected_row_count, len(actual_rows)
|
| 97 |
+
)
|
| 98 |
+
partial = max(0.0, 0.20 * ratio)
|
| 99 |
+
score += partial
|
| 100 |
+
feedback_parts.append(
|
| 101 |
+
f"PARTIAL: Row count: got {len(actual_rows)}, "
|
| 102 |
+
f"expected {expected_row_count} (+{partial:.2f})"
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
if len(actual_rows) == 0:
|
| 106 |
+
score += 0.20
|
| 107 |
+
feedback_parts.append("OK: Correct empty result set (+0.20)")
|
| 108 |
+
else:
|
| 109 |
+
feedback_parts.append(
|
| 110 |
+
f"WRONG: Expected empty result, got {len(actual_rows)} rows"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# ---- Component 4: Value Correctness (0.50) ----
|
| 114 |
+
if task.expected_rows:
|
| 115 |
+
normalized_expected = [normalize_row(r) for r in task.expected_rows]
|
| 116 |
+
normalized_actual = [normalize_row(r) for r in actual_rows]
|
| 117 |
+
|
| 118 |
+
# Try exact order match first
|
| 119 |
+
exact_matches = 0
|
| 120 |
+
for exp_row, act_row in zip(normalized_expected, normalized_actual):
|
| 121 |
+
if exp_row == act_row:
|
| 122 |
+
exact_matches += 1
|
| 123 |
+
|
| 124 |
+
if (
|
| 125 |
+
exact_matches == len(normalized_expected)
|
| 126 |
+
and len(normalized_actual) == len(normalized_expected)
|
| 127 |
+
):
|
| 128 |
+
score += 0.50
|
| 129 |
+
feedback_parts.append("OK: All values correct with correct ordering (+0.50)")
|
| 130 |
+
else:
|
| 131 |
+
# Try unordered match (set-based)
|
| 132 |
+
matched_rows = 0
|
| 133 |
+
remaining_actual = list(normalized_actual)
|
| 134 |
+
|
| 135 |
+
for exp_row in normalized_expected:
|
| 136 |
+
for i, act_row in enumerate(remaining_actual):
|
| 137 |
+
if exp_row == act_row:
|
| 138 |
+
matched_rows += 1
|
| 139 |
+
remaining_actual.pop(i)
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
if (
|
| 143 |
+
matched_rows == len(normalized_expected)
|
| 144 |
+
and len(normalized_actual) == len(normalized_expected)
|
| 145 |
+
):
|
| 146 |
+
# All rows match but wrong order
|
| 147 |
+
partial = 0.40
|
| 148 |
+
score += partial
|
| 149 |
+
feedback_parts.append(
|
| 150 |
+
f"PARTIAL: All values correct but wrong ordering (+{partial:.2f})"
|
| 151 |
+
)
|
| 152 |
+
feedback_parts.append("Hint: Check your ORDER BY clause")
|
| 153 |
+
elif matched_rows > 0:
|
| 154 |
+
# Some rows match
|
| 155 |
+
partial = 0.50 * (matched_rows / len(normalized_expected))
|
| 156 |
+
score += partial
|
| 157 |
+
feedback_parts.append(
|
| 158 |
+
f"PARTIAL: {matched_rows}/{len(normalized_expected)} rows match (+{partial:.2f})"
|
| 159 |
+
)
|
| 160 |
+
if matched_rows < len(normalized_expected):
|
| 161 |
+
feedback_parts.append(
|
| 162 |
+
"Hint: Some values are incorrect. Check WHERE/JOIN conditions."
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
feedback_parts.append("WRONG: No matching rows found")
|
| 166 |
+
feedback_parts.append(
|
| 167 |
+
"Hint: Review your query logic - values don't match expected output."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Tiny credit if some values appear somewhere
|
| 171 |
+
all_expected_vals = set()
|
| 172 |
+
for row in normalized_expected:
|
| 173 |
+
all_expected_vals.update(row)
|
| 174 |
+
all_actual_vals = set()
|
| 175 |
+
for row in normalized_actual:
|
| 176 |
+
all_actual_vals.update(row)
|
| 177 |
+
|
| 178 |
+
overlap = all_expected_vals & all_actual_vals
|
| 179 |
+
if overlap:
|
| 180 |
+
tiny_credit = 0.05
|
| 181 |
+
score += tiny_credit
|
| 182 |
+
feedback_parts.append(
|
| 183 |
+
f" (Some expected values found in output: +{tiny_credit:.2f})"
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
# Expected empty result
|
| 187 |
+
if len(actual_rows) == 0:
|
| 188 |
+
score += 0.50
|
| 189 |
+
feedback_parts.append("OK: Correctly returned empty result (+0.50)")
|
| 190 |
+
else:
|
| 191 |
+
feedback_parts.append(
|
| 192 |
+
f"WRONG: Expected empty result, got {len(actual_rows)} rows"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# ---- Final score ----
|
| 196 |
+
score = round(min(max(score, 0.0), 1.0), 4)
|
| 197 |
+
feedback_parts.append(f"\nTotal Score: {score:.2f}/1.00")
|
| 198 |
+
|
| 199 |
+
return score, "\n".join(feedback_parts)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def generate_hint(task: SQLTask, step: int, current_score: float) -> Optional[str]:
|
| 203 |
+
"""Generate progressive hints based on step number and current score."""
|
| 204 |
+
if current_score >= 0.8:
|
| 205 |
+
return None # No hint needed
|
| 206 |
+
|
| 207 |
+
if step <= len(task.hints):
|
| 208 |
+
return f"Hint {step}: {task.hints[step - 1]}"
|
| 209 |
+
|
| 210 |
+
# Generic hints for later steps
|
| 211 |
+
generic_hints = [
|
| 212 |
+
f"Expected columns are: {task.expected_columns}",
|
| 213 |
+
f"Expected {task.expected_row_count} rows in the result",
|
| 214 |
+
"Check the schema description carefully for table and column names",
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
hint_idx = min(step - len(task.hints) - 1, len(generic_hints) - 1)
|
| 218 |
+
if hint_idx >= 0:
|
| 219 |
+
return generic_hints[hint_idx]
|
| 220 |
+
return None
|
src/sql_arena/models.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Typed Pydantic models for SQL Arena OpenEnv environment.
|
| 3 |
+
|
| 4 |
+
These models define the contract between the agent and environment:
|
| 5 |
+
- SQLArenaAction: What the agent sends (a SQL query)
|
| 6 |
+
- SQLArenaObservation: What the agent receives (schema, results, feedback)
|
| 7 |
+
- SQLArenaState: Internal environment state tracking
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SQLArenaAction(BaseModel):
|
| 15 |
+
"""
|
| 16 |
+
Action model — the agent submits a SQL query.
|
| 17 |
+
|
| 18 |
+
This is what the agent sends to the environment each step.
|
| 19 |
+
The environment will execute this query against the SQLite database
|
| 20 |
+
and return results + feedback.
|
| 21 |
+
"""
|
| 22 |
+
sql_query: str = Field(
|
| 23 |
+
...,
|
| 24 |
+
description="SQL query to execute against the database",
|
| 25 |
+
examples=[
|
| 26 |
+
"SELECT name, salary FROM employees WHERE salary > 50000",
|
| 27 |
+
"SELECT department, COUNT(*) FROM employees GROUP BY department",
|
| 28 |
+
]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SQLArenaObservation(BaseModel):
|
| 33 |
+
"""
|
| 34 |
+
Observation model — what the agent sees after each step.
|
| 35 |
+
|
| 36 |
+
Contains the database schema, the question to answer,
|
| 37 |
+
results from the last query, error messages, and feedback
|
| 38 |
+
with partial credit information.
|
| 39 |
+
"""
|
| 40 |
+
# Always present
|
| 41 |
+
schema_description: str = Field(
|
| 42 |
+
...,
|
| 43 |
+
description="Human-readable database schema (CREATE TABLE statements)"
|
| 44 |
+
)
|
| 45 |
+
question: str = Field(
|
| 46 |
+
...,
|
| 47 |
+
description="Natural language question the agent must answer with SQL"
|
| 48 |
+
)
|
| 49 |
+
difficulty: str = Field(
|
| 50 |
+
...,
|
| 51 |
+
description="Task difficulty level: basic_select, join_aggregate, or complex_analysis"
|
| 52 |
+
)
|
| 53 |
+
task_id: str = Field(
|
| 54 |
+
...,
|
| 55 |
+
description="Unique identifier for this specific problem"
|
| 56 |
+
)
|
| 57 |
+
attempts_remaining: int = Field(
|
| 58 |
+
...,
|
| 59 |
+
description="Number of query attempts the agent has left"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Present after step() calls
|
| 63 |
+
query_result: Optional[str] = Field(
|
| 64 |
+
None,
|
| 65 |
+
description="Formatted result table from the last executed query"
|
| 66 |
+
)
|
| 67 |
+
error_message: Optional[str] = Field(
|
| 68 |
+
None,
|
| 69 |
+
description="SQL error message if the query failed to execute"
|
| 70 |
+
)
|
| 71 |
+
feedback: Optional[str] = Field(
|
| 72 |
+
None,
|
| 73 |
+
description="Detailed feedback on query correctness with partial credit breakdown"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Hints to help the agent
|
| 77 |
+
expected_columns: Optional[List[str]] = Field(
|
| 78 |
+
None,
|
| 79 |
+
description="Expected column names in the correct result (hint)"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SQLArenaState(BaseModel):
|
| 84 |
+
"""
|
| 85 |
+
Internal state model — tracks the episode progress.
|
| 86 |
+
|
| 87 |
+
This is returned by the state() endpoint and contains
|
| 88 |
+
all information about the current episode.
|
| 89 |
+
"""
|
| 90 |
+
task_id: str = Field(..., description="Current task identifier")
|
| 91 |
+
difficulty: str = Field(..., description="Current difficulty level")
|
| 92 |
+
current_step: int = Field(0, description="Number of steps taken so far")
|
| 93 |
+
max_steps: int = Field(5, description="Maximum steps allowed for this task")
|
| 94 |
+
best_score: float = Field(0.0, description="Best score achieved so far in this episode")
|
| 95 |
+
total_reward: float = Field(0.0, description="Sum of all rewards received")
|
| 96 |
+
rewards_history: List[float] = Field(
|
| 97 |
+
default_factory=list,
|
| 98 |
+
description="List of rewards received at each step"
|
| 99 |
+
)
|
| 100 |
+
done: bool = Field(False, description="Whether the episode has ended")
|
| 101 |
+
last_action_error: Optional[str] = Field(
|
| 102 |
+
None,
|
| 103 |
+
description="Error from the last action, if any"
|
| 104 |
+
)
|
src/sql_arena/server.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for SQL Arena - OpenEnv compatible.
|
| 3 |
+
Exposes /reset, /step, /state endpoints via HTTP and WebSocket.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import uuid
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
from .environment import SQLArenaEnvironment, StepResult
|
| 17 |
+
from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =====================================================
|
| 21 |
+
# Request / Response Models
|
| 22 |
+
# =====================================================
|
| 23 |
+
|
| 24 |
+
class ResetRequest(BaseModel):
|
| 25 |
+
difficulty: str = "basic_select"
|
| 26 |
+
task_id: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class StepRequest(BaseModel):
|
| 30 |
+
sql_query: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ResetResponse(BaseModel):
|
| 34 |
+
observation: SQLArenaObservation
|
| 35 |
+
reward: float
|
| 36 |
+
done: bool
|
| 37 |
+
info: dict = {}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class StepResponse(BaseModel):
|
| 41 |
+
observation: SQLArenaObservation
|
| 42 |
+
reward: float
|
| 43 |
+
done: bool
|
| 44 |
+
info: dict = {}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class StateResponse(BaseModel):
|
| 48 |
+
state: SQLArenaState
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TaskListResponse(BaseModel):
|
| 52 |
+
tasks: Dict
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# =====================================================
|
| 56 |
+
# Session Manager
|
| 57 |
+
# =====================================================
|
| 58 |
+
|
| 59 |
+
class SessionManager:
|
| 60 |
+
"""Manages multiple concurrent environment instances."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, max_sessions: int = 100):
|
| 63 |
+
self.sessions: Dict[str, SQLArenaEnvironment] = {}
|
| 64 |
+
self.max_sessions = max_sessions
|
| 65 |
+
self._lock = asyncio.Lock()
|
| 66 |
+
|
| 67 |
+
async def create_session(self):
|
| 68 |
+
async with self._lock:
|
| 69 |
+
if len(self.sessions) >= self.max_sessions:
|
| 70 |
+
oldest_key = next(iter(self.sessions))
|
| 71 |
+
self.sessions[oldest_key].close()
|
| 72 |
+
del self.sessions[oldest_key]
|
| 73 |
+
session_id = str(uuid.uuid4())
|
| 74 |
+
env = SQLArenaEnvironment()
|
| 75 |
+
self.sessions[session_id] = env
|
| 76 |
+
return session_id, env
|
| 77 |
+
|
| 78 |
+
async def get_session(self, session_id: str):
|
| 79 |
+
return self.sessions.get(session_id)
|
| 80 |
+
|
| 81 |
+
async def remove_session(self, session_id: str):
|
| 82 |
+
async with self._lock:
|
| 83 |
+
if session_id in self.sessions:
|
| 84 |
+
self.sessions[session_id].close()
|
| 85 |
+
del self.sessions[session_id]
|
| 86 |
+
|
| 87 |
+
async def cleanup_all(self):
|
| 88 |
+
async with self._lock:
|
| 89 |
+
for env in self.sessions.values():
|
| 90 |
+
env.close()
|
| 91 |
+
self.sessions.clear()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# =====================================================
|
| 95 |
+
# App Setup
|
| 96 |
+
# =====================================================
|
| 97 |
+
|
| 98 |
+
session_manager = SessionManager()
|
| 99 |
+
_default_env = SQLArenaEnvironment()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@asynccontextmanager
|
| 103 |
+
async def lifespan(app: FastAPI):
|
| 104 |
+
yield
|
| 105 |
+
await session_manager.cleanup_all()
|
| 106 |
+
_default_env.close()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
app = FastAPI(
|
| 110 |
+
title="SQL Arena - OpenEnv Environment",
|
| 111 |
+
description="Interactive SQL query challenge environment for AI agents",
|
| 112 |
+
version="1.0.0",
|
| 113 |
+
lifespan=lifespan,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
app.add_middleware(
|
| 117 |
+
CORSMiddleware,
|
| 118 |
+
allow_origins=["*"],
|
| 119 |
+
allow_credentials=True,
|
| 120 |
+
allow_methods=["*"],
|
| 121 |
+
allow_headers=["*"],
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# =====================================================
|
| 126 |
+
# HTTP Endpoints
|
| 127 |
+
# =====================================================
|
| 128 |
+
|
| 129 |
+
@app.get("/")
|
| 130 |
+
async def root():
|
| 131 |
+
return {
|
| 132 |
+
"name": "SQL Arena",
|
| 133 |
+
"version": "1.0.0",
|
| 134 |
+
"description": "Interactive SQL query challenge environment",
|
| 135 |
+
"endpoints": ["/reset", "/step", "/state", "/tasks", "/ws"],
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@app.get("/health")
|
| 140 |
+
async def health():
|
| 141 |
+
return {"status": "healthy"}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@app.post("/reset", response_model=ResetResponse)
|
| 145 |
+
async def reset(request: ResetRequest = ResetRequest()):
|
| 146 |
+
try:
|
| 147 |
+
result = _default_env.reset(
|
| 148 |
+
difficulty=request.difficulty,
|
| 149 |
+
task_id=request.task_id,
|
| 150 |
+
)
|
| 151 |
+
return ResetResponse(
|
| 152 |
+
observation=result.observation,
|
| 153 |
+
reward=result.reward,
|
| 154 |
+
done=result.done,
|
| 155 |
+
info=result.info,
|
| 156 |
+
)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@app.post("/step", response_model=StepResponse)
|
| 162 |
+
async def step(request: StepRequest):
|
| 163 |
+
try:
|
| 164 |
+
action = SQLArenaAction(sql_query=request.sql_query)
|
| 165 |
+
result = _default_env.step(action)
|
| 166 |
+
return StepResponse(
|
| 167 |
+
observation=result.observation,
|
| 168 |
+
reward=result.reward,
|
| 169 |
+
done=result.done,
|
| 170 |
+
info=result.info,
|
| 171 |
+
)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@app.get("/state", response_model=StateResponse)
|
| 177 |
+
async def state():
|
| 178 |
+
try:
|
| 179 |
+
return StateResponse(state=_default_env.state())
|
| 180 |
+
except Exception as e:
|
| 181 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@app.get("/tasks", response_model=TaskListResponse)
|
| 185 |
+
async def tasks():
|
| 186 |
+
return TaskListResponse(tasks=_default_env.get_available_tasks())
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# =====================================================
|
| 190 |
+
# WebSocket Endpoint
|
| 191 |
+
# =====================================================
|
| 192 |
+
|
| 193 |
+
@app.websocket("/ws")
|
| 194 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 195 |
+
await websocket.accept()
|
| 196 |
+
session_id, env = await session_manager.create_session()
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
while True:
|
| 200 |
+
data = await websocket.receive_text()
|
| 201 |
+
message = json.loads(data)
|
| 202 |
+
|
| 203 |
+
method = message.get("method", "")
|
| 204 |
+
params = message.get("params", {})
|
| 205 |
+
msg_id = message.get("id", None)
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
if method == "reset":
|
| 209 |
+
result = env.reset(
|
| 210 |
+
difficulty=params.get("difficulty", "basic_select"),
|
| 211 |
+
task_id=params.get("task_id"),
|
| 212 |
+
)
|
| 213 |
+
response = {
|
| 214 |
+
"id": msg_id,
|
| 215 |
+
"result": {
|
| 216 |
+
"observation": result.observation.model_dump(),
|
| 217 |
+
"reward": result.reward,
|
| 218 |
+
"done": result.done,
|
| 219 |
+
"info": result.info,
|
| 220 |
+
},
|
| 221 |
+
}
|
| 222 |
+
elif method == "step":
|
| 223 |
+
action = SQLArenaAction(sql_query=params.get("sql_query", ""))
|
| 224 |
+
result = env.step(action)
|
| 225 |
+
response = {
|
| 226 |
+
"id": msg_id,
|
| 227 |
+
"result": {
|
| 228 |
+
"observation": result.observation.model_dump(),
|
| 229 |
+
"reward": result.reward,
|
| 230 |
+
"done": result.done,
|
| 231 |
+
"info": result.info,
|
| 232 |
+
},
|
| 233 |
+
}
|
| 234 |
+
elif method == "state":
|
| 235 |
+
env_state = env.state()
|
| 236 |
+
response = {
|
| 237 |
+
"id": msg_id,
|
| 238 |
+
"result": {"state": env_state.model_dump()},
|
| 239 |
+
}
|
| 240 |
+
elif method == "close":
|
| 241 |
+
response = {"id": msg_id, "result": {"status": "closed"}}
|
| 242 |
+
await websocket.send_text(json.dumps(response))
|
| 243 |
+
break
|
| 244 |
+
else:
|
| 245 |
+
response = {"id": msg_id, "error": f"Unknown method: {method}"}
|
| 246 |
+
|
| 247 |
+
await websocket.send_text(json.dumps(response))
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
error_response = {"id": msg_id, "error": str(e)}
|
| 251 |
+
await websocket.send_text(json.dumps(error_response))
|
| 252 |
+
|
| 253 |
+
except WebSocketDisconnect:
|
| 254 |
+
pass
|
| 255 |
+
finally:
|
| 256 |
+
await session_manager.remove_session(session_id)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# =====================================================
|
| 260 |
+
# Entry point
|
| 261 |
+
# =====================================================
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
import uvicorn
|
| 265 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
src/sql_arena/tasks.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task Bank for SQL Arena.
|
| 3 |
+
|
| 4 |
+
Contains 9 SQL challenges across 3 difficulty levels:
|
| 5 |
+
- basic_select (Easy): 3 tasks — simple SELECT/WHERE/ORDER BY
|
| 6 |
+
- join_aggregate (Medium): 3 tasks — JOINs, GROUP BY, HAVING
|
| 7 |
+
- complex_analysis (Hard): 3 tasks — CTEs, window functions, subqueries
|
| 8 |
+
|
| 9 |
+
Each task defines:
|
| 10 |
+
- Database schema and sample data (setup_sql)
|
| 11 |
+
- Natural language question
|
| 12 |
+
- Expected SQL solution
|
| 13 |
+
- Expected result for grading
|
| 14 |
+
- Progressive hints
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import List, Dict, Optional
|
| 19 |
+
import random
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SQLTask:
|
| 24 |
+
"""A single SQL challenge problem."""
|
| 25 |
+
task_id: str
|
| 26 |
+
difficulty: str # basic_select, join_aggregate, complex_analysis
|
| 27 |
+
title: str
|
| 28 |
+
setup_sql: str # CREATE TABLE + INSERT statements
|
| 29 |
+
question: str # Natural language question
|
| 30 |
+
expected_sql: str # Reference solution
|
| 31 |
+
expected_columns: List[str] # Expected column names in result
|
| 32 |
+
expected_row_count: int # Expected number of result rows
|
| 33 |
+
expected_rows: List[tuple] # Expected result rows for grading
|
| 34 |
+
hints: List[str] = field(default_factory=list)
|
| 35 |
+
max_steps: int = 5
|
| 36 |
+
schema_description: str = "" # Human-readable schema description
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# =============================================================
|
| 40 |
+
# DATABASE SCHEMAS
|
| 41 |
+
# =============================================================
|
| 42 |
+
|
| 43 |
+
# Schema 1: Employee database (used by Easy tasks)
|
| 44 |
+
EMPLOYEES_SCHEMA = """
|
| 45 |
+
CREATE TABLE employees (
|
| 46 |
+
id INTEGER PRIMARY KEY,
|
| 47 |
+
name TEXT NOT NULL,
|
| 48 |
+
department TEXT NOT NULL,
|
| 49 |
+
salary REAL NOT NULL,
|
| 50 |
+
hire_date TEXT NOT NULL,
|
| 51 |
+
is_active INTEGER DEFAULT 1
|
| 52 |
+
);
|
| 53 |
+
|
| 54 |
+
INSERT INTO employees VALUES (1, 'Alice Johnson', 'Engineering', 95000, '2020-01-15', 1);
|
| 55 |
+
INSERT INTO employees VALUES (2, 'Bob Smith', 'Marketing', 65000, '2019-06-01', 1);
|
| 56 |
+
INSERT INTO employees VALUES (3, 'Carol Williams', 'Engineering', 110000, '2018-03-20', 1);
|
| 57 |
+
INSERT INTO employees VALUES (4, 'David Brown', 'Sales', 72000, '2021-09-10', 1);
|
| 58 |
+
INSERT INTO employees VALUES (5, 'Eve Davis', 'Engineering', 88000, '2022-02-28', 1);
|
| 59 |
+
INSERT INTO employees VALUES (6, 'Frank Miller', 'Marketing', 58000, '2020-11-15', 0);
|
| 60 |
+
INSERT INTO employees VALUES (7, 'Grace Wilson', 'Sales', 81000, '2019-04-22', 1);
|
| 61 |
+
INSERT INTO employees VALUES (8, 'Henry Taylor', 'Engineering', 125000, '2017-08-01', 1);
|
| 62 |
+
INSERT INTO employees VALUES (9, 'Ivy Anderson', 'HR', 70000, '2021-01-10', 1);
|
| 63 |
+
INSERT INTO employees VALUES (10, 'Jack Thomas', 'HR', 75000, '2020-07-15', 1);
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
EMPLOYEES_SCHEMA_DESC = """Table: employees
|
| 67 |
+
Columns:
|
| 68 |
+
- id: INTEGER PRIMARY KEY (auto-increment identifier)
|
| 69 |
+
- name: TEXT (employee full name, e.g. 'Alice Johnson')
|
| 70 |
+
- department: TEXT (one of: Engineering, Marketing, Sales, HR)
|
| 71 |
+
- salary: REAL (annual salary in USD, e.g. 95000.0)
|
| 72 |
+
- hire_date: TEXT (date in YYYY-MM-DD format, e.g. '2020-01-15')
|
| 73 |
+
- is_active: INTEGER (1 = currently active, 0 = inactive/left)
|
| 74 |
+
|
| 75 |
+
Data: 10 employees across 4 departments.
|
| 76 |
+
- 4 in Engineering, 2 in Marketing (1 inactive), 2 in Sales, 2 in HR
|
| 77 |
+
- Salaries range from 58,000 to 125,000
|
| 78 |
+
- Hire dates range from 2017 to 2022
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Schema 2: E-commerce database (used by Medium and Hard tasks)
|
| 83 |
+
ECOMMERCE_SCHEMA = """
|
| 84 |
+
CREATE TABLE customers (
|
| 85 |
+
id INTEGER PRIMARY KEY,
|
| 86 |
+
name TEXT NOT NULL,
|
| 87 |
+
email TEXT NOT NULL,
|
| 88 |
+
city TEXT NOT NULL,
|
| 89 |
+
signup_date TEXT NOT NULL
|
| 90 |
+
);
|
| 91 |
+
|
| 92 |
+
CREATE TABLE products (
|
| 93 |
+
id INTEGER PRIMARY KEY,
|
| 94 |
+
name TEXT NOT NULL,
|
| 95 |
+
category TEXT NOT NULL,
|
| 96 |
+
price REAL NOT NULL,
|
| 97 |
+
stock INTEGER NOT NULL
|
| 98 |
+
);
|
| 99 |
+
|
| 100 |
+
CREATE TABLE orders (
|
| 101 |
+
id INTEGER PRIMARY KEY,
|
| 102 |
+
customer_id INTEGER NOT NULL,
|
| 103 |
+
order_date TEXT NOT NULL,
|
| 104 |
+
status TEXT NOT NULL,
|
| 105 |
+
FOREIGN KEY (customer_id) REFERENCES customers(id)
|
| 106 |
+
);
|
| 107 |
+
|
| 108 |
+
CREATE TABLE order_items (
|
| 109 |
+
id INTEGER PRIMARY KEY,
|
| 110 |
+
order_id INTEGER NOT NULL,
|
| 111 |
+
product_id INTEGER NOT NULL,
|
| 112 |
+
quantity INTEGER NOT NULL,
|
| 113 |
+
unit_price REAL NOT NULL,
|
| 114 |
+
FOREIGN KEY (order_id) REFERENCES orders(id),
|
| 115 |
+
FOREIGN KEY (product_id) REFERENCES products(id)
|
| 116 |
+
);
|
| 117 |
+
|
| 118 |
+
-- Customers
|
| 119 |
+
INSERT INTO customers VALUES (1, 'Alice', 'alice@email.com', 'New York', '2023-01-15');
|
| 120 |
+
INSERT INTO customers VALUES (2, 'Bob', 'bob@email.com', 'Los Angeles', '2023-02-20');
|
| 121 |
+
INSERT INTO customers VALUES (3, 'Carol', 'carol@email.com', 'Chicago', '2023-03-10');
|
| 122 |
+
INSERT INTO customers VALUES (4, 'David', 'david@email.com', 'New York', '2023-04-05');
|
| 123 |
+
INSERT INTO customers VALUES (5, 'Eve', 'eve@email.com', 'Boston', '2023-05-12');
|
| 124 |
+
|
| 125 |
+
-- Products
|
| 126 |
+
INSERT INTO products VALUES (1, 'Laptop', 'Electronics', 999.99, 50);
|
| 127 |
+
INSERT INTO products VALUES (2, 'Headphones', 'Electronics', 149.99, 200);
|
| 128 |
+
INSERT INTO products VALUES (3, 'Python Book', 'Books', 39.99, 100);
|
| 129 |
+
INSERT INTO products VALUES (4, 'Desk Lamp', 'Home', 29.99, 150);
|
| 130 |
+
INSERT INTO products VALUES (5, 'Keyboard', 'Electronics', 79.99, 120);
|
| 131 |
+
INSERT INTO products VALUES (6, 'SQL Book', 'Books', 44.99, 80);
|
| 132 |
+
|
| 133 |
+
-- Orders (10 orders, various statuses)
|
| 134 |
+
INSERT INTO orders VALUES (1, 1, '2023-06-01', 'completed');
|
| 135 |
+
INSERT INTO orders VALUES (2, 1, '2023-07-15', 'completed');
|
| 136 |
+
INSERT INTO orders VALUES (3, 2, '2023-06-20', 'completed');
|
| 137 |
+
INSERT INTO orders VALUES (4, 3, '2023-08-01', 'completed');
|
| 138 |
+
INSERT INTO orders VALUES (5, 3, '2023-08-15', 'completed');
|
| 139 |
+
INSERT INTO orders VALUES (6, 3, '2023-09-01', 'completed');
|
| 140 |
+
INSERT INTO orders VALUES (7, 4, '2023-07-10', 'cancelled');
|
| 141 |
+
INSERT INTO orders VALUES (8, 5, '2023-09-20', 'completed');
|
| 142 |
+
INSERT INTO orders VALUES (9, 1, '2023-10-01', 'completed');
|
| 143 |
+
INSERT INTO orders VALUES (10, 2, '2023-10-15', 'pending');
|
| 144 |
+
|
| 145 |
+
-- Order Items (17 line items)
|
| 146 |
+
INSERT INTO order_items VALUES (1, 1, 1, 1, 999.99);
|
| 147 |
+
INSERT INTO order_items VALUES (2, 1, 2, 2, 149.99);
|
| 148 |
+
INSERT INTO order_items VALUES (3, 2, 3, 1, 39.99);
|
| 149 |
+
INSERT INTO order_items VALUES (4, 2, 5, 1, 79.99);
|
| 150 |
+
INSERT INTO order_items VALUES (5, 3, 1, 1, 999.99);
|
| 151 |
+
INSERT INTO order_items VALUES (6, 3, 4, 3, 29.99);
|
| 152 |
+
INSERT INTO order_items VALUES (7, 4, 2, 1, 149.99);
|
| 153 |
+
INSERT INTO order_items VALUES (8, 4, 6, 2, 44.99);
|
| 154 |
+
INSERT INTO order_items VALUES (9, 5, 3, 1, 39.99);
|
| 155 |
+
INSERT INTO order_items VALUES (10, 5, 5, 2, 79.99);
|
| 156 |
+
INSERT INTO order_items VALUES (11, 6, 1, 1, 999.99);
|
| 157 |
+
INSERT INTO order_items VALUES (12, 6, 2, 1, 149.99);
|
| 158 |
+
INSERT INTO order_items VALUES (13, 8, 6, 1, 44.99);
|
| 159 |
+
INSERT INTO order_items VALUES (14, 8, 4, 1, 29.99);
|
| 160 |
+
INSERT INTO order_items VALUES (15, 9, 2, 3, 149.99);
|
| 161 |
+
INSERT INTO order_items VALUES (16, 9, 3, 2, 39.99);
|
| 162 |
+
INSERT INTO order_items VALUES (17, 10, 1, 1, 999.99);
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
ECOMMERCE_SCHEMA_DESC = """Tables:
|
| 166 |
+
|
| 167 |
+
1. customers (5 rows)
|
| 168 |
+
- id: INTEGER PRIMARY KEY
|
| 169 |
+
- name: TEXT (customer first name)
|
| 170 |
+
- email: TEXT
|
| 171 |
+
- city: TEXT (New York, Los Angeles, Chicago, Boston)
|
| 172 |
+
- signup_date: TEXT (YYYY-MM-DD)
|
| 173 |
+
|
| 174 |
+
2. products (6 rows)
|
| 175 |
+
- id: INTEGER PRIMARY KEY
|
| 176 |
+
- name: TEXT (product name)
|
| 177 |
+
- category: TEXT (Electronics, Books, Home)
|
| 178 |
+
- price: REAL (unit price in USD)
|
| 179 |
+
- stock: INTEGER (units in stock)
|
| 180 |
+
|
| 181 |
+
3. orders (10 rows)
|
| 182 |
+
- id: INTEGER PRIMARY KEY
|
| 183 |
+
- customer_id: INTEGER → customers.id
|
| 184 |
+
- order_date: TEXT (YYYY-MM-DD, range: 2023-06 to 2023-10)
|
| 185 |
+
- status: TEXT (completed, cancelled, pending)
|
| 186 |
+
|
| 187 |
+
4. order_items (17 rows)
|
| 188 |
+
- id: INTEGER PRIMARY KEY
|
| 189 |
+
- order_id: INTEGER → orders.id
|
| 190 |
+
- product_id: INTEGER → products.id
|
| 191 |
+
- quantity: INTEGER
|
| 192 |
+
- unit_price: REAL (price at time of order)
|
| 193 |
+
|
| 194 |
+
Relationships:
|
| 195 |
+
orders.customer_id → customers.id
|
| 196 |
+
order_items.order_id → orders.id
|
| 197 |
+
order_items.product_id → products.id
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# =============================================================
|
| 202 |
+
# EASY TASKS: basic_select (3 tasks)
|
| 203 |
+
# =============================================================
|
| 204 |
+
|
| 205 |
+
EASY_TASKS = [
|
| 206 |
+
SQLTask(
|
| 207 |
+
task_id="easy_001",
|
| 208 |
+
difficulty="basic_select",
|
| 209 |
+
title="High Salary Employees",
|
| 210 |
+
setup_sql=EMPLOYEES_SCHEMA,
|
| 211 |
+
question="Find the names and salaries of all ACTIVE employees who earn more than \$80,000. Order the results by salary from highest to lowest.",
|
| 212 |
+
expected_sql="SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
|
| 213 |
+
expected_columns=["name", "salary"],
|
| 214 |
+
expected_row_count=4,
|
| 215 |
+
expected_rows=[
|
| 216 |
+
("Henry Taylor", 125000.0),
|
| 217 |
+
("Carol Williams", 110000.0),
|
| 218 |
+
("Alice Johnson", 95000.0),
|
| 219 |
+
("Eve Davis", 88000.0),
|
| 220 |
+
],
|
| 221 |
+
hints=[
|
| 222 |
+
"Use SELECT with specific column names, not SELECT *",
|
| 223 |
+
"Use WHERE with AND to combine conditions: is_active = 1 AND salary > 80000",
|
| 224 |
+
"Add ORDER BY salary DESC for descending order",
|
| 225 |
+
],
|
| 226 |
+
schema_description=EMPLOYEES_SCHEMA_DESC,
|
| 227 |
+
max_steps=5,
|
| 228 |
+
),
|
| 229 |
+
|
| 230 |
+
SQLTask(
|
| 231 |
+
task_id="easy_002",
|
| 232 |
+
difficulty="basic_select",
|
| 233 |
+
title="Department Employee Count",
|
| 234 |
+
setup_sql=EMPLOYEES_SCHEMA,
|
| 235 |
+
question="Count the number of ACTIVE employees in each department. Show the department name and the count. Order by count from highest to lowest.",
|
| 236 |
+
expected_sql="SELECT department, COUNT(*) as employee_count FROM employees WHERE is_active = 1 GROUP BY department ORDER BY employee_count DESC",
|
| 237 |
+
expected_columns=["department", "employee_count"],
|
| 238 |
+
expected_row_count=4,
|
| 239 |
+
expected_rows=[
|
| 240 |
+
("Engineering", 4),
|
| 241 |
+
("HR", 2),
|
| 242 |
+
("Sales", 2),
|
| 243 |
+
("Marketing", 1),
|
| 244 |
+
],
|
| 245 |
+
hints=[
|
| 246 |
+
"Use COUNT(*) to count rows in each group",
|
| 247 |
+
"GROUP BY department groups rows by department",
|
| 248 |
+
"Use an alias: COUNT(*) as employee_count",
|
| 249 |
+
],
|
| 250 |
+
schema_description=EMPLOYEES_SCHEMA_DESC,
|
| 251 |
+
max_steps=5,
|
| 252 |
+
),
|
| 253 |
+
|
| 254 |
+
SQLTask(
|
| 255 |
+
task_id="easy_003",
|
| 256 |
+
difficulty="basic_select",
|
| 257 |
+
title="Recent Hires",
|
| 258 |
+
setup_sql=EMPLOYEES_SCHEMA,
|
| 259 |
+
question="List the names and hire dates of employees hired on or after January 1, 2021. Order by hire date from earliest to latest.",
|
| 260 |
+
expected_sql="SELECT name, hire_date FROM employees WHERE hire_date >= '2021-01-01' ORDER BY hire_date",
|
| 261 |
+
expected_columns=["name", "hire_date"],
|
| 262 |
+
expected_row_count=3,
|
| 263 |
+
expected_rows=[
|
| 264 |
+
("Ivy Anderson", "2021-01-10"),
|
| 265 |
+
("David Brown", "2021-09-10"),
|
| 266 |
+
("Eve Davis", "2022-02-28"),
|
| 267 |
+
],
|
| 268 |
+
hints=[
|
| 269 |
+
"Dates in SQLite can be compared as strings when in YYYY-MM-DD format",
|
| 270 |
+
"Use WHERE hire_date >= '2021-01-01'",
|
| 271 |
+
"ORDER BY hire_date gives ascending order by default",
|
| 272 |
+
],
|
| 273 |
+
schema_description=EMPLOYEES_SCHEMA_DESC,
|
| 274 |
+
max_steps=5,
|
| 275 |
+
),
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# =============================================================
|
| 280 |
+
# MEDIUM TASKS: join_aggregate (3 tasks)
|
| 281 |
+
# =============================================================
|
| 282 |
+
|
| 283 |
+
MEDIUM_TASKS = [
|
| 284 |
+
SQLTask(
|
| 285 |
+
task_id="medium_001",
|
| 286 |
+
difficulty="join_aggregate",
|
| 287 |
+
title="Customer Total Spending",
|
| 288 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 289 |
+
question="Find the total amount spent by each customer on COMPLETED orders only. Show the customer name and their total spending. Only include customers who spent more than \$200. Order by total spending from highest to lowest.",
|
| 290 |
+
expected_sql="""
|
| 291 |
+
SELECT c.name, ROUND(SUM(oi.quantity * oi.unit_price), 2) as total_spent
|
| 292 |
+
FROM customers c
|
| 293 |
+
JOIN orders o ON c.id = o.customer_id
|
| 294 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 295 |
+
WHERE o.status = 'completed'
|
| 296 |
+
GROUP BY c.id, c.name
|
| 297 |
+
HAVING SUM(oi.quantity * oi.unit_price) > 200
|
| 298 |
+
ORDER BY total_spent DESC
|
| 299 |
+
""",
|
| 300 |
+
expected_columns=["name", "total_spent"],
|
| 301 |
+
expected_row_count=4,
|
| 302 |
+
expected_rows=[
|
| 303 |
+
("Alice", 1919.91),
|
| 304 |
+
("Carol", 1464.94),
|
| 305 |
+
("Bob", 1089.96),
|
| 306 |
+
("Eve", 74.98),
|
| 307 |
+
],
|
| 308 |
+
hints=[
|
| 309 |
+
"You need to JOIN three tables: customers → orders → order_items",
|
| 310 |
+
"Total per item = quantity * unit_price, then SUM for total per customer",
|
| 311 |
+
"Filter completed orders with WHERE o.status = 'completed'",
|
| 312 |
+
"Use HAVING (not WHERE) to filter after GROUP BY",
|
| 313 |
+
],
|
| 314 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 315 |
+
max_steps=7,
|
| 316 |
+
),
|
| 317 |
+
|
| 318 |
+
SQLTask(
|
| 319 |
+
task_id="medium_002",
|
| 320 |
+
difficulty="join_aggregate",
|
| 321 |
+
title="Category Revenue",
|
| 322 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 323 |
+
question="Calculate the total revenue for each product category from COMPLETED orders. Show the category name and total revenue. Order by total revenue from highest to lowest.",
|
| 324 |
+
expected_sql="""
|
| 325 |
+
SELECT p.category, ROUND(SUM(oi.quantity * oi.unit_price), 2) as total_revenue
|
| 326 |
+
FROM products p
|
| 327 |
+
JOIN order_items oi ON p.id = oi.product_id
|
| 328 |
+
JOIN orders o ON oi.order_id = o.id
|
| 329 |
+
WHERE o.status = 'completed'
|
| 330 |
+
GROUP BY p.category
|
| 331 |
+
ORDER BY total_revenue DESC
|
| 332 |
+
""",
|
| 333 |
+
expected_columns=["category", "total_revenue"],
|
| 334 |
+
expected_row_count=3,
|
| 335 |
+
expected_rows=[
|
| 336 |
+
("Electronics", 4459.83),
|
| 337 |
+
("Books", 254.93),
|
| 338 |
+
("Home", 119.96),
|
| 339 |
+
],
|
| 340 |
+
hints=[
|
| 341 |
+
"JOIN products → order_items → orders",
|
| 342 |
+
"Revenue per item = quantity * unit_price",
|
| 343 |
+
"Filter only completed orders",
|
| 344 |
+
"GROUP BY p.category to get per-category totals",
|
| 345 |
+
],
|
| 346 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 347 |
+
max_steps=7,
|
| 348 |
+
),
|
| 349 |
+
|
| 350 |
+
SQLTask(
|
| 351 |
+
task_id="medium_003",
|
| 352 |
+
difficulty="join_aggregate",
|
| 353 |
+
title="Customers with Multiple Orders",
|
| 354 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 355 |
+
question="Find customers who have placed more than one COMPLETED order. Show the customer name and the number of completed orders they placed. Order by order count descending, then by name ascending.",
|
| 356 |
+
expected_sql="""
|
| 357 |
+
SELECT c.name, COUNT(o.id) as order_count
|
| 358 |
+
FROM customers c
|
| 359 |
+
JOIN orders o ON c.id = o.customer_id
|
| 360 |
+
WHERE o.status = 'completed'
|
| 361 |
+
GROUP BY c.id, c.name
|
| 362 |
+
HAVING COUNT(o.id) > 1
|
| 363 |
+
ORDER BY order_count DESC, c.name ASC
|
| 364 |
+
""",
|
| 365 |
+
expected_columns=["name", "order_count"],
|
| 366 |
+
expected_row_count=2,
|
| 367 |
+
expected_rows=[
|
| 368 |
+
("Alice", 3),
|
| 369 |
+
("Carol", 3),
|
| 370 |
+
],
|
| 371 |
+
hints=[
|
| 372 |
+
"JOIN customers with orders",
|
| 373 |
+
"Filter for completed orders in WHERE clause",
|
| 374 |
+
"GROUP BY customer, then HAVING COUNT > 1",
|
| 375 |
+
"ORDER BY count DESC, then name ASC for ties",
|
| 376 |
+
],
|
| 377 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 378 |
+
max_steps=7,
|
| 379 |
+
),
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# =============================================================
|
| 384 |
+
# HARD TASKS: complex_analysis (3 tasks)
|
| 385 |
+
# =============================================================
|
| 386 |
+
|
| 387 |
+
HARD_TASKS = [
|
| 388 |
+
SQLTask(
|
| 389 |
+
task_id="hard_001",
|
| 390 |
+
difficulty="complex_analysis",
|
| 391 |
+
title="Monthly Revenue with Growth Rate",
|
| 392 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 393 |
+
question="Calculate monthly revenue from COMPLETED orders, and for each month show the month (YYYY-MM format), the total revenue, and the percentage change from the previous month. For the first month, the percentage change should be NULL. Round revenue to 2 decimal places and percentage to 2 decimal places. Order by month ascending.",
|
| 394 |
+
expected_sql="""
|
| 395 |
+
WITH monthly AS (
|
| 396 |
+
SELECT
|
| 397 |
+
strftime('%Y-%m', o.order_date) as month,
|
| 398 |
+
ROUND(SUM(oi.quantity * oi.unit_price), 2) as revenue
|
| 399 |
+
FROM orders o
|
| 400 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 401 |
+
WHERE o.status = 'completed'
|
| 402 |
+
GROUP BY strftime('%Y-%m', o.order_date)
|
| 403 |
+
),
|
| 404 |
+
with_prev AS (
|
| 405 |
+
SELECT
|
| 406 |
+
month,
|
| 407 |
+
revenue,
|
| 408 |
+
LAG(revenue) OVER (ORDER BY month) as prev_revenue
|
| 409 |
+
FROM monthly
|
| 410 |
+
)
|
| 411 |
+
SELECT
|
| 412 |
+
month,
|
| 413 |
+
revenue,
|
| 414 |
+
CASE
|
| 415 |
+
WHEN prev_revenue IS NULL THEN NULL
|
| 416 |
+
ELSE ROUND(((revenue - prev_revenue) * 100.0 / prev_revenue), 2)
|
| 417 |
+
END as pct_change
|
| 418 |
+
FROM with_prev
|
| 419 |
+
ORDER BY month
|
| 420 |
+
""",
|
| 421 |
+
expected_columns=["month", "revenue", "pct_change"],
|
| 422 |
+
expected_row_count=5,
|
| 423 |
+
expected_rows=[
|
| 424 |
+
("2023-06", 2289.93, None),
|
| 425 |
+
("2023-07", 119.98, -94.76),
|
| 426 |
+
("2023-08", 1429.93, 1091.81),
|
| 427 |
+
("2023-09", 1224.97, -14.34),
|
| 428 |
+
("2023-10", 529.95, -56.74),
|
| 429 |
+
],
|
| 430 |
+
hints=[
|
| 431 |
+
"Use a CTE (WITH clause) to first calculate monthly revenue",
|
| 432 |
+
"strftime('%Y-%m', date) extracts year-month from a date string",
|
| 433 |
+
"LAG(revenue) OVER (ORDER BY month) gets the previous month's revenue",
|
| 434 |
+
"Percentage change = ((new - old) / old) * 100",
|
| 435 |
+
"Use CASE WHEN prev IS NULL THEN NULL ELSE ... END for first month",
|
| 436 |
+
],
|
| 437 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 438 |
+
max_steps=10,
|
| 439 |
+
),
|
| 440 |
+
|
| 441 |
+
SQLTask(
|
| 442 |
+
task_id="hard_002",
|
| 443 |
+
difficulty="complex_analysis",
|
| 444 |
+
title="Top Product Per Category",
|
| 445 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 446 |
+
question="For each product category, find the single best-selling product (by total quantity sold across COMPLETED orders). Show the category, product name, and total quantity sold. If there are ties, pick the one with the higher total revenue. Order by category name ascending.",
|
| 447 |
+
expected_sql="""
|
| 448 |
+
WITH product_sales AS (
|
| 449 |
+
SELECT
|
| 450 |
+
p.category,
|
| 451 |
+
p.name as product_name,
|
| 452 |
+
SUM(oi.quantity) as total_qty,
|
| 453 |
+
SUM(oi.quantity * oi.unit_price) as total_revenue,
|
| 454 |
+
ROW_NUMBER() OVER (
|
| 455 |
+
PARTITION BY p.category
|
| 456 |
+
ORDER BY SUM(oi.quantity) DESC, SUM(oi.quantity * oi.unit_price) DESC
|
| 457 |
+
) as rn
|
| 458 |
+
FROM products p
|
| 459 |
+
JOIN order_items oi ON p.id = oi.product_id
|
| 460 |
+
JOIN orders o ON oi.order_id = o.id
|
| 461 |
+
WHERE o.status = 'completed'
|
| 462 |
+
GROUP BY p.category, p.name
|
| 463 |
+
)
|
| 464 |
+
SELECT category, product_name, total_qty
|
| 465 |
+
FROM product_sales
|
| 466 |
+
WHERE rn = 1
|
| 467 |
+
ORDER BY category ASC
|
| 468 |
+
""",
|
| 469 |
+
expected_columns=["category", "product_name", "total_qty"],
|
| 470 |
+
expected_row_count=3,
|
| 471 |
+
expected_rows=[
|
| 472 |
+
("Books", "Python Book", 4),
|
| 473 |
+
("Electronics", "Headphones", 7),
|
| 474 |
+
("Home", "Desk Lamp", 4),
|
| 475 |
+
],
|
| 476 |
+
hints=[
|
| 477 |
+
"First calculate total quantity sold per product (SUM of quantity)",
|
| 478 |
+
"Use ROW_NUMBER() OVER (PARTITION BY category ORDER BY qty DESC) to rank within category",
|
| 479 |
+
"Filter WHERE rn = 1 to get only the top product per category",
|
| 480 |
+
"A CTE makes this much cleaner than nested subqueries",
|
| 481 |
+
"Don't forget to filter for completed orders only",
|
| 482 |
+
],
|
| 483 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 484 |
+
max_steps=10,
|
| 485 |
+
),
|
| 486 |
+
|
| 487 |
+
SQLTask(
|
| 488 |
+
task_id="hard_003",
|
| 489 |
+
difficulty="complex_analysis",
|
| 490 |
+
title="Customer Lifetime Value Analysis",
|
| 491 |
+
setup_sql=ECOMMERCE_SCHEMA,
|
| 492 |
+
question="For customers with at least 2 completed orders, calculate: their name, number of completed orders, total lifetime spending (rounded to 2 decimals), average order value (rounded to 2 decimals), and the number of days between their first and last completed order. Order by total spending descending.",
|
| 493 |
+
expected_sql="""
|
| 494 |
+
WITH customer_order_totals AS (
|
| 495 |
+
SELECT
|
| 496 |
+
c.id as customer_id,
|
| 497 |
+
c.name,
|
| 498 |
+
o.id as order_id,
|
| 499 |
+
o.order_date,
|
| 500 |
+
SUM(oi.quantity * oi.unit_price) as order_total
|
| 501 |
+
FROM customers c
|
| 502 |
+
JOIN orders o ON c.id = o.customer_id
|
| 503 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 504 |
+
WHERE o.status = 'completed'
|
| 505 |
+
GROUP BY c.id, c.name, o.id, o.order_date
|
| 506 |
+
)
|
| 507 |
+
SELECT
|
| 508 |
+
name,
|
| 509 |
+
COUNT(*) as num_orders,
|
| 510 |
+
ROUND(SUM(order_total), 2) as total_spending,
|
| 511 |
+
ROUND(AVG(order_total), 2) as avg_order_value,
|
| 512 |
+
CAST(julianday(MAX(order_date)) - julianday(MIN(order_date)) AS INTEGER) as days_span
|
| 513 |
+
FROM customer_order_totals
|
| 514 |
+
GROUP BY customer_id, name
|
| 515 |
+
HAVING COUNT(*) >= 2
|
| 516 |
+
ORDER BY total_spending DESC
|
| 517 |
+
""",
|
| 518 |
+
expected_columns=["name", "num_orders", "total_spending", "avg_order_value", "days_span"],
|
| 519 |
+
expected_row_count=2,
|
| 520 |
+
expected_rows=[
|
| 521 |
+
("Alice", 3, 1919.91, 639.97, 122),
|
| 522 |
+
("Carol", 3, 1464.94, 488.31, 31),
|
| 523 |
+
],
|
| 524 |
+
hints=[
|
| 525 |
+
"Use a CTE to first calculate the total for each individual order",
|
| 526 |
+
"In the CTE: JOIN customers → orders → order_items, GROUP BY order",
|
| 527 |
+
"In the outer query: GROUP BY customer, HAVING COUNT >= 2",
|
| 528 |
+
"julianday() converts date strings to Julian day numbers for arithmetic",
|
| 529 |
+
"days_span = julianday(MAX(order_date)) - julianday(MIN(order_date))",
|
| 530 |
+
],
|
| 531 |
+
schema_description=ECOMMERCE_SCHEMA_DESC,
|
| 532 |
+
max_steps=10,
|
| 533 |
+
),
|
| 534 |
+
]
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# =============================================================
|
| 538 |
+
# TASK REGISTRY — Maps task IDs and difficulty levels
|
| 539 |
+
# =============================================================
|
| 540 |
+
|
| 541 |
+
ALL_TASKS: Dict[str, List[SQLTask]] = {
|
| 542 |
+
"basic_select": EASY_TASKS,
|
| 543 |
+
"join_aggregate": MEDIUM_TASKS,
|
| 544 |
+
"complex_analysis": HARD_TASKS,
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
# Build a flat lookup by task_id
|
| 548 |
+
TASK_BY_ID: Dict[str, SQLTask] = {}
|
| 549 |
+
for _tasks in ALL_TASKS.values():
|
| 550 |
+
for _task in _tasks:
|
| 551 |
+
TASK_BY_ID[_task.task_id] = _task
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def get_task(difficulty: str, task_id: Optional[str] = None) -> SQLTask:
|
| 555 |
+
"""
|
| 556 |
+
Get a task by difficulty level, optionally by specific ID.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
difficulty: One of 'basic_select', 'join_aggregate', 'complex_analysis'
|
| 560 |
+
task_id: Optional specific task ID (e.g., 'easy_001')
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
SQLTask instance
|
| 564 |
+
|
| 565 |
+
Raises:
|
| 566 |
+
ValueError: If difficulty is unknown
|
| 567 |
+
"""
|
| 568 |
+
# If specific task_id given, return it directly
|
| 569 |
+
if task_id and task_id in TASK_BY_ID:
|
| 570 |
+
return TASK_BY_ID[task_id]
|
| 571 |
+
|
| 572 |
+
# Otherwise pick from the difficulty pool
|
| 573 |
+
if difficulty not in ALL_TASKS:
|
| 574 |
+
raise ValueError(
|
| 575 |
+
f"Unknown difficulty: '{difficulty}'. "
|
| 576 |
+
f"Choose from: {list(ALL_TASKS.keys())}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
tasks = ALL_TASKS[difficulty]
|
| 580 |
+
return random.choice(tasks)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def list_tasks() -> Dict[str, List[str]]:
|
| 584 |
+
"""
|
| 585 |
+
List all available tasks grouped by difficulty.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
Dict mapping difficulty name to list of task IDs
|
| 589 |
+
"""
|
| 590 |
+
return {
|
| 591 |
+
difficulty: [t.task_id for t in tasks]
|
| 592 |
+
for difficulty, tasks in ALL_TASKS.items()
|
| 593 |
+
}
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for SQL Arena environment."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from src.sql_arena.environment import SQLArenaEnvironment
|
| 5 |
+
from src.sql_arena.models import SQLArenaAction
|
| 6 |
+
from src.sql_arena.tasks import list_tasks, TASK_BY_ID
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestEnvironmentBasics:
|
| 10 |
+
|
| 11 |
+
def setup_method(self):
|
| 12 |
+
self.env = SQLArenaEnvironment()
|
| 13 |
+
|
| 14 |
+
def teardown_method(self):
|
| 15 |
+
self.env.close()
|
| 16 |
+
|
| 17 |
+
def test_reset_returns_observation(self):
|
| 18 |
+
result = self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 19 |
+
assert result.observation is not None
|
| 20 |
+
assert result.reward == 0.0
|
| 21 |
+
assert result.done is False
|
| 22 |
+
|
| 23 |
+
def test_step_with_correct_query(self):
|
| 24 |
+
self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 25 |
+
task = self.env.current_task
|
| 26 |
+
action = SQLArenaAction(sql_query=task.expected_sql)
|
| 27 |
+
result = self.env.step(action)
|
| 28 |
+
assert result.reward > 0.0
|
| 29 |
+
assert result.info.get("score", 0) >= 0.8
|
| 30 |
+
|
| 31 |
+
def test_step_with_invalid_query(self):
|
| 32 |
+
self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 33 |
+
action = SQLArenaAction(sql_query="INVALID SQL QUERY")
|
| 34 |
+
result = self.env.step(action)
|
| 35 |
+
assert result.reward == 0.0
|
| 36 |
+
assert result.observation.error_message is not None
|
| 37 |
+
|
| 38 |
+
def test_state_tracking(self):
|
| 39 |
+
self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 40 |
+
state = self.env.state()
|
| 41 |
+
assert state.current_step == 0
|
| 42 |
+
|
| 43 |
+
self.env.step(SQLArenaAction(sql_query="SELECT 1"))
|
| 44 |
+
state = self.env.state()
|
| 45 |
+
assert state.current_step == 1
|
| 46 |
+
|
| 47 |
+
def test_episode_terminates(self):
|
| 48 |
+
self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 49 |
+
task = self.env.current_task
|
| 50 |
+
for _ in range(task.max_steps + 1):
|
| 51 |
+
if self.env.state().done:
|
| 52 |
+
break
|
| 53 |
+
self.env.step(SQLArenaAction(sql_query="SELECT 1"))
|
| 54 |
+
assert self.env.state().done is True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestAllDifficulties:
|
| 58 |
+
|
| 59 |
+
def setup_method(self):
|
| 60 |
+
self.env = SQLArenaEnvironment()
|
| 61 |
+
|
| 62 |
+
def teardown_method(self):
|
| 63 |
+
self.env.close()
|
| 64 |
+
|
| 65 |
+
def test_easy(self):
|
| 66 |
+
result = self.env.reset(difficulty="basic_select")
|
| 67 |
+
assert result.observation.difficulty == "basic_select"
|
| 68 |
+
|
| 69 |
+
def test_medium(self):
|
| 70 |
+
result = self.env.reset(difficulty="join_aggregate")
|
| 71 |
+
assert result.observation.difficulty == "join_aggregate"
|
| 72 |
+
|
| 73 |
+
def test_hard(self):
|
| 74 |
+
result = self.env.reset(difficulty="complex_analysis")
|
| 75 |
+
assert result.observation.difficulty == "complex_analysis"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TestGrading:
|
| 79 |
+
|
| 80 |
+
def setup_method(self):
|
| 81 |
+
self.env = SQLArenaEnvironment()
|
| 82 |
+
|
| 83 |
+
def teardown_method(self):
|
| 84 |
+
self.env.close()
|
| 85 |
+
|
| 86 |
+
def test_scores_in_range(self):
|
| 87 |
+
for task_id, task in TASK_BY_ID.items():
|
| 88 |
+
self.env.reset(difficulty=task.difficulty, task_id=task_id)
|
| 89 |
+
action = SQLArenaAction(sql_query=task.expected_sql)
|
| 90 |
+
result = self.env.step(action)
|
| 91 |
+
assert 0.0 <= result.reward <= 1.0
|
| 92 |
+
assert 0.0 <= result.info.get("score", 0) <= 1.0
|
| 93 |
+
self.env.reset(difficulty=task.difficulty, task_id=task_id)
|
| 94 |
+
|
| 95 |
+
def test_varying_scores(self):
|
| 96 |
+
scores = set()
|
| 97 |
+
queries = [
|
| 98 |
+
"SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
|
| 99 |
+
"SELECT * FROM employees",
|
| 100 |
+
"INVALID",
|
| 101 |
+
"SELECT name FROM employees",
|
| 102 |
+
]
|
| 103 |
+
for q in queries:
|
| 104 |
+
self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| 105 |
+
result = self.env.step(SQLArenaAction(sql_query=q))
|
| 106 |
+
scores.add(round(result.info.get("score", 0), 2))
|
| 107 |
+
assert len(scores) > 1, "Grader always returns the same score!"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TestTaskRegistry:
|
| 111 |
+
|
| 112 |
+
def test_list_tasks(self):
|
| 113 |
+
tasks = list_tasks()
|
| 114 |
+
assert "basic_select" in tasks
|
| 115 |
+
assert "join_aggregate" in tasks
|
| 116 |
+
assert "complex_analysis" in tasks
|
| 117 |
+
|
| 118 |
+
def test_minimum_3_tasks(self):
|
| 119 |
+
tasks = list_tasks()
|
| 120 |
+
for difficulty, task_ids in tasks.items():
|
| 121 |
+
assert len(task_ids) >= 3, f"{difficulty} has fewer than 3 tasks"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
pytest.main([__file__, "-v"])
|