File size: 1,665 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from abc import ABC, abstractmethod
import sqlite3
import re
from typing import Any, List, Optional


class BaseTask(ABC):
    """Abstract base class for all tasks."""

    task_id: str
    difficulty: str
    max_steps: int
    question: str
    relevant_tables: List[str]
    sql_hint: str

    def __init__(self):
        self.ground_truth: Any = None
        self.top_3_categories: List[str] = []

    @abstractmethod
    def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
        """Compute ground truth after database seeding."""
        pass

    @abstractmethod
    def grade(self, submitted_answer: str) -> float:
        """Grade the submitted answer. Returns score 0.0-1.0."""
        pass

    def get_hints(self, step: int) -> List[str]:
        """Return progressive hints based on current step."""
        hints = []
        if step > 5:
            hints.append(
                f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}"
            )
        if step > 10:
            hints.append(f"Hint: Try using {self.sql_hint}")
        if step > 15:
            hints.append("Hint: Make sure to submit your answer with submit_answer.")
        return hints

    def _normalize(self, text: str) -> str:
        """Remove common LLM formatting and normalize text."""
        text = text.strip().lower()
        text = re.sub(r"the (answer|result|category) is:?\s*", "", text)
        text = re.sub(r"\*+", "", text)
        text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
        text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
        text = re.sub(r"\s+", " ", text)
        return text.strip()