YashashMathur commited on
Commit
8fc1355
·
verified ·
1 Parent(s): 2d4a521

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +204 -0
inference.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import textwrap
5
+ from typing import List
6
+ from openai import OpenAI
7
+ from client import SQLAnalystClient
8
+ from env import Action as SQLAction
9
+
10
+ DEBUG = True
11
+ ACTION_PREFIX_RE = re.compile(
12
+ r"^(action|next action)\s*[:\-]\s*",
13
+ re.IGNORECASE,
14
+ )
15
+ ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
16
+ FALLBACK_ACTION = "noop()"
17
+ MAX_STEPS = 20
18
+
19
+ SYSTEM_PROMPT = textwrap.dedent(
20
+ """
21
+ You are a SQL Data Analyst Agent.
22
+ Your goal is to answer business questions by writing and executing SQL queries.
23
+ Reply with exactly one action string.
24
+ The action must be a valid SQL command such as:
25
+ - execute_sql('SELECT * FROM users')
26
+ - submit_answer('42')
27
+ - noop()
28
+ Use single quotes around string arguments.
29
+ Do not include explanations or additional text.
30
+ """
31
+ ).strip()
32
+
33
+
34
+ def build_history_lines(history: List[str]) -> str:
35
+ if not history:
36
+ return "None"
37
+ return "\n".join(history[-4:])
38
+
39
+
40
+ def build_user_prompt(step: int, observation, history: List[str]) -> str:
41
+ goal = getattr(
42
+ observation, "question", observation.get("question", "(not provided)")
43
+ )
44
+ schema = getattr(
45
+ observation,
46
+ "schema_summary",
47
+ observation.get("schema_summary", "(none detected)"),
48
+ )
49
+ last_error = getattr(observation, "last_error", observation.get("last_error", None))
50
+ error_note = "Yes" if last_error else "No"
51
+
52
+ prompt = textwrap.dedent(
53
+ f"""
54
+ Step: {step}
55
+ Goal: {goal}
56
+ Database Schema: {schema}
57
+ Previous steps:
58
+ {build_history_lines(history)}
59
+ Last action error: {error_note}
60
+ Reply with exactly one SQL action string.
61
+ """
62
+ ).strip()
63
+ return prompt
64
+
65
+
66
+ def parse_model_action(response_text: str) -> str:
67
+ if not response_text:
68
+ return FALLBACK_ACTION
69
+
70
+ lines = response_text.splitlines()
71
+ for raw_line in lines:
72
+ line = raw_line.strip()
73
+ if not line:
74
+ continue
75
+ line = ACTION_PREFIX_RE.sub("", line)
76
+ match = ACTION_PATTERN.search(line)
77
+ if match:
78
+ action = match.group(0).strip()
79
+ action = re.sub(r"\s+", " ", action)
80
+ return action
81
+
82
+ match = ACTION_PATTERN.search(response_text)
83
+ if match:
84
+ action = match.group(0).strip()
85
+ action = re.sub(r"\s+", " ", action)
86
+ return action
87
+
88
+ return FALLBACK_ACTION
89
+
90
+
91
+ def extract_sql_or_answer(action_str: str):
92
+ """Extract sql_query or submit_answer from action string like execute_sql('SELECT...')"""
93
+ action_str = action_str.strip()
94
+
95
+ if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("):
96
+ match = re.search(r"\((.*)\)", action_str)
97
+ if match:
98
+ content = match.group(1).strip()
99
+ # Remove outer quotes if present
100
+ if (content.startswith("'") and content.endswith("'")) or (
101
+ content.startswith('"') and content.endswith('"')
102
+ ):
103
+ content = content[1:-1]
104
+
105
+ if action_str.startswith("execute_sql("):
106
+ return content, None
107
+ else:
108
+ return None, content
109
+
110
+ if action_str == "noop()":
111
+ return None, None
112
+
113
+ # Default: treat as SQL query
114
+ return action_str, None
115
+
116
+
117
+ def main():
118
+ api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
119
+ base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
120
+ model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
121
+ env_url = os.environ.get("OPENENV_URL")
122
+
123
+ if not api_key:
124
+ print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable")
125
+ return
126
+
127
+ client = OpenAI(base_url=base_url, api_key=api_key)
128
+
129
+ tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"]
130
+
131
+ for task_id in tasks:
132
+ print(
133
+ f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}"
134
+ )
135
+
136
+ history: List[str] = []
137
+
138
+ # Use local environment instead of HTTP
139
+ from env import SQLAnalystEnv as LocalEnv
140
+
141
+ env = LocalEnv(task_id=task_id)
142
+ result = env.reset()
143
+ observation = result.observation
144
+ total_reward = 0.0
145
+
146
+ for step in range(1, MAX_STEPS + 1):
147
+ if result.done:
148
+ break
149
+
150
+ user_prompt = build_user_prompt(step, observation, history)
151
+
152
+ try:
153
+ completion = client.chat.completions.create(
154
+ model=model_name,
155
+ messages=[
156
+ {"role": "system", "content": SYSTEM_PROMPT},
157
+ {"role": "user", "content": user_prompt},
158
+ ],
159
+ temperature=0.0,
160
+ )
161
+ response_text = completion.choices[0].message.content or ""
162
+ except Exception as exc:
163
+ print(f"Model request failed ({exc}). Using fallback action.")
164
+ response_text = FALLBACK_ACTION
165
+
166
+ action_str = parse_model_action(response_text)
167
+
168
+ sql_query, submit_answer = extract_sql_or_answer(action_str)
169
+
170
+ if submit_answer:
171
+ action = SQLAction(submit_answer=submit_answer)
172
+ elif sql_query:
173
+ action = SQLAction(sql_query=sql_query)
174
+ else:
175
+ action = SQLAction(sql_query="SELECT 1")
176
+
177
+ result = env.step(action)
178
+ observation = result.observation
179
+ reward = result.reward or 0.0
180
+ total_reward += reward
181
+
182
+ print(
183
+ f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}"
184
+ )
185
+
186
+ error_flag = " ERROR" if observation.last_error else ""
187
+ history_line = (
188
+ f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
189
+ )
190
+ history.append(history_line)
191
+
192
+ print(
193
+ f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}"
194
+ )
195
+
196
+ avg_score = total_reward
197
+ print(f"\n{'=' * 60}")
198
+ print(f"TASK: {task_id}")
199
+ print(f"FINAL REWARD: {avg_score:.3f}")
200
+ print(f"{'=' * 60}\n")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()