Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| from typing import Optional | |
| from .models import Action, Observation, StepResult, EnvState, QueryResult | |
| from .database import create_database, seed_database, get_schema_summary | |
| from .reward import RewardCalculator | |
| from .tasks import TASKS | |
| class SQLAnalystEnv: | |
| """ | |
| OpenEnv-compliant SQL Data Analyst environment. | |
| An agent must answer business questions by iteratively | |
| writing and executing SQL queries. | |
| """ | |
| def __init__(self, task_id: str = "monthly_signups"): | |
| assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" | |
| self.task_id = task_id | |
| self.task = TASKS[task_id] | |
| self.conn: Optional[sqlite3.Connection] = None | |
| self.step_count: int = 0 | |
| self.total_reward: float = 0.0 | |
| self.done: bool = False | |
| self._query_history: list = [] | |
| self._reward_calc = RewardCalculator() | |
| def reset(self) -> StepResult: | |
| """Reset environment. Reseed DB. Return initial observation.""" | |
| if self.conn: | |
| self.conn.close() | |
| self.conn = create_database() | |
| seed_database(self.conn) | |
| self.step_count = 0 | |
| self.total_reward = 0.0 | |
| self.done = False | |
| self._query_history = [] | |
| self.task.compute_ground_truth(self.conn) | |
| obs = Observation( | |
| schema_summary=get_schema_summary(self.conn), | |
| question=self.task.question, | |
| step=0, | |
| max_steps=self.task.max_steps, | |
| ) | |
| return StepResult(observation=obs, reward=0.0, done=False) | |
| def step(self, action: Action) -> StepResult: | |
| """Execute one agent action. Return (observation, reward, done, info).""" | |
| assert self.conn is not None, "Call reset() before step()" | |
| assert not self.done, "Episode is done. Call reset()." | |
| assert action.is_valid(), ( | |
| "Action must have exactly one of: sql_query, submit_answer" | |
| ) | |
| self.step_count += 1 | |
| query_result = None | |
| error = None | |
| if action.sql_query: | |
| query_result = self._execute_sql(action.sql_query) | |
| self._query_history.append(action.sql_query) | |
| error = query_result.error | |
| terminal = ( | |
| action.submit_answer is not None or self.step_count >= self.task.max_steps | |
| ) | |
| reward = self._reward_calc.calculate( | |
| action=action, | |
| result=query_result, | |
| task=self.task, | |
| step=self.step_count, | |
| query_history=self._query_history, | |
| terminal=terminal, | |
| ) | |
| self.total_reward += reward | |
| self.done = terminal | |
| obs = Observation( | |
| schema_summary=get_schema_summary(self.conn), | |
| question=self.task.question, | |
| last_query=action.sql_query, | |
| last_result=query_result, | |
| last_error=error, | |
| step=self.step_count, | |
| max_steps=self.task.max_steps, | |
| hints=self.task.get_hints(self.step_count), | |
| done=self.done, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=round(reward, 3), | |
| done=self.done, | |
| info={ | |
| "step": self.step_count, | |
| "total_reward": round(self.total_reward, 3), | |
| "task_id": self.task_id, | |
| }, | |
| ) | |
| def state(self) -> EnvState: | |
| """Return current full state of the environment.""" | |
| return EnvState( | |
| task_id=self.task_id, | |
| difficulty=self.task.difficulty, | |
| step=self.step_count, | |
| max_steps=self.task.max_steps, | |
| query_history=self._query_history.copy(), | |
| total_reward=round(self.total_reward, 3), | |
| done=self.done, | |
| ) | |
| def _execute_sql(self, query: str) -> QueryResult: | |
| """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" | |
| q = query.strip().upper() | |
| if not q.startswith("SELECT") and not q.startswith("WITH"): | |
| return QueryResult(error="Only SELECT / WITH queries are allowed.") | |
| try: | |
| cursor = self.conn.execute(query) | |
| cols = [d[0] for d in cursor.description] if cursor.description else [] | |
| rows = cursor.fetchmany(50) | |
| total = len(rows) | |
| return QueryResult( | |
| columns=cols, | |
| rows=[list(r) for r in rows], | |
| truncated=(total == 50), | |
| total_rows=total, | |
| ) | |
| except Exception as e: | |
| return QueryResult(error=str(e)) | |