sql_env / tests /unit /test_trl_adapter.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Unit tests for the TRL adapter shell."""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
import pytest
from sql_env.models import SQLAction
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.training.notebook_pipeline import build_trainer
from sql_env.training.trl_adapter import (
SQLEnvTRL,
_REPEAT_PENALTY,
_MinimalTokenizer,
sql_env_reward_func,
)
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_QUESTIONS_PATH = _PROJECT_ROOT / "data/questions/student_assessment.json"
_DB_DIR = _PROJECT_ROOT / "data/databases"
@pytest.fixture(autouse=True)
def _reset_class_configuration() -> None:
previous_questions_path = SQLEnvTRL._questions_path
previous_db_dir = SQLEnvTRL._db_dir
previous_step_budget = SQLEnvTRL._step_budget
SQLEnvTRL._questions_path = None
SQLEnvTRL._db_dir = None
SQLEnvTRL._step_budget = 10
yield
SQLEnvTRL._questions_path = previous_questions_path
SQLEnvTRL._db_dir = previous_db_dir
SQLEnvTRL._step_budget = previous_step_budget
def test_minimal_tokenizer_apply_chat_template() -> None:
tokenizer = _MinimalTokenizer()
rendered = tokenizer.apply_chat_template(
[
{"role": "user", "content": "hi"},
]
)
assert isinstance(rendered, str)
def test_minimal_tokenizer_empty_messages() -> None:
tokenizer = _MinimalTokenizer()
rendered = tokenizer.apply_chat_template([])
assert isinstance(rendered, str)
def test_configure_sets_class_attrs() -> None:
SQLEnvTRL._configure(
questions_path="q.json",
db_dir="dbs",
step_budget=10,
)
assert SQLEnvTRL._questions_path == "q.json"
assert SQLEnvTRL._db_dir == "dbs"
assert SQLEnvTRL._step_budget == 10
def test_configure_custom_step_budget() -> None:
SQLEnvTRL._configure(
questions_path="q.json",
db_dir="dbs",
step_budget=5,
)
assert SQLEnvTRL._step_budget == 5
def test_configure_default_step_budget() -> None:
SQLEnvTRL._configure(
questions_path="q.json",
db_dir="dbs",
)
assert SQLEnvTRL._step_budget == 10
def test_configure_is_classmethod() -> None:
SQLEnvTRL._configure(
questions_path="q.json",
db_dir="dbs",
step_budget=8,
)
assert SQLEnvTRL._questions_path == "q.json"
assert SQLEnvTRL._db_dir == "dbs"
assert SQLEnvTRL._step_budget == 8
def test_configure_overwrites_previous() -> None:
SQLEnvTRL._configure(
questions_path="one.json",
db_dir="one-dbs",
step_budget=3,
)
SQLEnvTRL._configure(
questions_path="two.json",
db_dir="two-dbs",
step_budget=9,
)
assert SQLEnvTRL._questions_path == "two.json"
assert SQLEnvTRL._db_dir == "two-dbs"
assert SQLEnvTRL._step_budget == 9
def test_init_after_configure() -> None:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
step_budget=10,
)
adapter = SQLEnvTRL()
assert isinstance(adapter._env, SQLEnvironment)
def test_init_no_args() -> None:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
assert isinstance(adapter, SQLEnvTRL)
def test_init_without_configure_raises() -> None:
with pytest.raises(RuntimeError):
SQLEnvTRL()
def test_init_sets_reward_zero() -> None:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
assert adapter.reward == 0.0
def test_init_sets_done_false() -> None:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
assert adapter._done is False
def test_init_invalid_questions_path() -> None:
SQLEnvTRL._configure(
questions_path="/no/such/file.json",
db_dir=str(_DB_DIR),
)
with pytest.raises(FileNotFoundError):
SQLEnvTRL()
def test_init_invalid_db_dir() -> None:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir="/no/such/dir",
)
with pytest.raises(FileNotFoundError):
SQLEnvTRL()
class _RecordingEnv:
def __init__(self, observations: list[SimpleNamespace]) -> None:
self._observations = observations
self.actions: list[SQLAction] = []
def step(self, action: SQLAction) -> SimpleNamespace:
self.actions.append(action)
return self._observations.pop(0)
class _ResetRecordingEnv:
def __init__(self, observations: list[SimpleNamespace]) -> None:
self._observations = observations
self.reset_calls: list[dict[str, object]] = []
def reset(self, *, seed: int | None = None) -> SimpleNamespace:
self.reset_calls.append({"seed": seed})
return self._observations.pop(0)
class _RecordingEnvWithReset:
def __init__(
self,
*,
step_observations: list[SimpleNamespace],
reset_observations: list[SimpleNamespace],
) -> None:
self._step_observations = step_observations
self._reset_observations = reset_observations
self.actions: list[SQLAction] = []
self.reset_calls: list[dict[str, object]] = []
def step(self, action: SQLAction) -> SimpleNamespace:
self.actions.append(action)
return self._step_observations.pop(0)
def reset(self, *, seed: int | None = None) -> SimpleNamespace:
self.reset_calls.append({"seed": seed})
return self._reset_observations.pop(0)
def _observation(
*,
result: str = "ok",
reward: float | None = 0.0,
done: bool = False,
) -> SimpleNamespace:
return SimpleNamespace(result=result, error="", reward=reward, done=done)
def _build_adapter_with_recording_env(
observation: SimpleNamespace,
) -> tuple[SQLEnvTRL, _RecordingEnv]:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
recording_env = _RecordingEnv([observation])
adapter._env = recording_env
return adapter, recording_env
def _build_adapter_with_recording_observations(
observations: list[SimpleNamespace],
) -> tuple[SQLEnvTRL, _RecordingEnv]:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
recording_env = _RecordingEnv(observations)
adapter._env = recording_env
return adapter, recording_env
def _build_adapter_with_reset_env(
observations: list[SimpleNamespace],
) -> tuple[SQLEnvTRL, _ResetRecordingEnv]:
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
reset_env = _ResetRecordingEnv(observations)
adapter._env = reset_env
return adapter, reset_env
def test_describe_dispatches_action_and_accumulates_reward() -> None:
observation = SimpleNamespace(result="schema", reward=0.25, done=False)
adapter, recording_env = _build_adapter_with_recording_env(observation)
result = adapter.describe("employees")
assert result == "schema"
assert adapter.reward == 0.25
assert adapter._done is False
assert recording_env.actions == [
SQLAction(action_type="DESCRIBE", argument="employees")
]
def test_sample_dispatches_action_and_accumulates_reward() -> None:
observation = SimpleNamespace(result="rows", reward=0.1, done=False)
adapter, recording_env = _build_adapter_with_recording_env(observation)
result = adapter.sample("employees")
assert result == "rows"
assert adapter.reward == 0.1
assert adapter._done is False
assert recording_env.actions == [
SQLAction(action_type="SAMPLE", argument="employees")
]
def test_query_dispatches_action_and_accumulates_reward() -> None:
observation = SimpleNamespace(result="query output", reward=0.5, done=False)
adapter, recording_env = _build_adapter_with_recording_env(observation)
result = adapter.query("SELECT 1")
assert result == "query output"
assert adapter.reward == 0.5
assert adapter._done is False
assert recording_env.actions == [
SQLAction(action_type="QUERY", argument="SELECT 1")
]
def test_answer_dispatches_action_sets_done_and_accumulates_reward() -> None:
observation = SimpleNamespace(
result="Answer submitted: correct.", reward=1.0, done=True
)
adapter, recording_env = _build_adapter_with_recording_env(observation)
result = adapter.answer("42")
assert result == "Answer submitted: correct."
assert adapter.reward == 1.0
assert adapter._done is True
assert recording_env.actions == [SQLAction(action_type="ANSWER", argument="42")]
def test_query_repeat_penalty_applies_on_exact_repeat() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation()]
)
adapter.query("SELECT 1")
adapter.query("SELECT 1")
assert adapter.reward == pytest.approx(_REPEAT_PENALTY)
assert adapter._repeat_count == 1
def test_query_repeat_penalty_not_applied_for_different_sql() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation()]
)
adapter.query("SELECT 1")
adapter.query("SELECT 2")
assert adapter.reward == pytest.approx(0.0)
assert adapter._repeat_count == 0
def test_repeat_penalty_not_applied_for_different_method_same_argument() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation()]
)
adapter.describe("employees")
adapter.sample("employees")
assert adapter.reward == pytest.approx(0.0)
assert adapter._repeat_count == 0
def test_reset_clears_recent_call_tracker_for_penalty() -> None:
reset_obs = SimpleNamespace(
question="Q",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
)
SQLEnvTRL._configure(
questions_path=str(_QUESTIONS_PATH),
db_dir=str(_DB_DIR),
)
adapter = SQLEnvTRL()
adapter._env = _RecordingEnvWithReset(
step_observations=[_observation(), _observation(), _observation()],
reset_observations=[reset_obs],
)
adapter.query("SELECT 1")
adapter.query("SELECT 1")
reward_before_reset = adapter.reward
adapter.reset()
adapter.query("SELECT 1")
assert reward_before_reset == pytest.approx(_REPEAT_PENALTY)
assert adapter.reward == pytest.approx(0.0)
assert adapter._repeat_count == 0
def test_repeat_penalty_catches_alternating_reuse_within_window() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation(), _observation()]
)
adapter.query("A")
adapter.query("B")
adapter.query("A")
assert adapter.reward == pytest.approx(_REPEAT_PENALTY)
assert adapter._repeat_count == 1
def test_repeat_penalty_stacks_for_three_identical_calls() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation(), _observation()]
)
adapter.query("SELECT 1")
adapter.query("SELECT 1")
adapter.query("SELECT 1")
assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 2)
assert adapter._repeat_count == 2
def test_repeat_count_matches_penalty_fire_count_across_methods() -> None:
adapter, _ = _build_adapter_with_recording_observations(
[_observation(), _observation(), _observation(), _observation(), _observation()]
)
adapter.describe("t")
adapter.describe("t")
adapter.sample("t")
adapter.sample("t")
adapter.describe("t")
assert adapter._repeat_count == 3
assert adapter.reward == pytest.approx(_REPEAT_PENALTY * 3)
@pytest.mark.parametrize(
"method_name, argument",
[
("describe", "employees"),
("sample", "employees"),
("query", "SELECT 1"),
("answer", "42"),
],
)
def test_tool_methods_raise_when_episode_is_over(
method_name: str, argument: str
) -> None:
observation = SimpleNamespace(result="unused", reward=0.0, done=False)
adapter, _ = _build_adapter_with_recording_env(observation)
adapter._done = True
with pytest.raises(ValueError, match="Episode is over"):
getattr(adapter, method_name)(argument)
def test_tool_method_docstrings_include_args_and_returns_sections() -> None:
for method_name in ["describe", "sample", "query", "answer"]:
doc = getattr(SQLEnvTRL, method_name).__doc__
assert isinstance(doc, str)
assert "Args:" in doc
assert "Returns:" in doc
def test_tool_methods_have_annotations() -> None:
assert SQLEnvTRL.describe.__annotations__ == {
"table_name": "str",
"return": "str",
}
assert SQLEnvTRL.sample.__annotations__ == {
"table_name": "str",
"return": "str",
}
assert SQLEnvTRL.query.__annotations__ == {
"sql": "str",
"return": "str",
}
assert SQLEnvTRL.answer.__annotations__ == {
"value": "str",
"return": "str",
}
def test_reset_returns_observation_string() -> None:
adapter, reset_env = _build_adapter_with_reset_env(
[
SimpleNamespace(
question="How many students?",
schema_info="student(id, name)",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
)
]
)
observation_text = adapter.reset()
assert isinstance(observation_text, str)
assert observation_text.strip() != ""
assert reset_env.reset_calls == [{"seed": None}]
def test_reset_clears_reward() -> None:
adapter, _ = _build_adapter_with_reset_env(
[
SimpleNamespace(
question="Q",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
)
]
)
adapter.reward = 3.5
adapter.reset()
assert adapter.reward == 0.0
def test_reset_clears_done() -> None:
adapter, _ = _build_adapter_with_reset_env(
[
SimpleNamespace(
question="Q",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
)
]
)
adapter._done = True
adapter.reset()
assert adapter._done is False
def test_reset_accepts_kwargs() -> None:
adapter, reset_env = _build_adapter_with_reset_env(
[
SimpleNamespace(
question="Q",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
)
]
)
observation_text = adapter.reset(foo="bar")
assert isinstance(observation_text, str)
assert reset_env.reset_calls == [{"seed": None}]
def test_reset_multiple_times() -> None:
adapter, reset_env = _build_adapter_with_reset_env(
[
SimpleNamespace(
question="Q1",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
),
SimpleNamespace(
question="Q2",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
),
SimpleNamespace(
question="Q3",
schema_info="schema",
result="",
error="",
step_count=0,
budget_remaining=10,
action_history=[],
done=False,
reward=None,
),
]
)
first = adapter.reset()
adapter.reward = 1.5
adapter._done = True
second = adapter.reset()
adapter.reward = 2.0
adapter._done = True
third = adapter.reset()
assert isinstance(first, str)
assert isinstance(second, str)
assert isinstance(third, str)
assert adapter.reward == 0.0
assert adapter._done is False
assert reset_env.reset_calls == [{"seed": None}, {"seed": None}, {"seed": None}]
def test_reward_func_reads_accumulated_rewards() -> None:
env_one = SimpleNamespace(reward=0.5)
env_two = SimpleNamespace(reward=1.0)
env_three = SimpleNamespace(reward=0.0)
rewards = sql_env_reward_func([env_one, env_two, env_three])
assert rewards == [0.5, 1.0, 0.0]
def test_reward_func_empty_list() -> None:
rewards = sql_env_reward_func([])
assert rewards == []
def test_reward_func_single_env() -> None:
env = SimpleNamespace(reward=0.75)
rewards = sql_env_reward_func([env])
assert rewards == [0.75]
def test_reward_func_ignores_kwargs() -> None:
env = SimpleNamespace(reward=2.25)
rewards = sql_env_reward_func([env], completions=[], foo="bar")
assert rewards == [2.25]
def test_reward_func_returns_list_of_floats() -> None:
env_one = SimpleNamespace(reward=1)
env_two = SimpleNamespace(reward=0.25)
rewards = sql_env_reward_func([env_one, env_two])
assert isinstance(rewards, list)
assert all(isinstance(value, float) for value in rewards)
class _BuildTrainerConfigRecorder:
def __init__(self, **kwargs: object) -> None:
self.kwargs = kwargs
class _BuildTrainerClassRecorder:
def __init__(self, **kwargs: object) -> None:
self.kwargs = kwargs
class _EnvironmentFactoryWithConfigure:
configure_calls: list[dict[str, object]] = []
@classmethod
def configure(
cls,
*,
questions_path: str,
db_dir: str,
step_budget: int,
) -> None:
cls.configure_calls.append(
{
"questions_path": questions_path,
"db_dir": db_dir,
"step_budget": step_budget,
}
)
class _EnvironmentFactoryWithoutConfigure:
pass
def _build_trainer_config() -> SimpleNamespace:
return SimpleNamespace(
output_dir="outputs/test",
learning_rate=1e-5,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_train_epochs=1,
logging_steps=1,
max_new_tokens=128,
num_generations=2,
questions_path="data/questions/student_assessment.json",
db_dir="data/databases",
step_budget=7,
)
def test_build_trainer_with_environment_factory() -> None:
_EnvironmentFactoryWithConfigure.configure_calls = []
config = _build_trainer_config()
trainer = build_trainer(
model=object(),
tokenizer=object(),
prompts=["prompt"],
config=config,
trl_grpo_config_cls=_BuildTrainerConfigRecorder,
grpo_trainer_cls=_BuildTrainerClassRecorder,
reward_funcs=[sql_env_reward_func],
environment_factory=_EnvironmentFactoryWithConfigure,
)
assert trainer.kwargs["environment_factory"] is _EnvironmentFactoryWithConfigure
assert _EnvironmentFactoryWithConfigure.configure_calls == [
{
"questions_path": config.questions_path,
"db_dir": config.db_dir,
"step_budget": config.step_budget,
}
]
def test_build_trainer_without_environment_factory() -> None:
config = _build_trainer_config()
trainer = build_trainer(
model=object(),
tokenizer=object(),
prompts=["prompt"],
config=config,
trl_grpo_config_cls=_BuildTrainerConfigRecorder,
grpo_trainer_cls=_BuildTrainerClassRecorder,
reward_funcs=[sql_env_reward_func],
environment_factory=None,
)
assert "environment_factory" not in trainer.kwargs
def test_build_trainer_passes_reward_funcs() -> None:
config = _build_trainer_config()
reward_funcs = [sql_env_reward_func]
trainer = build_trainer(
model=object(),
tokenizer=object(),
prompts=["prompt"],
config=config,
trl_grpo_config_cls=_BuildTrainerConfigRecorder,
grpo_trainer_cls=_BuildTrainerClassRecorder,
reward_funcs=reward_funcs,
environment_factory=_EnvironmentFactoryWithoutConfigure,
)
assert trainer.kwargs["reward_funcs"] == reward_funcs