File size: 5,367 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import pytest
from env import SQLAnalystEnv, Action


class TestOpenEnvContract:
    """Test OpenEnv required methods: reset(), step(), state()"""

    def test_reset_returns_step_result(self):
        """reset() must return a StepResult with observation"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        result = env.reset()

        assert result.observation is not None
        assert result.reward == 0.0
        assert result.done is False

    def test_reset_contains_required_fields(self):
        """Initial observation must contain all required fields"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        result = env.reset()
        obs = result.observation

        assert obs.schema_summary is not None
        assert obs.question is not None
        assert obs.step == 0
        assert obs.max_steps == 10

    def test_step_returns_step_result(self):
        """step() must return StepResult with observation, reward, done"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(sql_query="SELECT COUNT(*) FROM users")
        result = env.step(action)

        assert result.observation is not None
        assert isinstance(result.reward, float)
        assert isinstance(result.done, bool)
        assert result.info is not None

    def test_step_increments_step_count(self):
        """Each step should increment step count"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(sql_query="SELECT COUNT(*) FROM users")

        result1 = env.step(action)
        assert result1.observation.step == 1

        result2 = env.step(action)
        assert result2.observation.step == 2

    def test_step_sql_query_execution(self):
        """step() with sql_query should execute query and return result"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(sql_query="SELECT COUNT(*) as cnt FROM users")
        result = env.step(action)

        assert result.observation.last_query is not None
        assert result.observation.last_result is not None
        assert "cnt" in result.observation.last_result.columns

    def test_step_submit_answer_terminates(self):
        """submit_answer should set done=True"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(submit_answer="100")
        result = env.step(action)

        assert result.done is True

    def test_step_max_steps_terminates(self):
        """Exceeding max_steps should terminate episode"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        # Easy task has max_steps=10
        done_count = 0
        for i in range(10):
            action = Action(sql_query="SELECT 1")
            result = env.step(action)
            if result.done:
                done_count += 1

        # At step 10 (last step), done should be True
        assert done_count >= 1, "Episode should terminate by step 10"

    def test_state_returns_full_state(self):
        """state() should return EnvState with all metadata"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        env.step(Action(sql_query="SELECT 1"))

        state = env.state()

        assert state.task_id == "monthly_signups"
        assert state.difficulty == "easy"
        assert state.step > 0
        assert state.max_steps == 10
        assert len(state.query_history) > 0

    def test_invalid_action_raises(self):
        """Action with both sql_query and submit_answer should error"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        # This should fail - exactly one must be set
        with pytest.raises(AssertionError):
            action = Action(sql_query="SELECT 1", submit_answer="test")
            env.step(action)

    def test_all_three_tasks_work(self):
        """All task IDs should be supported"""
        for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]:
            env = SQLAnalystEnv(task_id=task_id)
            result = env.reset()
            assert result.observation.question is not None


class TestEdgeCases:
    """Test edge cases and error handling"""

    def test_sql_error_returns_error_in_observation(self):
        """Invalid SQL should return error in observation"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(sql_query="SELECT * FROM nonexistent_table")
        result = env.step(action)

        assert result.observation.last_error is not None
        assert result.observation.last_error != ""

    def test_non_select_blocked(self):
        """Non-SELECT queries should be blocked"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        action = Action(sql_query="DELETE FROM users")
        result = env.step(action)

        assert result.observation.last_error is not None
        assert "Only SELECT" in result.observation.last_error

    def test_empty_db_after_reset(self):
        """Database should have data after reset"""
        env = SQLAnalystEnv(task_id="monthly_signups")
        env.reset()

        result = env.step(Action(sql_query="SELECT COUNT(*) FROM users"))
        assert result.observation.last_result.rows[0][0] > 0