""" nl2sql-bench/server/environment.py ==================================== NL2SQL-Bench core environment — implements the OpenEnv Environment interface. Episode flow ------------ 1. reset(task_name?) → picks a task + question, returns initial observation 2. step(action) → executes the SQL, grades it, returns observation + reward 3. state() → returns episode metadata 4. Episode ends when: exact_match OR step count reaches max_steps The environment manages its own SQLite connection (in-memory, seeded deterministically). One connection per Environment instance; the FastAPI server creates one Environment per WebSocket session. """ from __future__ import annotations import os import sqlite3 import uuid from pathlib import Path from typing import Optional from openenv.core.env_server import Environment # Import after openenv so path is correct regardless of working directory _HERE = Path(__file__).parent # Lazy import of task registry (avoids circular imports) from tasks import get_task, all_task_names, BaseTask from tasks.base import TaskExample from grader import ( GradeResult, compute_ground_truth, execute_query, grade, has_order_by, ) # We import our models from one level up (models.py at project root) import sys sys.path.insert(0, str(_HERE.parent)) from models import NL2SQLAction, NL2SQLObservation, NL2SQLState # ── Constants ────────────────────────────────────────────────────────────── DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5")) RESULT_LIMIT = 10 # Max rows shown to agent per step class NL2SQLEnvironment(Environment): """ OpenEnv-compliant environment for NL-to-SQL query generation. One instance per WebSocket session (created by create_fastapi_app). """ def __init__(self) -> None: self._conn: Optional[sqlite3.Connection] = None self._task: Optional[BaseTask] = None self._example: Optional[TaskExample] = None self._ground_truth: list = [] self._order_sensitive: bool = False self._state = NL2SQLState( episode_id=None, step_count=0, task_name="", task_difficulty="", question="", best_reward=0.0, cumulative_reward=0.0, solved=False ) self._last_obs = NL2SQLObservation( question="", schema_context="", task_name="", last_query="", last_result=[], last_error=None, result_columns=[], step=0, max_steps=5, done=False, reward=None, score=0.0 ) self._episode_rewards: list = [] self._setup_db() # ── DB lifecycle ─────────────────────────────────────────────────────── def _setup_db(self) -> None: """Create in-memory SQLite DB and seed it.""" schema_path = _HERE / "db" / "schema.sql" from db.seed import seed_database # local import after sys.path setup conn = sqlite3.connect(":memory:", check_same_thread=False) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") conn.executescript(schema_path.read_text()) seed_database(conn) self._conn = conn # ── OpenEnv interface ────────────────────────────────────────────────── def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation: """ Start a new episode. task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'. Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'. """ task_name = task_name or DEFAULT_TASK if task_name not in all_task_names(): task_name = DEFAULT_TASK self._task = get_task(task_name) self._example = self._task.next_example() self._order_sensitive = has_order_by(self._example.sql) # Pre-compute ground truth once per episode self._ground_truth = compute_ground_truth(self._conn, self._example.sql) self._episode_rewards = [] self._state = NL2SQLState( episode_id=str(uuid.uuid4()), step_count=0, task_name=self._task.name, task_difficulty=self._task.difficulty, question=self._example.question, best_reward=0.0, cumulative_reward=0.0, solved=False, ) obs = NL2SQLObservation( question=self._example.question, schema_context=self._task.schema_context(), task_name=self._task.name, last_query="", last_result=[], last_error=None, result_columns=[], step=0, max_steps=MAX_STEPS, done=False, reward=None, score=0.0, ) self._last_obs = obs return obs def step(self, action: NL2SQLAction) -> NL2SQLObservation: """Execute the agent's SQL and return graded observation.""" if self._task is None or self._example is None: # Called before reset — auto-reset self.reset() self._state.step_count += 1 current_step = self._state.step_count done = False # Execute the query rows, error = execute_query(self._conn, action.query) # Grade it result: GradeResult = grade( actual_rows=rows, ground_truth_rows=self._ground_truth, error=error, step=current_step, order_sensitive=self._order_sensitive, ) reward = result.reward self._episode_rewards.append(reward) self._state.cumulative_reward += reward self._state.best_reward = max(self._state.best_reward, reward) if result.exact_match: self._state.solved = True done = True elif current_step >= MAX_STEPS: done = True # Prepare result rows for observation (truncated for agent readability) display_rows = (rows or [])[:RESULT_LIMIT] result_columns = list(display_rows[0].keys()) if display_rows else [] # Convert sqlite3.Row objects if needed display_rows = [dict(r) for r in display_rows] # Normalised cumulative score n = len(self._episode_rewards) score = self._state.cumulative_reward / max(n, 1) if n else 0.0 score = round(min(max(score, 0.0), 1.0), 4) obs = NL2SQLObservation( question=self._example.question, schema_context=self._task.schema_context(), task_name=self._task.name, last_query=action.query, last_result=display_rows, last_error=error, result_columns=result_columns, step=current_step, max_steps=MAX_STEPS, done=done, reward=reward, score=score, ) self._last_obs = obs # openenv-core expects ONLY the observation returned from step(). # The framework reads obs.reward and obs.done itself — do NOT return a tuple. return obs @property def state(self) -> NL2SQLState: return self._state # ── Helpers ──────────────────────────────────────────────────────────── def available_tasks(self) -> list: return all_task_names()