nl2sql-bench / server /grader.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
nl2sql-bench/server/grader.py
==============================
Deterministic, programmatic reward grader.
No LLM-as-judge. Every reward is computed by comparing the agent's SQL
execution results against a ground-truth result set.
Reward decomposition (sums to 1.0 for a perfect first-attempt answer):
+0.10 syntax_ok β€” query runs without SQLite error
+0.20 columns_match β€” returned column names match ground truth exactly
+0.20 row_count_match β€” number of returned rows matches
+0.50 exact_match β€” full result set equals ground truth (order-aware
for ORDER BY queries, order-agnostic otherwise)
Step penalty:
-0.05 per step beyond the first (encourages solving in fewer steps),
clamped so the minimum is always 0.0.
All rewards are floats in [0.0, 1.0].
"""
from __future__ import annotations
import sqlite3
from typing import Any, Dict, List, Optional, Tuple
# ── Result normalisation ───────────────────────────────────────────────────
def _normalise_value(v: Any) -> Any:
"""Round floats for comparison so 1.2300000001 == 1.23."""
if isinstance(v, float):
return round(v, 4)
if isinstance(v, str):
return v.strip()
return v
def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]:
return {k: _normalise_value(v) for k, v in row.items()}
def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [_normalise_row(r) for r in rows]
# ── SQL execution ──────────────────────────────────────────────────────────
def execute_query(
conn: sqlite3.Connection,
query: str,
max_rows: int = 200,
) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
"""
Execute a SQL query safely.
Returns (rows, error_string).
rows is None on error.
"""
query = query.strip().rstrip(";")
if not query:
return None, "Empty query."
# Block write operations β€” the environment is read-only from the agent's view.
forbidden = ("insert", "update", "delete", "drop", "alter",
"create", "replace", "truncate", "pragma")
first_word = query.split()[0].lower() if query.split() else ""
if first_word in forbidden:
return None, (
f"Operation '{first_word.upper()}' is not allowed. "
"Only SELECT queries are permitted."
)
try:
cur = conn.execute(query)
cols = [d[0] for d in cur.description] if cur.description else []
rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)]
return rows, None
except sqlite3.Error as exc:
return None, str(exc)
# ── Grading logic ──────────────────────────────────────────────────────────
class GradeResult:
__slots__ = (
"reward", "syntax_ok", "columns_match",
"row_count_match", "exact_match", "step_penalty",
"breakdown",
)
def __init__(
self,
reward: float,
syntax_ok: bool,
columns_match: bool,
row_count_match: bool,
exact_match: bool,
step_penalty: float,
) -> None:
self.reward = reward
self.syntax_ok = syntax_ok
self.columns_match = columns_match
self.row_count_match = row_count_match
self.exact_match = exact_match
self.step_penalty = step_penalty
self.breakdown = {
"syntax_ok": 0.10 if syntax_ok else 0.0,
"columns_match": 0.20 if (syntax_ok and columns_match) else 0.0,
"row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0,
"exact_match": 0.50 if (syntax_ok and exact_match) else 0.0,
"step_penalty": -step_penalty,
}
def __repr__(self) -> str: # pragma: no cover
return (
f"GradeResult(reward={self.reward:.3f}, "
f"exact={self.exact_match}, cols={self.columns_match}, "
f"rows={self.row_count_match}, syntax={self.syntax_ok})"
)
def grade(
actual_rows: Optional[List[Dict[str, Any]]],
ground_truth_rows: List[Dict[str, Any]],
error: Optional[str],
step: int,
order_sensitive: bool = False,
) -> GradeResult:
"""
Grade the agent's query result against ground truth.
Parameters
----------
actual_rows : Rows returned by the agent's query (None on error).
ground_truth_rows : Expected rows (pre-computed at task load time).
error : SQLite error string (None if query ran successfully).
step : Current step number (1-indexed) for penalty calculation.
order_sensitive : If True, row order matters (queries with ORDER BY).
"""
# ── Syntax ──────────────────────────────────────────────────────────
syntax_ok = error is None and actual_rows is not None
if not syntax_ok:
return GradeResult(
reward=0.0,
syntax_ok=False,
columns_match=False,
row_count_match=False,
exact_match=False,
step_penalty=0.0,
)
gt_norm = _normalise_rows(ground_truth_rows)
act_norm = _normalise_rows(actual_rows)
gt_cols = set(gt_norm[0].keys()) if gt_norm else set()
act_cols = set(act_norm[0].keys()) if act_norm else set()
columns_match = act_cols == gt_cols
row_count_match = len(act_norm) == len(gt_norm)
# Exact match: if order matters, compare list; otherwise compare sorted sets
if columns_match and row_count_match:
if order_sensitive:
exact_match = act_norm == gt_norm
else:
# Sort rows by their string representation for order-agnostic compare
def _sort_key(r: Dict) -> str:
return str(sorted(r.items()))
exact_match = (
sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key)
)
else:
exact_match = False
# ── Score assembly ────────────────────────────────────────────────
raw = (
0.10 # syntax
+ (0.20 if columns_match else 0.0)
+ (0.20 if row_count_match else 0.0)
+ (0.50 if exact_match else 0.0)
)
penalty = max(0.0, step - 1) * 0.05
reward = float(max(0.0, min(1.0, raw - penalty)))
return GradeResult(
reward=reward,
syntax_ok=syntax_ok,
columns_match=columns_match,
row_count_match=row_count_match,
exact_match=exact_match,
step_penalty=penalty,
)
# ── Convenience: pre-compute ground truth rows ─────────────────────────────
def compute_ground_truth(
conn: sqlite3.Connection,
sql: str,
) -> List[Dict[str, Any]]:
"""Execute the ground-truth SQL and return normalised rows."""
rows, error = execute_query(conn, sql)
if error or rows is None:
raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}")
return _normalise_rows(rows)
def has_order_by(sql: str) -> bool:
"""Heuristic: does the top-level query have an ORDER BY?"""
# Simple check sufficient for our controlled task SQL
return "ORDER BY" in sql.upper()