File size: 3,354 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
OpenEnv client for SQL Data Analyst environment.

Provides a Python client interface to interact with the environment.
"""

from typing import Dict, Any, Optional
from env import SQLAnalystEnv, Action


class SQLAnalystClient:
    """Client for interacting with the SQL Data Analyst environment."""

    def __init__(self, task_id: str = "monthly_signups"):
        self.env = SQLAnalystEnv(task_id=task_id)
        self.task_id = task_id

    def reset(self) -> Dict[str, Any]:
        """Reset the environment and return initial observation."""
        result = self.env.reset()
        return {
            "observation": {
                "schema_summary": result.observation.schema_summary,
                "question": result.observation.question,
                "step": result.observation.step,
                "max_steps": result.observation.max_steps,
                "hints": result.observation.hints,
                "done": result.observation.done,
            },
            "reward": result.reward,
            "done": result.done,
        }

    def step(self, action: Action) -> Dict[str, Any]:
        """Execute an action and return the result."""
        result = self.env.step(action)
        return {
            "observation": {
                "schema_summary": result.observation.schema_summary,
                "question": result.observation.question,
                "last_query": result.observation.last_query,
                "last_result": {
                    "columns": result.observation.last_result.columns
                    if result.observation.last_result
                    else None,
                    "rows": result.observation.last_result.rows
                    if result.observation.last_result
                    else None,
                    "error": result.observation.last_result.error
                    if result.observation.last_result
                    else None,
                },
                "last_error": result.observation.last_error,
                "step": result.observation.step,
                "max_steps": result.observation.max_steps,
                "hints": result.observation.hints,
                "done": result.observation.done,
            },
            "reward": result.reward,
            "done": result.done,
            "info": result.info,
        }

    def state(self) -> Dict[str, Any]:
        """Get the current state of the environment."""
        state = self.env.state()
        return {
            "task_id": state.task_id,
            "difficulty": state.difficulty,
            "step": state.step,
            "max_steps": state.max_steps,
            "query_history": state.query_history,
            "total_reward": state.total_reward,
            "done": state.done,
        }

    def execute_sql(self, query: str) -> Dict[str, Any]:
        """Execute a SQL query."""
        action = Action(sql_query=query)
        return self.step(action)

    def submit_answer(self, answer: str) -> Dict[str, Any]:
        """Submit the final answer."""
        action = Action(submit_answer=answer)
        return self.step(action)


def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient:
    """Get a client instance for the specified task."""
    return SQLAnalystClient(task_id=task_id)


__all__ = ["SQLAnalystClient", "get_client"]