Spaces:
Sleeping
Sleeping
File size: 6,353 Bytes
8fc1355 f762b8d 8fc1355 f762b8d 8fc1355 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | import os
import re
import json
import textwrap
from typing import List
from openai import OpenAI
from client import SQLAnalystClient
from env import Action as SQLAction
DEBUG = True
ACTION_PREFIX_RE = re.compile(
r"^(action|next action)\s*[:\-]\s*",
re.IGNORECASE,
)
ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
FALLBACK_ACTION = "noop()"
MAX_STEPS = 20
SYSTEM_PROMPT = textwrap.dedent(
"""
You are a SQL Data Analyst Agent.
Your goal is to answer business questions by writing and executing SQL queries.
Reply with exactly one action string.
The action must be a valid SQL command such as:
- execute_sql('SELECT * FROM users')
- submit_answer('42')
- noop()
Use single quotes around string arguments.
Do not include explanations or additional text.
"""
).strip()
def build_history_lines(history: List[str]) -> str:
if not history:
return "None"
return "\n".join(history[-4:])
def build_user_prompt(step: int, observation, history: List[str]) -> str:
goal = getattr(
observation, "question", observation.get("question", "(not provided)")
)
schema = getattr(
observation,
"schema_summary",
observation.get("schema_summary", "(none detected)"),
)
last_error = getattr(observation, "last_error", observation.get("last_error", None))
error_note = "Yes" if last_error else "No"
prompt = textwrap.dedent(
f"""
Step: {step}
Goal: {goal}
Database Schema: {schema}
Previous steps:
{build_history_lines(history)}
Last action error: {error_note}
Reply with exactly one SQL action string.
"""
).strip()
return prompt
def parse_model_action(response_text: str) -> str:
if not response_text:
return FALLBACK_ACTION
lines = response_text.splitlines()
for raw_line in lines:
line = raw_line.strip()
if not line:
continue
line = ACTION_PREFIX_RE.sub("", line)
match = ACTION_PATTERN.search(line)
if match:
action = match.group(0).strip()
action = re.sub(r"\s+", " ", action)
return action
match = ACTION_PATTERN.search(response_text)
if match:
action = match.group(0).strip()
action = re.sub(r"\s+", " ", action)
return action
return FALLBACK_ACTION
def extract_sql_or_answer(action_str: str):
"""Extract sql_query or submit_answer from action string like execute_sql('SELECT...')"""
action_str = action_str.strip()
if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("):
match = re.search(r"\((.*)\)", action_str)
if match:
content = match.group(1).strip()
# Remove outer quotes if present
if (content.startswith("'") and content.endswith("'")) or (
content.startswith('"') and content.endswith('"')
):
content = content[1:-1]
if action_str.startswith("execute_sql("):
return content, None
else:
return None, content
if action_str == "noop()":
return None, None
# Default: treat as SQL query
return action_str, None
def main():
api_key = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
env_url = os.environ.get("OPENENV_URL")
if not api_key:
print("Error: Set API_KEY, HF_TOKEN, or OPENAI_API_KEY environment variable")
return
client = OpenAI(base_url=base_url, api_key=api_key)
tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"]
for task_id in tasks:
print(
f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}"
)
history: List[str] = []
# Use local environment instead of HTTP
from env import SQLAnalystEnv as LocalEnv
env = LocalEnv(task_id=task_id)
result = env.reset()
observation = result.observation
total_reward = 0.0
for step in range(1, MAX_STEPS + 1):
if result.done:
break
user_prompt = build_user_prompt(step, observation, history)
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.0,
)
response_text = completion.choices[0].message.content or ""
except Exception as exc:
print(f"Model request failed ({exc}). Using fallback action.")
response_text = FALLBACK_ACTION
action_str = parse_model_action(response_text)
sql_query, submit_answer = extract_sql_or_answer(action_str)
if submit_answer:
action = SQLAction(submit_answer=submit_answer)
elif sql_query:
action = SQLAction(sql_query=sql_query)
else:
action = SQLAction(sql_query="SELECT 1")
result = env.step(action)
observation = result.observation
reward = result.reward or 0.0
total_reward += reward
print(
f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}"
)
error_flag = " ERROR" if observation.last_error else ""
history_line = (
f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
)
history.append(history_line)
print(
f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}"
)
avg_score = total_reward
print(f"\n{'=' * 60}")
print(f"TASK: {task_id}")
print(f"FINAL REWARD: {avg_score:.3f}")
print(f"{'=' * 60}\n")
if __name__ == "__main__":
main()
|