sql_env / tests /unit /test_sft_terminal_message.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Unit tests for terminal SFT assistant message."""
from __future__ import annotations
from types import SimpleNamespace
from scripts.generate_sft_data import generate_trajectory
class _FakeEnv:
def __init__(self, question_text: str, final_reward: float = 1.0) -> None:
self.questions = [SimpleNamespace(question_text=question_text)]
self._step_index = 0
self._final_reward = final_reward
def reset(self, seed: int | None = None) -> SimpleNamespace:
del seed
return SimpleNamespace(schema_info="Available tables:\n- students")
def step(self, action: object) -> SimpleNamespace:
del action
self._step_index += 1
if self._step_index == 1:
return SimpleNamespace(result="students(id, name)", error=None, reward=0.0)
if self._step_index == 2:
return SimpleNamespace(result="[(2,)]", error=None, reward=0.0)
return SimpleNamespace(
result="Answer submitted: correct.",
error=None,
reward=self._final_reward,
)
def test_generate_trajectory_appends_terminal_assistant_message() -> None:
question = {
"question_text": "How many students are there?",
"tables_involved": ["students"],
"gold_sql": "SELECT COUNT(*) FROM students",
"gold_answer": "2",
"answer_type": "integer",
}
env = _FakeEnv(question_text=question["question_text"])
result = generate_trajectory(env=env, question=question)
assert result is not None
messages = result["messages"]
assert messages[-2]["role"] == "tool"
assert messages[-1]["role"] == "assistant"
assert messages[-1]["content"] == "Task complete."
assert "tool_calls" not in messages[-1]
def test_generate_trajectory_returns_none_when_reward_is_low() -> None:
question = {
"question_text": "How many students are there?",
"tables_involved": ["students"],
"gold_sql": "SELECT COUNT(*) FROM students",
"gold_answer": "2",
"answer_type": "integer",
}
env = _FakeEnv(question_text=question["question_text"], final_reward=0.0)
assert generate_trajectory(env=env, question=question) is None