sql_data_analyst / tests /test_env.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
import pytest
from env import SQLAnalystEnv, Action
class TestOpenEnvContract:
"""Test OpenEnv required methods: reset(), step(), state()"""
def test_reset_returns_step_result(self):
"""reset() must return a StepResult with observation"""
env = SQLAnalystEnv(task_id="monthly_signups")
result = env.reset()
assert result.observation is not None
assert result.reward == 0.0
assert result.done is False
def test_reset_contains_required_fields(self):
"""Initial observation must contain all required fields"""
env = SQLAnalystEnv(task_id="monthly_signups")
result = env.reset()
obs = result.observation
assert obs.schema_summary is not None
assert obs.question is not None
assert obs.step == 0
assert obs.max_steps == 10
def test_step_returns_step_result(self):
"""step() must return StepResult with observation, reward, done"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) FROM users")
result = env.step(action)
assert result.observation is not None
assert isinstance(result.reward, float)
assert isinstance(result.done, bool)
assert result.info is not None
def test_step_increments_step_count(self):
"""Each step should increment step count"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) FROM users")
result1 = env.step(action)
assert result1.observation.step == 1
result2 = env.step(action)
assert result2.observation.step == 2
def test_step_sql_query_execution(self):
"""step() with sql_query should execute query and return result"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT COUNT(*) as cnt FROM users")
result = env.step(action)
assert result.observation.last_query is not None
assert result.observation.last_result is not None
assert "cnt" in result.observation.last_result.columns
def test_step_submit_answer_terminates(self):
"""submit_answer should set done=True"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(submit_answer="100")
result = env.step(action)
assert result.done is True
def test_step_max_steps_terminates(self):
"""Exceeding max_steps should terminate episode"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
# Easy task has max_steps=10
done_count = 0
for i in range(10):
action = Action(sql_query="SELECT 1")
result = env.step(action)
if result.done:
done_count += 1
# At step 10 (last step), done should be True
assert done_count >= 1, "Episode should terminate by step 10"
def test_state_returns_full_state(self):
"""state() should return EnvState with all metadata"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
env.step(Action(sql_query="SELECT 1"))
state = env.state()
assert state.task_id == "monthly_signups"
assert state.difficulty == "easy"
assert state.step > 0
assert state.max_steps == 10
assert len(state.query_history) > 0
def test_invalid_action_raises(self):
"""Action with both sql_query and submit_answer should error"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
# This should fail - exactly one must be set
with pytest.raises(AssertionError):
action = Action(sql_query="SELECT 1", submit_answer="test")
env.step(action)
def test_all_three_tasks_work(self):
"""All task IDs should be supported"""
for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]:
env = SQLAnalystEnv(task_id=task_id)
result = env.reset()
assert result.observation.question is not None
class TestEdgeCases:
"""Test edge cases and error handling"""
def test_sql_error_returns_error_in_observation(self):
"""Invalid SQL should return error in observation"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="SELECT * FROM nonexistent_table")
result = env.step(action)
assert result.observation.last_error is not None
assert result.observation.last_error != ""
def test_non_select_blocked(self):
"""Non-SELECT queries should be blocked"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
action = Action(sql_query="DELETE FROM users")
result = env.step(action)
assert result.observation.last_error is not None
assert "Only SELECT" in result.observation.last_error
def test_empty_db_after_reset(self):
"""Database should have data after reset"""
env = SQLAnalystEnv(task_id="monthly_signups")
env.reset()
result = env.step(Action(sql_query="SELECT COUNT(*) FROM users"))
assert result.observation.last_result.rows[0][0] > 0