sql_data_analyst / tests /test_reward.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
raw
history blame
4.52 kB
import pytest
from env.reward import RewardCalculator
from env.models import Action, QueryResult
from env.tasks import MonthlySignupsTask
class MockTask:
"""Mock task for testing reward calculator"""
def __init__(self):
self.relevant_tables = ["users", "orders"]
self.ground_truth = 100
self.difficulty = "easy"
self.max_steps = 10
def grade(self, answer):
return 1.0 if answer == str(self.ground_truth) else 0.0
def get_hints(self, step):
return []
class TestRewardCalculator:
"""Test the reward calculation logic"""
def setup_method(self):
self.calc = RewardCalculator()
self.task = MockTask()
def test_no_error_query_reward(self):
"""Query without error gets +0.15"""
action = Action(sql_query="SELECT 1 FROM users")
result = QueryResult(columns=["1"], rows=[[1]], error=None)
reward = self.calc.calculate(action, result, self.task, 1, [], False)
assert reward >= 0.15
def test_relevant_table_reward(self):
"""Query touching relevant table gets +0.10"""
action = Action(sql_query="SELECT * FROM users")
result = QueryResult(columns=["id"], rows=[[1]], error=None)
reward = self.calc.calculate(action, result, self.task, 1, [], False)
assert reward >= 0.10
def test_non_empty_result_reward(self):
"""Query with rows gets +0.05"""
action = Action(sql_query="SELECT 1")
result = QueryResult(columns=["1"], rows=[[1]], error=None)
reward = self.calc.calculate(action, result, self.task, 1, [], False)
assert reward >= 0.05
def test_error_query_no_reward(self):
"""Query with error gets no step rewards"""
action = Action(sql_query="SELECT * FROM nonexistent")
result = QueryResult(columns=[], rows=[], error="Table not found")
reward = self.calc.calculate(action, result, self.task, 1, [], False)
assert reward == 0.0
def test_efficiency_penalty_after_step_3(self):
"""Steps beyond 3 get -0.02 per step"""
action = Action(sql_query="SELECT 1")
result = QueryResult(columns=["1"], rows=[[1]], error=None)
reward = self.calc.calculate(action, result, self.task, 5, [], False)
# 0.15 + 0.10 + 0.05 + 0.05 - (0.02 * 2) = 0.31
assert reward < 0.35
def test_infinite_loop_penalty(self):
"""Same query 3 times gets -0.10"""
action = Action(sql_query="SELECT 1")
result = QueryResult(columns=["1"], rows=[[1]], error=None)
query_history = ["SELECT 1", "SELECT 1", "SELECT 1"]
reward = self.calc.calculate(action, result, self.task, 4, query_history, False)
assert reward < 0.30
def test_terminal_submit_grade_reward(self):
"""Terminal submit gets up to 0.60 based on grade"""
action = Action(submit_answer="100")
result = None
# Use step 1 to avoid efficiency penalty
reward = self.calc.calculate(action, result, self.task, 1, [], True)
# grade(100) = 1.0 * 0.60 = 0.60
assert reward >= 0.60
def test_terminal_submit_wrong_answer(self):
"""Wrong answer gets partial terminal reward"""
action = Action(submit_answer="999")
result = None
reward = self.calc.calculate(action, result, self.task, 5, [], True)
# grade(999) = 0.0 * 0.60 = 0.0
assert reward < 0.10
def test_reward_clamped_to_0_1(self):
"""Reward should be clamped between 0 and 1"""
# Create task that always grades 1.0
task = MockTask()
# Many steps should accumulate penalty but stay >= 0
action = Action(sql_query="SELECT 1")
result = QueryResult(columns=["1"], rows=[[1]], error=None)
reward = self.calc.calculate(action, result, task, 50, [], False)
assert 0.0 <= reward <= 1.0
class TestRewardBreakdown:
"""Test specific reward components"""
def test_max_step_reward_calculation(self):
"""Test maximum possible reward at good query"""
action = Action(sql_query="SELECT * FROM users")
result = QueryResult(columns=["id"], rows=[[1], [2], [3]], error=None)
calc = RewardCalculator()
task = MockTask()
reward = calc.calculate(action, result, task, 1, [], False)
# 0.15 (no error) + 0.10 (relevant table) + 0.05 (has rows) + 0.05 (reasonable size)
expected = 0.35
assert abs(reward - expected) < 0.01