"""WebSocket client for the DataDetective environment.""" from typing import Dict from openenv.core.env_client import EnvClient from openenv.core.client_types import StepResult from .models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState class DataDetectiveEnv( EnvClient[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState] ): """ Async/sync client for DataDetective. Example (sync): >>> with DataDetectiveEnv(base_url="http://localhost:7860").sync() as env: ... result = env.reset(task_id="orders_drop") ... result = env.step(DataDetectiveAction(action_type="query", content="SELECT COUNT(*) FROM orders")) """ def _step_payload(self, action: DataDetectiveAction) -> Dict: return {"action_type": action.action_type, "content": action.content} def _parse_result(self, payload: Dict) -> StepResult[DataDetectiveObservation]: obs = payload.get("observation", {}) observation = DataDetectiveObservation( output=obs.get("output", ""), task_description=obs.get("task_description", ""), schema_info=obs.get("schema_info", ""), step_number=obs.get("step_number", 0), max_steps=obs.get("max_steps", 30), message=obs.get("message", ""), done=payload.get("done", False), reward=payload.get("reward"), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> DataDetectiveState: return DataDetectiveState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), task_id=payload.get("task_id", ""), queries_executed=payload.get("queries_executed", 0), max_steps=payload.get("max_steps", 30), )