Spaces:
Sleeping
Sleeping
| # environment/tasks.py | |
| # Task definitions for SQL Data Analyst environment | |
| # 3 Tasks: Easy (single table COUNT), Medium (JOIN + aggregation), Hard (subquery/ordering) | |
| from dataclasses import dataclass | |
| from typing import List, Callable, Any | |
| import random | |
| class Task: | |
| """ | |
| Represents a data analysis task for the agent. | |
| Attributes: | |
| task_id: Unique identifier for the task | |
| difficulty: easy, medium, or hard | |
| question: The business question to answer | |
| ground_truth: The expected correct answer | |
| ground_truth_sql: A SQL query that produces the correct answer | |
| description: Additional context about the task | |
| """ | |
| task_id: str | |
| difficulty: str | |
| question: str | |
| ground_truth: Any | |
| ground_truth_sql: str | |
| description: str | |
| # ============================================ | |
| # TASK DEFINITIONS | |
| # ============================================ | |
| TASK_EASY = Task( | |
| task_id="easy_user_count", | |
| difficulty="easy", | |
| question=( | |
| "How many users are registered in the system? " | |
| "Provide the total count as a single number." | |
| ), | |
| ground_truth=15, | |
| ground_truth_sql="SELECT COUNT(*) FROM users", | |
| description="Single table COUNT query on users table" | |
| ) | |
| TASK_MEDIUM = Task( | |
| task_id="medium_usa_revenue", | |
| difficulty="medium", | |
| question=( | |
| "What is the total revenue (sum of total_amount) from purchases made by users in the USA? " | |
| "Provide the total as a number (rounded to 2 decimal places if needed)." | |
| ), | |
| ground_truth=2423.87, # Sum of purchases by USA users (user_ids: 1, 4, 7, 10, 14) | |
| ground_truth_sql=""" | |
| SELECT ROUND(SUM(p.total_amount), 2) as total_revenue | |
| FROM purchases p | |
| JOIN users u ON p.user_id = u.user_id | |
| WHERE u.country = 'USA' | |
| """, | |
| description="Two-table JOIN with SUM aggregation filtered by country" | |
| ) | |
| TASK_HARD = Task( | |
| task_id="hard_top_spender", | |
| difficulty="hard", | |
| question=( | |
| "Who is the top spender (user with highest total purchase amount)? " | |
| "Provide the username of the user who spent the most money in total." | |
| ), | |
| ground_truth="alice", # alice has purchases totaling 1509.96 (1299.99 + 59.98 + 149.99) | |
| ground_truth_sql=""" | |
| SELECT u.username | |
| FROM users u | |
| JOIN purchases p ON u.user_id = p.user_id | |
| GROUP BY u.user_id, u.username | |
| ORDER BY SUM(p.total_amount) DESC | |
| LIMIT 1 | |
| """, | |
| description="Complex query with JOIN, GROUP BY, ORDER BY, and LIMIT" | |
| ) | |
| # List of all tasks | |
| TASKS: List[Task] = [TASK_EASY, TASK_MEDIUM, TASK_HARD] | |
| def get_task_by_id(task_id: str) -> Task: | |
| """ | |
| Get a task by its ID. | |
| Args: | |
| task_id: The unique task identifier | |
| Returns: | |
| Task: The matching task | |
| Raises: | |
| ValueError: If task_id not found | |
| """ | |
| for task in TASKS: | |
| if task.task_id == task_id: | |
| return task | |
| raise ValueError(f"Task not found: {task_id}") | |
| def get_task_by_difficulty(difficulty: str) -> Task: | |
| """ | |
| Get a task by difficulty level. | |
| Args: | |
| difficulty: easy, medium, or hard | |
| Returns: | |
| Task: A task matching the difficulty | |
| Raises: | |
| ValueError: If difficulty not found | |
| """ | |
| for task in TASKS: | |
| if task.difficulty == difficulty: | |
| return task | |
| raise ValueError(f"No task found for difficulty: {difficulty}") | |
| def get_random_task() -> Task: | |
| """ | |
| Get a random task from the available tasks. | |
| Returns: | |
| Task: A randomly selected task | |
| """ | |
| return random.choice(TASKS) | |
| def get_all_tasks() -> List[Task]: | |
| """ | |
| Get all available tasks. | |
| Returns: | |
| List[Task]: All defined tasks | |
| """ | |
| return TASKS.copy() | |