sql_env / server /verifier.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Answer verification for SQLEnv using type-aware comparisons."""
from __future__ import annotations
import ast
import re
def _strip_answer_wrapping(text: str) -> str:
"""Remove common LLM wrapping artifacts from an answer string.
Strips markdown code fences, surrounding quotes, "Answer: " prefix,
and extra whitespace so the type-aware comparators see clean values.
"""
s = text.strip()
# Markdown code blocks: ```...``` or ```sql\n...\n```
if s.startswith("```") and s.endswith("```"):
# Language tag only if followed by newline (e.g. ```sql\n)
s = re.sub(r"^```(?:\w+\n|\n?)", "", s)
s = re.sub(r"\n?```$", "", s)
s = s.strip()
# "Answer: " or "answer:" prefix
s = re.sub(r"^[Aa]nswer:\s*", "", s)
# Surrounding quotes (single or double)
if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
s = s[1:-1].strip()
return s
def verify_answer(
predicted: str,
gold: str,
answer_type: str | None = None,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare submitted and gold answers with type-aware dispatch."""
predicted_text = "" if predicted is None else str(predicted)
gold_text = "" if gold is None else str(gold)
predicted_text = _strip_answer_wrapping(predicted_text)
if not predicted_text.strip():
return False
match answer_type:
case "integer":
return _compare_integer(predicted_text, gold_text)
case "float":
return _compare_float(predicted_text, gold_text)
case "list":
return _compare_list(predicted_text, gold_text, gold_rows)
case "table":
return _compare_table(predicted_text, gold_text, gold_rows)
case "string":
return _compare_string(predicted_text, gold_text)
case _:
return _compare_string(predicted_text, gold_text)
def _normalize_value(value: str) -> str:
"""Normalize strings for case-insensitive, whitespace-stable comparison."""
text = "" if value is None else str(value)
return " ".join(text.strip().lower().split())
def _compare_integer(predicted: str, gold: str) -> bool:
"""Compare integer values after coercing with ``int(float(x))``."""
try:
return int(float(predicted)) == int(float(gold))
except (TypeError, ValueError):
return False
def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool:
"""Compare float values using a relative tolerance."""
try:
predicted_value = float(predicted)
gold_value = float(gold)
except (TypeError, ValueError):
return False
if gold_value == 0.0:
return abs(predicted_value - gold_value) <= 1e-9
return abs(predicted_value - gold_value) <= tolerance * abs(gold_value)
def _compare_string(predicted: str, gold: str) -> bool:
"""Compare two strings with normalization."""
return _normalize_value(predicted) == _normalize_value(gold)
def _parse_list_values(raw: str) -> set[str]:
"""Parse comma/newline/pipe-separated values into a normalized set.
Handles plain delimited strings and Python list representations:
"121\\n111\\n171" -> {"121", "111", "171"}
"[121, 111, 171]" -> {"121", "111", "171"}
"['Feil', 'Fisher']" -> {"feil", "fisher"}
"""
text = raw.strip()
# Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher'])
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return {
_normalize_value(str(item)) for item in parsed if str(item).strip()
}
except (ValueError, SyntaxError):
pass
tokens = re.split(r"\s*(?:,|\n|\|)\s*", text)
normalized = {_normalize_value(token) for token in tokens if token.strip()}
return normalized
def _compare_list(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare list-like answers as order-insensitive sets."""
predicted_set = _parse_list_values(predicted)
if gold_rows is not None:
gold_set = {
_normalize_value(str(cell))
for row in gold_rows
for cell in row
if str(cell).strip()
}
else:
gold_set = _parse_list_values(gold)
return predicted_set == gold_set
def _parse_table_rows(raw: str) -> list[tuple[str, ...]]:
"""Parse a table answer string into normalized rows.
Supports formats:
- Pipe-separated rows: "111 | 1\\n121 | 2"
- Python list-of-lists: "[[111, 1], [121, 2]]"
- Numbered rows: "1. 111 | 1\\n2. 121 | 2"
"""
text = raw.strip()
if not text:
return []
# Try Python literal (list-of-lists from gold_answer storage)
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return [
tuple(_normalize_value(str(cell)) for cell in row)
for row in parsed
if isinstance(row, (list, tuple))
]
except (ValueError, SyntaxError):
pass
rows = []
for line in text.split("\n"):
line = line.strip()
if not line:
continue
# Strip leading numbering: "1. value | value"
line = re.sub(r"^\d+\.\s*", "", line)
cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)]
if any(c for c in cells):
rows.append(tuple(cells))
return rows
def _compare_table(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare table answers row-by-row with cell-level normalization.
Order-insensitive: rows are compared as multisets (sorted).
"""
pred_rows = _parse_table_rows(predicted)
if gold_rows is not None:
gold_normalized = sorted(
tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows
)
else:
gold_normalized = sorted(_parse_table_rows(gold))
# Sorted comparison preserves duplicate counts, acting as multiset equality
pred_normalized = sorted(pred_rows)
return pred_normalized == gold_normalized