import pytest from env import SQLAnalystEnv, Action class TestOpenEnvContract: """Test OpenEnv required methods: reset(), step(), state()""" def test_reset_returns_step_result(self): """reset() must return a StepResult with observation""" env = SQLAnalystEnv(task_id="monthly_signups") result = env.reset() assert result.observation is not None assert result.reward == 0.0 assert result.done is False def test_reset_contains_required_fields(self): """Initial observation must contain all required fields""" env = SQLAnalystEnv(task_id="monthly_signups") result = env.reset() obs = result.observation assert obs.schema_summary is not None assert obs.question is not None assert obs.step == 0 assert obs.max_steps == 10 def test_step_returns_step_result(self): """step() must return StepResult with observation, reward, done""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(sql_query="SELECT COUNT(*) FROM users") result = env.step(action) assert result.observation is not None assert isinstance(result.reward, float) assert isinstance(result.done, bool) assert result.info is not None def test_step_increments_step_count(self): """Each step should increment step count""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(sql_query="SELECT COUNT(*) FROM users") result1 = env.step(action) assert result1.observation.step == 1 result2 = env.step(action) assert result2.observation.step == 2 def test_step_sql_query_execution(self): """step() with sql_query should execute query and return result""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(sql_query="SELECT COUNT(*) as cnt FROM users") result = env.step(action) assert result.observation.last_query is not None assert result.observation.last_result is not None assert "cnt" in result.observation.last_result.columns def test_step_submit_answer_terminates(self): """submit_answer should set done=True""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(submit_answer="100") result = env.step(action) assert result.done is True def test_step_max_steps_terminates(self): """Exceeding max_steps should terminate episode""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() # Easy task has max_steps=10 done_count = 0 for i in range(10): action = Action(sql_query="SELECT 1") result = env.step(action) if result.done: done_count += 1 # At step 10 (last step), done should be True assert done_count >= 1, "Episode should terminate by step 10" def test_state_returns_full_state(self): """state() should return EnvState with all metadata""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() env.step(Action(sql_query="SELECT 1")) state = env.state() assert state.task_id == "monthly_signups" assert state.difficulty == "easy" assert state.step > 0 assert state.max_steps == 10 assert len(state.query_history) > 0 def test_invalid_action_raises(self): """Action with both sql_query and submit_answer should error""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() # This should fail - exactly one must be set with pytest.raises(AssertionError): action = Action(sql_query="SELECT 1", submit_answer="test") env.step(action) def test_all_three_tasks_work(self): """All task IDs should be supported""" for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]: env = SQLAnalystEnv(task_id=task_id) result = env.reset() assert result.observation.question is not None class TestEdgeCases: """Test edge cases and error handling""" def test_sql_error_returns_error_in_observation(self): """Invalid SQL should return error in observation""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(sql_query="SELECT * FROM nonexistent_table") result = env.step(action) assert result.observation.last_error is not None assert result.observation.last_error != "" def test_non_select_blocked(self): """Non-SELECT queries should be blocked""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() action = Action(sql_query="DELETE FROM users") result = env.step(action) assert result.observation.last_error is not None assert "Only SELECT" in result.observation.last_error def test_empty_db_after_reset(self): """Database should have data after reset""" env = SQLAnalystEnv(task_id="monthly_signups") env.reset() result = env.step(Action(sql_query="SELECT COUNT(*) FROM users")) assert result.observation.last_result.rows[0][0] > 0