sql_env / tests /unit /test_prompts.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Unit tests for GRPO training prompts."""
from sql_env.models import SQLObservation
from sql_env.training.prompts import format_observation, get_system_prompt
def _build_observation(
*,
result: str = "25",
error: str = "",
done: bool = False,
reward: float | None = None,
) -> SQLObservation:
return SQLObservation(
question="How many employees are there?",
schema_info="Available tables:\n- employees",
result=result,
error=error,
step_count=1,
budget_remaining=9,
action_history=["DESCRIBE employees"],
done=done,
reward=reward,
)
def test_system_prompt_returns_string() -> None:
prompt = get_system_prompt()
assert isinstance(prompt, str)
assert prompt.strip()
def test_system_prompt_mentions_tools() -> None:
prompt = get_system_prompt()
assert "describe" in prompt
assert "query" in prompt
assert "answer" in prompt
def test_system_prompt_is_deterministic() -> None:
assert get_system_prompt() == get_system_prompt()
def test_system_prompt_thinking_disabled() -> None:
prompt = get_system_prompt(enable_thinking=False)
assert prompt.startswith("/no_think")
def test_system_prompt_thinking_enabled() -> None:
prompt = get_system_prompt(enable_thinking=True)
assert not prompt.startswith("/no_think")
def test_format_observation_happy() -> None:
formatted = format_observation(_build_observation())
assert formatted
assert "How many employees are there?" in formatted
assert "25" in formatted
assert "Budget Remaining: 9" in formatted
def test_format_observation_with_error() -> None:
formatted = format_observation(_build_observation(result="", error="syntax error"))
assert "syntax error" in formatted
def test_format_observation_done_state() -> None:
formatted = format_observation(_build_observation(done=True, reward=1.0))
assert "Done: True" in formatted
assert "Final Reward: 1.0" in formatted
def test_format_observation_empty_result() -> None:
formatted = format_observation(_build_observation(result=""))
assert formatted
assert "(empty)" in formatted
def test_format_observation_long_result() -> None:
formatted = format_observation(_build_observation(result="x" * 10000))
assert formatted
assert "[truncated]" in formatted