sql-agent-openenv / backend /env /sql_env.py
ar9avg's picture
Clamp all remaining score leak paths: /state, step_rewards, demo SSE
e99d0aa
"""
SQLAgentEnv β€” OpenEnv-compliant environment for SQL generation.
Observation β†’ Action β†’ (Observation, Reward) loop.
The step() function:
1. Selects a repair prompt based on action.repair_action
2. Calls the LLM (OpenAI-compatible) to generate/repair SQL
3. Executes SQL on the benchmark DB
4. Classifies any error
5. Computes reward via grader
6. Updates LinUCB bandit
7. Returns (new_observation, reward)
Environment variables:
API_BASE_URL β€” OpenAI-compatible base URL (default: https://api.openai.com/v1)
MODEL_NAME β€” model to use (default: gpt-4o-mini)
HF_TOKEN β€” bearer token / API key
"""
from __future__ import annotations
import asyncio
import os
import re
from typing import Optional, AsyncIterator
from openai import AsyncOpenAI
from pydantic import BaseModel
from env.database import ensure_seeded, get_schema_info, execute_query
from env.tasks import get_task, get_all_tasks, TASKS
from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
from rl.error_classifier import classify_error, extract_offending_token
from rl.grader import GraderInput, compute_reward, compute_episode_reward
from rl.linucb import LinUCB
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
from rl.experience import record_episode
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
# ─── OpenEnv Models ──────────────────────────────────────────────
class Observation(BaseModel):
question: str
schema_info: str
current_sql: Optional[str] = None
error_message: Optional[str] = None
error_class: Optional[str] = None
attempt_number: int = 0
max_attempts: int = 5
task_id: str
task_difficulty: str
class Action(BaseModel):
repair_action: str # one of 8 repair action names or "generate"
custom_sql: Optional[str] = None # optional direct SQL override
class RewardInfo(BaseModel):
value: float
success: bool
done: bool
info: dict
# ─── LLM Client ──────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
_MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.environ.get("HF_TOKEN") # no default β€” must be set explicitly
# ─── Score clamping (strictly in (0, 1)) ─────────────────────────
_SCORE_MIN = 0.05
_SCORE_MAX = 0.95
def _clamp_score(x) -> float:
"""Coerce any value into strictly (0, 1). None/NaN/invalid β†’ _SCORE_MIN."""
try:
if x is None:
return _SCORE_MIN
if isinstance(x, bool):
return _SCORE_MAX if x else _SCORE_MIN
v = float(x)
if v != v or v == float("inf") or v == float("-inf"):
return _SCORE_MIN if v != float("inf") else _SCORE_MAX
except (TypeError, ValueError):
return _SCORE_MIN
return max(_SCORE_MIN, min(_SCORE_MAX, v))
def _make_client() -> AsyncOpenAI:
return AsyncOpenAI(
api_key=HF_TOKEN,
base_url=API_BASE_URL,
)
BASE_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
Rules:
- Output ONLY the SQL query, nothing else
- No markdown, no code fences, no explanation
- Use SQLite syntax
- Do not include semicolons at the end"""
_POSTGRES_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a PostgreSQL database schema, write a correct SQL query.
Rules:
- Output ONLY the SQL query, nothing else
- No markdown, no code fences, no explanation
- Use PostgreSQL syntax
- Do not include semicolons at the end"""
def get_system_prompt() -> str:
"""Return the system prompt appropriate for the currently active database dialect."""
from env.database import get_active_db_type
if get_active_db_type() == "postgres":
return _POSTGRES_SYSTEM_PROMPT
return BASE_SYSTEM_PROMPT
def _clean_sql(raw: str) -> str:
"""Strip markdown code fences and extra whitespace."""
raw = raw.strip()
raw = re.sub(r"^```(?:sql)?\s*", "", raw, flags=re.IGNORECASE)
raw = re.sub(r"\s*```$", "", raw)
return raw.strip().rstrip(";")
async def _call_llm(
system_prompt: str,
user_message: str,
stream: bool = False,
) -> AsyncIterator[str] | str:
"""Call the LLM and return the generated text."""
client = _make_client()
if stream:
async def _gen():
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
stream=True,
temperature=0.1,
)
async for chunk in resp:
if not chunk.choices:
continue
delta = chunk.choices[0].delta.content
if delta:
yield delta
return _gen()
else:
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
temperature=0.1,
)
return resp.choices[0].message.content or ""
# ─── Episode State ────────────────────────────────────────────────
class _Episode:
def __init__(self, task_id: str, question_id: str, question: str) -> None:
self.task_id = task_id
self.question_id = question_id
self.question = question
self.attempt_number = 0
self.current_sql: Optional[str] = None
self.error_message: Optional[str] = None
self.error_class: Optional[str] = None
self.steps: list[EpisodeStep] = []
self.step_rewards: list[float] = []
self.previous_error_class = None
self.consecutive_same_error = 0
self.last_action: Optional[RepairAction] = None
self.current_rl_state: Optional[RLState] = None
self.current_features: Optional[list[float]] = None
self.done = False
self.success = False
# ─── Main Environment Class ───────────────────────────────────────
class SQLAgentEnv:
"""
OpenEnv-compliant environment for SQL generation and repair.
One active episode at a time.
"""
MAX_ATTEMPTS = 5
def __init__(self) -> None:
ensure_seeded()
self._bandit = LinUCB()
self._episode: Optional[_Episode] = None
self._schema_info = get_schema_info()
def reset(self, task_id: str = "simple_queries") -> Observation:
"""Start a new episode, picking the first question of the task."""
if self._episode and self._episode.steps and not self._episode.done:
self._finalize_episode(success=False)
task = get_task(task_id)
question_obj = task.questions[0]
self._episode = _Episode(
task_id=task_id,
question_id=question_obj.id,
question=question_obj.question,
)
return self._build_observation()
def reset_with_question(
self, task_id: str, question_id: str
) -> Observation:
"""Start a new episode for a specific question."""
if self._episode and self._episode.steps and not self._episode.done:
self._finalize_episode(success=False)
task = get_task(task_id)
question_obj = next(
(q for q in task.questions if q.id == question_id), task.questions[0]
)
self._episode = _Episode(
task_id=task_id,
question_id=question_obj.id,
question=question_obj.question,
)
return self._build_observation()
async def step(self, action: Action) -> tuple[Observation, RewardInfo]:
"""
Execute one step:
1. Generate/repair SQL via LLM
2. Execute SQL
3. Grade and reward
4. Update bandit
"""
if self._episode is None:
raise RuntimeError("Call reset() before step()")
if self._episode.done:
raise RuntimeError("Episode is done. Call reset() to start a new one.")
ep = self._episode
ep.attempt_number += 1
# ── 1. Build prompt ──────────────────────────────────────
if action.custom_sql:
generated_sql = action.custom_sql
else:
generated_sql = await self._generate_sql(action, ep)
generated_sql = _clean_sql(generated_sql)
# ── 2. Execute SQL ───────────────────────────────────────
rows, error = execute_query(generated_sql)
success = error is None and len(rows) > 0
# ── 3. Grade ─────────────────────────────────────────────
task = get_task(ep.task_id)
question_obj = next(q for q in task.questions if q.id == ep.question_id)
from env.tasks import grade_response
task_score = grade_response(
ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
)
success = task_score >= 0.8
# ── 4. RL state + reward ─────────────────────────────────
current_error_class = None
error_class_name = None
if error:
ec = classify_error(error)
current_error_class = ec
error_class_name = ERROR_CLASS_NAMES[ec]
error_changed = (
ep.previous_error_class is not None
and ep.previous_error_class != current_error_class
)
if ep.previous_error_class == current_error_class:
ep.consecutive_same_error += 1
else:
ep.consecutive_same_error = 1
rl_state = RLState(
error_class=current_error_class,
attempt_number=ep.attempt_number,
previous_action=ep.last_action,
error_changed=error_changed,
consecutive_same_error=ep.consecutive_same_error,
)
ep.current_rl_state = rl_state
ep.current_features = featurize(rl_state)
grader_in = GraderInput(
success=success,
attempt_number=ep.attempt_number,
current_error_class=current_error_class,
previous_error_class=ep.previous_error_class,
)
grader_out = compute_reward(grader_in)
if ep.current_rl_state and ep.current_features:
# Determine action index
if action.repair_action == "generate":
repair_action_enum = RepairAction.REWRITE_FULL
else:
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
action.repair_action, RepairAction.REWRITE_FULL
)
step_obj = EpisodeStep(
state=ep.current_rl_state,
featurized=ep.current_features,
action=repair_action_enum,
reward=grader_out.reward,
error_message=error or "",
sql=generated_sql,
success=success,
)
ep.steps.append(step_obj)
# Store clamped reward so /state never returns raw RL values
ep.step_rewards.append(_clamp_score(task_score))
ep.current_sql = generated_sql
ep.error_message = error
ep.error_class = error_class_name
ep.previous_error_class = current_error_class
# ── 5. Done check ────────────────────────────────────────
done = success or ep.attempt_number >= self.MAX_ATTEMPTS
if done:
self._finalize_episode(success=success)
ep.done = True
ep.success = success
obs = self._build_observation()
safe_task_score = _clamp_score(task_score)
reward_info = RewardInfo(
value=safe_task_score, # strictly in (0, 1) per OpenEnv spec
success=success,
done=done,
info={
"task_score": safe_task_score,
"attempt": ep.attempt_number,
"rows": rows[:5] if rows else [],
"row_count": len(rows),
"sql": generated_sql,
},
)
return obs, reward_info
async def step_streaming(
self, action: Action
) -> AsyncIterator[dict]:
"""
Step with SSE-compatible event streaming.
Yields dicts representing stream events.
"""
if self._episode is None:
raise RuntimeError("Call reset() before step_streaming()")
ep = self._episode
ep.attempt_number += 1
yield {"type": "attempt_start", "attempt": ep.attempt_number}
# Generate SQL
if action.custom_sql:
generated_sql = action.custom_sql
yield {"type": "sql_complete", "sql": generated_sql}
else:
chunks = []
async for chunk in await self._generate_sql_streaming(action, ep):
chunks.append(chunk)
yield {"type": "sql_chunk", "chunk": chunk}
generated_sql = _clean_sql("".join(chunks))
yield {"type": "sql_complete", "sql": generated_sql}
yield {"type": "executing"}
rows, error = execute_query(generated_sql)
from env.tasks import grade_response
task_score = grade_response(
ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
)
success = task_score >= 0.8
# RL processing
current_error_class = None
error_class_name = None
repair_action_enum = RepairAction.REWRITE_FULL
if action.repair_action != "generate":
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
action.repair_action, RepairAction.REWRITE_FULL
)
if error:
ec = classify_error(error)
current_error_class = ec
error_class_name = ERROR_CLASS_NAMES[ec]
error_changed = (
ep.previous_error_class is not None
and ep.previous_error_class != current_error_class
)
if ep.previous_error_class == current_error_class:
ep.consecutive_same_error += 1
else:
ep.consecutive_same_error = 1
rl_state = RLState(
error_class=current_error_class,
attempt_number=ep.attempt_number,
previous_action=ep.last_action,
error_changed=error_changed,
consecutive_same_error=ep.consecutive_same_error,
)
ep.current_rl_state = rl_state
ep.current_features = featurize(rl_state)
_, scores = self._bandit.select_action(ep.current_features)
ucb_scores = {
REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
for i in range(len(scores))
}
yield {
"type": "rl_action",
"action": REPAIR_ACTION_NAMES[repair_action_enum],
"ucb_scores": ucb_scores,
}
yield {"type": "error", "error": error, "error_class": error_class_name}
grader_in = GraderInput(
success=success,
attempt_number=ep.attempt_number,
current_error_class=current_error_class,
previous_error_class=ep.previous_error_class,
)
grader_out = compute_reward(grader_in)
if ep.current_rl_state and ep.current_features:
step_obj = EpisodeStep(
state=ep.current_rl_state,
featurized=ep.current_features,
action=repair_action_enum,
reward=grader_out.reward,
error_message=error or "",
sql=generated_sql,
success=success,
)
ep.steps.append(step_obj)
self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward)
ep.step_rewards.append(_clamp_score(task_score))
ep.current_sql = generated_sql
ep.error_message = error
ep.error_class = error_class_name
ep.previous_error_class = current_error_class
yield {
"type": "rl_reward",
"reward": grader_out.reward,
"breakdown": {
"base": grader_out.breakdown.base,
"attempt_penalty": grader_out.breakdown.attempt_penalty,
"severity_bonus": grader_out.breakdown.severity_bonus,
"change_bonus": grader_out.breakdown.change_bonus,
},
}
done = success or ep.attempt_number >= self.MAX_ATTEMPTS
if success:
yield {
"type": "success",
"rows": rows,
"row_count": len(rows),
"sql": generated_sql,
}
if done:
total_reward = compute_episode_reward(ep.step_rewards, success)
self._finalize_episode(success=success)
ep.done = True
ep.success = success
yield {
"type": "rl_episode_end",
"total_reward": total_reward,
"success": success,
}
def state(self) -> dict:
if self._episode is None:
return {"active": False}
ep = self._episode
safe_rewards = [_clamp_score(r) for r in ep.step_rewards]
total = sum(safe_rewards) / max(len(safe_rewards), 1) if safe_rewards else _SCORE_MIN
return {
"active": True,
"task_id": ep.task_id,
"question_id": ep.question_id,
"question": ep.question,
"attempt_number": ep.attempt_number,
"max_attempts": self.MAX_ATTEMPTS,
"current_sql": ep.current_sql,
"error_message": ep.error_message,
"error_class": ep.error_class,
"done": ep.done,
"success": ep.success,
"step_rewards": safe_rewards,
"total_reward": _clamp_score(total),
}
# ─── Private Helpers ──────────────────────────────────────────
def _build_observation(self) -> Observation:
if self._episode is None:
raise RuntimeError("No active episode")
ep = self._episode
task = get_task(ep.task_id)
return Observation(
question=ep.question,
schema_info=self._schema_info,
current_sql=ep.current_sql,
error_message=ep.error_message,
error_class=ep.error_class,
attempt_number=ep.attempt_number,
max_attempts=self.MAX_ATTEMPTS,
task_id=ep.task_id,
task_difficulty=task.difficulty,
)
async def _generate_sql(self, action: Action, ep: _Episode) -> str:
if action.repair_action == "generate" or ep.current_sql is None:
system = BASE_SYSTEM_PROMPT
user = (
f"Schema:\n{self._schema_info}\n\n"
f"Question: {ep.question}\n\n"
"Write a SQL query to answer this question."
)
else:
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
action.repair_action, RepairAction.REWRITE_FULL
)
suffix = get_repair_system_suffix(repair_action_enum)
offending_token = extract_offending_token(ep.error_message or "")
ctx = RepairContext(
schema=self._schema_info,
question=ep.question,
failing_sql=ep.current_sql or "",
error_message=ep.error_message or "",
offending_token=offending_token,
)
system = BASE_SYSTEM_PROMPT + suffix
user = build_repair_user_message(repair_action_enum, ctx)
result = await _call_llm(system, user, stream=False)
return result # type: ignore[return-value]
async def _generate_sql_streaming(
self, action: Action, ep: _Episode
) -> AsyncIterator[str]:
if action.repair_action == "generate" or ep.current_sql is None:
system = BASE_SYSTEM_PROMPT
user = (
f"Schema:\n{self._schema_info}\n\n"
f"Question: {ep.question}\n\n"
"Write a SQL query to answer this question."
)
else:
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
action.repair_action, RepairAction.REWRITE_FULL
)
suffix = get_repair_system_suffix(repair_action_enum)
offending_token = extract_offending_token(ep.error_message or "")
ctx = RepairContext(
schema=self._schema_info,
question=ep.question,
failing_sql=ep.current_sql or "",
error_message=ep.error_message or "",
offending_token=offending_token,
)
system = BASE_SYSTEM_PROMPT + suffix
user = build_repair_user_message(repair_action_enum, ctx)
return await _call_llm(system, user, stream=True) # type: ignore[return-value]
def _finalize_episode(self, success: bool) -> None:
ep = self._episode
if ep is None or not ep.steps:
return
try:
episode_obj, relabeled = record_episode(ep.question, ep.steps, success)
for exp in relabeled:
self._bandit.update(exp.state, exp.action, exp.reward)
self._bandit.decay_alpha()
except Exception:
pass
# ─── Singleton instance ───────────────────────────────────────────
_env_instance: Optional[SQLAgentEnv] = None
def get_env() -> SQLAgentEnv:
global _env_instance
if _env_instance is None:
_env_instance = SQLAgentEnv()
return _env_instance