""" SQL Agent OpenEnv — Baseline Inference Script ============================================== Runs a baseline LLM agent against all 3 tasks of the SQL Agent OpenEnv environment. Environment variables (required): API_BASE_URL — OpenAI-compatible base URL (default: https://router.huggingface.co/v1) MODEL_NAME — Model identifier (default: Qwen/Qwen2.5-72B-Instruct) HF_TOKEN — Hugging Face / API key STDOUT format (strictly enforced): [START] task= env=sql-agent-openenv model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score=<0.000> rewards= """ from __future__ import annotations import asyncio import os import sys import textwrap from typing import List, Optional # ── Path setup (inference.py lives at repo root; backend is a subdirectory) ── _BACKEND = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backend") if _BACKEND not in sys.path: sys.path.insert(0, _BACKEND) from openai import OpenAI # noqa: E402 from env.sql_env import SQLAgentEnv, Action, Observation # noqa: E402 # ── Config ──────────────────────────────────────────────────────────────────── API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "") API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") BENCHMARK = "sql-agent-openenv" TASKS = ["simple_queries", "join_queries", "complex_queries"] MAX_STEPS = 5 TEMPERATURE = 0.2 MAX_TOKENS = 50 REPAIR_ACTIONS = [ "rewrite_full", "fix_column", "fix_table", "add_groupby", "rewrite_cte", "fix_syntax", "change_dialect", "relax_filter", ] SYSTEM_PROMPT = textwrap.dedent(""" You are an expert SQL agent interacting with a SQL repair environment. At each step you receive a natural language question, a database schema, and optionally the last SQL attempt + error message. Your job: pick ONE repair action from the list below that is most likely to fix the SQL error on the next attempt. Available actions: generate — write fresh SQL from scratch (use on first attempt) rewrite_full — completely rewrite the query from scratch fix_column — fix wrong column name references fix_table — fix wrong table name references add_groupby — add or fix GROUP BY / aggregation clauses rewrite_cte — restructure subqueries or CTEs fix_syntax — fix syntax errors (brackets, commas, keywords) change_dialect — convert to SQLite-compatible functions relax_filter — broaden or remove overly strict WHERE conditions Reply with ONLY the action name. No explanation. No punctuation. Example: fix_column """).strip() # ── Logging ─────────────────────────────────────────────────────────────────── # Hard bounds: every score/reward we ever emit is clamped to this closed range. # 0.05 margin guarantees that :.2f and :.3f formatting never produces # "0.00", "0.000", "1.00", or "1.000" (all of which parse as exactly 0.0 / 1.0). _MIN_SCORE = 0.05 _MAX_SCORE = 0.95 def _safe_score(x) -> float: """Coerce anything (None, NaN, str, bool, int, float) to a float strictly in (0, 1).""" try: if x is None: return _MIN_SCORE if isinstance(x, bool): return _MAX_SCORE if x else _MIN_SCORE v = float(x) if v != v: # NaN return _MIN_SCORE if v == float("inf"): return _MAX_SCORE if v == float("-inf"): return _MIN_SCORE except (TypeError, ValueError): return _MIN_SCORE return max(_MIN_SCORE, min(_MAX_SCORE, v)) def log_start(task: str, model: str) -> None: print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) def log_step(step: int, action: str, reward, done: bool, error: Optional[str]) -> None: r = _safe_score(reward) error_val = (error or "null") if hasattr(error_val, "replace"): error_val = error_val.replace("\n", " ").strip() or "null" done_val = str(bool(done)).lower() print( f"[STEP] step={int(step)} action={action or 'noop'} reward={r:.2f} " f"done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score, rewards: List) -> None: s = _safe_score(score) safe_rewards = [_safe_score(r) for r in (rewards or [])] if not safe_rewards: safe_rewards = [_MIN_SCORE] rewards_str = ",".join(f"{r:.2f}" for r in safe_rewards) print( f"[END] success={str(bool(success)).lower()} steps={int(steps)} " f"score={s:.3f} rewards={rewards_str}", flush=True, ) # ── LLM helper ──────────────────────────────────────────────────────────────── def pick_action( client: OpenAI, obs: Observation, step: int, ) -> str: """Ask the LLM to pick a repair action given the current observation.""" if step == 1 or obs.current_sql is None: return "generate" user_msg = textwrap.dedent(f""" Question: {obs.question} Current SQL (failed): {obs.current_sql} Error: {obs.error_message or "unknown"} Error class: {obs.error_class or "unknown"} Attempt number: {obs.attempt_number} of {obs.max_attempts} Which repair action should I use next? """).strip() try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, ) raw = (completion.choices[0].message.content or "").strip().lower() # Normalise to valid action name for action in REPAIR_ACTIONS: if action in raw: return action return "rewrite_full" except Exception as exc: print(f"[DEBUG] LLM call failed: {exc}", flush=True) return "rewrite_full" # ── Single-episode runner ───────────────────────────────────────────────────── async def run_episode( env: SQLAgentEnv, client: OpenAI, task_id: str, ) -> None: """Run one full episode for a task, emitting structured stdout logs.""" log_start(task=task_id, model=MODEL_NAME) rewards: List[float] = [] steps_taken = 0 score = _MIN_SCORE success = False last_error: Optional[str] = None try: try: obs = env.reset(task_id) except Exception as exc: log_step(step=1, action="reset", reward=_MIN_SCORE, done=True, error=str(exc)) rewards.append(_MIN_SCORE) steps_taken = 1 return for step in range(1, MAX_STEPS + 1): try: action_name = pick_action(client, obs, step) except Exception: action_name = "generate" action = Action(repair_action=action_name) try: obs, reward_info = await env.step(action) except Exception as exc: log_step(step=step, action=action_name, reward=_MIN_SCORE, done=True, error=str(exc)) rewards.append(_MIN_SCORE) steps_taken = step break reward = _safe_score(getattr(reward_info, "value", None)) done = bool(getattr(reward_info, "done", False)) last_error = getattr(obs, "error_message", None) success = bool(getattr(reward_info, "success", False)) rewards.append(reward) steps_taken = step log_step( step=step, action=action_name, reward=reward, done=done, error=last_error, ) if done: break denom = max(len(rewards), 1) avg = sum(rewards) / denom if rewards else _MIN_SCORE score = _safe_score(avg) except Exception as exc: # Catch-all so we always emit a valid [END] line log_step(step=steps_taken or 1, action="error", reward=_MIN_SCORE, done=True, error=str(exc)) if not rewards: rewards.append(_MIN_SCORE) score = _MIN_SCORE finally: log_end( success=success, steps=max(int(steps_taken), 1), score=score, rewards=rewards, ) # ── Main ────────────────────────────────────────────────────────────────────── async def main() -> None: try: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) env = SQLAgentEnv() except Exception as exc: # Environment couldn't init — still emit a valid [START]/[STEP]/[END] per task for task_id in TASKS: log_start(task=task_id, model=MODEL_NAME) log_step(step=1, action="init_error", reward=_MIN_SCORE, done=True, error=str(exc)) log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE]) print("", flush=True) return for task_id in TASKS: try: await run_episode(env, client, task_id) except Exception as exc: # run_episode already has its own catch-all, but guard against anything leaking log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE]) print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True) print("", flush=True) if __name__ == "__main__": asyncio.run(main())