sql_env / tests /test_trl_adapter.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Tests for SQLEnvTRL question-database alignment and post-episode penalty."""
import pytest
from sql_env.training.trl_adapter import SQLEnvTRL, _POST_EPISODE_PENALTY
def test_reset_with_question_text_loads_correct_database():
"""Verify that passing question_text to reset() loads the matching DB."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
# "How many templates do we have?" belongs to cre_Doc_Template_Mgt
obs = env.reset(question_text="How many templates do we have?")
assert "Templates" in obs, (
f"Expected 'Templates' table for template question, got: {obs}"
)
# "How many employees are there?" belongs to employee_hire_evaluation
obs = env.reset(
question_text="Find the total amount of bonus given in all the evaluations."
)
assert "employee" in obs or "evaluation" in obs, (
f"Expected employee/evaluation tables, got: {obs}"
)
def test_reset_without_question_text_still_works():
"""Verify reset() works without question_text (random question)."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
obs = env.reset()
assert "Tables:" in obs
assert "Use describe" in obs
def test_reset_with_unknown_question_falls_back():
"""Unknown question_text falls back to random selection."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
obs = env.reset(question_text="This question does not exist")
assert "Tables:" in obs
def test_post_episode_penalty_on_tool_call_after_done():
"""Calling a tool after episode ends should apply penalty and raise."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
env.reset()
env._done = True
env.reward = 1.0
with pytest.raises(ValueError, match="Episode is over"):
env.describe("some_table")
assert env.reward == pytest.approx(1.0 + _POST_EPISODE_PENALTY)
# Second call stacks the penalty
with pytest.raises(ValueError, match="Episode is over"):
env.query("SELECT 1")
assert env.reward == pytest.approx(1.0 + 2 * _POST_EPISODE_PENALTY)