import sqlite3 from typing import Optional from .models import Action, Observation, StepResult, EnvState, QueryResult from .database import create_database, seed_database, get_schema_summary from .reward import RewardCalculator from .tasks import TASKS class SQLAnalystEnv: """ OpenEnv-compliant SQL Data Analyst environment. An agent must answer business questions by iteratively writing and executing SQL queries. """ def __init__(self, task_id: str = "monthly_signups"): assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" self.task_id = task_id self.task = TASKS[task_id] self.conn: Optional[sqlite3.Connection] = None self.step_count: int = 0 self.total_reward: float = 0.0 self.done: bool = False self._query_history: list = [] self._reward_calc = RewardCalculator() def reset(self) -> StepResult: """Reset environment. Reseed DB. Return initial observation.""" if self.conn: self.conn.close() self.conn = create_database() seed_database(self.conn) self.step_count = 0 self.total_reward = 0.0 self.done = False self._query_history = [] self.task.compute_ground_truth(self.conn) obs = Observation( schema_summary=get_schema_summary(self.conn), question=self.task.question, step=0, max_steps=self.task.max_steps, ) return StepResult(observation=obs, reward=0.0, done=False) def step(self, action: Action) -> StepResult: """Execute one agent action. Return (observation, reward, done, info).""" assert self.conn is not None, "Call reset() before step()" assert not self.done, "Episode is done. Call reset()." assert action.is_valid(), ( "Action must have exactly one of: sql_query, submit_answer" ) self.step_count += 1 query_result = None error = None if action.sql_query: query_result = self._execute_sql(action.sql_query) self._query_history.append(action.sql_query) error = query_result.error terminal = ( action.submit_answer is not None or self.step_count >= self.task.max_steps ) reward = self._reward_calc.calculate( action=action, result=query_result, task=self.task, step=self.step_count, query_history=self._query_history, terminal=terminal, ) self.total_reward += reward self.done = terminal obs = Observation( schema_summary=get_schema_summary(self.conn), question=self.task.question, last_query=action.sql_query, last_result=query_result, last_error=error, step=self.step_count, max_steps=self.task.max_steps, hints=self.task.get_hints(self.step_count), done=self.done, ) return StepResult( observation=obs, reward=round(reward, 3), done=self.done, info={ "step": self.step_count, "total_reward": round(self.total_reward, 3), "task_id": self.task_id, }, ) def state(self) -> EnvState: """Return current full state of the environment.""" return EnvState( task_id=self.task_id, difficulty=self.task.difficulty, step=self.step_count, max_steps=self.task.max_steps, query_history=self._query_history.copy(), total_reward=round(self.total_reward, 3), done=self.done, ) def _execute_sql(self, query: str) -> QueryResult: """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" q = query.strip().upper() if not q.startswith("SELECT") and not q.startswith("WITH"): return QueryResult(error="Only SELECT / WITH queries are allowed.") try: cursor = self.conn.execute(query) cols = [d[0] for d in cursor.description] if cursor.description else [] rows = cursor.fetchmany(50) total = len(rows) return QueryResult( columns=cols, rows=[list(r) for r in rows], truncated=(total == 50), total_rows=total, ) except Exception as e: return QueryResult(error=str(e))