Spaces:
Running
Running
| """ | |
| SQL Agent OpenEnv β Baseline Inference Script | |
| ============================================== | |
| Runs a baseline LLM agent against all 3 tasks of the SQL Agent OpenEnv environment. | |
| Environment variables (required): | |
| API_BASE_URL β OpenAI-compatible base URL (default: https://router.huggingface.co/v1) | |
| MODEL_NAME β Model identifier (default: Qwen/Qwen2.5-72B-Instruct) | |
| HF_TOKEN β Hugging Face / API key | |
| STDOUT format (strictly enforced): | |
| [START] task=<task_id> env=sql-agent-openenv model=<model> | |
| [STEP] step=<n> action=<action> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...> | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import sys | |
| import textwrap | |
| from typing import List, Optional | |
| # ββ Path setup (inference.py lives at repo root; backend is a subdirectory) ββ | |
| _BACKEND = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backend") | |
| if _BACKEND not in sys.path: | |
| sys.path.insert(0, _BACKEND) | |
| from openai import OpenAI # noqa: E402 | |
| from env.sql_env import SQLAgentEnv, Action, Observation # noqa: E402 | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| BENCHMARK = "sql-agent-openenv" | |
| TASKS = ["simple_queries", "join_queries", "complex_queries"] | |
| MAX_STEPS = 5 | |
| TEMPERATURE = 0.2 | |
| MAX_TOKENS = 50 | |
| REPAIR_ACTIONS = [ | |
| "rewrite_full", | |
| "fix_column", | |
| "fix_table", | |
| "add_groupby", | |
| "rewrite_cte", | |
| "fix_syntax", | |
| "change_dialect", | |
| "relax_filter", | |
| ] | |
| SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are an expert SQL agent interacting with a SQL repair environment. | |
| At each step you receive a natural language question, a database schema, | |
| and optionally the last SQL attempt + error message. | |
| Your job: pick ONE repair action from the list below that is most likely | |
| to fix the SQL error on the next attempt. | |
| Available actions: | |
| generate β write fresh SQL from scratch (use on first attempt) | |
| rewrite_full β completely rewrite the query from scratch | |
| fix_column β fix wrong column name references | |
| fix_table β fix wrong table name references | |
| add_groupby β add or fix GROUP BY / aggregation clauses | |
| rewrite_cte β restructure subqueries or CTEs | |
| fix_syntax β fix syntax errors (brackets, commas, keywords) | |
| change_dialect β convert to SQLite-compatible functions | |
| relax_filter β broaden or remove overly strict WHERE conditions | |
| Reply with ONLY the action name. No explanation. No punctuation. | |
| Example: fix_column | |
| """).strip() | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Hard bounds: every score/reward we ever emit is clamped to this closed range. | |
| # 0.05 margin guarantees that :.2f and :.3f formatting never produces | |
| # "0.00", "0.000", "1.00", or "1.000" (all of which parse as exactly 0.0 / 1.0). | |
| _MIN_SCORE = 0.05 | |
| _MAX_SCORE = 0.95 | |
| def _safe_score(x) -> float: | |
| """Coerce anything (None, NaN, str, bool, int, float) to a float strictly in (0, 1).""" | |
| try: | |
| if x is None: | |
| return _MIN_SCORE | |
| if isinstance(x, bool): | |
| return _MAX_SCORE if x else _MIN_SCORE | |
| v = float(x) | |
| if v != v: # NaN | |
| return _MIN_SCORE | |
| if v == float("inf"): | |
| return _MAX_SCORE | |
| if v == float("-inf"): | |
| return _MIN_SCORE | |
| except (TypeError, ValueError): | |
| return _MIN_SCORE | |
| return max(_MIN_SCORE, min(_MAX_SCORE, v)) | |
| def log_start(task: str, model: str) -> None: | |
| print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward, done: bool, error: Optional[str]) -> None: | |
| r = _safe_score(reward) | |
| error_val = (error or "null") | |
| if hasattr(error_val, "replace"): | |
| error_val = error_val.replace("\n", " ").strip() or "null" | |
| done_val = str(bool(done)).lower() | |
| print( | |
| f"[STEP] step={int(step)} action={action or 'noop'} reward={r:.2f} " | |
| f"done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score, rewards: List) -> None: | |
| s = _safe_score(score) | |
| safe_rewards = [_safe_score(r) for r in (rewards or [])] | |
| if not safe_rewards: | |
| safe_rewards = [_MIN_SCORE] | |
| rewards_str = ",".join(f"{r:.2f}" for r in safe_rewards) | |
| print( | |
| f"[END] success={str(bool(success)).lower()} steps={int(steps)} " | |
| f"score={s:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # ββ LLM helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def pick_action( | |
| client: OpenAI, | |
| obs: Observation, | |
| step: int, | |
| ) -> str: | |
| """Ask the LLM to pick a repair action given the current observation.""" | |
| if step == 1 or obs.current_sql is None: | |
| return "generate" | |
| user_msg = textwrap.dedent(f""" | |
| Question: {obs.question} | |
| Current SQL (failed): | |
| {obs.current_sql} | |
| Error: {obs.error_message or "unknown"} | |
| Error class: {obs.error_class or "unknown"} | |
| Attempt number: {obs.attempt_number} of {obs.max_attempts} | |
| Which repair action should I use next? | |
| """).strip() | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| ) | |
| raw = (completion.choices[0].message.content or "").strip().lower() | |
| # Normalise to valid action name | |
| for action in REPAIR_ACTIONS: | |
| if action in raw: | |
| return action | |
| return "rewrite_full" | |
| except Exception as exc: | |
| print(f"[DEBUG] LLM call failed: {exc}", flush=True) | |
| return "rewrite_full" | |
| # ββ Single-episode runner βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_episode( | |
| env: SQLAgentEnv, | |
| client: OpenAI, | |
| task_id: str, | |
| ) -> None: | |
| """Run one full episode for a task, emitting structured stdout logs.""" | |
| log_start(task=task_id, model=MODEL_NAME) | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = _MIN_SCORE | |
| success = False | |
| last_error: Optional[str] = None | |
| try: | |
| try: | |
| obs = env.reset(task_id) | |
| except Exception as exc: | |
| log_step(step=1, action="reset", reward=_MIN_SCORE, done=True, error=str(exc)) | |
| rewards.append(_MIN_SCORE) | |
| steps_taken = 1 | |
| return | |
| for step in range(1, MAX_STEPS + 1): | |
| try: | |
| action_name = pick_action(client, obs, step) | |
| except Exception: | |
| action_name = "generate" | |
| action = Action(repair_action=action_name) | |
| try: | |
| obs, reward_info = await env.step(action) | |
| except Exception as exc: | |
| log_step(step=step, action=action_name, reward=_MIN_SCORE, done=True, error=str(exc)) | |
| rewards.append(_MIN_SCORE) | |
| steps_taken = step | |
| break | |
| reward = _safe_score(getattr(reward_info, "value", None)) | |
| done = bool(getattr(reward_info, "done", False)) | |
| last_error = getattr(obs, "error_message", None) | |
| success = bool(getattr(reward_info, "success", False)) | |
| rewards.append(reward) | |
| steps_taken = step | |
| log_step( | |
| step=step, | |
| action=action_name, | |
| reward=reward, | |
| done=done, | |
| error=last_error, | |
| ) | |
| if done: | |
| break | |
| denom = max(len(rewards), 1) | |
| avg = sum(rewards) / denom if rewards else _MIN_SCORE | |
| score = _safe_score(avg) | |
| except Exception as exc: | |
| # Catch-all so we always emit a valid [END] line | |
| log_step(step=steps_taken or 1, action="error", reward=_MIN_SCORE, done=True, error=str(exc)) | |
| if not rewards: | |
| rewards.append(_MIN_SCORE) | |
| score = _MIN_SCORE | |
| finally: | |
| log_end( | |
| success=success, | |
| steps=max(int(steps_taken), 1), | |
| score=score, | |
| rewards=rewards, | |
| ) | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def main() -> None: | |
| try: | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| env = SQLAgentEnv() | |
| except Exception as exc: | |
| # Environment couldn't init β still emit a valid [START]/[STEP]/[END] per task | |
| for task_id in TASKS: | |
| log_start(task=task_id, model=MODEL_NAME) | |
| log_step(step=1, action="init_error", reward=_MIN_SCORE, done=True, error=str(exc)) | |
| log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE]) | |
| print("", flush=True) | |
| return | |
| for task_id in TASKS: | |
| try: | |
| await run_episode(env, client, task_id) | |
| except Exception as exc: | |
| # run_episode already has its own catch-all, but guard against anything leaking | |
| log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE]) | |
| print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True) | |
| print("", flush=True) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |