"""Unit tests for GRPO training prompts.""" from sql_env.models import SQLObservation from sql_env.training.prompts import format_observation, get_system_prompt def _build_observation( *, result: str = "25", error: str = "", done: bool = False, reward: float | None = None, ) -> SQLObservation: return SQLObservation( question="How many employees are there?", schema_info="Available tables:\n- employees", result=result, error=error, step_count=1, budget_remaining=9, action_history=["DESCRIBE employees"], done=done, reward=reward, ) def test_system_prompt_returns_string() -> None: prompt = get_system_prompt() assert isinstance(prompt, str) assert prompt.strip() def test_system_prompt_mentions_tools() -> None: prompt = get_system_prompt() assert "describe" in prompt assert "query" in prompt assert "answer" in prompt def test_system_prompt_is_deterministic() -> None: assert get_system_prompt() == get_system_prompt() def test_system_prompt_thinking_disabled() -> None: prompt = get_system_prompt(enable_thinking=False) assert prompt.startswith("/no_think") def test_system_prompt_thinking_enabled() -> None: prompt = get_system_prompt(enable_thinking=True) assert not prompt.startswith("/no_think") def test_format_observation_happy() -> None: formatted = format_observation(_build_observation()) assert formatted assert "How many employees are there?" in formatted assert "25" in formatted assert "Budget Remaining: 9" in formatted def test_format_observation_with_error() -> None: formatted = format_observation(_build_observation(result="", error="syntax error")) assert "syntax error" in formatted def test_format_observation_done_state() -> None: formatted = format_observation(_build_observation(done=True, reward=1.0)) assert "Done: True" in formatted assert "Final Reward: 1.0" in formatted def test_format_observation_empty_result() -> None: formatted = format_observation(_build_observation(result="")) assert formatted assert "(empty)" in formatted def test_format_observation_long_result() -> None: formatted = format_observation(_build_observation(result="x" * 10000)) assert formatted assert "[truncated]" in formatted