"""Prompt helpers for GRPO training rollouts. Canonical system prompt and observation formatting for SQL exploration. The system prompt matches the SFT/GRPO tool-calling format used by ``scripts/generate_sft_data.py`` and ``notebooks/train_grpo.ipynb``. """ try: from sql_env.models import SQLObservation except ImportError: from models import SQLObservation _MAX_RESULT_CHARS = 2000 _SYSTEM_PROMPT = ( "You answer questions about a SQL database. " "Use ONLY the provided tools.\n\n" "Strategy:\n" "1. Call describe(table_name=...) to see columns\n" "2. Call query(sql=...) to run SELECT queries\n" "3. Call answer(value=...) to submit your final answer\n\n" "Answer format: submit ONLY the data values from your query result.\n" "- Single value: 42 or ford\n" "- Multiple values: alice, bob, charlie\n" "- Table rows: col1 | col2 (one row per line)\n" "- No results: []\n\n" "IMPORTANT: Call only ONE tool at a time, then read the " "response before deciding what to do next." ) def get_system_prompt(*, enable_thinking: bool = False) -> str: """Return the SQL exploration system prompt. Parameters ---------- enable_thinking When False (default), prepends ``/no_think`` to disable Qwen3 thinking mode. When True, returns prompt as-is. Returns ------- str Deterministic prompt text describing tool-calling strategy. """ if enable_thinking: return _SYSTEM_PROMPT return "/no_think\n" + _SYSTEM_PROMPT def format_observation(obs: SQLObservation) -> str: """Format an observation into a model-ready user turn. Parameters ---------- obs Environment observation to serialize for the language model. Returns ------- str Human-readable observation context including question, schema, latest result/error, and remaining budget. """ result_text = obs.result or "(empty)" if len(result_text) > _MAX_RESULT_CHARS: result_text = f"{result_text[:_MAX_RESULT_CHARS]}... [truncated]" lines = [ f"Question: {obs.question}", "", "Schema:", obs.schema_info or "(none)", "", "Last Result:", result_text, ] if obs.error: lines.extend(["", f"Error: {obs.error}"]) if obs.action_history: lines.extend(["", "Action History:"]) lines.extend(f"- {entry}" for entry in obs.action_history) lines.extend( [ "", f"Step: {obs.step_count}", f"Budget Remaining: {obs.budget_remaining}", f"Done: {obs.done}", ] ) if obs.done: reward_text = "None" if obs.reward is None else str(obs.reward) lines.append(f"Final Reward: {reward_text}") return "\n".join(lines)