File size: 6,314 Bytes
8fc1355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import re
import json
import textwrap
from typing import List
from openai import OpenAI
from client import SQLAnalystClient
from env import Action as SQLAction

DEBUG = True
ACTION_PREFIX_RE = re.compile(
    r"^(action|next action)\s*[:\-]\s*",
    re.IGNORECASE,
)
ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
FALLBACK_ACTION = "noop()"
MAX_STEPS = 20

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are a SQL Data Analyst Agent.
    Your goal is to answer business questions by writing and executing SQL queries.
    Reply with exactly one action string.
    The action must be a valid SQL command such as:
    - execute_sql('SELECT * FROM users')
    - submit_answer('42')
    - noop()
    Use single quotes around string arguments.
    Do not include explanations or additional text.
    """
).strip()


def build_history_lines(history: List[str]) -> str:
    if not history:
        return "None"
    return "\n".join(history[-4:])


def build_user_prompt(step: int, observation, history: List[str]) -> str:
    goal = getattr(
        observation, "question", observation.get("question", "(not provided)")
    )
    schema = getattr(
        observation,
        "schema_summary",
        observation.get("schema_summary", "(none detected)"),
    )
    last_error = getattr(observation, "last_error", observation.get("last_error", None))
    error_note = "Yes" if last_error else "No"

    prompt = textwrap.dedent(
        f"""
        Step: {step}
        Goal: {goal}
        Database Schema: {schema}
        Previous steps:
        {build_history_lines(history)}
        Last action error: {error_note}
        Reply with exactly one SQL action string.
        """
    ).strip()
    return prompt


def parse_model_action(response_text: str) -> str:
    if not response_text:
        return FALLBACK_ACTION

    lines = response_text.splitlines()
    for raw_line in lines:
        line = raw_line.strip()
        if not line:
            continue
        line = ACTION_PREFIX_RE.sub("", line)
        match = ACTION_PATTERN.search(line)
        if match:
            action = match.group(0).strip()
            action = re.sub(r"\s+", " ", action)
            return action

    match = ACTION_PATTERN.search(response_text)
    if match:
        action = match.group(0).strip()
        action = re.sub(r"\s+", " ", action)
        return action

    return FALLBACK_ACTION


def extract_sql_or_answer(action_str: str):
    """Extract sql_query or submit_answer from action string like execute_sql('SELECT...')"""
    action_str = action_str.strip()

    if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("):
        match = re.search(r"\((.*)\)", action_str)
        if match:
            content = match.group(1).strip()
            # Remove outer quotes if present
            if (content.startswith("'") and content.endswith("'")) or (
                content.startswith('"') and content.endswith('"')
            ):
                content = content[1:-1]

            if action_str.startswith("execute_sql("):
                return content, None
            else:
                return None, content

    if action_str == "noop()":
        return None, None

    # Default: treat as SQL query
    return action_str, None


def main():
    api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
    model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
    env_url = os.environ.get("OPENENV_URL")

    if not api_key:
        print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable")
        return

    client = OpenAI(base_url=base_url, api_key=api_key)

    tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"]

    for task_id in tasks:
        print(
            f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}"
        )

        history: List[str] = []

        # Use local environment instead of HTTP
        from env import SQLAnalystEnv as LocalEnv

        env = LocalEnv(task_id=task_id)
        result = env.reset()
        observation = result.observation
        total_reward = 0.0

        for step in range(1, MAX_STEPS + 1):
            if result.done:
                break

            user_prompt = build_user_prompt(step, observation, history)

            try:
                completion = client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": user_prompt},
                    ],
                    temperature=0.0,
                )
                response_text = completion.choices[0].message.content or ""
            except Exception as exc:
                print(f"Model request failed ({exc}). Using fallback action.")
                response_text = FALLBACK_ACTION

            action_str = parse_model_action(response_text)

            sql_query, submit_answer = extract_sql_or_answer(action_str)

            if submit_answer:
                action = SQLAction(submit_answer=submit_answer)
            elif sql_query:
                action = SQLAction(sql_query=sql_query)
            else:
                action = SQLAction(sql_query="SELECT 1")

            result = env.step(action)
            observation = result.observation
            reward = result.reward or 0.0
            total_reward += reward

            print(
                f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}"
            )

            error_flag = " ERROR" if observation.last_error else ""
            history_line = (
                f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
            )
            history.append(history_line)

        print(
            f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}"
        )

        avg_score = total_reward
        print(f"\n{'=' * 60}")
        print(f"TASK: {task_id}")
        print(f"FINAL REWARD: {avg_score:.3f}")
        print(f"{'=' * 60}\n")


if __name__ == "__main__":
    main()