Spaces:
Sleeping
Sleeping
| import pytest | |
| from env import SQLAnalystEnv, Action | |
| class TestOpenEnvContract: | |
| """Test OpenEnv required methods: reset(), step(), state()""" | |
| def test_reset_returns_step_result(self): | |
| """reset() must return a StepResult with observation""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| result = env.reset() | |
| assert result.observation is not None | |
| assert result.reward == 0.0 | |
| assert result.done is False | |
| def test_reset_contains_required_fields(self): | |
| """Initial observation must contain all required fields""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| result = env.reset() | |
| obs = result.observation | |
| assert obs.schema_summary is not None | |
| assert obs.question is not None | |
| assert obs.step == 0 | |
| assert obs.max_steps == 10 | |
| def test_step_returns_step_result(self): | |
| """step() must return StepResult with observation, reward, done""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(sql_query="SELECT COUNT(*) FROM users") | |
| result = env.step(action) | |
| assert result.observation is not None | |
| assert isinstance(result.reward, float) | |
| assert isinstance(result.done, bool) | |
| assert result.info is not None | |
| def test_step_increments_step_count(self): | |
| """Each step should increment step count""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(sql_query="SELECT COUNT(*) FROM users") | |
| result1 = env.step(action) | |
| assert result1.observation.step == 1 | |
| result2 = env.step(action) | |
| assert result2.observation.step == 2 | |
| def test_step_sql_query_execution(self): | |
| """step() with sql_query should execute query and return result""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(sql_query="SELECT COUNT(*) as cnt FROM users") | |
| result = env.step(action) | |
| assert result.observation.last_query is not None | |
| assert result.observation.last_result is not None | |
| assert "cnt" in result.observation.last_result.columns | |
| def test_step_submit_answer_terminates(self): | |
| """submit_answer should set done=True""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(submit_answer="100") | |
| result = env.step(action) | |
| assert result.done is True | |
| def test_step_max_steps_terminates(self): | |
| """Exceeding max_steps should terminate episode""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| # Easy task has max_steps=10 | |
| done_count = 0 | |
| for i in range(10): | |
| action = Action(sql_query="SELECT 1") | |
| result = env.step(action) | |
| if result.done: | |
| done_count += 1 | |
| # At step 10 (last step), done should be True | |
| assert done_count >= 1, "Episode should terminate by step 10" | |
| def test_state_returns_full_state(self): | |
| """state() should return EnvState with all metadata""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| env.step(Action(sql_query="SELECT 1")) | |
| state = env.state() | |
| assert state.task_id == "monthly_signups" | |
| assert state.difficulty == "easy" | |
| assert state.step > 0 | |
| assert state.max_steps == 10 | |
| assert len(state.query_history) > 0 | |
| def test_invalid_action_raises(self): | |
| """Action with both sql_query and submit_answer should error""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| # This should fail - exactly one must be set | |
| with pytest.raises(AssertionError): | |
| action = Action(sql_query="SELECT 1", submit_answer="test") | |
| env.step(action) | |
| def test_all_three_tasks_work(self): | |
| """All task IDs should be supported""" | |
| for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]: | |
| env = SQLAnalystEnv(task_id=task_id) | |
| result = env.reset() | |
| assert result.observation.question is not None | |
| class TestEdgeCases: | |
| """Test edge cases and error handling""" | |
| def test_sql_error_returns_error_in_observation(self): | |
| """Invalid SQL should return error in observation""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(sql_query="SELECT * FROM nonexistent_table") | |
| result = env.step(action) | |
| assert result.observation.last_error is not None | |
| assert result.observation.last_error != "" | |
| def test_non_select_blocked(self): | |
| """Non-SELECT queries should be blocked""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| action = Action(sql_query="DELETE FROM users") | |
| result = env.step(action) | |
| assert result.observation.last_error is not None | |
| assert "Only SELECT" in result.observation.last_error | |
| def test_empty_db_after_reset(self): | |
| """Database should have data after reset""" | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| env.reset() | |
| result = env.step(Action(sql_query="SELECT COUNT(*) FROM users")) | |
| assert result.observation.last_result.rows[0][0] > 0 | |