#!/usr/bin/env python3 # inference.py # Baseline Inference Script for OpenEnv SQL Analyst # Uses OpenAI API client to run model against the environment import os import sys import json from typing import Optional # Add the project root to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from openai import OpenAI from environment.env import SQLAnalystEnv from environment.models import Action # ============================================ # CONFIGURATION # ============================================ API_BASE_URL = os.environ.get("API_BASE_URL") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") API_KEY = os.environ.get("API_KEY") if not API_BASE_URL: raise ValueError("API_BASE_URL environment variable is required") if not API_KEY: raise ValueError("API_KEY environment variable is required") # Environment configuration BENCHMARK_NAME = "sql_analyst" MAX_STEPS = 15 # ============================================ # SYSTEM PROMPT # ============================================ SYSTEM_PROMPT = """You are an expert SQL Data Analyst AI agent. Your task is to answer business questions by querying a SQLite database. You have two possible actions each turn: 1. Execute a SQL query to explore the data: {"sql_query": "SELECT ..."} 2. Submit your final answer: {"submit_answer": "your answer"} IMPORTANT RULES: - Only use SELECT queries. INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE are blocked. - Explore the data step by step before submitting your final answer. - Your final answer should be just the value requested (a number, name, etc.), not a SQL query. - Respond with ONLY a valid JSON object, no other text. DATABASE SCHEMA: {schema_info} CURRENT QUESTION: {current_question} LAST QUERY RESULT: {last_query_result} {error_section} Respond with a JSON object containing either "sql_query" or "submit_answer".""" def format_action_str(action: Action) -> str: """Format action for logging.""" if action.sql_query: # Truncate long queries for logging query = action.sql_query.replace("\n", " ").strip() if len(query) > 50: query = query[:47] + "..." return f"sql_query={query}" elif action.submit_answer: answer = str(action.submit_answer).strip() if len(answer) > 30: answer = answer[:27] + "..." return f"submit_answer={answer}" return "invalid_action" def parse_model_response(response_text: str) -> Optional[Action]: """ Parse the model's response into an Action. Args: response_text: The raw text response from the model Returns: Action or None if parsing fails """ try: # Clean the response text = response_text.strip() # Try to extract JSON from the response # Handle cases where model wraps JSON in markdown code blocks if "```json" in text: start = text.find("```json") + 7 end = text.find("```", start) text = text[start:end].strip() elif "```" in text: start = text.find("```") + 3 end = text.find("```", start) text = text[start:end].strip() # Parse JSON data = json.loads(text) # Create Action return Action( sql_query=data.get("sql_query"), submit_answer=data.get("submit_answer") ) except (json.JSONDecodeError, ValueError) as e: return None def run_inference(): """ Run the baseline inference loop. This function: 1. Initializes the environment 2. Runs the model against the environment 3. Outputs structured logs in the exact required format """ # Initialize OpenAI client client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) # Initialize environment env = SQLAnalystEnv() # Reset environment and get initial observation observation = env.reset() # Get task info from state state = env.state() task_name = state.get("task_id", "unknown") # ============================================ # [START] LOG - EXACT FORMAT REQUIRED # ============================================ print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}") # Track rewards and steps rewards = [] step_num = 0 done = False success = False final_score = 0.0 while not done and step_num < MAX_STEPS: step_num += 1 # Build the prompt error_section = "" if observation.error_message: error_section = f"ERROR FROM LAST ACTION:\n{observation.error_message}" prompt = SYSTEM_PROMPT.format( schema_info=observation.schema_info, current_question=observation.current_question, last_query_result=observation.last_query_result, error_section=error_section, ) try: # Call the model response = client.chat.completions.create( model=MODEL_NAME, messages=[ { "role": "system", "content": "You are a SQL expert. Respond only with valid JSON.", }, {"role": "user", "content": prompt}, ], temperature=0.0, max_tokens=500, ) # Extract response text response_text = response.choices[0].message.content # Parse into Action action = parse_model_response(response_text) if action is None: # Failed to parse, try a simple query as fallback action = Action(sql_query="SELECT 1") error_msg = "parse_error" else: error_msg = "null" # Execute action in environment observation, reward, done, info = env.step(action) # Track reward reward_value = reward.value rewards.append(reward_value) # Check for errors in observation if observation.error_message: error_msg = observation.error_message.replace("\n", " ")[:50] # ============================================ # [STEP] LOG - EXACT FORMAT REQUIRED # ============================================ action_str = format_action_str(action) done_str = "true" if done else "false" print( f"[STEP] step={step_num} action={action_str} reward={reward_value:.2f} done={done_str} error={error_msg}" ) # Update final results if done: success = info.get("success", False) final_score = info.get("final_score", 0.0) except Exception as e: # Handle API or other errors error_msg = str(e).replace("\n", " ")[:50] print( f"[STEP] step={step_num} action=error reward=0.00 done=false error={error_msg}" ) rewards.append(0.0) # Try to continue with a simple action try: action = Action(submit_answer="error") observation, reward, done, info = env.step(action) success = info.get("success", False) final_score = info.get("final_score", 0.0) except: done = True success = False final_score = 0.0 # ============================================ # [END] LOG - EXACT FORMAT REQUIRED # ============================================ success_str = "true" if success else "false" rewards_str = ",".join([f"{r:.2f}" for r in rewards]) print( f"[END] success={success_str} steps={step_num} score={final_score:.2f} rewards={rewards_str}" ) # Cleanup env.close() return success, final_score def main(): """Main entry point.""" try: success, score = run_inference() sys.exit(0 if success else 0) # Always exit 0 for validation script except Exception as e: # Emergency fallback - still output required logs print(f"[START] task=error env={BENCHMARK_NAME} model={MODEL_NAME}") print(f"[STEP] step=1 action=error reward=0.00 done=true error={str(e)[:50]}") print(f"[END] success=false steps=1 score=0.00 rewards=0.00") sys.exit(0) if __name__ == "__main__": main()