sql_data_analyst / baseline /prompts.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
raw
history blame
2.37 kB
"""
Prompt templates for the baseline inference script.
"""
SYSTEM_PROMPT = """
You are a SQL data analyst. You are given a database schema and a business question.
Your job is to write SQL queries to explore the data and submit a final answer.
Rules:
- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
- Reply with JSON only. No explanation.
- To run a query: {"sql_query": "SELECT ..."}
- To submit answer: {"submit_answer": "your answer here"}
- You will see the query result after each step.
- Submit your answer when you are confident.
Important:
- Always use valid SQL syntax
- Table names: users, products, orders, order_items, events
- Dates are stored as ISO timestamps
- Always filter orders by status='completed' for revenue calculations
"""
def build_prompt(obs) -> str:
"""Build the user prompt from an observation."""
parts = [
f"Database schema:\n{obs.schema_summary}",
f"\nQuestion: {obs.question}",
f"\nStep: {obs.step} / {obs.max_steps}",
]
if obs.last_query:
parts.append(f"\nLast query:\n{obs.last_query}")
if obs.last_result:
if obs.last_result.error:
parts.append(f"\nSQL error: {obs.last_result.error}")
elif obs.last_result.rows:
cols = obs.last_result.columns
rows = obs.last_result.rows[:10]
parts.append(f"\nResult columns: {cols}")
parts.append(
f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
)
if obs.hints:
parts.append(f"\nHints: {'; '.join(obs.hints)}")
parts.append("\nWhat is your next action? Reply with JSON only.")
return "\n".join(parts)
import json
def parse_action(response_text: str | None):
"""Extract JSON action from LLM response."""
from env import Action
if not response_text:
return Action(submit_answer="")
text = response_text.strip()
text = text.replace("```json", "").replace("```", "").strip()
try:
data = json.loads(text)
if "sql_query" in data and data["sql_query"]:
return Action(sql_query=data["sql_query"])
elif "submit_answer" in data and data["submit_answer"]:
return Action(submit_answer=data["submit_answer"])
except json.JSONDecodeError:
pass
return Action(submit_answer=text)