sql-arena / tests /test_env.py
rahul2124's picture
Upload folder using huggingface_hub
ac49ad8 verified
"""Tests for SQL Arena environment."""
import pytest
from src.sql_arena.environment import SQLArenaEnvironment
from src.sql_arena.models import SQLArenaAction
from src.sql_arena.tasks import list_tasks, TASK_BY_ID
class TestEnvironmentBasics:
def setup_method(self):
self.env = SQLArenaEnvironment()
def teardown_method(self):
self.env.close()
def test_reset_returns_observation(self):
result = self.env.reset(difficulty="basic_select", task_id="easy_001")
assert result.observation is not None
assert result.reward == 0.0
assert result.done is False
def test_step_with_correct_query(self):
self.env.reset(difficulty="basic_select", task_id="easy_001")
task = self.env.current_task
action = SQLArenaAction(sql_query=task.expected_sql)
result = self.env.step(action)
assert result.reward > 0.0
assert result.info.get("score", 0) >= 0.8
def test_step_with_invalid_query(self):
self.env.reset(difficulty="basic_select", task_id="easy_001")
action = SQLArenaAction(sql_query="INVALID SQL QUERY")
result = self.env.step(action)
assert result.reward == 0.01 # Clamped to strictly > 0
assert result.observation.error_message is not None
def test_state_tracking(self):
self.env.reset(difficulty="basic_select", task_id="easy_001")
state = self.env.state()
assert state.current_step == 0
self.env.step(SQLArenaAction(sql_query="SELECT 1"))
state = self.env.state()
assert state.current_step == 1
def test_episode_terminates(self):
self.env.reset(difficulty="basic_select", task_id="easy_001")
task = self.env.current_task
for _ in range(task.max_steps + 1):
if self.env.state().done:
break
self.env.step(SQLArenaAction(sql_query="SELECT 1"))
assert self.env.state().done is True
class TestAllDifficulties:
def setup_method(self):
self.env = SQLArenaEnvironment()
def teardown_method(self):
self.env.close()
def test_easy(self):
result = self.env.reset(difficulty="basic_select")
assert result.observation.difficulty == "basic_select"
def test_medium(self):
result = self.env.reset(difficulty="join_aggregate")
assert result.observation.difficulty == "join_aggregate"
def test_hard(self):
result = self.env.reset(difficulty="complex_analysis")
assert result.observation.difficulty == "complex_analysis"
class TestGrading:
def setup_method(self):
self.env = SQLArenaEnvironment()
def teardown_method(self):
self.env.close()
def test_scores_in_range(self):
for task_id, task in TASK_BY_ID.items():
self.env.reset(difficulty=task.difficulty, task_id=task_id)
action = SQLArenaAction(sql_query=task.expected_sql)
result = self.env.step(action)
assert 0.0 <= result.reward <= 1.0
assert 0.0 <= result.info.get("score", 0) <= 1.0
self.env.reset(difficulty=task.difficulty, task_id=task_id)
def test_varying_scores(self):
scores = set()
queries = [
"SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
"SELECT * FROM employees",
"INVALID",
"SELECT name FROM employees",
]
for q in queries:
self.env.reset(difficulty="basic_select", task_id="easy_001")
result = self.env.step(SQLArenaAction(sql_query=q))
scores.add(round(result.info.get("score", 0), 2))
assert len(scores) > 1, "Grader always returns the same score!"
class TestTaskRegistry:
def test_list_tasks(self):
tasks = list_tasks()
assert "basic_select" in tasks
assert "join_aggregate" in tasks
assert "complex_analysis" in tasks
def test_minimum_3_tasks(self):
tasks = list_tasks()
for difficulty, task_ids in tasks.items():
assert len(task_ids) >= 3, f"{difficulty} has fewer than 3 tasks"
if __name__ == "__main__":
pytest.main([__file__, "-v"])