YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
"""
Baseline inference script for sql-data-analyst OpenEnv.
Usage:
export OPENAI_API_KEY=sk-...
python baseline/run_baseline.py
Produces reproducible scores across all 3 tasks.
"""
import os
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from typing import List, Dict, Any
try:
from openai import OpenAI
except ImportError:
print("Error: openai package not installed. Run: pip install openai")
sys.exit(1)
from env import SQLAnalystEnv, Action
from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action
MODEL = "gpt-4o-mini"
MAX_STEPS = 20
TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]
def run_task(
client: OpenAI, task_id: str, max_steps: int = MAX_STEPS
) -> Dict[str, Any]:
"""Run a single task with the LLM agent."""
print(f"\n{'=' * 50}")
print(f"Task: {task_id}")
print("=" * 50)
env = SQLAnalystEnv(task_id=task_id)
result = env.reset()
obs = result.observation
history = []
total_reward = 0.0
print(f"Question: {obs.question}")
print(f"Schema: {obs.schema_summary[:200]}...")
for step in range(1, max_steps + 1):
if result.done:
print(f"Episode done at step {step - 1}")
break
user_prompt = build_prompt(obs)
history.append({"role": "user", "content": user_prompt})
try:
response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
*history[-8:],
],
temperature=0.0,
)
except Exception as e:
print(f"API Error: {e}")
break
reply = response.choices[0].message.content or ""
history.append({"role": "assistant", "content": reply})
action = parse_action(reply)
if action.sql_query:
print(f"Step {step}: Executing SQL...")
print(f" Query: {action.sql_query[:100]}...")
else:
print(f"Step {step}: Submitting answer...")
print(
f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..."
)
result = env.step(action)
obs = result.observation
total_reward = result.info.get("total_reward", 0.0)
if result.done:
break
state = env.state()
print(f"\nFinal total reward: {total_reward:.3f}")
print(f"Steps taken: {state.step}")
return {
"task_id": task_id,
"difficulty": state.difficulty,
"total_reward": round(total_reward, 3),
"steps": state.step,
"max_steps": state.max_steps,
}
def main():
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY environment variable not set")
print("Usage: export OPENAI_API_KEY=sk-...")
sys.exit(1)
client = OpenAI(api_key=api_key)
print("=" * 60)
print("SQL Data Analyst - Baseline Inference")
print("=" * 60)
print(f"Model: {MODEL}")
print(f"Max steps per task: {MAX_STEPS}")
print(f"Tasks: {TASK_IDS}")
results = []
for task_id in TASK_IDS:
try:
r = run_task(client, task_id)
results.append(r)
except Exception as e:
print(f"Error running task {task_id}: {e}")
results.append(
{
"task_id": task_id,
"error": str(e),
"total_reward": 0.0,
"steps": 0,
}
)
print("\n" + "=" * 60)
print("BASELINE RESULTS")
print("=" * 60)
for r in results:
task = r.get("task_id", "unknown")
reward = r.get("total_reward", 0.0)
steps = r.get("steps", 0)
print(f"{task:30s} score={reward:.3f} steps={steps}")
valid_results = [r for r in results if "total_reward" in r]
if valid_results:
avg = sum(r["total_reward"] for r in valid_results) / len(valid_results)
print(f"\nAverage score: {avg:.3f}")
output_file = "baseline_scores.json"
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to {output_file}")
if __name__ == "__main__":
main()