Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv client for SQL Data Analyst environment. | |
| Provides a Python client interface to interact with the environment. | |
| """ | |
| from typing import Dict, Any, Optional | |
| from env import SQLAnalystEnv, Action | |
| class SQLAnalystClient: | |
| """Client for interacting with the SQL Data Analyst environment.""" | |
| def __init__(self, task_id: str = "monthly_signups"): | |
| self.env = SQLAnalystEnv(task_id=task_id) | |
| self.task_id = task_id | |
| def reset(self) -> Dict[str, Any]: | |
| """Reset the environment and return initial observation.""" | |
| result = self.env.reset() | |
| return { | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "step": result.observation.step, | |
| "max_steps": result.observation.max_steps, | |
| "hints": result.observation.hints, | |
| "done": result.observation.done, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| def step(self, action: Action) -> Dict[str, Any]: | |
| """Execute an action and return the result.""" | |
| result = self.env.step(action) | |
| return { | |
| "observation": { | |
| "schema_summary": result.observation.schema_summary, | |
| "question": result.observation.question, | |
| "last_query": result.observation.last_query, | |
| "last_result": { | |
| "columns": result.observation.last_result.columns | |
| if result.observation.last_result | |
| else None, | |
| "rows": result.observation.last_result.rows | |
| if result.observation.last_result | |
| else None, | |
| "error": result.observation.last_result.error | |
| if result.observation.last_result | |
| else None, | |
| }, | |
| "last_error": result.observation.last_error, | |
| "step": result.observation.step, | |
| "max_steps": result.observation.max_steps, | |
| "hints": result.observation.hints, | |
| "done": result.observation.done, | |
| }, | |
| "reward": result.reward, | |
| "done": result.done, | |
| "info": result.info, | |
| } | |
| def state(self) -> Dict[str, Any]: | |
| """Get the current state of the environment.""" | |
| state = self.env.state() | |
| return { | |
| "task_id": state.task_id, | |
| "difficulty": state.difficulty, | |
| "step": state.step, | |
| "max_steps": state.max_steps, | |
| "query_history": state.query_history, | |
| "total_reward": state.total_reward, | |
| "done": state.done, | |
| } | |
| def execute_sql(self, query: str) -> Dict[str, Any]: | |
| """Execute a SQL query.""" | |
| action = Action(sql_query=query) | |
| return self.step(action) | |
| def submit_answer(self, answer: str) -> Dict[str, Any]: | |
| """Submit the final answer.""" | |
| action = Action(submit_answer=answer) | |
| return self.step(action) | |
| def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient: | |
| """Get a client instance for the specified task.""" | |
| return SQLAnalystClient(task_id=task_id) | |
| __all__ = ["SQLAnalystClient", "get_client"] | |