Spaces:
Sleeping
Sleeping
File size: 5,367 Bytes
d103a0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import pytest
from env import SQLAnalystEnv, Action
class TestOpenEnvContract:
"""Test OpenEnv required methods: reset(), step(), state()"""
def test_reset_returns_step_result(self):
"""reset() must return a StepResult with observation"""
env = SQLAnalystEnv(task_id="monthly_signups")
result = env.reset()
assert result.observation is not None
assert result.reward == 0.0
assert result.done is False
def test_reset_contains_required_fields(self):
"""Initial observation must contain all required fields"""
env = SQLAnalystEnv(task_id="monthly_signups")
result = env.reset()
obs = result.observation
assert obs.schema_summary is not None
assert obs.question is not None
assert obs.step == 0
assert obs.max_steps == 10
def test_step_returns_step_result(self):
"""step() must return StepResult with observation, reward, done"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) FROM users")
result = env.step(action)
assert result.observation is not None
assert isinstance(result.reward, float)
assert isinstance(result.done, bool)
assert result.info is not None
def test_step_increments_step_count(self):
"""Each step should increment step count"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) FROM users")
result1 = env.step(action)
assert result1.observation.step == 1
result2 = env.step(action)
assert result2.observation.step == 2
def test_step_sql_query_execution(self):
"""step() with sql_query should execute query and return result"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) as cnt FROM users")
result = env.step(action)
assert result.observation.last_query is not None
assert result.observation.last_result is not None
assert "cnt" in result.observation.last_result.columns
def test_step_submit_answer_terminates(self):
"""submit_answer should set done=True"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(submit_answer="100")
result = env.step(action)
assert result.done is True
def test_step_max_steps_terminates(self):
"""Exceeding max_steps should terminate episode"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
# Easy task has max_steps=10
done_count = 0
for i in range(10):
action = Action(sql_query="SELECT 1")
result = env.step(action)
if result.done:
done_count += 1
# At step 10 (last step), done should be True
assert done_count >= 1, "Episode should terminate by step 10"
def test_state_returns_full_state(self):
"""state() should return EnvState with all metadata"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
env.step(Action(sql_query="SELECT 1"))
state = env.state()
assert state.task_id == "monthly_signups"
assert state.difficulty == "easy"
assert state.step > 0
assert state.max_steps == 10
assert len(state.query_history) > 0
def test_invalid_action_raises(self):
"""Action with both sql_query and submit_answer should error"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
# This should fail - exactly one must be set
with pytest.raises(AssertionError):
action = Action(sql_query="SELECT 1", submit_answer="test")
env.step(action)
def test_all_three_tasks_work(self):
"""All task IDs should be supported"""
for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]:
env = SQLAnalystEnv(task_id=task_id)
result = env.reset()
assert result.observation.question is not None
class TestEdgeCases:
"""Test edge cases and error handling"""
def test_sql_error_returns_error_in_observation(self):
"""Invalid SQL should return error in observation"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT * FROM nonexistent_table")
result = env.step(action)
assert result.observation.last_error is not None
assert result.observation.last_error != ""
def test_non_select_blocked(self):
"""Non-SELECT queries should be blocked"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="DELETE FROM users")
result = env.step(action)
assert result.observation.last_error is not None
assert "Only SELECT" in result.observation.last_error
def test_empty_db_after_reset(self):
"""Database should have data after reset"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
result = env.step(Action(sql_query="SELECT COUNT(*) FROM users"))
assert result.observation.last_result.rows[0][0] > 0
|