YashashMathur's picture
Sync from GitHub - all files
f762b8d verified
# environment/tasks.py
# Task definitions for SQL Data Analyst environment
# 3 Tasks: Easy (single table COUNT), Medium (JOIN + aggregation), Hard (subquery/ordering)
from dataclasses import dataclass
from typing import List, Callable, Any
import random
@dataclass
class Task:
"""
Represents a data analysis task for the agent.
Attributes:
task_id: Unique identifier for the task
difficulty: easy, medium, or hard
question: The business question to answer
ground_truth: The expected correct answer
ground_truth_sql: A SQL query that produces the correct answer
description: Additional context about the task
"""
task_id: str
difficulty: str
question: str
ground_truth: Any
ground_truth_sql: str
description: str
# ============================================
# TASK DEFINITIONS
# ============================================
TASK_EASY = Task(
task_id="easy_user_count",
difficulty="easy",
question=(
"How many users are registered in the system? "
"Provide the total count as a single number."
),
ground_truth=15,
ground_truth_sql="SELECT COUNT(*) FROM users",
description="Single table COUNT query on users table"
)
TASK_MEDIUM = Task(
task_id="medium_usa_revenue",
difficulty="medium",
question=(
"What is the total revenue (sum of total_amount) from purchases made by users in the USA? "
"Provide the total as a number (rounded to 2 decimal places if needed)."
),
ground_truth=2423.87, # Sum of purchases by USA users (user_ids: 1, 4, 7, 10, 14)
ground_truth_sql="""
SELECT ROUND(SUM(p.total_amount), 2) as total_revenue
FROM purchases p
JOIN users u ON p.user_id = u.user_id
WHERE u.country = 'USA'
""",
description="Two-table JOIN with SUM aggregation filtered by country"
)
TASK_HARD = Task(
task_id="hard_top_spender",
difficulty="hard",
question=(
"Who is the top spender (user with highest total purchase amount)? "
"Provide the username of the user who spent the most money in total."
),
ground_truth="alice", # alice has purchases totaling 1509.96 (1299.99 + 59.98 + 149.99)
ground_truth_sql="""
SELECT u.username
FROM users u
JOIN purchases p ON u.user_id = p.user_id
GROUP BY u.user_id, u.username
ORDER BY SUM(p.total_amount) DESC
LIMIT 1
""",
description="Complex query with JOIN, GROUP BY, ORDER BY, and LIMIT"
)
# List of all tasks
TASKS: List[Task] = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
def get_task_by_id(task_id: str) -> Task:
"""
Get a task by its ID.
Args:
task_id: The unique task identifier
Returns:
Task: The matching task
Raises:
ValueError: If task_id not found
"""
for task in TASKS:
if task.task_id == task_id:
return task
raise ValueError(f"Task not found: {task_id}")
def get_task_by_difficulty(difficulty: str) -> Task:
"""
Get a task by difficulty level.
Args:
difficulty: easy, medium, or hard
Returns:
Task: A task matching the difficulty
Raises:
ValueError: If difficulty not found
"""
for task in TASKS:
if task.difficulty == difficulty:
return task
raise ValueError(f"No task found for difficulty: {difficulty}")
def get_random_task() -> Task:
"""
Get a random task from the available tasks.
Returns:
Task: A randomly selected task
"""
return random.choice(TASKS)
def get_all_tasks() -> List[Task]:
"""
Get all available tasks.
Returns:
List[Task]: All defined tasks
"""
return TASKS.copy()