sql_env / training /prompts.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Prompt helpers for GRPO training rollouts.
Canonical system prompt and observation formatting for SQL exploration.
The system prompt matches the SFT/GRPO tool-calling format used by
``scripts/generate_sft_data.py`` and ``notebooks/train_grpo.ipynb``.
"""
try:
from sql_env.models import SQLObservation
except ImportError:
from models import SQLObservation
_MAX_RESULT_CHARS = 2000
_SYSTEM_PROMPT = (
"You answer questions about a SQL database. "
"Use ONLY the provided tools.\n\n"
"Strategy:\n"
"1. Call describe(table_name=...) to see columns\n"
"2. Call query(sql=...) to run SELECT queries\n"
"3. Call answer(value=...) to submit your final answer\n\n"
"Answer format: submit ONLY the data values from your query result.\n"
"- Single value: 42 or ford\n"
"- Multiple values: alice, bob, charlie\n"
"- Table rows: col1 | col2 (one row per line)\n"
"- No results: []\n\n"
"IMPORTANT: Call only ONE tool at a time, then read the "
"response before deciding what to do next."
)
def get_system_prompt(*, enable_thinking: bool = False) -> str:
"""Return the SQL exploration system prompt.
Parameters
----------
enable_thinking
When False (default), prepends ``/no_think`` to disable
Qwen3 thinking mode. When True, returns prompt as-is.
Returns
-------
str
Deterministic prompt text describing tool-calling strategy.
"""
if enable_thinking:
return _SYSTEM_PROMPT
return "/no_think\n" + _SYSTEM_PROMPT
def format_observation(obs: SQLObservation) -> str:
"""Format an observation into a model-ready user turn.
Parameters
----------
obs
Environment observation to serialize for the language model.
Returns
-------
str
Human-readable observation context including question, schema,
latest result/error, and remaining budget.
"""
result_text = obs.result or "(empty)"
if len(result_text) > _MAX_RESULT_CHARS:
result_text = f"{result_text[:_MAX_RESULT_CHARS]}... [truncated]"
lines = [
f"Question: {obs.question}",
"",
"Schema:",
obs.schema_info or "(none)",
"",
"Last Result:",
result_text,
]
if obs.error:
lines.extend(["", f"Error: {obs.error}"])
if obs.action_history:
lines.extend(["", "Action History:"])
lines.extend(f"- {entry}" for entry in obs.action_history)
lines.extend(
[
"",
f"Step: {obs.step_count}",
f"Budget Remaining: {obs.budget_remaining}",
f"Done: {obs.done}",
]
)
if obs.done:
reward_text = "None" if obs.reward is None else str(obs.reward)
lines.append(f"Final Reward: {reward_text}")
return "\n".join(lines)