| """Unit tests for the TRL adapter shell.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import pytest |
|
|
| from sql_env.models import SQLAction |
| from sql_env.server.sql_environment import SQLEnvironment |
| from sql_env.training.notebook_pipeline import build_trainer |
| from sql_env.training.trl_adapter import ( |
| SQLEnvTRL, |
| _REPEAT_PENALTY, |
| _MinimalTokenizer, |
| sql_env_reward_func, |
| ) |
|
|
| _PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| _QUESTIONS_PATH = _PROJECT_ROOT / "data/questions/student_assessment.json" |
| _DB_DIR = _PROJECT_ROOT / "data/databases" |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def _reset_class_configuration() -> None: |
| previous_questions_path = SQLEnvTRL._questions_path |
| previous_db_dir = SQLEnvTRL._db_dir |
| previous_step_budget = SQLEnvTRL._step_budget |
|
|
| SQLEnvTRL._questions_path = None |
| SQLEnvTRL._db_dir = None |
| SQLEnvTRL._step_budget = 10 |
| yield |
| SQLEnvTRL._questions_path = previous_questions_path |
| SQLEnvTRL._db_dir = previous_db_dir |
| SQLEnvTRL._step_budget = previous_step_budget |
|
|
|
|
| def test_minimal_tokenizer_apply_chat_template() -> None: |
| tokenizer = _MinimalTokenizer() |
|
|
| rendered = tokenizer.apply_chat_template( |
| [ |
| {"role": "user", "content": "hi"}, |
| ] |
| ) |
|
|
| assert isinstance(rendered, str) |
|
|
|
|
| def test_minimal_tokenizer_empty_messages() -> None: |
| tokenizer = _MinimalTokenizer() |
|
|
| rendered = tokenizer.apply_chat_template([]) |
|
|
| assert isinstance(rendered, str) |
|
|
|
|
| def test_configure_sets_class_attrs() -> None: |
| SQLEnvTRL._configure( |
| questions_path="q.json", |
| db_dir="dbs", |
| step_budget=10, |
| ) |
|
|
| assert SQLEnvTRL._questions_path == "q.json" |
| assert SQLEnvTRL._db_dir == "dbs" |
| assert SQLEnvTRL._step_budget == 10 |
|
|
|
|
| def test_configure_custom_step_budget() -> None: |
| SQLEnvTRL._configure( |
| questions_path="q.json", |
| db_dir="dbs", |
| step_budget=5, |
| ) |
|
|
| assert SQLEnvTRL._step_budget == 5 |
|
|
|
|
| def test_configure_default_step_budget() -> None: |
| SQLEnvTRL._configure( |
| questions_path="q.json", |
| db_dir="dbs", |
| ) |
|
|
| assert SQLEnvTRL._step_budget == 10 |
|
|
|
|
| def test_configure_is_classmethod() -> None: |
| SQLEnvTRL._configure( |
| questions_path="q.json", |
| db_dir="dbs", |
| step_budget=8, |
| ) |
|
|
| assert SQLEnvTRL._questions_path == "q.json" |
| assert SQLEnvTRL._db_dir == "dbs" |
| assert SQLEnvTRL._step_budget == 8 |
|
|
|
|
| def test_configure_overwrites_previous() -> None: |
| SQLEnvTRL._configure( |
| questions_path="one.json", |
| db_dir="one-dbs", |
| step_budget=3, |
| ) |
| SQLEnvTRL._configure( |
| questions_path="two.json", |
| db_dir="two-dbs", |
| step_budget=9, |
| ) |
|
|
| assert SQLEnvTRL._questions_path == "two.json" |
| assert SQLEnvTRL._db_dir == "two-dbs" |
| assert SQLEnvTRL._step_budget == 9 |
|
|
|
|
| def test_init_after_configure() -> None: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| step_budget=10, |
| ) |
|
|
| adapter = SQLEnvTRL() |
|
|
| assert isinstance(adapter._env, SQLEnvironment) |
|
|
|
|
| def test_init_no_args() -> None: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
|
|
| adapter = SQLEnvTRL() |
|
|
| assert isinstance(adapter, SQLEnvTRL) |
|
|
|
|
| def test_init_without_configure_raises() -> None: |
| with pytest.raises(RuntimeError): |
| SQLEnvTRL() |
|
|
|
|
| def test_init_sets_reward_zero() -> None: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
|
|
| adapter = SQLEnvTRL() |
|
|
| assert adapter.reward == 0.0 |
|
|
|
|
| def test_init_sets_done_false() -> None: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
|
|
| adapter = SQLEnvTRL() |
|
|
| assert adapter._done is False |
|
|
|
|
| def test_init_invalid_questions_path() -> None: |
| SQLEnvTRL._configure( |
| questions_path="/no/such/file.json", |
| db_dir=str(_DB_DIR), |
| ) |
|
|
| with pytest.raises(FileNotFoundError): |
| SQLEnvTRL() |
|
|
|
|
| def test_init_invalid_db_dir() -> None: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir="/no/such/dir", |
| ) |
|
|
| with pytest.raises(FileNotFoundError): |
| SQLEnvTRL() |
|
|
|
|
| class _RecordingEnv: |
| def __init__(self, observations: list[SimpleNamespace]) -> None: |
| self._observations = observations |
| self.actions: list[SQLAction] = [] |
|
|
| def step(self, action: SQLAction) -> SimpleNamespace: |
| self.actions.append(action) |
| return self._observations.pop(0) |
|
|
|
|
| class _ResetRecordingEnv: |
| def __init__(self, observations: list[SimpleNamespace]) -> None: |
| self._observations = observations |
| self.reset_calls: list[dict[str, object]] = [] |
|
|
| def reset(self, *, seed: int | None = None) -> SimpleNamespace: |
| self.reset_calls.append({"seed": seed}) |
| return self._observations.pop(0) |
|
|
|
|
| class _RecordingEnvWithReset: |
| def __init__( |
| self, |
| *, |
| step_observations: list[SimpleNamespace], |
| reset_observations: list[SimpleNamespace], |
| ) -> None: |
| self._step_observations = step_observations |
| self._reset_observations = reset_observations |
| self.actions: list[SQLAction] = [] |
| self.reset_calls: list[dict[str, object]] = [] |
|
|
| def step(self, action: SQLAction) -> SimpleNamespace: |
| self.actions.append(action) |
| return self._step_observations.pop(0) |
|
|
| def reset(self, *, seed: int | None = None) -> SimpleNamespace: |
| self.reset_calls.append({"seed": seed}) |
| return self._reset_observations.pop(0) |
|
|
|
|
| def _observation( |
| *, |
| result: str = "ok", |
| reward: float | None = 0.0, |
| done: bool = False, |
| ) -> SimpleNamespace: |
| return SimpleNamespace(result=result, error="", reward=reward, done=done) |
|
|
|
|
| def _build_adapter_with_recording_env( |
| observation: SimpleNamespace, |
| ) -> tuple[SQLEnvTRL, _RecordingEnv]: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
| adapter = SQLEnvTRL() |
| recording_env = _RecordingEnv([observation]) |
| adapter._env = recording_env |
| return adapter, recording_env |
|
|
|
|
| def _build_adapter_with_recording_observations( |
| observations: list[SimpleNamespace], |
| ) -> tuple[SQLEnvTRL, _RecordingEnv]: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
| adapter = SQLEnvTRL() |
| recording_env = _RecordingEnv(observations) |
| adapter._env = recording_env |
| return adapter, recording_env |
|
|
|
|
| def _build_adapter_with_reset_env( |
| observations: list[SimpleNamespace], |
| ) -> tuple[SQLEnvTRL, _ResetRecordingEnv]: |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
| adapter = SQLEnvTRL() |
| reset_env = _ResetRecordingEnv(observations) |
| adapter._env = reset_env |
| return adapter, reset_env |
|
|
|
|
| def test_describe_dispatches_action_and_accumulates_reward() -> None: |
| observation = SimpleNamespace(result="schema", reward=0.25, done=False) |
| adapter, recording_env = _build_adapter_with_recording_env(observation) |
|
|
| result = adapter.describe("employees") |
|
|
| assert result == "schema" |
| assert adapter.reward == 0.25 |
| assert adapter._done is False |
| assert recording_env.actions == [ |
| SQLAction(action_type="DESCRIBE", argument="employees") |
| ] |
|
|
|
|
| def test_sample_dispatches_action_and_accumulates_reward() -> None: |
| observation = SimpleNamespace(result="rows", reward=0.1, done=False) |
| adapter, recording_env = _build_adapter_with_recording_env(observation) |
|
|
| result = adapter.sample("employees") |
|
|
| assert result == "rows" |
| assert adapter.reward == 0.1 |
| assert adapter._done is False |
| assert recording_env.actions == [ |
| SQLAction(action_type="SAMPLE", argument="employees") |
| ] |
|
|
|
|
| def test_query_dispatches_action_and_accumulates_reward() -> None: |
| observation = SimpleNamespace(result="query output", reward=0.5, done=False) |
| adapter, recording_env = _build_adapter_with_recording_env(observation) |
|
|
| result = adapter.query("SELECT 1") |
|
|
| assert result == "query output" |
| assert adapter.reward == 0.5 |
| assert adapter._done is False |
| assert recording_env.actions == [ |
| SQLAction(action_type="QUERY", argument="SELECT 1") |
| ] |
|
|
|
|
| def test_answer_dispatches_action_sets_done_and_accumulates_reward() -> None: |
| observation = SimpleNamespace( |
| result="Answer submitted: correct.", reward=1.0, done=True |
| ) |
| adapter, recording_env = _build_adapter_with_recording_env(observation) |
|
|
| result = adapter.answer("42") |
|
|
| assert result == "Answer submitted: correct." |
| assert adapter.reward == 1.0 |
| assert adapter._done is True |
| assert recording_env.actions == [SQLAction(action_type="ANSWER", argument="42")] |
|
|
|
|
| def test_query_repeat_penalty_applies_on_exact_repeat() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation()] |
| ) |
|
|
| adapter.query("SELECT 1") |
| adapter.query("SELECT 1") |
|
|
| assert adapter.reward == pytest.approx(_REPEAT_PENALTY) |
| assert adapter._repeat_count == 1 |
|
|
|
|
| def test_query_repeat_penalty_not_applied_for_different_sql() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation()] |
| ) |
|
|
| adapter.query("SELECT 1") |
| adapter.query("SELECT 2") |
|
|
| assert adapter.reward == pytest.approx(0.0) |
| assert adapter._repeat_count == 0 |
|
|
|
|
| def test_repeat_penalty_not_applied_for_different_method_same_argument() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation()] |
| ) |
|
|
| adapter.describe("employees") |
| adapter.sample("employees") |
|
|
| assert adapter.reward == pytest.approx(0.0) |
| assert adapter._repeat_count == 0 |
|
|
|
|
| def test_reset_clears_recent_call_tracker_for_penalty() -> None: |
| reset_obs = SimpleNamespace( |
| question="Q", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| SQLEnvTRL._configure( |
| questions_path=str(_QUESTIONS_PATH), |
| db_dir=str(_DB_DIR), |
| ) |
| adapter = SQLEnvTRL() |
| adapter._env = _RecordingEnvWithReset( |
| step_observations=[_observation(), _observation(), _observation()], |
| reset_observations=[reset_obs], |
| ) |
|
|
| adapter.query("SELECT 1") |
| adapter.query("SELECT 1") |
| reward_before_reset = adapter.reward |
| adapter.reset() |
| adapter.query("SELECT 1") |
|
|
| assert reward_before_reset == pytest.approx(_REPEAT_PENALTY) |
| assert adapter.reward == pytest.approx(0.0) |
| assert adapter._repeat_count == 0 |
|
|
|
|
| def test_repeat_penalty_catches_alternating_reuse_within_window() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation(), _observation()] |
| ) |
|
|
| adapter.query("A") |
| adapter.query("B") |
| adapter.query("A") |
|
|
| assert adapter.reward == pytest.approx(_REPEAT_PENALTY) |
| assert adapter._repeat_count == 1 |
|
|
|
|
| def test_repeat_penalty_stacks_for_three_identical_calls() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation(), _observation()] |
| ) |
|
|
| adapter.query("SELECT 1") |
| adapter.query("SELECT 1") |
| adapter.query("SELECT 1") |
|
|
| assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 2) |
| assert adapter._repeat_count == 2 |
|
|
|
|
| def test_repeat_count_matches_penalty_fire_count_across_methods() -> None: |
| adapter, _ = _build_adapter_with_recording_observations( |
| [_observation(), _observation(), _observation(), _observation(), _observation()] |
| ) |
|
|
| adapter.describe("t") |
| adapter.describe("t") |
| adapter.sample("t") |
| adapter.sample("t") |
| adapter.describe("t") |
|
|
| assert adapter._repeat_count == 3 |
| assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 3) |
|
|
|
|
| @pytest.mark.parametrize( |
| "method_name, argument", |
| [ |
| ("describe", "employees"), |
| ("sample", "employees"), |
| ("query", "SELECT 1"), |
| ("answer", "42"), |
| ], |
| ) |
| def test_tool_methods_raise_when_episode_is_over( |
| method_name: str, argument: str |
| ) -> None: |
| observation = SimpleNamespace(result="unused", reward=0.0, done=False) |
| adapter, _ = _build_adapter_with_recording_env(observation) |
| adapter._done = True |
|
|
| with pytest.raises(ValueError, match="Episode is over"): |
| getattr(adapter, method_name)(argument) |
|
|
|
|
| def test_tool_method_docstrings_include_args_and_returns_sections() -> None: |
| for method_name in ["describe", "sample", "query", "answer"]: |
| doc = getattr(SQLEnvTRL, method_name).__doc__ |
| assert isinstance(doc, str) |
| assert "Args:" in doc |
| assert "Returns:" in doc |
|
|
|
|
| def test_tool_methods_have_annotations() -> None: |
| assert SQLEnvTRL.describe.__annotations__ == { |
| "table_name": "str", |
| "return": "str", |
| } |
| assert SQLEnvTRL.sample.__annotations__ == { |
| "table_name": "str", |
| "return": "str", |
| } |
| assert SQLEnvTRL.query.__annotations__ == { |
| "sql": "str", |
| "return": "str", |
| } |
| assert SQLEnvTRL.answer.__annotations__ == { |
| "value": "str", |
| "return": "str", |
| } |
|
|
|
|
| def test_reset_returns_observation_string() -> None: |
| adapter, reset_env = _build_adapter_with_reset_env( |
| [ |
| SimpleNamespace( |
| question="How many students?", |
| schema_info="student(id, name)", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| ] |
| ) |
|
|
| observation_text = adapter.reset() |
|
|
| assert isinstance(observation_text, str) |
| assert observation_text.strip() != "" |
| assert reset_env.reset_calls == [{"seed": None}] |
|
|
|
|
| def test_reset_clears_reward() -> None: |
| adapter, _ = _build_adapter_with_reset_env( |
| [ |
| SimpleNamespace( |
| question="Q", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| ] |
| ) |
| adapter.reward = 3.5 |
|
|
| adapter.reset() |
|
|
| assert adapter.reward == 0.0 |
|
|
|
|
| def test_reset_clears_done() -> None: |
| adapter, _ = _build_adapter_with_reset_env( |
| [ |
| SimpleNamespace( |
| question="Q", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| ] |
| ) |
| adapter._done = True |
|
|
| adapter.reset() |
|
|
| assert adapter._done is False |
|
|
|
|
| def test_reset_accepts_kwargs() -> None: |
| adapter, reset_env = _build_adapter_with_reset_env( |
| [ |
| SimpleNamespace( |
| question="Q", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| ] |
| ) |
|
|
| observation_text = adapter.reset(foo="bar") |
|
|
| assert isinstance(observation_text, str) |
| assert reset_env.reset_calls == [{"seed": None}] |
|
|
|
|
| def test_reset_multiple_times() -> None: |
| adapter, reset_env = _build_adapter_with_reset_env( |
| [ |
| SimpleNamespace( |
| question="Q1", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ), |
| SimpleNamespace( |
| question="Q2", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ), |
| SimpleNamespace( |
| question="Q3", |
| schema_info="schema", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=10, |
| action_history=[], |
| done=False, |
| reward=None, |
| ), |
| ] |
| ) |
|
|
| first = adapter.reset() |
| adapter.reward = 1.5 |
| adapter._done = True |
| second = adapter.reset() |
| adapter.reward = 2.0 |
| adapter._done = True |
| third = adapter.reset() |
|
|
| assert isinstance(first, str) |
| assert isinstance(second, str) |
| assert isinstance(third, str) |
| assert adapter.reward == 0.0 |
| assert adapter._done is False |
| assert reset_env.reset_calls == [{"seed": None}, {"seed": None}, {"seed": None}] |
|
|
|
|
| def test_reward_func_reads_accumulated_rewards() -> None: |
| env_one = SimpleNamespace(reward=0.5) |
| env_two = SimpleNamespace(reward=1.0) |
| env_three = SimpleNamespace(reward=0.0) |
|
|
| rewards = sql_env_reward_func([env_one, env_two, env_three]) |
|
|
| assert rewards == [0.5, 1.0, 0.0] |
|
|
|
|
| def test_reward_func_empty_list() -> None: |
| rewards = sql_env_reward_func([]) |
|
|
| assert rewards == [] |
|
|
|
|
| def test_reward_func_single_env() -> None: |
| env = SimpleNamespace(reward=0.75) |
|
|
| rewards = sql_env_reward_func([env]) |
|
|
| assert rewards == [0.75] |
|
|
|
|
| def test_reward_func_ignores_kwargs() -> None: |
| env = SimpleNamespace(reward=2.25) |
|
|
| rewards = sql_env_reward_func([env], completions=[], foo="bar") |
|
|
| assert rewards == [2.25] |
|
|
|
|
| def test_reward_func_returns_list_of_floats() -> None: |
| env_one = SimpleNamespace(reward=1) |
| env_two = SimpleNamespace(reward=0.25) |
|
|
| rewards = sql_env_reward_func([env_one, env_two]) |
|
|
| assert isinstance(rewards, list) |
| assert all(isinstance(value, float) for value in rewards) |
|
|
|
|
| class _BuildTrainerConfigRecorder: |
| def __init__(self, **kwargs: object) -> None: |
| self.kwargs = kwargs |
|
|
|
|
| class _BuildTrainerClassRecorder: |
| def __init__(self, **kwargs: object) -> None: |
| self.kwargs = kwargs |
|
|
|
|
| class _EnvironmentFactoryWithConfigure: |
| configure_calls: list[dict[str, object]] = [] |
|
|
| @classmethod |
| def configure( |
| cls, |
| *, |
| questions_path: str, |
| db_dir: str, |
| step_budget: int, |
| ) -> None: |
| cls.configure_calls.append( |
| { |
| "questions_path": questions_path, |
| "db_dir": db_dir, |
| "step_budget": step_budget, |
| } |
| ) |
|
|
|
|
| class _EnvironmentFactoryWithoutConfigure: |
| pass |
|
|
|
|
| def _build_trainer_config() -> SimpleNamespace: |
| return SimpleNamespace( |
| output_dir="outputs/test", |
| learning_rate=1e-5, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=2, |
| num_train_epochs=1, |
| logging_steps=1, |
| max_new_tokens=128, |
| num_generations=2, |
| questions_path="data/questions/student_assessment.json", |
| db_dir="data/databases", |
| step_budget=7, |
| ) |
|
|
|
|
| def test_build_trainer_with_environment_factory() -> None: |
| _EnvironmentFactoryWithConfigure.configure_calls = [] |
| config = _build_trainer_config() |
|
|
| trainer = build_trainer( |
| model=object(), |
| tokenizer=object(), |
| prompts=["prompt"], |
| config=config, |
| trl_grpo_config_cls=_BuildTrainerConfigRecorder, |
| grpo_trainer_cls=_BuildTrainerClassRecorder, |
| reward_funcs=[sql_env_reward_func], |
| environment_factory=_EnvironmentFactoryWithConfigure, |
| ) |
|
|
| assert trainer.kwargs["environment_factory"] is _EnvironmentFactoryWithConfigure |
| assert _EnvironmentFactoryWithConfigure.configure_calls == [ |
| { |
| "questions_path": config.questions_path, |
| "db_dir": config.db_dir, |
| "step_budget": config.step_budget, |
| } |
| ] |
|
|
|
|
| def test_build_trainer_without_environment_factory() -> None: |
| config = _build_trainer_config() |
|
|
| trainer = build_trainer( |
| model=object(), |
| tokenizer=object(), |
| prompts=["prompt"], |
| config=config, |
| trl_grpo_config_cls=_BuildTrainerConfigRecorder, |
| grpo_trainer_cls=_BuildTrainerClassRecorder, |
| reward_funcs=[sql_env_reward_func], |
| environment_factory=None, |
| ) |
|
|
| assert "environment_factory" not in trainer.kwargs |
|
|
|
|
| def test_build_trainer_passes_reward_funcs() -> None: |
| config = _build_trainer_config() |
| reward_funcs = [sql_env_reward_func] |
|
|
| trainer = build_trainer( |
| model=object(), |
| tokenizer=object(), |
| prompts=["prompt"], |
| config=config, |
| trl_grpo_config_cls=_BuildTrainerConfigRecorder, |
| grpo_trainer_cls=_BuildTrainerClassRecorder, |
| reward_funcs=reward_funcs, |
| environment_factory=_EnvironmentFactoryWithoutConfigure, |
| ) |
|
|
| assert trainer.kwargs["reward_funcs"] == reward_funcs |
|
|