File size: 4,392 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Baseline inference script for sql-data-analyst OpenEnv.

Usage:
    export OPENAI_API_KEY=sk-...
    python baseline/run_baseline.py

Produces reproducible scores across all 3 tasks.
"""

import os
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from typing import List, Dict, Any

try:
    from openai import OpenAI
except ImportError:
    print("Error: openai package not installed. Run: pip install openai")
    sys.exit(1)

from env import SQLAnalystEnv, Action
from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action


MODEL = "gpt-4o-mini"
MAX_STEPS = 20
TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]


def run_task(
    client: OpenAI, task_id: str, max_steps: int = MAX_STEPS
) -> Dict[str, Any]:
    """Run a single task with the LLM agent."""
    print(f"\n{'=' * 50}")
    print(f"Task: {task_id}")
    print("=" * 50)

    env = SQLAnalystEnv(task_id=task_id)
    result = env.reset()
    obs = result.observation
    history = []
    total_reward = 0.0

    print(f"Question: {obs.question}")
    print(f"Schema: {obs.schema_summary[:200]}...")

    for step in range(1, max_steps + 1):
        if result.done:
            print(f"Episode done at step {step - 1}")
            break

        user_prompt = build_prompt(obs)
        history.append({"role": "user", "content": user_prompt})

        try:
            response = client.chat.completions.create(
                model=MODEL,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    *history[-8:],
                ],
                temperature=0.0,
            )
        except Exception as e:
            print(f"API Error: {e}")
            break

        reply = response.choices[0].message.content or ""
        history.append({"role": "assistant", "content": reply})

        action = parse_action(reply)

        if action.sql_query:
            print(f"Step {step}: Executing SQL...")
            print(f"  Query: {action.sql_query[:100]}...")
        else:
            print(f"Step {step}: Submitting answer...")
            print(
                f"  Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..."
            )

        result = env.step(action)
        obs = result.observation
        total_reward = result.info.get("total_reward", 0.0)

        if result.done:
            break

    state = env.state()
    print(f"\nFinal total reward: {total_reward:.3f}")
    print(f"Steps taken: {state.step}")

    return {
        "task_id": task_id,
        "difficulty": state.difficulty,
        "total_reward": round(total_reward, 3),
        "steps": state.step,
        "max_steps": state.max_steps,
    }


def main():
    api_key = os.environ.get("OPENAI_API_KEY")

    if not api_key:
        print("Error: OPENAI_API_KEY environment variable not set")
        print("Usage: export OPENAI_API_KEY=sk-...")
        sys.exit(1)

    client = OpenAI(api_key=api_key)

    print("=" * 60)
    print("SQL Data Analyst - Baseline Inference")
    print("=" * 60)
    print(f"Model: {MODEL}")
    print(f"Max steps per task: {MAX_STEPS}")
    print(f"Tasks: {TASK_IDS}")

    results = []
    for task_id in TASK_IDS:
        try:
            r = run_task(client, task_id)
            results.append(r)
        except Exception as e:
            print(f"Error running task {task_id}: {e}")
            results.append(
                {
                    "task_id": task_id,
                    "error": str(e),
                    "total_reward": 0.0,
                    "steps": 0,
                }
            )

    print("\n" + "=" * 60)
    print("BASELINE RESULTS")
    print("=" * 60)

    for r in results:
        task = r.get("task_id", "unknown")
        reward = r.get("total_reward", 0.0)
        steps = r.get("steps", 0)
        print(f"{task:30s}  score={reward:.3f}  steps={steps}")

    valid_results = [r for r in results if "total_reward" in r]
    if valid_results:
        avg = sum(r["total_reward"] for r in valid_results) / len(valid_results)
        print(f"\nAverage score: {avg:.3f}")

    output_file = "baseline_scores.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved results to {output_file}")


if __name__ == "__main__":
    main()