"""Answer verification for SQLEnv using type-aware comparisons.""" from __future__ import annotations import ast import re def _strip_answer_wrapping(text: str) -> str: """Remove common LLM wrapping artifacts from an answer string. Strips markdown code fences, surrounding quotes, "Answer: " prefix, and extra whitespace so the type-aware comparators see clean values. """ s = text.strip() # Markdown code blocks: ```...``` or ```sql\n...\n``` if s.startswith("```") and s.endswith("```"): # Language tag only if followed by newline (e.g. ```sql\n) s = re.sub(r"^```(?:\w+\n|\n?)", "", s) s = re.sub(r"\n?```$", "", s) s = s.strip() # "Answer: " or "answer:" prefix s = re.sub(r"^[Aa]nswer:\s*", "", s) # Surrounding quotes (single or double) if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): s = s[1:-1].strip() return s def verify_answer( predicted: str, gold: str, answer_type: str | None = None, gold_rows: list[tuple] | None = None, ) -> bool: """Compare submitted and gold answers with type-aware dispatch.""" predicted_text = "" if predicted is None else str(predicted) gold_text = "" if gold is None else str(gold) predicted_text = _strip_answer_wrapping(predicted_text) if not predicted_text.strip(): return False match answer_type: case "integer": return _compare_integer(predicted_text, gold_text) case "float": return _compare_float(predicted_text, gold_text) case "list": return _compare_list(predicted_text, gold_text, gold_rows) case "table": return _compare_table(predicted_text, gold_text, gold_rows) case "string": return _compare_string(predicted_text, gold_text) case _: return _compare_string(predicted_text, gold_text) def _normalize_value(value: str) -> str: """Normalize strings for case-insensitive, whitespace-stable comparison.""" text = "" if value is None else str(value) return " ".join(text.strip().lower().split()) def _compare_integer(predicted: str, gold: str) -> bool: """Compare integer values after coercing with ``int(float(x))``.""" try: return int(float(predicted)) == int(float(gold)) except (TypeError, ValueError): return False def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool: """Compare float values using a relative tolerance.""" try: predicted_value = float(predicted) gold_value = float(gold) except (TypeError, ValueError): return False if gold_value == 0.0: return abs(predicted_value - gold_value) <= 1e-9 return abs(predicted_value - gold_value) <= tolerance * abs(gold_value) def _compare_string(predicted: str, gold: str) -> bool: """Compare two strings with normalization.""" return _normalize_value(predicted) == _normalize_value(gold) def _parse_list_values(raw: str) -> set[str]: """Parse comma/newline/pipe-separated values into a normalized set. Handles plain delimited strings and Python list representations: "121\\n111\\n171" -> {"121", "111", "171"} "[121, 111, 171]" -> {"121", "111", "171"} "['Feil', 'Fisher']" -> {"feil", "fisher"} """ text = raw.strip() # Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher']) if text.startswith("["): try: parsed = ast.literal_eval(text) if isinstance(parsed, list): return { _normalize_value(str(item)) for item in parsed if str(item).strip() } except (ValueError, SyntaxError): pass tokens = re.split(r"\s*(?:,|\n|\|)\s*", text) normalized = {_normalize_value(token) for token in tokens if token.strip()} return normalized def _compare_list( predicted: str, gold: str, gold_rows: list[tuple] | None = None, ) -> bool: """Compare list-like answers as order-insensitive sets.""" predicted_set = _parse_list_values(predicted) if gold_rows is not None: gold_set = { _normalize_value(str(cell)) for row in gold_rows for cell in row if str(cell).strip() } else: gold_set = _parse_list_values(gold) return predicted_set == gold_set def _parse_table_rows(raw: str) -> list[tuple[str, ...]]: """Parse a table answer string into normalized rows. Supports formats: - Pipe-separated rows: "111 | 1\\n121 | 2" - Python list-of-lists: "[[111, 1], [121, 2]]" - Numbered rows: "1. 111 | 1\\n2. 121 | 2" """ text = raw.strip() if not text: return [] # Try Python literal (list-of-lists from gold_answer storage) if text.startswith("["): try: parsed = ast.literal_eval(text) if isinstance(parsed, list): return [ tuple(_normalize_value(str(cell)) for cell in row) for row in parsed if isinstance(row, (list, tuple)) ] except (ValueError, SyntaxError): pass rows = [] for line in text.split("\n"): line = line.strip() if not line: continue # Strip leading numbering: "1. value | value" line = re.sub(r"^\d+\.\s*", "", line) cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)] if any(c for c in cells): rows.append(tuple(cells)) return rows def _compare_table( predicted: str, gold: str, gold_rows: list[tuple] | None = None, ) -> bool: """Compare table answers row-by-row with cell-level normalization. Order-insensitive: rows are compared as multisets (sorted). """ pred_rows = _parse_table_rows(predicted) if gold_rows is not None: gold_normalized = sorted( tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows ) else: gold_normalized = sorted(_parse_table_rows(gold)) # Sorted comparison preserves duplicate counts, acting as multiset equality pred_normalized = sorted(pred_rows) return pred_normalized == gold_normalized