rahul2124 commited on
Commit
72805b8
·
verified ·
1 Parent(s): 87e2ef6

Upload folder using huggingface_hub

Browse files
.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "tests/test_env.py::TestAllDifficulties::test_easy",
3
+ "tests/test_env.py::TestAllDifficulties::test_hard",
4
+ "tests/test_env.py::TestAllDifficulties::test_medium",
5
+ "tests/test_env.py::TestEnvironmentBasics::test_episode_terminates",
6
+ "tests/test_env.py::TestEnvironmentBasics::test_reset_returns_observation",
7
+ "tests/test_env.py::TestEnvironmentBasics::test_state_tracking",
8
+ "tests/test_env.py::TestEnvironmentBasics::test_step_with_correct_query",
9
+ "tests/test_env.py::TestEnvironmentBasics::test_step_with_invalid_query",
10
+ "tests/test_env.py::TestGrading::test_scores_in_range",
11
+ "tests/test_env.py::TestGrading::test_varying_scores",
12
+ "tests/test_env.py::TestTaskRegistry::test_list_tasks",
13
+ "tests/test_env.py::TestTaskRegistry::test_minimum_3_tasks"
14
+ ]
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ EXPOSE 7860
15
+
16
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s \
17
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
18
+
19
+ CMD ["uvicorn", "src.sql_arena.server:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
README.md CHANGED
@@ -1,10 +1,29 @@
1
- ---
2
- title: Sql Arena
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Arena - OpenEnv Environment
2
+
3
+ An interactive SQL query challenge environment where AI agents learn to write SQL
4
+ by iteratively querying databases and receiving execution feedback with partial credit scoring.
5
+
6
+ ## Real-World Utility
7
+
8
+ Text-to-SQL is one of the most valuable capabilities for AI agents:
9
+ - Used by data analysts, business users, and developers daily
10
+ - Evaluates reasoning, schema understanding, and query composition
11
+ - Directly applicable to production AI assistants and copilots
12
+ - SQL Arena provides interactive iterative feedback (not just static benchmarks)
13
+
14
+ ## Tasks
15
+
16
+ | Task | Difficulty | Description | Max Steps |
17
+ |------|-----------|-------------|-----------|
18
+ | basic_select | Easy | SELECT, WHERE, ORDER BY on single table | 5 |
19
+ | join_aggregate | Medium | Multi-table JOINs, GROUP BY, HAVING | 7 |
20
+ | complex_analysis | Hard | CTEs, window functions, subqueries | 10 |
21
+
22
+ Each difficulty has 3+ unique problems with deterministic, reproducible grading.
23
+
24
+ ## Action Space
25
+
26
+ ```json
27
+ {
28
+ "sql_query": "SELECT name, salary FROM employees WHERE salary > 80000 ORDER BY salary DESC"
29
+ }
inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script - SQL Arena OpenEnv Environment
3
+ Baseline agent that uses an LLM to solve SQL challenges.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import textwrap
9
+ from typing import List, Optional
10
+
11
+ from openai import OpenAI
12
+
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+ from src.sql_arena.environment import SQLArenaEnvironment
15
+ from src.sql_arena.models import SQLArenaAction
16
+
17
+ # =====================================================
18
+ # Configuration
19
+ # =====================================================
20
+
21
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
22
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
23
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
24
+
25
+ BENCHMARK = "sql_arena"
26
+ TEMPERATURE = 0.3
27
+ MAX_TOKENS = 500
28
+
29
+ TASKS = [
30
+ {"difficulty": "basic_select", "task_id": "easy_001", "name": "basic_select", "max_steps": 5},
31
+ {"difficulty": "join_aggregate", "task_id": "medium_001", "name": "join_aggregate", "max_steps": 7},
32
+ {"difficulty": "complex_analysis", "task_id": "hard_001", "name": "complex_analysis", "max_steps": 10},
33
+ ]
34
+
35
+
36
+ # =====================================================
37
+ # Logging (MANDATORY format)
38
+ # =====================================================
39
+
40
+ def log_start(task: str, env: str, model: str) -> None:
41
+ print(f"[START] task={task} env={env} model={model}", flush=True)
42
+
43
+
44
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
45
+ error_val = error if error else "null"
46
+ done_val = str(done).lower()
47
+ action_short = action.replace('\n', ' ').strip()[:100]
48
+ print(
49
+ f"[STEP] step={step} action={action_short} reward={reward:.2f} done={done_val} error={error_val}",
50
+ flush=True,
51
+ )
52
+
53
+
54
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
55
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
56
+ print(
57
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
58
+ flush=True,
59
+ )
60
+
61
+
62
+ # =====================================================
63
+ # LLM Agent
64
+ # =====================================================
65
+
66
+ SYSTEM_PROMPT = textwrap.dedent("""
67
+ You are an expert SQL query writer. You are interacting with a SQL challenge environment.
68
+
69
+ Each turn you receive: database schema, a question, previous query results, and feedback.
70
+ Your goal: Write a SQL query that correctly answers the question.
71
+
72
+ Rules:
73
+ - Output ONLY the SQL query, nothing else
74
+ - No explanations, no markdown, no code fences
75
+ - Use standard SQLite syntax
76
+ - Be precise with column names and table names
77
+ - If your previous query had errors, fix them based on the feedback
78
+ """).strip()
79
+
80
+
81
+ def build_user_prompt(observation: dict, step: int, history: List[str]) -> str:
82
+ parts = []
83
+ parts.append(f"=== SQL Challenge (Step {step}) ===")
84
+ parts.append(f"\nDifficulty: {observation.get('difficulty', 'unknown')}")
85
+ parts.append(f"\n--- Database Schema ---\n{observation.get('schema_description', '')}")
86
+ parts.append(f"\n--- Question ---\n{observation.get('question', '')}")
87
+
88
+ if observation.get('expected_columns'):
89
+ parts.append(f"\n--- Expected Columns ---\n{observation['expected_columns']}")
90
+ if observation.get('query_result'):
91
+ parts.append(f"\n--- Previous Query Result ---\n{observation['query_result']}")
92
+ if observation.get('error_message'):
93
+ parts.append(f"\n--- Error ---\n{observation['error_message']}")
94
+ if observation.get('feedback'):
95
+ parts.append(f"\n--- Feedback ---\n{observation['feedback']}")
96
+
97
+ parts.append(f"\nAttempts remaining: {observation.get('attempts_remaining', 0)}")
98
+
99
+ if history:
100
+ parts.append("\n--- Previous Attempts ---")
101
+ for h in history[-3:]:
102
+ parts.append(h)
103
+
104
+ parts.append("\nWrite your SQL query now:")
105
+ return "\n".join(parts)
106
+
107
+
108
+ def get_sql_from_llm(client: OpenAI, observation: dict, step: int, history: List[str]) -> str:
109
+ user_prompt = build_user_prompt(observation, step, history)
110
+ try:
111
+ completion = client.chat.completions.create(
112
+ model=MODEL_NAME,
113
+ messages=[
114
+ {"role": "system", "content": SYSTEM_PROMPT},
115
+ {"role": "user", "content": user_prompt},
116
+ ],
117
+ temperature=TEMPERATURE,
118
+ max_tokens=MAX_TOKENS,
119
+ stream=False,
120
+ )
121
+ raw = (completion.choices[0].message.content or "").strip()
122
+ sql = raw
123
+ if sql.startswith("```sql"):
124
+ sql = sql[6:]
125
+ if sql.startswith("```"):
126
+ sql = sql[3:]
127
+ if sql.endswith("```"):
128
+ sql = sql[:-3]
129
+ sql = sql.strip()
130
+ return sql if sql else "SELECT 1"
131
+ except Exception as exc:
132
+ print(f"[DEBUG] LLM request failed: {exc}", flush=True)
133
+ return "SELECT 1"
134
+
135
+
136
+ # =====================================================
137
+ # Main Inference Loop
138
+ # =====================================================
139
+
140
+ def run_task(client: OpenAI, env: SQLArenaEnvironment, task_config: dict) -> float:
141
+ difficulty = task_config["difficulty"]
142
+ task_id = task_config["task_id"]
143
+ task_name = task_config["name"]
144
+ max_steps = task_config["max_steps"]
145
+
146
+ history: List[str] = []
147
+ rewards: List[float] = []
148
+ steps_taken = 0
149
+ best_score = 0.0
150
+
151
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
152
+
153
+ try:
154
+ result = env.reset(difficulty=difficulty, task_id=task_id)
155
+ obs_dict = result.observation.model_dump()
156
+
157
+ for step in range(1, max_steps + 1):
158
+ if result.done:
159
+ break
160
+
161
+ sql_query = get_sql_from_llm(client, obs_dict, step, history)
162
+
163
+ action = SQLArenaAction(sql_query=sql_query)
164
+ result = env.step(action)
165
+
166
+ obs_dict = result.observation.model_dump()
167
+ reward = result.reward
168
+ done = result.done
169
+ error = obs_dict.get("error_message")
170
+
171
+ rewards.append(reward)
172
+ steps_taken = step
173
+ best_score = max(best_score, result.info.get("score", 0.0))
174
+
175
+ log_step(step=step, action=sql_query, reward=reward, done=done, error=error)
176
+
177
+ history.append(
178
+ f"Step {step}: {sql_query[:80]}... -> reward={reward:.2f}"
179
+ )
180
+
181
+ if done:
182
+ break
183
+
184
+ final_score = min(max(best_score, 0.0), 1.0)
185
+ success = final_score >= 0.5
186
+
187
+ except Exception as e:
188
+ print(f"[DEBUG] Task {task_name} error: {e}", flush=True)
189
+ final_score = 0.0
190
+ success = False
191
+
192
+ finally:
193
+ log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
194
+
195
+ return final_score
196
+
197
+
198
+ def main() -> None:
199
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
200
+ env = SQLArenaEnvironment()
201
+
202
+ all_scores = []
203
+
204
+ for task_config in TASKS:
205
+ print(f"\n{'='*60}", flush=True)
206
+ print(f"Running task: {task_config['name']} ({task_config['difficulty']})", flush=True)
207
+ print(f"{'='*60}", flush=True)
208
+
209
+ score = run_task(client, env, task_config)
210
+ all_scores.append(score)
211
+ print(f"\nTask {task_config['name']} final score: {score:.2f}\n", flush=True)
212
+
213
+ avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
214
+ print(f"\n{'='*60}", flush=True)
215
+ print("SUMMARY", flush=True)
216
+ print(f"{'='*60}", flush=True)
217
+ for tc, sc in zip(TASKS, all_scores):
218
+ print(f" {tc['name']:20s}: {sc:.2f}", flush=True)
219
+ print(f" {'Average':20s}: {avg_score:.2f}", flush=True)
220
+
221
+ env.close()
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()
openenv.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql_arena
2
+ version: "1.0.0"
3
+ description: >
4
+ Interactive SQL query challenge environment where AI agents learn to write SQL
5
+ by iteratively querying databases and receiving execution feedback with partial credit.
6
+
7
+ author: "Vudumula Naga Sai Rahul"
8
+ license: "MIT"
9
+
10
+ interface:
11
+ action:
12
+ type: object
13
+ model: sql_arena.models.SQLArenaAction
14
+ properties:
15
+ sql_query:
16
+ type: string
17
+ description: "SQL query to execute against the database"
18
+
19
+ observation:
20
+ type: object
21
+ model: sql_arena.models.SQLArenaObservation
22
+ properties:
23
+ schema_description:
24
+ type: string
25
+ question:
26
+ type: string
27
+ query_result:
28
+ type: string
29
+ nullable: true
30
+ error_message:
31
+ type: string
32
+ nullable: true
33
+ feedback:
34
+ type: string
35
+ nullable: true
36
+ expected_columns:
37
+ type: array
38
+ nullable: true
39
+ attempts_remaining:
40
+ type: integer
41
+ difficulty:
42
+ type: string
43
+ task_id:
44
+ type: string
45
+
46
+ state:
47
+ type: object
48
+ model: sql_arena.models.SQLArenaState
49
+
50
+ tasks:
51
+ - id: basic_select
52
+ name: "Basic SELECT Queries"
53
+ description: "Simple SELECT, WHERE, ORDER BY queries"
54
+ difficulty: easy
55
+ max_steps: 5
56
+ subtasks:
57
+ - easy_001
58
+ - easy_002
59
+ - easy_003
60
+
61
+ - id: join_aggregate
62
+ name: "JOIN and Aggregate Queries"
63
+ description: "Multi-table JOINs with GROUP BY, HAVING"
64
+ difficulty: medium
65
+ max_steps: 7
66
+ subtasks:
67
+ - medium_001
68
+ - medium_002
69
+ - medium_003
70
+
71
+ - id: complex_analysis
72
+ name: "Complex Analysis Queries"
73
+ description: "CTEs, window functions, subqueries"
74
+ difficulty: hard
75
+ max_steps: 10
76
+ subtasks:
77
+ - hard_001
78
+ - hard_002
79
+ - hard_003
80
+
81
+ grading:
82
+ score_range: [0.0, 1.0]
83
+ components:
84
+ - name: execution
85
+ weight: 0.10
86
+ description: "Query executes without errors"
87
+ - name: columns
88
+ weight: 0.20
89
+ description: "Correct column names"
90
+ - name: row_count
91
+ weight: 0.20
92
+ description: "Correct number of rows"
93
+ - name: values
94
+ weight: 0.50
95
+ description: "Correct data values"
96
+
97
+ server:
98
+ framework: fastapi
99
+ entrypoint: src.sql_arena.server:app
100
+ port: 7860
101
+
102
+ deployment:
103
+ platform: huggingface-spaces
104
+ docker: true
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-arena"
7
+ version = "1.0.0"
8
+ description = "Interactive SQL query challenge OpenEnv environment"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "Vudumula Naga Sai Rahul", email = "nagasairahulvudumula@gmail.com"}
14
+ ]
15
+ dependencies = [
16
+ "fastapi>=0.104.0",
17
+ "uvicorn[standard]>=0.24.0",
18
+ "pydantic>=2.5.0",
19
+ "websockets>=12.0",
20
+ "openai>=1.0.0",
21
+ ]
22
+
23
+ [project.optional-dependencies]
24
+ dev = [
25
+ "pytest>=7.0",
26
+ "httpx>=0.25.0",
27
+ ]
28
+
29
+ [tool.setuptools.packages.find]
30
+ where = ["."]
31
+ include = ["src*"]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn[standard]>=0.24.0
3
+ pydantic>=2.5.0
4
+ websockets>=12.0
5
+ openai>=1.0.0
6
+ pytest>=7.0
7
+ httpx>=0.25.0
src/__init__.py ADDED
File without changes
src/sql_arena/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL Arena - Interactive SQL Query Challenge Environment for OpenEnv.
3
+ """
4
+
5
+ from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
6
+ from .environment import SQLArenaEnvironment, StepResult
7
+ from .tasks import get_task, list_tasks, SQLTask, ALL_TASKS, TASK_BY_ID
8
+ from .graders import grade_result
9
+
10
+ __all__ = [
11
+ "SQLArenaAction",
12
+ "SQLArenaObservation",
13
+ "SQLArenaState",
14
+ "SQLArenaEnvironment",
15
+ "StepResult",
16
+ "get_task",
17
+ "list_tasks",
18
+ "SQLTask",
19
+ "ALL_TASKS",
20
+ "TASK_BY_ID",
21
+ "grade_result",
22
+ ]
23
+
24
+ __version__ = "1.0.0"
src/sql_arena/database.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite Database Manager for SQL Arena.
3
+
4
+ Creates in-memory SQLite databases for each task.
5
+ Executes agent queries safely and formats results.
6
+
7
+ Key design decisions:
8
+ - In-memory databases (fast, no disk I/O, no cleanup needed)
9
+ - Each reset() creates a fresh database
10
+ - Query execution is sandboxed (read-only would be ideal but SQLite
11
+ in-memory is ephemeral anyway)
12
+ """
13
+
14
+ import sqlite3
15
+ from typing import Tuple, Optional, Any, Dict
16
+
17
+
18
+ class DatabaseManager:
19
+ """
20
+ Manages SQLite in-memory databases for SQL challenges.
21
+
22
+ Each task gets its own fresh database with schema and sample data.
23
+ The agent's queries are executed against this database.
24
+ """
25
+
26
+ def __init__(self):
27
+ self.conn: Optional[sqlite3.Connection] = None
28
+
29
+ def create_database(self, setup_sql: str) -> None:
30
+ """
31
+ Create a new in-memory database with the given schema and data.
32
+
33
+ Args:
34
+ setup_sql: SQL string containing CREATE TABLE and INSERT statements
35
+ """
36
+ # Close any existing connection
37
+ self.close()
38
+
39
+ # Create fresh in-memory database
40
+ self.conn = sqlite3.connect(":memory:")
41
+
42
+ # Enable foreign keys
43
+ self.conn.execute("PRAGMA foreign_keys = ON")
44
+
45
+ # Run the setup SQL (creates tables and inserts data)
46
+ self.conn.executescript(setup_sql)
47
+ self.conn.commit()
48
+
49
+ def execute_query(self, sql: str) -> Tuple[bool, Optional[Dict], Optional[str]]:
50
+ """
51
+ Execute a SQL query and return results.
52
+
53
+ This is the main method called when the agent submits a query.
54
+ It catches all exceptions to prevent crashes.
55
+
56
+ Args:
57
+ sql: The SQL query string to execute
58
+
59
+ Returns:
60
+ Tuple of (success, result_dict, error_message):
61
+ - success: True if query executed without error
62
+ - result_dict: {"columns": [...], "rows": [...]} if successful
63
+ - error_message: Error string if failed, None if success
64
+ """
65
+ if not self.conn:
66
+ return False, None, "No database connection. Call create_database() first."
67
+
68
+ try:
69
+ cursor = self.conn.execute(sql)
70
+
71
+ # Get column names from cursor description
72
+ if cursor.description:
73
+ columns = [desc[0] for desc in cursor.description]
74
+ else:
75
+ columns = []
76
+
77
+ # Fetch all rows
78
+ rows = cursor.fetchall()
79
+
80
+ result = {
81
+ "columns": columns,
82
+ "rows": rows,
83
+ }
84
+
85
+ return True, result, None
86
+
87
+ except sqlite3.Error as e:
88
+ return False, None, f"SQL Error: {str(e)}"
89
+ except Exception as e:
90
+ return False, None, f"Execution Error: {str(e)}"
91
+
92
+ def format_result(self, result: Dict, max_rows: int = 20) -> str:
93
+ """
94
+ Format query result as a human-readable table string.
95
+
96
+ This formatted string is shown to the agent in the observation
97
+ so it can see what its query returned.
98
+
99
+ Args:
100
+ result: Dict with "columns" and "rows" keys
101
+ max_rows: Maximum number of rows to display
102
+
103
+ Returns:
104
+ Formatted table string
105
+ """
106
+ if not result or not result.get("columns"):
107
+ return "(empty result set)"
108
+
109
+ columns = result["columns"]
110
+ rows = result["rows"]
111
+
112
+ if not rows:
113
+ return f"Columns: {', '.join(columns)}\n(0 rows returned)"
114
+
115
+ # Calculate column widths (at least as wide as header)
116
+ col_widths = [len(str(c)) for c in columns]
117
+ for row in rows[:max_rows]:
118
+ for i, val in enumerate(row):
119
+ if i < len(col_widths):
120
+ col_widths[i] = max(col_widths[i], len(str(val)))
121
+
122
+ # Build formatted table
123
+ # Header
124
+ header = " | ".join(
125
+ str(c).ljust(w) for c, w in zip(columns, col_widths)
126
+ )
127
+ separator = "-+-".join("-" * w for w in col_widths)
128
+
129
+ # Data rows
130
+ formatted_rows = []
131
+ for row in rows[:max_rows]:
132
+ formatted_row = " | ".join(
133
+ str(v).ljust(w) for v, w in zip(row, col_widths)
134
+ )
135
+ formatted_rows.append(formatted_row)
136
+
137
+ # Assemble
138
+ table_str = f"{header}\n{separator}\n" + "\n".join(formatted_rows)
139
+
140
+ # Truncation notice
141
+ if len(rows) > max_rows:
142
+ table_str += f"\n... ({len(rows) - max_rows} more rows not shown)"
143
+
144
+ # Row count
145
+ table_str += f"\n\n({len(rows)} row{'s' if len(rows) != 1 else ''} returned)"
146
+
147
+ return table_str
148
+
149
+ def close(self) -> None:
150
+ """Close the database connection and free resources."""
151
+ if self.conn:
152
+ try:
153
+ self.conn.close()
154
+ except Exception:
155
+ pass # Ignore errors on close
156
+ self.conn = None
src/sql_arena/environment.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core SQL Arena Environment.
3
+ Implements the OpenEnv step()/reset()/state() interface.
4
+ """
5
+
6
+ from typing import Optional, Dict, Any, List
7
+ from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
8
+ from .database import DatabaseManager
9
+ from .tasks import SQLTask, get_task, list_tasks, TASK_BY_ID
10
+ from .graders import grade_result, generate_hint
11
+
12
+
13
+ class StepResult:
14
+ """Result of a single environment step."""
15
+
16
+ def __init__(
17
+ self,
18
+ observation: SQLArenaObservation,
19
+ reward: float,
20
+ done: bool,
21
+ info: Optional[Dict[str, Any]] = None,
22
+ ):
23
+ self.observation = observation
24
+ self.reward = reward
25
+ self.done = done
26
+ self.info = info or {}
27
+
28
+
29
+ class SQLArenaEnvironment:
30
+ """
31
+ SQL Arena: An interactive SQL query challenge environment.
32
+
33
+ The agent receives a database schema and a natural language question,
34
+ then iteratively writes SQL queries. The environment provides
35
+ execution results, feedback, and partial credit scoring.
36
+ """
37
+
38
+ def __init__(self):
39
+ self.db = DatabaseManager()
40
+ self.current_task: Optional[SQLTask] = None
41
+ self._state: Optional[SQLArenaState] = None
42
+ self._last_observation: Optional[SQLArenaObservation] = None
43
+
44
+ def reset(
45
+ self,
46
+ difficulty: str = "basic_select",
47
+ task_id: Optional[str] = None,
48
+ ) -> StepResult:
49
+ """
50
+ Reset the environment with a new task.
51
+
52
+ Args:
53
+ difficulty: 'basic_select', 'join_aggregate', or 'complex_analysis'
54
+ task_id: Optional specific task ID
55
+
56
+ Returns:
57
+ StepResult with initial observation
58
+ """
59
+ # Get the task
60
+ self.current_task = get_task(difficulty, task_id)
61
+ task = self.current_task
62
+
63
+ # Setup database
64
+ self.db.create_database(task.setup_sql)
65
+
66
+ # Initialize state
67
+ self._state = SQLArenaState(
68
+ task_id=task.task_id,
69
+ difficulty=task.difficulty,
70
+ current_step=0,
71
+ max_steps=task.max_steps,
72
+ best_score=0.0,
73
+ total_reward=0.0,
74
+ rewards_history=[],
75
+ done=False,
76
+ last_action_error=None,
77
+ )
78
+
79
+ # Create initial observation
80
+ self._last_observation = SQLArenaObservation(
81
+ schema_description=task.schema_description,
82
+ question=task.question,
83
+ query_result=None,
84
+ error_message=None,
85
+ feedback="Welcome to SQL Arena! Write a SQL query to answer the question above.",
86
+ expected_columns=task.expected_columns,
87
+ attempts_remaining=task.max_steps,
88
+ difficulty=task.difficulty,
89
+ task_id=task.task_id,
90
+ )
91
+
92
+ return StepResult(
93
+ observation=self._last_observation,
94
+ reward=0.0,
95
+ done=False,
96
+ info={"task_title": task.title},
97
+ )
98
+
99
+ def step(self, action: SQLArenaAction) -> StepResult:
100
+ """
101
+ Execute the agent's SQL query and return feedback.
102
+
103
+ Args:
104
+ action: SQLArenaAction containing the SQL query
105
+
106
+ Returns:
107
+ StepResult with observation, reward, and done flag
108
+ """
109
+ if self._state is None or self.current_task is None:
110
+ raise RuntimeError("Environment not initialized. Call reset() first.")
111
+
112
+ if self._state.done:
113
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
114
+
115
+ task = self.current_task
116
+ state = self._state
117
+
118
+ # Increment step counter
119
+ state.current_step += 1
120
+
121
+ # Execute the query
122
+ success, result, error = self.db.execute_query(action.sql_query)
123
+
124
+ # Grade the result
125
+ score, feedback = grade_result(task, success, result, error)
126
+
127
+ # Track best score
128
+ state.best_score = max(state.best_score, score)
129
+
130
+ # Calculate step reward
131
+ if len(state.rewards_history) == 0:
132
+ reward = score
133
+ else:
134
+ prev_best = max(state.rewards_history) if state.rewards_history else 0.0
135
+ improvement = max(0, score - prev_best)
136
+ reward = score * 0.5 + improvement * 0.5
137
+
138
+ reward = round(min(max(reward, 0.0), 1.0), 4)
139
+ state.rewards_history.append(reward)
140
+ state.total_reward += reward
141
+
142
+ # Add progressive hints
143
+ hint = generate_hint(task, state.current_step, score)
144
+ if hint and score < 1.0:
145
+ feedback += f"\n\n{hint}"
146
+
147
+ # Check if done
148
+ attempts_remaining = task.max_steps - state.current_step
149
+ is_perfect = score >= 1.0
150
+ is_out_of_steps = attempts_remaining <= 0
151
+
152
+ state.done = is_perfect or is_out_of_steps
153
+ state.last_action_error = error
154
+
155
+ # Format query result for observation
156
+ query_result_str = None
157
+ if success and result:
158
+ query_result_str = self.db.format_result(result)
159
+
160
+ # Build observation
161
+ self._last_observation = SQLArenaObservation(
162
+ schema_description=task.schema_description,
163
+ question=task.question,
164
+ query_result=query_result_str,
165
+ error_message=error,
166
+ feedback=feedback,
167
+ expected_columns=task.expected_columns,
168
+ attempts_remaining=attempts_remaining,
169
+ difficulty=task.difficulty,
170
+ task_id=task.task_id,
171
+ )
172
+
173
+ return StepResult(
174
+ observation=self._last_observation,
175
+ reward=reward,
176
+ done=state.done,
177
+ info={
178
+ "score": score,
179
+ "best_score": state.best_score,
180
+ "step": state.current_step,
181
+ "is_perfect": is_perfect,
182
+ },
183
+ )
184
+
185
+ def state(self) -> SQLArenaState:
186
+ """Return the current environment state."""
187
+ if self._state is None:
188
+ raise RuntimeError("Environment not initialized. Call reset() first.")
189
+ return self._state
190
+
191
+ def close(self) -> None:
192
+ """Clean up resources."""
193
+ self.db.close()
194
+ self.current_task = None
195
+ self._state = None
196
+ self._last_observation = None
197
+
198
+ def get_available_tasks(self) -> Dict:
199
+ """Return all available tasks."""
200
+ return list_tasks()
src/sql_arena/graders.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grading logic for SQL Arena.
3
+ Provides partial credit scoring (0.0 to 1.0) based on:
4
+ - Query execution success (0.10)
5
+ - Column correctness (0.20)
6
+ - Row count correctness (0.20)
7
+ - Value correctness (0.50)
8
+ """
9
+
10
+ from typing import List, Tuple, Optional, Dict, Any
11
+ from .tasks import SQLTask
12
+
13
+
14
+ def normalize_value(val: Any) -> Any:
15
+ """Normalize values for comparison."""
16
+ if val is None:
17
+ return None
18
+ if isinstance(val, float):
19
+ return round(val, 2)
20
+ if isinstance(val, str):
21
+ return val.strip().lower()
22
+ return val
23
+
24
+
25
+ def normalize_row(row: tuple) -> tuple:
26
+ """Normalize all values in a row."""
27
+ return tuple(normalize_value(v) for v in row)
28
+
29
+
30
+ def grade_result(
31
+ task: SQLTask,
32
+ success: bool,
33
+ result: Optional[Dict],
34
+ error: Optional[str],
35
+ ) -> Tuple[float, str]:
36
+ """
37
+ Grade a SQL query result against expected output.
38
+
39
+ Returns:
40
+ (score, feedback) where score is in [0.0, 1.0]
41
+
42
+ Scoring breakdown:
43
+ - 0.10: Query executes without error
44
+ - 0.20: Correct column names
45
+ - 0.20: Correct number of rows
46
+ - 0.50: Correct values (proportional to matching rows)
47
+ """
48
+ feedback_parts = []
49
+ score = 0.0
50
+
51
+ # ---- Component 1: Execution Success (0.10) ----
52
+ if not success:
53
+ feedback_parts.append(f"X Query failed: {error}")
54
+ feedback_parts.append("Hint: Fix the syntax error and try again.")
55
+ return 0.0, "\n".join(feedback_parts)
56
+
57
+ score += 0.10
58
+ feedback_parts.append("OK: Query executed successfully (+0.10)")
59
+
60
+ # ---- Component 2: Column Correctness (0.20) ----
61
+ actual_columns = [c.lower().strip() for c in result.get("columns", [])]
62
+ expected_columns = [c.lower().strip() for c in task.expected_columns]
63
+
64
+ if actual_columns == expected_columns:
65
+ score += 0.20
66
+ feedback_parts.append(f"OK: Correct columns: {actual_columns} (+0.20)")
67
+ else:
68
+ # Partial credit for overlapping columns
69
+ matching_cols = set(actual_columns) & set(expected_columns)
70
+ if matching_cols:
71
+ partial = 0.20 * (len(matching_cols) / len(expected_columns))
72
+ score += partial
73
+ feedback_parts.append(
74
+ f"PARTIAL: Column match: got {actual_columns}, "
75
+ f"expected {expected_columns} (+{partial:.2f})"
76
+ )
77
+ missing = set(expected_columns) - set(actual_columns)
78
+ if missing:
79
+ feedback_parts.append(f"Hint: Missing columns: {missing}")
80
+ else:
81
+ feedback_parts.append(
82
+ f"WRONG: Columns: got {actual_columns}, expected {expected_columns}"
83
+ )
84
+
85
+ # ---- Component 3: Row Count (0.20) ----
86
+ actual_rows = result.get("rows", [])
87
+ expected_row_count = task.expected_row_count
88
+
89
+ if len(actual_rows) == expected_row_count:
90
+ score += 0.20
91
+ feedback_parts.append(f"OK: Correct row count: {len(actual_rows)} (+0.20)")
92
+ else:
93
+ # Partial credit: closer counts get more credit
94
+ if expected_row_count > 0:
95
+ ratio = 1.0 - abs(len(actual_rows) - expected_row_count) / max(
96
+ expected_row_count, len(actual_rows)
97
+ )
98
+ partial = max(0.0, 0.20 * ratio)
99
+ score += partial
100
+ feedback_parts.append(
101
+ f"PARTIAL: Row count: got {len(actual_rows)}, "
102
+ f"expected {expected_row_count} (+{partial:.2f})"
103
+ )
104
+ else:
105
+ if len(actual_rows) == 0:
106
+ score += 0.20
107
+ feedback_parts.append("OK: Correct empty result set (+0.20)")
108
+ else:
109
+ feedback_parts.append(
110
+ f"WRONG: Expected empty result, got {len(actual_rows)} rows"
111
+ )
112
+
113
+ # ---- Component 4: Value Correctness (0.50) ----
114
+ if task.expected_rows:
115
+ normalized_expected = [normalize_row(r) for r in task.expected_rows]
116
+ normalized_actual = [normalize_row(r) for r in actual_rows]
117
+
118
+ # Try exact order match first
119
+ exact_matches = 0
120
+ for exp_row, act_row in zip(normalized_expected, normalized_actual):
121
+ if exp_row == act_row:
122
+ exact_matches += 1
123
+
124
+ if (
125
+ exact_matches == len(normalized_expected)
126
+ and len(normalized_actual) == len(normalized_expected)
127
+ ):
128
+ score += 0.50
129
+ feedback_parts.append("OK: All values correct with correct ordering (+0.50)")
130
+ else:
131
+ # Try unordered match (set-based)
132
+ matched_rows = 0
133
+ remaining_actual = list(normalized_actual)
134
+
135
+ for exp_row in normalized_expected:
136
+ for i, act_row in enumerate(remaining_actual):
137
+ if exp_row == act_row:
138
+ matched_rows += 1
139
+ remaining_actual.pop(i)
140
+ break
141
+
142
+ if (
143
+ matched_rows == len(normalized_expected)
144
+ and len(normalized_actual) == len(normalized_expected)
145
+ ):
146
+ # All rows match but wrong order
147
+ partial = 0.40
148
+ score += partial
149
+ feedback_parts.append(
150
+ f"PARTIAL: All values correct but wrong ordering (+{partial:.2f})"
151
+ )
152
+ feedback_parts.append("Hint: Check your ORDER BY clause")
153
+ elif matched_rows > 0:
154
+ # Some rows match
155
+ partial = 0.50 * (matched_rows / len(normalized_expected))
156
+ score += partial
157
+ feedback_parts.append(
158
+ f"PARTIAL: {matched_rows}/{len(normalized_expected)} rows match (+{partial:.2f})"
159
+ )
160
+ if matched_rows < len(normalized_expected):
161
+ feedback_parts.append(
162
+ "Hint: Some values are incorrect. Check WHERE/JOIN conditions."
163
+ )
164
+ else:
165
+ feedback_parts.append("WRONG: No matching rows found")
166
+ feedback_parts.append(
167
+ "Hint: Review your query logic - values don't match expected output."
168
+ )
169
+
170
+ # Tiny credit if some values appear somewhere
171
+ all_expected_vals = set()
172
+ for row in normalized_expected:
173
+ all_expected_vals.update(row)
174
+ all_actual_vals = set()
175
+ for row in normalized_actual:
176
+ all_actual_vals.update(row)
177
+
178
+ overlap = all_expected_vals & all_actual_vals
179
+ if overlap:
180
+ tiny_credit = 0.05
181
+ score += tiny_credit
182
+ feedback_parts.append(
183
+ f" (Some expected values found in output: +{tiny_credit:.2f})"
184
+ )
185
+ else:
186
+ # Expected empty result
187
+ if len(actual_rows) == 0:
188
+ score += 0.50
189
+ feedback_parts.append("OK: Correctly returned empty result (+0.50)")
190
+ else:
191
+ feedback_parts.append(
192
+ f"WRONG: Expected empty result, got {len(actual_rows)} rows"
193
+ )
194
+
195
+ # ---- Final score ----
196
+ score = round(min(max(score, 0.0), 1.0), 4)
197
+ feedback_parts.append(f"\nTotal Score: {score:.2f}/1.00")
198
+
199
+ return score, "\n".join(feedback_parts)
200
+
201
+
202
+ def generate_hint(task: SQLTask, step: int, current_score: float) -> Optional[str]:
203
+ """Generate progressive hints based on step number and current score."""
204
+ if current_score >= 0.8:
205
+ return None # No hint needed
206
+
207
+ if step <= len(task.hints):
208
+ return f"Hint {step}: {task.hints[step - 1]}"
209
+
210
+ # Generic hints for later steps
211
+ generic_hints = [
212
+ f"Expected columns are: {task.expected_columns}",
213
+ f"Expected {task.expected_row_count} rows in the result",
214
+ "Check the schema description carefully for table and column names",
215
+ ]
216
+
217
+ hint_idx = min(step - len(task.hints) - 1, len(generic_hints) - 1)
218
+ if hint_idx >= 0:
219
+ return generic_hints[hint_idx]
220
+ return None
src/sql_arena/models.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Typed Pydantic models for SQL Arena OpenEnv environment.
3
+
4
+ These models define the contract between the agent and environment:
5
+ - SQLArenaAction: What the agent sends (a SQL query)
6
+ - SQLArenaObservation: What the agent receives (schema, results, feedback)
7
+ - SQLArenaState: Internal environment state tracking
8
+ """
9
+
10
+ from pydantic import BaseModel, Field
11
+ from typing import Optional, List
12
+
13
+
14
+ class SQLArenaAction(BaseModel):
15
+ """
16
+ Action model — the agent submits a SQL query.
17
+
18
+ This is what the agent sends to the environment each step.
19
+ The environment will execute this query against the SQLite database
20
+ and return results + feedback.
21
+ """
22
+ sql_query: str = Field(
23
+ ...,
24
+ description="SQL query to execute against the database",
25
+ examples=[
26
+ "SELECT name, salary FROM employees WHERE salary > 50000",
27
+ "SELECT department, COUNT(*) FROM employees GROUP BY department",
28
+ ]
29
+ )
30
+
31
+
32
+ class SQLArenaObservation(BaseModel):
33
+ """
34
+ Observation model — what the agent sees after each step.
35
+
36
+ Contains the database schema, the question to answer,
37
+ results from the last query, error messages, and feedback
38
+ with partial credit information.
39
+ """
40
+ # Always present
41
+ schema_description: str = Field(
42
+ ...,
43
+ description="Human-readable database schema (CREATE TABLE statements)"
44
+ )
45
+ question: str = Field(
46
+ ...,
47
+ description="Natural language question the agent must answer with SQL"
48
+ )
49
+ difficulty: str = Field(
50
+ ...,
51
+ description="Task difficulty level: basic_select, join_aggregate, or complex_analysis"
52
+ )
53
+ task_id: str = Field(
54
+ ...,
55
+ description="Unique identifier for this specific problem"
56
+ )
57
+ attempts_remaining: int = Field(
58
+ ...,
59
+ description="Number of query attempts the agent has left"
60
+ )
61
+
62
+ # Present after step() calls
63
+ query_result: Optional[str] = Field(
64
+ None,
65
+ description="Formatted result table from the last executed query"
66
+ )
67
+ error_message: Optional[str] = Field(
68
+ None,
69
+ description="SQL error message if the query failed to execute"
70
+ )
71
+ feedback: Optional[str] = Field(
72
+ None,
73
+ description="Detailed feedback on query correctness with partial credit breakdown"
74
+ )
75
+
76
+ # Hints to help the agent
77
+ expected_columns: Optional[List[str]] = Field(
78
+ None,
79
+ description="Expected column names in the correct result (hint)"
80
+ )
81
+
82
+
83
+ class SQLArenaState(BaseModel):
84
+ """
85
+ Internal state model — tracks the episode progress.
86
+
87
+ This is returned by the state() endpoint and contains
88
+ all information about the current episode.
89
+ """
90
+ task_id: str = Field(..., description="Current task identifier")
91
+ difficulty: str = Field(..., description="Current difficulty level")
92
+ current_step: int = Field(0, description="Number of steps taken so far")
93
+ max_steps: int = Field(5, description="Maximum steps allowed for this task")
94
+ best_score: float = Field(0.0, description="Best score achieved so far in this episode")
95
+ total_reward: float = Field(0.0, description="Sum of all rewards received")
96
+ rewards_history: List[float] = Field(
97
+ default_factory=list,
98
+ description="List of rewards received at each step"
99
+ )
100
+ done: bool = Field(False, description="Whether the episode has ended")
101
+ last_action_error: Optional[str] = Field(
102
+ None,
103
+ description="Error from the last action, if any"
104
+ )
src/sql_arena/server.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for SQL Arena - OpenEnv compatible.
3
+ Exposes /reset, /step, /state endpoints via HTTP and WebSocket.
4
+ """
5
+
6
+ import json
7
+ import uuid
8
+ import asyncio
9
+ from typing import Dict, Optional
10
+ from contextlib import asynccontextmanager
11
+
12
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+
16
+ from .environment import SQLArenaEnvironment, StepResult
17
+ from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
18
+
19
+
20
+ # =====================================================
21
+ # Request / Response Models
22
+ # =====================================================
23
+
24
+ class ResetRequest(BaseModel):
25
+ difficulty: str = "basic_select"
26
+ task_id: Optional[str] = None
27
+
28
+
29
+ class StepRequest(BaseModel):
30
+ sql_query: str
31
+
32
+
33
+ class ResetResponse(BaseModel):
34
+ observation: SQLArenaObservation
35
+ reward: float
36
+ done: bool
37
+ info: dict = {}
38
+
39
+
40
+ class StepResponse(BaseModel):
41
+ observation: SQLArenaObservation
42
+ reward: float
43
+ done: bool
44
+ info: dict = {}
45
+
46
+
47
+ class StateResponse(BaseModel):
48
+ state: SQLArenaState
49
+
50
+
51
+ class TaskListResponse(BaseModel):
52
+ tasks: Dict
53
+
54
+
55
+ # =====================================================
56
+ # Session Manager
57
+ # =====================================================
58
+
59
+ class SessionManager:
60
+ """Manages multiple concurrent environment instances."""
61
+
62
+ def __init__(self, max_sessions: int = 100):
63
+ self.sessions: Dict[str, SQLArenaEnvironment] = {}
64
+ self.max_sessions = max_sessions
65
+ self._lock = asyncio.Lock()
66
+
67
+ async def create_session(self):
68
+ async with self._lock:
69
+ if len(self.sessions) >= self.max_sessions:
70
+ oldest_key = next(iter(self.sessions))
71
+ self.sessions[oldest_key].close()
72
+ del self.sessions[oldest_key]
73
+ session_id = str(uuid.uuid4())
74
+ env = SQLArenaEnvironment()
75
+ self.sessions[session_id] = env
76
+ return session_id, env
77
+
78
+ async def get_session(self, session_id: str):
79
+ return self.sessions.get(session_id)
80
+
81
+ async def remove_session(self, session_id: str):
82
+ async with self._lock:
83
+ if session_id in self.sessions:
84
+ self.sessions[session_id].close()
85
+ del self.sessions[session_id]
86
+
87
+ async def cleanup_all(self):
88
+ async with self._lock:
89
+ for env in self.sessions.values():
90
+ env.close()
91
+ self.sessions.clear()
92
+
93
+
94
+ # =====================================================
95
+ # App Setup
96
+ # =====================================================
97
+
98
+ session_manager = SessionManager()
99
+ _default_env = SQLArenaEnvironment()
100
+
101
+
102
+ @asynccontextmanager
103
+ async def lifespan(app: FastAPI):
104
+ yield
105
+ await session_manager.cleanup_all()
106
+ _default_env.close()
107
+
108
+
109
+ app = FastAPI(
110
+ title="SQL Arena - OpenEnv Environment",
111
+ description="Interactive SQL query challenge environment for AI agents",
112
+ version="1.0.0",
113
+ lifespan=lifespan,
114
+ )
115
+
116
+ app.add_middleware(
117
+ CORSMiddleware,
118
+ allow_origins=["*"],
119
+ allow_credentials=True,
120
+ allow_methods=["*"],
121
+ allow_headers=["*"],
122
+ )
123
+
124
+
125
+ # =====================================================
126
+ # HTTP Endpoints
127
+ # =====================================================
128
+
129
+ @app.get("/")
130
+ async def root():
131
+ return {
132
+ "name": "SQL Arena",
133
+ "version": "1.0.0",
134
+ "description": "Interactive SQL query challenge environment",
135
+ "endpoints": ["/reset", "/step", "/state", "/tasks", "/ws"],
136
+ }
137
+
138
+
139
+ @app.get("/health")
140
+ async def health():
141
+ return {"status": "healthy"}
142
+
143
+
144
+ @app.post("/reset", response_model=ResetResponse)
145
+ async def reset(request: ResetRequest = ResetRequest()):
146
+ try:
147
+ result = _default_env.reset(
148
+ difficulty=request.difficulty,
149
+ task_id=request.task_id,
150
+ )
151
+ return ResetResponse(
152
+ observation=result.observation,
153
+ reward=result.reward,
154
+ done=result.done,
155
+ info=result.info,
156
+ )
157
+ except Exception as e:
158
+ raise HTTPException(status_code=400, detail=str(e))
159
+
160
+
161
+ @app.post("/step", response_model=StepResponse)
162
+ async def step(request: StepRequest):
163
+ try:
164
+ action = SQLArenaAction(sql_query=request.sql_query)
165
+ result = _default_env.step(action)
166
+ return StepResponse(
167
+ observation=result.observation,
168
+ reward=result.reward,
169
+ done=result.done,
170
+ info=result.info,
171
+ )
172
+ except Exception as e:
173
+ raise HTTPException(status_code=400, detail=str(e))
174
+
175
+
176
+ @app.get("/state", response_model=StateResponse)
177
+ async def state():
178
+ try:
179
+ return StateResponse(state=_default_env.state())
180
+ except Exception as e:
181
+ raise HTTPException(status_code=400, detail=str(e))
182
+
183
+
184
+ @app.get("/tasks", response_model=TaskListResponse)
185
+ async def tasks():
186
+ return TaskListResponse(tasks=_default_env.get_available_tasks())
187
+
188
+
189
+ # =====================================================
190
+ # WebSocket Endpoint
191
+ # =====================================================
192
+
193
+ @app.websocket("/ws")
194
+ async def websocket_endpoint(websocket: WebSocket):
195
+ await websocket.accept()
196
+ session_id, env = await session_manager.create_session()
197
+
198
+ try:
199
+ while True:
200
+ data = await websocket.receive_text()
201
+ message = json.loads(data)
202
+
203
+ method = message.get("method", "")
204
+ params = message.get("params", {})
205
+ msg_id = message.get("id", None)
206
+
207
+ try:
208
+ if method == "reset":
209
+ result = env.reset(
210
+ difficulty=params.get("difficulty", "basic_select"),
211
+ task_id=params.get("task_id"),
212
+ )
213
+ response = {
214
+ "id": msg_id,
215
+ "result": {
216
+ "observation": result.observation.model_dump(),
217
+ "reward": result.reward,
218
+ "done": result.done,
219
+ "info": result.info,
220
+ },
221
+ }
222
+ elif method == "step":
223
+ action = SQLArenaAction(sql_query=params.get("sql_query", ""))
224
+ result = env.step(action)
225
+ response = {
226
+ "id": msg_id,
227
+ "result": {
228
+ "observation": result.observation.model_dump(),
229
+ "reward": result.reward,
230
+ "done": result.done,
231
+ "info": result.info,
232
+ },
233
+ }
234
+ elif method == "state":
235
+ env_state = env.state()
236
+ response = {
237
+ "id": msg_id,
238
+ "result": {"state": env_state.model_dump()},
239
+ }
240
+ elif method == "close":
241
+ response = {"id": msg_id, "result": {"status": "closed"}}
242
+ await websocket.send_text(json.dumps(response))
243
+ break
244
+ else:
245
+ response = {"id": msg_id, "error": f"Unknown method: {method}"}
246
+
247
+ await websocket.send_text(json.dumps(response))
248
+
249
+ except Exception as e:
250
+ error_response = {"id": msg_id, "error": str(e)}
251
+ await websocket.send_text(json.dumps(error_response))
252
+
253
+ except WebSocketDisconnect:
254
+ pass
255
+ finally:
256
+ await session_manager.remove_session(session_id)
257
+
258
+
259
+ # =====================================================
260
+ # Entry point
261
+ # =====================================================
262
+
263
+ if __name__ == "__main__":
264
+ import uvicorn
265
+ uvicorn.run(app, host="0.0.0.0", port=7860)
src/sql_arena/tasks.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task Bank for SQL Arena.
3
+
4
+ Contains 9 SQL challenges across 3 difficulty levels:
5
+ - basic_select (Easy): 3 tasks — simple SELECT/WHERE/ORDER BY
6
+ - join_aggregate (Medium): 3 tasks — JOINs, GROUP BY, HAVING
7
+ - complex_analysis (Hard): 3 tasks — CTEs, window functions, subqueries
8
+
9
+ Each task defines:
10
+ - Database schema and sample data (setup_sql)
11
+ - Natural language question
12
+ - Expected SQL solution
13
+ - Expected result for grading
14
+ - Progressive hints
15
+ """
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Dict, Optional
19
+ import random
20
+
21
+
22
+ @dataclass
23
+ class SQLTask:
24
+ """A single SQL challenge problem."""
25
+ task_id: str
26
+ difficulty: str # basic_select, join_aggregate, complex_analysis
27
+ title: str
28
+ setup_sql: str # CREATE TABLE + INSERT statements
29
+ question: str # Natural language question
30
+ expected_sql: str # Reference solution
31
+ expected_columns: List[str] # Expected column names in result
32
+ expected_row_count: int # Expected number of result rows
33
+ expected_rows: List[tuple] # Expected result rows for grading
34
+ hints: List[str] = field(default_factory=list)
35
+ max_steps: int = 5
36
+ schema_description: str = "" # Human-readable schema description
37
+
38
+
39
+ # =============================================================
40
+ # DATABASE SCHEMAS
41
+ # =============================================================
42
+
43
+ # Schema 1: Employee database (used by Easy tasks)
44
+ EMPLOYEES_SCHEMA = """
45
+ CREATE TABLE employees (
46
+ id INTEGER PRIMARY KEY,
47
+ name TEXT NOT NULL,
48
+ department TEXT NOT NULL,
49
+ salary REAL NOT NULL,
50
+ hire_date TEXT NOT NULL,
51
+ is_active INTEGER DEFAULT 1
52
+ );
53
+
54
+ INSERT INTO employees VALUES (1, 'Alice Johnson', 'Engineering', 95000, '2020-01-15', 1);
55
+ INSERT INTO employees VALUES (2, 'Bob Smith', 'Marketing', 65000, '2019-06-01', 1);
56
+ INSERT INTO employees VALUES (3, 'Carol Williams', 'Engineering', 110000, '2018-03-20', 1);
57
+ INSERT INTO employees VALUES (4, 'David Brown', 'Sales', 72000, '2021-09-10', 1);
58
+ INSERT INTO employees VALUES (5, 'Eve Davis', 'Engineering', 88000, '2022-02-28', 1);
59
+ INSERT INTO employees VALUES (6, 'Frank Miller', 'Marketing', 58000, '2020-11-15', 0);
60
+ INSERT INTO employees VALUES (7, 'Grace Wilson', 'Sales', 81000, '2019-04-22', 1);
61
+ INSERT INTO employees VALUES (8, 'Henry Taylor', 'Engineering', 125000, '2017-08-01', 1);
62
+ INSERT INTO employees VALUES (9, 'Ivy Anderson', 'HR', 70000, '2021-01-10', 1);
63
+ INSERT INTO employees VALUES (10, 'Jack Thomas', 'HR', 75000, '2020-07-15', 1);
64
+ """
65
+
66
+ EMPLOYEES_SCHEMA_DESC = """Table: employees
67
+ Columns:
68
+ - id: INTEGER PRIMARY KEY (auto-increment identifier)
69
+ - name: TEXT (employee full name, e.g. 'Alice Johnson')
70
+ - department: TEXT (one of: Engineering, Marketing, Sales, HR)
71
+ - salary: REAL (annual salary in USD, e.g. 95000.0)
72
+ - hire_date: TEXT (date in YYYY-MM-DD format, e.g. '2020-01-15')
73
+ - is_active: INTEGER (1 = currently active, 0 = inactive/left)
74
+
75
+ Data: 10 employees across 4 departments.
76
+ - 4 in Engineering, 2 in Marketing (1 inactive), 2 in Sales, 2 in HR
77
+ - Salaries range from 58,000 to 125,000
78
+ - Hire dates range from 2017 to 2022
79
+ """
80
+
81
+
82
+ # Schema 2: E-commerce database (used by Medium and Hard tasks)
83
+ ECOMMERCE_SCHEMA = """
84
+ CREATE TABLE customers (
85
+ id INTEGER PRIMARY KEY,
86
+ name TEXT NOT NULL,
87
+ email TEXT NOT NULL,
88
+ city TEXT NOT NULL,
89
+ signup_date TEXT NOT NULL
90
+ );
91
+
92
+ CREATE TABLE products (
93
+ id INTEGER PRIMARY KEY,
94
+ name TEXT NOT NULL,
95
+ category TEXT NOT NULL,
96
+ price REAL NOT NULL,
97
+ stock INTEGER NOT NULL
98
+ );
99
+
100
+ CREATE TABLE orders (
101
+ id INTEGER PRIMARY KEY,
102
+ customer_id INTEGER NOT NULL,
103
+ order_date TEXT NOT NULL,
104
+ status TEXT NOT NULL,
105
+ FOREIGN KEY (customer_id) REFERENCES customers(id)
106
+ );
107
+
108
+ CREATE TABLE order_items (
109
+ id INTEGER PRIMARY KEY,
110
+ order_id INTEGER NOT NULL,
111
+ product_id INTEGER NOT NULL,
112
+ quantity INTEGER NOT NULL,
113
+ unit_price REAL NOT NULL,
114
+ FOREIGN KEY (order_id) REFERENCES orders(id),
115
+ FOREIGN KEY (product_id) REFERENCES products(id)
116
+ );
117
+
118
+ -- Customers
119
+ INSERT INTO customers VALUES (1, 'Alice', 'alice@email.com', 'New York', '2023-01-15');
120
+ INSERT INTO customers VALUES (2, 'Bob', 'bob@email.com', 'Los Angeles', '2023-02-20');
121
+ INSERT INTO customers VALUES (3, 'Carol', 'carol@email.com', 'Chicago', '2023-03-10');
122
+ INSERT INTO customers VALUES (4, 'David', 'david@email.com', 'New York', '2023-04-05');
123
+ INSERT INTO customers VALUES (5, 'Eve', 'eve@email.com', 'Boston', '2023-05-12');
124
+
125
+ -- Products
126
+ INSERT INTO products VALUES (1, 'Laptop', 'Electronics', 999.99, 50);
127
+ INSERT INTO products VALUES (2, 'Headphones', 'Electronics', 149.99, 200);
128
+ INSERT INTO products VALUES (3, 'Python Book', 'Books', 39.99, 100);
129
+ INSERT INTO products VALUES (4, 'Desk Lamp', 'Home', 29.99, 150);
130
+ INSERT INTO products VALUES (5, 'Keyboard', 'Electronics', 79.99, 120);
131
+ INSERT INTO products VALUES (6, 'SQL Book', 'Books', 44.99, 80);
132
+
133
+ -- Orders (10 orders, various statuses)
134
+ INSERT INTO orders VALUES (1, 1, '2023-06-01', 'completed');
135
+ INSERT INTO orders VALUES (2, 1, '2023-07-15', 'completed');
136
+ INSERT INTO orders VALUES (3, 2, '2023-06-20', 'completed');
137
+ INSERT INTO orders VALUES (4, 3, '2023-08-01', 'completed');
138
+ INSERT INTO orders VALUES (5, 3, '2023-08-15', 'completed');
139
+ INSERT INTO orders VALUES (6, 3, '2023-09-01', 'completed');
140
+ INSERT INTO orders VALUES (7, 4, '2023-07-10', 'cancelled');
141
+ INSERT INTO orders VALUES (8, 5, '2023-09-20', 'completed');
142
+ INSERT INTO orders VALUES (9, 1, '2023-10-01', 'completed');
143
+ INSERT INTO orders VALUES (10, 2, '2023-10-15', 'pending');
144
+
145
+ -- Order Items (17 line items)
146
+ INSERT INTO order_items VALUES (1, 1, 1, 1, 999.99);
147
+ INSERT INTO order_items VALUES (2, 1, 2, 2, 149.99);
148
+ INSERT INTO order_items VALUES (3, 2, 3, 1, 39.99);
149
+ INSERT INTO order_items VALUES (4, 2, 5, 1, 79.99);
150
+ INSERT INTO order_items VALUES (5, 3, 1, 1, 999.99);
151
+ INSERT INTO order_items VALUES (6, 3, 4, 3, 29.99);
152
+ INSERT INTO order_items VALUES (7, 4, 2, 1, 149.99);
153
+ INSERT INTO order_items VALUES (8, 4, 6, 2, 44.99);
154
+ INSERT INTO order_items VALUES (9, 5, 3, 1, 39.99);
155
+ INSERT INTO order_items VALUES (10, 5, 5, 2, 79.99);
156
+ INSERT INTO order_items VALUES (11, 6, 1, 1, 999.99);
157
+ INSERT INTO order_items VALUES (12, 6, 2, 1, 149.99);
158
+ INSERT INTO order_items VALUES (13, 8, 6, 1, 44.99);
159
+ INSERT INTO order_items VALUES (14, 8, 4, 1, 29.99);
160
+ INSERT INTO order_items VALUES (15, 9, 2, 3, 149.99);
161
+ INSERT INTO order_items VALUES (16, 9, 3, 2, 39.99);
162
+ INSERT INTO order_items VALUES (17, 10, 1, 1, 999.99);
163
+ """
164
+
165
+ ECOMMERCE_SCHEMA_DESC = """Tables:
166
+
167
+ 1. customers (5 rows)
168
+ - id: INTEGER PRIMARY KEY
169
+ - name: TEXT (customer first name)
170
+ - email: TEXT
171
+ - city: TEXT (New York, Los Angeles, Chicago, Boston)
172
+ - signup_date: TEXT (YYYY-MM-DD)
173
+
174
+ 2. products (6 rows)
175
+ - id: INTEGER PRIMARY KEY
176
+ - name: TEXT (product name)
177
+ - category: TEXT (Electronics, Books, Home)
178
+ - price: REAL (unit price in USD)
179
+ - stock: INTEGER (units in stock)
180
+
181
+ 3. orders (10 rows)
182
+ - id: INTEGER PRIMARY KEY
183
+ - customer_id: INTEGER → customers.id
184
+ - order_date: TEXT (YYYY-MM-DD, range: 2023-06 to 2023-10)
185
+ - status: TEXT (completed, cancelled, pending)
186
+
187
+ 4. order_items (17 rows)
188
+ - id: INTEGER PRIMARY KEY
189
+ - order_id: INTEGER → orders.id
190
+ - product_id: INTEGER → products.id
191
+ - quantity: INTEGER
192
+ - unit_price: REAL (price at time of order)
193
+
194
+ Relationships:
195
+ orders.customer_id → customers.id
196
+ order_items.order_id → orders.id
197
+ order_items.product_id → products.id
198
+ """
199
+
200
+
201
+ # =============================================================
202
+ # EASY TASKS: basic_select (3 tasks)
203
+ # =============================================================
204
+
205
+ EASY_TASKS = [
206
+ SQLTask(
207
+ task_id="easy_001",
208
+ difficulty="basic_select",
209
+ title="High Salary Employees",
210
+ setup_sql=EMPLOYEES_SCHEMA,
211
+ question="Find the names and salaries of all ACTIVE employees who earn more than \$80,000. Order the results by salary from highest to lowest.",
212
+ expected_sql="SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
213
+ expected_columns=["name", "salary"],
214
+ expected_row_count=4,
215
+ expected_rows=[
216
+ ("Henry Taylor", 125000.0),
217
+ ("Carol Williams", 110000.0),
218
+ ("Alice Johnson", 95000.0),
219
+ ("Eve Davis", 88000.0),
220
+ ],
221
+ hints=[
222
+ "Use SELECT with specific column names, not SELECT *",
223
+ "Use WHERE with AND to combine conditions: is_active = 1 AND salary > 80000",
224
+ "Add ORDER BY salary DESC for descending order",
225
+ ],
226
+ schema_description=EMPLOYEES_SCHEMA_DESC,
227
+ max_steps=5,
228
+ ),
229
+
230
+ SQLTask(
231
+ task_id="easy_002",
232
+ difficulty="basic_select",
233
+ title="Department Employee Count",
234
+ setup_sql=EMPLOYEES_SCHEMA,
235
+ question="Count the number of ACTIVE employees in each department. Show the department name and the count. Order by count from highest to lowest.",
236
+ expected_sql="SELECT department, COUNT(*) as employee_count FROM employees WHERE is_active = 1 GROUP BY department ORDER BY employee_count DESC",
237
+ expected_columns=["department", "employee_count"],
238
+ expected_row_count=4,
239
+ expected_rows=[
240
+ ("Engineering", 4),
241
+ ("HR", 2),
242
+ ("Sales", 2),
243
+ ("Marketing", 1),
244
+ ],
245
+ hints=[
246
+ "Use COUNT(*) to count rows in each group",
247
+ "GROUP BY department groups rows by department",
248
+ "Use an alias: COUNT(*) as employee_count",
249
+ ],
250
+ schema_description=EMPLOYEES_SCHEMA_DESC,
251
+ max_steps=5,
252
+ ),
253
+
254
+ SQLTask(
255
+ task_id="easy_003",
256
+ difficulty="basic_select",
257
+ title="Recent Hires",
258
+ setup_sql=EMPLOYEES_SCHEMA,
259
+ question="List the names and hire dates of employees hired on or after January 1, 2021. Order by hire date from earliest to latest.",
260
+ expected_sql="SELECT name, hire_date FROM employees WHERE hire_date >= '2021-01-01' ORDER BY hire_date",
261
+ expected_columns=["name", "hire_date"],
262
+ expected_row_count=3,
263
+ expected_rows=[
264
+ ("Ivy Anderson", "2021-01-10"),
265
+ ("David Brown", "2021-09-10"),
266
+ ("Eve Davis", "2022-02-28"),
267
+ ],
268
+ hints=[
269
+ "Dates in SQLite can be compared as strings when in YYYY-MM-DD format",
270
+ "Use WHERE hire_date >= '2021-01-01'",
271
+ "ORDER BY hire_date gives ascending order by default",
272
+ ],
273
+ schema_description=EMPLOYEES_SCHEMA_DESC,
274
+ max_steps=5,
275
+ ),
276
+ ]
277
+
278
+
279
+ # =============================================================
280
+ # MEDIUM TASKS: join_aggregate (3 tasks)
281
+ # =============================================================
282
+
283
+ MEDIUM_TASKS = [
284
+ SQLTask(
285
+ task_id="medium_001",
286
+ difficulty="join_aggregate",
287
+ title="Customer Total Spending",
288
+ setup_sql=ECOMMERCE_SCHEMA,
289
+ question="Find the total amount spent by each customer on COMPLETED orders only. Show the customer name and their total spending. Only include customers who spent more than \$200. Order by total spending from highest to lowest.",
290
+ expected_sql="""
291
+ SELECT c.name, ROUND(SUM(oi.quantity * oi.unit_price), 2) as total_spent
292
+ FROM customers c
293
+ JOIN orders o ON c.id = o.customer_id
294
+ JOIN order_items oi ON o.id = oi.order_id
295
+ WHERE o.status = 'completed'
296
+ GROUP BY c.id, c.name
297
+ HAVING SUM(oi.quantity * oi.unit_price) > 200
298
+ ORDER BY total_spent DESC
299
+ """,
300
+ expected_columns=["name", "total_spent"],
301
+ expected_row_count=4,
302
+ expected_rows=[
303
+ ("Alice", 1919.91),
304
+ ("Carol", 1464.94),
305
+ ("Bob", 1089.96),
306
+ ("Eve", 74.98),
307
+ ],
308
+ hints=[
309
+ "You need to JOIN three tables: customers → orders → order_items",
310
+ "Total per item = quantity * unit_price, then SUM for total per customer",
311
+ "Filter completed orders with WHERE o.status = 'completed'",
312
+ "Use HAVING (not WHERE) to filter after GROUP BY",
313
+ ],
314
+ schema_description=ECOMMERCE_SCHEMA_DESC,
315
+ max_steps=7,
316
+ ),
317
+
318
+ SQLTask(
319
+ task_id="medium_002",
320
+ difficulty="join_aggregate",
321
+ title="Category Revenue",
322
+ setup_sql=ECOMMERCE_SCHEMA,
323
+ question="Calculate the total revenue for each product category from COMPLETED orders. Show the category name and total revenue. Order by total revenue from highest to lowest.",
324
+ expected_sql="""
325
+ SELECT p.category, ROUND(SUM(oi.quantity * oi.unit_price), 2) as total_revenue
326
+ FROM products p
327
+ JOIN order_items oi ON p.id = oi.product_id
328
+ JOIN orders o ON oi.order_id = o.id
329
+ WHERE o.status = 'completed'
330
+ GROUP BY p.category
331
+ ORDER BY total_revenue DESC
332
+ """,
333
+ expected_columns=["category", "total_revenue"],
334
+ expected_row_count=3,
335
+ expected_rows=[
336
+ ("Electronics", 4459.83),
337
+ ("Books", 254.93),
338
+ ("Home", 119.96),
339
+ ],
340
+ hints=[
341
+ "JOIN products → order_items → orders",
342
+ "Revenue per item = quantity * unit_price",
343
+ "Filter only completed orders",
344
+ "GROUP BY p.category to get per-category totals",
345
+ ],
346
+ schema_description=ECOMMERCE_SCHEMA_DESC,
347
+ max_steps=7,
348
+ ),
349
+
350
+ SQLTask(
351
+ task_id="medium_003",
352
+ difficulty="join_aggregate",
353
+ title="Customers with Multiple Orders",
354
+ setup_sql=ECOMMERCE_SCHEMA,
355
+ question="Find customers who have placed more than one COMPLETED order. Show the customer name and the number of completed orders they placed. Order by order count descending, then by name ascending.",
356
+ expected_sql="""
357
+ SELECT c.name, COUNT(o.id) as order_count
358
+ FROM customers c
359
+ JOIN orders o ON c.id = o.customer_id
360
+ WHERE o.status = 'completed'
361
+ GROUP BY c.id, c.name
362
+ HAVING COUNT(o.id) > 1
363
+ ORDER BY order_count DESC, c.name ASC
364
+ """,
365
+ expected_columns=["name", "order_count"],
366
+ expected_row_count=2,
367
+ expected_rows=[
368
+ ("Alice", 3),
369
+ ("Carol", 3),
370
+ ],
371
+ hints=[
372
+ "JOIN customers with orders",
373
+ "Filter for completed orders in WHERE clause",
374
+ "GROUP BY customer, then HAVING COUNT > 1",
375
+ "ORDER BY count DESC, then name ASC for ties",
376
+ ],
377
+ schema_description=ECOMMERCE_SCHEMA_DESC,
378
+ max_steps=7,
379
+ ),
380
+ ]
381
+
382
+
383
+ # =============================================================
384
+ # HARD TASKS: complex_analysis (3 tasks)
385
+ # =============================================================
386
+
387
+ HARD_TASKS = [
388
+ SQLTask(
389
+ task_id="hard_001",
390
+ difficulty="complex_analysis",
391
+ title="Monthly Revenue with Growth Rate",
392
+ setup_sql=ECOMMERCE_SCHEMA,
393
+ question="Calculate monthly revenue from COMPLETED orders, and for each month show the month (YYYY-MM format), the total revenue, and the percentage change from the previous month. For the first month, the percentage change should be NULL. Round revenue to 2 decimal places and percentage to 2 decimal places. Order by month ascending.",
394
+ expected_sql="""
395
+ WITH monthly AS (
396
+ SELECT
397
+ strftime('%Y-%m', o.order_date) as month,
398
+ ROUND(SUM(oi.quantity * oi.unit_price), 2) as revenue
399
+ FROM orders o
400
+ JOIN order_items oi ON o.id = oi.order_id
401
+ WHERE o.status = 'completed'
402
+ GROUP BY strftime('%Y-%m', o.order_date)
403
+ ),
404
+ with_prev AS (
405
+ SELECT
406
+ month,
407
+ revenue,
408
+ LAG(revenue) OVER (ORDER BY month) as prev_revenue
409
+ FROM monthly
410
+ )
411
+ SELECT
412
+ month,
413
+ revenue,
414
+ CASE
415
+ WHEN prev_revenue IS NULL THEN NULL
416
+ ELSE ROUND(((revenue - prev_revenue) * 100.0 / prev_revenue), 2)
417
+ END as pct_change
418
+ FROM with_prev
419
+ ORDER BY month
420
+ """,
421
+ expected_columns=["month", "revenue", "pct_change"],
422
+ expected_row_count=5,
423
+ expected_rows=[
424
+ ("2023-06", 2289.93, None),
425
+ ("2023-07", 119.98, -94.76),
426
+ ("2023-08", 1429.93, 1091.81),
427
+ ("2023-09", 1224.97, -14.34),
428
+ ("2023-10", 529.95, -56.74),
429
+ ],
430
+ hints=[
431
+ "Use a CTE (WITH clause) to first calculate monthly revenue",
432
+ "strftime('%Y-%m', date) extracts year-month from a date string",
433
+ "LAG(revenue) OVER (ORDER BY month) gets the previous month's revenue",
434
+ "Percentage change = ((new - old) / old) * 100",
435
+ "Use CASE WHEN prev IS NULL THEN NULL ELSE ... END for first month",
436
+ ],
437
+ schema_description=ECOMMERCE_SCHEMA_DESC,
438
+ max_steps=10,
439
+ ),
440
+
441
+ SQLTask(
442
+ task_id="hard_002",
443
+ difficulty="complex_analysis",
444
+ title="Top Product Per Category",
445
+ setup_sql=ECOMMERCE_SCHEMA,
446
+ question="For each product category, find the single best-selling product (by total quantity sold across COMPLETED orders). Show the category, product name, and total quantity sold. If there are ties, pick the one with the higher total revenue. Order by category name ascending.",
447
+ expected_sql="""
448
+ WITH product_sales AS (
449
+ SELECT
450
+ p.category,
451
+ p.name as product_name,
452
+ SUM(oi.quantity) as total_qty,
453
+ SUM(oi.quantity * oi.unit_price) as total_revenue,
454
+ ROW_NUMBER() OVER (
455
+ PARTITION BY p.category
456
+ ORDER BY SUM(oi.quantity) DESC, SUM(oi.quantity * oi.unit_price) DESC
457
+ ) as rn
458
+ FROM products p
459
+ JOIN order_items oi ON p.id = oi.product_id
460
+ JOIN orders o ON oi.order_id = o.id
461
+ WHERE o.status = 'completed'
462
+ GROUP BY p.category, p.name
463
+ )
464
+ SELECT category, product_name, total_qty
465
+ FROM product_sales
466
+ WHERE rn = 1
467
+ ORDER BY category ASC
468
+ """,
469
+ expected_columns=["category", "product_name", "total_qty"],
470
+ expected_row_count=3,
471
+ expected_rows=[
472
+ ("Books", "Python Book", 4),
473
+ ("Electronics", "Headphones", 7),
474
+ ("Home", "Desk Lamp", 4),
475
+ ],
476
+ hints=[
477
+ "First calculate total quantity sold per product (SUM of quantity)",
478
+ "Use ROW_NUMBER() OVER (PARTITION BY category ORDER BY qty DESC) to rank within category",
479
+ "Filter WHERE rn = 1 to get only the top product per category",
480
+ "A CTE makes this much cleaner than nested subqueries",
481
+ "Don't forget to filter for completed orders only",
482
+ ],
483
+ schema_description=ECOMMERCE_SCHEMA_DESC,
484
+ max_steps=10,
485
+ ),
486
+
487
+ SQLTask(
488
+ task_id="hard_003",
489
+ difficulty="complex_analysis",
490
+ title="Customer Lifetime Value Analysis",
491
+ setup_sql=ECOMMERCE_SCHEMA,
492
+ question="For customers with at least 2 completed orders, calculate: their name, number of completed orders, total lifetime spending (rounded to 2 decimals), average order value (rounded to 2 decimals), and the number of days between their first and last completed order. Order by total spending descending.",
493
+ expected_sql="""
494
+ WITH customer_order_totals AS (
495
+ SELECT
496
+ c.id as customer_id,
497
+ c.name,
498
+ o.id as order_id,
499
+ o.order_date,
500
+ SUM(oi.quantity * oi.unit_price) as order_total
501
+ FROM customers c
502
+ JOIN orders o ON c.id = o.customer_id
503
+ JOIN order_items oi ON o.id = oi.order_id
504
+ WHERE o.status = 'completed'
505
+ GROUP BY c.id, c.name, o.id, o.order_date
506
+ )
507
+ SELECT
508
+ name,
509
+ COUNT(*) as num_orders,
510
+ ROUND(SUM(order_total), 2) as total_spending,
511
+ ROUND(AVG(order_total), 2) as avg_order_value,
512
+ CAST(julianday(MAX(order_date)) - julianday(MIN(order_date)) AS INTEGER) as days_span
513
+ FROM customer_order_totals
514
+ GROUP BY customer_id, name
515
+ HAVING COUNT(*) >= 2
516
+ ORDER BY total_spending DESC
517
+ """,
518
+ expected_columns=["name", "num_orders", "total_spending", "avg_order_value", "days_span"],
519
+ expected_row_count=2,
520
+ expected_rows=[
521
+ ("Alice", 3, 1919.91, 639.97, 122),
522
+ ("Carol", 3, 1464.94, 488.31, 31),
523
+ ],
524
+ hints=[
525
+ "Use a CTE to first calculate the total for each individual order",
526
+ "In the CTE: JOIN customers → orders → order_items, GROUP BY order",
527
+ "In the outer query: GROUP BY customer, HAVING COUNT >= 2",
528
+ "julianday() converts date strings to Julian day numbers for arithmetic",
529
+ "days_span = julianday(MAX(order_date)) - julianday(MIN(order_date))",
530
+ ],
531
+ schema_description=ECOMMERCE_SCHEMA_DESC,
532
+ max_steps=10,
533
+ ),
534
+ ]
535
+
536
+
537
+ # =============================================================
538
+ # TASK REGISTRY — Maps task IDs and difficulty levels
539
+ # =============================================================
540
+
541
+ ALL_TASKS: Dict[str, List[SQLTask]] = {
542
+ "basic_select": EASY_TASKS,
543
+ "join_aggregate": MEDIUM_TASKS,
544
+ "complex_analysis": HARD_TASKS,
545
+ }
546
+
547
+ # Build a flat lookup by task_id
548
+ TASK_BY_ID: Dict[str, SQLTask] = {}
549
+ for _tasks in ALL_TASKS.values():
550
+ for _task in _tasks:
551
+ TASK_BY_ID[_task.task_id] = _task
552
+
553
+
554
+ def get_task(difficulty: str, task_id: Optional[str] = None) -> SQLTask:
555
+ """
556
+ Get a task by difficulty level, optionally by specific ID.
557
+
558
+ Args:
559
+ difficulty: One of 'basic_select', 'join_aggregate', 'complex_analysis'
560
+ task_id: Optional specific task ID (e.g., 'easy_001')
561
+
562
+ Returns:
563
+ SQLTask instance
564
+
565
+ Raises:
566
+ ValueError: If difficulty is unknown
567
+ """
568
+ # If specific task_id given, return it directly
569
+ if task_id and task_id in TASK_BY_ID:
570
+ return TASK_BY_ID[task_id]
571
+
572
+ # Otherwise pick from the difficulty pool
573
+ if difficulty not in ALL_TASKS:
574
+ raise ValueError(
575
+ f"Unknown difficulty: '{difficulty}'. "
576
+ f"Choose from: {list(ALL_TASKS.keys())}"
577
+ )
578
+
579
+ tasks = ALL_TASKS[difficulty]
580
+ return random.choice(tasks)
581
+
582
+
583
+ def list_tasks() -> Dict[str, List[str]]:
584
+ """
585
+ List all available tasks grouped by difficulty.
586
+
587
+ Returns:
588
+ Dict mapping difficulty name to list of task IDs
589
+ """
590
+ return {
591
+ difficulty: [t.task_id for t in tasks]
592
+ for difficulty, tasks in ALL_TASKS.items()
593
+ }
tests/__init__.py ADDED
File without changes
tests/test_env.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for SQL Arena environment."""
2
+
3
+ import pytest
4
+ from src.sql_arena.environment import SQLArenaEnvironment
5
+ from src.sql_arena.models import SQLArenaAction
6
+ from src.sql_arena.tasks import list_tasks, TASK_BY_ID
7
+
8
+
9
+ class TestEnvironmentBasics:
10
+
11
+ def setup_method(self):
12
+ self.env = SQLArenaEnvironment()
13
+
14
+ def teardown_method(self):
15
+ self.env.close()
16
+
17
+ def test_reset_returns_observation(self):
18
+ result = self.env.reset(difficulty="basic_select", task_id="easy_001")
19
+ assert result.observation is not None
20
+ assert result.reward == 0.0
21
+ assert result.done is False
22
+
23
+ def test_step_with_correct_query(self):
24
+ self.env.reset(difficulty="basic_select", task_id="easy_001")
25
+ task = self.env.current_task
26
+ action = SQLArenaAction(sql_query=task.expected_sql)
27
+ result = self.env.step(action)
28
+ assert result.reward > 0.0
29
+ assert result.info.get("score", 0) >= 0.8
30
+
31
+ def test_step_with_invalid_query(self):
32
+ self.env.reset(difficulty="basic_select", task_id="easy_001")
33
+ action = SQLArenaAction(sql_query="INVALID SQL QUERY")
34
+ result = self.env.step(action)
35
+ assert result.reward == 0.0
36
+ assert result.observation.error_message is not None
37
+
38
+ def test_state_tracking(self):
39
+ self.env.reset(difficulty="basic_select", task_id="easy_001")
40
+ state = self.env.state()
41
+ assert state.current_step == 0
42
+
43
+ self.env.step(SQLArenaAction(sql_query="SELECT 1"))
44
+ state = self.env.state()
45
+ assert state.current_step == 1
46
+
47
+ def test_episode_terminates(self):
48
+ self.env.reset(difficulty="basic_select", task_id="easy_001")
49
+ task = self.env.current_task
50
+ for _ in range(task.max_steps + 1):
51
+ if self.env.state().done:
52
+ break
53
+ self.env.step(SQLArenaAction(sql_query="SELECT 1"))
54
+ assert self.env.state().done is True
55
+
56
+
57
+ class TestAllDifficulties:
58
+
59
+ def setup_method(self):
60
+ self.env = SQLArenaEnvironment()
61
+
62
+ def teardown_method(self):
63
+ self.env.close()
64
+
65
+ def test_easy(self):
66
+ result = self.env.reset(difficulty="basic_select")
67
+ assert result.observation.difficulty == "basic_select"
68
+
69
+ def test_medium(self):
70
+ result = self.env.reset(difficulty="join_aggregate")
71
+ assert result.observation.difficulty == "join_aggregate"
72
+
73
+ def test_hard(self):
74
+ result = self.env.reset(difficulty="complex_analysis")
75
+ assert result.observation.difficulty == "complex_analysis"
76
+
77
+
78
+ class TestGrading:
79
+
80
+ def setup_method(self):
81
+ self.env = SQLArenaEnvironment()
82
+
83
+ def teardown_method(self):
84
+ self.env.close()
85
+
86
+ def test_scores_in_range(self):
87
+ for task_id, task in TASK_BY_ID.items():
88
+ self.env.reset(difficulty=task.difficulty, task_id=task_id)
89
+ action = SQLArenaAction(sql_query=task.expected_sql)
90
+ result = self.env.step(action)
91
+ assert 0.0 <= result.reward <= 1.0
92
+ assert 0.0 <= result.info.get("score", 0) <= 1.0
93
+ self.env.reset(difficulty=task.difficulty, task_id=task_id)
94
+
95
+ def test_varying_scores(self):
96
+ scores = set()
97
+ queries = [
98
+ "SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
99
+ "SELECT * FROM employees",
100
+ "INVALID",
101
+ "SELECT name FROM employees",
102
+ ]
103
+ for q in queries:
104
+ self.env.reset(difficulty="basic_select", task_id="easy_001")
105
+ result = self.env.step(SQLArenaAction(sql_query=q))
106
+ scores.add(round(result.info.get("score", 0), 2))
107
+ assert len(scores) > 1, "Grader always returns the same score!"
108
+
109
+
110
+ class TestTaskRegistry:
111
+
112
+ def test_list_tasks(self):
113
+ tasks = list_tasks()
114
+ assert "basic_select" in tasks
115
+ assert "join_aggregate" in tasks
116
+ assert "complex_analysis" in tasks
117
+
118
+ def test_minimum_3_tasks(self):
119
+ tasks = list_tasks()
120
+ for difficulty, task_ids in tasks.items():
121
+ assert len(task_ids) >= 3, f"{difficulty} has fewer than 3 tasks"
122
+
123
+
124
+ if __name__ == "__main__":
125
+ pytest.main([__file__, "-v"])