""" Utilities for SymPy-oriented solver output: validation and GSM8K trace cleanup. Aligned with ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT`` (Step N: / Final Answer:). """ from __future__ import annotations import re from dataclasses import dataclass from typing import List, Optional from sympy.parsing.sympy_parser import parse_expr from src.sft.sympy_normalize import normalize_for_parse_expr STEP_RE = re.compile(r"^Step\s+(\d+)\s*:", re.IGNORECASE | re.MULTILINE) FINAL_RE = re.compile(r"(?im)^Final\s*Answer\s*:\s*([^\n]+?)\s*$") @dataclass class FormatCheckResult: ok: bool step_count: int has_final_line: bool final_answer_raw: str sympy_parseable_steps: int sympy_parseable_final: bool errors: List[str] def strip_gsm8k_scratchpads(text: str) -> str: """Remove GSM8K ``<<...>>`` calculator traces; collapse extra spaces.""" s = re.sub(r"<<[^>]*>>", "", text) s = re.sub(r"[ \t]+", " ", s) s = re.sub(r"\n{3,}", "\n\n", s) return s.strip() def _step_bodies(text: str) -> List[str]: """Text after each 'Step N:' up to next Step or Final Answer (best-effort).""" lines = text.splitlines() bodies: List[str] = [] cur: List[str] = [] in_step = False for line in lines: if re.match(r"^\s*Step\s+\d+\s*:", line, re.I): if cur: bodies.append("\n".join(cur).strip()) cur = [re.sub(r"^\s*Step\s+\d+\s*:\s*", "", line, flags=re.I)] in_step = True elif re.match(r"^\s*Final\s*Answer\s*:", line, re.I): if cur: bodies.append("\n".join(cur).strip()) cur = [] in_step = False break elif in_step: cur.append(line) if cur: bodies.append("\n".join(cur).strip()) return [b for b in bodies if b] def _sympy_can_parse_fragment(s: str) -> bool: s = s.strip() if not s: return False # Normalize using shared normalizer (handles ^, currency, etc.) s = normalize_for_parse_expr(s) # Take first line or expression-ish segment after last '=' chunk = s if "=" in s and "==" not in s: chunk = s.split("=")[-1].strip() chunk = chunk.split()[0] if chunk.split() else chunk try: parse_expr(chunk) return True except Exception: try: parse_expr(s[:200]) return True except Exception: return False def validate_sympy_solution_format( text: str, *, require_step_prefix: bool = True, require_final_answer: bool = True, min_steps: int = 1, ) -> FormatCheckResult: """ Check solution text for structural compliance and loose SymPy parseability. Steps: at least ``min_steps`` lines starting with ``Step N:``. Final: a line ``Final Answer: ...`` where the RHS should parse with SymPy (integers and simple rationals usually succeed). """ errors: List[str] = [] steps = STEP_RE.findall(text) step_count = len(steps) if require_step_prefix and step_count < min_steps: errors.append(f"expected at least {min_steps} Step N: line(s), found {step_count}") m_final = None for m in FINAL_RE.finditer(text): m_final = m has_final = m_final is not None final_raw = m_final.group(1).strip() if m_final else "" if require_final_answer and not has_final: errors.append("missing 'Final Answer:' line") sympy_final = False if final_raw: try: parse_expr(normalize_for_parse_expr(final_raw)) sympy_final = True except Exception: errors.append(f"final answer does not parse as SymPy expr: {final_raw!r}") bodies = _step_bodies(text) sympy_parseable_steps = len([b for b in bodies if _sympy_can_parse_fragment(b)]) ok = len(errors) == 0 return FormatCheckResult( ok=ok, step_count=step_count, has_final_line=has_final, final_answer_raw=final_raw, sympy_parseable_steps=sympy_parseable_steps, sympy_parseable_final=sympy_final, errors=errors, ) def extract_final_answer_numeric_str(text: str) -> Optional[str]: """Return substring after 'Final Answer:' if present.""" m = list(FINAL_RE.finditer(text)) if not m: return None return m[-1].group(1).strip()