| """Unit tests for terminal SFT assistant message.""" | |
| from __future__ import annotations | |
| from types import SimpleNamespace | |
| from scripts.generate_sft_data import generate_trajectory | |
| class _FakeEnv: | |
| def __init__(self, question_text: str, final_reward: float = 1.0) -> None: | |
| self.questions = [SimpleNamespace(question_text=question_text)] | |
| self._step_index = 0 | |
| self._final_reward = final_reward | |
| def reset(self, seed: int | None = None) -> SimpleNamespace: | |
| del seed | |
| return SimpleNamespace(schema_info="Available tables:\n- students") | |
| def step(self, action: object) -> SimpleNamespace: | |
| del action | |
| self._step_index += 1 | |
| if self._step_index == 1: | |
| return SimpleNamespace(result="students(id, name)", error=None, reward=0.0) | |
| if self._step_index == 2: | |
| return SimpleNamespace(result="[(2,)]", error=None, reward=0.0) | |
| return SimpleNamespace( | |
| result="Answer submitted: correct.", | |
| error=None, | |
| reward=self._final_reward, | |
| ) | |
| def test_generate_trajectory_appends_terminal_assistant_message() -> None: | |
| question = { | |
| "question_text": "How many students are there?", | |
| "tables_involved": ["students"], | |
| "gold_sql": "SELECT COUNT(*) FROM students", | |
| "gold_answer": "2", | |
| "answer_type": "integer", | |
| } | |
| env = _FakeEnv(question_text=question["question_text"]) | |
| result = generate_trajectory(env=env, question=question) | |
| assert result is not None | |
| messages = result["messages"] | |
| assert messages[-2]["role"] == "tool" | |
| assert messages[-1]["role"] == "assistant" | |
| assert messages[-1]["content"] == "Task complete." | |
| assert "tool_calls" not in messages[-1] | |
| def test_generate_trajectory_returns_none_when_reward_is_low() -> None: | |
| question = { | |
| "question_text": "How many students are there?", | |
| "tables_involved": ["students"], | |
| "gold_sql": "SELECT COUNT(*) FROM students", | |
| "gold_answer": "2", | |
| "answer_type": "integer", | |
| } | |
| env = _FakeEnv(question_text=question["question_text"], final_reward=0.0) | |
| assert generate_trajectory(env=env, question=question) is None | |