Spaces:
Sleeping
Sleeping
| """ | |
| nl2sql-bench/server/grader.py | |
| ============================== | |
| Deterministic, programmatic reward grader. | |
| No LLM-as-judge. Every reward is computed by comparing the agent's SQL | |
| execution results against a ground-truth result set. | |
| Reward decomposition (sums to 1.0 for a perfect first-attempt answer): | |
| +0.10 syntax_ok β query runs without SQLite error | |
| +0.20 columns_match β returned column names match ground truth exactly | |
| +0.20 row_count_match β number of returned rows matches | |
| +0.50 exact_match β full result set equals ground truth (order-aware | |
| for ORDER BY queries, order-agnostic otherwise) | |
| Step penalty: | |
| -0.05 per step beyond the first (encourages solving in fewer steps), | |
| clamped so the minimum is always 0.0. | |
| All rewards are floats in [0.0, 1.0]. | |
| """ | |
| from __future__ import annotations | |
| import sqlite3 | |
| from typing import Any, Dict, List, Optional, Tuple | |
| # ββ Result normalisation βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _normalise_value(v: Any) -> Any: | |
| """Round floats for comparison so 1.2300000001 == 1.23.""" | |
| if isinstance(v, float): | |
| return round(v, 4) | |
| if isinstance(v, str): | |
| return v.strip() | |
| return v | |
| def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]: | |
| return {k: _normalise_value(v) for k, v in row.items()} | |
| def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| return [_normalise_row(r) for r in rows] | |
| # ββ SQL execution ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def execute_query( | |
| conn: sqlite3.Connection, | |
| query: str, | |
| max_rows: int = 200, | |
| ) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: | |
| """ | |
| Execute a SQL query safely. | |
| Returns (rows, error_string). | |
| rows is None on error. | |
| """ | |
| query = query.strip().rstrip(";") | |
| if not query: | |
| return None, "Empty query." | |
| # Block write operations β the environment is read-only from the agent's view. | |
| forbidden = ("insert", "update", "delete", "drop", "alter", | |
| "create", "replace", "truncate", "pragma") | |
| first_word = query.split()[0].lower() if query.split() else "" | |
| if first_word in forbidden: | |
| return None, ( | |
| f"Operation '{first_word.upper()}' is not allowed. " | |
| "Only SELECT queries are permitted." | |
| ) | |
| try: | |
| cur = conn.execute(query) | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)] | |
| return rows, None | |
| except sqlite3.Error as exc: | |
| return None, str(exc) | |
| # ββ Grading logic ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GradeResult: | |
| __slots__ = ( | |
| "reward", "syntax_ok", "columns_match", | |
| "row_count_match", "exact_match", "step_penalty", | |
| "breakdown", | |
| ) | |
| def __init__( | |
| self, | |
| reward: float, | |
| syntax_ok: bool, | |
| columns_match: bool, | |
| row_count_match: bool, | |
| exact_match: bool, | |
| step_penalty: float, | |
| ) -> None: | |
| self.reward = reward | |
| self.syntax_ok = syntax_ok | |
| self.columns_match = columns_match | |
| self.row_count_match = row_count_match | |
| self.exact_match = exact_match | |
| self.step_penalty = step_penalty | |
| self.breakdown = { | |
| "syntax_ok": 0.10 if syntax_ok else 0.0, | |
| "columns_match": 0.20 if (syntax_ok and columns_match) else 0.0, | |
| "row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0, | |
| "exact_match": 0.50 if (syntax_ok and exact_match) else 0.0, | |
| "step_penalty": -step_penalty, | |
| } | |
| def __repr__(self) -> str: # pragma: no cover | |
| return ( | |
| f"GradeResult(reward={self.reward:.3f}, " | |
| f"exact={self.exact_match}, cols={self.columns_match}, " | |
| f"rows={self.row_count_match}, syntax={self.syntax_ok})" | |
| ) | |
| def grade( | |
| actual_rows: Optional[List[Dict[str, Any]]], | |
| ground_truth_rows: List[Dict[str, Any]], | |
| error: Optional[str], | |
| step: int, | |
| order_sensitive: bool = False, | |
| ) -> GradeResult: | |
| """ | |
| Grade the agent's query result against ground truth. | |
| Parameters | |
| ---------- | |
| actual_rows : Rows returned by the agent's query (None on error). | |
| ground_truth_rows : Expected rows (pre-computed at task load time). | |
| error : SQLite error string (None if query ran successfully). | |
| step : Current step number (1-indexed) for penalty calculation. | |
| order_sensitive : If True, row order matters (queries with ORDER BY). | |
| """ | |
| # ββ Syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| syntax_ok = error is None and actual_rows is not None | |
| if not syntax_ok: | |
| return GradeResult( | |
| reward=0.0, | |
| syntax_ok=False, | |
| columns_match=False, | |
| row_count_match=False, | |
| exact_match=False, | |
| step_penalty=0.0, | |
| ) | |
| gt_norm = _normalise_rows(ground_truth_rows) | |
| act_norm = _normalise_rows(actual_rows) | |
| gt_cols = set(gt_norm[0].keys()) if gt_norm else set() | |
| act_cols = set(act_norm[0].keys()) if act_norm else set() | |
| columns_match = act_cols == gt_cols | |
| row_count_match = len(act_norm) == len(gt_norm) | |
| # Exact match: if order matters, compare list; otherwise compare sorted sets | |
| if columns_match and row_count_match: | |
| if order_sensitive: | |
| exact_match = act_norm == gt_norm | |
| else: | |
| # Sort rows by their string representation for order-agnostic compare | |
| def _sort_key(r: Dict) -> str: | |
| return str(sorted(r.items())) | |
| exact_match = ( | |
| sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key) | |
| ) | |
| else: | |
| exact_match = False | |
| # ββ Score assembly ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| raw = ( | |
| 0.10 # syntax | |
| + (0.20 if columns_match else 0.0) | |
| + (0.20 if row_count_match else 0.0) | |
| + (0.50 if exact_match else 0.0) | |
| ) | |
| penalty = max(0.0, step - 1) * 0.05 | |
| reward = float(max(0.0, min(1.0, raw - penalty))) | |
| return GradeResult( | |
| reward=reward, | |
| syntax_ok=syntax_ok, | |
| columns_match=columns_match, | |
| row_count_match=row_count_match, | |
| exact_match=exact_match, | |
| step_penalty=penalty, | |
| ) | |
| # ββ Convenience: pre-compute ground truth rows βββββββββββββββββββββββββββββ | |
| def compute_ground_truth( | |
| conn: sqlite3.Connection, | |
| sql: str, | |
| ) -> List[Dict[str, Any]]: | |
| """Execute the ground-truth SQL and return normalised rows.""" | |
| rows, error = execute_query(conn, sql) | |
| if error or rows is None: | |
| raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}") | |
| return _normalise_rows(rows) | |
| def has_order_by(sql: str) -> bool: | |
| """Heuristic: does the top-level query have an ORDER BY?""" | |
| # Simple check sufficient for our controlled task SQL | |
| return "ORDER BY" in sql.upper() | |