File size: 4,310 Bytes
5dd1bb4 9e64e71 5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | from typing import Any, Dict, Iterable
from openenv.core.client_types import StepResult
from openenv.core.env_server.interfaces import Message
from openenv.core.env_client import EnvClient
from .models import SQLAction, SQLObservation, SQLState
class SQLEnvClient(EnvClient[SQLAction, SQLObservation, SQLState]):
"""Client for interacting with the SQLEnv environment server."""
def _step_payload(self, action: SQLAction) -> Dict[str, Any]:
"""Convert a SQLAction into the payload for the step endpoint."""
return {
"action_type": action.action_type,
"argument": action.argument,
"metadata": action.metadata,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SQLObservation]:
"""Parse the response from the step endpoint into a StepResult."""
obs_data = payload.get("observation")
if not isinstance(obs_data, dict):
obs_data = payload
done = payload.get("done", obs_data.get("done", False))
reward = payload.get("reward", obs_data.get("reward"))
observation = SQLObservation(
question=str(obs_data.get("question", "")),
schema_info=str(obs_data.get("schema_info", "")),
result=str(obs_data.get("result", "")),
error=str(obs_data.get("error", "")),
step_count=int(obs_data.get("step_count", 0)),
budget_remaining=int(obs_data.get("budget_remaining", 0)),
action_history=list(obs_data.get("action_history", [])),
done=bool(done),
reward=reward,
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=reward,
done=bool(done),
)
def _parse_state(self, payload: Dict[str, Any]) -> SQLState:
return SQLState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
history_messages=payload.get("history_messages", []),
current_action_type=payload.get("current_action_type", "QUERY"),
)
def _detect_action_type(self, message_content: str) -> str:
"""Detect the action type from user message content."""
content_lower = message_content.lower()
if content_lower.startswith("answer "):
return "ANSWER"
describe_keywords = [
"describe",
"schema",
"columns",
"structure",
"what columns",
"show columns",
]
if any(keyword in content_lower for keyword in describe_keywords):
return "DESCRIBE"
sample_keywords = [
"sample",
"example",
"rows",
"data",
"show me",
"few rows",
"how many",
]
if any(keyword in content_lower for keyword in sample_keywords):
return "SAMPLE"
return "QUERY"
def message_to_action(
self,
message: Message,
tokenizer: Any,
history_messages: Iterable[Message] | None = None,
) -> SQLAction:
"""Convert a user Message into a SQLAction."""
if "role" not in message:
raise ValueError("Message must contain a 'role' key")
if "content" not in message:
raise ValueError("Message must contain a 'content' key")
if message["content"] is None:
raise ValueError("Message content cannot be None")
_ = tokenizer
_ = history_messages
content = str(message["content"])
parsed = content.strip()
action_type = "QUERY"
argument = content
if message["role"].lower() == "user" and parsed:
prefix, separator, remainder = parsed.partition(" ")
normalized_prefix = prefix.upper()
if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
action_type = normalized_prefix
argument = remainder if separator else ""
else:
action_type = self._detect_action_type(parsed)
argument = parsed
return SQLAction(
action_type=action_type,
argument=argument,
)
|