"""Tests for SQLEnvTRL question-database alignment and post-episode penalty.""" import pytest from sql_env.training.trl_adapter import SQLEnvTRL, _POST_EPISODE_PENALTY def test_reset_with_question_text_loads_correct_database(): """Verify that passing question_text to reset() loads the matching DB.""" SQLEnvTRL._configure( questions_path="data/questions/questions_train.json", db_dir="data/databases", step_budget=5, ) env = SQLEnvTRL() # "How many templates do we have?" belongs to cre_Doc_Template_Mgt obs = env.reset(question_text="How many templates do we have?") assert "Templates" in obs, ( f"Expected 'Templates' table for template question, got: {obs}" ) # "How many employees are there?" belongs to employee_hire_evaluation obs = env.reset( question_text="Find the total amount of bonus given in all the evaluations." ) assert "employee" in obs or "evaluation" in obs, ( f"Expected employee/evaluation tables, got: {obs}" ) def test_reset_without_question_text_still_works(): """Verify reset() works without question_text (random question).""" SQLEnvTRL._configure( questions_path="data/questions/questions_train.json", db_dir="data/databases", step_budget=5, ) env = SQLEnvTRL() obs = env.reset() assert "Tables:" in obs assert "Use describe" in obs def test_reset_with_unknown_question_falls_back(): """Unknown question_text falls back to random selection.""" SQLEnvTRL._configure( questions_path="data/questions/questions_train.json", db_dir="data/databases", step_budget=5, ) env = SQLEnvTRL() obs = env.reset(question_text="This question does not exist") assert "Tables:" in obs def test_post_episode_penalty_on_tool_call_after_done(): """Calling a tool after episode ends should apply penalty and raise.""" SQLEnvTRL._configure( questions_path="data/questions/questions_train.json", db_dir="data/databases", step_budget=5, ) env = SQLEnvTRL() env.reset() env._done = True env.reward = 1.0 with pytest.raises(ValueError, match="Episode is over"): env.describe("some_table") assert env.reward == pytest.approx(1.0 + _POST_EPISODE_PENALTY) # Second call stacks the penalty with pytest.raises(ValueError, match="Episode is over"): env.query("SELECT 1") assert env.reward == pytest.approx(1.0 + 2 * _POST_EPISODE_PENALTY)