| """Tests for SQLEnvTRL question-database alignment and post-episode penalty.""" |
|
|
| import pytest |
|
|
| from sql_env.training.trl_adapter import SQLEnvTRL, _POST_EPISODE_PENALTY |
|
|
|
|
| def test_reset_with_question_text_loads_correct_database(): |
| """Verify that passing question_text to reset() loads the matching DB.""" |
| SQLEnvTRL._configure( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| step_budget=5, |
| ) |
| env = SQLEnvTRL() |
|
|
| |
| obs = env.reset(question_text="How many templates do we have?") |
| assert "Templates" in obs, ( |
| f"Expected 'Templates' table for template question, got: {obs}" |
| ) |
|
|
| |
| obs = env.reset( |
| question_text="Find the total amount of bonus given in all the evaluations." |
| ) |
| assert "employee" in obs or "evaluation" in obs, ( |
| f"Expected employee/evaluation tables, got: {obs}" |
| ) |
|
|
|
|
| def test_reset_without_question_text_still_works(): |
| """Verify reset() works without question_text (random question).""" |
| SQLEnvTRL._configure( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| step_budget=5, |
| ) |
| env = SQLEnvTRL() |
| obs = env.reset() |
| assert "Tables:" in obs |
| assert "Use describe" in obs |
|
|
|
|
| def test_reset_with_unknown_question_falls_back(): |
| """Unknown question_text falls back to random selection.""" |
| SQLEnvTRL._configure( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| step_budget=5, |
| ) |
| env = SQLEnvTRL() |
| obs = env.reset(question_text="This question does not exist") |
| assert "Tables:" in obs |
|
|
|
|
| def test_post_episode_penalty_on_tool_call_after_done(): |
| """Calling a tool after episode ends should apply penalty and raise.""" |
| SQLEnvTRL._configure( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| step_budget=5, |
| ) |
| env = SQLEnvTRL() |
| env.reset() |
| env._done = True |
| env.reward = 1.0 |
|
|
| with pytest.raises(ValueError, match="Episode is over"): |
| env.describe("some_table") |
|
|
| assert env.reward == pytest.approx(1.0 + _POST_EPISODE_PENALTY) |
|
|
| |
| with pytest.raises(ValueError, match="Episode is over"): |
| env.query("SELECT 1") |
|
|
| assert env.reward == pytest.approx(1.0 + 2 * _POST_EPISODE_PENALTY) |
|
|