"""Unit tests for terminal SFT assistant message.""" from __future__ import annotations from types import SimpleNamespace from scripts.generate_sft_data import generate_trajectory class _FakeEnv: def __init__(self, question_text: str, final_reward: float = 1.0) -> None: self.questions = [SimpleNamespace(question_text=question_text)] self._step_index = 0 self._final_reward = final_reward def reset(self, seed: int | None = None) -> SimpleNamespace: del seed return SimpleNamespace(schema_info="Available tables:\n- students") def step(self, action: object) -> SimpleNamespace: del action self._step_index += 1 if self._step_index == 1: return SimpleNamespace(result="students(id, name)", error=None, reward=0.0) if self._step_index == 2: return SimpleNamespace(result="[(2,)]", error=None, reward=0.0) return SimpleNamespace( result="Answer submitted: correct.", error=None, reward=self._final_reward, ) def test_generate_trajectory_appends_terminal_assistant_message() -> None: question = { "question_text": "How many students are there?", "tables_involved": ["students"], "gold_sql": "SELECT COUNT(*) FROM students", "gold_answer": "2", "answer_type": "integer", } env = _FakeEnv(question_text=question["question_text"]) result = generate_trajectory(env=env, question=question) assert result is not None messages = result["messages"] assert messages[-2]["role"] == "tool" assert messages[-1]["role"] == "assistant" assert messages[-1]["content"] == "Task complete." assert "tool_calls" not in messages[-1] def test_generate_trajectory_returns_none_when_reward_is_low() -> None: question = { "question_text": "How many students are there?", "tables_involved": ["students"], "gold_sql": "SELECT COUNT(*) FROM students", "gold_answer": "2", "answer_type": "integer", } env = _FakeEnv(question_text=question["question_text"], final_reward=0.0) assert generate_trajectory(env=env, question=question) is None