ar9avg's picture
Clamp every grader return value strictly inside (0, 1)
98b87b7
"""
Task definitions for the SQL agent benchmark.
Three difficulty tiers, each with 5 questions and a grader function.
Grader contract: grader(sql, rows, error, attempts) -> float strictly in (0, 1)
- rows: list[dict] from the executed SQL (may be empty)
- error: str | None
- attempts: int (1-indexed count of attempts taken)
All graders return values strictly in (_EPS, 1 - _EPS) so no path can
emit exact 0.0 or 1.0.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Callable, Optional
from env.database import execute_query
# ─── Score clamping (strictly in (0, 1)) ──────────────────────────
_EPS = 0.05 # margin so :.2f/:.3f formatting never rounds to 0.00 or 1.00
def _clamp(x: float) -> float:
"""Clamp to strictly (0, 1). NaN/None β†’ 0.5."""
if x is None or x != x: # None or NaN
return 0.5
return max(_EPS, min(1.0 - _EPS, float(x)))
# ─── Task Definitions ─────────────────────────────────────────────
@dataclass
class TaskQuestion:
id: str
question: str
expected_columns: list[str] # at least these columns should appear
min_rows: int # minimum expected rows
max_rows: Optional[int] = None # None = no upper bound
hint_tables: list[str] = field(default_factory=list) # tables that must be touched
@dataclass
class Task:
id: str
name: str
difficulty: str # "easy" | "medium" | "hard"
description: str
questions: list[TaskQuestion]
grader: Callable # grader(question, sql, rows, error, attempts) -> float
# ─── Grader Helpers ───────────────────────────────────────────────
def _has_required_columns(rows: list[dict], required: list[str]) -> bool:
if not rows:
return False
row_keys = {k.lower() for k in rows[0].keys()}
return all(col.lower() in row_keys for col in required)
def _row_count_score(rows: list[dict], min_rows: int, max_rows: Optional[int]) -> float:
"""Returns a raw score in [0, 1]; graders must clamp before returning."""
n = len(rows)
if n == 0:
return 0.0
if n >= min_rows:
if max_rows is None or n <= max_rows:
return 1.0
return 0.5
return 0.5 * (n / min_rows)
# ─── Task 1: Simple Queries (Easy) ────────────────────────────────
_SIMPLE_QUESTIONS = [
TaskQuestion(
id="sq-01",
question="List all users from the USA.",
expected_columns=["name", "email", "country"],
min_rows=10,
max_rows=25,
hint_tables=["users"],
),
TaskQuestion(
id="sq-02",
question="Show all products in the 'Electronics' category with their prices.",
expected_columns=["name", "price"],
min_rows=8,
max_rows=20,
hint_tables=["products"],
),
TaskQuestion(
id="sq-03",
question="Find all orders with status 'delivered'.",
expected_columns=["id", "status"],
min_rows=30,
max_rows=50,
hint_tables=["orders"],
),
TaskQuestion(
id="sq-04",
question="List all sellers and their countries.",
expected_columns=["name", "country"],
min_rows=10,
max_rows=10,
hint_tables=["sellers"],
),
TaskQuestion(
id="sq-05",
question="Show all reviews with a rating of 5 stars.",
expected_columns=["rating"],
min_rows=15,
max_rows=35,
hint_tables=["reviews"],
),
]
def _grade_simple(
question: TaskQuestion,
sql: str,
rows: list[dict],
error: Optional[str],
attempts: int,
) -> float:
if error:
return _clamp(0.0)
col_ok = _has_required_columns(rows, question.expected_columns)
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
if col_ok and row_score == 1.0:
return _clamp(1.0)
if col_ok or row_score >= 0.5:
return _clamp(0.5)
return _clamp(0.0)
_TASK_SIMPLE = Task(
id="simple_queries",
name="Simple Queries",
difficulty="easy",
description="Single-table SELECT queries with basic filters.",
questions=_SIMPLE_QUESTIONS,
grader=_grade_simple,
)
# ─── Task 2: Join Queries (Medium) ────────────────────────────────
_JOIN_QUESTIONS = [
TaskQuestion(
id="jq-01",
question="Show the total number of orders per user, including the user's name.",
expected_columns=["name"],
min_rows=10,
hint_tables=["users", "orders"],
),
TaskQuestion(
id="jq-02",
question="List products along with the name of their seller.",
expected_columns=["name", "name"], # product name + seller name both called 'name'
min_rows=20,
hint_tables=["products", "sellers"],
),
TaskQuestion(
id="jq-03",
question="Find the average rating for each product category.",
expected_columns=["category"],
min_rows=5,
max_rows=10,
hint_tables=["products", "reviews"],
),
TaskQuestion(
id="jq-04",
question="Show the total revenue (sum of total_price) per seller.",
expected_columns=["name"],
min_rows=5,
hint_tables=["sellers", "products", "orders"],
),
TaskQuestion(
id="jq-05",
question="List the top 5 most reviewed products with their review counts.",
expected_columns=["name"],
min_rows=5,
max_rows=5,
hint_tables=["products", "reviews"],
),
]
def _grade_join(
question: TaskQuestion,
sql: str,
rows: list[dict],
error: Optional[str],
attempts: int,
) -> float:
if error:
return _clamp(0.0)
col_ok = _has_required_columns(rows, [question.expected_columns[0]])
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
base = 0.0
if col_ok and row_score == 1.0:
base = 1.0
elif col_ok or row_score >= 0.5:
base = 0.5
attempt_penalty = max(0.0, 0.1 * (attempts - 1))
return _clamp(base - attempt_penalty)
_TASK_JOIN = Task(
id="join_queries",
name="Join Queries",
difficulty="medium",
description="Multi-table JOINs with GROUP BY and aggregation.",
questions=_JOIN_QUESTIONS,
grader=_grade_join,
)
# ─── Task 3: Complex Queries (Hard) ───────────────────────────────
_COMPLEX_QUESTIONS = [
TaskQuestion(
id="cq-01",
question=(
"Find users who have placed more than 1 order, showing their name "
"and total number of orders, ordered by order count descending."
),
expected_columns=["name"],
min_rows=1,
hint_tables=["users", "orders"],
),
TaskQuestion(
id="cq-02",
question=(
"For each product category, show the category name, number of products, "
"average price, and total stock. Use a CTE."
),
expected_columns=["category"],
min_rows=5,
max_rows=10,
hint_tables=["products"],
),
TaskQuestion(
id="cq-03",
question=(
"Show each seller's name, their total sales revenue, and rank them "
"by revenue using a window function (RANK() or ROW_NUMBER())."
),
expected_columns=["name"],
min_rows=5,
hint_tables=["sellers", "products", "orders"],
),
TaskQuestion(
id="cq-04",
question=(
"Find the top-rated product in each category (highest average review rating). "
"Show category, product name, and average rating."
),
expected_columns=["category", "name"],
min_rows=5,
max_rows=10,
hint_tables=["products", "reviews"],
),
TaskQuestion(
id="cq-05",
question=(
"Calculate the month-over-month order count for 2024, showing year, "
"month, order_count, and a running total."
),
expected_columns=["month"],
min_rows=6,
max_rows=12,
hint_tables=["orders"],
),
]
def _grade_complex(
question: TaskQuestion,
sql: str,
rows: list[dict],
error: Optional[str],
attempts: int,
) -> float:
if error:
return _clamp(0.0)
col_ok = _has_required_columns(rows, question.expected_columns)
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
if not col_ok or row_score == 0.0:
return _clamp(0.0)
if row_score == 1.0 and col_ok:
base = 0.8 + (0.2 if attempts == 1 else 0.0)
else:
base = 0.4
attempt_penalty = 0.1 * (attempts - 1)
return _clamp(base - attempt_penalty)
_TASK_COMPLEX = Task(
id="complex_queries",
name="Complex Queries",
difficulty="hard",
description="CTEs, window functions, and nested aggregations.",
questions=_COMPLEX_QUESTIONS,
grader=_grade_complex,
)
# ─── Registry ─────────────────────────────────────────────────────
TASKS: dict[str, Task] = {
"simple_queries": _TASK_SIMPLE,
"join_queries": _TASK_JOIN,
"complex_queries": _TASK_COMPLEX,
}
def get_task(task_id: str) -> Task:
if task_id not in TASKS:
raise ValueError(f"Unknown task_id: {task_id!r}. Valid: {list(TASKS)}")
return TASKS[task_id]
def get_all_tasks() -> list[Task]:
return list(TASKS.values())
def grade_response(
task_id: str,
question_id: str,
sql: str,
rows: list[dict],
error: Optional[str],
attempts: int,
) -> float:
task = get_task(task_id)
question = next((q for q in task.questions if q.id == question_id), None)
if question is None:
raise ValueError(f"Unknown question_id {question_id!r} in task {task_id!r}")
# Graders already clamp internally; this is a final safety net.
return _clamp(task.grader(question, sql, rows, error, attempts))