| """Prompt helpers for GRPO training rollouts. |
| |
| Canonical system prompt and observation formatting for SQL exploration. |
| The system prompt matches the SFT/GRPO tool-calling format used by |
| ``scripts/generate_sft_data.py`` and ``notebooks/train_grpo.ipynb``. |
| """ |
|
|
| try: |
| from sql_env.models import SQLObservation |
| except ImportError: |
| from models import SQLObservation |
|
|
| _MAX_RESULT_CHARS = 2000 |
|
|
| _SYSTEM_PROMPT = ( |
| "You answer questions about a SQL database. " |
| "Use ONLY the provided tools.\n\n" |
| "Strategy:\n" |
| "1. Call describe(table_name=...) to see columns\n" |
| "2. Call query(sql=...) to run SELECT queries\n" |
| "3. Call answer(value=...) to submit your final answer\n\n" |
| "Answer format: submit ONLY the data values from your query result.\n" |
| "- Single value: 42 or ford\n" |
| "- Multiple values: alice, bob, charlie\n" |
| "- Table rows: col1 | col2 (one row per line)\n" |
| "- No results: []\n\n" |
| "IMPORTANT: Call only ONE tool at a time, then read the " |
| "response before deciding what to do next." |
| ) |
|
|
|
|
| def get_system_prompt(*, enable_thinking: bool = False) -> str: |
| """Return the SQL exploration system prompt. |
| |
| Parameters |
| ---------- |
| enable_thinking |
| When False (default), prepends ``/no_think`` to disable |
| Qwen3 thinking mode. When True, returns prompt as-is. |
| |
| Returns |
| ------- |
| str |
| Deterministic prompt text describing tool-calling strategy. |
| """ |
| if enable_thinking: |
| return _SYSTEM_PROMPT |
| return "/no_think\n" + _SYSTEM_PROMPT |
|
|
|
|
| def format_observation(obs: SQLObservation) -> str: |
| """Format an observation into a model-ready user turn. |
| |
| Parameters |
| ---------- |
| obs |
| Environment observation to serialize for the language model. |
| |
| Returns |
| ------- |
| str |
| Human-readable observation context including question, schema, |
| latest result/error, and remaining budget. |
| """ |
|
|
| result_text = obs.result or "(empty)" |
| if len(result_text) > _MAX_RESULT_CHARS: |
| result_text = f"{result_text[:_MAX_RESULT_CHARS]}... [truncated]" |
|
|
| lines = [ |
| f"Question: {obs.question}", |
| "", |
| "Schema:", |
| obs.schema_info or "(none)", |
| "", |
| "Last Result:", |
| result_text, |
| ] |
|
|
| if obs.error: |
| lines.extend(["", f"Error: {obs.error}"]) |
|
|
| if obs.action_history: |
| lines.extend(["", "Action History:"]) |
| lines.extend(f"- {entry}" for entry in obs.action_history) |
|
|
| lines.extend( |
| [ |
| "", |
| f"Step: {obs.step_count}", |
| f"Budget Remaining: {obs.budget_remaining}", |
| f"Done: {obs.done}", |
| ] |
| ) |
|
|
| if obs.done: |
| reward_text = "None" if obs.reward is None else str(obs.reward) |
| lines.append(f"Final Reward: {reward_text}") |
|
|
| return "\n".join(lines) |
|
|