File size: 1,877 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sqlite3
from .base import BaseTask


class TopRevenueCategoryTask(BaseTask):
    """Task 2 — Medium: Find product category with most revenue in Q3."""

    task_id = "top_revenue_category"
    difficulty = "medium"
    max_steps = 15
    question = (
        "Which product category generated the most revenue in Q3 (July-September)?"
    )
    relevant_tables = ["orders", "order_items", "products"]
    sql_hint = "JOIN with GROUP BY and ORDER BY"

    def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
        result = conn.execute("""
            SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
            FROM orders o
            JOIN order_items oi ON o.id = oi.order_id
            JOIN products p ON oi.product_id = p.id
            WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
              AND o.status = 'completed'
            GROUP BY p.category
            ORDER BY revenue DESC
            LIMIT 1
        """).fetchone()

        self.ground_truth = result[0] if result else None

        all_categories = conn.execute("""
            SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
            FROM orders o
            JOIN order_items oi ON o.id = oi.order_id
            JOIN products p ON oi.product_id = p.id
            WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
              AND o.status = 'completed'
            GROUP BY p.category
            ORDER BY revenue DESC
        """).fetchall()

        self.top_3_categories = [row[0] for row in all_categories[:3]]

    def grade(self, submitted_answer: str) -> float:
        answer = self._normalize(submitted_answer)

        if self.ground_truth and self.ground_truth.lower() in answer:
            return 1.0

        if any(cat.lower() in answer for cat in self.top_3_categories):
            return 0.4

        return 0.0