File size: 4,357 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Utilities for SymPy-oriented solver output: validation and GSM8K trace cleanup.

Aligned with ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT`` (Step N: / Final Answer:).
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import List, Optional

from sympy.parsing.sympy_parser import parse_expr

from src.sft.sympy_normalize import normalize_for_parse_expr


STEP_RE = re.compile(r"^Step\s+(\d+)\s*:", re.IGNORECASE | re.MULTILINE)
FINAL_RE = re.compile(r"(?im)^Final\s*Answer\s*:\s*([^\n]+?)\s*$")


@dataclass
class FormatCheckResult:
    ok: bool
    step_count: int
    has_final_line: bool
    final_answer_raw: str
    sympy_parseable_steps: int
    sympy_parseable_final: bool
    errors: List[str]


def strip_gsm8k_scratchpads(text: str) -> str:
    """Remove GSM8K ``<<...>>`` calculator traces; collapse extra spaces."""
    s = re.sub(r"<<[^>]*>>", "", text)
    s = re.sub(r"[ \t]+", " ", s)
    s = re.sub(r"\n{3,}", "\n\n", s)
    return s.strip()


def _step_bodies(text: str) -> List[str]:
    """Text after each 'Step N:' up to next Step or Final Answer (best-effort)."""
    lines = text.splitlines()
    bodies: List[str] = []
    cur: List[str] = []
    in_step = False
    for line in lines:
        if re.match(r"^\s*Step\s+\d+\s*:", line, re.I):
            if cur:
                bodies.append("\n".join(cur).strip())
            cur = [re.sub(r"^\s*Step\s+\d+\s*:\s*", "", line, flags=re.I)]
            in_step = True
        elif re.match(r"^\s*Final\s*Answer\s*:", line, re.I):
            if cur:
                bodies.append("\n".join(cur).strip())
            cur = []
            in_step = False
            break
        elif in_step:
            cur.append(line)
    if cur:
        bodies.append("\n".join(cur).strip())
    return [b for b in bodies if b]


def _sympy_can_parse_fragment(s: str) -> bool:
    s = s.strip()
    if not s:
        return False
    # Normalize using shared normalizer (handles ^, currency, etc.)
    s = normalize_for_parse_expr(s)
    # Take first line or expression-ish segment after last '='
    chunk = s
    if "=" in s and "==" not in s:
        chunk = s.split("=")[-1].strip()
    chunk = chunk.split()[0] if chunk.split() else chunk
    try:
        parse_expr(chunk)
        return True
    except Exception:
        try:
            parse_expr(s[:200])
            return True
        except Exception:
            return False


def validate_sympy_solution_format(
    text: str,
    *,
    require_step_prefix: bool = True,
    require_final_answer: bool = True,
    min_steps: int = 1,
) -> FormatCheckResult:
    """
    Check solution text for structural compliance and loose SymPy parseability.

    Steps: at least ``min_steps`` lines starting with ``Step N:``.
    Final: a line ``Final Answer: ...`` where the RHS should parse with SymPy
    (integers and simple rationals usually succeed).
    """
    errors: List[str] = []
    steps = STEP_RE.findall(text)
    step_count = len(steps)

    if require_step_prefix and step_count < min_steps:
        errors.append(f"expected at least {min_steps} Step N: line(s), found {step_count}")

    m_final = None
    for m in FINAL_RE.finditer(text):
        m_final = m
    has_final = m_final is not None
    final_raw = m_final.group(1).strip() if m_final else ""

    if require_final_answer and not has_final:
        errors.append("missing 'Final Answer:' line")

    sympy_final = False
    if final_raw:
        try:
            parse_expr(normalize_for_parse_expr(final_raw))
            sympy_final = True
        except Exception:
            errors.append(f"final answer does not parse as SymPy expr: {final_raw!r}")

    bodies = _step_bodies(text)
    sympy_parseable_steps = len([b for b in bodies if _sympy_can_parse_fragment(b)])

    ok = len(errors) == 0
    return FormatCheckResult(
        ok=ok,
        step_count=step_count,
        has_final_line=has_final,
        final_answer_raw=final_raw,
        sympy_parseable_steps=sympy_parseable_steps,
        sympy_parseable_final=sympy_final,
        errors=errors,
    )


def extract_final_answer_numeric_str(text: str) -> Optional[str]:
    """Return substring after 'Final Answer:' if present."""
    m = list(FINAL_RE.finditer(text))
    if not m:
        return None
    return m[-1].group(1).strip()