Spaces:
Sleeping
Sleeping
File size: 3,969 Bytes
f762b8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # environment/tasks.py
# Task definitions for SQL Data Analyst environment
# 3 Tasks: Easy (single table COUNT), Medium (JOIN + aggregation), Hard (subquery/ordering)
from dataclasses import dataclass
from typing import List, Callable, Any
import random
@dataclass
class Task:
"""
Represents a data analysis task for the agent.
Attributes:
task_id: Unique identifier for the task
difficulty: easy, medium, or hard
question: The business question to answer
ground_truth: The expected correct answer
ground_truth_sql: A SQL query that produces the correct answer
description: Additional context about the task
"""
task_id: str
difficulty: str
question: str
ground_truth: Any
ground_truth_sql: str
description: str
# ============================================
# TASK DEFINITIONS
# ============================================
TASK_EASY = Task(
task_id="easy_user_count",
difficulty="easy",
question=(
"How many users are registered in the system? "
"Provide the total count as a single number."
),
ground_truth=15,
ground_truth_sql="SELECT COUNT(*) FROM users",
description="Single table COUNT query on users table"
)
TASK_MEDIUM = Task(
task_id="medium_usa_revenue",
difficulty="medium",
question=(
"What is the total revenue (sum of total_amount) from purchases made by users in the USA? "
"Provide the total as a number (rounded to 2 decimal places if needed)."
),
ground_truth=2423.87, # Sum of purchases by USA users (user_ids: 1, 4, 7, 10, 14)
ground_truth_sql="""
SELECT ROUND(SUM(p.total_amount), 2) as total_revenue
FROM purchases p
JOIN users u ON p.user_id = u.user_id
WHERE u.country = 'USA'
""",
description="Two-table JOIN with SUM aggregation filtered by country"
)
TASK_HARD = Task(
task_id="hard_top_spender",
difficulty="hard",
question=(
"Who is the top spender (user with highest total purchase amount)? "
"Provide the username of the user who spent the most money in total."
),
ground_truth="alice", # alice has purchases totaling 1509.96 (1299.99 + 59.98 + 149.99)
ground_truth_sql="""
SELECT u.username
FROM users u
JOIN purchases p ON u.user_id = p.user_id
GROUP BY u.user_id, u.username
ORDER BY SUM(p.total_amount) DESC
LIMIT 1
""",
description="Complex query with JOIN, GROUP BY, ORDER BY, and LIMIT"
)
# List of all tasks
TASKS: List[Task] = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
def get_task_by_id(task_id: str) -> Task:
"""
Get a task by its ID.
Args:
task_id: The unique task identifier
Returns:
Task: The matching task
Raises:
ValueError: If task_id not found
"""
for task in TASKS:
if task.task_id == task_id:
return task
raise ValueError(f"Task not found: {task_id}")
def get_task_by_difficulty(difficulty: str) -> Task:
"""
Get a task by difficulty level.
Args:
difficulty: easy, medium, or hard
Returns:
Task: A task matching the difficulty
Raises:
ValueError: If difficulty not found
"""
for task in TASKS:
if task.difficulty == difficulty:
return task
raise ValueError(f"No task found for difficulty: {difficulty}")
def get_random_task() -> Task:
"""
Get a random task from the available tasks.
Returns:
Task: A randomly selected task
"""
return random.choice(TASKS)
def get_all_tasks() -> List[Task]:
"""
Get all available tasks.
Returns:
List[Task]: All defined tasks
"""
return TASKS.copy()
|