Spaces:
Sleeping
Sleeping
| """ | |
| Baseline inference script for sql-data-analyst OpenEnv. | |
| Usage: | |
| export OPENAI_API_KEY=sk-... | |
| python baseline/run_baseline.py | |
| Produces reproducible scores across all 3 tasks. | |
| """ | |
| import os | |
| import json | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from typing import List, Dict, Any | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| print("Error: openai package not installed. Run: pip install openai") | |
| sys.exit(1) | |
| from env import SQLAnalystEnv, Action | |
| from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action | |
| MODEL = "gpt-4o-mini" | |
| MAX_STEPS = 20 | |
| TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"] | |
| def run_task( | |
| client: OpenAI, task_id: str, max_steps: int = MAX_STEPS | |
| ) -> Dict[str, Any]: | |
| """Run a single task with the LLM agent.""" | |
| print(f"\n{'=' * 50}") | |
| print(f"Task: {task_id}") | |
| print("=" * 50) | |
| env = SQLAnalystEnv(task_id=task_id) | |
| result = env.reset() | |
| obs = result.observation | |
| history = [] | |
| total_reward = 0.0 | |
| print(f"Question: {obs.question}") | |
| print(f"Schema: {obs.schema_summary[:200]}...") | |
| for step in range(1, max_steps + 1): | |
| if result.done: | |
| print(f"Episode done at step {step - 1}") | |
| break | |
| user_prompt = build_prompt(obs) | |
| history.append({"role": "user", "content": user_prompt}) | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| *history[-8:], | |
| ], | |
| temperature=0.0, | |
| ) | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| break | |
| reply = response.choices[0].message.content or "" | |
| history.append({"role": "assistant", "content": reply}) | |
| action = parse_action(reply) | |
| if action.sql_query: | |
| print(f"Step {step}: Executing SQL...") | |
| print(f" Query: {action.sql_query[:100]}...") | |
| else: | |
| print(f"Step {step}: Submitting answer...") | |
| print( | |
| f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..." | |
| ) | |
| result = env.step(action) | |
| obs = result.observation | |
| total_reward = result.info.get("total_reward", 0.0) | |
| if result.done: | |
| break | |
| state = env.state() | |
| print(f"\nFinal total reward: {total_reward:.3f}") | |
| print(f"Steps taken: {state.step}") | |
| return { | |
| "task_id": task_id, | |
| "difficulty": state.difficulty, | |
| "total_reward": round(total_reward, 3), | |
| "steps": state.step, | |
| "max_steps": state.max_steps, | |
| } | |
| def main(): | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| print("Error: OPENAI_API_KEY environment variable not set") | |
| print("Usage: export OPENAI_API_KEY=sk-...") | |
| sys.exit(1) | |
| client = OpenAI(api_key=api_key) | |
| print("=" * 60) | |
| print("SQL Data Analyst - Baseline Inference") | |
| print("=" * 60) | |
| print(f"Model: {MODEL}") | |
| print(f"Max steps per task: {MAX_STEPS}") | |
| print(f"Tasks: {TASK_IDS}") | |
| results = [] | |
| for task_id in TASK_IDS: | |
| try: | |
| r = run_task(client, task_id) | |
| results.append(r) | |
| except Exception as e: | |
| print(f"Error running task {task_id}: {e}") | |
| results.append( | |
| { | |
| "task_id": task_id, | |
| "error": str(e), | |
| "total_reward": 0.0, | |
| "steps": 0, | |
| } | |
| ) | |
| print("\n" + "=" * 60) | |
| print("BASELINE RESULTS") | |
| print("=" * 60) | |
| for r in results: | |
| task = r.get("task_id", "unknown") | |
| reward = r.get("total_reward", 0.0) | |
| steps = r.get("steps", 0) | |
| print(f"{task:30s} score={reward:.3f} steps={steps}") | |
| valid_results = [r for r in results if "total_reward" in r] | |
| if valid_results: | |
| avg = sum(r["total_reward"] for r in valid_results) / len(valid_results) | |
| print(f"\nAverage score: {avg:.3f}") | |
| output_file = "baseline_scores.json" | |
| with open(output_file, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nSaved results to {output_file}") | |
| if __name__ == "__main__": | |
| main() | |