Spaces:
Sleeping
Sleeping
| import pytest | |
| import sqlite3 | |
| from env.tasks import ( | |
| TASKS, | |
| MonthlySignupsTask, | |
| TopRevenueCategoryTask, | |
| ChurnAnalysisTask, | |
| ) | |
| from env.database import create_database, seed_database | |
| class TestMonthlySignupsGrader: | |
| """Test the easy task grader""" | |
| def test_perfect_answer(self): | |
| """Exact match should return 1.0""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 100 | |
| score = task.grade("100") | |
| assert score == 1.0 | |
| def test_partial_credit_within_3(self): | |
| """Answer within 3 should return 0.6""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 100 | |
| score = task.grade("98") | |
| assert score == 0.6 | |
| def test_small_credit_within_10(self): | |
| """Answer within 10 should return 0.3""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 100 | |
| score = task.grade("92") | |
| assert score == 0.3 | |
| def test_wrong_answer(self): | |
| """Wrong answer should return 0.0""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 100 | |
| score = task.grade("50") | |
| assert score == 0.0 | |
| def test_comma_separated_number(self): | |
| """Numbers with commas should work""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 1000 | |
| score = task.grade("1,000") | |
| assert score == 1.0 | |
| def test_invalid_input(self): | |
| """Invalid input should return 0.0""" | |
| task = MonthlySignupsTask() | |
| task.ground_truth = 100 | |
| score = task.grade("not a number") | |
| assert score == 0.0 | |
| class TestTopRevenueCategoryGrader: | |
| """Test the medium task grader""" | |
| def test_perfect_match(self): | |
| """Exact category match should return 1.0""" | |
| task = TopRevenueCategoryTask() | |
| task.ground_truth = "Electronics" | |
| task.top_3_categories = ["Electronics", "Books", "Clothing"] | |
| score = task.grade("Electronics") | |
| assert score == 1.0 | |
| def test_partial_match_top_3(self): | |
| """Answer in top 3 but not first should return 0.4""" | |
| task = TopRevenueCategoryTask() | |
| task.ground_truth = "Electronics" | |
| task.top_3_categories = ["Electronics", "Books", "Clothing"] | |
| score = task.grade("Books") | |
| assert score == 0.4 | |
| def test_case_insensitive(self): | |
| """Should be case insensitive""" | |
| task = TopRevenueCategoryTask() | |
| task.ground_truth = "Electronics" | |
| task.top_3_categories = ["Electronics", "Books", "Clothing"] | |
| score = task.grade("electronics") | |
| assert score == 1.0 | |
| def test_llm_preamble_removed(self): | |
| """LLM preamble should be stripped""" | |
| task = TopRevenueCategoryTask() | |
| task.ground_truth = "Electronics" | |
| task.top_3_categories = ["Electronics", "Books", "Clothing"] | |
| score = task.grade("The answer is: Electronics") | |
| assert score == 1.0 | |
| def test_wrong_category(self): | |
| """Wrong category should return 0.0""" | |
| task = TopRevenueCategoryTask() | |
| task.ground_truth = "Electronics" | |
| task.top_3_categories = ["Electronics", "Books", "Clothing"] | |
| score = task.grade("Sports") | |
| assert score == 0.0 | |
| class TestChurnAnalysisGrader: | |
| """Test the hard task grader""" | |
| def test_perfect_match_all_emails(self): | |
| """All correct emails should return 1.0""" | |
| task = ChurnAnalysisTask() | |
| task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} | |
| score = task.grade("a@test.com, b@test.com, c@test.com") | |
| assert score == 1.0 | |
| def test_partial_match_precision_recall(self): | |
| """Partial match should use F1 score""" | |
| task = ChurnAnalysisTask() | |
| task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} | |
| # 2 correct out of 3 submitted, 3 total correct | |
| # precision = 2/3 = 0.667, recall = 2/3 = 0.667, f1 = 0.667 | |
| score = task.grade("a@test.com, b@test.com, wrong@test.com") | |
| assert abs(score - 0.667) < 0.01 | |
| def test_empty_submission(self): | |
| """Empty submission should return 0.0""" | |
| task = ChurnAnalysisTask() | |
| task.ground_truth = {"a@test.com", "b@test.com"} | |
| score = task.grade("") | |
| assert score == 0.0 | |
| def test_no_valid_emails(self): | |
| """No valid emails should return 0.0""" | |
| task = ChurnAnalysisTask() | |
| task.ground_truth = {"a@test.com", "b@test.com"} | |
| score = task.grade("not an email, also not") | |
| assert score == 0.0 | |
| def test_case_insensitive(self): | |
| """Should be case insensitive""" | |
| task = ChurnAnalysisTask() | |
| task.ground_truth = {"A@Test.com", "B@Test.com"} | |
| score = task.grade("a@test.com, b@test.com") | |
| assert score == 1.0 | |
| class TestTaskIntegration: | |
| """Test tasks with real database""" | |
| def test_monthly_signups_with_real_db(self): | |
| """Test with seeded database""" | |
| conn = create_database() | |
| seed_database(conn) | |
| task = MonthlySignupsTask() | |
| task.compute_ground_truth(conn) | |
| assert task.ground_truth is not None | |
| assert isinstance(task.ground_truth, int) | |
| def test_top_revenue_with_real_db(self): | |
| """Test with seeded database""" | |
| conn = create_database() | |
| seed_database(conn) | |
| task = TopRevenueCategoryTask() | |
| task.compute_ground_truth(conn) | |
| assert task.ground_truth is not None | |
| assert len(task.top_3_categories) == 3 | |
| def test_churn_analysis_with_real_db(self): | |
| """Test with seeded database""" | |
| conn = create_database() | |
| seed_database(conn) | |
| task = ChurnAnalysisTask() | |
| task.compute_ground_truth(conn) | |
| assert task.ground_truth is not None | |
| assert isinstance(task.ground_truth, set) | |
| class TestHintSystem: | |
| """Test progressive hints""" | |
| def test_no_hints_early(self): | |
| """No hints at step 5 or less""" | |
| task = MonthlySignupsTask() | |
| hints = task.get_hints(3) | |
| assert len(hints) == 0 | |
| def test_first_hint_after_5(self): | |
| """First hint after step 5""" | |
| task = MonthlySignupsTask() | |
| hints = task.get_hints(6) | |
| assert len(hints) >= 1 | |
| assert "relevant tables" in hints[0].lower() | |
| def test_second_hint_after_10(self): | |
| """Second hint after step 10""" | |
| task = MonthlySignupsTask() | |
| hints = task.get_hints(11) | |
| assert len(hints) >= 2 | |
| def test_third_hint_after_15(self): | |
| """Third hint after step 15""" | |
| task = MonthlySignupsTask() | |
| hints = task.get_hints(16) | |
| assert len(hints) >= 3 | |
| assert "submit_answer" in hints[2].lower() | |