import pytest import sqlite3 from env.tasks import ( TASKS, MonthlySignupsTask, TopRevenueCategoryTask, ChurnAnalysisTask, ) from env.database import create_database, seed_database class TestMonthlySignupsGrader: """Test the easy task grader""" def test_perfect_answer(self): """Exact match should return 1.0""" task = MonthlySignupsTask() task.ground_truth = 100 score = task.grade("100") assert score == 1.0 def test_partial_credit_within_3(self): """Answer within 3 should return 0.6""" task = MonthlySignupsTask() task.ground_truth = 100 score = task.grade("98") assert score == 0.6 def test_small_credit_within_10(self): """Answer within 10 should return 0.3""" task = MonthlySignupsTask() task.ground_truth = 100 score = task.grade("92") assert score == 0.3 def test_wrong_answer(self): """Wrong answer should return 0.0""" task = MonthlySignupsTask() task.ground_truth = 100 score = task.grade("50") assert score == 0.0 def test_comma_separated_number(self): """Numbers with commas should work""" task = MonthlySignupsTask() task.ground_truth = 1000 score = task.grade("1,000") assert score == 1.0 def test_invalid_input(self): """Invalid input should return 0.0""" task = MonthlySignupsTask() task.ground_truth = 100 score = task.grade("not a number") assert score == 0.0 class TestTopRevenueCategoryGrader: """Test the medium task grader""" def test_perfect_match(self): """Exact category match should return 1.0""" task = TopRevenueCategoryTask() task.ground_truth = "Electronics" task.top_3_categories = ["Electronics", "Books", "Clothing"] score = task.grade("Electronics") assert score == 1.0 def test_partial_match_top_3(self): """Answer in top 3 but not first should return 0.4""" task = TopRevenueCategoryTask() task.ground_truth = "Electronics" task.top_3_categories = ["Electronics", "Books", "Clothing"] score = task.grade("Books") assert score == 0.4 def test_case_insensitive(self): """Should be case insensitive""" task = TopRevenueCategoryTask() task.ground_truth = "Electronics" task.top_3_categories = ["Electronics", "Books", "Clothing"] score = task.grade("electronics") assert score == 1.0 def test_llm_preamble_removed(self): """LLM preamble should be stripped""" task = TopRevenueCategoryTask() task.ground_truth = "Electronics" task.top_3_categories = ["Electronics", "Books", "Clothing"] score = task.grade("The answer is: Electronics") assert score == 1.0 def test_wrong_category(self): """Wrong category should return 0.0""" task = TopRevenueCategoryTask() task.ground_truth = "Electronics" task.top_3_categories = ["Electronics", "Books", "Clothing"] score = task.grade("Sports") assert score == 0.0 class TestChurnAnalysisGrader: """Test the hard task grader""" def test_perfect_match_all_emails(self): """All correct emails should return 1.0""" task = ChurnAnalysisTask() task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} score = task.grade("a@test.com, b@test.com, c@test.com") assert score == 1.0 def test_partial_match_precision_recall(self): """Partial match should use F1 score""" task = ChurnAnalysisTask() task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} # 2 correct out of 3 submitted, 3 total correct # precision = 2/3 = 0.667, recall = 2/3 = 0.667, f1 = 0.667 score = task.grade("a@test.com, b@test.com, wrong@test.com") assert abs(score - 0.667) < 0.01 def test_empty_submission(self): """Empty submission should return 0.0""" task = ChurnAnalysisTask() task.ground_truth = {"a@test.com", "b@test.com"} score = task.grade("") assert score == 0.0 def test_no_valid_emails(self): """No valid emails should return 0.0""" task = ChurnAnalysisTask() task.ground_truth = {"a@test.com", "b@test.com"} score = task.grade("not an email, also not") assert score == 0.0 def test_case_insensitive(self): """Should be case insensitive""" task = ChurnAnalysisTask() task.ground_truth = {"A@Test.com", "B@Test.com"} score = task.grade("a@test.com, b@test.com") assert score == 1.0 class TestTaskIntegration: """Test tasks with real database""" def test_monthly_signups_with_real_db(self): """Test with seeded database""" conn = create_database() seed_database(conn) task = MonthlySignupsTask() task.compute_ground_truth(conn) assert task.ground_truth is not None assert isinstance(task.ground_truth, int) def test_top_revenue_with_real_db(self): """Test with seeded database""" conn = create_database() seed_database(conn) task = TopRevenueCategoryTask() task.compute_ground_truth(conn) assert task.ground_truth is not None assert len(task.top_3_categories) == 3 def test_churn_analysis_with_real_db(self): """Test with seeded database""" conn = create_database() seed_database(conn) task = ChurnAnalysisTask() task.compute_ground_truth(conn) assert task.ground_truth is not None assert isinstance(task.ground_truth, set) class TestHintSystem: """Test progressive hints""" def test_no_hints_early(self): """No hints at step 5 or less""" task = MonthlySignupsTask() hints = task.get_hints(3) assert len(hints) == 0 def test_first_hint_after_5(self): """First hint after step 5""" task = MonthlySignupsTask() hints = task.get_hints(6) assert len(hints) >= 1 assert "relevant tables" in hints[0].lower() def test_second_hint_after_10(self): """Second hint after step 10""" task = MonthlySignupsTask() hints = task.get_hints(11) assert len(hints) >= 2 def test_third_hint_after_15(self): """Third hint after step 15""" task = MonthlySignupsTask() hints = task.get_hints(16) assert len(hints) >= 3 assert "submit_answer" in hints[2].lower()