""" Task definitions for the SQL agent benchmark. Three difficulty tiers, each with 5 questions and a grader function. Grader contract: grader(sql, rows, error, attempts) -> float strictly in (0, 1) - rows: list[dict] from the executed SQL (may be empty) - error: str | None - attempts: int (1-indexed count of attempts taken) All graders return values strictly in (_EPS, 1 - _EPS) so no path can emit exact 0.0 or 1.0. """ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Callable, Optional from env.database import execute_query # ─── Score clamping (strictly in (0, 1)) ────────────────────────── _EPS = 0.05 # margin so :.2f/:.3f formatting never rounds to 0.00 or 1.00 def _clamp(x: float) -> float: """Clamp to strictly (0, 1). NaN/None → 0.5.""" if x is None or x != x: # None or NaN return 0.5 return max(_EPS, min(1.0 - _EPS, float(x))) # ─── Task Definitions ───────────────────────────────────────────── @dataclass class TaskQuestion: id: str question: str expected_columns: list[str] # at least these columns should appear min_rows: int # minimum expected rows max_rows: Optional[int] = None # None = no upper bound hint_tables: list[str] = field(default_factory=list) # tables that must be touched @dataclass class Task: id: str name: str difficulty: str # "easy" | "medium" | "hard" description: str questions: list[TaskQuestion] grader: Callable # grader(question, sql, rows, error, attempts) -> float # ─── Grader Helpers ─────────────────────────────────────────────── def _has_required_columns(rows: list[dict], required: list[str]) -> bool: if not rows: return False row_keys = {k.lower() for k in rows[0].keys()} return all(col.lower() in row_keys for col in required) def _row_count_score(rows: list[dict], min_rows: int, max_rows: Optional[int]) -> float: """Returns a raw score in [0, 1]; graders must clamp before returning.""" n = len(rows) if n == 0: return 0.0 if n >= min_rows: if max_rows is None or n <= max_rows: return 1.0 return 0.5 return 0.5 * (n / min_rows) # ─── Task 1: Simple Queries (Easy) ──────────────────────────────── _SIMPLE_QUESTIONS = [ TaskQuestion( id="sq-01", question="List all users from the USA.", expected_columns=["name", "email", "country"], min_rows=10, max_rows=25, hint_tables=["users"], ), TaskQuestion( id="sq-02", question="Show all products in the 'Electronics' category with their prices.", expected_columns=["name", "price"], min_rows=8, max_rows=20, hint_tables=["products"], ), TaskQuestion( id="sq-03", question="Find all orders with status 'delivered'.", expected_columns=["id", "status"], min_rows=30, max_rows=50, hint_tables=["orders"], ), TaskQuestion( id="sq-04", question="List all sellers and their countries.", expected_columns=["name", "country"], min_rows=10, max_rows=10, hint_tables=["sellers"], ), TaskQuestion( id="sq-05", question="Show all reviews with a rating of 5 stars.", expected_columns=["rating"], min_rows=15, max_rows=35, hint_tables=["reviews"], ), ] def _grade_simple( question: TaskQuestion, sql: str, rows: list[dict], error: Optional[str], attempts: int, ) -> float: if error: return _clamp(0.0) col_ok = _has_required_columns(rows, question.expected_columns) row_score = _row_count_score(rows, question.min_rows, question.max_rows) if col_ok and row_score == 1.0: return _clamp(1.0) if col_ok or row_score >= 0.5: return _clamp(0.5) return _clamp(0.0) _TASK_SIMPLE = Task( id="simple_queries", name="Simple Queries", difficulty="easy", description="Single-table SELECT queries with basic filters.", questions=_SIMPLE_QUESTIONS, grader=_grade_simple, ) # ─── Task 2: Join Queries (Medium) ──────────────────────────────── _JOIN_QUESTIONS = [ TaskQuestion( id="jq-01", question="Show the total number of orders per user, including the user's name.", expected_columns=["name"], min_rows=10, hint_tables=["users", "orders"], ), TaskQuestion( id="jq-02", question="List products along with the name of their seller.", expected_columns=["name", "name"], # product name + seller name both called 'name' min_rows=20, hint_tables=["products", "sellers"], ), TaskQuestion( id="jq-03", question="Find the average rating for each product category.", expected_columns=["category"], min_rows=5, max_rows=10, hint_tables=["products", "reviews"], ), TaskQuestion( id="jq-04", question="Show the total revenue (sum of total_price) per seller.", expected_columns=["name"], min_rows=5, hint_tables=["sellers", "products", "orders"], ), TaskQuestion( id="jq-05", question="List the top 5 most reviewed products with their review counts.", expected_columns=["name"], min_rows=5, max_rows=5, hint_tables=["products", "reviews"], ), ] def _grade_join( question: TaskQuestion, sql: str, rows: list[dict], error: Optional[str], attempts: int, ) -> float: if error: return _clamp(0.0) col_ok = _has_required_columns(rows, [question.expected_columns[0]]) row_score = _row_count_score(rows, question.min_rows, question.max_rows) base = 0.0 if col_ok and row_score == 1.0: base = 1.0 elif col_ok or row_score >= 0.5: base = 0.5 attempt_penalty = max(0.0, 0.1 * (attempts - 1)) return _clamp(base - attempt_penalty) _TASK_JOIN = Task( id="join_queries", name="Join Queries", difficulty="medium", description="Multi-table JOINs with GROUP BY and aggregation.", questions=_JOIN_QUESTIONS, grader=_grade_join, ) # ─── Task 3: Complex Queries (Hard) ─────────────────────────────── _COMPLEX_QUESTIONS = [ TaskQuestion( id="cq-01", question=( "Find users who have placed more than 1 order, showing their name " "and total number of orders, ordered by order count descending." ), expected_columns=["name"], min_rows=1, hint_tables=["users", "orders"], ), TaskQuestion( id="cq-02", question=( "For each product category, show the category name, number of products, " "average price, and total stock. Use a CTE." ), expected_columns=["category"], min_rows=5, max_rows=10, hint_tables=["products"], ), TaskQuestion( id="cq-03", question=( "Show each seller's name, their total sales revenue, and rank them " "by revenue using a window function (RANK() or ROW_NUMBER())." ), expected_columns=["name"], min_rows=5, hint_tables=["sellers", "products", "orders"], ), TaskQuestion( id="cq-04", question=( "Find the top-rated product in each category (highest average review rating). " "Show category, product name, and average rating." ), expected_columns=["category", "name"], min_rows=5, max_rows=10, hint_tables=["products", "reviews"], ), TaskQuestion( id="cq-05", question=( "Calculate the month-over-month order count for 2024, showing year, " "month, order_count, and a running total." ), expected_columns=["month"], min_rows=6, max_rows=12, hint_tables=["orders"], ), ] def _grade_complex( question: TaskQuestion, sql: str, rows: list[dict], error: Optional[str], attempts: int, ) -> float: if error: return _clamp(0.0) col_ok = _has_required_columns(rows, question.expected_columns) row_score = _row_count_score(rows, question.min_rows, question.max_rows) if not col_ok or row_score == 0.0: return _clamp(0.0) if row_score == 1.0 and col_ok: base = 0.8 + (0.2 if attempts == 1 else 0.0) else: base = 0.4 attempt_penalty = 0.1 * (attempts - 1) return _clamp(base - attempt_penalty) _TASK_COMPLEX = Task( id="complex_queries", name="Complex Queries", difficulty="hard", description="CTEs, window functions, and nested aggregations.", questions=_COMPLEX_QUESTIONS, grader=_grade_complex, ) # ─── Registry ───────────────────────────────────────────────────── TASKS: dict[str, Task] = { "simple_queries": _TASK_SIMPLE, "join_queries": _TASK_JOIN, "complex_queries": _TASK_COMPLEX, } def get_task(task_id: str) -> Task: if task_id not in TASKS: raise ValueError(f"Unknown task_id: {task_id!r}. Valid: {list(TASKS)}") return TASKS[task_id] def get_all_tasks() -> list[Task]: return list(TASKS.values()) def grade_response( task_id: str, question_id: str, sql: str, rows: list[dict], error: Optional[str], attempts: int, ) -> float: task = get_task(task_id) question = next((q for q in task.questions if q.id == question_id), None) if question is None: raise ValueError(f"Unknown question_id {question_id!r} in task {task_id!r}") # Graders already clamp internally; this is a final safety net. return _clamp(task.grader(question, sql, rows, error, attempts))