""" Baseline inference script for sql-data-analyst OpenEnv. Usage: export OPENAI_API_KEY=sk-... python baseline/run_baseline.py Produces reproducible scores across all 3 tasks. """ import os import json import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from typing import List, Dict, Any try: from openai import OpenAI except ImportError: print("Error: openai package not installed. Run: pip install openai") sys.exit(1) from env import SQLAnalystEnv, Action from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action MODEL = "gpt-4o-mini" MAX_STEPS = 20 TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"] def run_task( client: OpenAI, task_id: str, max_steps: int = MAX_STEPS ) -> Dict[str, Any]: """Run a single task with the LLM agent.""" print(f"\n{'=' * 50}") print(f"Task: {task_id}") print("=" * 50) env = SQLAnalystEnv(task_id=task_id) result = env.reset() obs = result.observation history = [] total_reward = 0.0 print(f"Question: {obs.question}") print(f"Schema: {obs.schema_summary[:200]}...") for step in range(1, max_steps + 1): if result.done: print(f"Episode done at step {step - 1}") break user_prompt = build_prompt(obs) history.append({"role": "user", "content": user_prompt}) try: response = client.chat.completions.create( model=MODEL, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, *history[-8:], ], temperature=0.0, ) except Exception as e: print(f"API Error: {e}") break reply = response.choices[0].message.content or "" history.append({"role": "assistant", "content": reply}) action = parse_action(reply) if action.sql_query: print(f"Step {step}: Executing SQL...") print(f" Query: {action.sql_query[:100]}...") else: print(f"Step {step}: Submitting answer...") print( f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..." ) result = env.step(action) obs = result.observation total_reward = result.info.get("total_reward", 0.0) if result.done: break state = env.state() print(f"\nFinal total reward: {total_reward:.3f}") print(f"Steps taken: {state.step}") return { "task_id": task_id, "difficulty": state.difficulty, "total_reward": round(total_reward, 3), "steps": state.step, "max_steps": state.max_steps, } def main(): api_key = os.environ.get("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY environment variable not set") print("Usage: export OPENAI_API_KEY=sk-...") sys.exit(1) client = OpenAI(api_key=api_key) print("=" * 60) print("SQL Data Analyst - Baseline Inference") print("=" * 60) print(f"Model: {MODEL}") print(f"Max steps per task: {MAX_STEPS}") print(f"Tasks: {TASK_IDS}") results = [] for task_id in TASK_IDS: try: r = run_task(client, task_id) results.append(r) except Exception as e: print(f"Error running task {task_id}: {e}") results.append( { "task_id": task_id, "error": str(e), "total_reward": 0.0, "steps": 0, } ) print("\n" + "=" * 60) print("BASELINE RESULTS") print("=" * 60) for r in results: task = r.get("task_id", "unknown") reward = r.get("total_reward", 0.0) steps = r.get("steps", 0) print(f"{task:30s} score={reward:.3f} steps={steps}") valid_results = [r for r in results if "total_reward" in r] if valid_results: avg = sum(r["total_reward"] for r in valid_results) / len(valid_results) print(f"\nAverage score: {avg:.3f}") output_file = "baseline_scores.json" with open(output_file, "w") as f: json.dump(results, f, indent=2) print(f"\nSaved results to {output_file}") if __name__ == "__main__": main()