File size: 6,324 Bytes
5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """Answer verification for SQLEnv using type-aware comparisons."""
from __future__ import annotations
import ast
import re
def _strip_answer_wrapping(text: str) -> str:
"""Remove common LLM wrapping artifacts from an answer string.
Strips markdown code fences, surrounding quotes, "Answer: " prefix,
and extra whitespace so the type-aware comparators see clean values.
"""
s = text.strip()
# Markdown code blocks: ```...``` or ```sql\n...\n```
if s.startswith("```") and s.endswith("```"):
# Language tag only if followed by newline (e.g. ```sql\n)
s = re.sub(r"^```(?:\w+\n|\n?)", "", s)
s = re.sub(r"\n?```$", "", s)
s = s.strip()
# "Answer: " or "answer:" prefix
s = re.sub(r"^[Aa]nswer:\s*", "", s)
# Surrounding quotes (single or double)
if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
s = s[1:-1].strip()
return s
def verify_answer(
predicted: str,
gold: str,
answer_type: str | None = None,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare submitted and gold answers with type-aware dispatch."""
predicted_text = "" if predicted is None else str(predicted)
gold_text = "" if gold is None else str(gold)
predicted_text = _strip_answer_wrapping(predicted_text)
if not predicted_text.strip():
return False
match answer_type:
case "integer":
return _compare_integer(predicted_text, gold_text)
case "float":
return _compare_float(predicted_text, gold_text)
case "list":
return _compare_list(predicted_text, gold_text, gold_rows)
case "table":
return _compare_table(predicted_text, gold_text, gold_rows)
case "string":
return _compare_string(predicted_text, gold_text)
case _:
return _compare_string(predicted_text, gold_text)
def _normalize_value(value: str) -> str:
"""Normalize strings for case-insensitive, whitespace-stable comparison."""
text = "" if value is None else str(value)
return " ".join(text.strip().lower().split())
def _compare_integer(predicted: str, gold: str) -> bool:
"""Compare integer values after coercing with ``int(float(x))``."""
try:
return int(float(predicted)) == int(float(gold))
except (TypeError, ValueError):
return False
def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool:
"""Compare float values using a relative tolerance."""
try:
predicted_value = float(predicted)
gold_value = float(gold)
except (TypeError, ValueError):
return False
if gold_value == 0.0:
return abs(predicted_value - gold_value) <= 1e-9
return abs(predicted_value - gold_value) <= tolerance * abs(gold_value)
def _compare_string(predicted: str, gold: str) -> bool:
"""Compare two strings with normalization."""
return _normalize_value(predicted) == _normalize_value(gold)
def _parse_list_values(raw: str) -> set[str]:
"""Parse comma/newline/pipe-separated values into a normalized set.
Handles plain delimited strings and Python list representations:
"121\\n111\\n171" -> {"121", "111", "171"}
"[121, 111, 171]" -> {"121", "111", "171"}
"['Feil', 'Fisher']" -> {"feil", "fisher"}
"""
text = raw.strip()
# Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher'])
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return {
_normalize_value(str(item)) for item in parsed if str(item).strip()
}
except (ValueError, SyntaxError):
pass
tokens = re.split(r"\s*(?:,|\n|\|)\s*", text)
normalized = {_normalize_value(token) for token in tokens if token.strip()}
return normalized
def _compare_list(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare list-like answers as order-insensitive sets."""
predicted_set = _parse_list_values(predicted)
if gold_rows is not None:
gold_set = {
_normalize_value(str(cell))
for row in gold_rows
for cell in row
if str(cell).strip()
}
else:
gold_set = _parse_list_values(gold)
return predicted_set == gold_set
def _parse_table_rows(raw: str) -> list[tuple[str, ...]]:
"""Parse a table answer string into normalized rows.
Supports formats:
- Pipe-separated rows: "111 | 1\\n121 | 2"
- Python list-of-lists: "[[111, 1], [121, 2]]"
- Numbered rows: "1. 111 | 1\\n2. 121 | 2"
"""
text = raw.strip()
if not text:
return []
# Try Python literal (list-of-lists from gold_answer storage)
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return [
tuple(_normalize_value(str(cell)) for cell in row)
for row in parsed
if isinstance(row, (list, tuple))
]
except (ValueError, SyntaxError):
pass
rows = []
for line in text.split("\n"):
line = line.strip()
if not line:
continue
# Strip leading numbering: "1. value | value"
line = re.sub(r"^\d+\.\s*", "", line)
cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)]
if any(c for c in cells):
rows.append(tuple(cells))
return rows
def _compare_table(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare table answers row-by-row with cell-level normalization.
Order-insensitive: rows are compared as multisets (sorted).
"""
pred_rows = _parse_table_rows(predicted)
if gold_rows is not None:
gold_normalized = sorted(
tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows
)
else:
gold_normalized = sorted(_parse_table_rows(gold))
# Sorted comparison preserves duplicate counts, acting as multiset equality
pred_normalized = sorted(pred_rows)
return pred_normalized == gold_normalized
|