File size: 4,524 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pytest
from env.reward import RewardCalculator
from env.models import Action, QueryResult
from env.tasks import MonthlySignupsTask


class MockTask:
    """Mock task for testing reward calculator"""

    def __init__(self):
        self.relevant_tables = ["users", "orders"]
        self.ground_truth = 100
        self.difficulty = "easy"
        self.max_steps = 10

    def grade(self, answer):
        return 1.0 if answer == str(self.ground_truth) else 0.0

    def get_hints(self, step):
        return []


class TestRewardCalculator:
    """Test the reward calculation logic"""

    def setup_method(self):
        self.calc = RewardCalculator()
        self.task = MockTask()

    def test_no_error_query_reward(self):
        """Query without error gets +0.15"""
        action = Action(sql_query="SELECT 1 FROM users")
        result = QueryResult(columns=["1"], rows=[[1]], error=None)

        reward = self.calc.calculate(action, result, self.task, 1, [], False)

        assert reward >= 0.15

    def test_relevant_table_reward(self):
        """Query touching relevant table gets +0.10"""
        action = Action(sql_query="SELECT * FROM users")
        result = QueryResult(columns=["id"], rows=[[1]], error=None)

        reward = self.calc.calculate(action, result, self.task, 1, [], False)

        assert reward >= 0.10

    def test_non_empty_result_reward(self):
        """Query with rows gets +0.05"""
        action = Action(sql_query="SELECT 1")
        result = QueryResult(columns=["1"], rows=[[1]], error=None)

        reward = self.calc.calculate(action, result, self.task, 1, [], False)

        assert reward >= 0.05

    def test_error_query_no_reward(self):
        """Query with error gets no step rewards"""
        action = Action(sql_query="SELECT * FROM nonexistent")
        result = QueryResult(columns=[], rows=[], error="Table not found")

        reward = self.calc.calculate(action, result, self.task, 1, [], False)

        assert reward == 0.0

    def test_efficiency_penalty_after_step_3(self):
        """Steps beyond 3 get -0.02 per step"""
        action = Action(sql_query="SELECT 1")
        result = QueryResult(columns=["1"], rows=[[1]], error=None)

        reward = self.calc.calculate(action, result, self.task, 5, [], False)

        # 0.15 + 0.10 + 0.05 + 0.05 - (0.02 * 2) = 0.31
        assert reward < 0.35

    def test_infinite_loop_penalty(self):
        """Same query 3 times gets -0.10"""
        action = Action(sql_query="SELECT 1")
        result = QueryResult(columns=["1"], rows=[[1]], error=None)

        query_history = ["SELECT 1", "SELECT 1", "SELECT 1"]
        reward = self.calc.calculate(action, result, self.task, 4, query_history, False)

        assert reward < 0.30

    def test_terminal_submit_grade_reward(self):
        """Terminal submit gets up to 0.60 based on grade"""
        action = Action(submit_answer="100")
        result = None

        # Use step 1 to avoid efficiency penalty
        reward = self.calc.calculate(action, result, self.task, 1, [], True)

        # grade(100) = 1.0 * 0.60 = 0.60
        assert reward >= 0.60

    def test_terminal_submit_wrong_answer(self):
        """Wrong answer gets partial terminal reward"""
        action = Action(submit_answer="999")
        result = None

        reward = self.calc.calculate(action, result, self.task, 5, [], True)

        # grade(999) = 0.0 * 0.60 = 0.0
        assert reward < 0.10

    def test_reward_clamped_to_0_1(self):
        """Reward should be clamped between 0 and 1"""
        # Create task that always grades 1.0
        task = MockTask()

        # Many steps should accumulate penalty but stay >= 0
        action = Action(sql_query="SELECT 1")
        result = QueryResult(columns=["1"], rows=[[1]], error=None)

        reward = self.calc.calculate(action, result, task, 50, [], False)

        assert 0.0 <= reward <= 1.0


class TestRewardBreakdown:
    """Test specific reward components"""

    def test_max_step_reward_calculation(self):
        """Test maximum possible reward at good query"""
        action = Action(sql_query="SELECT * FROM users")
        result = QueryResult(columns=["id"], rows=[[1], [2], [3]], error=None)

        calc = RewardCalculator()
        task = MockTask()

        reward = calc.calculate(action, result, task, 1, [], False)

        # 0.15 (no error) + 0.10 (relevant table) + 0.05 (has rows) + 0.05 (reasonable size)
        expected = 0.35
        assert abs(reward - expected) < 0.01