File size: 3,969 Bytes
f762b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# environment/tasks.py
# Task definitions for SQL Data Analyst environment
# 3 Tasks: Easy (single table COUNT), Medium (JOIN + aggregation), Hard (subquery/ordering)

from dataclasses import dataclass
from typing import List, Callable, Any
import random


@dataclass
class Task:
    """

    Represents a data analysis task for the agent.

    

    Attributes:

        task_id: Unique identifier for the task

        difficulty: easy, medium, or hard

        question: The business question to answer

        ground_truth: The expected correct answer

        ground_truth_sql: A SQL query that produces the correct answer

        description: Additional context about the task

    """
    task_id: str
    difficulty: str
    question: str
    ground_truth: Any
    ground_truth_sql: str
    description: str


# ============================================
# TASK DEFINITIONS
# ============================================

TASK_EASY = Task(
    task_id="easy_user_count",
    difficulty="easy",
    question=(
        "How many users are registered in the system? "
        "Provide the total count as a single number."
    ),
    ground_truth=15,
    ground_truth_sql="SELECT COUNT(*) FROM users",
    description="Single table COUNT query on users table"
)

TASK_MEDIUM = Task(
    task_id="medium_usa_revenue",
    difficulty="medium",
    question=(
        "What is the total revenue (sum of total_amount) from purchases made by users in the USA? "
        "Provide the total as a number (rounded to 2 decimal places if needed)."
    ),
    ground_truth=2423.87,  # Sum of purchases by USA users (user_ids: 1, 4, 7, 10, 14)
    ground_truth_sql="""

        SELECT ROUND(SUM(p.total_amount), 2) as total_revenue

        FROM purchases p

        JOIN users u ON p.user_id = u.user_id

        WHERE u.country = 'USA'

    """,
    description="Two-table JOIN with SUM aggregation filtered by country"
)

TASK_HARD = Task(
    task_id="hard_top_spender",
    difficulty="hard",
    question=(
        "Who is the top spender (user with highest total purchase amount)? "
        "Provide the username of the user who spent the most money in total."
    ),
    ground_truth="alice",  # alice has purchases totaling 1509.96 (1299.99 + 59.98 + 149.99)
    ground_truth_sql="""

        SELECT u.username

        FROM users u

        JOIN purchases p ON u.user_id = p.user_id

        GROUP BY u.user_id, u.username

        ORDER BY SUM(p.total_amount) DESC

        LIMIT 1

    """,
    description="Complex query with JOIN, GROUP BY, ORDER BY, and LIMIT"
)


# List of all tasks
TASKS: List[Task] = [TASK_EASY, TASK_MEDIUM, TASK_HARD]


def get_task_by_id(task_id: str) -> Task:
    """

    Get a task by its ID.

    

    Args:

        task_id: The unique task identifier

        

    Returns:

        Task: The matching task

        

    Raises:

        ValueError: If task_id not found

    """
    for task in TASKS:
        if task.task_id == task_id:
            return task
    raise ValueError(f"Task not found: {task_id}")


def get_task_by_difficulty(difficulty: str) -> Task:
    """

    Get a task by difficulty level.

    

    Args:

        difficulty: easy, medium, or hard

        

    Returns:

        Task: A task matching the difficulty

        

    Raises:

        ValueError: If difficulty not found

    """
    for task in TASKS:
        if task.difficulty == difficulty:
            return task
    raise ValueError(f"No task found for difficulty: {difficulty}")


def get_random_task() -> Task:
    """

    Get a random task from the available tasks.

    

    Returns:

        Task: A randomly selected task

    """
    return random.choice(TASKS)


def get_all_tasks() -> List[Task]:
    """

    Get all available tasks.

    

    Returns:

        List[Task]: All defined tasks

    """
    return TASKS.copy()