Prithvigg commited on
Commit
a8a3c90
Β·
verified Β·
1 Parent(s): 452be68

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=queryforge
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,353 @@
1
  ---
2
- title: Queryforge
3
- emoji: πŸ†
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: QueryForge Environment Server
3
+ emoji: πŸ”
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - sql
13
+ - reinforcement-learning
14
  ---
15
 
16
+ # QueryForge β€” SQL Debugging & Optimisation Environment
17
+
18
+ SQL is the language that runs the world's data infrastructure. Yet SQL bugs are silent killers β€” a missing JOIN condition inflates totals by 3Γ—, a correlated subquery scans a million rows once per row, a typo in a keyword stops production cold. These bugs are rarely caught by linters, rarely surfaced by error messages, and routinely shipped to production.
19
+
20
+ QueryForge is an **OpenEnv-compatible reinforcement learning environment** where an agent learns to debug and optimise SQL queries. The agent receives a broken or slow query, submits fixes, and receives graded feedback from a deterministic DuckDB engine combined with an Anthropic AI quality judge β€” a smooth, informative reward signal across the full 0.0 β†’ 1.0 range.
21
+
22
+ ---
23
+
24
+ ## Why SQL Debugging as an RL Environment?
25
+
26
+ LLMs can write SQL. What they struggle with is the **iterative, feedback-driven debugging loop** that real engineers do:
27
+
28
+ - Read the error message
29
+ - Form a hypothesis about the root cause
30
+ - Patch the query
31
+ - Check if the output is now correct
32
+ - Refine until it's both correct *and* efficient
33
+
34
+ This is precisely the loop that RL is built for. QueryForge provides the environment that closes this loop with a graded, multi-stage reward signal β€” not just "correct / incorrect" but partial credit for syntax validity, execution success, row correctness, and code quality.
35
+
36
+ ---
37
+
38
+ ## Environment Overview
39
+
40
+ | Property | Value |
41
+ |---|---|
42
+ | Task type | SQL debugging & optimisation |
43
+ | Action space | Single SQL query string |
44
+ | Observation space | Task description + graded feedback |
45
+ | Reward range | 0.0 – 1.0 (continuous) |
46
+ | Episode termination | Score β‰₯ 0.90, no improvement for 2 steps, or max steps |
47
+ | Grading engine | DuckDB (deterministic) + Anthropic AI judge |
48
+ | Concurrent sessions | Supported |
49
+
50
+ ---
51
+
52
+ ## Reward Scale
53
+
54
+ The grading pipeline has four stages that produce a smooth partial-progress signal:
55
+
56
+ | Score | Meaning |
57
+ |---|---|
58
+ | **0.00** | Syntax error β€” query could not be parsed |
59
+ | **0.15** | Syntax valid but runtime error |
60
+ | **0.30** | Executes but returns 0 rows or wrong row count |
61
+ | **0.30 – 0.80** | Partial row correctness (deterministic, DuckDB) |
62
+ | **0.80 – 1.00** | Correct rows + AI quality assessment (Anthropic) |
63
+
64
+ The AI judge scores on three axes: **Correctness** (0–0.50), **Optimization** (0–0.30 β€” penalises cartesian products, correlated subqueries), **Code quality** (0–0.20 β€” readability, aliases, formatting).
65
+
66
+ > **Offline mode:** If `ANTHROPIC_API_KEY` is not set, the AI judge is skipped and scoring is fully deterministic (capped at 0.80). The done threshold self-adjusts to 0.80 in this case so episodes still terminate correctly.
67
+
68
+ ---
69
+
70
+ ## Action Space
71
+
72
+ ```python
73
+ class SQLAction(Action):
74
+ sql: str # The SQL query to submit for grading
75
+ ```
76
+
77
+ One field. The agent submits a SQL string. No multi-statement queries (`;` separated) are allowed β€” rejected with score 0.0.
78
+
79
+ ---
80
+
81
+ ## Observation Space
82
+
83
+ ```python
84
+ class SQLObservation(Observation):
85
+ # Task context (set on reset, constant within an episode)
86
+ task_id: str # e.g. "task_easy_syntax"
87
+ task_level: str # "easy" | "medium" | "hard" | "custom"
88
+ task_title: str # Human-readable title
89
+ task_description: str # Full context: schema, broken query, error, goal
90
+
91
+ # Per-step grading signals
92
+ syntax_valid: bool # True if query parsed without error
93
+ execution_success: bool # True if query ran to completion in DuckDB
94
+ execution_error: str # Runtime error message, if any
95
+ rows_returned: int # Number of rows returned
96
+
97
+ # Feedback
98
+ feedback: str # Detailed grading feedback (DuckDB + AI judge)
99
+ hint: str # Actionable hint (suppressed once score >= 0.90)
100
+
101
+ # Episode progress
102
+ attempt: int # Number of queries submitted this episode
103
+ best_score: float # Highest score achieved so far
104
+ done: bool
105
+ reward: float # Score for this specific step (0.0 – 1.0)
106
+ ```
107
+
108
+ ---
109
+
110
+ ## Built-in Tasks
111
+
112
+ ### Easy β€” Fix Syntax Errors
113
+ Three SQL keywords are misspelled (`SELEC`, `FORM`, `WEHRE`). The agent must identify and correct them.
114
+
115
+ **Schema:** `users(id, name, age, city)` β€” 6 rows
116
+ **Goal:** Return name and age of users older than 30 in New York, ordered by name
117
+ **Max steps:** 5
118
+
119
+ ### Medium β€” Fix the Cartesian JOIN
120
+ A missing `JOIN` condition (`o.product_id = p.id`) causes a cartesian product, inflating every total by 3Γ—. The agent must rewrite using explicit `INNER JOIN … ON` syntax.
121
+
122
+ **Schema:** `orders`, `users`, `products` β€” e-commerce dataset
123
+ **Goal:** Correct per-(user, product) total amount spent, ordered by total DESC
124
+ **Max steps:** 5
125
+
126
+ ### Hard β€” Rewrite Correlated Subquery as CTE
127
+ A semantically correct but O(NΒ²) query re-executes `AVG(salary)` for every employee row. The agent must rewrite using a `WITH` clause that computes department averages exactly once.
128
+
129
+ **Schema:** `departments`, `employees` β€” 9 employees across 3 departments
130
+ **Goal:** Employees who earn strictly above their department average, ordered by dept/salary
131
+ **Max steps:** 6
132
+
133
+ > Tasks have **structural penalties**: the hard task requires a `WITH` clause (βˆ’0.30 if absent); the medium task requires explicit `JOIN` syntax (βˆ’0.20 if absent). This prevents an agent from gaming the score by submitting the broken query verbatim.
134
+
135
+ ---
136
+
137
+ ## Custom Tasks
138
+
139
+ Register any SQL task at runtime β€” no code changes needed.
140
+
141
+ ### Via Python
142
+ ```python
143
+ from tasks import REGISTRY, task_from_dict
144
+
145
+ REGISTRY.register(task_from_dict({
146
+ "id": "my_window_task",
147
+ "level": "hard",
148
+ "title": "Rank Employees by Salary",
149
+ "schema_ddl": "CREATE TABLE emp (id INT, name VARCHAR, dept VARCHAR, salary DECIMAL); INSERT INTO emp VALUES ...",
150
+ "broken_query": "SELECT name, salary FROM emp ORDER BY salary DESC",
151
+ "expected_rows": [{"name": "Alice", "rank": 1}, ...],
152
+ "hint": "Use ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC)",
153
+ "solution_query": "SELECT name, RANK() OVER (ORDER BY salary DESC) AS rank FROM emp",
154
+ }))
155
+ ```
156
+
157
+ ### Via REST API (when server is running)
158
+ ```bash
159
+ # Register a custom task
160
+ curl -X POST http://localhost:8000/tasks \
161
+ -H "Content-Type: application/json" \
162
+ -d '{"id": "my_task", "schema_ddl": "...", "expected_rows": [...]}'
163
+
164
+ # List all tasks
165
+ curl http://localhost:8000/tasks
166
+
167
+ # Remove a custom task
168
+ curl -X DELETE http://localhost:8000/tasks/my_task
169
+ ```
170
+
171
+ ### Via JSON file
172
+ ```python
173
+ REGISTRY.load_from_json("my_tasks.json")
174
+ ```
175
+
176
+ ---
177
+
178
+ ## Quickstart
179
+
180
+ ### Install dependencies
181
+ ```bash
182
+ python -m venv .venv
183
+ .venv/bin/pip install -e ".[dev]"
184
+ ```
185
+
186
+ ### Run the local playbook (no server needed)
187
+ Tests all three built-in tasks directly, with progressive SQL attempts:
188
+ ```bash
189
+ ANTHROPIC_API_KEY=your_key .venv/bin/python playbook.py
190
+ ```
191
+
192
+ ### Run the baseline inference script
193
+ Runs a Claude model as an agent against all tasks and reports scores:
194
+ ```bash
195
+ # Default model (claude-haiku-4-5)
196
+ ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py
197
+
198
+ # Specific model
199
+ ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py --model claude-opus-4-6
200
+
201
+ # Single task with verbose output
202
+ ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py --task task_hard_cte --verbose
203
+ ```
204
+
205
+ ### Run the HTTP server
206
+ ```bash
207
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
208
+ ```
209
+
210
+ ---
211
+
212
+ ## Baseline Results
213
+
214
+ The following scores were produced by running `claude-haiku-4-5` as the agent against all three tasks with the full AI judge active. These serve as the reproducible baseline for this environment.
215
+
216
+ | Task | Level | Steps Used | Best Score |
217
+ |---|---|---|---|
218
+ | Fix the Syntax Errors | easy | 1 | **1.000** |
219
+ | Fix the Cartesian JOIN | medium | 1 | **0.900** |
220
+ | Rewrite Correlated Subquery as CTE | hard | 1 | **0.950** |
221
+ | **Average** | | | **0.950** |
222
+
223
+ All three tasks were solved (or near-solved) on the first step, demonstrating that:
224
+ - The reward pipeline returns meaningful signal immediately
225
+ - The environment terminates cleanly when the done threshold (β‰₯ 0.90) is met
226
+ - A stronger model or a harder task set would produce more training-relevant trajectories
227
+
228
+ ---
229
+
230
+ ## API Endpoints
231
+
232
+ | Method | Endpoint | Description |
233
+ |---|---|---|
234
+ | `POST` | `/reset` | Start a new episode. Pass `{"task_id": "..."}` to pin to a task |
235
+ | `POST` | `/step` | Submit a SQL query: `{"sql": "SELECT ..."}` |
236
+ | `GET` | `/state` | Current episode ID and step count |
237
+ | `GET` | `/schema` | Action and observation JSON schemas |
238
+ | `POST` | `/tasks` | Register a custom task |
239
+ | `GET` | `/tasks` | List all registered tasks |
240
+ | `DELETE` | `/tasks/{task_id}` | Remove a custom task (built-ins protected) |
241
+ | `WS` | `/ws` | WebSocket endpoint for persistent low-latency sessions |
242
+ | `GET` | `/health` | Container health check |
243
+ | `GET` | `/docs` | Interactive OpenAPI documentation |
244
+
245
+ ### Examples
246
+
247
+ ```bash
248
+ # Start an episode pinned to the hard task
249
+ curl -X POST http://localhost:8000/reset \
250
+ -H "Content-Type: application/json" \
251
+ -d '{"task_id": "task_hard_cte"}'
252
+
253
+ # Submit a query
254
+ curl -X POST http://localhost:8000/step \
255
+ -H "Content-Type: application/json" \
256
+ -d '{"sql": "WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary FROM employees GROUP BY department_id) SELECT e.name, e.department_id, e.salary FROM employees e JOIN dept_avg d ON e.department_id = d.department_id WHERE e.salary > d.avg_salary ORDER BY e.department_id, e.salary DESC"}'
257
+
258
+ # List all available tasks
259
+ curl http://localhost:8000/tasks
260
+ ```
261
+
262
+ ---
263
+
264
+ ## Python Client
265
+
266
+ ```python
267
+ from queryforge import QueryforgeEnv, SQLAction
268
+
269
+ with QueryforgeEnv(base_url="http://localhost:8000") as env:
270
+ # Pin to a specific task
271
+ obs = env.reset(task_id="task_medium_join")
272
+ print(obs.task_description)
273
+
274
+ # Submit a fix
275
+ result = env.step(SQLAction(sql="""
276
+ SELECT u.name, p.title, SUM(o.amount) AS total_spent
277
+ FROM orders o
278
+ INNER JOIN users u ON o.user_id = u.id
279
+ INNER JOIN products p ON o.product_id = p.id
280
+ GROUP BY u.name, p.title
281
+ ORDER BY total_spent DESC
282
+ """))
283
+ print(f"Score: {result.reward:.3f}")
284
+ print(f"Feedback: {result.observation.feedback}")
285
+ print(f"Done: {result.done}")
286
+
287
+ # Register and use a custom task
288
+ env.register_task(TaskSpec(
289
+ id="my_task",
290
+ schema_ddl="CREATE TABLE ...; INSERT INTO ...",
291
+ expected_rows=[{"col": "val"}],
292
+ title="My Custom Task",
293
+ ))
294
+ obs = env.reset(task_id="my_task")
295
+ ```
296
+
297
+ ---
298
+
299
+ ## Project Structure
300
+
301
+ ```
302
+ queryforge/
303
+ β”œβ”€β”€ __init__.py # Public exports (SQLAction, SQLObservation, TaskSpec, REGISTRY)
304
+ β”œβ”€β”€ models.py # SQLAction, SQLObservation, TaskSpec Pydantic models
305
+ β”œβ”€β”€ tasks.py # Built-in tasks + thread-safe TaskRegistry
306
+ β”œβ”€β”€ judge.py # 4-stage grading pipeline (DuckDB + Anthropic)
307
+ β”œβ”€β”€ client.py # QueryforgeEnv client with task management helpers
308
+ β”œβ”€β”€ playbook.py # Local test runner (no server required)
309
+ β”œβ”€β”€ baseline.py # Baseline inference script (Claude as agent)
310
+ β”œβ”€β”€ openenv.yaml # OpenEnv manifest
311
+ β”œβ”€β”€ pyproject.toml # Project metadata and dependencies
312
+ β”œβ”€β”€ uv.lock # Locked dependencies
313
+ └── server/
314
+ β”œβ”€β”€ app.py # FastAPI app β€” core + /tasks REST endpoints
315
+ β”œβ”€β”€ queryforge_environment.py # Environment class (reset, step, state)
316
+ β”œβ”€β”€ Dockerfile # Container image
317
+ └── requirements.txt # Server dependencies
318
+ ```
319
+
320
+ ---
321
+
322
+ ## Deployment
323
+
324
+ ### Hugging Face Spaces (recommended)
325
+
326
+ ```bash
327
+ UV_CACHE_DIR=/tmp/uv-cache openenv push . --repo-id <hf-username>/queryforge
328
+ ```
329
+
330
+ Add `ANTHROPIC_API_KEY` as a Space secret after deployment. Without it, the environment runs in deterministic-only mode (scores capped at 0.80, done threshold self-adjusts accordingly).
331
+
332
+ ### Docker
333
+
334
+ ```bash
335
+ docker build -t queryforge:latest -f server/Dockerfile .
336
+ docker run -p 8000:8000 -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY queryforge:latest
337
+ ```
338
+
339
+ The deployed environment exposes:
340
+ - **`/web`** β€” Interactive UI for exploring the environment
341
+ - **`/docs`** β€” Full OpenAPI / Swagger interface
342
+ - **`/ws`** β€” WebSocket endpoint for persistent agent sessions
343
+ - **`/health`** β€” Container health monitoring
344
+
345
+ ---
346
+
347
+ ## Environment Design Notes
348
+
349
+ **Why DuckDB?** DuckDB runs fully in-memory with no external process or network dependency. Each `step()` call creates an isolated connection, seeds it with the task's schema, runs the agent's query, then closes β€” complete isolation with zero shared state between steps.
350
+
351
+ **Why a 4-stage reward?** Binary correct/incorrect rewards give an agent no gradient to climb when its query is simply broken. The 4-stage pipeline means every improvement β€” fixing a typo, avoiding a runtime error, returning the right row count, getting the right rows, writing clean SQL β€” is rewarded. This produces a smooth loss landscape for policy gradient methods.
352
+
353
+ **Why structural penalties?** Without them, an agent could achieve 0.80 on the hard CTE task by submitting the original correlated subquery verbatim (rows match, but the task was never solved). Structural penalties enforce that the agent actually learned *what* to change, not just that rows matched.
__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """QueryForge β€” SQL Debugger & Optimiser Environment."""
8
+
9
+ from .client import QueryforgeEnv
10
+ from .models import SQLAction, SQLObservation, TaskSpec
11
+ from .tasks import TASKS, TASK_BY_ID, SQLTask, REGISTRY, task_from_dict
12
+
13
+ __all__ = [
14
+ "SQLAction",
15
+ "SQLObservation",
16
+ "TaskSpec",
17
+ "QueryforgeEnv",
18
+ "TASKS",
19
+ "TASK_BY_ID",
20
+ "SQLTask",
21
+ "REGISTRY",
22
+ "task_from_dict",
23
+ ]
baseline.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryForge Baseline Inference Script
3
+ ─────────────────────────────────────
4
+ Runs a Claude model as an agent against all 3 built-in tasks and reports
5
+ a reproducible baseline score.
6
+
7
+ Usage:
8
+ # All tasks, default model (claude-haiku-4-5):
9
+ python baseline.py
10
+
11
+ # Specific model:
12
+ python baseline.py --model claude-opus-4-6
13
+
14
+ # Single task:
15
+ python baseline.py --task task_easy_syntax
16
+
17
+ # More verbose output:
18
+ python baseline.py --verbose
19
+
20
+ Requirements:
21
+ ANTHROPIC_API_KEY must be set in the environment.
22
+ """
23
+
24
+ import argparse
25
+ import os
26
+ import re
27
+ import sys
28
+
29
+ import anthropic
30
+
31
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
32
+
33
+ from models import SQLAction
34
+ from server.queryforge_environment import QueryforgeEnvironment
35
+ from tasks import REGISTRY
36
+
37
+ # ── Constants ─────────────────────────────────────────────────────────────────
38
+
39
+ DEFAULT_MODEL = "claude-haiku-4-5"
40
+
41
+ SYSTEM_PROMPT = """\
42
+ You are an expert SQL engineer. You will be given a SQL debugging or \
43
+ optimisation challenge. Your job is to submit a corrected or improved SQL query.
44
+
45
+ Rules:
46
+ - Respond with ONLY a single SQL query inside a ```sql ... ``` code block.
47
+ - Do not explain your reasoning outside the code block.
48
+ - Do not include multiple statements (no semicolons except at the very end).
49
+ - If you receive feedback on a previous attempt, use it to improve your query.
50
+ """
51
+
52
+ # ── SQL extraction ─────────────────────────────────────────────────────────────
53
+
54
+ _SQL_BLOCK = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)
55
+
56
+
57
+ def _extract_sql(text: str) -> str:
58
+ """Pull the first SQL code block out of Claude's response."""
59
+ match = _SQL_BLOCK.search(text)
60
+ if match:
61
+ return match.group(1).strip()
62
+ # Fallback: return the whole response stripped β€” better than crashing
63
+ return text.strip()
64
+
65
+
66
+ # ── Formatting helpers ────────────────────────────────────────────────────────
67
+
68
+ def _hr(char="═", width=70):
69
+ print(char * width)
70
+
71
+ def _score_bar(score: float, width: int = 25) -> str:
72
+ filled = int(score * width)
73
+ bar = "β–ˆ" * filled + "β–‘" * (width - filled)
74
+ return f"[{bar}] {score:.3f}"
75
+
76
+
77
+ # ── Per-task agent loop ────────────────────────────────────────────────────────
78
+
79
+ def run_task(
80
+ task_id: str,
81
+ model: str,
82
+ client: anthropic.Anthropic,
83
+ verbose: bool = False,
84
+ ) -> dict:
85
+ """
86
+ Run one episode of a single task.
87
+
88
+ Returns a dict with keys:
89
+ task_id, task_title, task_level,
90
+ best_score, attempts, done
91
+ """
92
+ env = QueryforgeEnvironment()
93
+ obs = env.reset(task_id=task_id)
94
+
95
+ if obs.done:
96
+ # reset() returned an error (unknown task_id)
97
+ print(f" ERROR: {obs.feedback}")
98
+ return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
99
+
100
+ print(f"\n Task : {obs.task_title} [{obs.task_level}] (max {env._current_task.max_steps} steps)")
101
+ if verbose:
102
+ print(f" ID : {obs.task_id}")
103
+
104
+ # ── Build initial conversation ────────────────────────────────────────────
105
+ messages = [
106
+ {
107
+ "role": "user",
108
+ "content": (
109
+ f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
110
+ "Provide your fixed SQL query."
111
+ ),
112
+ }
113
+ ]
114
+
115
+ step = 0
116
+ while not obs.done:
117
+ step += 1
118
+
119
+ # ── Call Claude ───────────────────────────────────────────────────────
120
+ with client.messages.stream(
121
+ model=model,
122
+ max_tokens=512,
123
+ system=SYSTEM_PROMPT,
124
+ messages=messages,
125
+ ) as stream:
126
+ response_text = ""
127
+ for text in stream.text_stream:
128
+ response_text += text
129
+
130
+ sql = _extract_sql(response_text)
131
+
132
+ if verbose:
133
+ print(f"\n ── Step {step}")
134
+ short_sql = sql[:120] + ("…" if len(sql) > 120 else "")
135
+ print(f" SQL: {short_sql}")
136
+
137
+ # ── Submit to environment ─────────────────────────────────────────────
138
+ obs = env.step(SQLAction(sql=sql))
139
+
140
+ score_bar = _score_bar(obs.reward or 0.0)
141
+ status = "βœ“ DONE" if obs.done else f"step {step}/{env._current_task.max_steps}"
142
+ print(f" [{status}] Score: {score_bar}")
143
+
144
+ if verbose and obs.feedback:
145
+ fb = obs.feedback[:200] + ("…" if len(obs.feedback) > 200 else "")
146
+ print(f" Feedback: {fb}")
147
+
148
+ if obs.done:
149
+ break
150
+
151
+ # ── Append exchange to conversation for next attempt ──────────────────
152
+ messages.append({"role": "assistant", "content": response_text})
153
+ messages.append({
154
+ "role": "user",
155
+ "content": (
156
+ f"Your query scored {obs.reward:.3f}. Here is the feedback:\n\n"
157
+ f"{obs.feedback}\n\n"
158
+ f"Hint: {obs.hint}\n\n"
159
+ "Please try again with an improved SQL query."
160
+ ),
161
+ })
162
+
163
+ return {
164
+ "task_id": task_id,
165
+ "task_title": obs.task_title,
166
+ "task_level": obs.task_level,
167
+ "best_score": obs.best_score,
168
+ "attempts": obs.attempt,
169
+ "done": obs.done,
170
+ }
171
+
172
+
173
+ # ── Main ───────────────────────────────────────────────────────────────────────
174
+
175
+ def main():
176
+ parser = argparse.ArgumentParser(description="QueryForge Baseline Inference")
177
+ parser.add_argument(
178
+ "--model", default=DEFAULT_MODEL,
179
+ help=f"Anthropic model ID to use (default: {DEFAULT_MODEL})"
180
+ )
181
+ parser.add_argument(
182
+ "--task", default=None,
183
+ help="Run a single task by ID instead of all built-in tasks"
184
+ )
185
+ parser.add_argument(
186
+ "--verbose", action="store_true",
187
+ help="Print SQL queries and full feedback for each step"
188
+ )
189
+ args = parser.parse_args()
190
+
191
+ # ── Validate API key ──────────────────────────────────────────────────────
192
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
193
+ if not api_key:
194
+ print("ERROR: ANTHROPIC_API_KEY is not set.")
195
+ sys.exit(1)
196
+
197
+ client = anthropic.Anthropic(api_key=api_key)
198
+
199
+ # ── Determine tasks to run ────────────────────────────────────────────────
200
+ if args.task:
201
+ task_ids = [args.task]
202
+ else:
203
+ task_ids = ["task_easy_syntax", "task_medium_join", "task_hard_cte"]
204
+
205
+ # ── Header ────────────────────────────────────────────────────────────────
206
+ _hr()
207
+ print(" QueryForge β€” Baseline Inference")
208
+ print(f" Model : {args.model}")
209
+ print(f" Tasks : {', '.join(task_ids)}")
210
+ _hr()
211
+
212
+ # ── Run each task ─────────────────────────────────────────────────────────
213
+ results = []
214
+ for task_id in task_ids:
215
+ print(f"\n{'─' * 70}")
216
+ result = run_task(task_id, args.model, client, verbose=args.verbose)
217
+ results.append(result)
218
+
219
+ # ── Results table ─────────────────────────────────────────────────────────
220
+ print(f"\n{'═' * 70}")
221
+ print(" BASELINE RESULTS")
222
+ print(f" Model: {args.model}")
223
+ print(f"{'═' * 70}")
224
+ print(f" {'Task':<28} {'Level':<8} {'Steps':>5} {'Best Score'}")
225
+ print(f" {'─' * 28} {'─' * 8} {'─' * 5} {'─' * 30}")
226
+
227
+ total_score = 0.0
228
+ for r in results:
229
+ title = r.get("task_title", r["task_id"])[:27]
230
+ level = r.get("task_level", "?")
231
+ attempts = r.get("attempts", "?")
232
+ score = r["best_score"]
233
+ total_score += score
234
+ bar = _score_bar(score)
235
+ print(f" {title:<28} {level:<8} {attempts:>5} {bar}")
236
+
237
+ avg = total_score / len(results) if results else 0.0
238
+ print(f"{'─' * 70}")
239
+ print(f" {'AVERAGE':<28} {'':8} {'':5} {_score_bar(avg)}")
240
+ print(f"{'═' * 70}\n")
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
client.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """QueryForge Environment Client."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import httpx
6
+ from openenv.core import EnvClient
7
+ from openenv.core.client_types import StepResult
8
+ from openenv.core.env_server.types import State
9
+
10
+ from .models import SQLAction, SQLObservation, TaskSpec
11
+
12
+
13
+ class QueryforgeEnv(EnvClient[SQLAction, SQLObservation, State]):
14
+ """
15
+ Client for the QueryForge SQL Debugger & Optimiser environment.
16
+
17
+ Maintains a persistent WebSocket connection to the environment server.
18
+ Each client instance has its own dedicated session (isolated task state).
19
+
20
+ Example:
21
+ >>> with QueryforgeEnv(base_url="http://localhost:8000") as env:
22
+ ... obs = env.reset()
23
+ ... print(obs.task_title)
24
+ ... print(obs.task_description)
25
+ ...
26
+ ... result = env.step(SQLAction(sql="SELECT name, age FROM users WHERE age > 30"))
27
+ ... print(result.reward, result.observation.feedback)
28
+
29
+ Example with Docker:
30
+ >>> env = QueryforgeEnv.from_docker_image("queryforge-env:latest")
31
+ >>> try:
32
+ ... obs = env.reset()
33
+ ... result = env.step(SQLAction(sql="SELECT ..."))
34
+ ... finally:
35
+ ... env.close()
36
+ """
37
+
38
+ def _step_payload(self, action: SQLAction) -> Dict:
39
+ return {"sql": action.sql}
40
+
41
+ def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
42
+ obs_data = payload.get("observation", {})
43
+ observation = SQLObservation(
44
+ task_id=obs_data.get("task_id", ""),
45
+ task_level=obs_data.get("task_level", ""),
46
+ task_title=obs_data.get("task_title", ""),
47
+ task_description=obs_data.get("task_description", ""),
48
+ syntax_valid=obs_data.get("syntax_valid", False),
49
+ execution_success=obs_data.get("execution_success", False),
50
+ execution_error=obs_data.get("execution_error"),
51
+ rows_returned=obs_data.get("rows_returned", 0),
52
+ feedback=obs_data.get("feedback", ""),
53
+ hint=obs_data.get("hint", ""),
54
+ attempt=obs_data.get("attempt", 0),
55
+ best_score=obs_data.get("best_score", 0.0),
56
+ done=payload.get("done", False),
57
+ reward=payload.get("reward", 0.0),
58
+ metadata=obs_data.get("metadata", {}),
59
+ )
60
+ return StepResult(
61
+ observation=observation,
62
+ reward=payload.get("reward", 0.0),
63
+ done=payload.get("done", False),
64
+ )
65
+
66
+ def _parse_state(self, payload: Dict) -> State:
67
+ return State(
68
+ episode_id=payload.get("episode_id"),
69
+ step_count=payload.get("step_count", 0),
70
+ )
71
+
72
+ # ── Task Registry helpers ─────────────────────────────────────────────────
73
+
74
+ def register_task(self, spec: TaskSpec) -> Dict[str, Any]:
75
+ """Register a custom task on the server. Returns the server response dict."""
76
+ resp = httpx.post(
77
+ f"{self.base_url}/tasks",
78
+ json=spec.model_dump(),
79
+ timeout=10,
80
+ )
81
+ resp.raise_for_status()
82
+ return resp.json()
83
+
84
+ def list_tasks(self) -> List[Dict[str, Any]]:
85
+ """Return all registered tasks (built-in + custom) as a list of dicts."""
86
+ resp = httpx.get(f"{self.base_url}/tasks", timeout=10)
87
+ resp.raise_for_status()
88
+ return resp.json()
89
+
90
+ def delete_task(self, task_id: str) -> Dict[str, Any]:
91
+ """Remove a custom task by ID. Raises httpx.HTTPStatusError on 403/404."""
92
+ resp = httpx.delete(f"{self.base_url}/tasks/{task_id}", timeout=10)
93
+ resp.raise_for_status()
94
+ return resp.json()
judge.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryForge Judge β€” deterministic DuckDB grading + Anthropic AI quality scoring.
3
+
4
+ Grading pipeline for each submitted SQL query:
5
+
6
+ Stage 1 β€” Syntax (0.0 β†’ 0.15)
7
+ DuckDB EXPLAIN parses the query. Fail β†’ score = 0.0.
8
+
9
+ Stage 2 β€” Execution (β†’ 0.30)
10
+ Run the full query against in-memory DuckDB seeded with task data.
11
+ Fail β†’ score = 0.15 (syntax was fine, runtime error).
12
+
13
+ Stage 3 β€” Correctness (β†’ 0.80)
14
+ Compare returned rows against expected rows.
15
+ Perfect match β†’ deterministic score reaches 0.80.
16
+ Partial credit for correct row count or partial row matches.
17
+
18
+ Stage 4 β€” AI Quality (β†’ 1.0)
19
+ Anthropic claude-sonnet-4-6 evaluates optimization, code style, and
20
+ semantic correctness vs. the reference solution.
21
+ The AI score can move the final score up to 1.0 when rows are correct,
22
+ or provide nuanced feedback even when rows are partially wrong.
23
+
24
+ Environment variable required:
25
+ ANTHROPIC_API_KEY β€” standard Anthropic SDK key.
26
+ """
27
+
28
+ import json
29
+ import re
30
+ from typing import Any, Dict, List, Optional, Tuple
31
+
32
+ import anthropic
33
+ import duckdb
34
+
35
+ try:
36
+ from .tasks import SQLTask, TestCase
37
+ except ImportError:
38
+ from tasks import SQLTask, TestCase
39
+
40
+ JUDGE_MODEL = "claude-haiku-4-5-20251001"
41
+ # ---------------------------------------------------------------------------
42
+ # Stage 1 β€” Syntax check
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _reject_multi_statement(query: str) -> Optional[str]:
46
+ """Return an error message if the query contains multiple statements."""
47
+ # Strip string literals and comments before checking for semicolons
48
+ stripped = re.sub(r"'[^']*'", "", query) # remove string literals
49
+ stripped = re.sub(r"--[^\n]*", "", stripped) # remove line comments
50
+ stripped = re.sub(r"/\*.*?\*/", "", stripped, flags=re.DOTALL) # block comments
51
+ stripped = stripped.strip().rstrip(";") # allow a single trailing semicolon
52
+ if ";" in stripped:
53
+ return "Multi-statement queries are not allowed."
54
+ return None
55
+
56
+
57
+ def check_syntax(query: str) -> Tuple[bool, Optional[str]]:
58
+ """
59
+ Return (is_valid, error_message).
60
+
61
+ Strategy: run EXPLAIN against an empty in-memory DuckDB.
62
+ - "Parser Error" in the exception β†’ genuine syntax error β†’ invalid.
63
+ - "Catalog Error" / "Binder Error" β†’ tables unknown but syntax is fine β†’ valid.
64
+ - Any other exception β†’ treat as syntax error to be safe.
65
+ """
66
+ multi_err = _reject_multi_statement(query)
67
+ if multi_err:
68
+ return False, multi_err
69
+
70
+ conn = duckdb.connect(":memory:")
71
+ try:
72
+ conn.execute(f"EXPLAIN {query}")
73
+ return True, None
74
+ except Exception as exc:
75
+ msg = str(exc)
76
+ # Catalog/Binder errors mean the SQL parsed fine; tables just aren't seeded.
77
+ if any(
78
+ tag in msg
79
+ for tag in ("Catalog Error", "Binder Error", "Table with name",
80
+ "Referenced column", "does not exist", "column")
81
+ ):
82
+ return True, None
83
+ return False, msg
84
+ finally:
85
+ conn.close()
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Stage 2 β€” Execution
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def execute_query(
93
+ schema_ddl: str, query: str
94
+ ) -> Tuple[bool, Optional[List[Dict[str, Any]]], Optional[str]]:
95
+ """
96
+ Seed a fresh DuckDB in-memory DB with *schema_ddl*, then run *query*.
97
+ Returns (success, rows_as_list_of_dicts, error_message).
98
+ """
99
+ conn = duckdb.connect(":memory:")
100
+ try:
101
+ conn.execute(schema_ddl)
102
+ result = conn.execute(query).fetchdf()
103
+ rows = result.to_dict(orient="records")
104
+ # Convert numpy types to native Python
105
+ clean: List[Dict[str, Any]] = []
106
+ for row in rows:
107
+ clean.append({k: _native(v) for k, v in row.items()})
108
+ return True, clean, None
109
+ except Exception as exc:
110
+ return False, None, str(exc)
111
+ finally:
112
+ conn.close()
113
+
114
+
115
+ def _native(value: Any) -> Any:
116
+ """Convert numpy scalars β†’ native Python types for JSON-safe comparison."""
117
+ try:
118
+ import numpy as np # duckdb fetchdf() returns numpy types
119
+ if isinstance(value, (np.integer,)):
120
+ return int(value)
121
+ if isinstance(value, (np.floating,)):
122
+ return float(value)
123
+ if isinstance(value, np.bool_):
124
+ return bool(value)
125
+ except ImportError:
126
+ pass
127
+ return value
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Stage 3 β€” Row correctness
132
+ # ---------------------------------------------------------------------------
133
+
134
+ def _normalize(row: Dict[str, Any]) -> Dict[str, Any]:
135
+ """Round floats to 2 dp so 999.99000000001 == 999.99."""
136
+ return {
137
+ k: (round(float(v), 2) if isinstance(v, float) else v)
138
+ for k, v in row.items()
139
+ }
140
+
141
+
142
+ def _sort_key(row: Dict[str, Any], order_by: Optional[str]) -> tuple:
143
+ if order_by:
144
+ cols = [c.strip() for c in order_by.split(",")]
145
+ return tuple(str(row.get(c, "")) for c in cols)
146
+ return tuple(str(v) for v in row.values())
147
+
148
+
149
+ def rows_match(
150
+ actual: List[Dict[str, Any]],
151
+ expected: List[Dict[str, Any]],
152
+ order_by: Optional[str] = None,
153
+ ) -> Tuple[float, str]:
154
+ """
155
+ Compare *actual* vs *expected* rows.
156
+
157
+ Scoring:
158
+ 1.0 β€” exact match
159
+ 0.5–0.9 β€” row count matches, some rows differ
160
+ 0.3 β€” row count wrong but partial overlap
161
+ 0.0 β€” empty when non-empty expected
162
+ """
163
+ if not expected:
164
+ return (1.0, "No expected rows β€” query accepted.") if not actual else (
165
+ 0.8, f"Expected empty result but got {len(actual)} row(s)."
166
+ )
167
+
168
+ if not actual:
169
+ return 0.0, f"Query returned 0 rows; expected {len(expected)}."
170
+
171
+ # Project actual rows to only the expected columns (agent may SELECT extra).
172
+ # Use case-insensitive matching: build a map from lower(actual_col) β†’ actual_col.
173
+ expected_cols = list(expected[0].keys())
174
+ lower_map = {k.lower(): k for k in actual[0].keys()} if actual else {}
175
+
176
+ def _project(row: Dict[str, Any]) -> Dict[str, Any]:
177
+ out: Dict[str, Any] = {}
178
+ for ec in expected_cols:
179
+ actual_key = lower_map.get(ec.lower())
180
+ if actual_key is not None:
181
+ out[ec] = row[actual_key]
182
+ return out
183
+
184
+ projected = [_project(row) for row in actual]
185
+
186
+ if len(projected) != len(expected):
187
+ overlap_ratio = min(len(projected), len(expected)) / max(len(projected), len(expected))
188
+ score = 0.3 * overlap_ratio
189
+ return score, (
190
+ f"Row count mismatch: got {len(projected)}, expected {len(expected)}. "
191
+ f"({overlap_ratio:.0%} overlap ratio)"
192
+ )
193
+
194
+ actual_sorted = sorted([_normalize(r) for r in projected], key=lambda r: _sort_key(r, order_by))
195
+ expected_sorted = sorted([_normalize(r) for r in expected], key=lambda r: _sort_key(r, order_by))
196
+
197
+ matches = sum(1 for a, e in zip(actual_sorted, expected_sorted) if a == e)
198
+ row_accuracy = matches / len(expected)
199
+
200
+ if row_accuracy == 1.0:
201
+ return 1.0, "All rows match perfectly."
202
+
203
+ score = 0.5 + 0.4 * row_accuracy
204
+ return score, f"{matches}/{len(expected)} rows match correctly."
205
+
206
+
207
+ # ---------------------------------------------------------------------------
208
+ # Stage 4 β€” Anthropic AI judge
209
+ # ---------------------------------------------------------------------------
210
+
211
+ def call_anthropic_judge(
212
+ task: SQLTask,
213
+ agent_query: str,
214
+ execution_success: bool,
215
+ execution_error: Optional[str],
216
+ actual_rows: Optional[List[Dict[str, Any]]],
217
+ deterministic_score: float,
218
+ ) -> Tuple[float, str, str]:
219
+ """
220
+ Call claude-sonnet-4-6 to evaluate query quality across three axes:
221
+ - Correctness (0–0.50)
222
+ - Optimization (0–0.30) β€” avoids inefficiencies, uses best SQL patterns
223
+ - Code quality (0–0.20) β€” readable, well-aliased, idiomatic SQL
224
+
225
+ Returns (final_score, feedback, improvement_hint).
226
+ Falls back to deterministic_score if the API call fails.
227
+ """
228
+ client = anthropic.Anthropic()
229
+
230
+ sample_actual = json.dumps(actual_rows[:5] if actual_rows else [], indent=2)
231
+ sample_expected = json.dumps(
232
+ task.test_cases[0].expected_rows if task.test_cases else [], indent=2
233
+ )
234
+
235
+ prompt = f"""\
236
+ You are a strict SQL expert judge scoring an agent's query for the task below.
237
+
238
+ ## Task ({task.level})
239
+ {task.description}
240
+
241
+ ## Agent Query
242
+ ```sql
243
+ {agent_query}
244
+ ```
245
+
246
+ ## Execution
247
+ - Success: {execution_success}
248
+ - Error: {execution_error or "None"}
249
+ - Rows returned (first 5): {sample_actual}
250
+ - Expected rows: {sample_expected}
251
+
252
+ ## Reference Solution
253
+ ```sql
254
+ {task.solution_query}
255
+ ```
256
+
257
+ ## Deterministic row-match score (0.0–1.0): {deterministic_score:.3f}
258
+
259
+ Score the agent query on THREE axes and sum them for the final score:
260
+
261
+ | Axis | Max | Criteria |
262
+ |--------------|------|----------|
263
+ | Correctness | 0.50 | Produces the right rows for the stated goal |
264
+ | Optimization | 0.30 | Avoids cartesian products / correlated subqueries; uses efficient patterns (CTEs, explicit JOINs, proper GROUP BY) |
265
+ | Code quality | 0.20 | Readable aliases, clean formatting, no redundant clauses |
266
+
267
+ IMPORTANT rules:
268
+ - If execution failed with a runtime error, Correctness ≀ 0.10.
269
+ - If rows are fully correct per deterministic score β‰₯ 0.95, Correctness β‰₯ 0.40.
270
+ - For the medium task: a query that still uses comma-join syntax scores Optimization ≀ 0.05.
271
+ - For the hard task: a query without a CTE scores Optimization ≀ 0.10.
272
+
273
+ Respond with ONLY valid JSON (no markdown fences):
274
+ {{
275
+ "correctness": <float 0.0–0.50>,
276
+ "optimization": <float 0.0–0.30>,
277
+ "code_quality": <float 0.0–0.20>,
278
+ "score": <sum of above, float 0.0–1.0>,
279
+ "feedback": "<2–3 sentences summarising what the agent did right/wrong>",
280
+ "hint": "<one concrete actionable improvement, or 'Excellent!' if score >= 0.95>"
281
+ }}"""
282
+
283
+ try:
284
+ message = client.messages.create(
285
+ model="claude-sonnet-4-6",
286
+ max_tokens=512,
287
+ messages=[{"role": "user", "content": prompt}],
288
+ )
289
+ raw = message.content[0].text.strip()
290
+
291
+ # Strip accidental markdown fences
292
+ if raw.startswith("```"):
293
+ raw = raw.split("```")[1]
294
+ if raw.startswith("json"):
295
+ raw = raw[4:]
296
+ raw = raw.rsplit("```", 1)[0].strip()
297
+
298
+ data = json.loads(raw)
299
+ score = float(data["score"])
300
+ score = max(0.0, min(1.0, score))
301
+ feedback = str(data.get("feedback", ""))
302
+ hint = str(data.get("hint", ""))
303
+ return score, feedback, hint
304
+
305
+ except Exception as exc:
306
+ # Graceful fallback β€” no API key, network error, or parse failure
307
+ msg = str(exc).lower()
308
+ reason = (
309
+ "no ANTHROPIC_API_KEY set"
310
+ if "api_key" in msg or "auth" in msg or "authentication" in msg
311
+ else type(exc).__name__
312
+ )
313
+ return (
314
+ deterministic_score,
315
+ f"AI judge offline ({reason}). Using deterministic score.",
316
+ task.hint,
317
+ )
318
+
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # Public entry point
322
+ # ---------------------------------------------------------------------------
323
+
324
+ def grade(
325
+ task: SQLTask, agent_query: str
326
+ ) -> Tuple[float, str, Dict[str, Any]]:
327
+ """
328
+ Full grading pipeline. Returns (score 0.0–1.0, feedback, details_dict).
329
+
330
+ Partial progress scoring:
331
+ 0.00 β€” syntax error (unparseable)
332
+ 0.15 β€” syntax valid, runtime error
333
+ 0.30 β€” executes, but 0 rows returned
334
+ 0.30–0.80 β€” partial row matches (deterministic)
335
+ 0.80–1.00 β€” correct rows + AI quality assessment
336
+ """
337
+ details: Dict[str, Any] = {}
338
+
339
+ # ── Stage 1: syntax ──────────────────────────────────────────────────────
340
+ syntax_ok, syntax_error = check_syntax(agent_query)
341
+ details["syntax_valid"] = syntax_ok
342
+ details["syntax_error"] = syntax_error
343
+
344
+ if not syntax_ok:
345
+ return 0.0, f"Syntax error: {syntax_error}", details
346
+
347
+ # ── Stage 2: execution ───────────────────────────────────────────────────
348
+ exec_ok, rows, exec_error = execute_query(task.schema_ddl, agent_query)
349
+ details["execution_success"] = exec_ok
350
+ details["execution_error"] = exec_error
351
+ details["rows_returned"] = len(rows) if rows else 0
352
+
353
+ if not exec_ok:
354
+ # Syntax valid but runtime error β€” call AI for nuanced feedback
355
+ ai_score, ai_feedback, ai_hint = call_anthropic_judge(
356
+ task, agent_query, False, exec_error, None, 0.15
357
+ )
358
+ details["ai_score"] = ai_score
359
+ details["ai_feedback"] = ai_feedback
360
+ final = max(0.15, ai_score * 0.3) # cap at 0.3 when execution fails
361
+ return final, f"Runtime error: {exec_error} | AI: {ai_feedback}", details
362
+
363
+ # ── Stage 3: row correctness ─────────────────────────────────────────────
364
+ test_case = task.test_cases[0]
365
+ row_score, row_feedback = rows_match(rows, test_case.expected_rows, test_case.order_by)
366
+ details["row_match_score"] = row_score
367
+ details["row_match_feedback"] = row_feedback
368
+
369
+ # ── Stage 3b: structural checks (task-specific) ─────────────────────────
370
+ # These prevent high scores when the agent submits the broken query verbatim
371
+ # or ignores the task's structural requirement.
372
+ structural_penalty = 0.0
373
+ query_upper = agent_query.upper()
374
+
375
+ if task.level == "hard" and "WITH " not in query_upper:
376
+ structural_penalty = 0.30 # hard task demands a CTE
377
+ row_feedback += " (Penalty: no CTE detected β€” task requires WITH clause.)"
378
+ elif task.level == "medium" and "JOIN " not in query_upper:
379
+ structural_penalty = 0.20 # medium task demands explicit JOINs
380
+ row_feedback += " (Penalty: no explicit JOIN β€” task requires JOIN … ON syntax.)"
381
+
382
+ details["structural_penalty"] = structural_penalty
383
+
384
+ # Deterministic score: 0.30 base for executing + up to 0.50 for rows βˆ’ penalty
385
+ deterministic_score = max(0.30, 0.30 + 0.50 * row_score - structural_penalty)
386
+
387
+ # ── Stage 4: AI quality ──────────────────────────────────────────────────
388
+ ai_score, ai_feedback, ai_hint = call_anthropic_judge(
389
+ task, agent_query, True, None, rows, deterministic_score
390
+ )
391
+ details["ai_score"] = ai_score
392
+ details["ai_feedback"] = ai_feedback
393
+ details["ai_hint"] = ai_hint
394
+
395
+ # Final blending:
396
+ # rows fully correct β†’ trust AI score (can reach 1.0)
397
+ # rows partially wrong β†’ clamp AI score to not exceed deterministic
398
+ if row_score >= 0.95:
399
+ final_score = ai_score
400
+ elif row_score >= 0.5:
401
+ # Blend: AI provides nuance but can't exceed deterministic ceiling
402
+ final_score = min(deterministic_score, ai_score + 0.05)
403
+ else:
404
+ # Low row accuracy β€” stay near deterministic
405
+ final_score = min(deterministic_score, ai_score * 0.6)
406
+
407
+ final_score = max(0.0, min(1.0, final_score))
408
+
409
+ feedback = (
410
+ f"[Rows] {row_feedback} "
411
+ f"[AI Judge] {ai_feedback} "
412
+ f"[Hint] {ai_hint}"
413
+ )
414
+ return final_score, feedback, details
models.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the QueryForge SQL environment.
3
+
4
+ SQLAction β€” the agent's submitted SQL query.
5
+ SQLObservation β€” task description + grading feedback returned after each step.
6
+ TaskSpec β€” payload for registering a custom task via POST /tasks.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ from openenv.core.env_server.types import Action, Observation
12
+ from pydantic import BaseModel, Field
13
+
14
+
15
+ class SQLAction(Action):
16
+ """Action: submit a SQL query for evaluation."""
17
+
18
+ sql: str = Field(..., description="The SQL query to submit for grading")
19
+
20
+
21
+ class SQLObservation(Observation):
22
+ """Observation returned after reset() or step()."""
23
+
24
+ # ── Task context ─────────────────────────────────────────────────────────
25
+ task_id: str = Field(default="", description="Active task identifier")
26
+ task_level: str = Field(
27
+ default="", description="Difficulty: easy | medium | hard"
28
+ )
29
+ task_title: str = Field(default="", description="Human-readable task title")
30
+ task_description: str = Field(
31
+ default="",
32
+ description=(
33
+ "Full task description: schema, broken query, error message, and goal"
34
+ ),
35
+ )
36
+
37
+ # ── Per-step grading signals ──────────────────────────────────────────────
38
+ syntax_valid: bool = Field(
39
+ default=False, description="True if the submitted query parsed without error"
40
+ )
41
+ execution_success: bool = Field(
42
+ default=False, description="True if the query ran to completion in DuckDB"
43
+ )
44
+ execution_error: Optional[str] = Field(
45
+ default=None, description="Runtime error message, if any"
46
+ )
47
+ rows_returned: int = Field(
48
+ default=0, description="Number of rows the query returned"
49
+ )
50
+ feedback: str = Field(
51
+ default="",
52
+ description="Detailed grading feedback from DuckDB + AI judge",
53
+ )
54
+ hint: str = Field(
55
+ default="", description="Actionable hint for the next attempt"
56
+ )
57
+
58
+ # ── Episode progress ──────────────────────────────────────────────────────
59
+ attempt: int = Field(
60
+ default=0, description="Number of queries submitted this episode"
61
+ )
62
+ best_score: float = Field(
63
+ default=0.0, description="Highest score achieved so far this episode"
64
+ )
65
+
66
+
67
+ class TaskSpec(BaseModel):
68
+ """
69
+ Payload for registering a custom SQL task via POST /tasks
70
+ or directly via REGISTRY.register(task_from_dict(spec.model_dump())).
71
+
72
+ Required: id, schema_ddl, expected_rows
73
+ Everything else has sensible defaults.
74
+ """
75
+
76
+ id: str = Field(
77
+ ..., description="Unique task identifier, e.g. 'null_handling_task'"
78
+ )
79
+ level: str = Field(
80
+ default="custom",
81
+ description="Difficulty label: easy | medium | hard | custom",
82
+ )
83
+ title: str = Field(..., description="Human-readable task title")
84
+ description: str = Field(
85
+ default="",
86
+ description="Full task description shown to the agent (schema, goal, etc.)",
87
+ )
88
+ schema_ddl: str = Field(
89
+ ...,
90
+ description="CREATE TABLE + INSERT statements to seed the DuckDB test DB",
91
+ )
92
+ broken_query: str = Field(
93
+ default="",
94
+ description="The broken or slow query the agent must fix",
95
+ )
96
+ error_message: str = Field(
97
+ default="",
98
+ description="Error or performance warning shown to the agent alongside the task",
99
+ )
100
+ hint: str = Field(
101
+ default="",
102
+ description="Actionable hint surfaced in the observation after each wrong attempt",
103
+ )
104
+ expected_rows: List[Dict[str, Any]] = Field(
105
+ ...,
106
+ description=(
107
+ "Exact rows the correct query must return. "
108
+ "Used for deterministic row-match scoring."
109
+ ),
110
+ )
111
+ order_by: Optional[str] = Field(
112
+ default=None,
113
+ description="Comma-separated column names used to sort rows before comparison",
114
+ )
115
+ solution_query: str = Field(
116
+ default="",
117
+ description="Reference solution shown to the AI judge for quality scoring",
118
+ )
119
+ test_description: str = Field(
120
+ default="Custom test case",
121
+ description="One-line description of what the test case checks",
122
+ )
123
+ max_steps: int = Field(
124
+ default=5, ge=1, le=20,
125
+ description="Maximum number of step() calls allowed per episode",
126
+ )
openenv.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: queryforge
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
8
+ description: |
9
+ SQL Query Debugger & Optimiser environment.
10
+
11
+ An agent receives a broken or slow SQL query together with the schema and an
12
+ error/performance warning. It must produce a working, optimised query.
13
+
14
+ Tasks (3 levels, cycled in order):
15
+ easy β€” fix three misspelled SQL keywords (SELECT / FROM / WHERE)
16
+ medium β€” fix a missing JOIN condition that causes a cartesian product
17
+ hard β€” rewrite a correlated subquery (O(NΒ²)) as a CTE (O(N))
18
+
19
+ Reward signal (0.0 – 1.0):
20
+ 0.00 syntax error
21
+ 0.15 syntax valid, runtime error
22
+ 0.30 executes, wrong / empty results
23
+ 0.30–0.80 partial row correctness (deterministic, DuckDB)
24
+ 0.80–1.00 correct results + AI quality score (Anthropic claude-sonnet-4-6)
25
+
26
+ Required env var: ANTHROPIC_API_KEY
openenv_queryforge.egg-info/PKG-INFO ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-queryforge
3
+ Version: 0.1.0
4
+ Summary: Queryforge environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.1
7
+ Requires-Dist: duckdb>=0.10.0
8
+ Requires-Dist: anthropic>=0.25.0
9
+ Provides-Extra: dev
10
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
11
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_queryforge.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./judge.py
6
+ ./models.py
7
+ ./tasks.py
8
+ openenv_queryforge.egg-info/PKG-INFO
9
+ openenv_queryforge.egg-info/SOURCES.txt
10
+ openenv_queryforge.egg-info/dependency_links.txt
11
+ openenv_queryforge.egg-info/entry_points.txt
12
+ openenv_queryforge.egg-info/requires.txt
13
+ openenv_queryforge.egg-info/top_level.txt
14
+ server/__init__.py
15
+ server/app.py
16
+ server/queryforge_environment.py
openenv_queryforge.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_queryforge.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = queryforge.server.app:main
openenv_queryforge.egg-info/requires.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+ duckdb>=0.10.0
3
+ anthropic>=0.25.0
4
+
5
+ [dev]
6
+ pytest>=8.0.0
7
+ pytest-cov>=4.0.0
openenv_queryforge.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ queryforge
playbook.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryForge Local Playbook
3
+ ─────────────────────────
4
+ Tests the environment directly (no HTTP server needed).
5
+
6
+ Run from the queryforge directory:
7
+ .venv/bin/python playbook.py
8
+
9
+ If ANTHROPIC_API_KEY is set, Stage 4 AI scoring is live.
10
+ If not set, the judge falls back to deterministic scoring (capped at 0.80).
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import textwrap
16
+
17
+ # Make imports work whether run directly or as a module
18
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
19
+
20
+ from server.queryforge_environment import QueryforgeEnvironment
21
+ from models import SQLAction
22
+ from tasks import REGISTRY, task_from_dict
23
+
24
+ # ── Formatting helpers ────────────────────────────────────────────────────────
25
+
26
+ def _hr(char="═", width=70):
27
+ print(char * width)
28
+
29
+ def _section(title):
30
+ print()
31
+ _hr()
32
+ print(f" {title}")
33
+ _hr()
34
+
35
+ def _score_bar(score: float, width: int = 30) -> str:
36
+ filled = int(score * width)
37
+ bar = "β–ˆ" * filled + "β–‘" * (width - filled)
38
+ return f"[{bar}] {score:.2f}"
39
+
40
+ def _print_obs(obs, show_description=False):
41
+ if show_description:
42
+ print()
43
+ print(textwrap.indent(obs.task_description, " "))
44
+ print()
45
+ if obs.feedback and obs.feedback != "New task loaded. Submit your fixed/optimised SQL query.":
46
+ print(f" Syntax valid : {obs.syntax_valid}")
47
+ print(f" Execution OK : {obs.execution_success}")
48
+ if obs.execution_error:
49
+ print(f" Execution error : {obs.execution_error[:100]}")
50
+ print(f" Rows returned : {obs.rows_returned}")
51
+ print(f" Score : {_score_bar(obs.reward or 0.0)}")
52
+ print(f" Best this ep. : {_score_bar(obs.best_score)}")
53
+ # Print just the first 200 chars of feedback to keep output clean
54
+ fb = obs.feedback[:250] + ("…" if len(obs.feedback) > 250 else "")
55
+ print(f" Feedback : {fb}")
56
+ if obs.hint:
57
+ print(f" Hint : {obs.hint[:120]}")
58
+
59
+ def _attempt(env, label: str, sql: str):
60
+ print(f"\n ── Attempt: {label}")
61
+ print(f" SQL: {sql[:100]}{'…' if len(sql) > 100 else ''}")
62
+ obs = env.step(SQLAction(sql=sql))
63
+ _print_obs(obs)
64
+ return obs
65
+
66
+
67
+ # ── Task runners ──────────────────────────────────────────────────────────────
68
+
69
+ def run_easy(env):
70
+ _section("TASK 1 Β· EASY β€” Fix Syntax Errors")
71
+ env._task_index = 0 # pin to easy
72
+ obs = env.reset()
73
+ print(f"\n Task : {obs.task_title} [{obs.task_level}]")
74
+ print(f" Steps: up to {5}")
75
+ _print_obs(obs, show_description=True)
76
+
77
+ _attempt(env, "still broken",
78
+ "SELEC name, age FORM users WEHRE age > 30")
79
+
80
+ _attempt(env, "one keyword fixed",
81
+ "SELECT name, age FORM users WEHRE age > 30")
82
+
83
+ _attempt(env, "all keywords fixed, no filter",
84
+ "SELECT name, age FROM users WHERE age > 30")
85
+
86
+ obs = _attempt(env, "correct solution",
87
+ "SELECT name, age FROM users "
88
+ "WHERE age > 30 AND city = 'New York' "
89
+ "ORDER BY name ASC")
90
+
91
+ print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
92
+
93
+
94
+ def run_medium(env):
95
+ _section("TASK 2 Β· MEDIUM β€” Fix the Cartesian JOIN")
96
+ env._task_index = 1 # pin to medium
97
+ obs = env.reset()
98
+ print(f"\n Task : {obs.task_title} [{obs.task_level}]")
99
+ print(f" Steps: up to {5}")
100
+ _print_obs(obs, show_description=True)
101
+
102
+ _attempt(env, "broken verbatim (cartesian product)",
103
+ "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
104
+ "FROM orders o, users u, products p "
105
+ "WHERE o.user_id = u.id "
106
+ "GROUP BY u.name, p.title "
107
+ "ORDER BY total_spent DESC")
108
+
109
+ _attempt(env, "comma-join but missing product condition",
110
+ "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
111
+ "FROM orders o, users u, products p "
112
+ "WHERE o.user_id = u.id AND o.product_id = p.id "
113
+ "GROUP BY u.name, p.title "
114
+ "ORDER BY total_spent DESC")
115
+
116
+ obs = _attempt(env, "correct INNER JOINs",
117
+ "SELECT u.name, p.title, SUM(o.amount) AS total_spent\n"
118
+ "FROM orders o\n"
119
+ "INNER JOIN users u ON o.user_id = u.id\n"
120
+ "INNER JOIN products p ON o.product_id = p.id\n"
121
+ "GROUP BY u.name, p.title\n"
122
+ "ORDER BY total_spent DESC")
123
+
124
+ print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
125
+
126
+
127
+ def run_hard(env):
128
+ _section("TASK 3 Β· HARD β€” Rewrite Correlated Subquery as CTE")
129
+ env._task_index = 2 # pin to hard
130
+ obs = env.reset()
131
+ print(f"\n Task : {obs.task_title} [{obs.task_level}]")
132
+ print(f" Steps: up to {6}")
133
+ _print_obs(obs, show_description=True)
134
+
135
+ _attempt(env, "broken verbatim (no CTE β€” penalised even though rows match)",
136
+ "SELECT e.name, e.department_id, e.salary\n"
137
+ "FROM employees e\n"
138
+ "WHERE e.salary > (\n"
139
+ " SELECT AVG(e2.salary) FROM employees e2\n"
140
+ " WHERE e2.department_id = e.department_id\n"
141
+ ")\n"
142
+ "ORDER BY e.department_id, e.salary DESC")
143
+
144
+ _attempt(env, "halfway β€” CTE defined but wrong join",
145
+ "WITH dept_avg AS (\n"
146
+ " SELECT department_id, AVG(salary) AS avg_salary\n"
147
+ " FROM employees GROUP BY department_id\n"
148
+ ")\n"
149
+ "SELECT e.name, e.department_id, e.salary\n"
150
+ "FROM employees e, dept_avg d\n"
151
+ "WHERE e.salary > d.avg_salary\n"
152
+ "ORDER BY e.department_id, e.salary DESC")
153
+
154
+ obs = _attempt(env, "correct CTE with proper JOIN",
155
+ "WITH dept_avg AS (\n"
156
+ " SELECT department_id, AVG(salary) AS avg_salary\n"
157
+ " FROM employees\n"
158
+ " GROUP BY department_id\n"
159
+ ")\n"
160
+ "SELECT e.name, e.department_id, e.salary\n"
161
+ "FROM employees e\n"
162
+ "JOIN dept_avg d ON e.department_id = d.department_id\n"
163
+ "WHERE e.salary > d.avg_salary\n"
164
+ "ORDER BY e.department_id, e.salary DESC")
165
+
166
+ print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
167
+
168
+
169
+ # ── Custom task demo ──────────────────────────────────────────────────────────
170
+
171
+ def run_custom(env):
172
+ _section("TASK 4 Β· CUSTOM β€” NULL Handling in Aggregation")
173
+
174
+ # Register a brand-new task at runtime
175
+ custom_task = task_from_dict({
176
+ "id": "custom_null_avg",
177
+ "level": "custom",
178
+ "title": "Handle NULLs in Aggregation",
179
+ "description": """\
180
+ TASK: The query below skips NULL scores, making the class average look higher.
181
+ Fix it so NULL scores are treated as 0.
182
+
183
+ SCHEMA:
184
+ students(id INTEGER, name VARCHAR, score INTEGER)
185
+
186
+ BROKEN QUERY:
187
+ SELECT AVG(score) AS avg_score FROM students
188
+
189
+ ERROR:
190
+ NULL values are silently excluded by AVG(), inflating the result.
191
+
192
+ GOAL: Return a single row with avg_score that treats NULL as 0.
193
+ Expected result: avg_score = 72.5""",
194
+ "schema_ddl": """\
195
+ CREATE TABLE students (id INTEGER, name VARCHAR, score INTEGER);
196
+ INSERT INTO students VALUES
197
+ (1, 'Alice', 90),
198
+ (2, 'Bob', NULL),
199
+ (3, 'Carol', 80),
200
+ (4, 'Dave', NULL),
201
+ (5, 'Eve', 70),
202
+ (6, 'Frank', 50);
203
+ """,
204
+ "broken_query": "SELECT AVG(score) AS avg_score FROM students",
205
+ "error_message": "NULL scores are silently skipped by AVG().",
206
+ "hint": "Wrap score with COALESCE(score, 0) before averaging.",
207
+ "expected_rows": [{"avg_score": 65.0}],
208
+ "solution_query": "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students",
209
+ "test_description": "AVG treats NULL as 0 β†’ 65.0",
210
+ "max_steps": 4,
211
+ })
212
+ REGISTRY.register(custom_task)
213
+
214
+ obs = env.reset(task_id="custom_null_avg")
215
+ print(f"\n Task : {obs.task_title} [{obs.task_level}]")
216
+ print(f" Steps: up to {custom_task.max_steps}")
217
+ _print_obs(obs, show_description=True)
218
+
219
+ _attempt(env, "broken (NULL excluded)",
220
+ "SELECT AVG(score) AS avg_score FROM students")
221
+
222
+ obs = _attempt(env, "correct (COALESCE)",
223
+ "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students")
224
+
225
+ print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
226
+
227
+ # Clean up: remove custom task from registry
228
+ REGISTRY.unregister("custom_null_avg")
229
+ print(" Custom task unregistered from registry.")
230
+
231
+
232
+ # ── Main ──────────────────────────────────────────────────────────────────────
233
+
234
+ if __name__ == "__main__":
235
+ ai_key = os.environ.get("ANTHROPIC_API_KEY")
236
+
237
+ _hr("═")
238
+ print(" QueryForge β€” Local Playbook")
239
+ print(f" AI judge : {'LIVE (ANTHROPIC_API_KEY set)' if ai_key else 'OFFLINE (fallback to deterministic, max 0.80)'}")
240
+ _hr("═")
241
+
242
+ # Create a fresh env for each task so cycling order never matters
243
+ run_easy(QueryforgeEnvironment())
244
+ run_medium(QueryforgeEnvironment())
245
+ run_hard(QueryforgeEnvironment())
246
+ run_custom(QueryforgeEnvironment())
247
+
248
+ _section("DONE")
249
+ print(" All 4 tasks completed.\n")
pyproject.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-queryforge"
13
+ version = "0.1.0"
14
+ description = "Queryforge environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.1",
21
+ # SQL execution engine (in-memory, no external deps)
22
+ "duckdb>=0.10.0",
23
+ # AI judge β€” quality scoring via Anthropic API
24
+ "anthropic>=0.25.0",
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ dev = [
29
+ "pytest>=8.0.0",
30
+ "pytest-cov>=4.0.0",
31
+ ]
32
+
33
+ [project.scripts]
34
+ # Server entry point - enables running via: uv run --project . server
35
+ # or: python -m queryforge.server.app
36
+ server = "queryforge.server.app:main"
37
+
38
+ [tool.setuptools]
39
+ include-package-data = true
40
+ packages = ["queryforge", "queryforge.server"]
41
+ package-dir = { "queryforge" = ".", "queryforge.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Queryforge environment server components."""
8
+
9
+ from .queryforge_environment import QueryforgeEnvironment
10
+
11
+ __all__ = ["QueryforgeEnvironment"]
server/app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Queryforge Environment.
9
+
10
+ This module creates an HTTP server that exposes the QueryforgeEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ try:
39
+ from ..models import SQLAction, SQLObservation, TaskSpec
40
+ from ..tasks import REGISTRY, task_from_dict
41
+ from .queryforge_environment import QueryforgeEnvironment
42
+ except ImportError:
43
+ from models import SQLAction, SQLObservation, TaskSpec
44
+ from tasks import REGISTRY, task_from_dict
45
+ from server.queryforge_environment import QueryforgeEnvironment
46
+
47
+ from fastapi import HTTPException
48
+
49
+
50
+ # Create the app with web interface and README integration
51
+ app = create_app(
52
+ QueryforgeEnvironment,
53
+ SQLAction,
54
+ SQLObservation,
55
+ env_name="queryforge",
56
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
57
+ )
58
+
59
+
60
+ # ── Task Registry REST endpoints ──────────────────────────────────────────────
61
+
62
+ @app.post("/tasks", tags=["Task Registry"], status_code=201)
63
+ async def register_task(spec: TaskSpec):
64
+ """Register a custom SQL task. Replaces silently if the ID already exists."""
65
+ task = task_from_dict(spec.model_dump())
66
+ REGISTRY.register(task)
67
+ return {"ok": True, "task_id": task.id, "total_tasks": len(REGISTRY)}
68
+
69
+
70
+ @app.get("/tasks", tags=["Task Registry"])
71
+ async def list_tasks():
72
+ """List all registered tasks (built-in + custom)."""
73
+ return [
74
+ {"id": t.id, "level": t.level, "title": t.title}
75
+ for t in REGISTRY.list_all()
76
+ ]
77
+
78
+
79
+ @app.delete("/tasks/{task_id}", tags=["Task Registry"])
80
+ async def delete_task(task_id: str):
81
+ """Remove a custom task. Returns 403 for built-in tasks, 404 if not found."""
82
+ try:
83
+ REGISTRY.unregister(task_id)
84
+ return {"ok": True, "task_id": task_id}
85
+ except ValueError as exc:
86
+ raise HTTPException(status_code=403, detail=str(exc))
87
+ except KeyError:
88
+ raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
89
+
90
+
91
+ def main(host: str = "0.0.0.0", port: int = 8000):
92
+ """
93
+ Entry point for direct execution via uv run or python -m.
94
+
95
+ This function enables running the server without Docker:
96
+ uv run --project . server
97
+ uv run --project . server --port 8001
98
+ python -m queryforge.server.app
99
+
100
+ Args:
101
+ host: Host address to bind to (default: "0.0.0.0")
102
+ port: Port number to listen on (default: 8000)
103
+
104
+ For production deployments, consider using uvicorn directly with
105
+ multiple workers:
106
+ uvicorn queryforge.server.app:app --workers 4
107
+ """
108
+ import uvicorn
109
+
110
+ uvicorn.run(app, host=host, port=port)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ import argparse
115
+
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument("--port", type=int, default=8000)
118
+ args = parser.parse_args()
119
+ main(port=args.port)
server/queryforge_environment.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryForge SQL Environment β€” server-side implementation.
3
+
4
+ The agent interacts with a SQL debugging and optimisation challenge:
5
+ reset() β†’ next task in round-robin rotation
6
+ reset(task_id="x") β†’ pin to a specific task by ID (built-in or custom)
7
+ step() β†’ grade the submitted query, return scored observation
8
+ state β†’ episode_id + step count
9
+
10
+ Reward scale:
11
+ 0.00 syntax error
12
+ 0.15 syntax valid, runtime error
13
+ 0.30 executes, wrong / empty results
14
+ 0.30–0.80 partial row correctness (deterministic, DuckDB)
15
+ 0.80–1.00 correct results + AI quality assessment (Anthropic)
16
+
17
+ Episode ends when:
18
+ - score >= 0.90 (correct + high-quality solution)
19
+ - best_score has not improved for 2 consecutive steps (early stopping)
20
+ - max_steps for the task is exhausted
21
+ """
22
+
23
+ from typing import Optional
24
+ from uuid import uuid4
25
+
26
+ from openenv.core.env_server.interfaces import Environment
27
+ from openenv.core.env_server.types import State
28
+
29
+ try:
30
+ from ..models import SQLAction, SQLObservation
31
+ from ..tasks import REGISTRY, SQLTask
32
+ from ..judge import grade
33
+ except ImportError:
34
+ from models import SQLAction, SQLObservation
35
+ from tasks import REGISTRY, SQLTask
36
+ from judge import grade
37
+
38
+
39
+ class QueryforgeEnvironment(Environment):
40
+ """
41
+ SQL Query Debugger & Optimiser environment.
42
+
43
+ Built-in tasks (cycled in order by default):
44
+ 1. easy β€” fix three misspelled SQL keywords
45
+ 2. medium β€” fix a missing JOIN condition causing a cartesian product
46
+ 3. hard β€” rewrite a correlated subquery as a CTE
47
+
48
+ Custom tasks can be registered at runtime via POST /tasks and then
49
+ requested by passing task_id to reset():
50
+ env.reset(task_id="my_custom_task")
51
+
52
+ Each episode ends when:
53
+ - The agent achieves score β‰₯ 0.90 (correct + high-quality solution), or
54
+ - best_score has not improved for 2 consecutive steps (early stopping), or
55
+ - The maximum steps for the current task is exhausted.
56
+
57
+ Supports concurrent WebSocket sessions (each client gets its own instance).
58
+ """
59
+
60
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
61
+
62
+ # Episode ends when score >= this threshold.
63
+ # Falls back to 0.80 when ANTHROPIC_API_KEY is unset (AI judge offline,
64
+ # deterministic scoring caps at 0.80).
65
+ DONE_THRESHOLD: float = 0.80 if not __import__("os").environ.get("ANTHROPIC_API_KEY") else 0.90
66
+ # Episode ends when best_score has not improved for this many consecutive steps
67
+ EARLY_STOP_STEPS: int = 2
68
+
69
+ def __init__(self) -> None:
70
+ self._state = State(episode_id=str(uuid4()), step_count=0)
71
+ self._current_task: Optional[SQLTask] = None
72
+ self._best_score: float = 0.0
73
+ self._attempt: int = 0
74
+ self._stale_steps: int = 0 # consecutive steps with no best_score improvement
75
+
76
+ # ── OpenEnv interface ─────────────────────────────────────────────────────
77
+
78
+ def reset(
79
+ self,
80
+ task_id: Optional[str] = None,
81
+ seed: Optional[int] = None,
82
+ episode_id: Optional[str] = None,
83
+ **kwargs,
84
+ ) -> SQLObservation:
85
+ """
86
+ Start a new episode.
87
+
88
+ Args:
89
+ task_id: Pin to a specific task by ID. If None, the registry
90
+ cycles round-robin through all registered tasks.
91
+ seed: Ignored (reserved for future use).
92
+ episode_id: Optional custom episode identifier.
93
+ """
94
+ ep_id = episode_id or str(uuid4())
95
+ self._state = State(episode_id=ep_id, step_count=0)
96
+ self._best_score = 0.0
97
+ self._attempt = 0
98
+ self._stale_steps = 0
99
+
100
+ if task_id is not None:
101
+ try:
102
+ self._current_task = REGISTRY.get(task_id)
103
+ except KeyError as exc:
104
+ # Unknown task_id β€” return an error observation so the caller
105
+ # gets clear feedback instead of a silent 500.
106
+ return SQLObservation(
107
+ feedback=str(exc),
108
+ hint=f"Available task IDs: {', '.join(REGISTRY.ids())}",
109
+ done=True,
110
+ reward=0.0,
111
+ )
112
+ else:
113
+ self._current_task = REGISTRY.cycle_next()
114
+
115
+ return SQLObservation(
116
+ task_id=self._current_task.id,
117
+ task_level=self._current_task.level,
118
+ task_title=self._current_task.title,
119
+ task_description=self._current_task.description,
120
+ syntax_valid=False,
121
+ execution_success=False,
122
+ execution_error=None,
123
+ rows_returned=0,
124
+ feedback="New task loaded. Submit your fixed/optimised SQL query.",
125
+ hint=self._current_task.hint,
126
+ attempt=0,
127
+ best_score=0.0,
128
+ done=False,
129
+ reward=0.0,
130
+ )
131
+
132
+ def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
133
+ """Grade the submitted SQL query and return a scored observation."""
134
+ self._state.step_count += 1
135
+ self._attempt += 1
136
+
137
+ if self._current_task is None:
138
+ return SQLObservation(
139
+ feedback="No task active. Call reset() first.",
140
+ hint="Call reset() to start a new episode.",
141
+ done=True,
142
+ reward=0.0,
143
+ )
144
+
145
+ score, feedback, details = grade(self._current_task, action.sql)
146
+
147
+ # Fix 1 β€” early stopping: track consecutive steps with no improvement
148
+ if score > self._best_score:
149
+ self._stale_steps = 0
150
+ else:
151
+ self._stale_steps += 1
152
+ self._best_score = max(self._best_score, score)
153
+
154
+ # Fix 3 β€” lower done threshold + early stopping condition
155
+ done = (
156
+ score >= self.DONE_THRESHOLD
157
+ or self._stale_steps >= self.EARLY_STOP_STEPS
158
+ or self._state.step_count >= self._current_task.max_steps
159
+ )
160
+
161
+ return SQLObservation(
162
+ task_id=self._current_task.id,
163
+ task_level=self._current_task.level,
164
+ task_title=self._current_task.title,
165
+ task_description=self._current_task.description,
166
+ syntax_valid=bool(details.get("syntax_valid", False)),
167
+ execution_success=bool(details.get("execution_success", False)),
168
+ execution_error=details.get("execution_error"),
169
+ rows_returned=int(details.get("rows_returned", 0)),
170
+ feedback=feedback,
171
+ hint="" if score >= 0.9 else self._current_task.hint,
172
+ attempt=self._attempt,
173
+ best_score=self._best_score,
174
+ done=done,
175
+ reward=score,
176
+ )
177
+
178
+ @property
179
+ def state(self) -> State:
180
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ duckdb>=0.10.0
5
+ anthropic>=0.25.0
6
+
7
+
8
+
tasks.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL task definitions and runtime task registry for the QueryForge environment.
3
+
4
+ Built-in tasks:
5
+ easy β€” fix three misspelled SQL keywords
6
+ medium β€” fix a cartesian JOIN producing wrong results
7
+ hard β€” rewrite a correlated subquery as a CTE
8
+
9
+ Custom tasks can be added at runtime via REGISTRY.register() or
10
+ POST /tasks on the running server.
11
+ """
12
+
13
+ import json
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from threading import Lock
17
+ from typing import Any, Dict, List, Optional
18
+
19
+
20
+ # ── Data classes ──────────────────────────────────────────────────────────────
21
+
22
+ @dataclass
23
+ class TestCase:
24
+ """A single test case: expected output rows for correctness grading."""
25
+
26
+ description: str
27
+ expected_rows: List[Dict[str, Any]]
28
+ order_by: Optional[str] = None # comma-separated columns to sort by
29
+
30
+
31
+ @dataclass
32
+ class SQLTask:
33
+ """Full definition of one SQL challenge."""
34
+
35
+ id: str
36
+ level: str # "easy" | "medium" | "hard" | "custom"
37
+ title: str
38
+ description: str
39
+ schema_ddl: str # DDL + seed INSERT statements for DuckDB
40
+ broken_query: str # broken/slow query the agent must fix
41
+ error_message: str # error or performance warning shown to agent
42
+ hint: str
43
+ test_cases: List[TestCase]
44
+ solution_query: str # reference solution used by the AI judge
45
+ max_steps: int = 5
46
+
47
+
48
+ # ── Built-in tasks ────────────────────────────────────────────────────────────
49
+
50
+ _TASK_EASY = SQLTask(
51
+ id="task_easy_syntax",
52
+ level="easy",
53
+ title="Fix the Syntax Errors",
54
+ description="""\
55
+ TASK: Fix the syntax errors in the query below so it runs correctly.
56
+
57
+ SCHEMA:
58
+ users(id INTEGER, name VARCHAR, age INTEGER, city VARCHAR)
59
+
60
+ BROKEN QUERY:
61
+ SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'
62
+
63
+ ERROR:
64
+ Parser Error: syntax error at or near "SELEC"
65
+
66
+ GOAL: Return a valid SQL query that retrieves `name` and `age`
67
+ of users who are older than 30 AND live in New York.
68
+ Order by name ASC.""",
69
+ schema_ddl="""\
70
+ CREATE TABLE users (
71
+ id INTEGER,
72
+ name VARCHAR,
73
+ age INTEGER,
74
+ city VARCHAR
75
+ );
76
+ INSERT INTO users VALUES
77
+ (1, 'Alice', 35, 'New York'),
78
+ (2, 'Bob', 28, 'New York'),
79
+ (3, 'Carol', 42, 'Chicago'),
80
+ (4, 'Dave', 31, 'New York'),
81
+ (5, 'Eve', 25, 'New York'),
82
+ (6, 'Frank', 38, 'New York');
83
+ """,
84
+ broken_query="SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'",
85
+ error_message='Parser Error: syntax error at or near "SELEC"',
86
+ hint="Three SQL keywords are misspelled: SELEC β†’ SELECT, FORM β†’ FROM, WEHRE β†’ WHERE.",
87
+ test_cases=[
88
+ TestCase(
89
+ description="Users over 30 living in New York, ordered by name",
90
+ expected_rows=[
91
+ {"name": "Alice", "age": 35},
92
+ {"name": "Dave", "age": 31},
93
+ {"name": "Frank", "age": 38},
94
+ ],
95
+ order_by="name",
96
+ )
97
+ ],
98
+ solution_query=(
99
+ "SELECT name, age FROM users "
100
+ "WHERE age > 30 AND city = 'New York' "
101
+ "ORDER BY name ASC"
102
+ ),
103
+ )
104
+
105
+ _TASK_MEDIUM = SQLTask(
106
+ id="task_medium_join",
107
+ level="medium",
108
+ title="Fix the Cartesian JOIN",
109
+ description="""\
110
+ TASK: The query below produces wildly inflated totals because a JOIN condition
111
+ is missing, creating a cartesian product with the `products` table. Fix it.
112
+
113
+ SCHEMAS:
114
+ users(id INTEGER, name VARCHAR, age INTEGER)
115
+ products(id INTEGER, title VARCHAR, price DECIMAL)
116
+ orders(id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL)
117
+
118
+ BROKEN QUERY:
119
+ SELECT u.name, p.title, SUM(o.amount) AS total_spent
120
+ FROM orders o, users u, products p
121
+ WHERE o.user_id = u.id
122
+ GROUP BY u.name, p.title
123
+ ORDER BY total_spent DESC
124
+
125
+ PROBLEM:
126
+ Missing join condition `o.product_id = p.id`.
127
+ Every order row is multiplied by ALL products, inflating every total by 3Γ—.
128
+
129
+ GOAL: Rewrite using explicit INNER JOIN … ON syntax with all correct join
130
+ conditions. Return user name, product title, and true total amount spent per
131
+ (user, product) pair, ordered by total_spent DESC.""",
132
+ schema_ddl="""\
133
+ CREATE TABLE users (id INTEGER, name VARCHAR, age INTEGER);
134
+ CREATE TABLE products (id INTEGER, title VARCHAR, price DECIMAL);
135
+ CREATE TABLE orders (id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL);
136
+
137
+ INSERT INTO users VALUES (1,'Alice',30),(2,'Bob',25),(3,'Carol',35);
138
+ INSERT INTO products VALUES (1,'Laptop',999.99),(2,'Phone',599.99),(3,'Tablet',399.99);
139
+ INSERT INTO orders VALUES
140
+ (1,1,1,999.99),(2,1,2,599.99),
141
+ (3,2,1,999.99),(4,2,3,399.99),
142
+ (5,3,2,599.99),(6,3,1,999.99);
143
+ """,
144
+ broken_query="""\
145
+ SELECT u.name, p.title, SUM(o.amount) AS total_spent
146
+ FROM orders o, users u, products p
147
+ WHERE o.user_id = u.id
148
+ GROUP BY u.name, p.title
149
+ ORDER BY total_spent DESC""",
150
+ error_message=(
151
+ "Query runs but produces WRONG results: totals are 3Γ— too high "
152
+ "because every order is joined to every product (cartesian product)."
153
+ ),
154
+ hint=(
155
+ "Use INNER JOIN … ON for every table. "
156
+ "You need both: o.user_id = u.id AND o.product_id = p.id."
157
+ ),
158
+ test_cases=[
159
+ TestCase(
160
+ description="Correct per-(user, product) totals",
161
+ expected_rows=[
162
+ {"name": "Alice", "title": "Laptop", "total_spent": 999.99},
163
+ {"name": "Alice", "title": "Phone", "total_spent": 599.99},
164
+ {"name": "Bob", "title": "Laptop", "total_spent": 999.99},
165
+ {"name": "Bob", "title": "Tablet", "total_spent": 399.99},
166
+ {"name": "Carol", "title": "Laptop", "total_spent": 999.99},
167
+ {"name": "Carol", "title": "Phone", "total_spent": 599.99},
168
+ ],
169
+ order_by="name,title",
170
+ )
171
+ ],
172
+ solution_query="""\
173
+ SELECT u.name, p.title, SUM(o.amount) AS total_spent
174
+ FROM orders o
175
+ INNER JOIN users u ON o.user_id = u.id
176
+ INNER JOIN products p ON o.product_id = p.id
177
+ GROUP BY u.name, p.title
178
+ ORDER BY total_spent DESC""",
179
+ )
180
+
181
+ _TASK_HARD = SQLTask(
182
+ id="task_hard_cte",
183
+ level="hard",
184
+ title="Rewrite Correlated Subquery as CTE",
185
+ description="""\
186
+ TASK: The query below is semantically correct but executes the inner AVG(salary)
187
+ once per employee row β€” O(N) full scans. Rewrite it using a WITH (CTE) so the
188
+ department averages are computed exactly once.
189
+
190
+ SCHEMAS:
191
+ departments(id INTEGER, dept_name VARCHAR)
192
+ employees(id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL)
193
+
194
+ SLOW QUERY:
195
+ SELECT e.name, e.department_id, e.salary
196
+ FROM employees e
197
+ WHERE e.salary > (
198
+ SELECT AVG(e2.salary)
199
+ FROM employees e2
200
+ WHERE e2.department_id = e.department_id
201
+ )
202
+ ORDER BY e.department_id, e.salary DESC
203
+
204
+ PERFORMANCE WARNING:
205
+ For 1 M employees the inner subquery executes 1 M times.
206
+ DuckDB's EXPLAIN shows: 'FILTER ... (subquery)' with nested loop.
207
+
208
+ GOAL: Rewrite using a CTE that computes per-department average salary once,
209
+ then join it to employees and filter. The result must be identical:
210
+ employees who earn strictly above their own department's average salary,
211
+ ordered by department_id ASC, salary DESC.""",
212
+ schema_ddl="""\
213
+ CREATE TABLE departments (id INTEGER, dept_name VARCHAR);
214
+ CREATE TABLE employees (id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL);
215
+
216
+ INSERT INTO departments VALUES (1,'Engineering'),(2,'Marketing'),(3,'Sales');
217
+ INSERT INTO employees VALUES
218
+ (1,'Alice', 1, 95000),(2,'Bob', 1, 75000),(3,'Carol', 1, 85000),
219
+ (4,'Dave', 2, 65000),(5,'Eve', 2, 70000),(6,'Frank', 2, 60000),
220
+ (7,'Grace', 3, 55000),(8,'Hank', 3, 72000),(9,'Iris', 3, 58000);
221
+ """,
222
+ broken_query="""\
223
+ SELECT e.name, e.department_id, e.salary
224
+ FROM employees e
225
+ WHERE e.salary > (
226
+ SELECT AVG(e2.salary)
227
+ FROM employees e2
228
+ WHERE e2.department_id = e.department_id
229
+ )
230
+ ORDER BY e.department_id, e.salary DESC""",
231
+ error_message=(
232
+ "PERFORMANCE: Correlated subquery re-executes AVG() for every row. "
233
+ "On large tables this is O(NΒ²). Rewrite as a CTE for O(N) execution."
234
+ ),
235
+ hint=(
236
+ "WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary "
237
+ "FROM employees GROUP BY department_id) β€” then JOIN employees to dept_avg "
238
+ "and filter WHERE e.salary > d.avg_salary."
239
+ ),
240
+ test_cases=[
241
+ TestCase(
242
+ description="Employees strictly above their department's average salary",
243
+ expected_rows=[
244
+ {"name": "Alice", "department_id": 1, "salary": 95000.0},
245
+ {"name": "Eve", "department_id": 2, "salary": 70000.0},
246
+ {"name": "Hank", "department_id": 3, "salary": 72000.0},
247
+ ],
248
+ order_by="department_id,name",
249
+ )
250
+ ],
251
+ solution_query="""\
252
+ WITH dept_avg AS (
253
+ SELECT department_id, AVG(salary) AS avg_salary
254
+ FROM employees
255
+ GROUP BY department_id
256
+ )
257
+ SELECT e.name, e.department_id, e.salary
258
+ FROM employees e
259
+ JOIN dept_avg d ON e.department_id = d.department_id
260
+ WHERE e.salary > d.avg_salary
261
+ ORDER BY e.department_id, e.salary DESC""",
262
+ max_steps=6,
263
+ )
264
+
265
+
266
+ # ── Task Registry ─────────────────────────────────────────────────────────────
267
+
268
+ class TaskRegistry:
269
+ """
270
+ Thread-safe registry of SQL tasks, shared across all environment sessions.
271
+
272
+ Built-in tasks (easy / medium / hard) are always present and cannot be removed.
273
+ Custom tasks can be added via register(), load_from_json(), or POST /tasks.
274
+ """
275
+
276
+ _BUILTIN_IDS: frozenset = frozenset(
277
+ ["task_easy_syntax", "task_medium_join", "task_hard_cte"]
278
+ )
279
+
280
+ def __init__(self, initial_tasks: List[SQLTask]) -> None:
281
+ self._lock = Lock()
282
+ # Insertion-ordered dict preserves cycling order
283
+ self._tasks: Dict[str, SQLTask] = {t.id: t for t in initial_tasks}
284
+ self._cycle_index: int = 0
285
+
286
+ # ── CRUD ─────────────────────────────────────────────────────────────────
287
+
288
+ def register(self, task: SQLTask) -> None:
289
+ """Add or replace a task. Replaces silently if the ID already exists."""
290
+ with self._lock:
291
+ self._tasks[task.id] = task
292
+
293
+ def unregister(self, task_id: str) -> None:
294
+ """
295
+ Remove a custom task.
296
+ Raises ValueError for built-in tasks, KeyError if not found.
297
+ """
298
+ if task_id in self._BUILTIN_IDS:
299
+ raise ValueError(f"Built-in task '{task_id}' cannot be removed.")
300
+ with self._lock:
301
+ if task_id not in self._tasks:
302
+ raise KeyError(task_id)
303
+ del self._tasks[task_id]
304
+
305
+ def get(self, task_id: str) -> SQLTask:
306
+ """Return a task by ID. Raises KeyError with available IDs if not found."""
307
+ with self._lock:
308
+ if task_id not in self._tasks:
309
+ available = ", ".join(self._tasks.keys())
310
+ raise KeyError(
311
+ f"Task '{task_id}' not found. "
312
+ f"Available: {available}"
313
+ )
314
+ return self._tasks[task_id]
315
+
316
+ def list_all(self) -> List[SQLTask]:
317
+ """Return all registered tasks in insertion order."""
318
+ with self._lock:
319
+ return list(self._tasks.values())
320
+
321
+ def ids(self) -> List[str]:
322
+ """Return all task IDs in insertion order."""
323
+ with self._lock:
324
+ return list(self._tasks.keys())
325
+
326
+ # ── Cycling ───────────────────────────────────────────────────────────────
327
+
328
+ def cycle_next(self) -> SQLTask:
329
+ """Return the next task in round-robin order (wraps at end)."""
330
+ with self._lock:
331
+ tasks = list(self._tasks.values())
332
+ task = tasks[self._cycle_index % len(tasks)]
333
+ self._cycle_index += 1
334
+ return task
335
+
336
+ # ── Bulk loading ──────────────────────────────────────────────────────────
337
+
338
+ def load_from_json(self, path: str) -> int:
339
+ """
340
+ Load tasks from a JSON file (list of task spec objects).
341
+ Returns the number of tasks loaded.
342
+
343
+ Minimal required fields per task:
344
+ id, schema_ddl, expected_rows
345
+
346
+ Example file::
347
+
348
+ [
349
+ {
350
+ "id": "my_null_task",
351
+ "level": "medium",
352
+ "title": "Handle NULLs in aggregation",
353
+ "schema_ddl": "CREATE TABLE ...; INSERT ...",
354
+ "broken_query": "SELECT AVG(score) FROM ...",
355
+ "expected_rows": [{"avg_score": 72.5}],
356
+ "hint": "Use COALESCE to handle NULL scores."
357
+ }
358
+ ]
359
+ """
360
+ raw = json.loads(Path(path).read_text())
361
+ if isinstance(raw, dict):
362
+ raw = [raw]
363
+ for item in raw:
364
+ self.register(task_from_dict(item))
365
+ return len(raw)
366
+
367
+ # ── Helpers ───────────────────────────────────────────────────────────────
368
+
369
+ def __len__(self) -> int:
370
+ with self._lock:
371
+ return len(self._tasks)
372
+
373
+ def __contains__(self, task_id: str) -> bool:
374
+ with self._lock:
375
+ return task_id in self._tasks
376
+
377
+
378
+ # ── Conversion helper ─────────────────────────────────────────────────────────
379
+
380
+ def task_from_dict(d: Dict[str, Any]) -> SQLTask:
381
+ """
382
+ Construct an SQLTask from a plain dict (JSON payload or loaded file).
383
+
384
+ Required keys : id, schema_ddl, expected_rows
385
+ Optional keys : level, title, description, broken_query, error_message,
386
+ hint, order_by, solution_query, test_description, max_steps
387
+ """
388
+ return SQLTask(
389
+ id=d["id"],
390
+ level=d.get("level", "custom"),
391
+ title=d.get("title", d["id"]),
392
+ description=d.get("description", ""),
393
+ schema_ddl=d["schema_ddl"],
394
+ broken_query=d.get("broken_query", ""),
395
+ error_message=d.get("error_message", ""),
396
+ hint=d.get("hint", ""),
397
+ test_cases=[
398
+ TestCase(
399
+ description=d.get("test_description", "Custom test case"),
400
+ expected_rows=d["expected_rows"],
401
+ order_by=d.get("order_by"),
402
+ )
403
+ ],
404
+ solution_query=d.get("solution_query", ""),
405
+ max_steps=d.get("max_steps", 5),
406
+ )
407
+
408
+
409
+ # ── Global singleton ──────────────────────────────────────────────────────────
410
+
411
+ REGISTRY = TaskRegistry([_TASK_EASY, _TASK_MEDIUM, _TASK_HARD])
412
+
413
+ # Backwards-compat: snapshot of the three built-in tasks at import time
414
+ TASKS: List[SQLTask] = [_TASK_EASY, _TASK_MEDIUM, _TASK_HARD]
415
+ TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS}
uv.lock ADDED
The diff for this file is too large to render. See raw diff