YashashMathur's picture
Sync from GitHub - all files
f762b8d verified
raw
history blame
8.75 kB
#!/usr/bin/env python3
# inference.py
# Baseline Inference Script for OpenEnv SQL Analyst
# Uses OpenAI API client to run model against the environment
import os
import sys
import json
from typing import Optional
# Add the project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from openai import OpenAI
from environment.env import SQLAnalystEnv
from environment.models import Action
# ============================================
# CONFIGURATION
# ============================================
API_BASE_URL = os.environ.get("API_BASE_URL")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
API_KEY = os.environ.get("API_KEY")
if not API_BASE_URL:
raise ValueError("API_BASE_URL environment variable is required")
if not API_KEY:
raise ValueError("API_KEY environment variable is required")
# Environment configuration
BENCHMARK_NAME = "sql_analyst"
MAX_STEPS = 15
# ============================================
# SYSTEM PROMPT
# ============================================
SYSTEM_PROMPT = """You are an expert SQL Data Analyst AI agent. Your task is to answer business questions by querying a SQLite database.
You have two possible actions each turn:
1. Execute a SQL query to explore the data: {"sql_query": "SELECT ..."}
2. Submit your final answer: {"submit_answer": "your answer"}
IMPORTANT RULES:
- Only use SELECT queries. INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE are blocked.
- Explore the data step by step before submitting your final answer.
- Your final answer should be just the value requested (a number, name, etc.), not a SQL query.
- Respond with ONLY a valid JSON object, no other text.
DATABASE SCHEMA:
{schema_info}
CURRENT QUESTION:
{current_question}
LAST QUERY RESULT:
{last_query_result}
{error_section}
Respond with a JSON object containing either "sql_query" or "submit_answer"."""
def format_action_str(action: Action) -> str:
"""Format action for logging."""
if action.sql_query:
# Truncate long queries for logging
query = action.sql_query.replace("\n", " ").strip()
if len(query) > 50:
query = query[:47] + "..."
return f"sql_query={query}"
elif action.submit_answer:
answer = str(action.submit_answer).strip()
if len(answer) > 30:
answer = answer[:27] + "..."
return f"submit_answer={answer}"
return "invalid_action"
def parse_model_response(response_text: str) -> Optional[Action]:
"""
Parse the model's response into an Action.
Args:
response_text: The raw text response from the model
Returns:
Action or None if parsing fails
"""
try:
# Clean the response
text = response_text.strip()
# Try to extract JSON from the response
# Handle cases where model wraps JSON in markdown code blocks
if "```json" in text:
start = text.find("```json") + 7
end = text.find("```", start)
text = text[start:end].strip()
elif "```" in text:
start = text.find("```") + 3
end = text.find("```", start)
text = text[start:end].strip()
# Parse JSON
data = json.loads(text)
# Create Action
return Action(
sql_query=data.get("sql_query"), submit_answer=data.get("submit_answer")
)
except (json.JSONDecodeError, ValueError) as e:
return None
def run_inference():
"""
Run the baseline inference loop.
This function:
1. Initializes the environment
2. Runs the model against the environment
3. Outputs structured logs in the exact required format
"""
# Initialize OpenAI client
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# Initialize environment
env = SQLAnalystEnv()
# Reset environment and get initial observation
observation = env.reset()
# Get task info from state
state = env.state()
task_name = state.get("task_id", "unknown")
# ============================================
# [START] LOG - EXACT FORMAT REQUIRED
# ============================================
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
# Track rewards and steps
rewards = []
step_num = 0
done = False
success = False
final_score = 0.0
while not done and step_num < MAX_STEPS:
step_num += 1
# Build the prompt
error_section = ""
if observation.error_message:
error_section = f"ERROR FROM LAST ACTION:\n{observation.error_message}"
prompt = SYSTEM_PROMPT.format(
schema_info=observation.schema_info,
current_question=observation.current_question,
last_query_result=observation.last_query_result,
error_section=error_section,
)
try:
# Call the model
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": "You are a SQL expert. Respond only with valid JSON.",
},
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=500,
)
# Extract response text
response_text = response.choices[0].message.content
# Parse into Action
action = parse_model_response(response_text)
if action is None:
# Failed to parse, try a simple query as fallback
action = Action(sql_query="SELECT 1")
error_msg = "parse_error"
else:
error_msg = "null"
# Execute action in environment
observation, reward, done, info = env.step(action)
# Track reward
reward_value = reward.value
rewards.append(reward_value)
# Check for errors in observation
if observation.error_message:
error_msg = observation.error_message.replace("\n", " ")[:50]
# ============================================
# [STEP] LOG - EXACT FORMAT REQUIRED
# ============================================
action_str = format_action_str(action)
done_str = "true" if done else "false"
print(
f"[STEP] step={step_num} action={action_str} reward={reward_value:.2f} done={done_str} error={error_msg}"
)
# Update final results
if done:
success = info.get("success", False)
final_score = info.get("final_score", 0.0)
except Exception as e:
# Handle API or other errors
error_msg = str(e).replace("\n", " ")[:50]
print(
f"[STEP] step={step_num} action=error reward=0.00 done=false error={error_msg}"
)
rewards.append(0.0)
# Try to continue with a simple action
try:
action = Action(submit_answer="error")
observation, reward, done, info = env.step(action)
success = info.get("success", False)
final_score = info.get("final_score", 0.0)
except:
done = True
success = False
final_score = 0.0
# ============================================
# [END] LOG - EXACT FORMAT REQUIRED
# ============================================
success_str = "true" if success else "false"
rewards_str = ",".join([f"{r:.2f}" for r in rewards])
print(
f"[END] success={success_str} steps={step_num} score={final_score:.2f} rewards={rewards_str}"
)
# Cleanup
env.close()
return success, final_score
def main():
"""Main entry point."""
try:
success, score = run_inference()
sys.exit(0 if success else 0) # Always exit 0 for validation script
except Exception as e:
# Emergency fallback - still output required logs
print(f"[START] task=error env={BENCHMARK_NAME} model={MODEL_NAME}")
print(f"[STEP] step=1 action=error reward=0.00 done=true error={str(e)[:50]}")
print(f"[END] success=false steps=1 score=0.00 rewards=0.00")
sys.exit(0)
if __name__ == "__main__":
main()