""" OpenEnv client for SQL Data Analyst environment. Provides a Python client interface to interact with the environment. """ from typing import Dict, Any, Optional from env import SQLAnalystEnv, Action class SQLAnalystClient: """Client for interacting with the SQL Data Analyst environment.""" def __init__(self, task_id: str = "monthly_signups"): self.env = SQLAnalystEnv(task_id=task_id) self.task_id = task_id def reset(self) -> Dict[str, Any]: """Reset the environment and return initial observation.""" result = self.env.reset() return { "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "step": result.observation.step, "max_steps": result.observation.max_steps, "hints": result.observation.hints, "done": result.observation.done, }, "reward": result.reward, "done": result.done, } def step(self, action: Action) -> Dict[str, Any]: """Execute an action and return the result.""" result = self.env.step(action) return { "observation": { "schema_summary": result.observation.schema_summary, "question": result.observation.question, "last_query": result.observation.last_query, "last_result": { "columns": result.observation.last_result.columns if result.observation.last_result else None, "rows": result.observation.last_result.rows if result.observation.last_result else None, "error": result.observation.last_result.error if result.observation.last_result else None, }, "last_error": result.observation.last_error, "step": result.observation.step, "max_steps": result.observation.max_steps, "hints": result.observation.hints, "done": result.observation.done, }, "reward": result.reward, "done": result.done, "info": result.info, } def state(self) -> Dict[str, Any]: """Get the current state of the environment.""" state = self.env.state() return { "task_id": state.task_id, "difficulty": state.difficulty, "step": state.step, "max_steps": state.max_steps, "query_history": state.query_history, "total_reward": state.total_reward, "done": state.done, } def execute_sql(self, query: str) -> Dict[str, Any]: """Execute a SQL query.""" action = Action(sql_query=query) return self.step(action) def submit_answer(self, answer: str) -> Dict[str, Any]: """Submit the final answer.""" action = Action(submit_answer=answer) return self.step(action) def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient: """Get a client instance for the specified task.""" return SQLAnalystClient(task_id=task_id) __all__ = ["SQLAnalystClient", "get_client"]