Spaces:
Running
Running
| """ | |
| SQLAgentEnv β OpenEnv-compliant environment for SQL generation. | |
| Observation β Action β (Observation, Reward) loop. | |
| The step() function: | |
| 1. Selects a repair prompt based on action.repair_action | |
| 2. Calls the LLM (OpenAI-compatible) to generate/repair SQL | |
| 3. Executes SQL on the benchmark DB | |
| 4. Classifies any error | |
| 5. Computes reward via grader | |
| 6. Updates LinUCB bandit | |
| 7. Returns (new_observation, reward) | |
| Environment variables: | |
| API_BASE_URL β OpenAI-compatible base URL (default: https://api.openai.com/v1) | |
| MODEL_NAME β model to use (default: gpt-4o-mini) | |
| HF_TOKEN β bearer token / API key | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import re | |
| from typing import Optional, AsyncIterator | |
| from openai import AsyncOpenAI | |
| from pydantic import BaseModel | |
| from env.database import ensure_seeded, get_schema_info, execute_query | |
| from env.tasks import get_task, get_all_tasks, TASKS | |
| from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME | |
| from rl.error_classifier import classify_error, extract_offending_token | |
| from rl.grader import GraderInput, compute_reward, compute_episode_reward | |
| from rl.linucb import LinUCB | |
| from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message | |
| from rl.experience import record_episode | |
| from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES | |
| # βββ OpenEnv Models ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Observation(BaseModel): | |
| question: str | |
| schema_info: str | |
| current_sql: Optional[str] = None | |
| error_message: Optional[str] = None | |
| error_class: Optional[str] = None | |
| attempt_number: int = 0 | |
| max_attempts: int = 5 | |
| task_id: str | |
| task_difficulty: str | |
| class Action(BaseModel): | |
| repair_action: str # one of 8 repair action names or "generate" | |
| custom_sql: Optional[str] = None # optional direct SQL override | |
| class RewardInfo(BaseModel): | |
| value: float | |
| success: bool | |
| done: bool | |
| info: dict | |
| # βββ LLM Client ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| _MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # no default β must be set explicitly | |
| # βββ Score clamping (strictly in (0, 1)) βββββββββββββββββββββββββ | |
| _SCORE_MIN = 0.05 | |
| _SCORE_MAX = 0.95 | |
| def _clamp_score(x) -> float: | |
| """Coerce any value into strictly (0, 1). None/NaN/invalid β _SCORE_MIN.""" | |
| try: | |
| if x is None: | |
| return _SCORE_MIN | |
| if isinstance(x, bool): | |
| return _SCORE_MAX if x else _SCORE_MIN | |
| v = float(x) | |
| if v != v or v == float("inf") or v == float("-inf"): | |
| return _SCORE_MIN if v != float("inf") else _SCORE_MAX | |
| except (TypeError, ValueError): | |
| return _SCORE_MIN | |
| return max(_SCORE_MIN, min(_SCORE_MAX, v)) | |
| def _make_client() -> AsyncOpenAI: | |
| return AsyncOpenAI( | |
| api_key=HF_TOKEN, | |
| base_url=API_BASE_URL, | |
| ) | |
| BASE_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query. | |
| Rules: | |
| - Output ONLY the SQL query, nothing else | |
| - No markdown, no code fences, no explanation | |
| - Use SQLite syntax | |
| - Do not include semicolons at the end""" | |
| _POSTGRES_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a PostgreSQL database schema, write a correct SQL query. | |
| Rules: | |
| - Output ONLY the SQL query, nothing else | |
| - No markdown, no code fences, no explanation | |
| - Use PostgreSQL syntax | |
| - Do not include semicolons at the end""" | |
| def get_system_prompt() -> str: | |
| """Return the system prompt appropriate for the currently active database dialect.""" | |
| from env.database import get_active_db_type | |
| if get_active_db_type() == "postgres": | |
| return _POSTGRES_SYSTEM_PROMPT | |
| return BASE_SYSTEM_PROMPT | |
| def _clean_sql(raw: str) -> str: | |
| """Strip markdown code fences and extra whitespace.""" | |
| raw = raw.strip() | |
| raw = re.sub(r"^```(?:sql)?\s*", "", raw, flags=re.IGNORECASE) | |
| raw = re.sub(r"\s*```$", "", raw) | |
| return raw.strip().rstrip(";") | |
| async def _call_llm( | |
| system_prompt: str, | |
| user_message: str, | |
| stream: bool = False, | |
| ) -> AsyncIterator[str] | str: | |
| """Call the LLM and return the generated text.""" | |
| client = _make_client() | |
| if stream: | |
| async def _gen(): | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message}, | |
| ], | |
| stream=True, | |
| temperature=0.1, | |
| ) | |
| async for chunk in resp: | |
| if not chunk.choices: | |
| continue | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| return _gen() | |
| else: | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message}, | |
| ], | |
| temperature=0.1, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| # βββ Episode State ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _Episode: | |
| def __init__(self, task_id: str, question_id: str, question: str) -> None: | |
| self.task_id = task_id | |
| self.question_id = question_id | |
| self.question = question | |
| self.attempt_number = 0 | |
| self.current_sql: Optional[str] = None | |
| self.error_message: Optional[str] = None | |
| self.error_class: Optional[str] = None | |
| self.steps: list[EpisodeStep] = [] | |
| self.step_rewards: list[float] = [] | |
| self.previous_error_class = None | |
| self.consecutive_same_error = 0 | |
| self.last_action: Optional[RepairAction] = None | |
| self.current_rl_state: Optional[RLState] = None | |
| self.current_features: Optional[list[float]] = None | |
| self.done = False | |
| self.success = False | |
| # βββ Main Environment Class βββββββββββββββββββββββββββββββββββββββ | |
| class SQLAgentEnv: | |
| """ | |
| OpenEnv-compliant environment for SQL generation and repair. | |
| One active episode at a time. | |
| """ | |
| MAX_ATTEMPTS = 5 | |
| def __init__(self) -> None: | |
| ensure_seeded() | |
| self._bandit = LinUCB() | |
| self._episode: Optional[_Episode] = None | |
| self._schema_info = get_schema_info() | |
| def reset(self, task_id: str = "simple_queries") -> Observation: | |
| """Start a new episode, picking the first question of the task.""" | |
| if self._episode and self._episode.steps and not self._episode.done: | |
| self._finalize_episode(success=False) | |
| task = get_task(task_id) | |
| question_obj = task.questions[0] | |
| self._episode = _Episode( | |
| task_id=task_id, | |
| question_id=question_obj.id, | |
| question=question_obj.question, | |
| ) | |
| return self._build_observation() | |
| def reset_with_question( | |
| self, task_id: str, question_id: str | |
| ) -> Observation: | |
| """Start a new episode for a specific question.""" | |
| if self._episode and self._episode.steps and not self._episode.done: | |
| self._finalize_episode(success=False) | |
| task = get_task(task_id) | |
| question_obj = next( | |
| (q for q in task.questions if q.id == question_id), task.questions[0] | |
| ) | |
| self._episode = _Episode( | |
| task_id=task_id, | |
| question_id=question_obj.id, | |
| question=question_obj.question, | |
| ) | |
| return self._build_observation() | |
| async def step(self, action: Action) -> tuple[Observation, RewardInfo]: | |
| """ | |
| Execute one step: | |
| 1. Generate/repair SQL via LLM | |
| 2. Execute SQL | |
| 3. Grade and reward | |
| 4. Update bandit | |
| """ | |
| if self._episode is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self._episode.done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new one.") | |
| ep = self._episode | |
| ep.attempt_number += 1 | |
| # ββ 1. Build prompt ββββββββββββββββββββββββββββββββββββββ | |
| if action.custom_sql: | |
| generated_sql = action.custom_sql | |
| else: | |
| generated_sql = await self._generate_sql(action, ep) | |
| generated_sql = _clean_sql(generated_sql) | |
| # ββ 2. Execute SQL βββββββββββββββββββββββββββββββββββββββ | |
| rows, error = execute_query(generated_sql) | |
| success = error is None and len(rows) > 0 | |
| # ββ 3. Grade βββββββββββββββββββββββββββββββββββββββββββββ | |
| task = get_task(ep.task_id) | |
| question_obj = next(q for q in task.questions if q.id == ep.question_id) | |
| from env.tasks import grade_response | |
| task_score = grade_response( | |
| ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number | |
| ) | |
| success = task_score >= 0.8 | |
| # ββ 4. RL state + reward βββββββββββββββββββββββββββββββββ | |
| current_error_class = None | |
| error_class_name = None | |
| if error: | |
| ec = classify_error(error) | |
| current_error_class = ec | |
| error_class_name = ERROR_CLASS_NAMES[ec] | |
| error_changed = ( | |
| ep.previous_error_class is not None | |
| and ep.previous_error_class != current_error_class | |
| ) | |
| if ep.previous_error_class == current_error_class: | |
| ep.consecutive_same_error += 1 | |
| else: | |
| ep.consecutive_same_error = 1 | |
| rl_state = RLState( | |
| error_class=current_error_class, | |
| attempt_number=ep.attempt_number, | |
| previous_action=ep.last_action, | |
| error_changed=error_changed, | |
| consecutive_same_error=ep.consecutive_same_error, | |
| ) | |
| ep.current_rl_state = rl_state | |
| ep.current_features = featurize(rl_state) | |
| grader_in = GraderInput( | |
| success=success, | |
| attempt_number=ep.attempt_number, | |
| current_error_class=current_error_class, | |
| previous_error_class=ep.previous_error_class, | |
| ) | |
| grader_out = compute_reward(grader_in) | |
| if ep.current_rl_state and ep.current_features: | |
| # Determine action index | |
| if action.repair_action == "generate": | |
| repair_action_enum = RepairAction.REWRITE_FULL | |
| else: | |
| repair_action_enum = REPAIR_ACTION_BY_NAME.get( | |
| action.repair_action, RepairAction.REWRITE_FULL | |
| ) | |
| step_obj = EpisodeStep( | |
| state=ep.current_rl_state, | |
| featurized=ep.current_features, | |
| action=repair_action_enum, | |
| reward=grader_out.reward, | |
| error_message=error or "", | |
| sql=generated_sql, | |
| success=success, | |
| ) | |
| ep.steps.append(step_obj) | |
| # Store clamped reward so /state never returns raw RL values | |
| ep.step_rewards.append(_clamp_score(task_score)) | |
| ep.current_sql = generated_sql | |
| ep.error_message = error | |
| ep.error_class = error_class_name | |
| ep.previous_error_class = current_error_class | |
| # ββ 5. Done check ββββββββββββββββββββββββββββββββββββββββ | |
| done = success or ep.attempt_number >= self.MAX_ATTEMPTS | |
| if done: | |
| self._finalize_episode(success=success) | |
| ep.done = True | |
| ep.success = success | |
| obs = self._build_observation() | |
| safe_task_score = _clamp_score(task_score) | |
| reward_info = RewardInfo( | |
| value=safe_task_score, # strictly in (0, 1) per OpenEnv spec | |
| success=success, | |
| done=done, | |
| info={ | |
| "task_score": safe_task_score, | |
| "attempt": ep.attempt_number, | |
| "rows": rows[:5] if rows else [], | |
| "row_count": len(rows), | |
| "sql": generated_sql, | |
| }, | |
| ) | |
| return obs, reward_info | |
| async def step_streaming( | |
| self, action: Action | |
| ) -> AsyncIterator[dict]: | |
| """ | |
| Step with SSE-compatible event streaming. | |
| Yields dicts representing stream events. | |
| """ | |
| if self._episode is None: | |
| raise RuntimeError("Call reset() before step_streaming()") | |
| ep = self._episode | |
| ep.attempt_number += 1 | |
| yield {"type": "attempt_start", "attempt": ep.attempt_number} | |
| # Generate SQL | |
| if action.custom_sql: | |
| generated_sql = action.custom_sql | |
| yield {"type": "sql_complete", "sql": generated_sql} | |
| else: | |
| chunks = [] | |
| async for chunk in await self._generate_sql_streaming(action, ep): | |
| chunks.append(chunk) | |
| yield {"type": "sql_chunk", "chunk": chunk} | |
| generated_sql = _clean_sql("".join(chunks)) | |
| yield {"type": "sql_complete", "sql": generated_sql} | |
| yield {"type": "executing"} | |
| rows, error = execute_query(generated_sql) | |
| from env.tasks import grade_response | |
| task_score = grade_response( | |
| ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number | |
| ) | |
| success = task_score >= 0.8 | |
| # RL processing | |
| current_error_class = None | |
| error_class_name = None | |
| repair_action_enum = RepairAction.REWRITE_FULL | |
| if action.repair_action != "generate": | |
| repair_action_enum = REPAIR_ACTION_BY_NAME.get( | |
| action.repair_action, RepairAction.REWRITE_FULL | |
| ) | |
| if error: | |
| ec = classify_error(error) | |
| current_error_class = ec | |
| error_class_name = ERROR_CLASS_NAMES[ec] | |
| error_changed = ( | |
| ep.previous_error_class is not None | |
| and ep.previous_error_class != current_error_class | |
| ) | |
| if ep.previous_error_class == current_error_class: | |
| ep.consecutive_same_error += 1 | |
| else: | |
| ep.consecutive_same_error = 1 | |
| rl_state = RLState( | |
| error_class=current_error_class, | |
| attempt_number=ep.attempt_number, | |
| previous_action=ep.last_action, | |
| error_changed=error_changed, | |
| consecutive_same_error=ep.consecutive_same_error, | |
| ) | |
| ep.current_rl_state = rl_state | |
| ep.current_features = featurize(rl_state) | |
| _, scores = self._bandit.select_action(ep.current_features) | |
| ucb_scores = { | |
| REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4) | |
| for i in range(len(scores)) | |
| } | |
| yield { | |
| "type": "rl_action", | |
| "action": REPAIR_ACTION_NAMES[repair_action_enum], | |
| "ucb_scores": ucb_scores, | |
| } | |
| yield {"type": "error", "error": error, "error_class": error_class_name} | |
| grader_in = GraderInput( | |
| success=success, | |
| attempt_number=ep.attempt_number, | |
| current_error_class=current_error_class, | |
| previous_error_class=ep.previous_error_class, | |
| ) | |
| grader_out = compute_reward(grader_in) | |
| if ep.current_rl_state and ep.current_features: | |
| step_obj = EpisodeStep( | |
| state=ep.current_rl_state, | |
| featurized=ep.current_features, | |
| action=repair_action_enum, | |
| reward=grader_out.reward, | |
| error_message=error or "", | |
| sql=generated_sql, | |
| success=success, | |
| ) | |
| ep.steps.append(step_obj) | |
| self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward) | |
| ep.step_rewards.append(_clamp_score(task_score)) | |
| ep.current_sql = generated_sql | |
| ep.error_message = error | |
| ep.error_class = error_class_name | |
| ep.previous_error_class = current_error_class | |
| yield { | |
| "type": "rl_reward", | |
| "reward": grader_out.reward, | |
| "breakdown": { | |
| "base": grader_out.breakdown.base, | |
| "attempt_penalty": grader_out.breakdown.attempt_penalty, | |
| "severity_bonus": grader_out.breakdown.severity_bonus, | |
| "change_bonus": grader_out.breakdown.change_bonus, | |
| }, | |
| } | |
| done = success or ep.attempt_number >= self.MAX_ATTEMPTS | |
| if success: | |
| yield { | |
| "type": "success", | |
| "rows": rows, | |
| "row_count": len(rows), | |
| "sql": generated_sql, | |
| } | |
| if done: | |
| total_reward = compute_episode_reward(ep.step_rewards, success) | |
| self._finalize_episode(success=success) | |
| ep.done = True | |
| ep.success = success | |
| yield { | |
| "type": "rl_episode_end", | |
| "total_reward": total_reward, | |
| "success": success, | |
| } | |
| def state(self) -> dict: | |
| if self._episode is None: | |
| return {"active": False} | |
| ep = self._episode | |
| safe_rewards = [_clamp_score(r) for r in ep.step_rewards] | |
| total = sum(safe_rewards) / max(len(safe_rewards), 1) if safe_rewards else _SCORE_MIN | |
| return { | |
| "active": True, | |
| "task_id": ep.task_id, | |
| "question_id": ep.question_id, | |
| "question": ep.question, | |
| "attempt_number": ep.attempt_number, | |
| "max_attempts": self.MAX_ATTEMPTS, | |
| "current_sql": ep.current_sql, | |
| "error_message": ep.error_message, | |
| "error_class": ep.error_class, | |
| "done": ep.done, | |
| "success": ep.success, | |
| "step_rewards": safe_rewards, | |
| "total_reward": _clamp_score(total), | |
| } | |
| # βββ Private Helpers ββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation(self) -> Observation: | |
| if self._episode is None: | |
| raise RuntimeError("No active episode") | |
| ep = self._episode | |
| task = get_task(ep.task_id) | |
| return Observation( | |
| question=ep.question, | |
| schema_info=self._schema_info, | |
| current_sql=ep.current_sql, | |
| error_message=ep.error_message, | |
| error_class=ep.error_class, | |
| attempt_number=ep.attempt_number, | |
| max_attempts=self.MAX_ATTEMPTS, | |
| task_id=ep.task_id, | |
| task_difficulty=task.difficulty, | |
| ) | |
| async def _generate_sql(self, action: Action, ep: _Episode) -> str: | |
| if action.repair_action == "generate" or ep.current_sql is None: | |
| system = BASE_SYSTEM_PROMPT | |
| user = ( | |
| f"Schema:\n{self._schema_info}\n\n" | |
| f"Question: {ep.question}\n\n" | |
| "Write a SQL query to answer this question." | |
| ) | |
| else: | |
| repair_action_enum = REPAIR_ACTION_BY_NAME.get( | |
| action.repair_action, RepairAction.REWRITE_FULL | |
| ) | |
| suffix = get_repair_system_suffix(repair_action_enum) | |
| offending_token = extract_offending_token(ep.error_message or "") | |
| ctx = RepairContext( | |
| schema=self._schema_info, | |
| question=ep.question, | |
| failing_sql=ep.current_sql or "", | |
| error_message=ep.error_message or "", | |
| offending_token=offending_token, | |
| ) | |
| system = BASE_SYSTEM_PROMPT + suffix | |
| user = build_repair_user_message(repair_action_enum, ctx) | |
| result = await _call_llm(system, user, stream=False) | |
| return result # type: ignore[return-value] | |
| async def _generate_sql_streaming( | |
| self, action: Action, ep: _Episode | |
| ) -> AsyncIterator[str]: | |
| if action.repair_action == "generate" or ep.current_sql is None: | |
| system = BASE_SYSTEM_PROMPT | |
| user = ( | |
| f"Schema:\n{self._schema_info}\n\n" | |
| f"Question: {ep.question}\n\n" | |
| "Write a SQL query to answer this question." | |
| ) | |
| else: | |
| repair_action_enum = REPAIR_ACTION_BY_NAME.get( | |
| action.repair_action, RepairAction.REWRITE_FULL | |
| ) | |
| suffix = get_repair_system_suffix(repair_action_enum) | |
| offending_token = extract_offending_token(ep.error_message or "") | |
| ctx = RepairContext( | |
| schema=self._schema_info, | |
| question=ep.question, | |
| failing_sql=ep.current_sql or "", | |
| error_message=ep.error_message or "", | |
| offending_token=offending_token, | |
| ) | |
| system = BASE_SYSTEM_PROMPT + suffix | |
| user = build_repair_user_message(repair_action_enum, ctx) | |
| return await _call_llm(system, user, stream=True) # type: ignore[return-value] | |
| def _finalize_episode(self, success: bool) -> None: | |
| ep = self._episode | |
| if ep is None or not ep.steps: | |
| return | |
| try: | |
| episode_obj, relabeled = record_episode(ep.question, ep.steps, success) | |
| for exp in relabeled: | |
| self._bandit.update(exp.state, exp.action, exp.reward) | |
| self._bandit.decay_alpha() | |
| except Exception: | |
| pass | |
| # βββ Singleton instance βββββββββββββββββββββββββββββββββββββββββββ | |
| _env_instance: Optional[SQLAgentEnv] = None | |
| def get_env() -> SQLAgentEnv: | |
| global _env_instance | |
| if _env_instance is None: | |
| _env_instance = SQLAgentEnv() | |
| return _env_instance | |