File size: 2,374 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Prompt templates for the baseline inference script.
"""

SYSTEM_PROMPT = """
You are a SQL data analyst. You are given a database schema and a business question.
Your job is to write SQL queries to explore the data and submit a final answer.

Rules:
- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
- Reply with JSON only. No explanation.
- To run a query: {"sql_query": "SELECT ..."}
- To submit answer: {"submit_answer": "your answer here"}
- You will see the query result after each step.
- Submit your answer when you are confident.

Important:
- Always use valid SQL syntax
- Table names: users, products, orders, order_items, events
- Dates are stored as ISO timestamps
- Always filter orders by status='completed' for revenue calculations
"""


def build_prompt(obs) -> str:
    """Build the user prompt from an observation."""
    parts = [
        f"Database schema:\n{obs.schema_summary}",
        f"\nQuestion: {obs.question}",
        f"\nStep: {obs.step} / {obs.max_steps}",
    ]

    if obs.last_query:
        parts.append(f"\nLast query:\n{obs.last_query}")

    if obs.last_result:
        if obs.last_result.error:
            parts.append(f"\nSQL error: {obs.last_result.error}")
        elif obs.last_result.rows:
            cols = obs.last_result.columns
            rows = obs.last_result.rows[:10]
            parts.append(f"\nResult columns: {cols}")
            parts.append(
                f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
            )

    if obs.hints:
        parts.append(f"\nHints: {'; '.join(obs.hints)}")

    parts.append("\nWhat is your next action? Reply with JSON only.")
    return "\n".join(parts)


import json


def parse_action(response_text: str | None):
    """Extract JSON action from LLM response."""
    from env import Action

    if not response_text:
        return Action(submit_answer="")

    text = response_text.strip()

    text = text.replace("```json", "").replace("```", "").strip()

    try:
        data = json.loads(text)

        if "sql_query" in data and data["sql_query"]:
            return Action(sql_query=data["sql_query"])
        elif "submit_answer" in data and data["submit_answer"]:
            return Action(submit_answer=data["submit_answer"])
    except json.JSONDecodeError:
        pass

    return Action(submit_answer=text)