""" nl2sql-bench/server/grader.py ============================== Deterministic, programmatic reward grader. No LLM-as-judge. Every reward is computed by comparing the agent's SQL execution results against a ground-truth result set. Reward decomposition (sums to 1.0 for a perfect first-attempt answer): +0.10 syntax_ok — query runs without SQLite error +0.20 columns_match — returned column names match ground truth exactly +0.20 row_count_match — number of returned rows matches +0.50 exact_match — full result set equals ground truth (order-aware for ORDER BY queries, order-agnostic otherwise) Step penalty: -0.05 per step beyond the first (encourages solving in fewer steps), clamped so the minimum is always 0.0. All rewards are floats in [0.0, 1.0]. """ from __future__ import annotations import sqlite3 from typing import Any, Dict, List, Optional, Tuple # ── Result normalisation ─────────────────────────────────────────────────── def _normalise_value(v: Any) -> Any: """Round floats for comparison so 1.2300000001 == 1.23.""" if isinstance(v, float): return round(v, 4) if isinstance(v, str): return v.strip() return v def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]: return {k: _normalise_value(v) for k, v in row.items()} def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return [_normalise_row(r) for r in rows] # ── SQL execution ────────────────────────────────────────────────────────── def execute_query( conn: sqlite3.Connection, query: str, max_rows: int = 200, ) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: """ Execute a SQL query safely. Returns (rows, error_string). rows is None on error. """ query = query.strip().rstrip(";") if not query: return None, "Empty query." # Block write operations — the environment is read-only from the agent's view. forbidden = ("insert", "update", "delete", "drop", "alter", "create", "replace", "truncate", "pragma") first_word = query.split()[0].lower() if query.split() else "" if first_word in forbidden: return None, ( f"Operation '{first_word.upper()}' is not allowed. " "Only SELECT queries are permitted." ) try: cur = conn.execute(query) cols = [d[0] for d in cur.description] if cur.description else [] rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)] return rows, None except sqlite3.Error as exc: return None, str(exc) # ── Grading logic ────────────────────────────────────────────────────────── class GradeResult: __slots__ = ( "reward", "syntax_ok", "columns_match", "row_count_match", "exact_match", "step_penalty", "breakdown", ) def __init__( self, reward: float, syntax_ok: bool, columns_match: bool, row_count_match: bool, exact_match: bool, step_penalty: float, ) -> None: self.reward = reward self.syntax_ok = syntax_ok self.columns_match = columns_match self.row_count_match = row_count_match self.exact_match = exact_match self.step_penalty = step_penalty self.breakdown = { "syntax_ok": 0.10 if syntax_ok else 0.0, "columns_match": 0.20 if (syntax_ok and columns_match) else 0.0, "row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0, "exact_match": 0.50 if (syntax_ok and exact_match) else 0.0, "step_penalty": -step_penalty, } def __repr__(self) -> str: # pragma: no cover return ( f"GradeResult(reward={self.reward:.3f}, " f"exact={self.exact_match}, cols={self.columns_match}, " f"rows={self.row_count_match}, syntax={self.syntax_ok})" ) def grade( actual_rows: Optional[List[Dict[str, Any]]], ground_truth_rows: List[Dict[str, Any]], error: Optional[str], step: int, order_sensitive: bool = False, ) -> GradeResult: """ Grade the agent's query result against ground truth. Parameters ---------- actual_rows : Rows returned by the agent's query (None on error). ground_truth_rows : Expected rows (pre-computed at task load time). error : SQLite error string (None if query ran successfully). step : Current step number (1-indexed) for penalty calculation. order_sensitive : If True, row order matters (queries with ORDER BY). """ # ── Syntax ────────────────────────────────────────────────────────── syntax_ok = error is None and actual_rows is not None if not syntax_ok: return GradeResult( reward=0.0, syntax_ok=False, columns_match=False, row_count_match=False, exact_match=False, step_penalty=0.0, ) gt_norm = _normalise_rows(ground_truth_rows) act_norm = _normalise_rows(actual_rows) gt_cols = set(gt_norm[0].keys()) if gt_norm else set() act_cols = set(act_norm[0].keys()) if act_norm else set() columns_match = act_cols == gt_cols row_count_match = len(act_norm) == len(gt_norm) # Exact match: if order matters, compare list; otherwise compare sorted sets if columns_match and row_count_match: if order_sensitive: exact_match = act_norm == gt_norm else: # Sort rows by their string representation for order-agnostic compare def _sort_key(r: Dict) -> str: return str(sorted(r.items())) exact_match = ( sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key) ) else: exact_match = False # ── Score assembly ──────────────────────────────────────────────── raw = ( 0.10 # syntax + (0.20 if columns_match else 0.0) + (0.20 if row_count_match else 0.0) + (0.50 if exact_match else 0.0) ) penalty = max(0.0, step - 1) * 0.05 reward = float(max(0.0, min(1.0, raw - penalty))) return GradeResult( reward=reward, syntax_ok=syntax_ok, columns_match=columns_match, row_count_match=row_count_match, exact_match=exact_match, step_penalty=penalty, ) # ── Convenience: pre-compute ground truth rows ───────────────────────────── def compute_ground_truth( conn: sqlite3.Connection, sql: str, ) -> List[Dict[str, Any]]: """Execute the ground-truth SQL and return normalised rows.""" rows, error = execute_query(conn, sql) if error or rows is None: raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}") return _normalise_rows(rows) def has_order_by(sql: str) -> bool: """Heuristic: does the top-level query have an ORDER BY?""" # Simple check sufficient for our controlled task SQL return "ORDER BY" in sql.upper()