| """Unit tests for GRPO training prompts.""" |
|
|
| from sql_env.models import SQLObservation |
| from sql_env.training.prompts import format_observation, get_system_prompt |
|
|
|
|
| def _build_observation( |
| *, |
| result: str = "25", |
| error: str = "", |
| done: bool = False, |
| reward: float | None = None, |
| ) -> SQLObservation: |
| return SQLObservation( |
| question="How many employees are there?", |
| schema_info="Available tables:\n- employees", |
| result=result, |
| error=error, |
| step_count=1, |
| budget_remaining=9, |
| action_history=["DESCRIBE employees"], |
| done=done, |
| reward=reward, |
| ) |
|
|
|
|
| def test_system_prompt_returns_string() -> None: |
| prompt = get_system_prompt() |
|
|
| assert isinstance(prompt, str) |
| assert prompt.strip() |
|
|
|
|
| def test_system_prompt_mentions_tools() -> None: |
| prompt = get_system_prompt() |
|
|
| assert "describe" in prompt |
| assert "query" in prompt |
| assert "answer" in prompt |
|
|
|
|
| def test_system_prompt_is_deterministic() -> None: |
| assert get_system_prompt() == get_system_prompt() |
|
|
|
|
| def test_system_prompt_thinking_disabled() -> None: |
| prompt = get_system_prompt(enable_thinking=False) |
| assert prompt.startswith("/no_think") |
|
|
|
|
| def test_system_prompt_thinking_enabled() -> None: |
| prompt = get_system_prompt(enable_thinking=True) |
| assert not prompt.startswith("/no_think") |
|
|
|
|
| def test_format_observation_happy() -> None: |
| formatted = format_observation(_build_observation()) |
|
|
| assert formatted |
| assert "How many employees are there?" in formatted |
| assert "25" in formatted |
| assert "Budget Remaining: 9" in formatted |
|
|
|
|
| def test_format_observation_with_error() -> None: |
| formatted = format_observation(_build_observation(result="", error="syntax error")) |
|
|
| assert "syntax error" in formatted |
|
|
|
|
| def test_format_observation_done_state() -> None: |
| formatted = format_observation(_build_observation(done=True, reward=1.0)) |
|
|
| assert "Done: True" in formatted |
| assert "Final Reward: 1.0" in formatted |
|
|
|
|
| def test_format_observation_empty_result() -> None: |
| formatted = format_observation(_build_observation(result="")) |
|
|
| assert formatted |
| assert "(empty)" in formatted |
|
|
|
|
| def test_format_observation_long_result() -> None: |
| formatted = format_observation(_build_observation(result="x" * 10000)) |
|
|
| assert formatted |
| assert "[truncated]" in formatted |
|
|