"""Unit tests for the TRL adapter shell.""" from __future__ import annotations from pathlib import Path from types import SimpleNamespace import pytest from sql_env.models import SQLAction from sql_env.server.sql_environment import SQLEnvironment from sql_env.training.notebook_pipeline import build_trainer from sql_env.training.trl_adapter import ( SQLEnvTRL, _REPEAT_PENALTY, _MinimalTokenizer, sql_env_reward_func, ) _PROJECT_ROOT = Path(__file__).resolve().parents[2] _QUESTIONS_PATH = _PROJECT_ROOT / "data/questions/student_assessment.json" _DB_DIR = _PROJECT_ROOT / "data/databases" @pytest.fixture(autouse=True) def _reset_class_configuration() -> None: previous_questions_path = SQLEnvTRL._questions_path previous_db_dir = SQLEnvTRL._db_dir previous_step_budget = SQLEnvTRL._step_budget SQLEnvTRL._questions_path = None SQLEnvTRL._db_dir = None SQLEnvTRL._step_budget = 10 yield SQLEnvTRL._questions_path = previous_questions_path SQLEnvTRL._db_dir = previous_db_dir SQLEnvTRL._step_budget = previous_step_budget def test_minimal_tokenizer_apply_chat_template() -> None: tokenizer = _MinimalTokenizer() rendered = tokenizer.apply_chat_template( [ {"role": "user", "content": "hi"}, ] ) assert isinstance(rendered, str) def test_minimal_tokenizer_empty_messages() -> None: tokenizer = _MinimalTokenizer() rendered = tokenizer.apply_chat_template([]) assert isinstance(rendered, str) def test_configure_sets_class_attrs() -> None: SQLEnvTRL._configure( questions_path="q.json", db_dir="dbs", step_budget=10, ) assert SQLEnvTRL._questions_path == "q.json" assert SQLEnvTRL._db_dir == "dbs" assert SQLEnvTRL._step_budget == 10 def test_configure_custom_step_budget() -> None: SQLEnvTRL._configure( questions_path="q.json", db_dir="dbs", step_budget=5, ) assert SQLEnvTRL._step_budget == 5 def test_configure_default_step_budget() -> None: SQLEnvTRL._configure( questions_path="q.json", db_dir="dbs", ) assert SQLEnvTRL._step_budget == 10 def test_configure_is_classmethod() -> None: SQLEnvTRL._configure( questions_path="q.json", db_dir="dbs", step_budget=8, ) assert SQLEnvTRL._questions_path == "q.json" assert SQLEnvTRL._db_dir == "dbs" assert SQLEnvTRL._step_budget == 8 def test_configure_overwrites_previous() -> None: SQLEnvTRL._configure( questions_path="one.json", db_dir="one-dbs", step_budget=3, ) SQLEnvTRL._configure( questions_path="two.json", db_dir="two-dbs", step_budget=9, ) assert SQLEnvTRL._questions_path == "two.json" assert SQLEnvTRL._db_dir == "two-dbs" assert SQLEnvTRL._step_budget == 9 def test_init_after_configure() -> None: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), step_budget=10, ) adapter = SQLEnvTRL() assert isinstance(adapter._env, SQLEnvironment) def test_init_no_args() -> None: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() assert isinstance(adapter, SQLEnvTRL) def test_init_without_configure_raises() -> None: with pytest.raises(RuntimeError): SQLEnvTRL() def test_init_sets_reward_zero() -> None: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() assert adapter.reward == 0.0 def test_init_sets_done_false() -> None: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() assert adapter._done is False def test_init_invalid_questions_path() -> None: SQLEnvTRL._configure( questions_path="/no/such/file.json", db_dir=str(_DB_DIR), ) with pytest.raises(FileNotFoundError): SQLEnvTRL() def test_init_invalid_db_dir() -> None: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir="/no/such/dir", ) with pytest.raises(FileNotFoundError): SQLEnvTRL() class _RecordingEnv: def __init__(self, observations: list[SimpleNamespace]) -> None: self._observations = observations self.actions: list[SQLAction] = [] def step(self, action: SQLAction) -> SimpleNamespace: self.actions.append(action) return self._observations.pop(0) class _ResetRecordingEnv: def __init__(self, observations: list[SimpleNamespace]) -> None: self._observations = observations self.reset_calls: list[dict[str, object]] = [] def reset(self, *, seed: int | None = None) -> SimpleNamespace: self.reset_calls.append({"seed": seed}) return self._observations.pop(0) class _RecordingEnvWithReset: def __init__( self, *, step_observations: list[SimpleNamespace], reset_observations: list[SimpleNamespace], ) -> None: self._step_observations = step_observations self._reset_observations = reset_observations self.actions: list[SQLAction] = [] self.reset_calls: list[dict[str, object]] = [] def step(self, action: SQLAction) -> SimpleNamespace: self.actions.append(action) return self._step_observations.pop(0) def reset(self, *, seed: int | None = None) -> SimpleNamespace: self.reset_calls.append({"seed": seed}) return self._reset_observations.pop(0) def _observation( *, result: str = "ok", reward: float | None = 0.0, done: bool = False, ) -> SimpleNamespace: return SimpleNamespace(result=result, error="", reward=reward, done=done) def _build_adapter_with_recording_env( observation: SimpleNamespace, ) -> tuple[SQLEnvTRL, _RecordingEnv]: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() recording_env = _RecordingEnv([observation]) adapter._env = recording_env return adapter, recording_env def _build_adapter_with_recording_observations( observations: list[SimpleNamespace], ) -> tuple[SQLEnvTRL, _RecordingEnv]: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() recording_env = _RecordingEnv(observations) adapter._env = recording_env return adapter, recording_env def _build_adapter_with_reset_env( observations: list[SimpleNamespace], ) -> tuple[SQLEnvTRL, _ResetRecordingEnv]: SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() reset_env = _ResetRecordingEnv(observations) adapter._env = reset_env return adapter, reset_env def test_describe_dispatches_action_and_accumulates_reward() -> None: observation = SimpleNamespace(result="schema", reward=0.25, done=False) adapter, recording_env = _build_adapter_with_recording_env(observation) result = adapter.describe("employees") assert result == "schema" assert adapter.reward == 0.25 assert adapter._done is False assert recording_env.actions == [ SQLAction(action_type="DESCRIBE", argument="employees") ] def test_sample_dispatches_action_and_accumulates_reward() -> None: observation = SimpleNamespace(result="rows", reward=0.1, done=False) adapter, recording_env = _build_adapter_with_recording_env(observation) result = adapter.sample("employees") assert result == "rows" assert adapter.reward == 0.1 assert adapter._done is False assert recording_env.actions == [ SQLAction(action_type="SAMPLE", argument="employees") ] def test_query_dispatches_action_and_accumulates_reward() -> None: observation = SimpleNamespace(result="query output", reward=0.5, done=False) adapter, recording_env = _build_adapter_with_recording_env(observation) result = adapter.query("SELECT 1") assert result == "query output" assert adapter.reward == 0.5 assert adapter._done is False assert recording_env.actions == [ SQLAction(action_type="QUERY", argument="SELECT 1") ] def test_answer_dispatches_action_sets_done_and_accumulates_reward() -> None: observation = SimpleNamespace( result="Answer submitted: correct.", reward=1.0, done=True ) adapter, recording_env = _build_adapter_with_recording_env(observation) result = adapter.answer("42") assert result == "Answer submitted: correct." assert adapter.reward == 1.0 assert adapter._done is True assert recording_env.actions == [SQLAction(action_type="ANSWER", argument="42")] def test_query_repeat_penalty_applies_on_exact_repeat() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation()] ) adapter.query("SELECT 1") adapter.query("SELECT 1") assert adapter.reward == pytest.approx(_REPEAT_PENALTY) assert adapter._repeat_count == 1 def test_query_repeat_penalty_not_applied_for_different_sql() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation()] ) adapter.query("SELECT 1") adapter.query("SELECT 2") assert adapter.reward == pytest.approx(0.0) assert adapter._repeat_count == 0 def test_repeat_penalty_not_applied_for_different_method_same_argument() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation()] ) adapter.describe("employees") adapter.sample("employees") assert adapter.reward == pytest.approx(0.0) assert adapter._repeat_count == 0 def test_reset_clears_recent_call_tracker_for_penalty() -> None: reset_obs = SimpleNamespace( question="Q", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ) SQLEnvTRL._configure( questions_path=str(_QUESTIONS_PATH), db_dir=str(_DB_DIR), ) adapter = SQLEnvTRL() adapter._env = _RecordingEnvWithReset( step_observations=[_observation(), _observation(), _observation()], reset_observations=[reset_obs], ) adapter.query("SELECT 1") adapter.query("SELECT 1") reward_before_reset = adapter.reward adapter.reset() adapter.query("SELECT 1") assert reward_before_reset == pytest.approx(_REPEAT_PENALTY) assert adapter.reward == pytest.approx(0.0) assert adapter._repeat_count == 0 def test_repeat_penalty_catches_alternating_reuse_within_window() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation(), _observation()] ) adapter.query("A") adapter.query("B") adapter.query("A") assert adapter.reward == pytest.approx(_REPEAT_PENALTY) assert adapter._repeat_count == 1 def test_repeat_penalty_stacks_for_three_identical_calls() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation(), _observation()] ) adapter.query("SELECT 1") adapter.query("SELECT 1") adapter.query("SELECT 1") assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 2) assert adapter._repeat_count == 2 def test_repeat_count_matches_penalty_fire_count_across_methods() -> None: adapter, _ = _build_adapter_with_recording_observations( [_observation(), _observation(), _observation(), _observation(), _observation()] ) adapter.describe("t") adapter.describe("t") adapter.sample("t") adapter.sample("t") adapter.describe("t") assert adapter._repeat_count == 3 assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 3) @pytest.mark.parametrize( "method_name, argument", [ ("describe", "employees"), ("sample", "employees"), ("query", "SELECT 1"), ("answer", "42"), ], ) def test_tool_methods_raise_when_episode_is_over( method_name: str, argument: str ) -> None: observation = SimpleNamespace(result="unused", reward=0.0, done=False) adapter, _ = _build_adapter_with_recording_env(observation) adapter._done = True with pytest.raises(ValueError, match="Episode is over"): getattr(adapter, method_name)(argument) def test_tool_method_docstrings_include_args_and_returns_sections() -> None: for method_name in ["describe", "sample", "query", "answer"]: doc = getattr(SQLEnvTRL, method_name).__doc__ assert isinstance(doc, str) assert "Args:" in doc assert "Returns:" in doc def test_tool_methods_have_annotations() -> None: assert SQLEnvTRL.describe.__annotations__ == { "table_name": "str", "return": "str", } assert SQLEnvTRL.sample.__annotations__ == { "table_name": "str", "return": "str", } assert SQLEnvTRL.query.__annotations__ == { "sql": "str", "return": "str", } assert SQLEnvTRL.answer.__annotations__ == { "value": "str", "return": "str", } def test_reset_returns_observation_string() -> None: adapter, reset_env = _build_adapter_with_reset_env( [ SimpleNamespace( question="How many students?", schema_info="student(id, name)", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ) ] ) observation_text = adapter.reset() assert isinstance(observation_text, str) assert observation_text.strip() != "" assert reset_env.reset_calls == [{"seed": None}] def test_reset_clears_reward() -> None: adapter, _ = _build_adapter_with_reset_env( [ SimpleNamespace( question="Q", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ) ] ) adapter.reward = 3.5 adapter.reset() assert adapter.reward == 0.0 def test_reset_clears_done() -> None: adapter, _ = _build_adapter_with_reset_env( [ SimpleNamespace( question="Q", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ) ] ) adapter._done = True adapter.reset() assert adapter._done is False def test_reset_accepts_kwargs() -> None: adapter, reset_env = _build_adapter_with_reset_env( [ SimpleNamespace( question="Q", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ) ] ) observation_text = adapter.reset(foo="bar") assert isinstance(observation_text, str) assert reset_env.reset_calls == [{"seed": None}] def test_reset_multiple_times() -> None: adapter, reset_env = _build_adapter_with_reset_env( [ SimpleNamespace( question="Q1", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ), SimpleNamespace( question="Q2", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ), SimpleNamespace( question="Q3", schema_info="schema", result="", error="", step_count=0, budget_remaining=10, action_history=[], done=False, reward=None, ), ] ) first = adapter.reset() adapter.reward = 1.5 adapter._done = True second = adapter.reset() adapter.reward = 2.0 adapter._done = True third = adapter.reset() assert isinstance(first, str) assert isinstance(second, str) assert isinstance(third, str) assert adapter.reward == 0.0 assert adapter._done is False assert reset_env.reset_calls == [{"seed": None}, {"seed": None}, {"seed": None}] def test_reward_func_reads_accumulated_rewards() -> None: env_one = SimpleNamespace(reward=0.5) env_two = SimpleNamespace(reward=1.0) env_three = SimpleNamespace(reward=0.0) rewards = sql_env_reward_func([env_one, env_two, env_three]) assert rewards == [0.5, 1.0, 0.0] def test_reward_func_empty_list() -> None: rewards = sql_env_reward_func([]) assert rewards == [] def test_reward_func_single_env() -> None: env = SimpleNamespace(reward=0.75) rewards = sql_env_reward_func([env]) assert rewards == [0.75] def test_reward_func_ignores_kwargs() -> None: env = SimpleNamespace(reward=2.25) rewards = sql_env_reward_func([env], completions=[], foo="bar") assert rewards == [2.25] def test_reward_func_returns_list_of_floats() -> None: env_one = SimpleNamespace(reward=1) env_two = SimpleNamespace(reward=0.25) rewards = sql_env_reward_func([env_one, env_two]) assert isinstance(rewards, list) assert all(isinstance(value, float) for value in rewards) class _BuildTrainerConfigRecorder: def __init__(self, **kwargs: object) -> None: self.kwargs = kwargs class _BuildTrainerClassRecorder: def __init__(self, **kwargs: object) -> None: self.kwargs = kwargs class _EnvironmentFactoryWithConfigure: configure_calls: list[dict[str, object]] = [] @classmethod def configure( cls, *, questions_path: str, db_dir: str, step_budget: int, ) -> None: cls.configure_calls.append( { "questions_path": questions_path, "db_dir": db_dir, "step_budget": step_budget, } ) class _EnvironmentFactoryWithoutConfigure: pass def _build_trainer_config() -> SimpleNamespace: return SimpleNamespace( output_dir="outputs/test", learning_rate=1e-5, per_device_train_batch_size=2, gradient_accumulation_steps=2, num_train_epochs=1, logging_steps=1, max_new_tokens=128, num_generations=2, questions_path="data/questions/student_assessment.json", db_dir="data/databases", step_budget=7, ) def test_build_trainer_with_environment_factory() -> None: _EnvironmentFactoryWithConfigure.configure_calls = [] config = _build_trainer_config() trainer = build_trainer( model=object(), tokenizer=object(), prompts=["prompt"], config=config, trl_grpo_config_cls=_BuildTrainerConfigRecorder, grpo_trainer_cls=_BuildTrainerClassRecorder, reward_funcs=[sql_env_reward_func], environment_factory=_EnvironmentFactoryWithConfigure, ) assert trainer.kwargs["environment_factory"] is _EnvironmentFactoryWithConfigure assert _EnvironmentFactoryWithConfigure.configure_calls == [ { "questions_path": config.questions_path, "db_dir": config.db_dir, "step_budget": config.step_budget, } ] def test_build_trainer_without_environment_factory() -> None: config = _build_trainer_config() trainer = build_trainer( model=object(), tokenizer=object(), prompts=["prompt"], config=config, trl_grpo_config_cls=_BuildTrainerConfigRecorder, grpo_trainer_cls=_BuildTrainerClassRecorder, reward_funcs=[sql_env_reward_func], environment_factory=None, ) assert "environment_factory" not in trainer.kwargs def test_build_trainer_passes_reward_funcs() -> None: config = _build_trainer_config() reward_funcs = [sql_env_reward_func] trainer = build_trainer( model=object(), tokenizer=object(), prompts=["prompt"], config=config, trl_grpo_config_cls=_BuildTrainerConfigRecorder, grpo_trainer_cls=_BuildTrainerClassRecorder, reward_funcs=reward_funcs, environment_factory=_EnvironmentFactoryWithoutConfigure, ) assert trainer.kwargs["reward_funcs"] == reward_funcs