| """Answer verification for SQLEnv using type-aware comparisons.""" |
|
|
| from __future__ import annotations |
|
|
| import ast |
| import re |
|
|
|
|
| def _strip_answer_wrapping(text: str) -> str: |
| """Remove common LLM wrapping artifacts from an answer string. |
| |
| Strips markdown code fences, surrounding quotes, "Answer: " prefix, |
| and extra whitespace so the type-aware comparators see clean values. |
| """ |
| s = text.strip() |
| |
| if s.startswith("```") and s.endswith("```"): |
| |
| s = re.sub(r"^```(?:\w+\n|\n?)", "", s) |
| s = re.sub(r"\n?```$", "", s) |
| s = s.strip() |
| |
| s = re.sub(r"^[Aa]nswer:\s*", "", s) |
| |
| if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): |
| s = s[1:-1].strip() |
| return s |
|
|
|
|
| def verify_answer( |
| predicted: str, |
| gold: str, |
| answer_type: str | None = None, |
| gold_rows: list[tuple] | None = None, |
| ) -> bool: |
| """Compare submitted and gold answers with type-aware dispatch.""" |
| predicted_text = "" if predicted is None else str(predicted) |
| gold_text = "" if gold is None else str(gold) |
|
|
| predicted_text = _strip_answer_wrapping(predicted_text) |
|
|
| if not predicted_text.strip(): |
| return False |
|
|
| match answer_type: |
| case "integer": |
| return _compare_integer(predicted_text, gold_text) |
| case "float": |
| return _compare_float(predicted_text, gold_text) |
| case "list": |
| return _compare_list(predicted_text, gold_text, gold_rows) |
| case "table": |
| return _compare_table(predicted_text, gold_text, gold_rows) |
| case "string": |
| return _compare_string(predicted_text, gold_text) |
| case _: |
| return _compare_string(predicted_text, gold_text) |
|
|
|
|
| def _normalize_value(value: str) -> str: |
| """Normalize strings for case-insensitive, whitespace-stable comparison.""" |
| text = "" if value is None else str(value) |
| return " ".join(text.strip().lower().split()) |
|
|
|
|
| def _compare_integer(predicted: str, gold: str) -> bool: |
| """Compare integer values after coercing with ``int(float(x))``.""" |
| try: |
| return int(float(predicted)) == int(float(gold)) |
| except (TypeError, ValueError): |
| return False |
|
|
|
|
| def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool: |
| """Compare float values using a relative tolerance.""" |
| try: |
| predicted_value = float(predicted) |
| gold_value = float(gold) |
| except (TypeError, ValueError): |
| return False |
|
|
| if gold_value == 0.0: |
| return abs(predicted_value - gold_value) <= 1e-9 |
|
|
| return abs(predicted_value - gold_value) <= tolerance * abs(gold_value) |
|
|
|
|
| def _compare_string(predicted: str, gold: str) -> bool: |
| """Compare two strings with normalization.""" |
| return _normalize_value(predicted) == _normalize_value(gold) |
|
|
|
|
| def _parse_list_values(raw: str) -> set[str]: |
| """Parse comma/newline/pipe-separated values into a normalized set. |
| |
| Handles plain delimited strings and Python list representations: |
| "121\\n111\\n171" -> {"121", "111", "171"} |
| "[121, 111, 171]" -> {"121", "111", "171"} |
| "['Feil', 'Fisher']" -> {"feil", "fisher"} |
| """ |
| text = raw.strip() |
|
|
| |
| if text.startswith("["): |
| try: |
| parsed = ast.literal_eval(text) |
| if isinstance(parsed, list): |
| return { |
| _normalize_value(str(item)) for item in parsed if str(item).strip() |
| } |
| except (ValueError, SyntaxError): |
| pass |
|
|
| tokens = re.split(r"\s*(?:,|\n|\|)\s*", text) |
| normalized = {_normalize_value(token) for token in tokens if token.strip()} |
| return normalized |
|
|
|
|
| def _compare_list( |
| predicted: str, |
| gold: str, |
| gold_rows: list[tuple] | None = None, |
| ) -> bool: |
| """Compare list-like answers as order-insensitive sets.""" |
| predicted_set = _parse_list_values(predicted) |
|
|
| if gold_rows is not None: |
| gold_set = { |
| _normalize_value(str(cell)) |
| for row in gold_rows |
| for cell in row |
| if str(cell).strip() |
| } |
| else: |
| gold_set = _parse_list_values(gold) |
|
|
| return predicted_set == gold_set |
|
|
|
|
| def _parse_table_rows(raw: str) -> list[tuple[str, ...]]: |
| """Parse a table answer string into normalized rows. |
| |
| Supports formats: |
| - Pipe-separated rows: "111 | 1\\n121 | 2" |
| - Python list-of-lists: "[[111, 1], [121, 2]]" |
| - Numbered rows: "1. 111 | 1\\n2. 121 | 2" |
| """ |
| text = raw.strip() |
| if not text: |
| return [] |
|
|
| |
| if text.startswith("["): |
| try: |
| parsed = ast.literal_eval(text) |
| if isinstance(parsed, list): |
| return [ |
| tuple(_normalize_value(str(cell)) for cell in row) |
| for row in parsed |
| if isinstance(row, (list, tuple)) |
| ] |
| except (ValueError, SyntaxError): |
| pass |
|
|
| rows = [] |
| for line in text.split("\n"): |
| line = line.strip() |
| if not line: |
| continue |
| |
| line = re.sub(r"^\d+\.\s*", "", line) |
| cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)] |
| if any(c for c in cells): |
| rows.append(tuple(cells)) |
| return rows |
|
|
|
|
| def _compare_table( |
| predicted: str, |
| gold: str, |
| gold_rows: list[tuple] | None = None, |
| ) -> bool: |
| """Compare table answers row-by-row with cell-level normalization. |
| |
| Order-insensitive: rows are compared as multisets (sorted). |
| """ |
| pred_rows = _parse_table_rows(predicted) |
|
|
| if gold_rows is not None: |
| gold_normalized = sorted( |
| tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows |
| ) |
| else: |
| gold_normalized = sorted(_parse_table_rows(gold)) |
|
|
| |
| pred_normalized = sorted(pred_rows) |
|
|
| return pred_normalized == gold_normalized |
|
|