Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import textwrap | |
| from typing import List | |
| from openai import OpenAI | |
| from client import SQLAnalystClient as SQLAnalystEnv | |
| from env import Action as SQLAction | |
| DEBUG = True | |
| ACTION_PREFIX_RE = re.compile( | |
| r"^(action|next action)\s*[:\-]\s*", | |
| re.IGNORECASE, | |
| ) | |
| ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL) | |
| FALLBACK_ACTION = "noop()" | |
| MAX_STEPS = 20 | |
| SYSTEM_PROMPT = textwrap.dedent( | |
| """ | |
| You are a SQL Data Analyst Agent. | |
| Your goal is to answer business questions by writing and executing SQL queries. | |
| Reply with exactly one action string. | |
| The action must be a valid SQL command such as: | |
| - execute_sql('SELECT * FROM users') | |
| - submit_answer('42') | |
| - noop() | |
| Use single quotes around string arguments. | |
| Do not include explanations or additional text. | |
| """ | |
| ).strip() | |
| def build_history_lines(history: List[str]) -> str: | |
| if not history: | |
| return "None" | |
| return "\n".join(history[-4:]) | |
| def build_user_prompt(step: int, observation, history: List[str]) -> str: | |
| goal = getattr( | |
| observation, "question", observation.get("question", "(not provided)") | |
| ) | |
| schema = getattr( | |
| observation, | |
| "schema_summary", | |
| observation.get("schema_summary", "(none detected)"), | |
| ) | |
| last_error = getattr(observation, "last_error", observation.get("last_error", None)) | |
| error_note = "Yes" if last_error else "No" | |
| prompt = textwrap.dedent( | |
| f""" | |
| Step: {step} | |
| Goal: {goal} | |
| Database Schema: {schema} | |
| Previous steps: | |
| {build_history_lines(history)} | |
| Last action error: {error_note} | |
| Reply with exactly one SQL action string. | |
| """ | |
| ).strip() | |
| return prompt | |
| def parse_model_action(response_text: str) -> str: | |
| if not response_text: | |
| return FALLBACK_ACTION | |
| lines = response_text.splitlines() | |
| for raw_line in lines: | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| line = ACTION_PREFIX_RE.sub("", line) | |
| match = ACTION_PATTERN.search(line) | |
| if match: | |
| action = match.group(0).strip() | |
| action = re.sub(r"\s+", " ", action) | |
| return action | |
| match = ACTION_PATTERN.search(response_text) | |
| if match: | |
| action = match.group(0).strip() | |
| action = re.sub(r"\s+", " ", action) | |
| return action | |
| return FALLBACK_ACTION | |
| def extract_sql_or_answer(action_str: str): | |
| """Extract sql_query or submit_answer from action string like execute_sql('SELECT...')""" | |
| action_str = action_str.strip() | |
| if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("): | |
| match = re.search(r"\((.*)\)", action_str) | |
| if match: | |
| content = match.group(1).strip() | |
| # Remove outer quotes if present | |
| if (content.startswith("'") and content.endswith("'")) or ( | |
| content.startswith('"') and content.endswith('"') | |
| ): | |
| content = content[1:-1] | |
| if action_str.startswith("execute_sql("): | |
| return content, None | |
| else: | |
| return None, content | |
| if action_str == "noop()": | |
| return None, None | |
| # Default: treat as SQL query | |
| return action_str, None | |
| def main(): | |
| api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") | |
| base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") | |
| model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini") | |
| env_url = os.environ.get("OPENENV_URL") | |
| if not api_key: | |
| print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable") | |
| return | |
| client = OpenAI(base_url=base_url, api_key=api_key) | |
| tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"] | |
| for task_id in tasks: | |
| print( | |
| f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}" | |
| ) | |
| history: List[str] = [] | |
| # Use local environment instead of HTTP | |
| from env import SQLAnalystEnv as LocalEnv | |
| env = LocalEnv(task_id=task_id) | |
| result = env.reset() | |
| observation = result.observation | |
| total_reward = 0.0 | |
| for step in range(1, MAX_STEPS + 1): | |
| if result.done: | |
| break | |
| user_prompt = build_user_prompt(step, observation, history) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.0, | |
| ) | |
| response_text = completion.choices[0].message.content or "" | |
| except Exception as exc: | |
| print(f"Model request failed ({exc}). Using fallback action.") | |
| response_text = FALLBACK_ACTION | |
| action_str = parse_model_action(response_text) | |
| sql_query, submit_answer = extract_sql_or_answer(action_str) | |
| if submit_answer: | |
| action = SQLAction(submit_answer=submit_answer) | |
| elif sql_query: | |
| action = SQLAction(sql_query=sql_query) | |
| else: | |
| action = SQLAction(sql_query="SELECT 1") | |
| result = env.step(action) | |
| observation = result.observation | |
| reward = result.reward or 0.0 | |
| total_reward += reward | |
| print( | |
| f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}" | |
| ) | |
| error_flag = " ERROR" if observation.last_error else "" | |
| history_line = ( | |
| f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}" | |
| ) | |
| history.append(history_line) | |
| print( | |
| f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}" | |
| ) | |
| avg_score = total_reward | |
| print(f"\n{'=' * 60}") | |
| print(f"TASK: {task_id}") | |
| print(f"FINAL REWARD: {avg_score:.3f}") | |
| print(f"{'=' * 60}\n") | |
| if __name__ == "__main__": | |
| main() | |