Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +348 -5
- __init__.py +23 -0
- baseline.py +244 -0
- client.py +94 -0
- judge.py +414 -0
- models.py +126 -0
- openenv.yaml +26 -0
- openenv_queryforge.egg-info/PKG-INFO +11 -0
- openenv_queryforge.egg-info/SOURCES.txt +16 -0
- openenv_queryforge.egg-info/dependency_links.txt +1 -0
- openenv_queryforge.egg-info/entry_points.txt +2 -0
- openenv_queryforge.egg-info/requires.txt +7 -0
- openenv_queryforge.egg-info/top_level.txt +1 -0
- playbook.py +249 -0
- pyproject.toml +41 -0
- server/__init__.py +11 -0
- server/app.py +119 -0
- server/queryforge_environment.py +180 -0
- server/requirements.txt +8 -0
- tasks.py +415 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=queryforge
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,353 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: QueryForge Environment Server
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- sql
|
| 13 |
+
- reinforcement-learning
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# QueryForge β SQL Debugging & Optimisation Environment
|
| 17 |
+
|
| 18 |
+
SQL is the language that runs the world's data infrastructure. Yet SQL bugs are silent killers β a missing JOIN condition inflates totals by 3Γ, a correlated subquery scans a million rows once per row, a typo in a keyword stops production cold. These bugs are rarely caught by linters, rarely surfaced by error messages, and routinely shipped to production.
|
| 19 |
+
|
| 20 |
+
QueryForge is an **OpenEnv-compatible reinforcement learning environment** where an agent learns to debug and optimise SQL queries. The agent receives a broken or slow query, submits fixes, and receives graded feedback from a deterministic DuckDB engine combined with an Anthropic AI quality judge β a smooth, informative reward signal across the full 0.0 β 1.0 range.
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Why SQL Debugging as an RL Environment?
|
| 25 |
+
|
| 26 |
+
LLMs can write SQL. What they struggle with is the **iterative, feedback-driven debugging loop** that real engineers do:
|
| 27 |
+
|
| 28 |
+
- Read the error message
|
| 29 |
+
- Form a hypothesis about the root cause
|
| 30 |
+
- Patch the query
|
| 31 |
+
- Check if the output is now correct
|
| 32 |
+
- Refine until it's both correct *and* efficient
|
| 33 |
+
|
| 34 |
+
This is precisely the loop that RL is built for. QueryForge provides the environment that closes this loop with a graded, multi-stage reward signal β not just "correct / incorrect" but partial credit for syntax validity, execution success, row correctness, and code quality.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## Environment Overview
|
| 39 |
+
|
| 40 |
+
| Property | Value |
|
| 41 |
+
|---|---|
|
| 42 |
+
| Task type | SQL debugging & optimisation |
|
| 43 |
+
| Action space | Single SQL query string |
|
| 44 |
+
| Observation space | Task description + graded feedback |
|
| 45 |
+
| Reward range | 0.0 β 1.0 (continuous) |
|
| 46 |
+
| Episode termination | Score β₯ 0.90, no improvement for 2 steps, or max steps |
|
| 47 |
+
| Grading engine | DuckDB (deterministic) + Anthropic AI judge |
|
| 48 |
+
| Concurrent sessions | Supported |
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Reward Scale
|
| 53 |
+
|
| 54 |
+
The grading pipeline has four stages that produce a smooth partial-progress signal:
|
| 55 |
+
|
| 56 |
+
| Score | Meaning |
|
| 57 |
+
|---|---|
|
| 58 |
+
| **0.00** | Syntax error β query could not be parsed |
|
| 59 |
+
| **0.15** | Syntax valid but runtime error |
|
| 60 |
+
| **0.30** | Executes but returns 0 rows or wrong row count |
|
| 61 |
+
| **0.30 β 0.80** | Partial row correctness (deterministic, DuckDB) |
|
| 62 |
+
| **0.80 β 1.00** | Correct rows + AI quality assessment (Anthropic) |
|
| 63 |
+
|
| 64 |
+
The AI judge scores on three axes: **Correctness** (0β0.50), **Optimization** (0β0.30 β penalises cartesian products, correlated subqueries), **Code quality** (0β0.20 β readability, aliases, formatting).
|
| 65 |
+
|
| 66 |
+
> **Offline mode:** If `ANTHROPIC_API_KEY` is not set, the AI judge is skipped and scoring is fully deterministic (capped at 0.80). The done threshold self-adjusts to 0.80 in this case so episodes still terminate correctly.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Action Space
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
class SQLAction(Action):
|
| 74 |
+
sql: str # The SQL query to submit for grading
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
One field. The agent submits a SQL string. No multi-statement queries (`;` separated) are allowed β rejected with score 0.0.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Observation Space
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
class SQLObservation(Observation):
|
| 85 |
+
# Task context (set on reset, constant within an episode)
|
| 86 |
+
task_id: str # e.g. "task_easy_syntax"
|
| 87 |
+
task_level: str # "easy" | "medium" | "hard" | "custom"
|
| 88 |
+
task_title: str # Human-readable title
|
| 89 |
+
task_description: str # Full context: schema, broken query, error, goal
|
| 90 |
+
|
| 91 |
+
# Per-step grading signals
|
| 92 |
+
syntax_valid: bool # True if query parsed without error
|
| 93 |
+
execution_success: bool # True if query ran to completion in DuckDB
|
| 94 |
+
execution_error: str # Runtime error message, if any
|
| 95 |
+
rows_returned: int # Number of rows returned
|
| 96 |
+
|
| 97 |
+
# Feedback
|
| 98 |
+
feedback: str # Detailed grading feedback (DuckDB + AI judge)
|
| 99 |
+
hint: str # Actionable hint (suppressed once score >= 0.90)
|
| 100 |
+
|
| 101 |
+
# Episode progress
|
| 102 |
+
attempt: int # Number of queries submitted this episode
|
| 103 |
+
best_score: float # Highest score achieved so far
|
| 104 |
+
done: bool
|
| 105 |
+
reward: float # Score for this specific step (0.0 β 1.0)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## Built-in Tasks
|
| 111 |
+
|
| 112 |
+
### Easy β Fix Syntax Errors
|
| 113 |
+
Three SQL keywords are misspelled (`SELEC`, `FORM`, `WEHRE`). The agent must identify and correct them.
|
| 114 |
+
|
| 115 |
+
**Schema:** `users(id, name, age, city)` β 6 rows
|
| 116 |
+
**Goal:** Return name and age of users older than 30 in New York, ordered by name
|
| 117 |
+
**Max steps:** 5
|
| 118 |
+
|
| 119 |
+
### Medium β Fix the Cartesian JOIN
|
| 120 |
+
A missing `JOIN` condition (`o.product_id = p.id`) causes a cartesian product, inflating every total by 3Γ. The agent must rewrite using explicit `INNER JOIN β¦ ON` syntax.
|
| 121 |
+
|
| 122 |
+
**Schema:** `orders`, `users`, `products` β e-commerce dataset
|
| 123 |
+
**Goal:** Correct per-(user, product) total amount spent, ordered by total DESC
|
| 124 |
+
**Max steps:** 5
|
| 125 |
+
|
| 126 |
+
### Hard β Rewrite Correlated Subquery as CTE
|
| 127 |
+
A semantically correct but O(NΒ²) query re-executes `AVG(salary)` for every employee row. The agent must rewrite using a `WITH` clause that computes department averages exactly once.
|
| 128 |
+
|
| 129 |
+
**Schema:** `departments`, `employees` β 9 employees across 3 departments
|
| 130 |
+
**Goal:** Employees who earn strictly above their department average, ordered by dept/salary
|
| 131 |
+
**Max steps:** 6
|
| 132 |
+
|
| 133 |
+
> Tasks have **structural penalties**: the hard task requires a `WITH` clause (β0.30 if absent); the medium task requires explicit `JOIN` syntax (β0.20 if absent). This prevents an agent from gaming the score by submitting the broken query verbatim.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## Custom Tasks
|
| 138 |
+
|
| 139 |
+
Register any SQL task at runtime β no code changes needed.
|
| 140 |
+
|
| 141 |
+
### Via Python
|
| 142 |
+
```python
|
| 143 |
+
from tasks import REGISTRY, task_from_dict
|
| 144 |
+
|
| 145 |
+
REGISTRY.register(task_from_dict({
|
| 146 |
+
"id": "my_window_task",
|
| 147 |
+
"level": "hard",
|
| 148 |
+
"title": "Rank Employees by Salary",
|
| 149 |
+
"schema_ddl": "CREATE TABLE emp (id INT, name VARCHAR, dept VARCHAR, salary DECIMAL); INSERT INTO emp VALUES ...",
|
| 150 |
+
"broken_query": "SELECT name, salary FROM emp ORDER BY salary DESC",
|
| 151 |
+
"expected_rows": [{"name": "Alice", "rank": 1}, ...],
|
| 152 |
+
"hint": "Use ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC)",
|
| 153 |
+
"solution_query": "SELECT name, RANK() OVER (ORDER BY salary DESC) AS rank FROM emp",
|
| 154 |
+
}))
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### Via REST API (when server is running)
|
| 158 |
+
```bash
|
| 159 |
+
# Register a custom task
|
| 160 |
+
curl -X POST http://localhost:8000/tasks \
|
| 161 |
+
-H "Content-Type: application/json" \
|
| 162 |
+
-d '{"id": "my_task", "schema_ddl": "...", "expected_rows": [...]}'
|
| 163 |
+
|
| 164 |
+
# List all tasks
|
| 165 |
+
curl http://localhost:8000/tasks
|
| 166 |
+
|
| 167 |
+
# Remove a custom task
|
| 168 |
+
curl -X DELETE http://localhost:8000/tasks/my_task
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### Via JSON file
|
| 172 |
+
```python
|
| 173 |
+
REGISTRY.load_from_json("my_tasks.json")
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## Quickstart
|
| 179 |
+
|
| 180 |
+
### Install dependencies
|
| 181 |
+
```bash
|
| 182 |
+
python -m venv .venv
|
| 183 |
+
.venv/bin/pip install -e ".[dev]"
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### Run the local playbook (no server needed)
|
| 187 |
+
Tests all three built-in tasks directly, with progressive SQL attempts:
|
| 188 |
+
```bash
|
| 189 |
+
ANTHROPIC_API_KEY=your_key .venv/bin/python playbook.py
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
### Run the baseline inference script
|
| 193 |
+
Runs a Claude model as an agent against all tasks and reports scores:
|
| 194 |
+
```bash
|
| 195 |
+
# Default model (claude-haiku-4-5)
|
| 196 |
+
ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py
|
| 197 |
+
|
| 198 |
+
# Specific model
|
| 199 |
+
ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py --model claude-opus-4-6
|
| 200 |
+
|
| 201 |
+
# Single task with verbose output
|
| 202 |
+
ANTHROPIC_API_KEY=your_key .venv/bin/python baseline.py --task task_hard_cte --verbose
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### Run the HTTP server
|
| 206 |
+
```bash
|
| 207 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## Baseline Results
|
| 213 |
+
|
| 214 |
+
The following scores were produced by running `claude-haiku-4-5` as the agent against all three tasks with the full AI judge active. These serve as the reproducible baseline for this environment.
|
| 215 |
+
|
| 216 |
+
| Task | Level | Steps Used | Best Score |
|
| 217 |
+
|---|---|---|---|
|
| 218 |
+
| Fix the Syntax Errors | easy | 1 | **1.000** |
|
| 219 |
+
| Fix the Cartesian JOIN | medium | 1 | **0.900** |
|
| 220 |
+
| Rewrite Correlated Subquery as CTE | hard | 1 | **0.950** |
|
| 221 |
+
| **Average** | | | **0.950** |
|
| 222 |
+
|
| 223 |
+
All three tasks were solved (or near-solved) on the first step, demonstrating that:
|
| 224 |
+
- The reward pipeline returns meaningful signal immediately
|
| 225 |
+
- The environment terminates cleanly when the done threshold (β₯ 0.90) is met
|
| 226 |
+
- A stronger model or a harder task set would produce more training-relevant trajectories
|
| 227 |
+
|
| 228 |
+
---
|
| 229 |
+
|
| 230 |
+
## API Endpoints
|
| 231 |
+
|
| 232 |
+
| Method | Endpoint | Description |
|
| 233 |
+
|---|---|---|
|
| 234 |
+
| `POST` | `/reset` | Start a new episode. Pass `{"task_id": "..."}` to pin to a task |
|
| 235 |
+
| `POST` | `/step` | Submit a SQL query: `{"sql": "SELECT ..."}` |
|
| 236 |
+
| `GET` | `/state` | Current episode ID and step count |
|
| 237 |
+
| `GET` | `/schema` | Action and observation JSON schemas |
|
| 238 |
+
| `POST` | `/tasks` | Register a custom task |
|
| 239 |
+
| `GET` | `/tasks` | List all registered tasks |
|
| 240 |
+
| `DELETE` | `/tasks/{task_id}` | Remove a custom task (built-ins protected) |
|
| 241 |
+
| `WS` | `/ws` | WebSocket endpoint for persistent low-latency sessions |
|
| 242 |
+
| `GET` | `/health` | Container health check |
|
| 243 |
+
| `GET` | `/docs` | Interactive OpenAPI documentation |
|
| 244 |
+
|
| 245 |
+
### Examples
|
| 246 |
+
|
| 247 |
+
```bash
|
| 248 |
+
# Start an episode pinned to the hard task
|
| 249 |
+
curl -X POST http://localhost:8000/reset \
|
| 250 |
+
-H "Content-Type: application/json" \
|
| 251 |
+
-d '{"task_id": "task_hard_cte"}'
|
| 252 |
+
|
| 253 |
+
# Submit a query
|
| 254 |
+
curl -X POST http://localhost:8000/step \
|
| 255 |
+
-H "Content-Type: application/json" \
|
| 256 |
+
-d '{"sql": "WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary FROM employees GROUP BY department_id) SELECT e.name, e.department_id, e.salary FROM employees e JOIN dept_avg d ON e.department_id = d.department_id WHERE e.salary > d.avg_salary ORDER BY e.department_id, e.salary DESC"}'
|
| 257 |
+
|
| 258 |
+
# List all available tasks
|
| 259 |
+
curl http://localhost:8000/tasks
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
---
|
| 263 |
+
|
| 264 |
+
## Python Client
|
| 265 |
+
|
| 266 |
+
```python
|
| 267 |
+
from queryforge import QueryforgeEnv, SQLAction
|
| 268 |
+
|
| 269 |
+
with QueryforgeEnv(base_url="http://localhost:8000") as env:
|
| 270 |
+
# Pin to a specific task
|
| 271 |
+
obs = env.reset(task_id="task_medium_join")
|
| 272 |
+
print(obs.task_description)
|
| 273 |
+
|
| 274 |
+
# Submit a fix
|
| 275 |
+
result = env.step(SQLAction(sql="""
|
| 276 |
+
SELECT u.name, p.title, SUM(o.amount) AS total_spent
|
| 277 |
+
FROM orders o
|
| 278 |
+
INNER JOIN users u ON o.user_id = u.id
|
| 279 |
+
INNER JOIN products p ON o.product_id = p.id
|
| 280 |
+
GROUP BY u.name, p.title
|
| 281 |
+
ORDER BY total_spent DESC
|
| 282 |
+
"""))
|
| 283 |
+
print(f"Score: {result.reward:.3f}")
|
| 284 |
+
print(f"Feedback: {result.observation.feedback}")
|
| 285 |
+
print(f"Done: {result.done}")
|
| 286 |
+
|
| 287 |
+
# Register and use a custom task
|
| 288 |
+
env.register_task(TaskSpec(
|
| 289 |
+
id="my_task",
|
| 290 |
+
schema_ddl="CREATE TABLE ...; INSERT INTO ...",
|
| 291 |
+
expected_rows=[{"col": "val"}],
|
| 292 |
+
title="My Custom Task",
|
| 293 |
+
))
|
| 294 |
+
obs = env.reset(task_id="my_task")
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
---
|
| 298 |
+
|
| 299 |
+
## Project Structure
|
| 300 |
+
|
| 301 |
+
```
|
| 302 |
+
queryforge/
|
| 303 |
+
βββ __init__.py # Public exports (SQLAction, SQLObservation, TaskSpec, REGISTRY)
|
| 304 |
+
βββ models.py # SQLAction, SQLObservation, TaskSpec Pydantic models
|
| 305 |
+
βββ tasks.py # Built-in tasks + thread-safe TaskRegistry
|
| 306 |
+
βββ judge.py # 4-stage grading pipeline (DuckDB + Anthropic)
|
| 307 |
+
βββ client.py # QueryforgeEnv client with task management helpers
|
| 308 |
+
βββ playbook.py # Local test runner (no server required)
|
| 309 |
+
βββ baseline.py # Baseline inference script (Claude as agent)
|
| 310 |
+
βββ openenv.yaml # OpenEnv manifest
|
| 311 |
+
βββ pyproject.toml # Project metadata and dependencies
|
| 312 |
+
βββ uv.lock # Locked dependencies
|
| 313 |
+
βββ server/
|
| 314 |
+
βββ app.py # FastAPI app β core + /tasks REST endpoints
|
| 315 |
+
βββ queryforge_environment.py # Environment class (reset, step, state)
|
| 316 |
+
βββ Dockerfile # Container image
|
| 317 |
+
βββ requirements.txt # Server dependencies
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## Deployment
|
| 323 |
+
|
| 324 |
+
### Hugging Face Spaces (recommended)
|
| 325 |
+
|
| 326 |
+
```bash
|
| 327 |
+
UV_CACHE_DIR=/tmp/uv-cache openenv push . --repo-id <hf-username>/queryforge
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
Add `ANTHROPIC_API_KEY` as a Space secret after deployment. Without it, the environment runs in deterministic-only mode (scores capped at 0.80, done threshold self-adjusts accordingly).
|
| 331 |
+
|
| 332 |
+
### Docker
|
| 333 |
+
|
| 334 |
+
```bash
|
| 335 |
+
docker build -t queryforge:latest -f server/Dockerfile .
|
| 336 |
+
docker run -p 8000:8000 -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY queryforge:latest
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
The deployed environment exposes:
|
| 340 |
+
- **`/web`** β Interactive UI for exploring the environment
|
| 341 |
+
- **`/docs`** β Full OpenAPI / Swagger interface
|
| 342 |
+
- **`/ws`** β WebSocket endpoint for persistent agent sessions
|
| 343 |
+
- **`/health`** β Container health monitoring
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
## Environment Design Notes
|
| 348 |
+
|
| 349 |
+
**Why DuckDB?** DuckDB runs fully in-memory with no external process or network dependency. Each `step()` call creates an isolated connection, seeds it with the task's schema, runs the agent's query, then closes β complete isolation with zero shared state between steps.
|
| 350 |
+
|
| 351 |
+
**Why a 4-stage reward?** Binary correct/incorrect rewards give an agent no gradient to climb when its query is simply broken. The 4-stage pipeline means every improvement β fixing a typo, avoiding a runtime error, returning the right row count, getting the right rows, writing clean SQL β is rewarded. This produces a smooth loss landscape for policy gradient methods.
|
| 352 |
+
|
| 353 |
+
**Why structural penalties?** Without them, an agent could achieve 0.80 on the hard CTE task by submitting the original correlated subquery verbatim (rows match, but the task was never solved). Structural penalties enforce that the agent actually learned *what* to change, not just that rows matched.
|
__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""QueryForge β SQL Debugger & Optimiser Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import QueryforgeEnv
|
| 10 |
+
from .models import SQLAction, SQLObservation, TaskSpec
|
| 11 |
+
from .tasks import TASKS, TASK_BY_ID, SQLTask, REGISTRY, task_from_dict
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"SQLAction",
|
| 15 |
+
"SQLObservation",
|
| 16 |
+
"TaskSpec",
|
| 17 |
+
"QueryforgeEnv",
|
| 18 |
+
"TASKS",
|
| 19 |
+
"TASK_BY_ID",
|
| 20 |
+
"SQLTask",
|
| 21 |
+
"REGISTRY",
|
| 22 |
+
"task_from_dict",
|
| 23 |
+
]
|
baseline.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryForge Baseline Inference Script
|
| 3 |
+
βββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
Runs a Claude model as an agent against all 3 built-in tasks and reports
|
| 5 |
+
a reproducible baseline score.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# All tasks, default model (claude-haiku-4-5):
|
| 9 |
+
python baseline.py
|
| 10 |
+
|
| 11 |
+
# Specific model:
|
| 12 |
+
python baseline.py --model claude-opus-4-6
|
| 13 |
+
|
| 14 |
+
# Single task:
|
| 15 |
+
python baseline.py --task task_easy_syntax
|
| 16 |
+
|
| 17 |
+
# More verbose output:
|
| 18 |
+
python baseline.py --verbose
|
| 19 |
+
|
| 20 |
+
Requirements:
|
| 21 |
+
ANTHROPIC_API_KEY must be set in the environment.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
import anthropic
|
| 30 |
+
|
| 31 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 32 |
+
|
| 33 |
+
from models import SQLAction
|
| 34 |
+
from server.queryforge_environment import QueryforgeEnvironment
|
| 35 |
+
from tasks import REGISTRY
|
| 36 |
+
|
| 37 |
+
# ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
DEFAULT_MODEL = "claude-haiku-4-5"
|
| 40 |
+
|
| 41 |
+
SYSTEM_PROMPT = """\
|
| 42 |
+
You are an expert SQL engineer. You will be given a SQL debugging or \
|
| 43 |
+
optimisation challenge. Your job is to submit a corrected or improved SQL query.
|
| 44 |
+
|
| 45 |
+
Rules:
|
| 46 |
+
- Respond with ONLY a single SQL query inside a ```sql ... ``` code block.
|
| 47 |
+
- Do not explain your reasoning outside the code block.
|
| 48 |
+
- Do not include multiple statements (no semicolons except at the very end).
|
| 49 |
+
- If you receive feedback on a previous attempt, use it to improve your query.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# ββ SQL extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
|
| 54 |
+
_SQL_BLOCK = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _extract_sql(text: str) -> str:
|
| 58 |
+
"""Pull the first SQL code block out of Claude's response."""
|
| 59 |
+
match = _SQL_BLOCK.search(text)
|
| 60 |
+
if match:
|
| 61 |
+
return match.group(1).strip()
|
| 62 |
+
# Fallback: return the whole response stripped β better than crashing
|
| 63 |
+
return text.strip()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ Formatting helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
|
| 68 |
+
def _hr(char="β", width=70):
|
| 69 |
+
print(char * width)
|
| 70 |
+
|
| 71 |
+
def _score_bar(score: float, width: int = 25) -> str:
|
| 72 |
+
filled = int(score * width)
|
| 73 |
+
bar = "β" * filled + "β" * (width - filled)
|
| 74 |
+
return f"[{bar}] {score:.3f}"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ββ Per-task agent loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
|
| 79 |
+
def run_task(
|
| 80 |
+
task_id: str,
|
| 81 |
+
model: str,
|
| 82 |
+
client: anthropic.Anthropic,
|
| 83 |
+
verbose: bool = False,
|
| 84 |
+
) -> dict:
|
| 85 |
+
"""
|
| 86 |
+
Run one episode of a single task.
|
| 87 |
+
|
| 88 |
+
Returns a dict with keys:
|
| 89 |
+
task_id, task_title, task_level,
|
| 90 |
+
best_score, attempts, done
|
| 91 |
+
"""
|
| 92 |
+
env = QueryforgeEnvironment()
|
| 93 |
+
obs = env.reset(task_id=task_id)
|
| 94 |
+
|
| 95 |
+
if obs.done:
|
| 96 |
+
# reset() returned an error (unknown task_id)
|
| 97 |
+
print(f" ERROR: {obs.feedback}")
|
| 98 |
+
return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
|
| 99 |
+
|
| 100 |
+
print(f"\n Task : {obs.task_title} [{obs.task_level}] (max {env._current_task.max_steps} steps)")
|
| 101 |
+
if verbose:
|
| 102 |
+
print(f" ID : {obs.task_id}")
|
| 103 |
+
|
| 104 |
+
# ββ Build initial conversation ββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
messages = [
|
| 106 |
+
{
|
| 107 |
+
"role": "user",
|
| 108 |
+
"content": (
|
| 109 |
+
f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
|
| 110 |
+
"Provide your fixed SQL query."
|
| 111 |
+
),
|
| 112 |
+
}
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
step = 0
|
| 116 |
+
while not obs.done:
|
| 117 |
+
step += 1
|
| 118 |
+
|
| 119 |
+
# ββ Call Claude βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
with client.messages.stream(
|
| 121 |
+
model=model,
|
| 122 |
+
max_tokens=512,
|
| 123 |
+
system=SYSTEM_PROMPT,
|
| 124 |
+
messages=messages,
|
| 125 |
+
) as stream:
|
| 126 |
+
response_text = ""
|
| 127 |
+
for text in stream.text_stream:
|
| 128 |
+
response_text += text
|
| 129 |
+
|
| 130 |
+
sql = _extract_sql(response_text)
|
| 131 |
+
|
| 132 |
+
if verbose:
|
| 133 |
+
print(f"\n ββ Step {step}")
|
| 134 |
+
short_sql = sql[:120] + ("β¦" if len(sql) > 120 else "")
|
| 135 |
+
print(f" SQL: {short_sql}")
|
| 136 |
+
|
| 137 |
+
# ββ Submit to environment βββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
obs = env.step(SQLAction(sql=sql))
|
| 139 |
+
|
| 140 |
+
score_bar = _score_bar(obs.reward or 0.0)
|
| 141 |
+
status = "β DONE" if obs.done else f"step {step}/{env._current_task.max_steps}"
|
| 142 |
+
print(f" [{status}] Score: {score_bar}")
|
| 143 |
+
|
| 144 |
+
if verbose and obs.feedback:
|
| 145 |
+
fb = obs.feedback[:200] + ("β¦" if len(obs.feedback) > 200 else "")
|
| 146 |
+
print(f" Feedback: {fb}")
|
| 147 |
+
|
| 148 |
+
if obs.done:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
# ββ Append exchange to conversation for next attempt ββββββββββββββββββ
|
| 152 |
+
messages.append({"role": "assistant", "content": response_text})
|
| 153 |
+
messages.append({
|
| 154 |
+
"role": "user",
|
| 155 |
+
"content": (
|
| 156 |
+
f"Your query scored {obs.reward:.3f}. Here is the feedback:\n\n"
|
| 157 |
+
f"{obs.feedback}\n\n"
|
| 158 |
+
f"Hint: {obs.hint}\n\n"
|
| 159 |
+
"Please try again with an improved SQL query."
|
| 160 |
+
),
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"task_id": task_id,
|
| 165 |
+
"task_title": obs.task_title,
|
| 166 |
+
"task_level": obs.task_level,
|
| 167 |
+
"best_score": obs.best_score,
|
| 168 |
+
"attempts": obs.attempt,
|
| 169 |
+
"done": obs.done,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 174 |
+
|
| 175 |
+
def main():
|
| 176 |
+
parser = argparse.ArgumentParser(description="QueryForge Baseline Inference")
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--model", default=DEFAULT_MODEL,
|
| 179 |
+
help=f"Anthropic model ID to use (default: {DEFAULT_MODEL})"
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--task", default=None,
|
| 183 |
+
help="Run a single task by ID instead of all built-in tasks"
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--verbose", action="store_true",
|
| 187 |
+
help="Print SQL queries and full feedback for each step"
|
| 188 |
+
)
|
| 189 |
+
args = parser.parse_args()
|
| 190 |
+
|
| 191 |
+
# ββ Validate API key ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
| 193 |
+
if not api_key:
|
| 194 |
+
print("ERROR: ANTHROPIC_API_KEY is not set.")
|
| 195 |
+
sys.exit(1)
|
| 196 |
+
|
| 197 |
+
client = anthropic.Anthropic(api_key=api_key)
|
| 198 |
+
|
| 199 |
+
# ββ Determine tasks to run ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 200 |
+
if args.task:
|
| 201 |
+
task_ids = [args.task]
|
| 202 |
+
else:
|
| 203 |
+
task_ids = ["task_easy_syntax", "task_medium_join", "task_hard_cte"]
|
| 204 |
+
|
| 205 |
+
# ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 206 |
+
_hr()
|
| 207 |
+
print(" QueryForge β Baseline Inference")
|
| 208 |
+
print(f" Model : {args.model}")
|
| 209 |
+
print(f" Tasks : {', '.join(task_ids)}")
|
| 210 |
+
_hr()
|
| 211 |
+
|
| 212 |
+
# ββ Run each task βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 213 |
+
results = []
|
| 214 |
+
for task_id in task_ids:
|
| 215 |
+
print(f"\n{'β' * 70}")
|
| 216 |
+
result = run_task(task_id, args.model, client, verbose=args.verbose)
|
| 217 |
+
results.append(result)
|
| 218 |
+
|
| 219 |
+
# ββ Results table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
+
print(f"\n{'β' * 70}")
|
| 221 |
+
print(" BASELINE RESULTS")
|
| 222 |
+
print(f" Model: {args.model}")
|
| 223 |
+
print(f"{'β' * 70}")
|
| 224 |
+
print(f" {'Task':<28} {'Level':<8} {'Steps':>5} {'Best Score'}")
|
| 225 |
+
print(f" {'β' * 28} {'β' * 8} {'β' * 5} {'β' * 30}")
|
| 226 |
+
|
| 227 |
+
total_score = 0.0
|
| 228 |
+
for r in results:
|
| 229 |
+
title = r.get("task_title", r["task_id"])[:27]
|
| 230 |
+
level = r.get("task_level", "?")
|
| 231 |
+
attempts = r.get("attempts", "?")
|
| 232 |
+
score = r["best_score"]
|
| 233 |
+
total_score += score
|
| 234 |
+
bar = _score_bar(score)
|
| 235 |
+
print(f" {title:<28} {level:<8} {attempts:>5} {bar}")
|
| 236 |
+
|
| 237 |
+
avg = total_score / len(results) if results else 0.0
|
| 238 |
+
print(f"{'β' * 70}")
|
| 239 |
+
print(f" {'AVERAGE':<28} {'':8} {'':5} {_score_bar(avg)}")
|
| 240 |
+
print(f"{'β' * 70}\n")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
main()
|
client.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""QueryForge Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import httpx
|
| 6 |
+
from openenv.core import EnvClient
|
| 7 |
+
from openenv.core.client_types import StepResult
|
| 8 |
+
from openenv.core.env_server.types import State
|
| 9 |
+
|
| 10 |
+
from .models import SQLAction, SQLObservation, TaskSpec
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class QueryforgeEnv(EnvClient[SQLAction, SQLObservation, State]):
|
| 14 |
+
"""
|
| 15 |
+
Client for the QueryForge SQL Debugger & Optimiser environment.
|
| 16 |
+
|
| 17 |
+
Maintains a persistent WebSocket connection to the environment server.
|
| 18 |
+
Each client instance has its own dedicated session (isolated task state).
|
| 19 |
+
|
| 20 |
+
Example:
|
| 21 |
+
>>> with QueryforgeEnv(base_url="http://localhost:8000") as env:
|
| 22 |
+
... obs = env.reset()
|
| 23 |
+
... print(obs.task_title)
|
| 24 |
+
... print(obs.task_description)
|
| 25 |
+
...
|
| 26 |
+
... result = env.step(SQLAction(sql="SELECT name, age FROM users WHERE age > 30"))
|
| 27 |
+
... print(result.reward, result.observation.feedback)
|
| 28 |
+
|
| 29 |
+
Example with Docker:
|
| 30 |
+
>>> env = QueryforgeEnv.from_docker_image("queryforge-env:latest")
|
| 31 |
+
>>> try:
|
| 32 |
+
... obs = env.reset()
|
| 33 |
+
... result = env.step(SQLAction(sql="SELECT ..."))
|
| 34 |
+
... finally:
|
| 35 |
+
... env.close()
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def _step_payload(self, action: SQLAction) -> Dict:
|
| 39 |
+
return {"sql": action.sql}
|
| 40 |
+
|
| 41 |
+
def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
|
| 42 |
+
obs_data = payload.get("observation", {})
|
| 43 |
+
observation = SQLObservation(
|
| 44 |
+
task_id=obs_data.get("task_id", ""),
|
| 45 |
+
task_level=obs_data.get("task_level", ""),
|
| 46 |
+
task_title=obs_data.get("task_title", ""),
|
| 47 |
+
task_description=obs_data.get("task_description", ""),
|
| 48 |
+
syntax_valid=obs_data.get("syntax_valid", False),
|
| 49 |
+
execution_success=obs_data.get("execution_success", False),
|
| 50 |
+
execution_error=obs_data.get("execution_error"),
|
| 51 |
+
rows_returned=obs_data.get("rows_returned", 0),
|
| 52 |
+
feedback=obs_data.get("feedback", ""),
|
| 53 |
+
hint=obs_data.get("hint", ""),
|
| 54 |
+
attempt=obs_data.get("attempt", 0),
|
| 55 |
+
best_score=obs_data.get("best_score", 0.0),
|
| 56 |
+
done=payload.get("done", False),
|
| 57 |
+
reward=payload.get("reward", 0.0),
|
| 58 |
+
metadata=obs_data.get("metadata", {}),
|
| 59 |
+
)
|
| 60 |
+
return StepResult(
|
| 61 |
+
observation=observation,
|
| 62 |
+
reward=payload.get("reward", 0.0),
|
| 63 |
+
done=payload.get("done", False),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 67 |
+
return State(
|
| 68 |
+
episode_id=payload.get("episode_id"),
|
| 69 |
+
step_count=payload.get("step_count", 0),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# ββ Task Registry helpers βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
def register_task(self, spec: TaskSpec) -> Dict[str, Any]:
|
| 75 |
+
"""Register a custom task on the server. Returns the server response dict."""
|
| 76 |
+
resp = httpx.post(
|
| 77 |
+
f"{self.base_url}/tasks",
|
| 78 |
+
json=spec.model_dump(),
|
| 79 |
+
timeout=10,
|
| 80 |
+
)
|
| 81 |
+
resp.raise_for_status()
|
| 82 |
+
return resp.json()
|
| 83 |
+
|
| 84 |
+
def list_tasks(self) -> List[Dict[str, Any]]:
|
| 85 |
+
"""Return all registered tasks (built-in + custom) as a list of dicts."""
|
| 86 |
+
resp = httpx.get(f"{self.base_url}/tasks", timeout=10)
|
| 87 |
+
resp.raise_for_status()
|
| 88 |
+
return resp.json()
|
| 89 |
+
|
| 90 |
+
def delete_task(self, task_id: str) -> Dict[str, Any]:
|
| 91 |
+
"""Remove a custom task by ID. Raises httpx.HTTPStatusError on 403/404."""
|
| 92 |
+
resp = httpx.delete(f"{self.base_url}/tasks/{task_id}", timeout=10)
|
| 93 |
+
resp.raise_for_status()
|
| 94 |
+
return resp.json()
|
judge.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryForge Judge β deterministic DuckDB grading + Anthropic AI quality scoring.
|
| 3 |
+
|
| 4 |
+
Grading pipeline for each submitted SQL query:
|
| 5 |
+
|
| 6 |
+
Stage 1 β Syntax (0.0 β 0.15)
|
| 7 |
+
DuckDB EXPLAIN parses the query. Fail β score = 0.0.
|
| 8 |
+
|
| 9 |
+
Stage 2 β Execution (β 0.30)
|
| 10 |
+
Run the full query against in-memory DuckDB seeded with task data.
|
| 11 |
+
Fail β score = 0.15 (syntax was fine, runtime error).
|
| 12 |
+
|
| 13 |
+
Stage 3 β Correctness (β 0.80)
|
| 14 |
+
Compare returned rows against expected rows.
|
| 15 |
+
Perfect match β deterministic score reaches 0.80.
|
| 16 |
+
Partial credit for correct row count or partial row matches.
|
| 17 |
+
|
| 18 |
+
Stage 4 β AI Quality (β 1.0)
|
| 19 |
+
Anthropic claude-sonnet-4-6 evaluates optimization, code style, and
|
| 20 |
+
semantic correctness vs. the reference solution.
|
| 21 |
+
The AI score can move the final score up to 1.0 when rows are correct,
|
| 22 |
+
or provide nuanced feedback even when rows are partially wrong.
|
| 23 |
+
|
| 24 |
+
Environment variable required:
|
| 25 |
+
ANTHROPIC_API_KEY β standard Anthropic SDK key.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import re
|
| 30 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import anthropic
|
| 33 |
+
import duckdb
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from .tasks import SQLTask, TestCase
|
| 37 |
+
except ImportError:
|
| 38 |
+
from tasks import SQLTask, TestCase
|
| 39 |
+
|
| 40 |
+
JUDGE_MODEL = "claude-haiku-4-5-20251001"
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Stage 1 β Syntax check
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
def _reject_multi_statement(query: str) -> Optional[str]:
|
| 46 |
+
"""Return an error message if the query contains multiple statements."""
|
| 47 |
+
# Strip string literals and comments before checking for semicolons
|
| 48 |
+
stripped = re.sub(r"'[^']*'", "", query) # remove string literals
|
| 49 |
+
stripped = re.sub(r"--[^\n]*", "", stripped) # remove line comments
|
| 50 |
+
stripped = re.sub(r"/\*.*?\*/", "", stripped, flags=re.DOTALL) # block comments
|
| 51 |
+
stripped = stripped.strip().rstrip(";") # allow a single trailing semicolon
|
| 52 |
+
if ";" in stripped:
|
| 53 |
+
return "Multi-statement queries are not allowed."
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def check_syntax(query: str) -> Tuple[bool, Optional[str]]:
|
| 58 |
+
"""
|
| 59 |
+
Return (is_valid, error_message).
|
| 60 |
+
|
| 61 |
+
Strategy: run EXPLAIN against an empty in-memory DuckDB.
|
| 62 |
+
- "Parser Error" in the exception β genuine syntax error β invalid.
|
| 63 |
+
- "Catalog Error" / "Binder Error" β tables unknown but syntax is fine β valid.
|
| 64 |
+
- Any other exception β treat as syntax error to be safe.
|
| 65 |
+
"""
|
| 66 |
+
multi_err = _reject_multi_statement(query)
|
| 67 |
+
if multi_err:
|
| 68 |
+
return False, multi_err
|
| 69 |
+
|
| 70 |
+
conn = duckdb.connect(":memory:")
|
| 71 |
+
try:
|
| 72 |
+
conn.execute(f"EXPLAIN {query}")
|
| 73 |
+
return True, None
|
| 74 |
+
except Exception as exc:
|
| 75 |
+
msg = str(exc)
|
| 76 |
+
# Catalog/Binder errors mean the SQL parsed fine; tables just aren't seeded.
|
| 77 |
+
if any(
|
| 78 |
+
tag in msg
|
| 79 |
+
for tag in ("Catalog Error", "Binder Error", "Table with name",
|
| 80 |
+
"Referenced column", "does not exist", "column")
|
| 81 |
+
):
|
| 82 |
+
return True, None
|
| 83 |
+
return False, msg
|
| 84 |
+
finally:
|
| 85 |
+
conn.close()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
# Stage 2 β Execution
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def execute_query(
|
| 93 |
+
schema_ddl: str, query: str
|
| 94 |
+
) -> Tuple[bool, Optional[List[Dict[str, Any]]], Optional[str]]:
|
| 95 |
+
"""
|
| 96 |
+
Seed a fresh DuckDB in-memory DB with *schema_ddl*, then run *query*.
|
| 97 |
+
Returns (success, rows_as_list_of_dicts, error_message).
|
| 98 |
+
"""
|
| 99 |
+
conn = duckdb.connect(":memory:")
|
| 100 |
+
try:
|
| 101 |
+
conn.execute(schema_ddl)
|
| 102 |
+
result = conn.execute(query).fetchdf()
|
| 103 |
+
rows = result.to_dict(orient="records")
|
| 104 |
+
# Convert numpy types to native Python
|
| 105 |
+
clean: List[Dict[str, Any]] = []
|
| 106 |
+
for row in rows:
|
| 107 |
+
clean.append({k: _native(v) for k, v in row.items()})
|
| 108 |
+
return True, clean, None
|
| 109 |
+
except Exception as exc:
|
| 110 |
+
return False, None, str(exc)
|
| 111 |
+
finally:
|
| 112 |
+
conn.close()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _native(value: Any) -> Any:
|
| 116 |
+
"""Convert numpy scalars β native Python types for JSON-safe comparison."""
|
| 117 |
+
try:
|
| 118 |
+
import numpy as np # duckdb fetchdf() returns numpy types
|
| 119 |
+
if isinstance(value, (np.integer,)):
|
| 120 |
+
return int(value)
|
| 121 |
+
if isinstance(value, (np.floating,)):
|
| 122 |
+
return float(value)
|
| 123 |
+
if isinstance(value, np.bool_):
|
| 124 |
+
return bool(value)
|
| 125 |
+
except ImportError:
|
| 126 |
+
pass
|
| 127 |
+
return value
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# Stage 3 β Row correctness
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
def _normalize(row: Dict[str, Any]) -> Dict[str, Any]:
|
| 135 |
+
"""Round floats to 2 dp so 999.99000000001 == 999.99."""
|
| 136 |
+
return {
|
| 137 |
+
k: (round(float(v), 2) if isinstance(v, float) else v)
|
| 138 |
+
for k, v in row.items()
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _sort_key(row: Dict[str, Any], order_by: Optional[str]) -> tuple:
|
| 143 |
+
if order_by:
|
| 144 |
+
cols = [c.strip() for c in order_by.split(",")]
|
| 145 |
+
return tuple(str(row.get(c, "")) for c in cols)
|
| 146 |
+
return tuple(str(v) for v in row.values())
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def rows_match(
|
| 150 |
+
actual: List[Dict[str, Any]],
|
| 151 |
+
expected: List[Dict[str, Any]],
|
| 152 |
+
order_by: Optional[str] = None,
|
| 153 |
+
) -> Tuple[float, str]:
|
| 154 |
+
"""
|
| 155 |
+
Compare *actual* vs *expected* rows.
|
| 156 |
+
|
| 157 |
+
Scoring:
|
| 158 |
+
1.0 β exact match
|
| 159 |
+
0.5β0.9 β row count matches, some rows differ
|
| 160 |
+
0.3 β row count wrong but partial overlap
|
| 161 |
+
0.0 β empty when non-empty expected
|
| 162 |
+
"""
|
| 163 |
+
if not expected:
|
| 164 |
+
return (1.0, "No expected rows β query accepted.") if not actual else (
|
| 165 |
+
0.8, f"Expected empty result but got {len(actual)} row(s)."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if not actual:
|
| 169 |
+
return 0.0, f"Query returned 0 rows; expected {len(expected)}."
|
| 170 |
+
|
| 171 |
+
# Project actual rows to only the expected columns (agent may SELECT extra).
|
| 172 |
+
# Use case-insensitive matching: build a map from lower(actual_col) β actual_col.
|
| 173 |
+
expected_cols = list(expected[0].keys())
|
| 174 |
+
lower_map = {k.lower(): k for k in actual[0].keys()} if actual else {}
|
| 175 |
+
|
| 176 |
+
def _project(row: Dict[str, Any]) -> Dict[str, Any]:
|
| 177 |
+
out: Dict[str, Any] = {}
|
| 178 |
+
for ec in expected_cols:
|
| 179 |
+
actual_key = lower_map.get(ec.lower())
|
| 180 |
+
if actual_key is not None:
|
| 181 |
+
out[ec] = row[actual_key]
|
| 182 |
+
return out
|
| 183 |
+
|
| 184 |
+
projected = [_project(row) for row in actual]
|
| 185 |
+
|
| 186 |
+
if len(projected) != len(expected):
|
| 187 |
+
overlap_ratio = min(len(projected), len(expected)) / max(len(projected), len(expected))
|
| 188 |
+
score = 0.3 * overlap_ratio
|
| 189 |
+
return score, (
|
| 190 |
+
f"Row count mismatch: got {len(projected)}, expected {len(expected)}. "
|
| 191 |
+
f"({overlap_ratio:.0%} overlap ratio)"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
actual_sorted = sorted([_normalize(r) for r in projected], key=lambda r: _sort_key(r, order_by))
|
| 195 |
+
expected_sorted = sorted([_normalize(r) for r in expected], key=lambda r: _sort_key(r, order_by))
|
| 196 |
+
|
| 197 |
+
matches = sum(1 for a, e in zip(actual_sorted, expected_sorted) if a == e)
|
| 198 |
+
row_accuracy = matches / len(expected)
|
| 199 |
+
|
| 200 |
+
if row_accuracy == 1.0:
|
| 201 |
+
return 1.0, "All rows match perfectly."
|
| 202 |
+
|
| 203 |
+
score = 0.5 + 0.4 * row_accuracy
|
| 204 |
+
return score, f"{matches}/{len(expected)} rows match correctly."
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
# Stage 4 β Anthropic AI judge
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
def call_anthropic_judge(
|
| 212 |
+
task: SQLTask,
|
| 213 |
+
agent_query: str,
|
| 214 |
+
execution_success: bool,
|
| 215 |
+
execution_error: Optional[str],
|
| 216 |
+
actual_rows: Optional[List[Dict[str, Any]]],
|
| 217 |
+
deterministic_score: float,
|
| 218 |
+
) -> Tuple[float, str, str]:
|
| 219 |
+
"""
|
| 220 |
+
Call claude-sonnet-4-6 to evaluate query quality across three axes:
|
| 221 |
+
- Correctness (0β0.50)
|
| 222 |
+
- Optimization (0β0.30) β avoids inefficiencies, uses best SQL patterns
|
| 223 |
+
- Code quality (0β0.20) β readable, well-aliased, idiomatic SQL
|
| 224 |
+
|
| 225 |
+
Returns (final_score, feedback, improvement_hint).
|
| 226 |
+
Falls back to deterministic_score if the API call fails.
|
| 227 |
+
"""
|
| 228 |
+
client = anthropic.Anthropic()
|
| 229 |
+
|
| 230 |
+
sample_actual = json.dumps(actual_rows[:5] if actual_rows else [], indent=2)
|
| 231 |
+
sample_expected = json.dumps(
|
| 232 |
+
task.test_cases[0].expected_rows if task.test_cases else [], indent=2
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
prompt = f"""\
|
| 236 |
+
You are a strict SQL expert judge scoring an agent's query for the task below.
|
| 237 |
+
|
| 238 |
+
## Task ({task.level})
|
| 239 |
+
{task.description}
|
| 240 |
+
|
| 241 |
+
## Agent Query
|
| 242 |
+
```sql
|
| 243 |
+
{agent_query}
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
## Execution
|
| 247 |
+
- Success: {execution_success}
|
| 248 |
+
- Error: {execution_error or "None"}
|
| 249 |
+
- Rows returned (first 5): {sample_actual}
|
| 250 |
+
- Expected rows: {sample_expected}
|
| 251 |
+
|
| 252 |
+
## Reference Solution
|
| 253 |
+
```sql
|
| 254 |
+
{task.solution_query}
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
## Deterministic row-match score (0.0β1.0): {deterministic_score:.3f}
|
| 258 |
+
|
| 259 |
+
Score the agent query on THREE axes and sum them for the final score:
|
| 260 |
+
|
| 261 |
+
| Axis | Max | Criteria |
|
| 262 |
+
|--------------|------|----------|
|
| 263 |
+
| Correctness | 0.50 | Produces the right rows for the stated goal |
|
| 264 |
+
| Optimization | 0.30 | Avoids cartesian products / correlated subqueries; uses efficient patterns (CTEs, explicit JOINs, proper GROUP BY) |
|
| 265 |
+
| Code quality | 0.20 | Readable aliases, clean formatting, no redundant clauses |
|
| 266 |
+
|
| 267 |
+
IMPORTANT rules:
|
| 268 |
+
- If execution failed with a runtime error, Correctness β€ 0.10.
|
| 269 |
+
- If rows are fully correct per deterministic score β₯ 0.95, Correctness β₯ 0.40.
|
| 270 |
+
- For the medium task: a query that still uses comma-join syntax scores Optimization β€ 0.05.
|
| 271 |
+
- For the hard task: a query without a CTE scores Optimization β€ 0.10.
|
| 272 |
+
|
| 273 |
+
Respond with ONLY valid JSON (no markdown fences):
|
| 274 |
+
{{
|
| 275 |
+
"correctness": <float 0.0β0.50>,
|
| 276 |
+
"optimization": <float 0.0β0.30>,
|
| 277 |
+
"code_quality": <float 0.0β0.20>,
|
| 278 |
+
"score": <sum of above, float 0.0β1.0>,
|
| 279 |
+
"feedback": "<2β3 sentences summarising what the agent did right/wrong>",
|
| 280 |
+
"hint": "<one concrete actionable improvement, or 'Excellent!' if score >= 0.95>"
|
| 281 |
+
}}"""
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
message = client.messages.create(
|
| 285 |
+
model="claude-sonnet-4-6",
|
| 286 |
+
max_tokens=512,
|
| 287 |
+
messages=[{"role": "user", "content": prompt}],
|
| 288 |
+
)
|
| 289 |
+
raw = message.content[0].text.strip()
|
| 290 |
+
|
| 291 |
+
# Strip accidental markdown fences
|
| 292 |
+
if raw.startswith("```"):
|
| 293 |
+
raw = raw.split("```")[1]
|
| 294 |
+
if raw.startswith("json"):
|
| 295 |
+
raw = raw[4:]
|
| 296 |
+
raw = raw.rsplit("```", 1)[0].strip()
|
| 297 |
+
|
| 298 |
+
data = json.loads(raw)
|
| 299 |
+
score = float(data["score"])
|
| 300 |
+
score = max(0.0, min(1.0, score))
|
| 301 |
+
feedback = str(data.get("feedback", ""))
|
| 302 |
+
hint = str(data.get("hint", ""))
|
| 303 |
+
return score, feedback, hint
|
| 304 |
+
|
| 305 |
+
except Exception as exc:
|
| 306 |
+
# Graceful fallback β no API key, network error, or parse failure
|
| 307 |
+
msg = str(exc).lower()
|
| 308 |
+
reason = (
|
| 309 |
+
"no ANTHROPIC_API_KEY set"
|
| 310 |
+
if "api_key" in msg or "auth" in msg or "authentication" in msg
|
| 311 |
+
else type(exc).__name__
|
| 312 |
+
)
|
| 313 |
+
return (
|
| 314 |
+
deterministic_score,
|
| 315 |
+
f"AI judge offline ({reason}). Using deterministic score.",
|
| 316 |
+
task.hint,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
# Public entry point
|
| 322 |
+
# ---------------------------------------------------------------------------
|
| 323 |
+
|
| 324 |
+
def grade(
|
| 325 |
+
task: SQLTask, agent_query: str
|
| 326 |
+
) -> Tuple[float, str, Dict[str, Any]]:
|
| 327 |
+
"""
|
| 328 |
+
Full grading pipeline. Returns (score 0.0β1.0, feedback, details_dict).
|
| 329 |
+
|
| 330 |
+
Partial progress scoring:
|
| 331 |
+
0.00 β syntax error (unparseable)
|
| 332 |
+
0.15 β syntax valid, runtime error
|
| 333 |
+
0.30 β executes, but 0 rows returned
|
| 334 |
+
0.30β0.80 β partial row matches (deterministic)
|
| 335 |
+
0.80β1.00 β correct rows + AI quality assessment
|
| 336 |
+
"""
|
| 337 |
+
details: Dict[str, Any] = {}
|
| 338 |
+
|
| 339 |
+
# ββ Stage 1: syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
syntax_ok, syntax_error = check_syntax(agent_query)
|
| 341 |
+
details["syntax_valid"] = syntax_ok
|
| 342 |
+
details["syntax_error"] = syntax_error
|
| 343 |
+
|
| 344 |
+
if not syntax_ok:
|
| 345 |
+
return 0.0, f"Syntax error: {syntax_error}", details
|
| 346 |
+
|
| 347 |
+
# ββ Stage 2: execution βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 348 |
+
exec_ok, rows, exec_error = execute_query(task.schema_ddl, agent_query)
|
| 349 |
+
details["execution_success"] = exec_ok
|
| 350 |
+
details["execution_error"] = exec_error
|
| 351 |
+
details["rows_returned"] = len(rows) if rows else 0
|
| 352 |
+
|
| 353 |
+
if not exec_ok:
|
| 354 |
+
# Syntax valid but runtime error β call AI for nuanced feedback
|
| 355 |
+
ai_score, ai_feedback, ai_hint = call_anthropic_judge(
|
| 356 |
+
task, agent_query, False, exec_error, None, 0.15
|
| 357 |
+
)
|
| 358 |
+
details["ai_score"] = ai_score
|
| 359 |
+
details["ai_feedback"] = ai_feedback
|
| 360 |
+
final = max(0.15, ai_score * 0.3) # cap at 0.3 when execution fails
|
| 361 |
+
return final, f"Runtime error: {exec_error} | AI: {ai_feedback}", details
|
| 362 |
+
|
| 363 |
+
# ββ Stage 3: row correctness βββββββββββββββββββββββββββββββββββββββββββββ
|
| 364 |
+
test_case = task.test_cases[0]
|
| 365 |
+
row_score, row_feedback = rows_match(rows, test_case.expected_rows, test_case.order_by)
|
| 366 |
+
details["row_match_score"] = row_score
|
| 367 |
+
details["row_match_feedback"] = row_feedback
|
| 368 |
+
|
| 369 |
+
# ββ Stage 3b: structural checks (task-specific) βββββββββββββββββββββββββ
|
| 370 |
+
# These prevent high scores when the agent submits the broken query verbatim
|
| 371 |
+
# or ignores the task's structural requirement.
|
| 372 |
+
structural_penalty = 0.0
|
| 373 |
+
query_upper = agent_query.upper()
|
| 374 |
+
|
| 375 |
+
if task.level == "hard" and "WITH " not in query_upper:
|
| 376 |
+
structural_penalty = 0.30 # hard task demands a CTE
|
| 377 |
+
row_feedback += " (Penalty: no CTE detected β task requires WITH clause.)"
|
| 378 |
+
elif task.level == "medium" and "JOIN " not in query_upper:
|
| 379 |
+
structural_penalty = 0.20 # medium task demands explicit JOINs
|
| 380 |
+
row_feedback += " (Penalty: no explicit JOIN β task requires JOIN β¦ ON syntax.)"
|
| 381 |
+
|
| 382 |
+
details["structural_penalty"] = structural_penalty
|
| 383 |
+
|
| 384 |
+
# Deterministic score: 0.30 base for executing + up to 0.50 for rows β penalty
|
| 385 |
+
deterministic_score = max(0.30, 0.30 + 0.50 * row_score - structural_penalty)
|
| 386 |
+
|
| 387 |
+
# ββ Stage 4: AI quality ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
ai_score, ai_feedback, ai_hint = call_anthropic_judge(
|
| 389 |
+
task, agent_query, True, None, rows, deterministic_score
|
| 390 |
+
)
|
| 391 |
+
details["ai_score"] = ai_score
|
| 392 |
+
details["ai_feedback"] = ai_feedback
|
| 393 |
+
details["ai_hint"] = ai_hint
|
| 394 |
+
|
| 395 |
+
# Final blending:
|
| 396 |
+
# rows fully correct β trust AI score (can reach 1.0)
|
| 397 |
+
# rows partially wrong β clamp AI score to not exceed deterministic
|
| 398 |
+
if row_score >= 0.95:
|
| 399 |
+
final_score = ai_score
|
| 400 |
+
elif row_score >= 0.5:
|
| 401 |
+
# Blend: AI provides nuance but can't exceed deterministic ceiling
|
| 402 |
+
final_score = min(deterministic_score, ai_score + 0.05)
|
| 403 |
+
else:
|
| 404 |
+
# Low row accuracy β stay near deterministic
|
| 405 |
+
final_score = min(deterministic_score, ai_score * 0.6)
|
| 406 |
+
|
| 407 |
+
final_score = max(0.0, min(1.0, final_score))
|
| 408 |
+
|
| 409 |
+
feedback = (
|
| 410 |
+
f"[Rows] {row_feedback} "
|
| 411 |
+
f"[AI Judge] {ai_feedback} "
|
| 412 |
+
f"[Hint] {ai_hint}"
|
| 413 |
+
)
|
| 414 |
+
return final_score, feedback, details
|
models.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the QueryForge SQL environment.
|
| 3 |
+
|
| 4 |
+
SQLAction β the agent's submitted SQL query.
|
| 5 |
+
SQLObservation β task description + grading feedback returned after each step.
|
| 6 |
+
TaskSpec β payload for registering a custom task via POST /tasks.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
from openenv.core.env_server.types import Action, Observation
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SQLAction(Action):
|
| 16 |
+
"""Action: submit a SQL query for evaluation."""
|
| 17 |
+
|
| 18 |
+
sql: str = Field(..., description="The SQL query to submit for grading")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SQLObservation(Observation):
|
| 22 |
+
"""Observation returned after reset() or step()."""
|
| 23 |
+
|
| 24 |
+
# ββ Task context βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
task_id: str = Field(default="", description="Active task identifier")
|
| 26 |
+
task_level: str = Field(
|
| 27 |
+
default="", description="Difficulty: easy | medium | hard"
|
| 28 |
+
)
|
| 29 |
+
task_title: str = Field(default="", description="Human-readable task title")
|
| 30 |
+
task_description: str = Field(
|
| 31 |
+
default="",
|
| 32 |
+
description=(
|
| 33 |
+
"Full task description: schema, broken query, error message, and goal"
|
| 34 |
+
),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# ββ Per-step grading signals ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
syntax_valid: bool = Field(
|
| 39 |
+
default=False, description="True if the submitted query parsed without error"
|
| 40 |
+
)
|
| 41 |
+
execution_success: bool = Field(
|
| 42 |
+
default=False, description="True if the query ran to completion in DuckDB"
|
| 43 |
+
)
|
| 44 |
+
execution_error: Optional[str] = Field(
|
| 45 |
+
default=None, description="Runtime error message, if any"
|
| 46 |
+
)
|
| 47 |
+
rows_returned: int = Field(
|
| 48 |
+
default=0, description="Number of rows the query returned"
|
| 49 |
+
)
|
| 50 |
+
feedback: str = Field(
|
| 51 |
+
default="",
|
| 52 |
+
description="Detailed grading feedback from DuckDB + AI judge",
|
| 53 |
+
)
|
| 54 |
+
hint: str = Field(
|
| 55 |
+
default="", description="Actionable hint for the next attempt"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# ββ Episode progress ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
attempt: int = Field(
|
| 60 |
+
default=0, description="Number of queries submitted this episode"
|
| 61 |
+
)
|
| 62 |
+
best_score: float = Field(
|
| 63 |
+
default=0.0, description="Highest score achieved so far this episode"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TaskSpec(BaseModel):
|
| 68 |
+
"""
|
| 69 |
+
Payload for registering a custom SQL task via POST /tasks
|
| 70 |
+
or directly via REGISTRY.register(task_from_dict(spec.model_dump())).
|
| 71 |
+
|
| 72 |
+
Required: id, schema_ddl, expected_rows
|
| 73 |
+
Everything else has sensible defaults.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
id: str = Field(
|
| 77 |
+
..., description="Unique task identifier, e.g. 'null_handling_task'"
|
| 78 |
+
)
|
| 79 |
+
level: str = Field(
|
| 80 |
+
default="custom",
|
| 81 |
+
description="Difficulty label: easy | medium | hard | custom",
|
| 82 |
+
)
|
| 83 |
+
title: str = Field(..., description="Human-readable task title")
|
| 84 |
+
description: str = Field(
|
| 85 |
+
default="",
|
| 86 |
+
description="Full task description shown to the agent (schema, goal, etc.)",
|
| 87 |
+
)
|
| 88 |
+
schema_ddl: str = Field(
|
| 89 |
+
...,
|
| 90 |
+
description="CREATE TABLE + INSERT statements to seed the DuckDB test DB",
|
| 91 |
+
)
|
| 92 |
+
broken_query: str = Field(
|
| 93 |
+
default="",
|
| 94 |
+
description="The broken or slow query the agent must fix",
|
| 95 |
+
)
|
| 96 |
+
error_message: str = Field(
|
| 97 |
+
default="",
|
| 98 |
+
description="Error or performance warning shown to the agent alongside the task",
|
| 99 |
+
)
|
| 100 |
+
hint: str = Field(
|
| 101 |
+
default="",
|
| 102 |
+
description="Actionable hint surfaced in the observation after each wrong attempt",
|
| 103 |
+
)
|
| 104 |
+
expected_rows: List[Dict[str, Any]] = Field(
|
| 105 |
+
...,
|
| 106 |
+
description=(
|
| 107 |
+
"Exact rows the correct query must return. "
|
| 108 |
+
"Used for deterministic row-match scoring."
|
| 109 |
+
),
|
| 110 |
+
)
|
| 111 |
+
order_by: Optional[str] = Field(
|
| 112 |
+
default=None,
|
| 113 |
+
description="Comma-separated column names used to sort rows before comparison",
|
| 114 |
+
)
|
| 115 |
+
solution_query: str = Field(
|
| 116 |
+
default="",
|
| 117 |
+
description="Reference solution shown to the AI judge for quality scoring",
|
| 118 |
+
)
|
| 119 |
+
test_description: str = Field(
|
| 120 |
+
default="Custom test case",
|
| 121 |
+
description="One-line description of what the test case checks",
|
| 122 |
+
)
|
| 123 |
+
max_steps: int = Field(
|
| 124 |
+
default=5, ge=1, le=20,
|
| 125 |
+
description="Maximum number of step() calls allowed per episode",
|
| 126 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: queryforge
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
| 8 |
+
description: |
|
| 9 |
+
SQL Query Debugger & Optimiser environment.
|
| 10 |
+
|
| 11 |
+
An agent receives a broken or slow SQL query together with the schema and an
|
| 12 |
+
error/performance warning. It must produce a working, optimised query.
|
| 13 |
+
|
| 14 |
+
Tasks (3 levels, cycled in order):
|
| 15 |
+
easy β fix three misspelled SQL keywords (SELECT / FROM / WHERE)
|
| 16 |
+
medium β fix a missing JOIN condition that causes a cartesian product
|
| 17 |
+
hard β rewrite a correlated subquery (O(NΒ²)) as a CTE (O(N))
|
| 18 |
+
|
| 19 |
+
Reward signal (0.0 β 1.0):
|
| 20 |
+
0.00 syntax error
|
| 21 |
+
0.15 syntax valid, runtime error
|
| 22 |
+
0.30 executes, wrong / empty results
|
| 23 |
+
0.30β0.80 partial row correctness (deterministic, DuckDB)
|
| 24 |
+
0.80β1.00 correct results + AI quality score (Anthropic claude-sonnet-4-6)
|
| 25 |
+
|
| 26 |
+
Required env var: ANTHROPIC_API_KEY
|
openenv_queryforge.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-queryforge
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Queryforge environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.1
|
| 7 |
+
Requires-Dist: duckdb>=0.10.0
|
| 8 |
+
Requires-Dist: anthropic>=0.25.0
|
| 9 |
+
Provides-Extra: dev
|
| 10 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 11 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_queryforge.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./judge.py
|
| 6 |
+
./models.py
|
| 7 |
+
./tasks.py
|
| 8 |
+
openenv_queryforge.egg-info/PKG-INFO
|
| 9 |
+
openenv_queryforge.egg-info/SOURCES.txt
|
| 10 |
+
openenv_queryforge.egg-info/dependency_links.txt
|
| 11 |
+
openenv_queryforge.egg-info/entry_points.txt
|
| 12 |
+
openenv_queryforge.egg-info/requires.txt
|
| 13 |
+
openenv_queryforge.egg-info/top_level.txt
|
| 14 |
+
server/__init__.py
|
| 15 |
+
server/app.py
|
| 16 |
+
server/queryforge_environment.py
|
openenv_queryforge.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_queryforge.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = queryforge.server.app:main
|
openenv_queryforge.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
duckdb>=0.10.0
|
| 3 |
+
anthropic>=0.25.0
|
| 4 |
+
|
| 5 |
+
[dev]
|
| 6 |
+
pytest>=8.0.0
|
| 7 |
+
pytest-cov>=4.0.0
|
openenv_queryforge.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
queryforge
|
playbook.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryForge Local Playbook
|
| 3 |
+
βββββββββββββββββββββββββ
|
| 4 |
+
Tests the environment directly (no HTTP server needed).
|
| 5 |
+
|
| 6 |
+
Run from the queryforge directory:
|
| 7 |
+
.venv/bin/python playbook.py
|
| 8 |
+
|
| 9 |
+
If ANTHROPIC_API_KEY is set, Stage 4 AI scoring is live.
|
| 10 |
+
If not set, the judge falls back to deterministic scoring (capped at 0.80).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import textwrap
|
| 16 |
+
|
| 17 |
+
# Make imports work whether run directly or as a module
|
| 18 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 19 |
+
|
| 20 |
+
from server.queryforge_environment import QueryforgeEnvironment
|
| 21 |
+
from models import SQLAction
|
| 22 |
+
from tasks import REGISTRY, task_from_dict
|
| 23 |
+
|
| 24 |
+
# ββ Formatting helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
|
| 26 |
+
def _hr(char="β", width=70):
|
| 27 |
+
print(char * width)
|
| 28 |
+
|
| 29 |
+
def _section(title):
|
| 30 |
+
print()
|
| 31 |
+
_hr()
|
| 32 |
+
print(f" {title}")
|
| 33 |
+
_hr()
|
| 34 |
+
|
| 35 |
+
def _score_bar(score: float, width: int = 30) -> str:
|
| 36 |
+
filled = int(score * width)
|
| 37 |
+
bar = "β" * filled + "β" * (width - filled)
|
| 38 |
+
return f"[{bar}] {score:.2f}"
|
| 39 |
+
|
| 40 |
+
def _print_obs(obs, show_description=False):
|
| 41 |
+
if show_description:
|
| 42 |
+
print()
|
| 43 |
+
print(textwrap.indent(obs.task_description, " "))
|
| 44 |
+
print()
|
| 45 |
+
if obs.feedback and obs.feedback != "New task loaded. Submit your fixed/optimised SQL query.":
|
| 46 |
+
print(f" Syntax valid : {obs.syntax_valid}")
|
| 47 |
+
print(f" Execution OK : {obs.execution_success}")
|
| 48 |
+
if obs.execution_error:
|
| 49 |
+
print(f" Execution error : {obs.execution_error[:100]}")
|
| 50 |
+
print(f" Rows returned : {obs.rows_returned}")
|
| 51 |
+
print(f" Score : {_score_bar(obs.reward or 0.0)}")
|
| 52 |
+
print(f" Best this ep. : {_score_bar(obs.best_score)}")
|
| 53 |
+
# Print just the first 200 chars of feedback to keep output clean
|
| 54 |
+
fb = obs.feedback[:250] + ("β¦" if len(obs.feedback) > 250 else "")
|
| 55 |
+
print(f" Feedback : {fb}")
|
| 56 |
+
if obs.hint:
|
| 57 |
+
print(f" Hint : {obs.hint[:120]}")
|
| 58 |
+
|
| 59 |
+
def _attempt(env, label: str, sql: str):
|
| 60 |
+
print(f"\n ββ Attempt: {label}")
|
| 61 |
+
print(f" SQL: {sql[:100]}{'β¦' if len(sql) > 100 else ''}")
|
| 62 |
+
obs = env.step(SQLAction(sql=sql))
|
| 63 |
+
_print_obs(obs)
|
| 64 |
+
return obs
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ββ Task runners ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
|
| 69 |
+
def run_easy(env):
|
| 70 |
+
_section("TASK 1 Β· EASY β Fix Syntax Errors")
|
| 71 |
+
env._task_index = 0 # pin to easy
|
| 72 |
+
obs = env.reset()
|
| 73 |
+
print(f"\n Task : {obs.task_title} [{obs.task_level}]")
|
| 74 |
+
print(f" Steps: up to {5}")
|
| 75 |
+
_print_obs(obs, show_description=True)
|
| 76 |
+
|
| 77 |
+
_attempt(env, "still broken",
|
| 78 |
+
"SELEC name, age FORM users WEHRE age > 30")
|
| 79 |
+
|
| 80 |
+
_attempt(env, "one keyword fixed",
|
| 81 |
+
"SELECT name, age FORM users WEHRE age > 30")
|
| 82 |
+
|
| 83 |
+
_attempt(env, "all keywords fixed, no filter",
|
| 84 |
+
"SELECT name, age FROM users WHERE age > 30")
|
| 85 |
+
|
| 86 |
+
obs = _attempt(env, "correct solution",
|
| 87 |
+
"SELECT name, age FROM users "
|
| 88 |
+
"WHERE age > 30 AND city = 'New York' "
|
| 89 |
+
"ORDER BY name ASC")
|
| 90 |
+
|
| 91 |
+
print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def run_medium(env):
|
| 95 |
+
_section("TASK 2 Β· MEDIUM β Fix the Cartesian JOIN")
|
| 96 |
+
env._task_index = 1 # pin to medium
|
| 97 |
+
obs = env.reset()
|
| 98 |
+
print(f"\n Task : {obs.task_title} [{obs.task_level}]")
|
| 99 |
+
print(f" Steps: up to {5}")
|
| 100 |
+
_print_obs(obs, show_description=True)
|
| 101 |
+
|
| 102 |
+
_attempt(env, "broken verbatim (cartesian product)",
|
| 103 |
+
"SELECT u.name, p.title, SUM(o.amount) AS total_spent "
|
| 104 |
+
"FROM orders o, users u, products p "
|
| 105 |
+
"WHERE o.user_id = u.id "
|
| 106 |
+
"GROUP BY u.name, p.title "
|
| 107 |
+
"ORDER BY total_spent DESC")
|
| 108 |
+
|
| 109 |
+
_attempt(env, "comma-join but missing product condition",
|
| 110 |
+
"SELECT u.name, p.title, SUM(o.amount) AS total_spent "
|
| 111 |
+
"FROM orders o, users u, products p "
|
| 112 |
+
"WHERE o.user_id = u.id AND o.product_id = p.id "
|
| 113 |
+
"GROUP BY u.name, p.title "
|
| 114 |
+
"ORDER BY total_spent DESC")
|
| 115 |
+
|
| 116 |
+
obs = _attempt(env, "correct INNER JOINs",
|
| 117 |
+
"SELECT u.name, p.title, SUM(o.amount) AS total_spent\n"
|
| 118 |
+
"FROM orders o\n"
|
| 119 |
+
"INNER JOIN users u ON o.user_id = u.id\n"
|
| 120 |
+
"INNER JOIN products p ON o.product_id = p.id\n"
|
| 121 |
+
"GROUP BY u.name, p.title\n"
|
| 122 |
+
"ORDER BY total_spent DESC")
|
| 123 |
+
|
| 124 |
+
print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def run_hard(env):
|
| 128 |
+
_section("TASK 3 Β· HARD β Rewrite Correlated Subquery as CTE")
|
| 129 |
+
env._task_index = 2 # pin to hard
|
| 130 |
+
obs = env.reset()
|
| 131 |
+
print(f"\n Task : {obs.task_title} [{obs.task_level}]")
|
| 132 |
+
print(f" Steps: up to {6}")
|
| 133 |
+
_print_obs(obs, show_description=True)
|
| 134 |
+
|
| 135 |
+
_attempt(env, "broken verbatim (no CTE β penalised even though rows match)",
|
| 136 |
+
"SELECT e.name, e.department_id, e.salary\n"
|
| 137 |
+
"FROM employees e\n"
|
| 138 |
+
"WHERE e.salary > (\n"
|
| 139 |
+
" SELECT AVG(e2.salary) FROM employees e2\n"
|
| 140 |
+
" WHERE e2.department_id = e.department_id\n"
|
| 141 |
+
")\n"
|
| 142 |
+
"ORDER BY e.department_id, e.salary DESC")
|
| 143 |
+
|
| 144 |
+
_attempt(env, "halfway β CTE defined but wrong join",
|
| 145 |
+
"WITH dept_avg AS (\n"
|
| 146 |
+
" SELECT department_id, AVG(salary) AS avg_salary\n"
|
| 147 |
+
" FROM employees GROUP BY department_id\n"
|
| 148 |
+
")\n"
|
| 149 |
+
"SELECT e.name, e.department_id, e.salary\n"
|
| 150 |
+
"FROM employees e, dept_avg d\n"
|
| 151 |
+
"WHERE e.salary > d.avg_salary\n"
|
| 152 |
+
"ORDER BY e.department_id, e.salary DESC")
|
| 153 |
+
|
| 154 |
+
obs = _attempt(env, "correct CTE with proper JOIN",
|
| 155 |
+
"WITH dept_avg AS (\n"
|
| 156 |
+
" SELECT department_id, AVG(salary) AS avg_salary\n"
|
| 157 |
+
" FROM employees\n"
|
| 158 |
+
" GROUP BY department_id\n"
|
| 159 |
+
")\n"
|
| 160 |
+
"SELECT e.name, e.department_id, e.salary\n"
|
| 161 |
+
"FROM employees e\n"
|
| 162 |
+
"JOIN dept_avg d ON e.department_id = d.department_id\n"
|
| 163 |
+
"WHERE e.salary > d.avg_salary\n"
|
| 164 |
+
"ORDER BY e.department_id, e.salary DESC")
|
| 165 |
+
|
| 166 |
+
print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ββ Custom task demo ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
+
|
| 171 |
+
def run_custom(env):
|
| 172 |
+
_section("TASK 4 Β· CUSTOM β NULL Handling in Aggregation")
|
| 173 |
+
|
| 174 |
+
# Register a brand-new task at runtime
|
| 175 |
+
custom_task = task_from_dict({
|
| 176 |
+
"id": "custom_null_avg",
|
| 177 |
+
"level": "custom",
|
| 178 |
+
"title": "Handle NULLs in Aggregation",
|
| 179 |
+
"description": """\
|
| 180 |
+
TASK: The query below skips NULL scores, making the class average look higher.
|
| 181 |
+
Fix it so NULL scores are treated as 0.
|
| 182 |
+
|
| 183 |
+
SCHEMA:
|
| 184 |
+
students(id INTEGER, name VARCHAR, score INTEGER)
|
| 185 |
+
|
| 186 |
+
BROKEN QUERY:
|
| 187 |
+
SELECT AVG(score) AS avg_score FROM students
|
| 188 |
+
|
| 189 |
+
ERROR:
|
| 190 |
+
NULL values are silently excluded by AVG(), inflating the result.
|
| 191 |
+
|
| 192 |
+
GOAL: Return a single row with avg_score that treats NULL as 0.
|
| 193 |
+
Expected result: avg_score = 72.5""",
|
| 194 |
+
"schema_ddl": """\
|
| 195 |
+
CREATE TABLE students (id INTEGER, name VARCHAR, score INTEGER);
|
| 196 |
+
INSERT INTO students VALUES
|
| 197 |
+
(1, 'Alice', 90),
|
| 198 |
+
(2, 'Bob', NULL),
|
| 199 |
+
(3, 'Carol', 80),
|
| 200 |
+
(4, 'Dave', NULL),
|
| 201 |
+
(5, 'Eve', 70),
|
| 202 |
+
(6, 'Frank', 50);
|
| 203 |
+
""",
|
| 204 |
+
"broken_query": "SELECT AVG(score) AS avg_score FROM students",
|
| 205 |
+
"error_message": "NULL scores are silently skipped by AVG().",
|
| 206 |
+
"hint": "Wrap score with COALESCE(score, 0) before averaging.",
|
| 207 |
+
"expected_rows": [{"avg_score": 65.0}],
|
| 208 |
+
"solution_query": "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students",
|
| 209 |
+
"test_description": "AVG treats NULL as 0 β 65.0",
|
| 210 |
+
"max_steps": 4,
|
| 211 |
+
})
|
| 212 |
+
REGISTRY.register(custom_task)
|
| 213 |
+
|
| 214 |
+
obs = env.reset(task_id="custom_null_avg")
|
| 215 |
+
print(f"\n Task : {obs.task_title} [{obs.task_level}]")
|
| 216 |
+
print(f" Steps: up to {custom_task.max_steps}")
|
| 217 |
+
_print_obs(obs, show_description=True)
|
| 218 |
+
|
| 219 |
+
_attempt(env, "broken (NULL excluded)",
|
| 220 |
+
"SELECT AVG(score) AS avg_score FROM students")
|
| 221 |
+
|
| 222 |
+
obs = _attempt(env, "correct (COALESCE)",
|
| 223 |
+
"SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students")
|
| 224 |
+
|
| 225 |
+
print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
|
| 226 |
+
|
| 227 |
+
# Clean up: remove custom task from registry
|
| 228 |
+
REGISTRY.unregister("custom_null_avg")
|
| 229 |
+
print(" Custom task unregistered from registry.")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
ai_key = os.environ.get("ANTHROPIC_API_KEY")
|
| 236 |
+
|
| 237 |
+
_hr("β")
|
| 238 |
+
print(" QueryForge β Local Playbook")
|
| 239 |
+
print(f" AI judge : {'LIVE (ANTHROPIC_API_KEY set)' if ai_key else 'OFFLINE (fallback to deterministic, max 0.80)'}")
|
| 240 |
+
_hr("β")
|
| 241 |
+
|
| 242 |
+
# Create a fresh env for each task so cycling order never matters
|
| 243 |
+
run_easy(QueryforgeEnvironment())
|
| 244 |
+
run_medium(QueryforgeEnvironment())
|
| 245 |
+
run_hard(QueryforgeEnvironment())
|
| 246 |
+
run_custom(QueryforgeEnvironment())
|
| 247 |
+
|
| 248 |
+
_section("DONE")
|
| 249 |
+
print(" All 4 tasks completed.\n")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-queryforge"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Queryforge environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.1",
|
| 21 |
+
# SQL execution engine (in-memory, no external deps)
|
| 22 |
+
"duckdb>=0.10.0",
|
| 23 |
+
# AI judge β quality scoring via Anthropic API
|
| 24 |
+
"anthropic>=0.25.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[project.optional-dependencies]
|
| 28 |
+
dev = [
|
| 29 |
+
"pytest>=8.0.0",
|
| 30 |
+
"pytest-cov>=4.0.0",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.scripts]
|
| 34 |
+
# Server entry point - enables running via: uv run --project . server
|
| 35 |
+
# or: python -m queryforge.server.app
|
| 36 |
+
server = "queryforge.server.app:main"
|
| 37 |
+
|
| 38 |
+
[tool.setuptools]
|
| 39 |
+
include-package-data = true
|
| 40 |
+
packages = ["queryforge", "queryforge.server"]
|
| 41 |
+
package-dir = { "queryforge" = ".", "queryforge.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Queryforge environment server components."""
|
| 8 |
+
|
| 9 |
+
from .queryforge_environment import QueryforgeEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["QueryforgeEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Queryforge Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the QueryforgeEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from openenv.core.env_server.http_server import create_app
|
| 33 |
+
except Exception as e: # pragma: no cover
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 36 |
+
) from e
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from ..models import SQLAction, SQLObservation, TaskSpec
|
| 40 |
+
from ..tasks import REGISTRY, task_from_dict
|
| 41 |
+
from .queryforge_environment import QueryforgeEnvironment
|
| 42 |
+
except ImportError:
|
| 43 |
+
from models import SQLAction, SQLObservation, TaskSpec
|
| 44 |
+
from tasks import REGISTRY, task_from_dict
|
| 45 |
+
from server.queryforge_environment import QueryforgeEnvironment
|
| 46 |
+
|
| 47 |
+
from fastapi import HTTPException
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Create the app with web interface and README integration
|
| 51 |
+
app = create_app(
|
| 52 |
+
QueryforgeEnvironment,
|
| 53 |
+
SQLAction,
|
| 54 |
+
SQLObservation,
|
| 55 |
+
env_name="queryforge",
|
| 56 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ββ Task Registry REST endpoints ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
|
| 62 |
+
@app.post("/tasks", tags=["Task Registry"], status_code=201)
|
| 63 |
+
async def register_task(spec: TaskSpec):
|
| 64 |
+
"""Register a custom SQL task. Replaces silently if the ID already exists."""
|
| 65 |
+
task = task_from_dict(spec.model_dump())
|
| 66 |
+
REGISTRY.register(task)
|
| 67 |
+
return {"ok": True, "task_id": task.id, "total_tasks": len(REGISTRY)}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@app.get("/tasks", tags=["Task Registry"])
|
| 71 |
+
async def list_tasks():
|
| 72 |
+
"""List all registered tasks (built-in + custom)."""
|
| 73 |
+
return [
|
| 74 |
+
{"id": t.id, "level": t.level, "title": t.title}
|
| 75 |
+
for t in REGISTRY.list_all()
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@app.delete("/tasks/{task_id}", tags=["Task Registry"])
|
| 80 |
+
async def delete_task(task_id: str):
|
| 81 |
+
"""Remove a custom task. Returns 403 for built-in tasks, 404 if not found."""
|
| 82 |
+
try:
|
| 83 |
+
REGISTRY.unregister(task_id)
|
| 84 |
+
return {"ok": True, "task_id": task_id}
|
| 85 |
+
except ValueError as exc:
|
| 86 |
+
raise HTTPException(status_code=403, detail=str(exc))
|
| 87 |
+
except KeyError:
|
| 88 |
+
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 92 |
+
"""
|
| 93 |
+
Entry point for direct execution via uv run or python -m.
|
| 94 |
+
|
| 95 |
+
This function enables running the server without Docker:
|
| 96 |
+
uv run --project . server
|
| 97 |
+
uv run --project . server --port 8001
|
| 98 |
+
python -m queryforge.server.app
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 102 |
+
port: Port number to listen on (default: 8000)
|
| 103 |
+
|
| 104 |
+
For production deployments, consider using uvicorn directly with
|
| 105 |
+
multiple workers:
|
| 106 |
+
uvicorn queryforge.server.app:app --workers 4
|
| 107 |
+
"""
|
| 108 |
+
import uvicorn
|
| 109 |
+
|
| 110 |
+
uvicorn.run(app, host=host, port=port)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
import argparse
|
| 115 |
+
|
| 116 |
+
parser = argparse.ArgumentParser()
|
| 117 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
main(port=args.port)
|
server/queryforge_environment.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryForge SQL Environment β server-side implementation.
|
| 3 |
+
|
| 4 |
+
The agent interacts with a SQL debugging and optimisation challenge:
|
| 5 |
+
reset() β next task in round-robin rotation
|
| 6 |
+
reset(task_id="x") β pin to a specific task by ID (built-in or custom)
|
| 7 |
+
step() β grade the submitted query, return scored observation
|
| 8 |
+
state β episode_id + step count
|
| 9 |
+
|
| 10 |
+
Reward scale:
|
| 11 |
+
0.00 syntax error
|
| 12 |
+
0.15 syntax valid, runtime error
|
| 13 |
+
0.30 executes, wrong / empty results
|
| 14 |
+
0.30β0.80 partial row correctness (deterministic, DuckDB)
|
| 15 |
+
0.80β1.00 correct results + AI quality assessment (Anthropic)
|
| 16 |
+
|
| 17 |
+
Episode ends when:
|
| 18 |
+
- score >= 0.90 (correct + high-quality solution)
|
| 19 |
+
- best_score has not improved for 2 consecutive steps (early stopping)
|
| 20 |
+
- max_steps for the task is exhausted
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from typing import Optional
|
| 24 |
+
from uuid import uuid4
|
| 25 |
+
|
| 26 |
+
from openenv.core.env_server.interfaces import Environment
|
| 27 |
+
from openenv.core.env_server.types import State
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from ..models import SQLAction, SQLObservation
|
| 31 |
+
from ..tasks import REGISTRY, SQLTask
|
| 32 |
+
from ..judge import grade
|
| 33 |
+
except ImportError:
|
| 34 |
+
from models import SQLAction, SQLObservation
|
| 35 |
+
from tasks import REGISTRY, SQLTask
|
| 36 |
+
from judge import grade
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class QueryforgeEnvironment(Environment):
|
| 40 |
+
"""
|
| 41 |
+
SQL Query Debugger & Optimiser environment.
|
| 42 |
+
|
| 43 |
+
Built-in tasks (cycled in order by default):
|
| 44 |
+
1. easy β fix three misspelled SQL keywords
|
| 45 |
+
2. medium β fix a missing JOIN condition causing a cartesian product
|
| 46 |
+
3. hard β rewrite a correlated subquery as a CTE
|
| 47 |
+
|
| 48 |
+
Custom tasks can be registered at runtime via POST /tasks and then
|
| 49 |
+
requested by passing task_id to reset():
|
| 50 |
+
env.reset(task_id="my_custom_task")
|
| 51 |
+
|
| 52 |
+
Each episode ends when:
|
| 53 |
+
- The agent achieves score β₯ 0.90 (correct + high-quality solution), or
|
| 54 |
+
- best_score has not improved for 2 consecutive steps (early stopping), or
|
| 55 |
+
- The maximum steps for the current task is exhausted.
|
| 56 |
+
|
| 57 |
+
Supports concurrent WebSocket sessions (each client gets its own instance).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 61 |
+
|
| 62 |
+
# Episode ends when score >= this threshold.
|
| 63 |
+
# Falls back to 0.80 when ANTHROPIC_API_KEY is unset (AI judge offline,
|
| 64 |
+
# deterministic scoring caps at 0.80).
|
| 65 |
+
DONE_THRESHOLD: float = 0.80 if not __import__("os").environ.get("ANTHROPIC_API_KEY") else 0.90
|
| 66 |
+
# Episode ends when best_score has not improved for this many consecutive steps
|
| 67 |
+
EARLY_STOP_STEPS: int = 2
|
| 68 |
+
|
| 69 |
+
def __init__(self) -> None:
|
| 70 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 71 |
+
self._current_task: Optional[SQLTask] = None
|
| 72 |
+
self._best_score: float = 0.0
|
| 73 |
+
self._attempt: int = 0
|
| 74 |
+
self._stale_steps: int = 0 # consecutive steps with no best_score improvement
|
| 75 |
+
|
| 76 |
+
# ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
def reset(
|
| 79 |
+
self,
|
| 80 |
+
task_id: Optional[str] = None,
|
| 81 |
+
seed: Optional[int] = None,
|
| 82 |
+
episode_id: Optional[str] = None,
|
| 83 |
+
**kwargs,
|
| 84 |
+
) -> SQLObservation:
|
| 85 |
+
"""
|
| 86 |
+
Start a new episode.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
task_id: Pin to a specific task by ID. If None, the registry
|
| 90 |
+
cycles round-robin through all registered tasks.
|
| 91 |
+
seed: Ignored (reserved for future use).
|
| 92 |
+
episode_id: Optional custom episode identifier.
|
| 93 |
+
"""
|
| 94 |
+
ep_id = episode_id or str(uuid4())
|
| 95 |
+
self._state = State(episode_id=ep_id, step_count=0)
|
| 96 |
+
self._best_score = 0.0
|
| 97 |
+
self._attempt = 0
|
| 98 |
+
self._stale_steps = 0
|
| 99 |
+
|
| 100 |
+
if task_id is not None:
|
| 101 |
+
try:
|
| 102 |
+
self._current_task = REGISTRY.get(task_id)
|
| 103 |
+
except KeyError as exc:
|
| 104 |
+
# Unknown task_id β return an error observation so the caller
|
| 105 |
+
# gets clear feedback instead of a silent 500.
|
| 106 |
+
return SQLObservation(
|
| 107 |
+
feedback=str(exc),
|
| 108 |
+
hint=f"Available task IDs: {', '.join(REGISTRY.ids())}",
|
| 109 |
+
done=True,
|
| 110 |
+
reward=0.0,
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
self._current_task = REGISTRY.cycle_next()
|
| 114 |
+
|
| 115 |
+
return SQLObservation(
|
| 116 |
+
task_id=self._current_task.id,
|
| 117 |
+
task_level=self._current_task.level,
|
| 118 |
+
task_title=self._current_task.title,
|
| 119 |
+
task_description=self._current_task.description,
|
| 120 |
+
syntax_valid=False,
|
| 121 |
+
execution_success=False,
|
| 122 |
+
execution_error=None,
|
| 123 |
+
rows_returned=0,
|
| 124 |
+
feedback="New task loaded. Submit your fixed/optimised SQL query.",
|
| 125 |
+
hint=self._current_task.hint,
|
| 126 |
+
attempt=0,
|
| 127 |
+
best_score=0.0,
|
| 128 |
+
done=False,
|
| 129 |
+
reward=0.0,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
|
| 133 |
+
"""Grade the submitted SQL query and return a scored observation."""
|
| 134 |
+
self._state.step_count += 1
|
| 135 |
+
self._attempt += 1
|
| 136 |
+
|
| 137 |
+
if self._current_task is None:
|
| 138 |
+
return SQLObservation(
|
| 139 |
+
feedback="No task active. Call reset() first.",
|
| 140 |
+
hint="Call reset() to start a new episode.",
|
| 141 |
+
done=True,
|
| 142 |
+
reward=0.0,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
score, feedback, details = grade(self._current_task, action.sql)
|
| 146 |
+
|
| 147 |
+
# Fix 1 β early stopping: track consecutive steps with no improvement
|
| 148 |
+
if score > self._best_score:
|
| 149 |
+
self._stale_steps = 0
|
| 150 |
+
else:
|
| 151 |
+
self._stale_steps += 1
|
| 152 |
+
self._best_score = max(self._best_score, score)
|
| 153 |
+
|
| 154 |
+
# Fix 3 β lower done threshold + early stopping condition
|
| 155 |
+
done = (
|
| 156 |
+
score >= self.DONE_THRESHOLD
|
| 157 |
+
or self._stale_steps >= self.EARLY_STOP_STEPS
|
| 158 |
+
or self._state.step_count >= self._current_task.max_steps
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return SQLObservation(
|
| 162 |
+
task_id=self._current_task.id,
|
| 163 |
+
task_level=self._current_task.level,
|
| 164 |
+
task_title=self._current_task.title,
|
| 165 |
+
task_description=self._current_task.description,
|
| 166 |
+
syntax_valid=bool(details.get("syntax_valid", False)),
|
| 167 |
+
execution_success=bool(details.get("execution_success", False)),
|
| 168 |
+
execution_error=details.get("execution_error"),
|
| 169 |
+
rows_returned=int(details.get("rows_returned", 0)),
|
| 170 |
+
feedback=feedback,
|
| 171 |
+
hint="" if score >= 0.9 else self._current_task.hint,
|
| 172 |
+
attempt=self._attempt,
|
| 173 |
+
best_score=self._best_score,
|
| 174 |
+
done=done,
|
| 175 |
+
reward=score,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def state(self) -> State:
|
| 180 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
duckdb>=0.10.0
|
| 5 |
+
anthropic>=0.25.0
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
tasks.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL task definitions and runtime task registry for the QueryForge environment.
|
| 3 |
+
|
| 4 |
+
Built-in tasks:
|
| 5 |
+
easy β fix three misspelled SQL keywords
|
| 6 |
+
medium β fix a cartesian JOIN producing wrong results
|
| 7 |
+
hard β rewrite a correlated subquery as a CTE
|
| 8 |
+
|
| 9 |
+
Custom tasks can be added at runtime via REGISTRY.register() or
|
| 10 |
+
POST /tasks on the running server.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from threading import Lock
|
| 17 |
+
from typing import Any, Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ββ Data classes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TestCase:
|
| 24 |
+
"""A single test case: expected output rows for correctness grading."""
|
| 25 |
+
|
| 26 |
+
description: str
|
| 27 |
+
expected_rows: List[Dict[str, Any]]
|
| 28 |
+
order_by: Optional[str] = None # comma-separated columns to sort by
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class SQLTask:
|
| 33 |
+
"""Full definition of one SQL challenge."""
|
| 34 |
+
|
| 35 |
+
id: str
|
| 36 |
+
level: str # "easy" | "medium" | "hard" | "custom"
|
| 37 |
+
title: str
|
| 38 |
+
description: str
|
| 39 |
+
schema_ddl: str # DDL + seed INSERT statements for DuckDB
|
| 40 |
+
broken_query: str # broken/slow query the agent must fix
|
| 41 |
+
error_message: str # error or performance warning shown to agent
|
| 42 |
+
hint: str
|
| 43 |
+
test_cases: List[TestCase]
|
| 44 |
+
solution_query: str # reference solution used by the AI judge
|
| 45 |
+
max_steps: int = 5
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ββ Built-in tasks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
|
| 50 |
+
_TASK_EASY = SQLTask(
|
| 51 |
+
id="task_easy_syntax",
|
| 52 |
+
level="easy",
|
| 53 |
+
title="Fix the Syntax Errors",
|
| 54 |
+
description="""\
|
| 55 |
+
TASK: Fix the syntax errors in the query below so it runs correctly.
|
| 56 |
+
|
| 57 |
+
SCHEMA:
|
| 58 |
+
users(id INTEGER, name VARCHAR, age INTEGER, city VARCHAR)
|
| 59 |
+
|
| 60 |
+
BROKEN QUERY:
|
| 61 |
+
SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'
|
| 62 |
+
|
| 63 |
+
ERROR:
|
| 64 |
+
Parser Error: syntax error at or near "SELEC"
|
| 65 |
+
|
| 66 |
+
GOAL: Return a valid SQL query that retrieves `name` and `age`
|
| 67 |
+
of users who are older than 30 AND live in New York.
|
| 68 |
+
Order by name ASC.""",
|
| 69 |
+
schema_ddl="""\
|
| 70 |
+
CREATE TABLE users (
|
| 71 |
+
id INTEGER,
|
| 72 |
+
name VARCHAR,
|
| 73 |
+
age INTEGER,
|
| 74 |
+
city VARCHAR
|
| 75 |
+
);
|
| 76 |
+
INSERT INTO users VALUES
|
| 77 |
+
(1, 'Alice', 35, 'New York'),
|
| 78 |
+
(2, 'Bob', 28, 'New York'),
|
| 79 |
+
(3, 'Carol', 42, 'Chicago'),
|
| 80 |
+
(4, 'Dave', 31, 'New York'),
|
| 81 |
+
(5, 'Eve', 25, 'New York'),
|
| 82 |
+
(6, 'Frank', 38, 'New York');
|
| 83 |
+
""",
|
| 84 |
+
broken_query="SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'",
|
| 85 |
+
error_message='Parser Error: syntax error at or near "SELEC"',
|
| 86 |
+
hint="Three SQL keywords are misspelled: SELEC β SELECT, FORM β FROM, WEHRE β WHERE.",
|
| 87 |
+
test_cases=[
|
| 88 |
+
TestCase(
|
| 89 |
+
description="Users over 30 living in New York, ordered by name",
|
| 90 |
+
expected_rows=[
|
| 91 |
+
{"name": "Alice", "age": 35},
|
| 92 |
+
{"name": "Dave", "age": 31},
|
| 93 |
+
{"name": "Frank", "age": 38},
|
| 94 |
+
],
|
| 95 |
+
order_by="name",
|
| 96 |
+
)
|
| 97 |
+
],
|
| 98 |
+
solution_query=(
|
| 99 |
+
"SELECT name, age FROM users "
|
| 100 |
+
"WHERE age > 30 AND city = 'New York' "
|
| 101 |
+
"ORDER BY name ASC"
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
_TASK_MEDIUM = SQLTask(
|
| 106 |
+
id="task_medium_join",
|
| 107 |
+
level="medium",
|
| 108 |
+
title="Fix the Cartesian JOIN",
|
| 109 |
+
description="""\
|
| 110 |
+
TASK: The query below produces wildly inflated totals because a JOIN condition
|
| 111 |
+
is missing, creating a cartesian product with the `products` table. Fix it.
|
| 112 |
+
|
| 113 |
+
SCHEMAS:
|
| 114 |
+
users(id INTEGER, name VARCHAR, age INTEGER)
|
| 115 |
+
products(id INTEGER, title VARCHAR, price DECIMAL)
|
| 116 |
+
orders(id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL)
|
| 117 |
+
|
| 118 |
+
BROKEN QUERY:
|
| 119 |
+
SELECT u.name, p.title, SUM(o.amount) AS total_spent
|
| 120 |
+
FROM orders o, users u, products p
|
| 121 |
+
WHERE o.user_id = u.id
|
| 122 |
+
GROUP BY u.name, p.title
|
| 123 |
+
ORDER BY total_spent DESC
|
| 124 |
+
|
| 125 |
+
PROBLEM:
|
| 126 |
+
Missing join condition `o.product_id = p.id`.
|
| 127 |
+
Every order row is multiplied by ALL products, inflating every total by 3Γ.
|
| 128 |
+
|
| 129 |
+
GOAL: Rewrite using explicit INNER JOIN β¦ ON syntax with all correct join
|
| 130 |
+
conditions. Return user name, product title, and true total amount spent per
|
| 131 |
+
(user, product) pair, ordered by total_spent DESC.""",
|
| 132 |
+
schema_ddl="""\
|
| 133 |
+
CREATE TABLE users (id INTEGER, name VARCHAR, age INTEGER);
|
| 134 |
+
CREATE TABLE products (id INTEGER, title VARCHAR, price DECIMAL);
|
| 135 |
+
CREATE TABLE orders (id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL);
|
| 136 |
+
|
| 137 |
+
INSERT INTO users VALUES (1,'Alice',30),(2,'Bob',25),(3,'Carol',35);
|
| 138 |
+
INSERT INTO products VALUES (1,'Laptop',999.99),(2,'Phone',599.99),(3,'Tablet',399.99);
|
| 139 |
+
INSERT INTO orders VALUES
|
| 140 |
+
(1,1,1,999.99),(2,1,2,599.99),
|
| 141 |
+
(3,2,1,999.99),(4,2,3,399.99),
|
| 142 |
+
(5,3,2,599.99),(6,3,1,999.99);
|
| 143 |
+
""",
|
| 144 |
+
broken_query="""\
|
| 145 |
+
SELECT u.name, p.title, SUM(o.amount) AS total_spent
|
| 146 |
+
FROM orders o, users u, products p
|
| 147 |
+
WHERE o.user_id = u.id
|
| 148 |
+
GROUP BY u.name, p.title
|
| 149 |
+
ORDER BY total_spent DESC""",
|
| 150 |
+
error_message=(
|
| 151 |
+
"Query runs but produces WRONG results: totals are 3Γ too high "
|
| 152 |
+
"because every order is joined to every product (cartesian product)."
|
| 153 |
+
),
|
| 154 |
+
hint=(
|
| 155 |
+
"Use INNER JOIN β¦ ON for every table. "
|
| 156 |
+
"You need both: o.user_id = u.id AND o.product_id = p.id."
|
| 157 |
+
),
|
| 158 |
+
test_cases=[
|
| 159 |
+
TestCase(
|
| 160 |
+
description="Correct per-(user, product) totals",
|
| 161 |
+
expected_rows=[
|
| 162 |
+
{"name": "Alice", "title": "Laptop", "total_spent": 999.99},
|
| 163 |
+
{"name": "Alice", "title": "Phone", "total_spent": 599.99},
|
| 164 |
+
{"name": "Bob", "title": "Laptop", "total_spent": 999.99},
|
| 165 |
+
{"name": "Bob", "title": "Tablet", "total_spent": 399.99},
|
| 166 |
+
{"name": "Carol", "title": "Laptop", "total_spent": 999.99},
|
| 167 |
+
{"name": "Carol", "title": "Phone", "total_spent": 599.99},
|
| 168 |
+
],
|
| 169 |
+
order_by="name,title",
|
| 170 |
+
)
|
| 171 |
+
],
|
| 172 |
+
solution_query="""\
|
| 173 |
+
SELECT u.name, p.title, SUM(o.amount) AS total_spent
|
| 174 |
+
FROM orders o
|
| 175 |
+
INNER JOIN users u ON o.user_id = u.id
|
| 176 |
+
INNER JOIN products p ON o.product_id = p.id
|
| 177 |
+
GROUP BY u.name, p.title
|
| 178 |
+
ORDER BY total_spent DESC""",
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
_TASK_HARD = SQLTask(
|
| 182 |
+
id="task_hard_cte",
|
| 183 |
+
level="hard",
|
| 184 |
+
title="Rewrite Correlated Subquery as CTE",
|
| 185 |
+
description="""\
|
| 186 |
+
TASK: The query below is semantically correct but executes the inner AVG(salary)
|
| 187 |
+
once per employee row β O(N) full scans. Rewrite it using a WITH (CTE) so the
|
| 188 |
+
department averages are computed exactly once.
|
| 189 |
+
|
| 190 |
+
SCHEMAS:
|
| 191 |
+
departments(id INTEGER, dept_name VARCHAR)
|
| 192 |
+
employees(id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL)
|
| 193 |
+
|
| 194 |
+
SLOW QUERY:
|
| 195 |
+
SELECT e.name, e.department_id, e.salary
|
| 196 |
+
FROM employees e
|
| 197 |
+
WHERE e.salary > (
|
| 198 |
+
SELECT AVG(e2.salary)
|
| 199 |
+
FROM employees e2
|
| 200 |
+
WHERE e2.department_id = e.department_id
|
| 201 |
+
)
|
| 202 |
+
ORDER BY e.department_id, e.salary DESC
|
| 203 |
+
|
| 204 |
+
PERFORMANCE WARNING:
|
| 205 |
+
For 1 M employees the inner subquery executes 1 M times.
|
| 206 |
+
DuckDB's EXPLAIN shows: 'FILTER ... (subquery)' with nested loop.
|
| 207 |
+
|
| 208 |
+
GOAL: Rewrite using a CTE that computes per-department average salary once,
|
| 209 |
+
then join it to employees and filter. The result must be identical:
|
| 210 |
+
employees who earn strictly above their own department's average salary,
|
| 211 |
+
ordered by department_id ASC, salary DESC.""",
|
| 212 |
+
schema_ddl="""\
|
| 213 |
+
CREATE TABLE departments (id INTEGER, dept_name VARCHAR);
|
| 214 |
+
CREATE TABLE employees (id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL);
|
| 215 |
+
|
| 216 |
+
INSERT INTO departments VALUES (1,'Engineering'),(2,'Marketing'),(3,'Sales');
|
| 217 |
+
INSERT INTO employees VALUES
|
| 218 |
+
(1,'Alice', 1, 95000),(2,'Bob', 1, 75000),(3,'Carol', 1, 85000),
|
| 219 |
+
(4,'Dave', 2, 65000),(5,'Eve', 2, 70000),(6,'Frank', 2, 60000),
|
| 220 |
+
(7,'Grace', 3, 55000),(8,'Hank', 3, 72000),(9,'Iris', 3, 58000);
|
| 221 |
+
""",
|
| 222 |
+
broken_query="""\
|
| 223 |
+
SELECT e.name, e.department_id, e.salary
|
| 224 |
+
FROM employees e
|
| 225 |
+
WHERE e.salary > (
|
| 226 |
+
SELECT AVG(e2.salary)
|
| 227 |
+
FROM employees e2
|
| 228 |
+
WHERE e2.department_id = e.department_id
|
| 229 |
+
)
|
| 230 |
+
ORDER BY e.department_id, e.salary DESC""",
|
| 231 |
+
error_message=(
|
| 232 |
+
"PERFORMANCE: Correlated subquery re-executes AVG() for every row. "
|
| 233 |
+
"On large tables this is O(NΒ²). Rewrite as a CTE for O(N) execution."
|
| 234 |
+
),
|
| 235 |
+
hint=(
|
| 236 |
+
"WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary "
|
| 237 |
+
"FROM employees GROUP BY department_id) β then JOIN employees to dept_avg "
|
| 238 |
+
"and filter WHERE e.salary > d.avg_salary."
|
| 239 |
+
),
|
| 240 |
+
test_cases=[
|
| 241 |
+
TestCase(
|
| 242 |
+
description="Employees strictly above their department's average salary",
|
| 243 |
+
expected_rows=[
|
| 244 |
+
{"name": "Alice", "department_id": 1, "salary": 95000.0},
|
| 245 |
+
{"name": "Eve", "department_id": 2, "salary": 70000.0},
|
| 246 |
+
{"name": "Hank", "department_id": 3, "salary": 72000.0},
|
| 247 |
+
],
|
| 248 |
+
order_by="department_id,name",
|
| 249 |
+
)
|
| 250 |
+
],
|
| 251 |
+
solution_query="""\
|
| 252 |
+
WITH dept_avg AS (
|
| 253 |
+
SELECT department_id, AVG(salary) AS avg_salary
|
| 254 |
+
FROM employees
|
| 255 |
+
GROUP BY department_id
|
| 256 |
+
)
|
| 257 |
+
SELECT e.name, e.department_id, e.salary
|
| 258 |
+
FROM employees e
|
| 259 |
+
JOIN dept_avg d ON e.department_id = d.department_id
|
| 260 |
+
WHERE e.salary > d.avg_salary
|
| 261 |
+
ORDER BY e.department_id, e.salary DESC""",
|
| 262 |
+
max_steps=6,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ββ Task Registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
+
|
| 268 |
+
class TaskRegistry:
|
| 269 |
+
"""
|
| 270 |
+
Thread-safe registry of SQL tasks, shared across all environment sessions.
|
| 271 |
+
|
| 272 |
+
Built-in tasks (easy / medium / hard) are always present and cannot be removed.
|
| 273 |
+
Custom tasks can be added via register(), load_from_json(), or POST /tasks.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
_BUILTIN_IDS: frozenset = frozenset(
|
| 277 |
+
["task_easy_syntax", "task_medium_join", "task_hard_cte"]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def __init__(self, initial_tasks: List[SQLTask]) -> None:
|
| 281 |
+
self._lock = Lock()
|
| 282 |
+
# Insertion-ordered dict preserves cycling order
|
| 283 |
+
self._tasks: Dict[str, SQLTask] = {t.id: t for t in initial_tasks}
|
| 284 |
+
self._cycle_index: int = 0
|
| 285 |
+
|
| 286 |
+
# ββ CRUD βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 287 |
+
|
| 288 |
+
def register(self, task: SQLTask) -> None:
|
| 289 |
+
"""Add or replace a task. Replaces silently if the ID already exists."""
|
| 290 |
+
with self._lock:
|
| 291 |
+
self._tasks[task.id] = task
|
| 292 |
+
|
| 293 |
+
def unregister(self, task_id: str) -> None:
|
| 294 |
+
"""
|
| 295 |
+
Remove a custom task.
|
| 296 |
+
Raises ValueError for built-in tasks, KeyError if not found.
|
| 297 |
+
"""
|
| 298 |
+
if task_id in self._BUILTIN_IDS:
|
| 299 |
+
raise ValueError(f"Built-in task '{task_id}' cannot be removed.")
|
| 300 |
+
with self._lock:
|
| 301 |
+
if task_id not in self._tasks:
|
| 302 |
+
raise KeyError(task_id)
|
| 303 |
+
del self._tasks[task_id]
|
| 304 |
+
|
| 305 |
+
def get(self, task_id: str) -> SQLTask:
|
| 306 |
+
"""Return a task by ID. Raises KeyError with available IDs if not found."""
|
| 307 |
+
with self._lock:
|
| 308 |
+
if task_id not in self._tasks:
|
| 309 |
+
available = ", ".join(self._tasks.keys())
|
| 310 |
+
raise KeyError(
|
| 311 |
+
f"Task '{task_id}' not found. "
|
| 312 |
+
f"Available: {available}"
|
| 313 |
+
)
|
| 314 |
+
return self._tasks[task_id]
|
| 315 |
+
|
| 316 |
+
def list_all(self) -> List[SQLTask]:
|
| 317 |
+
"""Return all registered tasks in insertion order."""
|
| 318 |
+
with self._lock:
|
| 319 |
+
return list(self._tasks.values())
|
| 320 |
+
|
| 321 |
+
def ids(self) -> List[str]:
|
| 322 |
+
"""Return all task IDs in insertion order."""
|
| 323 |
+
with self._lock:
|
| 324 |
+
return list(self._tasks.keys())
|
| 325 |
+
|
| 326 |
+
# ββ Cycling βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 327 |
+
|
| 328 |
+
def cycle_next(self) -> SQLTask:
|
| 329 |
+
"""Return the next task in round-robin order (wraps at end)."""
|
| 330 |
+
with self._lock:
|
| 331 |
+
tasks = list(self._tasks.values())
|
| 332 |
+
task = tasks[self._cycle_index % len(tasks)]
|
| 333 |
+
self._cycle_index += 1
|
| 334 |
+
return task
|
| 335 |
+
|
| 336 |
+
# ββ Bulk loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 337 |
+
|
| 338 |
+
def load_from_json(self, path: str) -> int:
|
| 339 |
+
"""
|
| 340 |
+
Load tasks from a JSON file (list of task spec objects).
|
| 341 |
+
Returns the number of tasks loaded.
|
| 342 |
+
|
| 343 |
+
Minimal required fields per task:
|
| 344 |
+
id, schema_ddl, expected_rows
|
| 345 |
+
|
| 346 |
+
Example file::
|
| 347 |
+
|
| 348 |
+
[
|
| 349 |
+
{
|
| 350 |
+
"id": "my_null_task",
|
| 351 |
+
"level": "medium",
|
| 352 |
+
"title": "Handle NULLs in aggregation",
|
| 353 |
+
"schema_ddl": "CREATE TABLE ...; INSERT ...",
|
| 354 |
+
"broken_query": "SELECT AVG(score) FROM ...",
|
| 355 |
+
"expected_rows": [{"avg_score": 72.5}],
|
| 356 |
+
"hint": "Use COALESCE to handle NULL scores."
|
| 357 |
+
}
|
| 358 |
+
]
|
| 359 |
+
"""
|
| 360 |
+
raw = json.loads(Path(path).read_text())
|
| 361 |
+
if isinstance(raw, dict):
|
| 362 |
+
raw = [raw]
|
| 363 |
+
for item in raw:
|
| 364 |
+
self.register(task_from_dict(item))
|
| 365 |
+
return len(raw)
|
| 366 |
+
|
| 367 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 368 |
+
|
| 369 |
+
def __len__(self) -> int:
|
| 370 |
+
with self._lock:
|
| 371 |
+
return len(self._tasks)
|
| 372 |
+
|
| 373 |
+
def __contains__(self, task_id: str) -> bool:
|
| 374 |
+
with self._lock:
|
| 375 |
+
return task_id in self._tasks
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ββ Conversion helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 379 |
+
|
| 380 |
+
def task_from_dict(d: Dict[str, Any]) -> SQLTask:
|
| 381 |
+
"""
|
| 382 |
+
Construct an SQLTask from a plain dict (JSON payload or loaded file).
|
| 383 |
+
|
| 384 |
+
Required keys : id, schema_ddl, expected_rows
|
| 385 |
+
Optional keys : level, title, description, broken_query, error_message,
|
| 386 |
+
hint, order_by, solution_query, test_description, max_steps
|
| 387 |
+
"""
|
| 388 |
+
return SQLTask(
|
| 389 |
+
id=d["id"],
|
| 390 |
+
level=d.get("level", "custom"),
|
| 391 |
+
title=d.get("title", d["id"]),
|
| 392 |
+
description=d.get("description", ""),
|
| 393 |
+
schema_ddl=d["schema_ddl"],
|
| 394 |
+
broken_query=d.get("broken_query", ""),
|
| 395 |
+
error_message=d.get("error_message", ""),
|
| 396 |
+
hint=d.get("hint", ""),
|
| 397 |
+
test_cases=[
|
| 398 |
+
TestCase(
|
| 399 |
+
description=d.get("test_description", "Custom test case"),
|
| 400 |
+
expected_rows=d["expected_rows"],
|
| 401 |
+
order_by=d.get("order_by"),
|
| 402 |
+
)
|
| 403 |
+
],
|
| 404 |
+
solution_query=d.get("solution_query", ""),
|
| 405 |
+
max_steps=d.get("max_steps", 5),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# ββ Global singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 410 |
+
|
| 411 |
+
REGISTRY = TaskRegistry([_TASK_EASY, _TASK_MEDIUM, _TASK_HARD])
|
| 412 |
+
|
| 413 |
+
# Backwards-compat: snapshot of the three built-in tasks at import time
|
| 414 |
+
TASKS: List[SQLTask] = [_TASK_EASY, _TASK_MEDIUM, _TASK_HARD]
|
| 415 |
+
TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS}
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|