""" SQLAgentEnv — OpenEnv-compliant environment for SQL generation. Observation → Action → (Observation, Reward) loop. The step() function: 1. Selects a repair prompt based on action.repair_action 2. Calls the LLM (OpenAI-compatible) to generate/repair SQL 3. Executes SQL on the benchmark DB 4. Classifies any error 5. Computes reward via grader 6. Updates LinUCB bandit 7. Returns (new_observation, reward) Environment variables: API_BASE_URL — OpenAI-compatible base URL (default: https://api.openai.com/v1) MODEL_NAME — model to use (default: gpt-4o-mini) HF_TOKEN — bearer token / API key """ from __future__ import annotations import asyncio import os import re from typing import Optional, AsyncIterator from openai import AsyncOpenAI from pydantic import BaseModel from env.database import ensure_seeded, get_schema_info, execute_query from env.tasks import get_task, get_all_tasks, TASKS from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME from rl.error_classifier import classify_error, extract_offending_token from rl.grader import GraderInput, compute_reward, compute_episode_reward from rl.linucb import LinUCB from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message from rl.experience import record_episode from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES # ─── OpenEnv Models ────────────────────────────────────────────── class Observation(BaseModel): question: str schema_info: str current_sql: Optional[str] = None error_message: Optional[str] = None error_class: Optional[str] = None attempt_number: int = 0 max_attempts: int = 5 task_id: str task_difficulty: str class Action(BaseModel): repair_action: str # one of 8 repair action names or "generate" custom_sql: Optional[str] = None # optional direct SQL override class RewardInfo(BaseModel): value: float success: bool done: bool info: dict # ─── LLM Client ────────────────────────────────────────────────── API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") _MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") HF_TOKEN = os.environ.get("HF_TOKEN") # no default — must be set explicitly # ─── Score clamping (strictly in (0, 1)) ───────────────────────── _SCORE_MIN = 0.05 _SCORE_MAX = 0.95 def _clamp_score(x) -> float: """Coerce any value into strictly (0, 1). None/NaN/invalid → _SCORE_MIN.""" try: if x is None: return _SCORE_MIN if isinstance(x, bool): return _SCORE_MAX if x else _SCORE_MIN v = float(x) if v != v or v == float("inf") or v == float("-inf"): return _SCORE_MIN if v != float("inf") else _SCORE_MAX except (TypeError, ValueError): return _SCORE_MIN return max(_SCORE_MIN, min(_SCORE_MAX, v)) def _make_client() -> AsyncOpenAI: return AsyncOpenAI( api_key=HF_TOKEN, base_url=API_BASE_URL, ) BASE_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query. Rules: - Output ONLY the SQL query, nothing else - No markdown, no code fences, no explanation - Use SQLite syntax - Do not include semicolons at the end""" _POSTGRES_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a PostgreSQL database schema, write a correct SQL query. Rules: - Output ONLY the SQL query, nothing else - No markdown, no code fences, no explanation - Use PostgreSQL syntax - Do not include semicolons at the end""" def get_system_prompt() -> str: """Return the system prompt appropriate for the currently active database dialect.""" from env.database import get_active_db_type if get_active_db_type() == "postgres": return _POSTGRES_SYSTEM_PROMPT return BASE_SYSTEM_PROMPT def _clean_sql(raw: str) -> str: """Strip markdown code fences and extra whitespace.""" raw = raw.strip() raw = re.sub(r"^```(?:sql)?\s*", "", raw, flags=re.IGNORECASE) raw = re.sub(r"\s*```$", "", raw) return raw.strip().rstrip(";") async def _call_llm( system_prompt: str, user_message: str, stream: bool = False, ) -> AsyncIterator[str] | str: """Call the LLM and return the generated text.""" client = _make_client() if stream: async def _gen(): resp = await client.chat.completions.create( model=_MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ], stream=True, temperature=0.1, ) async for chunk in resp: if not chunk.choices: continue delta = chunk.choices[0].delta.content if delta: yield delta return _gen() else: resp = await client.chat.completions.create( model=_MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ], temperature=0.1, ) return resp.choices[0].message.content or "" # ─── Episode State ──────────────────────────────────────────────── class _Episode: def __init__(self, task_id: str, question_id: str, question: str) -> None: self.task_id = task_id self.question_id = question_id self.question = question self.attempt_number = 0 self.current_sql: Optional[str] = None self.error_message: Optional[str] = None self.error_class: Optional[str] = None self.steps: list[EpisodeStep] = [] self.step_rewards: list[float] = [] self.previous_error_class = None self.consecutive_same_error = 0 self.last_action: Optional[RepairAction] = None self.current_rl_state: Optional[RLState] = None self.current_features: Optional[list[float]] = None self.done = False self.success = False # ─── Main Environment Class ─────────────────────────────────────── class SQLAgentEnv: """ OpenEnv-compliant environment for SQL generation and repair. One active episode at a time. """ MAX_ATTEMPTS = 5 def __init__(self) -> None: ensure_seeded() self._bandit = LinUCB() self._episode: Optional[_Episode] = None self._schema_info = get_schema_info() def reset(self, task_id: str = "simple_queries") -> Observation: """Start a new episode, picking the first question of the task.""" if self._episode and self._episode.steps and not self._episode.done: self._finalize_episode(success=False) task = get_task(task_id) question_obj = task.questions[0] self._episode = _Episode( task_id=task_id, question_id=question_obj.id, question=question_obj.question, ) return self._build_observation() def reset_with_question( self, task_id: str, question_id: str ) -> Observation: """Start a new episode for a specific question.""" if self._episode and self._episode.steps and not self._episode.done: self._finalize_episode(success=False) task = get_task(task_id) question_obj = next( (q for q in task.questions if q.id == question_id), task.questions[0] ) self._episode = _Episode( task_id=task_id, question_id=question_obj.id, question=question_obj.question, ) return self._build_observation() async def step(self, action: Action) -> tuple[Observation, RewardInfo]: """ Execute one step: 1. Generate/repair SQL via LLM 2. Execute SQL 3. Grade and reward 4. Update bandit """ if self._episode is None: raise RuntimeError("Call reset() before step()") if self._episode.done: raise RuntimeError("Episode is done. Call reset() to start a new one.") ep = self._episode ep.attempt_number += 1 # ── 1. Build prompt ────────────────────────────────────── if action.custom_sql: generated_sql = action.custom_sql else: generated_sql = await self._generate_sql(action, ep) generated_sql = _clean_sql(generated_sql) # ── 2. Execute SQL ─────────────────────────────────────── rows, error = execute_query(generated_sql) success = error is None and len(rows) > 0 # ── 3. Grade ───────────────────────────────────────────── task = get_task(ep.task_id) question_obj = next(q for q in task.questions if q.id == ep.question_id) from env.tasks import grade_response task_score = grade_response( ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number ) success = task_score >= 0.8 # ── 4. RL state + reward ───────────────────────────────── current_error_class = None error_class_name = None if error: ec = classify_error(error) current_error_class = ec error_class_name = ERROR_CLASS_NAMES[ec] error_changed = ( ep.previous_error_class is not None and ep.previous_error_class != current_error_class ) if ep.previous_error_class == current_error_class: ep.consecutive_same_error += 1 else: ep.consecutive_same_error = 1 rl_state = RLState( error_class=current_error_class, attempt_number=ep.attempt_number, previous_action=ep.last_action, error_changed=error_changed, consecutive_same_error=ep.consecutive_same_error, ) ep.current_rl_state = rl_state ep.current_features = featurize(rl_state) grader_in = GraderInput( success=success, attempt_number=ep.attempt_number, current_error_class=current_error_class, previous_error_class=ep.previous_error_class, ) grader_out = compute_reward(grader_in) if ep.current_rl_state and ep.current_features: # Determine action index if action.repair_action == "generate": repair_action_enum = RepairAction.REWRITE_FULL else: repair_action_enum = REPAIR_ACTION_BY_NAME.get( action.repair_action, RepairAction.REWRITE_FULL ) step_obj = EpisodeStep( state=ep.current_rl_state, featurized=ep.current_features, action=repair_action_enum, reward=grader_out.reward, error_message=error or "", sql=generated_sql, success=success, ) ep.steps.append(step_obj) # Store clamped reward so /state never returns raw RL values ep.step_rewards.append(_clamp_score(task_score)) ep.current_sql = generated_sql ep.error_message = error ep.error_class = error_class_name ep.previous_error_class = current_error_class # ── 5. Done check ──────────────────────────────────────── done = success or ep.attempt_number >= self.MAX_ATTEMPTS if done: self._finalize_episode(success=success) ep.done = True ep.success = success obs = self._build_observation() safe_task_score = _clamp_score(task_score) reward_info = RewardInfo( value=safe_task_score, # strictly in (0, 1) per OpenEnv spec success=success, done=done, info={ "task_score": safe_task_score, "attempt": ep.attempt_number, "rows": rows[:5] if rows else [], "row_count": len(rows), "sql": generated_sql, }, ) return obs, reward_info async def step_streaming( self, action: Action ) -> AsyncIterator[dict]: """ Step with SSE-compatible event streaming. Yields dicts representing stream events. """ if self._episode is None: raise RuntimeError("Call reset() before step_streaming()") ep = self._episode ep.attempt_number += 1 yield {"type": "attempt_start", "attempt": ep.attempt_number} # Generate SQL if action.custom_sql: generated_sql = action.custom_sql yield {"type": "sql_complete", "sql": generated_sql} else: chunks = [] async for chunk in await self._generate_sql_streaming(action, ep): chunks.append(chunk) yield {"type": "sql_chunk", "chunk": chunk} generated_sql = _clean_sql("".join(chunks)) yield {"type": "sql_complete", "sql": generated_sql} yield {"type": "executing"} rows, error = execute_query(generated_sql) from env.tasks import grade_response task_score = grade_response( ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number ) success = task_score >= 0.8 # RL processing current_error_class = None error_class_name = None repair_action_enum = RepairAction.REWRITE_FULL if action.repair_action != "generate": repair_action_enum = REPAIR_ACTION_BY_NAME.get( action.repair_action, RepairAction.REWRITE_FULL ) if error: ec = classify_error(error) current_error_class = ec error_class_name = ERROR_CLASS_NAMES[ec] error_changed = ( ep.previous_error_class is not None and ep.previous_error_class != current_error_class ) if ep.previous_error_class == current_error_class: ep.consecutive_same_error += 1 else: ep.consecutive_same_error = 1 rl_state = RLState( error_class=current_error_class, attempt_number=ep.attempt_number, previous_action=ep.last_action, error_changed=error_changed, consecutive_same_error=ep.consecutive_same_error, ) ep.current_rl_state = rl_state ep.current_features = featurize(rl_state) _, scores = self._bandit.select_action(ep.current_features) ucb_scores = { REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4) for i in range(len(scores)) } yield { "type": "rl_action", "action": REPAIR_ACTION_NAMES[repair_action_enum], "ucb_scores": ucb_scores, } yield {"type": "error", "error": error, "error_class": error_class_name} grader_in = GraderInput( success=success, attempt_number=ep.attempt_number, current_error_class=current_error_class, previous_error_class=ep.previous_error_class, ) grader_out = compute_reward(grader_in) if ep.current_rl_state and ep.current_features: step_obj = EpisodeStep( state=ep.current_rl_state, featurized=ep.current_features, action=repair_action_enum, reward=grader_out.reward, error_message=error or "", sql=generated_sql, success=success, ) ep.steps.append(step_obj) self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward) ep.step_rewards.append(_clamp_score(task_score)) ep.current_sql = generated_sql ep.error_message = error ep.error_class = error_class_name ep.previous_error_class = current_error_class yield { "type": "rl_reward", "reward": grader_out.reward, "breakdown": { "base": grader_out.breakdown.base, "attempt_penalty": grader_out.breakdown.attempt_penalty, "severity_bonus": grader_out.breakdown.severity_bonus, "change_bonus": grader_out.breakdown.change_bonus, }, } done = success or ep.attempt_number >= self.MAX_ATTEMPTS if success: yield { "type": "success", "rows": rows, "row_count": len(rows), "sql": generated_sql, } if done: total_reward = compute_episode_reward(ep.step_rewards, success) self._finalize_episode(success=success) ep.done = True ep.success = success yield { "type": "rl_episode_end", "total_reward": total_reward, "success": success, } def state(self) -> dict: if self._episode is None: return {"active": False} ep = self._episode safe_rewards = [_clamp_score(r) for r in ep.step_rewards] total = sum(safe_rewards) / max(len(safe_rewards), 1) if safe_rewards else _SCORE_MIN return { "active": True, "task_id": ep.task_id, "question_id": ep.question_id, "question": ep.question, "attempt_number": ep.attempt_number, "max_attempts": self.MAX_ATTEMPTS, "current_sql": ep.current_sql, "error_message": ep.error_message, "error_class": ep.error_class, "done": ep.done, "success": ep.success, "step_rewards": safe_rewards, "total_reward": _clamp_score(total), } # ─── Private Helpers ────────────────────────────────────────── def _build_observation(self) -> Observation: if self._episode is None: raise RuntimeError("No active episode") ep = self._episode task = get_task(ep.task_id) return Observation( question=ep.question, schema_info=self._schema_info, current_sql=ep.current_sql, error_message=ep.error_message, error_class=ep.error_class, attempt_number=ep.attempt_number, max_attempts=self.MAX_ATTEMPTS, task_id=ep.task_id, task_difficulty=task.difficulty, ) async def _generate_sql(self, action: Action, ep: _Episode) -> str: if action.repair_action == "generate" or ep.current_sql is None: system = BASE_SYSTEM_PROMPT user = ( f"Schema:\n{self._schema_info}\n\n" f"Question: {ep.question}\n\n" "Write a SQL query to answer this question." ) else: repair_action_enum = REPAIR_ACTION_BY_NAME.get( action.repair_action, RepairAction.REWRITE_FULL ) suffix = get_repair_system_suffix(repair_action_enum) offending_token = extract_offending_token(ep.error_message or "") ctx = RepairContext( schema=self._schema_info, question=ep.question, failing_sql=ep.current_sql or "", error_message=ep.error_message or "", offending_token=offending_token, ) system = BASE_SYSTEM_PROMPT + suffix user = build_repair_user_message(repair_action_enum, ctx) result = await _call_llm(system, user, stream=False) return result # type: ignore[return-value] async def _generate_sql_streaming( self, action: Action, ep: _Episode ) -> AsyncIterator[str]: if action.repair_action == "generate" or ep.current_sql is None: system = BASE_SYSTEM_PROMPT user = ( f"Schema:\n{self._schema_info}\n\n" f"Question: {ep.question}\n\n" "Write a SQL query to answer this question." ) else: repair_action_enum = REPAIR_ACTION_BY_NAME.get( action.repair_action, RepairAction.REWRITE_FULL ) suffix = get_repair_system_suffix(repair_action_enum) offending_token = extract_offending_token(ep.error_message or "") ctx = RepairContext( schema=self._schema_info, question=ep.question, failing_sql=ep.current_sql or "", error_message=ep.error_message or "", offending_token=offending_token, ) system = BASE_SYSTEM_PROMPT + suffix user = build_repair_user_message(repair_action_enum, ctx) return await _call_llm(system, user, stream=True) # type: ignore[return-value] def _finalize_episode(self, success: bool) -> None: ep = self._episode if ep is None or not ep.steps: return try: episode_obj, relabeled = record_episode(ep.question, ep.steps, success) for exp in relabeled: self._bandit.update(exp.state, exp.action, exp.reward) self._bandit.decay_alpha() except Exception: pass # ─── Singleton instance ─────────────────────────────────────────── _env_instance: Optional[SQLAgentEnv] = None def get_env() -> SQLAgentEnv: global _env_instance if _env_instance is None: _env_instance = SQLAgentEnv() return _env_instance