Spaces:
Sleeping
Sleeping
File size: 3,354 Bytes
d103a0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """
OpenEnv client for SQL Data Analyst environment.
Provides a Python client interface to interact with the environment.
"""
from typing import Dict, Any, Optional
from env import SQLAnalystEnv, Action
class SQLAnalystClient:
"""Client for interacting with the SQL Data Analyst environment."""
def __init__(self, task_id: str = "monthly_signups"):
self.env = SQLAnalystEnv(task_id=task_id)
self.task_id = task_id
def reset(self) -> Dict[str, Any]:
"""Reset the environment and return initial observation."""
result = self.env.reset()
return {
"observation": {
"schema_summary": result.observation.schema_summary,
"question": result.observation.question,
"step": result.observation.step,
"max_steps": result.observation.max_steps,
"hints": result.observation.hints,
"done": result.observation.done,
},
"reward": result.reward,
"done": result.done,
}
def step(self, action: Action) -> Dict[str, Any]:
"""Execute an action and return the result."""
result = self.env.step(action)
return {
"observation": {
"schema_summary": result.observation.schema_summary,
"question": result.observation.question,
"last_query": result.observation.last_query,
"last_result": {
"columns": result.observation.last_result.columns
if result.observation.last_result
else None,
"rows": result.observation.last_result.rows
if result.observation.last_result
else None,
"error": result.observation.last_result.error
if result.observation.last_result
else None,
},
"last_error": result.observation.last_error,
"step": result.observation.step,
"max_steps": result.observation.max_steps,
"hints": result.observation.hints,
"done": result.observation.done,
},
"reward": result.reward,
"done": result.done,
"info": result.info,
}
def state(self) -> Dict[str, Any]:
"""Get the current state of the environment."""
state = self.env.state()
return {
"task_id": state.task_id,
"difficulty": state.difficulty,
"step": state.step,
"max_steps": state.max_steps,
"query_history": state.query_history,
"total_reward": state.total_reward,
"done": state.done,
}
def execute_sql(self, query: str) -> Dict[str, Any]:
"""Execute a SQL query."""
action = Action(sql_query=query)
return self.step(action)
def submit_answer(self, answer: str) -> Dict[str, Any]:
"""Submit the final answer."""
action = Action(submit_answer=answer)
return self.step(action)
def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient:
"""Get a client instance for the specified task."""
return SQLAnalystClient(task_id=task_id)
__all__ = ["SQLAnalystClient", "get_client"]
|