sql_data_analyst / tests /test_graders.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
raw
history blame
6.69 kB
import pytest
import sqlite3
from env.tasks import (
TASKS,
MonthlySignupsTask,
TopRevenueCategoryTask,
ChurnAnalysisTask,
)
from env.database import create_database, seed_database
class TestMonthlySignupsGrader:
"""Test the easy task grader"""
def test_perfect_answer(self):
"""Exact match should return 1.0"""
task = MonthlySignupsTask()
task.ground_truth = 100
score = task.grade("100")
assert score == 1.0
def test_partial_credit_within_3(self):
"""Answer within 3 should return 0.6"""
task = MonthlySignupsTask()
task.ground_truth = 100
score = task.grade("98")
assert score == 0.6
def test_small_credit_within_10(self):
"""Answer within 10 should return 0.3"""
task = MonthlySignupsTask()
task.ground_truth = 100
score = task.grade("92")
assert score == 0.3
def test_wrong_answer(self):
"""Wrong answer should return 0.0"""
task = MonthlySignupsTask()
task.ground_truth = 100
score = task.grade("50")
assert score == 0.0
def test_comma_separated_number(self):
"""Numbers with commas should work"""
task = MonthlySignupsTask()
task.ground_truth = 1000
score = task.grade("1,000")
assert score == 1.0
def test_invalid_input(self):
"""Invalid input should return 0.0"""
task = MonthlySignupsTask()
task.ground_truth = 100
score = task.grade("not a number")
assert score == 0.0
class TestTopRevenueCategoryGrader:
"""Test the medium task grader"""
def test_perfect_match(self):
"""Exact category match should return 1.0"""
task = TopRevenueCategoryTask()
task.ground_truth = "Electronics"
task.top_3_categories = ["Electronics", "Books", "Clothing"]
score = task.grade("Electronics")
assert score == 1.0
def test_partial_match_top_3(self):
"""Answer in top 3 but not first should return 0.4"""
task = TopRevenueCategoryTask()
task.ground_truth = "Electronics"
task.top_3_categories = ["Electronics", "Books", "Clothing"]
score = task.grade("Books")
assert score == 0.4
def test_case_insensitive(self):
"""Should be case insensitive"""
task = TopRevenueCategoryTask()
task.ground_truth = "Electronics"
task.top_3_categories = ["Electronics", "Books", "Clothing"]
score = task.grade("electronics")
assert score == 1.0
def test_llm_preamble_removed(self):
"""LLM preamble should be stripped"""
task = TopRevenueCategoryTask()
task.ground_truth = "Electronics"
task.top_3_categories = ["Electronics", "Books", "Clothing"]
score = task.grade("The answer is: Electronics")
assert score == 1.0
def test_wrong_category(self):
"""Wrong category should return 0.0"""
task = TopRevenueCategoryTask()
task.ground_truth = "Electronics"
task.top_3_categories = ["Electronics", "Books", "Clothing"]
score = task.grade("Sports")
assert score == 0.0
class TestChurnAnalysisGrader:
"""Test the hard task grader"""
def test_perfect_match_all_emails(self):
"""All correct emails should return 1.0"""
task = ChurnAnalysisTask()
task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"}
score = task.grade("a@test.com, b@test.com, c@test.com")
assert score == 1.0
def test_partial_match_precision_recall(self):
"""Partial match should use F1 score"""
task = ChurnAnalysisTask()
task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"}
# 2 correct out of 3 submitted, 3 total correct
# precision = 2/3 = 0.667, recall = 2/3 = 0.667, f1 = 0.667
score = task.grade("a@test.com, b@test.com, wrong@test.com")
assert abs(score - 0.667) < 0.01
def test_empty_submission(self):
"""Empty submission should return 0.0"""
task = ChurnAnalysisTask()
task.ground_truth = {"a@test.com", "b@test.com"}
score = task.grade("")
assert score == 0.0
def test_no_valid_emails(self):
"""No valid emails should return 0.0"""
task = ChurnAnalysisTask()
task.ground_truth = {"a@test.com", "b@test.com"}
score = task.grade("not an email, also not")
assert score == 0.0
def test_case_insensitive(self):
"""Should be case insensitive"""
task = ChurnAnalysisTask()
task.ground_truth = {"A@Test.com", "B@Test.com"}
score = task.grade("a@test.com, b@test.com")
assert score == 1.0
class TestTaskIntegration:
"""Test tasks with real database"""
def test_monthly_signups_with_real_db(self):
"""Test with seeded database"""
conn = create_database()
seed_database(conn)
task = MonthlySignupsTask()
task.compute_ground_truth(conn)
assert task.ground_truth is not None
assert isinstance(task.ground_truth, int)
def test_top_revenue_with_real_db(self):
"""Test with seeded database"""
conn = create_database()
seed_database(conn)
task = TopRevenueCategoryTask()
task.compute_ground_truth(conn)
assert task.ground_truth is not None
assert len(task.top_3_categories) == 3
def test_churn_analysis_with_real_db(self):
"""Test with seeded database"""
conn = create_database()
seed_database(conn)
task = ChurnAnalysisTask()
task.compute_ground_truth(conn)
assert task.ground_truth is not None
assert isinstance(task.ground_truth, set)
class TestHintSystem:
"""Test progressive hints"""
def test_no_hints_early(self):
"""No hints at step 5 or less"""
task = MonthlySignupsTask()
hints = task.get_hints(3)
assert len(hints) == 0
def test_first_hint_after_5(self):
"""First hint after step 5"""
task = MonthlySignupsTask()
hints = task.get_hints(6)
assert len(hints) >= 1
assert "relevant tables" in hints[0].lower()
def test_second_hint_after_10(self):
"""Second hint after step 10"""
task = MonthlySignupsTask()
hints = task.get_hints(11)
assert len(hints) >= 2
def test_third_hint_after_15(self):
"""Third hint after step 15"""
task = MonthlySignupsTask()
hints = task.get_hints(16)
assert len(hints) >= 3
assert "submit_answer" in hints[2].lower()