Spaces:
Running
Running
| """ | |
| Task definitions for the SQL agent benchmark. | |
| Three difficulty tiers, each with 5 questions and a grader function. | |
| Grader contract: grader(sql, rows, error, attempts) -> float strictly in (0, 1) | |
| - rows: list[dict] from the executed SQL (may be empty) | |
| - error: str | None | |
| - attempts: int (1-indexed count of attempts taken) | |
| All graders return values strictly in (_EPS, 1 - _EPS) so no path can | |
| emit exact 0.0 or 1.0. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Callable, Optional | |
| from env.database import execute_query | |
| # βββ Score clamping (strictly in (0, 1)) ββββββββββββββββββββββββββ | |
| _EPS = 0.05 # margin so :.2f/:.3f formatting never rounds to 0.00 or 1.00 | |
| def _clamp(x: float) -> float: | |
| """Clamp to strictly (0, 1). NaN/None β 0.5.""" | |
| if x is None or x != x: # None or NaN | |
| return 0.5 | |
| return max(_EPS, min(1.0 - _EPS, float(x))) | |
| # βββ Task Definitions βββββββββββββββββββββββββββββββββββββββββββββ | |
| class TaskQuestion: | |
| id: str | |
| question: str | |
| expected_columns: list[str] # at least these columns should appear | |
| min_rows: int # minimum expected rows | |
| max_rows: Optional[int] = None # None = no upper bound | |
| hint_tables: list[str] = field(default_factory=list) # tables that must be touched | |
| class Task: | |
| id: str | |
| name: str | |
| difficulty: str # "easy" | "medium" | "hard" | |
| description: str | |
| questions: list[TaskQuestion] | |
| grader: Callable # grader(question, sql, rows, error, attempts) -> float | |
| # βββ Grader Helpers βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _has_required_columns(rows: list[dict], required: list[str]) -> bool: | |
| if not rows: | |
| return False | |
| row_keys = {k.lower() for k in rows[0].keys()} | |
| return all(col.lower() in row_keys for col in required) | |
| def _row_count_score(rows: list[dict], min_rows: int, max_rows: Optional[int]) -> float: | |
| """Returns a raw score in [0, 1]; graders must clamp before returning.""" | |
| n = len(rows) | |
| if n == 0: | |
| return 0.0 | |
| if n >= min_rows: | |
| if max_rows is None or n <= max_rows: | |
| return 1.0 | |
| return 0.5 | |
| return 0.5 * (n / min_rows) | |
| # βββ Task 1: Simple Queries (Easy) ββββββββββββββββββββββββββββββββ | |
| _SIMPLE_QUESTIONS = [ | |
| TaskQuestion( | |
| id="sq-01", | |
| question="List all users from the USA.", | |
| expected_columns=["name", "email", "country"], | |
| min_rows=10, | |
| max_rows=25, | |
| hint_tables=["users"], | |
| ), | |
| TaskQuestion( | |
| id="sq-02", | |
| question="Show all products in the 'Electronics' category with their prices.", | |
| expected_columns=["name", "price"], | |
| min_rows=8, | |
| max_rows=20, | |
| hint_tables=["products"], | |
| ), | |
| TaskQuestion( | |
| id="sq-03", | |
| question="Find all orders with status 'delivered'.", | |
| expected_columns=["id", "status"], | |
| min_rows=30, | |
| max_rows=50, | |
| hint_tables=["orders"], | |
| ), | |
| TaskQuestion( | |
| id="sq-04", | |
| question="List all sellers and their countries.", | |
| expected_columns=["name", "country"], | |
| min_rows=10, | |
| max_rows=10, | |
| hint_tables=["sellers"], | |
| ), | |
| TaskQuestion( | |
| id="sq-05", | |
| question="Show all reviews with a rating of 5 stars.", | |
| expected_columns=["rating"], | |
| min_rows=15, | |
| max_rows=35, | |
| hint_tables=["reviews"], | |
| ), | |
| ] | |
| def _grade_simple( | |
| question: TaskQuestion, | |
| sql: str, | |
| rows: list[dict], | |
| error: Optional[str], | |
| attempts: int, | |
| ) -> float: | |
| if error: | |
| return _clamp(0.0) | |
| col_ok = _has_required_columns(rows, question.expected_columns) | |
| row_score = _row_count_score(rows, question.min_rows, question.max_rows) | |
| if col_ok and row_score == 1.0: | |
| return _clamp(1.0) | |
| if col_ok or row_score >= 0.5: | |
| return _clamp(0.5) | |
| return _clamp(0.0) | |
| _TASK_SIMPLE = Task( | |
| id="simple_queries", | |
| name="Simple Queries", | |
| difficulty="easy", | |
| description="Single-table SELECT queries with basic filters.", | |
| questions=_SIMPLE_QUESTIONS, | |
| grader=_grade_simple, | |
| ) | |
| # βββ Task 2: Join Queries (Medium) ββββββββββββββββββββββββββββββββ | |
| _JOIN_QUESTIONS = [ | |
| TaskQuestion( | |
| id="jq-01", | |
| question="Show the total number of orders per user, including the user's name.", | |
| expected_columns=["name"], | |
| min_rows=10, | |
| hint_tables=["users", "orders"], | |
| ), | |
| TaskQuestion( | |
| id="jq-02", | |
| question="List products along with the name of their seller.", | |
| expected_columns=["name", "name"], # product name + seller name both called 'name' | |
| min_rows=20, | |
| hint_tables=["products", "sellers"], | |
| ), | |
| TaskQuestion( | |
| id="jq-03", | |
| question="Find the average rating for each product category.", | |
| expected_columns=["category"], | |
| min_rows=5, | |
| max_rows=10, | |
| hint_tables=["products", "reviews"], | |
| ), | |
| TaskQuestion( | |
| id="jq-04", | |
| question="Show the total revenue (sum of total_price) per seller.", | |
| expected_columns=["name"], | |
| min_rows=5, | |
| hint_tables=["sellers", "products", "orders"], | |
| ), | |
| TaskQuestion( | |
| id="jq-05", | |
| question="List the top 5 most reviewed products with their review counts.", | |
| expected_columns=["name"], | |
| min_rows=5, | |
| max_rows=5, | |
| hint_tables=["products", "reviews"], | |
| ), | |
| ] | |
| def _grade_join( | |
| question: TaskQuestion, | |
| sql: str, | |
| rows: list[dict], | |
| error: Optional[str], | |
| attempts: int, | |
| ) -> float: | |
| if error: | |
| return _clamp(0.0) | |
| col_ok = _has_required_columns(rows, [question.expected_columns[0]]) | |
| row_score = _row_count_score(rows, question.min_rows, question.max_rows) | |
| base = 0.0 | |
| if col_ok and row_score == 1.0: | |
| base = 1.0 | |
| elif col_ok or row_score >= 0.5: | |
| base = 0.5 | |
| attempt_penalty = max(0.0, 0.1 * (attempts - 1)) | |
| return _clamp(base - attempt_penalty) | |
| _TASK_JOIN = Task( | |
| id="join_queries", | |
| name="Join Queries", | |
| difficulty="medium", | |
| description="Multi-table JOINs with GROUP BY and aggregation.", | |
| questions=_JOIN_QUESTIONS, | |
| grader=_grade_join, | |
| ) | |
| # βββ Task 3: Complex Queries (Hard) βββββββββββββββββββββββββββββββ | |
| _COMPLEX_QUESTIONS = [ | |
| TaskQuestion( | |
| id="cq-01", | |
| question=( | |
| "Find users who have placed more than 1 order, showing their name " | |
| "and total number of orders, ordered by order count descending." | |
| ), | |
| expected_columns=["name"], | |
| min_rows=1, | |
| hint_tables=["users", "orders"], | |
| ), | |
| TaskQuestion( | |
| id="cq-02", | |
| question=( | |
| "For each product category, show the category name, number of products, " | |
| "average price, and total stock. Use a CTE." | |
| ), | |
| expected_columns=["category"], | |
| min_rows=5, | |
| max_rows=10, | |
| hint_tables=["products"], | |
| ), | |
| TaskQuestion( | |
| id="cq-03", | |
| question=( | |
| "Show each seller's name, their total sales revenue, and rank them " | |
| "by revenue using a window function (RANK() or ROW_NUMBER())." | |
| ), | |
| expected_columns=["name"], | |
| min_rows=5, | |
| hint_tables=["sellers", "products", "orders"], | |
| ), | |
| TaskQuestion( | |
| id="cq-04", | |
| question=( | |
| "Find the top-rated product in each category (highest average review rating). " | |
| "Show category, product name, and average rating." | |
| ), | |
| expected_columns=["category", "name"], | |
| min_rows=5, | |
| max_rows=10, | |
| hint_tables=["products", "reviews"], | |
| ), | |
| TaskQuestion( | |
| id="cq-05", | |
| question=( | |
| "Calculate the month-over-month order count for 2024, showing year, " | |
| "month, order_count, and a running total." | |
| ), | |
| expected_columns=["month"], | |
| min_rows=6, | |
| max_rows=12, | |
| hint_tables=["orders"], | |
| ), | |
| ] | |
| def _grade_complex( | |
| question: TaskQuestion, | |
| sql: str, | |
| rows: list[dict], | |
| error: Optional[str], | |
| attempts: int, | |
| ) -> float: | |
| if error: | |
| return _clamp(0.0) | |
| col_ok = _has_required_columns(rows, question.expected_columns) | |
| row_score = _row_count_score(rows, question.min_rows, question.max_rows) | |
| if not col_ok or row_score == 0.0: | |
| return _clamp(0.0) | |
| if row_score == 1.0 and col_ok: | |
| base = 0.8 + (0.2 if attempts == 1 else 0.0) | |
| else: | |
| base = 0.4 | |
| attempt_penalty = 0.1 * (attempts - 1) | |
| return _clamp(base - attempt_penalty) | |
| _TASK_COMPLEX = Task( | |
| id="complex_queries", | |
| name="Complex Queries", | |
| difficulty="hard", | |
| description="CTEs, window functions, and nested aggregations.", | |
| questions=_COMPLEX_QUESTIONS, | |
| grader=_grade_complex, | |
| ) | |
| # βββ Registry βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TASKS: dict[str, Task] = { | |
| "simple_queries": _TASK_SIMPLE, | |
| "join_queries": _TASK_JOIN, | |
| "complex_queries": _TASK_COMPLEX, | |
| } | |
| def get_task(task_id: str) -> Task: | |
| if task_id not in TASKS: | |
| raise ValueError(f"Unknown task_id: {task_id!r}. Valid: {list(TASKS)}") | |
| return TASKS[task_id] | |
| def get_all_tasks() -> list[Task]: | |
| return list(TASKS.values()) | |
| def grade_response( | |
| task_id: str, | |
| question_id: str, | |
| sql: str, | |
| rows: list[dict], | |
| error: Optional[str], | |
| attempts: int, | |
| ) -> float: | |
| task = get_task(task_id) | |
| question = next((q for q in task.questions if q.id == question_id), None) | |
| if question is None: | |
| raise ValueError(f"Unknown question_id {question_id!r} in task {task_id!r}") | |
| # Graders already clamp internally; this is a final safety net. | |
| return _clamp(task.grader(question, sql, rows, error, attempts)) | |