File size: 6,324 Bytes
5dd1bb4
 
 
 
9e64e71
5dd1bb4
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
 
 
9e64e71
 
5dd1bb4
 
 
 
 
 
 
 
 
 
9e64e71
 
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Answer verification for SQLEnv using type-aware comparisons."""

from __future__ import annotations

import ast
import re


def _strip_answer_wrapping(text: str) -> str:
    """Remove common LLM wrapping artifacts from an answer string.

    Strips markdown code fences, surrounding quotes, "Answer: " prefix,
    and extra whitespace so the type-aware comparators see clean values.
    """
    s = text.strip()
    # Markdown code blocks: ```...``` or ```sql\n...\n```
    if s.startswith("```") and s.endswith("```"):
        # Language tag only if followed by newline (e.g. ```sql\n)
        s = re.sub(r"^```(?:\w+\n|\n?)", "", s)
        s = re.sub(r"\n?```$", "", s)
        s = s.strip()
    # "Answer: " or "answer:" prefix
    s = re.sub(r"^[Aa]nswer:\s*", "", s)
    # Surrounding quotes (single or double)
    if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
        s = s[1:-1].strip()
    return s


def verify_answer(
    predicted: str,
    gold: str,
    answer_type: str | None = None,
    gold_rows: list[tuple] | None = None,
) -> bool:
    """Compare submitted and gold answers with type-aware dispatch."""
    predicted_text = "" if predicted is None else str(predicted)
    gold_text = "" if gold is None else str(gold)

    predicted_text = _strip_answer_wrapping(predicted_text)

    if not predicted_text.strip():
        return False

    match answer_type:
        case "integer":
            return _compare_integer(predicted_text, gold_text)
        case "float":
            return _compare_float(predicted_text, gold_text)
        case "list":
            return _compare_list(predicted_text, gold_text, gold_rows)
        case "table":
            return _compare_table(predicted_text, gold_text, gold_rows)
        case "string":
            return _compare_string(predicted_text, gold_text)
        case _:
            return _compare_string(predicted_text, gold_text)


def _normalize_value(value: str) -> str:
    """Normalize strings for case-insensitive, whitespace-stable comparison."""
    text = "" if value is None else str(value)
    return " ".join(text.strip().lower().split())


def _compare_integer(predicted: str, gold: str) -> bool:
    """Compare integer values after coercing with ``int(float(x))``."""
    try:
        return int(float(predicted)) == int(float(gold))
    except (TypeError, ValueError):
        return False


def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool:
    """Compare float values using a relative tolerance."""
    try:
        predicted_value = float(predicted)
        gold_value = float(gold)
    except (TypeError, ValueError):
        return False

    if gold_value == 0.0:
        return abs(predicted_value - gold_value) <= 1e-9

    return abs(predicted_value - gold_value) <= tolerance * abs(gold_value)


def _compare_string(predicted: str, gold: str) -> bool:
    """Compare two strings with normalization."""
    return _normalize_value(predicted) == _normalize_value(gold)


def _parse_list_values(raw: str) -> set[str]:
    """Parse comma/newline/pipe-separated values into a normalized set.

    Handles plain delimited strings and Python list representations:
      "121\\n111\\n171"  ->  {"121", "111", "171"}
      "[121, 111, 171]"  ->  {"121", "111", "171"}
      "['Feil', 'Fisher']"  ->  {"feil", "fisher"}
    """
    text = raw.strip()

    # Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher'])
    if text.startswith("["):
        try:
            parsed = ast.literal_eval(text)
            if isinstance(parsed, list):
                return {
                    _normalize_value(str(item)) for item in parsed if str(item).strip()
                }
        except (ValueError, SyntaxError):
            pass

    tokens = re.split(r"\s*(?:,|\n|\|)\s*", text)
    normalized = {_normalize_value(token) for token in tokens if token.strip()}
    return normalized


def _compare_list(
    predicted: str,
    gold: str,
    gold_rows: list[tuple] | None = None,
) -> bool:
    """Compare list-like answers as order-insensitive sets."""
    predicted_set = _parse_list_values(predicted)

    if gold_rows is not None:
        gold_set = {
            _normalize_value(str(cell))
            for row in gold_rows
            for cell in row
            if str(cell).strip()
        }
    else:
        gold_set = _parse_list_values(gold)

    return predicted_set == gold_set


def _parse_table_rows(raw: str) -> list[tuple[str, ...]]:
    """Parse a table answer string into normalized rows.

    Supports formats:
      - Pipe-separated rows: "111 | 1\\n121 | 2"
      - Python list-of-lists: "[[111, 1], [121, 2]]"
      - Numbered rows: "1. 111 | 1\\n2. 121 | 2"
    """
    text = raw.strip()
    if not text:
        return []

    # Try Python literal (list-of-lists from gold_answer storage)
    if text.startswith("["):
        try:
            parsed = ast.literal_eval(text)
            if isinstance(parsed, list):
                return [
                    tuple(_normalize_value(str(cell)) for cell in row)
                    for row in parsed
                    if isinstance(row, (list, tuple))
                ]
        except (ValueError, SyntaxError):
            pass

    rows = []
    for line in text.split("\n"):
        line = line.strip()
        if not line:
            continue
        # Strip leading numbering: "1. value | value"
        line = re.sub(r"^\d+\.\s*", "", line)
        cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)]
        if any(c for c in cells):
            rows.append(tuple(cells))
    return rows


def _compare_table(
    predicted: str,
    gold: str,
    gold_rows: list[tuple] | None = None,
) -> bool:
    """Compare table answers row-by-row with cell-level normalization.

    Order-insensitive: rows are compared as multisets (sorted).
    """
    pred_rows = _parse_table_rows(predicted)

    if gold_rows is not None:
        gold_normalized = sorted(
            tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows
        )
    else:
        gold_normalized = sorted(_parse_table_rows(gold))

    # Sorted comparison preserves duplicate counts, acting as multiset equality
    pred_normalized = sorted(pred_rows)

    return pred_normalized == gold_normalized