File size: 1,936 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sqlite3
from .base import BaseTask


class ChurnAnalysisTask(BaseTask):
    """Task 3 — Hard: Find users who placed exactly 3 orders and then churned."""

    task_id = "churn_analysis"
    difficulty = "hard"
    max_steps = 20
    question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."
    relevant_tables = ["users", "orders"]
    sql_hint = "CTE with COUNT and HAVING"

    def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
        result = conn.execute("""
            WITH order_counts AS (
                SELECT user_id, COUNT(*) AS total_orders,
                       MAX(created_at) AS last_order_date
                FROM orders
                WHERE status = 'completed'
                GROUP BY user_id
                HAVING COUNT(*) = 3
            ),
            churned AS (
                SELECT oc.user_id
                FROM order_counts oc
                WHERE oc.last_order_date < DATE('now', '-90 days')
            )
            SELECT u.email
            FROM users u
            JOIN churned c ON u.id = c.user_id
        """).fetchall()

        self.ground_truth = {row[0].lower() for row in result}

    def grade(self, submitted_answer: str) -> float:
        if not submitted_answer.strip():
            return 0.0

        submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e}

        if not submitted:
            return 0.0

        correct = {e.lower() for e in self.ground_truth}
        tp = len(submitted & correct)

        if tp == 0:
            return 0.0

        precision = tp / len(submitted) if submitted else 0
        recall = tp / len(correct) if correct else 0

        if precision + recall == 0:
            return 0.0

        f1 = 2 * precision * recall / (precision + recall)
        return round(f1, 3)