sql_data_analyst / temp_upload /env /environment.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
raw
history blame
4.59 kB
import sqlite3
from typing import Optional
from .models import Action, Observation, StepResult, EnvState, QueryResult
from .database import create_database, seed_database, get_schema_summary
from .reward import RewardCalculator
from .tasks import TASKS
class SQLAnalystEnv:
"""
OpenEnv-compliant SQL Data Analyst environment.
An agent must answer business questions by iteratively
writing and executing SQL queries.
"""
def __init__(self, task_id: str = "monthly_signups"):
assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
self.task_id = task_id
self.task = TASKS[task_id]
self.conn: Optional[sqlite3.Connection] = None
self.step_count: int = 0
self.total_reward: float = 0.0
self.done: bool = False
self._query_history: list = []
self._reward_calc = RewardCalculator()
def reset(self) -> StepResult:
"""Reset environment. Reseed DB. Return initial observation."""
if self.conn:
self.conn.close()
self.conn = create_database()
seed_database(self.conn)
self.step_count = 0
self.total_reward = 0.0
self.done = False
self._query_history = []
self.task.compute_ground_truth(self.conn)
obs = Observation(
schema_summary=get_schema_summary(self.conn),
question=self.task.question,
step=0,
max_steps=self.task.max_steps,
)
return StepResult(observation=obs, reward=0.0, done=False)
def step(self, action: Action) -> StepResult:
"""Execute one agent action. Return (observation, reward, done, info)."""
assert self.conn is not None, "Call reset() before step()"
assert not self.done, "Episode is done. Call reset()."
assert action.is_valid(), (
"Action must have exactly one of: sql_query, submit_answer"
)
self.step_count += 1
query_result = None
error = None
if action.sql_query:
query_result = self._execute_sql(action.sql_query)
self._query_history.append(action.sql_query)
error = query_result.error
terminal = (
action.submit_answer is not None or self.step_count >= self.task.max_steps
)
reward = self._reward_calc.calculate(
action=action,
result=query_result,
task=self.task,
step=self.step_count,
query_history=self._query_history,
terminal=terminal,
)
self.total_reward += reward
self.done = terminal
obs = Observation(
schema_summary=get_schema_summary(self.conn),
question=self.task.question,
last_query=action.sql_query,
last_result=query_result,
last_error=error,
step=self.step_count,
max_steps=self.task.max_steps,
hints=self.task.get_hints(self.step_count),
done=self.done,
)
return StepResult(
observation=obs,
reward=round(reward, 3),
done=self.done,
info={
"step": self.step_count,
"total_reward": round(self.total_reward, 3),
"task_id": self.task_id,
},
)
def state(self) -> EnvState:
"""Return current full state of the environment."""
return EnvState(
task_id=self.task_id,
difficulty=self.task.difficulty,
step=self.step_count,
max_steps=self.task.max_steps,
query_history=self._query_history.copy(),
total_reward=round(self.total_reward, 3),
done=self.done,
)
def _execute_sql(self, query: str) -> QueryResult:
"""Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
q = query.strip().upper()
if not q.startswith("SELECT") and not q.startswith("WITH"):
return QueryResult(error="Only SELECT / WITH queries are allowed.")
try:
cursor = self.conn.execute(query)
cols = [d[0] for d in cursor.description] if cursor.description else []
rows = cursor.fetchmany(50)
total = len(rows)
return QueryResult(
columns=cols,
rows=[list(r) for r in rows],
truncated=(total == 50),
total_rows=total,
)
except Exception as e:
return QueryResult(error=str(e))