Prithvigg commited on
Commit
3867c62
Β·
verified Β·
1 Parent(s): a8a3c90

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. client.py +4 -1
  2. inference.py +241 -0
  3. judge.py +27 -15
  4. playbook.py +98 -99
  5. server/queryforge_environment.py +24 -0
  6. tasks.py +290 -6
client.py CHANGED
@@ -7,7 +7,10 @@ from openenv.core import EnvClient
7
  from openenv.core.client_types import StepResult
8
  from openenv.core.env_server.types import State
9
 
10
- from .models import SQLAction, SQLObservation, TaskSpec
 
 
 
11
 
12
 
13
  class QueryforgeEnv(EnvClient[SQLAction, SQLObservation, State]):
 
7
  from openenv.core.client_types import StepResult
8
  from openenv.core.env_server.types import State
9
 
10
+ try:
11
+ from .models import SQLAction, SQLObservation, TaskSpec
12
+ except ImportError:
13
+ from models import SQLAction, SQLObservation, TaskSpec
14
 
15
 
16
  class QueryforgeEnv(EnvClient[SQLAction, SQLObservation, State]):
inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryForge Inference Script
3
+ ===================================
4
+ MANDATORY env vars:
5
+ API_BASE_URL The API endpoint for the LLM (e.g. https://router.huggingface.co/v1)
6
+ MODEL_NAME The model identifier to use for inference
7
+ HF_TOKEN Your Hugging Face / API key
8
+
9
+ Optional env vars:
10
+ ENV_URL QueryForge environment server URL (default: http://localhost:8000)
11
+ ANTHROPIC_API_KEY Enables AI judge for scores up to 1.0 (default: deterministic mode)
12
+ """
13
+
14
+ import os
15
+ import re
16
+ import sys
17
+ import textwrap
18
+ from typing import List, Optional
19
+
20
+ from openai import OpenAI
21
+
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ from client import QueryforgeEnv
25
+ from models import SQLAction
26
+
27
+ # ── Configuration ─────────────────────────────────────────────────────────────
28
+
29
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
30
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
31
+ MODEL_NAME = os.getenv("MODEL_NAME")
32
+ ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
33
+
34
+ MAX_STEPS = 5 # max attempts per task (overridden by task's own max_steps)
35
+ TEMPERATURE = 0.2
36
+ MAX_TOKENS = 512
37
+
38
+ TASK_IDS = [
39
+ "task_easy_syntax",
40
+ "task_medium_join",
41
+ "task_hard_cte",
42
+ "task_expert_rank",
43
+ "task_expert_recursive",
44
+ "task_expert_window",
45
+ ]
46
+
47
+ # ── Prompts ───────────────────────────────────────────────────────────────────
48
+
49
+ SYSTEM_PROMPT = textwrap.dedent("""
50
+ You are an expert SQL engineer tasked with debugging and optimising SQL queries.
51
+ You will be given a SQL challenge that includes a schema, a broken or slow query,
52
+ and a description of what the correct output should be.
53
+
54
+ Rules:
55
+ - Respond with ONLY a single SQL query inside a ```sql ... ``` code block.
56
+ - Do not explain your reasoning outside the code block.
57
+ - Do not include multiple statements separated by semicolons.
58
+ - If you receive grading feedback on a previous attempt, use it to improve.
59
+ """).strip()
60
+
61
+ # ── SQL extraction ─────────────────────────────────────────────────────────────
62
+
63
+ _SQL_BLOCK = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)
64
+
65
+
66
+ def extract_sql(text: str) -> str:
67
+ """Pull the first SQL code block from the model response."""
68
+ match = _SQL_BLOCK.search(text)
69
+ if match:
70
+ return match.group(1).strip()
71
+ return text.strip()
72
+
73
+
74
+ # ── Formatting ────────────────────────────────────────────────────────────────
75
+
76
+ def score_bar(score: float, width: int = 25) -> str:
77
+ filled = int(score * width)
78
+ return "[" + "β–ˆ" * filled + "β–‘" * (width - filled) + f"] {score:.3f}"
79
+
80
+
81
+ def hr(char="═", width=70):
82
+ print(char * width)
83
+
84
+
85
+ # ── Per-task agent loop ────────────────────────────────────────────────────────
86
+
87
+ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
88
+ """
89
+ Run one episode for a single task.
90
+ Returns dict with task_id, task_title, task_level, best_score, attempts, done.
91
+ """
92
+ result = env_client.reset(task_id=task_id)
93
+ obs = result.observation
94
+
95
+ if result.done:
96
+ print(f" ERROR loading task: {obs.feedback}")
97
+ return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
98
+
99
+ print(f"\n Task : {obs.task_title} [{obs.task_level}]")
100
+
101
+ messages = [
102
+ {"role": "system", "content": SYSTEM_PROMPT},
103
+ {
104
+ "role": "user",
105
+ "content": (
106
+ f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
107
+ "Provide your fixed SQL query."
108
+ ),
109
+ },
110
+ ]
111
+
112
+ step = 0
113
+ while not result.done:
114
+ step += 1
115
+
116
+ try:
117
+ completion = llm.chat.completions.create(
118
+ model=MODEL_NAME,
119
+ messages=messages,
120
+ temperature=TEMPERATURE,
121
+ max_tokens=MAX_TOKENS,
122
+ stream=False,
123
+ )
124
+ response_text = completion.choices[0].message.content or ""
125
+ except Exception as exc:
126
+ print(f" LLM call failed at step {step}: {exc}")
127
+ break
128
+
129
+ sql = extract_sql(response_text)
130
+
131
+ # ── Print generated SQL ───────────────────────────────────────────────
132
+ print(f"\n β”Œβ”€ Step {step} Β· SQL submitted {'─' * (50 - len(str(step)))}")
133
+ for line in sql.splitlines():
134
+ print(f" β”‚ {line}")
135
+ print(f" β””{'─' * 56}")
136
+
137
+ result = env_client.step(SQLAction(sql=sql))
138
+ obs = result.observation
139
+
140
+ score = result.reward or 0.0
141
+ done_marker = " βœ“ DONE" if result.done else ""
142
+ print(f" Score : {score_bar(score)}{done_marker}")
143
+
144
+ if not obs.syntax_valid:
145
+ print(f" βœ— Syntax error β€” query could not be parsed")
146
+ elif not obs.execution_success:
147
+ print(f" βœ— Execution failed β€” {(obs.execution_error or '')[:80]}")
148
+ else:
149
+ print(f" βœ“ Executed Β· rows returned: {obs.rows_returned}")
150
+
151
+ if result.done:
152
+ break
153
+
154
+ # ── Why are we going to the next step? ───────────────────────────────
155
+ print(f"\n ↻ Retrying β€” score {score:.3f} below threshold")
156
+ if obs.feedback:
157
+ # Split the feedback into its tagged sections for readable multi-line output
158
+ for part in obs.feedback.split(" "):
159
+ part = part.strip()
160
+ if part:
161
+ print(f" {part}")
162
+ if obs.hint:
163
+ print(f" Hint : {obs.hint[:120]}")
164
+
165
+ # Feed grading result back to the model for the next attempt
166
+ messages.append({"role": "assistant", "content": response_text})
167
+ messages.append({
168
+ "role": "user",
169
+ "content": (
170
+ f"Your query scored {result.reward:.3f}.\n\n"
171
+ f"Feedback: {obs.feedback}\n\n"
172
+ f"Hint: {obs.hint}\n\n"
173
+ "Please submit an improved SQL query."
174
+ ),
175
+ })
176
+
177
+ return {
178
+ "task_id": task_id,
179
+ "task_title": obs.task_title,
180
+ "task_level": obs.task_level,
181
+ "best_score": obs.best_score,
182
+ "attempts": obs.attempt,
183
+ "done": result.done,
184
+ }
185
+
186
+
187
+ # ── Main ───────────────────────────────────────────────────────────────────────
188
+
189
+ def main() -> None:
190
+ # ── Validate required config ──────────────────────────────────────────────
191
+ missing = [v for v in ("API_BASE_URL", "MODEL_NAME") if not os.getenv(v)]
192
+ if missing:
193
+ print(f"ERROR: missing required env vars: {', '.join(missing)}")
194
+ sys.exit(1)
195
+
196
+ if not API_KEY:
197
+ print("ERROR: HF_TOKEN (or API_KEY) is not set.")
198
+ sys.exit(1)
199
+
200
+ llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
201
+
202
+ hr()
203
+ print(" QueryForge β€” Inference")
204
+ print(f" Model : {MODEL_NAME}")
205
+ print(f" Env : {ENV_URL}")
206
+ print(f" Tasks : {', '.join(TASK_IDS)}")
207
+ hr()
208
+
209
+ results = []
210
+
211
+ with QueryforgeEnv(base_url=ENV_URL).sync() as env_client:
212
+ for task_id in TASK_IDS:
213
+ print(f"\n{'─' * 70}")
214
+ result = run_task(task_id, llm, env_client)
215
+ results.append(result)
216
+
217
+ # ── Results table ─────────────────────────────────────────────────────────
218
+ print(f"\n{'═' * 70}")
219
+ print(" RESULTS")
220
+ print(f" Model: {MODEL_NAME}")
221
+ print(f"{'═' * 70}")
222
+ print(f" {'Task':<28} {'Level':<8} {'Steps':>5} {'Best Score'}")
223
+ print(f" {'─' * 28} {'─' * 8} {'─' * 5} {'─' * 30}")
224
+
225
+ total = 0.0
226
+ for r in results:
227
+ title = r.get("task_title", r["task_id"])[:27]
228
+ level = r.get("task_level", "?")
229
+ steps = r.get("attempts", "?")
230
+ score = r["best_score"]
231
+ total += score
232
+ print(f" {title:<28} {level:<8} {steps:>5} {score_bar(score)}")
233
+
234
+ avg = total / len(results) if results else 0.0
235
+ print(f"{'─' * 70}")
236
+ print(f" {'AVERAGE':<28} {'':8} {'':5} {score_bar(avg)}")
237
+ print(f"{'═' * 70}\n")
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
judge.py CHANGED
@@ -282,18 +282,22 @@ Respond with ONLY valid JSON (no markdown fences):
282
 
283
  try:
284
  message = client.messages.create(
285
- model="claude-sonnet-4-6",
286
  max_tokens=512,
287
- messages=[{"role": "user", "content": prompt}],
 
 
 
288
  )
289
- raw = message.content[0].text.strip()
 
 
290
 
291
- # Strip accidental markdown fences
292
- if raw.startswith("```"):
293
- raw = raw.split("```")[1]
294
- if raw.startswith("json"):
295
- raw = raw[4:]
296
- raw = raw.rsplit("```", 1)[0].strip()
297
 
298
  data = json.loads(raw)
299
  score = float(data["score"])
@@ -305,14 +309,13 @@ Respond with ONLY valid JSON (no markdown fences):
305
  except Exception as exc:
306
  # Graceful fallback β€” no API key, network error, or parse failure
307
  msg = str(exc).lower()
308
- reason = (
309
- "no ANTHROPIC_API_KEY set"
310
- if "api_key" in msg or "auth" in msg or "authentication" in msg
311
- else type(exc).__name__
312
- )
313
  return (
314
  deterministic_score,
315
- f"AI judge offline ({reason}). Using deterministic score.",
316
  task.hint,
317
  )
318
 
@@ -378,6 +381,15 @@ def grade(
378
  elif task.level == "medium" and "JOIN " not in query_upper:
379
  structural_penalty = 0.20 # medium task demands explicit JOINs
380
  row_feedback += " (Penalty: no explicit JOIN β€” task requires JOIN … ON syntax.)"
 
 
 
 
 
 
 
 
 
381
 
382
  details["structural_penalty"] = structural_penalty
383
 
 
282
 
283
  try:
284
  message = client.messages.create(
285
+ model=JUDGE_MODEL,
286
  max_tokens=512,
287
+ messages=[
288
+ {"role": "user", "content": prompt},
289
+ {"role": "assistant", "content": "{"}, # prefill forces JSON-only reply
290
+ ],
291
  )
292
+ print("Anthropic judge response:", message.content)
293
+ # Prepend the prefilled "{" back before parsing
294
+ raw = "{" + message.content[0].text.strip()
295
 
296
+ # Belt-and-suspenders: extract the first {...} block in case of any preamble
297
+ brace_start = raw.find("{")
298
+ brace_end = raw.rfind("}") + 1
299
+ if brace_start != -1 and brace_end > brace_start:
300
+ raw = raw[brace_start:brace_end]
 
301
 
302
  data = json.loads(raw)
303
  score = float(data["score"])
 
309
  except Exception as exc:
310
  # Graceful fallback β€” no API key, network error, or parse failure
311
  msg = str(exc).lower()
312
+ if "api_key" in msg or "auth" in msg or "authentication" in msg:
313
+ reason = "ANTHROPIC_API_KEY not set β€” deterministic scoring only (max 0.80)"
314
+ else:
315
+ reason = f"AI judge call failed ({type(exc).__name__}) β€” fell back to deterministic score"
 
316
  return (
317
  deterministic_score,
318
+ f"[AI Judge unavailable] {reason}.",
319
  task.hint,
320
  )
321
 
 
381
  elif task.level == "medium" and "JOIN " not in query_upper:
382
  structural_penalty = 0.20 # medium task demands explicit JOINs
383
  row_feedback += " (Penalty: no explicit JOIN β€” task requires JOIN … ON syntax.)"
384
+ elif task.id == "task_expert_recursive" and "RECURSIVE" not in query_upper:
385
+ structural_penalty = 0.30 # must use recursive CTE, not repeated JOINs
386
+ row_feedback += " (Penalty: WITH RECURSIVE required β€” plain JOIN only fetches one level.)"
387
+ elif task.id == "task_expert_rank" and "ROW_NUMBER" in query_upper:
388
+ structural_penalty = 0.20 # ROW_NUMBER breaks ties β€” must use RANK/DENSE_RANK
389
+ row_feedback += " (Penalty: ROW_NUMBER() drops tied rows β€” use RANK() or DENSE_RANK().)"
390
+ elif task.id == "task_expert_window" and "PARTITION BY" not in query_upper:
391
+ structural_penalty = 0.20 # both window functions need PARTITION BY region
392
+ row_feedback += " (Penalty: missing PARTITION BY β€” both SUM and RANK must be partitioned per region.)"
393
 
394
  details["structural_penalty"] = structural_penalty
395
 
playbook.py CHANGED
@@ -1,10 +1,13 @@
1
  """
2
- QueryForge Local Playbook
3
- ─────────────────────────
4
- Tests the environment directly (no HTTP server needed).
5
 
6
- Run from the queryforge directory:
7
- .venv/bin/python playbook.py
 
 
 
8
 
9
  If ANTHROPIC_API_KEY is set, Stage 4 AI scoring is live.
10
  If not set, the judge falls back to deterministic scoring (capped at 0.80).
@@ -14,13 +17,14 @@ import os
14
  import sys
15
  import textwrap
16
 
17
- # Make imports work whether run directly or as a module
18
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
19
 
20
- from server.queryforge_environment import QueryforgeEnvironment
21
- from models import SQLAction
22
  from tasks import REGISTRY, task_from_dict
23
 
 
 
24
  # ── Formatting helpers ────────────────────────────────────────────────────────
25
 
26
  def _hr(char="═", width=70):
@@ -37,8 +41,9 @@ def _score_bar(score: float, width: int = 30) -> str:
37
  bar = "β–ˆ" * filled + "β–‘" * (width - filled)
38
  return f"[{bar}] {score:.2f}"
39
 
40
- def _print_obs(obs, show_description=False):
41
- if show_description:
 
42
  print()
43
  print(textwrap.indent(obs.task_description, " "))
44
  print()
@@ -48,91 +53,87 @@ def _print_obs(obs, show_description=False):
48
  if obs.execution_error:
49
  print(f" Execution error : {obs.execution_error[:100]}")
50
  print(f" Rows returned : {obs.rows_returned}")
51
- print(f" Score : {_score_bar(obs.reward or 0.0)}")
52
  print(f" Best this ep. : {_score_bar(obs.best_score)}")
53
- # Print just the first 200 chars of feedback to keep output clean
54
  fb = obs.feedback[:250] + ("…" if len(obs.feedback) > 250 else "")
55
  print(f" Feedback : {fb}")
56
  if obs.hint:
57
  print(f" Hint : {obs.hint[:120]}")
58
 
59
- def _attempt(env, label: str, sql: str):
60
  print(f"\n ── Attempt: {label}")
61
  print(f" SQL: {sql[:100]}{'…' if len(sql) > 100 else ''}")
62
- obs = env.step(SQLAction(sql=sql))
63
- _print_obs(obs)
64
- return obs
65
 
66
 
67
  # ── Task runners ──────────────────────────────────────────────────────────────
68
 
69
- def run_easy(env):
70
  _section("TASK 1 Β· EASY β€” Fix Syntax Errors")
71
- env._task_index = 0 # pin to easy
72
- obs = env.reset()
73
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
74
- print(f" Steps: up to {5}")
75
- _print_obs(obs, show_description=True)
76
 
77
- _attempt(env, "still broken",
78
  "SELEC name, age FORM users WEHRE age > 30")
79
 
80
- _attempt(env, "one keyword fixed",
81
  "SELECT name, age FORM users WEHRE age > 30")
82
 
83
- _attempt(env, "all keywords fixed, no filter",
84
  "SELECT name, age FROM users WHERE age > 30")
85
 
86
- obs = _attempt(env, "correct solution",
87
- "SELECT name, age FROM users "
88
- "WHERE age > 30 AND city = 'New York' "
89
- "ORDER BY name ASC")
90
 
91
- print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
92
 
93
 
94
- def run_medium(env):
95
  _section("TASK 2 Β· MEDIUM β€” Fix the Cartesian JOIN")
96
- env._task_index = 1 # pin to medium
97
- obs = env.reset()
98
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
99
- print(f" Steps: up to {5}")
100
- _print_obs(obs, show_description=True)
101
 
102
- _attempt(env, "broken verbatim (cartesian product)",
103
  "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
104
  "FROM orders o, users u, products p "
105
  "WHERE o.user_id = u.id "
106
  "GROUP BY u.name, p.title "
107
  "ORDER BY total_spent DESC")
108
 
109
- _attempt(env, "comma-join but missing product condition",
110
  "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
111
  "FROM orders o, users u, products p "
112
  "WHERE o.user_id = u.id AND o.product_id = p.id "
113
  "GROUP BY u.name, p.title "
114
  "ORDER BY total_spent DESC")
115
 
116
- obs = _attempt(env, "correct INNER JOINs",
117
- "SELECT u.name, p.title, SUM(o.amount) AS total_spent\n"
118
- "FROM orders o\n"
119
- "INNER JOIN users u ON o.user_id = u.id\n"
120
- "INNER JOIN products p ON o.product_id = p.id\n"
121
- "GROUP BY u.name, p.title\n"
122
- "ORDER BY total_spent DESC")
123
 
124
- print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
125
 
126
 
127
- def run_hard(env):
128
  _section("TASK 3 Β· HARD β€” Rewrite Correlated Subquery as CTE")
129
- env._task_index = 2 # pin to hard
130
- obs = env.reset()
131
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
132
- print(f" Steps: up to {6}")
133
- _print_obs(obs, show_description=True)
134
 
135
- _attempt(env, "broken verbatim (no CTE β€” penalised even though rows match)",
136
  "SELECT e.name, e.department_id, e.salary\n"
137
  "FROM employees e\n"
138
  "WHERE e.salary > (\n"
@@ -141,7 +142,7 @@ def run_hard(env):
141
  ")\n"
142
  "ORDER BY e.department_id, e.salary DESC")
143
 
144
- _attempt(env, "halfway β€” CTE defined but wrong join",
145
  "WITH dept_avg AS (\n"
146
  " SELECT department_id, AVG(salary) AS avg_salary\n"
147
  " FROM employees GROUP BY department_id\n"
@@ -151,32 +152,30 @@ def run_hard(env):
151
  "WHERE e.salary > d.avg_salary\n"
152
  "ORDER BY e.department_id, e.salary DESC")
153
 
154
- obs = _attempt(env, "correct CTE with proper JOIN",
155
- "WITH dept_avg AS (\n"
156
- " SELECT department_id, AVG(salary) AS avg_salary\n"
157
- " FROM employees\n"
158
- " GROUP BY department_id\n"
159
- ")\n"
160
- "SELECT e.name, e.department_id, e.salary\n"
161
- "FROM employees e\n"
162
- "JOIN dept_avg d ON e.department_id = d.department_id\n"
163
- "WHERE e.salary > d.avg_salary\n"
164
- "ORDER BY e.department_id, e.salary DESC")
165
-
166
- print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
167
 
 
168
 
169
- # ── Custom task demo ──────────────────────────────────────────────────────────
170
 
171
- def run_custom(env):
172
  _section("TASK 4 Β· CUSTOM β€” NULL Handling in Aggregation")
173
 
174
- # Register a brand-new task at runtime
175
- custom_task = task_from_dict({
176
- "id": "custom_null_avg",
177
- "level": "custom",
178
- "title": "Handle NULLs in Aggregation",
179
- "description": """\
180
  TASK: The query below skips NULL scores, making the class average look higher.
181
  Fix it so NULL scores are treated as 0.
182
 
@@ -190,8 +189,8 @@ ERROR:
190
  NULL values are silently excluded by AVG(), inflating the result.
191
 
192
  GOAL: Return a single row with avg_score that treats NULL as 0.
193
- Expected result: avg_score = 72.5""",
194
- "schema_ddl": """\
195
  CREATE TABLE students (id INTEGER, name VARCHAR, score INTEGER);
196
  INSERT INTO students VALUES
197
  (1, 'Alice', 90),
@@ -201,31 +200,30 @@ INSERT INTO students VALUES
201
  (5, 'Eve', 70),
202
  (6, 'Frank', 50);
203
  """,
204
- "broken_query": "SELECT AVG(score) AS avg_score FROM students",
205
- "error_message": "NULL scores are silently skipped by AVG().",
206
- "hint": "Wrap score with COALESCE(score, 0) before averaging.",
207
- "expected_rows": [{"avg_score": 65.0}],
208
- "solution_query": "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students",
209
- "test_description": "AVG treats NULL as 0 β†’ 65.0",
210
- "max_steps": 4,
211
- })
212
- REGISTRY.register(custom_task)
213
-
214
- obs = env.reset(task_id="custom_null_avg")
215
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
216
- print(f" Steps: up to {custom_task.max_steps}")
217
- _print_obs(obs, show_description=True)
218
 
219
- _attempt(env, "broken (NULL excluded)",
220
  "SELECT AVG(score) AS avg_score FROM students")
221
 
222
- obs = _attempt(env, "correct (COALESCE)",
223
- "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students")
224
 
225
- print(f"\n Episode done: {obs.done} | Best score: {obs.best_score:.2f}")
226
 
227
- # Clean up: remove custom task from registry
228
- REGISTRY.unregister("custom_null_avg")
229
  print(" Custom task unregistered from registry.")
230
 
231
 
@@ -235,15 +233,16 @@ if __name__ == "__main__":
235
  ai_key = os.environ.get("ANTHROPIC_API_KEY")
236
 
237
  _hr("═")
238
- print(" QueryForge β€” Local Playbook")
 
239
  print(f" AI judge : {'LIVE (ANTHROPIC_API_KEY set)' if ai_key else 'OFFLINE (fallback to deterministic, max 0.80)'}")
240
  _hr("═")
241
 
242
- # Create a fresh env for each task so cycling order never matters
243
- run_easy(QueryforgeEnvironment())
244
- run_medium(QueryforgeEnvironment())
245
- run_hard(QueryforgeEnvironment())
246
- run_custom(QueryforgeEnvironment())
247
 
248
  _section("DONE")
249
  print(" All 4 tasks completed.\n")
 
1
  """
2
+ QueryForge Client Playbook
3
+ ──────────────────────────
4
+ Tests the environment through the HTTP server using the QueryforgeEnv client.
5
 
6
+ Requires the server to be running first:
7
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
8
+
9
+ Then run:
10
+ python playbook.py
11
 
12
  If ANTHROPIC_API_KEY is set, Stage 4 AI scoring is live.
13
  If not set, the judge falls back to deterministic scoring (capped at 0.80).
 
17
  import sys
18
  import textwrap
19
 
 
20
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
 
22
+ from client import QueryforgeEnv
23
+ from models import SQLAction, TaskSpec
24
  from tasks import REGISTRY, task_from_dict
25
 
26
+ BASE_URL = "https://prithvigg-queryforge.hf.space"
27
+
28
  # ── Formatting helpers ────────────────────────────────────────────────────────
29
 
30
  def _hr(char="═", width=70):
 
41
  bar = "β–ˆ" * filled + "β–‘" * (width - filled)
42
  return f"[{bar}] {score:.2f}"
43
 
44
+ def _print_result(result, show_description=False):
45
+ obs = result.observation
46
+ if show_description and obs.task_description:
47
  print()
48
  print(textwrap.indent(obs.task_description, " "))
49
  print()
 
53
  if obs.execution_error:
54
  print(f" Execution error : {obs.execution_error[:100]}")
55
  print(f" Rows returned : {obs.rows_returned}")
56
+ print(f" Score : {_score_bar(result.reward or 0.0)}")
57
  print(f" Best this ep. : {_score_bar(obs.best_score)}")
 
58
  fb = obs.feedback[:250] + ("…" if len(obs.feedback) > 250 else "")
59
  print(f" Feedback : {fb}")
60
  if obs.hint:
61
  print(f" Hint : {obs.hint[:120]}")
62
 
63
+ def _attempt(client, label: str, sql: str):
64
  print(f"\n ── Attempt: {label}")
65
  print(f" SQL: {sql[:100]}{'…' if len(sql) > 100 else ''}")
66
+ result = client.step(SQLAction(sql=sql))
67
+ _print_result(result)
68
+ return result
69
 
70
 
71
  # ── Task runners ──────────────────────────────────────────────────────────────
72
 
73
+ def run_easy(client):
74
  _section("TASK 1 Β· EASY β€” Fix Syntax Errors")
75
+ result = client.reset(task_id="task_easy_syntax")
76
+ obs = result.observation
77
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
78
+ _print_result(result, show_description=True)
 
79
 
80
+ _attempt(client, "still broken",
81
  "SELEC name, age FORM users WEHRE age > 30")
82
 
83
+ _attempt(client, "one keyword fixed",
84
  "SELECT name, age FORM users WEHRE age > 30")
85
 
86
+ _attempt(client, "all keywords fixed, no filter",
87
  "SELECT name, age FROM users WHERE age > 30")
88
 
89
+ result = _attempt(client, "correct solution",
90
+ "SELECT name, age FROM users "
91
+ "WHERE age > 30 AND city = 'New York' "
92
+ "ORDER BY name ASC")
93
 
94
+ print(f"\n Episode done: {result.done} | Best score: {result.observation.best_score:.2f}")
95
 
96
 
97
+ def run_medium(client):
98
  _section("TASK 2 Β· MEDIUM β€” Fix the Cartesian JOIN")
99
+ result = client.reset(task_id="task_medium_join")
100
+ obs = result.observation
101
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
102
+ _print_result(result, show_description=True)
 
103
 
104
+ _attempt(client, "broken verbatim (cartesian product)",
105
  "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
106
  "FROM orders o, users u, products p "
107
  "WHERE o.user_id = u.id "
108
  "GROUP BY u.name, p.title "
109
  "ORDER BY total_spent DESC")
110
 
111
+ _attempt(client, "comma-join with product condition (no explicit JOIN)",
112
  "SELECT u.name, p.title, SUM(o.amount) AS total_spent "
113
  "FROM orders o, users u, products p "
114
  "WHERE o.user_id = u.id AND o.product_id = p.id "
115
  "GROUP BY u.name, p.title "
116
  "ORDER BY total_spent DESC")
117
 
118
+ result = _attempt(client, "correct INNER JOINs",
119
+ "SELECT u.name, p.title, SUM(o.amount) AS total_spent\n"
120
+ "FROM orders o\n"
121
+ "INNER JOIN users u ON o.user_id = u.id\n"
122
+ "INNER JOIN products p ON o.product_id = p.id\n"
123
+ "GROUP BY u.name, p.title\n"
124
+ "ORDER BY total_spent DESC")
125
 
126
+ print(f"\n Episode done: {result.done} | Best score: {result.observation.best_score:.2f}")
127
 
128
 
129
+ def run_hard(client):
130
  _section("TASK 3 Β· HARD β€” Rewrite Correlated Subquery as CTE")
131
+ result = client.reset(task_id="task_hard_cte")
132
+ obs = result.observation
133
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
134
+ _print_result(result, show_description=True)
 
135
 
136
+ _attempt(client, "broken verbatim (no CTE)",
137
  "SELECT e.name, e.department_id, e.salary\n"
138
  "FROM employees e\n"
139
  "WHERE e.salary > (\n"
 
142
  ")\n"
143
  "ORDER BY e.department_id, e.salary DESC")
144
 
145
+ _attempt(client, "halfway β€” CTE defined but wrong join",
146
  "WITH dept_avg AS (\n"
147
  " SELECT department_id, AVG(salary) AS avg_salary\n"
148
  " FROM employees GROUP BY department_id\n"
 
152
  "WHERE e.salary > d.avg_salary\n"
153
  "ORDER BY e.department_id, e.salary DESC")
154
 
155
+ result = _attempt(client, "correct CTE with proper JOIN",
156
+ "WITH dept_avg AS (\n"
157
+ " SELECT department_id, AVG(salary) AS avg_salary\n"
158
+ " FROM employees\n"
159
+ " GROUP BY department_id\n"
160
+ ")\n"
161
+ "SELECT e.name, e.department_id, e.salary\n"
162
+ "FROM employees e\n"
163
+ "JOIN dept_avg d ON e.department_id = d.department_id\n"
164
+ "WHERE e.salary > d.avg_salary\n"
165
+ "ORDER BY e.department_id, e.salary DESC")
 
 
166
 
167
+ print(f"\n Episode done: {result.done} | Best score: {result.observation.best_score:.2f}")
168
 
 
169
 
170
+ def run_custom(client):
171
  _section("TASK 4 Β· CUSTOM β€” NULL Handling in Aggregation")
172
 
173
+ # Register a brand-new task at runtime via the REST API
174
+ client.register_task(TaskSpec(
175
+ id="custom_null_avg",
176
+ level="custom",
177
+ title="Handle NULLs in Aggregation",
178
+ description="""\
179
  TASK: The query below skips NULL scores, making the class average look higher.
180
  Fix it so NULL scores are treated as 0.
181
 
 
189
  NULL values are silently excluded by AVG(), inflating the result.
190
 
191
  GOAL: Return a single row with avg_score that treats NULL as 0.
192
+ Expected result: avg_score = 65.0""",
193
+ schema_ddl="""\
194
  CREATE TABLE students (id INTEGER, name VARCHAR, score INTEGER);
195
  INSERT INTO students VALUES
196
  (1, 'Alice', 90),
 
200
  (5, 'Eve', 70),
201
  (6, 'Frank', 50);
202
  """,
203
+ broken_query="SELECT AVG(score) AS avg_score FROM students",
204
+ error_message="NULL scores are silently skipped by AVG().",
205
+ hint="Wrap score with COALESCE(score, 0) before averaging.",
206
+ expected_rows=[{"avg_score": 65.0}],
207
+ solution_query="SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students",
208
+ test_description="AVG treats NULL as 0 β†’ 65.0",
209
+ max_steps=4,
210
+ ))
211
+
212
+ result = client.reset(task_id="custom_null_avg")
213
+ obs = result.observation
214
  print(f"\n Task : {obs.task_title} [{obs.task_level}]")
215
+ _print_result(result, show_description=True)
 
216
 
217
+ _attempt(client, "broken (NULL excluded)",
218
  "SELECT AVG(score) AS avg_score FROM students")
219
 
220
+ result = _attempt(client, "correct (COALESCE)",
221
+ "SELECT AVG(COALESCE(score, 0)) AS avg_score FROM students")
222
 
223
+ print(f"\n Episode done: {result.done} | Best score: {result.observation.best_score:.2f}")
224
 
225
+ # Clean up
226
+ client.delete_task("custom_null_avg")
227
  print(" Custom task unregistered from registry.")
228
 
229
 
 
233
  ai_key = os.environ.get("ANTHROPIC_API_KEY")
234
 
235
  _hr("═")
236
+ print(" QueryForge β€” Client Playbook")
237
+ print(f" Server : {BASE_URL}")
238
  print(f" AI judge : {'LIVE (ANTHROPIC_API_KEY set)' if ai_key else 'OFFLINE (fallback to deterministic, max 0.80)'}")
239
  _hr("═")
240
 
241
+ with QueryforgeEnv(base_url=BASE_URL).sync() as client:
242
+ # run_easy(client)
243
+ run_medium(client)
244
+ run_hard(client)
245
+ # run_custom(client)
246
 
247
  _section("DONE")
248
  print(" All 4 tasks completed.\n")
server/queryforge_environment.py CHANGED
@@ -20,6 +20,8 @@ Episode ends when:
20
  - max_steps for the task is exhausted
21
  """
22
 
 
 
23
  from typing import Optional
24
  from uuid import uuid4
25
 
@@ -35,6 +37,16 @@ except ImportError:
35
  from tasks import REGISTRY, SQLTask
36
  from judge import grade
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class QueryforgeEnvironment(Environment):
40
  """
@@ -97,6 +109,12 @@ class QueryforgeEnvironment(Environment):
97
  self._attempt = 0
98
  self._stale_steps = 0
99
 
 
 
 
 
 
 
100
  if task_id is not None:
101
  try:
102
  self._current_task = REGISTRY.get(task_id)
@@ -142,6 +160,12 @@ class QueryforgeEnvironment(Environment):
142
  reward=0.0,
143
  )
144
 
 
 
 
 
 
 
145
  score, feedback, details = grade(self._current_task, action.sql)
146
 
147
  # Fix 1 β€” early stopping: track consecutive steps with no improvement
 
20
  - max_steps for the task is exhausted
21
  """
22
 
23
+ import logging
24
+ import os
25
  from typing import Optional
26
  from uuid import uuid4
27
 
 
37
  from tasks import REGISTRY, SQLTask
38
  from judge import grade
39
 
40
+ logger = logging.getLogger(__name__)
41
+ _AI_JUDGE_ACTIVE = bool(os.environ.get("ANTHROPIC_API_KEY"))
42
+
43
+ print("here", os.environ.get("ANTHROPIC_API_KEY"))
44
+ logger.info(
45
+ "QueryForge environment loaded | AI judge: %s | done_threshold: %s",
46
+ "ACTIVE (scores up to 1.0)" if _AI_JUDGE_ACTIVE else "OFFLINE β€” deterministic only (max score 0.80)",
47
+ "0.90" if _AI_JUDGE_ACTIVE else "0.80",
48
+ )
49
+
50
 
51
  class QueryforgeEnvironment(Environment):
52
  """
 
109
  self._attempt = 0
110
  self._stale_steps = 0
111
 
112
+ logger.info(
113
+ "reset() | task_id=%s | AI judge: %s",
114
+ task_id or "round-robin",
115
+ "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
116
+ )
117
+
118
  if task_id is not None:
119
  try:
120
  self._current_task = REGISTRY.get(task_id)
 
160
  reward=0.0,
161
  )
162
 
163
+ logger.info(
164
+ "step() | task=%s | attempt=%d | AI judge: %s",
165
+ self._current_task.id,
166
+ self._attempt,
167
+ "ACTIVE" if _AI_JUDGE_ACTIVE else "OFFLINE",
168
+ )
169
  score, feedback, details = grade(self._current_task, action.sql)
170
 
171
  # Fix 1 β€” early stopping: track consecutive steps with no improvement
tasks.py CHANGED
@@ -263,6 +263,283 @@ ORDER BY e.department_id, e.salary DESC""",
263
  )
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  # ── Task Registry ─────────────────────────────────────────────────────────────
267
 
268
  class TaskRegistry:
@@ -273,9 +550,10 @@ class TaskRegistry:
273
  Custom tasks can be added via register(), load_from_json(), or POST /tasks.
274
  """
275
 
276
- _BUILTIN_IDS: frozenset = frozenset(
277
- ["task_easy_syntax", "task_medium_join", "task_hard_cte"]
278
- )
 
279
 
280
  def __init__(self, initial_tasks: List[SQLTask]) -> None:
281
  self._lock = Lock()
@@ -408,8 +686,14 @@ def task_from_dict(d: Dict[str, Any]) -> SQLTask:
408
 
409
  # ── Global singleton ──────────────────────────────────────────────────────────
410
 
411
- REGISTRY = TaskRegistry([_TASK_EASY, _TASK_MEDIUM, _TASK_HARD])
 
 
 
412
 
413
- # Backwards-compat: snapshot of the three built-in tasks at import time
414
- TASKS: List[SQLTask] = [_TASK_EASY, _TASK_MEDIUM, _TASK_HARD]
 
 
 
415
  TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS}
 
263
  )
264
 
265
 
266
+ # ── Expert tasks ──────────────────────────────────────────────────────────────
267
+
268
+ _TASK_EXPERT_RANK = SQLTask(
269
+ id="task_expert_rank",
270
+ level="expert",
271
+ title="Fix the Tie-Breaking Window Function",
272
+ description="""\
273
+ TASK: The query below finds the top-earning sales rep per region, but it
274
+ silently drops reps who are tied for first place. Fix it so ALL reps
275
+ tied at rank 1 are returned.
276
+
277
+ SCHEMA:
278
+ sales_reps(id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL)
279
+
280
+ BROKEN QUERY:
281
+ SELECT name, region, revenue
282
+ FROM (
283
+ SELECT name, region, revenue,
284
+ ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue DESC) AS rn
285
+ FROM sales_reps
286
+ ) ranked
287
+ WHERE rn = 1
288
+ ORDER BY region, name
289
+
290
+ PROBLEM:
291
+ ROW_NUMBER() assigns unique sequential numbers even for tied revenue values.
292
+ When two reps share the top revenue in a region, ROW_NUMBER arbitrarily
293
+ picks one and discards the other.
294
+
295
+ GOAL: Return ALL reps whose revenue is the highest in their region.
296
+ Use RANK() or DENSE_RANK() instead of ROW_NUMBER().
297
+ Order by region ASC, name ASC.""",
298
+ schema_ddl="""\
299
+ CREATE TABLE sales_reps (id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL);
300
+ INSERT INTO sales_reps VALUES
301
+ (1, 'Alice', 'North', 95000),
302
+ (2, 'Bob', 'North', 87000),
303
+ (3, 'Carol', 'North', 95000),
304
+ (4, 'Dave', 'South', 88000),
305
+ (5, 'Eve', 'South', 88000),
306
+ (6, 'Frank', 'South', 75000);
307
+ """,
308
+ broken_query="""\
309
+ SELECT name, region, revenue
310
+ FROM (
311
+ SELECT name, region, revenue,
312
+ ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue DESC) AS rn
313
+ FROM sales_reps
314
+ ) ranked
315
+ WHERE rn = 1
316
+ ORDER BY region, name""",
317
+ error_message=(
318
+ "Query runs but returns only 2 rows β€” one per region. "
319
+ "Tied reps at the top are silently dropped by ROW_NUMBER()."
320
+ ),
321
+ hint="Replace ROW_NUMBER() with RANK() or DENSE_RANK(). Both include all tied rows.",
322
+ test_cases=[
323
+ TestCase(
324
+ description="All reps tied at rank 1 per region",
325
+ expected_rows=[
326
+ {"name": "Alice", "region": "North", "revenue": 95000.0},
327
+ {"name": "Carol", "region": "North", "revenue": 95000.0},
328
+ {"name": "Dave", "region": "South", "revenue": 88000.0},
329
+ {"name": "Eve", "region": "South", "revenue": 88000.0},
330
+ ],
331
+ order_by="region,name",
332
+ )
333
+ ],
334
+ solution_query="""\
335
+ SELECT name, region, revenue
336
+ FROM (
337
+ SELECT name, region, revenue,
338
+ RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS rk
339
+ FROM sales_reps
340
+ ) ranked
341
+ WHERE rk = 1
342
+ ORDER BY region, name""",
343
+ max_steps=6,
344
+ )
345
+
346
+
347
+ _TASK_EXPERT_RECURSIVE = SQLTask(
348
+ id="task_expert_recursive",
349
+ level="expert",
350
+ title="Traverse Org Chart with Recursive CTE",
351
+ description="""\
352
+ TASK: The query below attempts to find all subordinates of the VP of Engineering
353
+ (id=3) using a two-level CTE expansion. It misses employees more than two levels
354
+ deep. Rewrite it using a recursive CTE that traverses all levels.
355
+
356
+ SCHEMA:
357
+ employees(id INTEGER, name VARCHAR, manager_id INTEGER)
358
+
359
+ DATA (partial):
360
+ VP Eng (id=3) β†’ Lead A (id=5), Lead B (id=6)
361
+ Lead A (id=5) β†’ Dev 1 (id=8), Dev 2 (id=9)
362
+ Lead B (id=6) β†’ Dev 3 (id=10), Dev 4 (id=11)
363
+ Dev 1 (id=8) β†’ Junior 1 (id=13), Junior 2 (id=14)
364
+
365
+ BROKEN QUERY:
366
+ WITH direct AS (
367
+ SELECT id, name, manager_id FROM employees WHERE manager_id = 3
368
+ ),
369
+ level2 AS (
370
+ SELECT e.id, e.name, e.manager_id
371
+ FROM employees e
372
+ INNER JOIN direct d ON e.manager_id = d.id
373
+ )
374
+ SELECT id, name, manager_id FROM direct
375
+ UNION ALL
376
+ SELECT id, name, manager_id FROM level2
377
+ ORDER BY id
378
+
379
+ PROBLEM:
380
+ This hardcoded two-level expansion returns 6 rows but misses Junior 1 (id=13)
381
+ and Junior 2 (id=14), who report to Dev 1 β€” three levels below VP Eng.
382
+ Adding a level3 CTE would help for now but still break if the tree grows deeper.
383
+
384
+ GOAL: Use WITH RECURSIVE to return ALL 8 subordinates of VP Eng (id=3)
385
+ at any depth. Return id, name, manager_id columns, ordered by id ASC.""",
386
+ schema_ddl="""\
387
+ CREATE TABLE employees (id INTEGER, name VARCHAR, manager_id INTEGER);
388
+ INSERT INTO employees VALUES
389
+ (1, 'CEO', NULL),
390
+ (2, 'CFO', 1),
391
+ (3, 'VP Eng', 1),
392
+ (4, 'VP Sales', 1),
393
+ (5, 'Lead A', 3),
394
+ (6, 'Lead B', 3),
395
+ (7, 'Sales Mgr',4),
396
+ (8, 'Dev 1', 5),
397
+ (9, 'Dev 2', 5),
398
+ (10, 'Dev 3', 6),
399
+ (11, 'Dev 4', 6),
400
+ (12, 'Sales Rep',7),
401
+ (13, 'Junior 1', 8),
402
+ (14, 'Junior 2', 8);
403
+ """,
404
+ broken_query="""\
405
+ WITH direct AS (
406
+ SELECT id, name, manager_id FROM employees WHERE manager_id = 3
407
+ ),
408
+ level2 AS (
409
+ SELECT e.id, e.name, e.manager_id
410
+ FROM employees e
411
+ INNER JOIN direct d ON e.manager_id = d.id
412
+ )
413
+ SELECT id, name, manager_id FROM direct
414
+ UNION ALL
415
+ SELECT id, name, manager_id FROM level2
416
+ ORDER BY id""",
417
+ error_message=(
418
+ "Query returns only 6 rows β€” two levels under VP Eng. "
419
+ "Junior 1 (id=13) and Junior 2 (id=14) who report to Dev 1 are missing. "
420
+ "A hardcoded level3 CTE would fix this instance but not scale to deeper trees."
421
+ ),
422
+ hint="Use WITH RECURSIVE. Start from manager_id = 3, then JOIN employees to the CTE itself on manager_id = cte.id.",
423
+ test_cases=[
424
+ TestCase(
425
+ description="All 8 subordinates of VP Eng at any depth",
426
+ expected_rows=[
427
+ {"id": 5, "name": "Lead A", "manager_id": 3},
428
+ {"id": 6, "name": "Lead B", "manager_id": 3},
429
+ {"id": 8, "name": "Dev 1", "manager_id": 5},
430
+ {"id": 9, "name": "Dev 2", "manager_id": 5},
431
+ {"id": 10, "name": "Dev 3", "manager_id": 6},
432
+ {"id": 11, "name": "Dev 4", "manager_id": 6},
433
+ {"id": 13, "name": "Junior 1", "manager_id": 8},
434
+ {"id": 14, "name": "Junior 2", "manager_id": 8},
435
+ ],
436
+ order_by="id",
437
+ )
438
+ ],
439
+ solution_query="""\
440
+ WITH RECURSIVE subordinates AS (
441
+ SELECT id, name, manager_id
442
+ FROM employees
443
+ WHERE manager_id = 3
444
+ UNION ALL
445
+ SELECT e.id, e.name, e.manager_id
446
+ FROM employees e
447
+ INNER JOIN subordinates s ON e.manager_id = s.id
448
+ )
449
+ SELECT id, name, manager_id
450
+ FROM subordinates
451
+ ORDER BY id""",
452
+ max_steps=7,
453
+ )
454
+
455
+
456
+ _TASK_EXPERT_WINDOW = SQLTask(
457
+ id="task_expert_window",
458
+ level="expert",
459
+ title="Fix Two Broken Window Functions: Running Total and Revenue Rank",
460
+ description="""\
461
+ TASK: The query below computes a cumulative running total and a
462
+ within-region revenue rank for each quarter, but BOTH window functions
463
+ are broken β€” neither has a PARTITION BY, so they treat all rows as one
464
+ giant partition instead of computing independently per region.
465
+
466
+ SCHEMA:
467
+ quarterly_sales(region VARCHAR, quarter INTEGER, revenue DECIMAL)
468
+
469
+ BROKEN QUERY:
470
+ SELECT region, quarter, revenue,
471
+ SUM(revenue) OVER (ORDER BY region, quarter) AS running_total,
472
+ RANK() OVER (ORDER BY revenue DESC) AS revenue_rank
473
+ FROM quarterly_sales
474
+ ORDER BY region, quarter
475
+
476
+ PROBLEM:
477
+ - running_total accumulates across both regions: West's Q1 shows 65000
478
+ (continuing from East's Q4) instead of resetting to 11000.
479
+ - revenue_rank ranks revenue across ALL regions globally, so East Q4 (20000)
480
+ and West Q3 (16000) compete directly instead of being ranked within their
481
+ own region.
482
+
483
+ GOAL: Fix BOTH window functions so they operate independently per region.
484
+ - running_total must reset to 0 at the start of each region (ORDER BY quarter).
485
+ - revenue_rank must rank revenue within each region (ORDER BY revenue DESC).
486
+ Both OVER clauses need PARTITION BY region, but with different ORDER BY columns.
487
+ Final output: ORDER BY region ASC, quarter ASC.""",
488
+ schema_ddl="""\
489
+ CREATE TABLE quarterly_sales (region VARCHAR, quarter INTEGER, revenue DECIMAL);
490
+ INSERT INTO quarterly_sales VALUES
491
+ ('East', 1, 15000),
492
+ ('East', 2, 18000),
493
+ ('East', 3, 12000),
494
+ ('East', 4, 20000),
495
+ ('West', 1, 11000),
496
+ ('West', 2, 14000),
497
+ ('West', 3, 16000),
498
+ ('West', 4, 13000);
499
+ """,
500
+ broken_query="""\
501
+ SELECT region, quarter, revenue,
502
+ SUM(revenue) OVER (ORDER BY region, quarter) AS running_total,
503
+ RANK() OVER (ORDER BY revenue DESC) AS revenue_rank
504
+ FROM quarterly_sales
505
+ ORDER BY region, quarter""",
506
+ error_message=(
507
+ "Query runs but both window functions are wrong. "
508
+ "West Q1 running_total shows 76000 (continuing from East) instead of 11000. "
509
+ "revenue_rank is a global ranking across all 8 rows instead of per-region. "
510
+ "Both SUM and RANK are missing PARTITION BY region."
511
+ ),
512
+ hint=(
513
+ "Add PARTITION BY region to BOTH window functions, but with different ORDER BY: "
514
+ "SUM(revenue) OVER (PARTITION BY region ORDER BY quarter) for running total, "
515
+ "RANK() OVER (PARTITION BY region ORDER BY revenue DESC) for within-region rank."
516
+ ),
517
+ test_cases=[
518
+ TestCase(
519
+ description="Per-region running total and within-region revenue rank",
520
+ expected_rows=[
521
+ {"region": "East", "quarter": 1, "revenue": 15000.0, "running_total": 15000.0, "revenue_rank": 3},
522
+ {"region": "East", "quarter": 2, "revenue": 18000.0, "running_total": 33000.0, "revenue_rank": 2},
523
+ {"region": "East", "quarter": 3, "revenue": 12000.0, "running_total": 45000.0, "revenue_rank": 4},
524
+ {"region": "East", "quarter": 4, "revenue": 20000.0, "running_total": 65000.0, "revenue_rank": 1},
525
+ {"region": "West", "quarter": 1, "revenue": 11000.0, "running_total": 11000.0, "revenue_rank": 4},
526
+ {"region": "West", "quarter": 2, "revenue": 14000.0, "running_total": 25000.0, "revenue_rank": 3},
527
+ {"region": "West", "quarter": 3, "revenue": 16000.0, "running_total": 41000.0, "revenue_rank": 1},
528
+ {"region": "West", "quarter": 4, "revenue": 13000.0, "running_total": 54000.0, "revenue_rank": 2},
529
+ ],
530
+ order_by="region,quarter",
531
+ )
532
+ ],
533
+ solution_query="""\
534
+ SELECT region, quarter, revenue,
535
+ SUM(revenue) OVER (PARTITION BY region ORDER BY quarter) AS running_total,
536
+ RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS revenue_rank
537
+ FROM quarterly_sales
538
+ ORDER BY region, quarter""",
539
+ max_steps=6,
540
+ )
541
+
542
+
543
  # ── Task Registry ─────────────────────────────────────────────────────────────
544
 
545
  class TaskRegistry:
 
550
  Custom tasks can be added via register(), load_from_json(), or POST /tasks.
551
  """
552
 
553
+ _BUILTIN_IDS: frozenset = frozenset([
554
+ "task_easy_syntax", "task_medium_join", "task_hard_cte",
555
+ "task_expert_rank", "task_expert_recursive", "task_expert_window",
556
+ ])
557
 
558
  def __init__(self, initial_tasks: List[SQLTask]) -> None:
559
  self._lock = Lock()
 
686
 
687
  # ── Global singleton ──────────────────────────────────────────────────────────
688
 
689
+ REGISTRY = TaskRegistry([
690
+ _TASK_EASY, _TASK_MEDIUM, _TASK_HARD,
691
+ _TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW,
692
+ ])
693
 
694
+ # Backwards-compat: snapshot of all built-in tasks at import time
695
+ TASKS: List[SQLTask] = [
696
+ _TASK_EASY, _TASK_MEDIUM, _TASK_HARD,
697
+ _TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW,
698
+ ]
699
  TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS}