File size: 8,748 Bytes
f762b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#!/usr/bin/env python3
# inference.py
# Baseline Inference Script for OpenEnv SQL Analyst
# Uses OpenAI API client to run model against the environment

import os
import sys
import json
from typing import Optional

# Add the project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from openai import OpenAI
from environment.env import SQLAnalystEnv
from environment.models import Action


# ============================================
# CONFIGURATION
# ============================================
API_BASE_URL = os.environ.get("API_BASE_URL")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
API_KEY = os.environ.get("API_KEY")

if not API_BASE_URL:
    raise ValueError("API_BASE_URL environment variable is required")
if not API_KEY:
    raise ValueError("API_KEY environment variable is required")

# Environment configuration
BENCHMARK_NAME = "sql_analyst"
MAX_STEPS = 15


# ============================================
# SYSTEM PROMPT
# ============================================
SYSTEM_PROMPT = """You are an expert SQL Data Analyst AI agent. Your task is to answer business questions by querying a SQLite database.



You have two possible actions each turn:

1. Execute a SQL query to explore the data: {"sql_query": "SELECT ..."}

2. Submit your final answer: {"submit_answer": "your answer"}



IMPORTANT RULES:

- Only use SELECT queries. INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE are blocked.

- Explore the data step by step before submitting your final answer.

- Your final answer should be just the value requested (a number, name, etc.), not a SQL query.

- Respond with ONLY a valid JSON object, no other text.



DATABASE SCHEMA:

{schema_info}



CURRENT QUESTION:

{current_question}



LAST QUERY RESULT:

{last_query_result}



{error_section}



Respond with a JSON object containing either "sql_query" or "submit_answer"."""


def format_action_str(action: Action) -> str:
    """Format action for logging."""
    if action.sql_query:
        # Truncate long queries for logging
        query = action.sql_query.replace("\n", " ").strip()
        if len(query) > 50:
            query = query[:47] + "..."
        return f"sql_query={query}"
    elif action.submit_answer:
        answer = str(action.submit_answer).strip()
        if len(answer) > 30:
            answer = answer[:27] + "..."
        return f"submit_answer={answer}"
    return "invalid_action"


def parse_model_response(response_text: str) -> Optional[Action]:
    """

    Parse the model's response into an Action.



    Args:

        response_text: The raw text response from the model



    Returns:

        Action or None if parsing fails

    """
    try:
        # Clean the response
        text = response_text.strip()

        # Try to extract JSON from the response
        # Handle cases where model wraps JSON in markdown code blocks
        if "```json" in text:
            start = text.find("```json") + 7
            end = text.find("```", start)
            text = text[start:end].strip()
        elif "```" in text:
            start = text.find("```") + 3
            end = text.find("```", start)
            text = text[start:end].strip()

        # Parse JSON
        data = json.loads(text)

        # Create Action
        return Action(
            sql_query=data.get("sql_query"), submit_answer=data.get("submit_answer")
        )
    except (json.JSONDecodeError, ValueError) as e:
        return None


def run_inference():
    """

    Run the baseline inference loop.



    This function:

    1. Initializes the environment

    2. Runs the model against the environment

    3. Outputs structured logs in the exact required format

    """
    # Initialize OpenAI client
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

    # Initialize environment
    env = SQLAnalystEnv()

    # Reset environment and get initial observation
    observation = env.reset()

    # Get task info from state
    state = env.state()
    task_name = state.get("task_id", "unknown")

    # ============================================
    # [START] LOG - EXACT FORMAT REQUIRED
    # ============================================
    print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")

    # Track rewards and steps
    rewards = []
    step_num = 0
    done = False
    success = False
    final_score = 0.0

    while not done and step_num < MAX_STEPS:
        step_num += 1

        # Build the prompt
        error_section = ""
        if observation.error_message:
            error_section = f"ERROR FROM LAST ACTION:\n{observation.error_message}"

        prompt = SYSTEM_PROMPT.format(
            schema_info=observation.schema_info,
            current_question=observation.current_question,
            last_query_result=observation.last_query_result,
            error_section=error_section,
        )

        try:
            # Call the model
            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {
                        "role": "system",
                        "content": "You are a SQL expert. Respond only with valid JSON.",
                    },
                    {"role": "user", "content": prompt},
                ],
                temperature=0.0,
                max_tokens=500,
            )

            # Extract response text
            response_text = response.choices[0].message.content

            # Parse into Action
            action = parse_model_response(response_text)

            if action is None:
                # Failed to parse, try a simple query as fallback
                action = Action(sql_query="SELECT 1")
                error_msg = "parse_error"
            else:
                error_msg = "null"

            # Execute action in environment
            observation, reward, done, info = env.step(action)

            # Track reward
            reward_value = reward.value
            rewards.append(reward_value)

            # Check for errors in observation
            if observation.error_message:
                error_msg = observation.error_message.replace("\n", " ")[:50]

            # ============================================
            # [STEP] LOG - EXACT FORMAT REQUIRED
            # ============================================
            action_str = format_action_str(action)
            done_str = "true" if done else "false"
            print(
                f"[STEP]  step={step_num} action={action_str} reward={reward_value:.2f} done={done_str} error={error_msg}"
            )

            # Update final results
            if done:
                success = info.get("success", False)
                final_score = info.get("final_score", 0.0)

        except Exception as e:
            # Handle API or other errors
            error_msg = str(e).replace("\n", " ")[:50]
            print(
                f"[STEP]  step={step_num} action=error reward=0.00 done=false error={error_msg}"
            )
            rewards.append(0.0)

            # Try to continue with a simple action
            try:
                action = Action(submit_answer="error")
                observation, reward, done, info = env.step(action)
                success = info.get("success", False)
                final_score = info.get("final_score", 0.0)
            except:
                done = True
                success = False
                final_score = 0.0

    # ============================================
    # [END] LOG - EXACT FORMAT REQUIRED
    # ============================================
    success_str = "true" if success else "false"
    rewards_str = ",".join([f"{r:.2f}" for r in rewards])
    print(
        f"[END]   success={success_str} steps={step_num} score={final_score:.2f} rewards={rewards_str}"
    )

    # Cleanup
    env.close()

    return success, final_score


def main():
    """Main entry point."""
    try:
        success, score = run_inference()
        sys.exit(0 if success else 0)  # Always exit 0 for validation script
    except Exception as e:
        # Emergency fallback - still output required logs
        print(f"[START] task=error env={BENCHMARK_NAME} model={MODEL_NAME}")
        print(f"[STEP]  step=1 action=error reward=0.00 done=true error={str(e)[:50]}")
        print(f"[END]   success=false steps=1 score=0.00 rewards=0.00")
        sys.exit(0)


if __name__ == "__main__":
    main()