""" inference.py — SQL Debug RL Environment Calls the running FastAPI server (/reset, /step) for each task and reports scores in the mandatory [START] / [STEP] / [END] format expected by OpenEnv. Uses official OpenAI client as required by OpenEnv evaluation rules. """ import os import time import json import urllib.request from typing import List, Optional from openai import OpenAI from openai.types.chat import ChatCompletion try: from dotenv import load_dotenv load_dotenv(override=True) except ImportError: pass # ── Configuration ───────────────────────────────────────────────────────────── ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") # OpenEnv injects these two — ALWAYS use them, never hardcode API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") API_KEY = os.getenv("API_KEY", os.getenv("OPENAI_API_KEY", "")) # OpenEnv injects API_KEY MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # Initialize official OpenAI client client = OpenAI( api_key=API_KEY, base_url=API_BASE_URL ) # Task to run — OpenEnv injects this via env variable TASK_ID = os.getenv("TASK_ID", "").strip() MAX_STEPS = 5 TEMPERATURE = 0.3 MAX_TOKENS = 512 # All valid task IDs in this environment ALL_TASKS = [ "task_1_easy", "task_2_medium", "task_3_hard", "task_4_expert", "task_5_optimization", "task_6_migration", "task_7_chaos", ] SYSTEM_PROMPT = """You are an expert SQL debugger. You will receive a broken SQL query and must fix it. Return ONLY the corrected SQL query. No explanation, no markdown, no code fences. Just the raw SQL.""" # ── Logging helpers (OpenEnv required format) ───────────────────────────────── def log_start(task: str, model: str) -> None: print(f"[START] task={task} env=sql-debug-env model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" action_safe = repr(action[:80]) print(f"[STEP] step={step} action={action_safe} reward={reward:.4f} done={str(done).lower()} error={error_val}", flush=True) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.4f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={rewards_str}", flush=True) # ── Environment API calls ───────────────────────────────────────────────────── def http_post(url: str, payload: dict, timeout: int = 30) -> dict: req = urllib.request.Request( url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, method="POST" ) with urllib.request.urlopen(req, timeout=timeout) as resp: return json.loads(resp.read().decode()) def env_reset(task_id: str) -> dict: return http_post(f"{ENV_BASE_URL}/reset", {"task_id": task_id}) def env_step(fixed_sql: str, explanation: str = "") -> dict: return http_post(f"{ENV_BASE_URL}/step", {"fixed_sql": fixed_sql, "explanation": explanation}) # ── LLM call with retry ─────────────────────────────────────────────────────── def get_llm_fix(broken_sql: str, error_hint: str, schema_info: dict, previous_attempts: list) -> str: attempts_text = "" if previous_attempts: attempts_text = "\n\nPrevious failed attempts:\n" + "\n".join( f"- {a}" for a in previous_attempts[-2:] ) schema_text = "\n".join( f"Table {tbl}: {', '.join(cols)}" for tbl, cols in schema_info.items() ) user_msg = f"""Fix this broken SQL query. Schema: {schema_text} Error: {error_hint} Broken SQL: {broken_sql} {attempts_text} Return ONLY the fixed SQL. No explanation.""" for attempt in range(4): try: response: ChatCompletion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, ) text = (response.choices[0].message.content or "").strip() # Strip markdown code fences if present if text.startswith("```"): lines = text.split("\n") text = "\n".join(l for l in lines if not l.startswith("```")).strip() return text if text else broken_sql except Exception as e: # Handle rate limits (429) manually with backoff if "429" in str(e) and attempt < 3: wait = 4 * (2 ** attempt) time.sleep(wait) continue return broken_sql return broken_sql # ── Main loop ───────────────────────────────────────────────────────────────── def run_task(task_id: str) -> float: log_start(task=task_id, model=MODEL_NAME) rewards: List[float] = [] steps_taken = 0 score = 0.15 # safe non-zero default if env fails success = False try: # Reset environment for this task reset_resp = env_reset(task_id) obs = reset_resp.get("observation", {}) broken_sql = obs.get("broken_sql", "SELECT 1") error_hint = obs.get("error_hint", "") schema_info = obs.get("schema_info", {}) previous_attempts: List[str] = [] for step in range(1, MAX_STEPS + 1): # Ask LLM to fix the SQL fixed_sql = get_llm_fix(broken_sql, error_hint, schema_info, previous_attempts) # Submit to environment step_resp = env_step(fixed_sql) reward = float(step_resp.get("reward", 0.0)) done = bool(step_resp.get("done", False)) # Clamp reward to safe range strictly between 0 and 1 reward = max(-0.99, min(0.99, reward)) rewards.append(reward) steps_taken = step previous_attempts.append(f"step {step}: {fixed_sql[:60]!r}") log_step(step=step, action=fixed_sql, reward=reward, done=done, error=None) if done: break # Normalize total reward into (0, 1) — never exactly 0 or 1 positive_rewards = [r for r in rewards if r > 0] if positive_rewards: raw_score = sum(positive_rewards) / (len(rewards) * 0.99) else: raw_score = 0.1 # agent tried but didn't solve # Hard clamp: strictly between 0 and 1 score = max(0.01, min(0.99, raw_score)) success = score >= 0.5 except Exception as exc: score = 0.15 # Non-zero safe default success = False log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return score def main(): specific_task = TASK_ID results_dir = "outputs" os.makedirs(results_dir, exist_ok=True) results_path = os.path.join(results_dir, "baseline_results.json") final_data = { "model": MODEL_NAME, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "tasks": {} } if specific_task and specific_task in ALL_TASKS: score = run_task(specific_task) final_data["tasks"][specific_task] = {"score": score} else: # Run all tasks so the validator sees graders for every task all_scores = [] for t_id in ALL_TASKS: score = run_task(t_id) all_scores.append(score) final_data["tasks"][t_id] = {"score": score} avg = sum(all_scores) / len(all_scores) final_data["avg_score"] = avg # Save to JSON for local tracking try: with open(results_path, "w") as f: json.dump(final_data, f, indent=2) except Exception as e: pass if __name__ == "__main__": main()