"""WebSocket client for CERNenv. Wraps OpenEnv's ``EnvClient`` so users can ``await client.reset()`` and ``await client.step(action)`` against a running CERNenv server. """ from __future__ import annotations from typing import Any, Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from models import CollisionObservation, ExperimentAction from server.environment import CernState class CernEnv(EnvClient[ExperimentAction, CollisionObservation, CernState]): """Async WebSocket client for the CERN environment.""" def _step_payload(self, action: ExperimentAction) -> Dict[str, Any]: return action.model_dump() def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CollisionObservation]: obs_data = payload.get("observation", payload) observation = CollisionObservation(**obs_data) return StepResult( observation=observation, reward=payload.get("reward", observation.reward), done=payload.get("done", observation.done), ) def _parse_state(self, payload: Dict[str, Any]) -> CernState: return CernState(**payload) __all__ = ["CernEnv"]