File size: 4,310 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
 
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Any, Dict, Iterable

from openenv.core.client_types import StepResult

from openenv.core.env_server.interfaces import Message
from openenv.core.env_client import EnvClient

from .models import SQLAction, SQLObservation, SQLState


class SQLEnvClient(EnvClient[SQLAction, SQLObservation, SQLState]):
    """Client for interacting with the SQLEnv environment server."""

    def _step_payload(self, action: SQLAction) -> Dict[str, Any]:
        """Convert a SQLAction into the payload for the step endpoint."""
        return {
            "action_type": action.action_type,
            "argument": action.argument,
            "metadata": action.metadata,
        }

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SQLObservation]:
        """Parse the response from the step endpoint into a StepResult."""

        obs_data = payload.get("observation")
        if not isinstance(obs_data, dict):
            obs_data = payload

        done = payload.get("done", obs_data.get("done", False))
        reward = payload.get("reward", obs_data.get("reward"))

        observation = SQLObservation(
            question=str(obs_data.get("question", "")),
            schema_info=str(obs_data.get("schema_info", "")),
            result=str(obs_data.get("result", "")),
            error=str(obs_data.get("error", "")),
            step_count=int(obs_data.get("step_count", 0)),
            budget_remaining=int(obs_data.get("budget_remaining", 0)),
            action_history=list(obs_data.get("action_history", [])),
            done=bool(done),
            reward=reward,
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=reward,
            done=bool(done),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> SQLState:
        return SQLState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            history_messages=payload.get("history_messages", []),
            current_action_type=payload.get("current_action_type", "QUERY"),
        )

    def _detect_action_type(self, message_content: str) -> str:
        """Detect the action type from user message content."""
        content_lower = message_content.lower()

        if content_lower.startswith("answer "):
            return "ANSWER"

        describe_keywords = [
            "describe",
            "schema",
            "columns",
            "structure",
            "what columns",
            "show columns",
        ]
        if any(keyword in content_lower for keyword in describe_keywords):
            return "DESCRIBE"

        sample_keywords = [
            "sample",
            "example",
            "rows",
            "data",
            "show me",
            "few rows",
            "how many",
        ]
        if any(keyword in content_lower for keyword in sample_keywords):
            return "SAMPLE"

        return "QUERY"

    def message_to_action(
        self,
        message: Message,
        tokenizer: Any,
        history_messages: Iterable[Message] | None = None,
    ) -> SQLAction:
        """Convert a user Message into a SQLAction."""
        if "role" not in message:
            raise ValueError("Message must contain a 'role' key")
        if "content" not in message:
            raise ValueError("Message must contain a 'content' key")
        if message["content"] is None:
            raise ValueError("Message content cannot be None")

        _ = tokenizer
        _ = history_messages

        content = str(message["content"])
        parsed = content.strip()

        action_type = "QUERY"
        argument = content
        if message["role"].lower() == "user" and parsed:
            prefix, separator, remainder = parsed.partition(" ")
            normalized_prefix = prefix.upper()
            if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
                action_type = normalized_prefix
                argument = remainder if separator else ""
            else:
                action_type = self._detect_action_type(parsed)
                argument = parsed

        return SQLAction(
            action_type=action_type,
            argument=argument,
        )