Spaces:
Runtime error
Runtime error
| # salespath_env/client.py | |
| from typing import Any, Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from .models import ( | |
| SalesPathAction, | |
| SalesPathObservation, | |
| SalesPathState, | |
| ) | |
| class SalesPathEnv(EnvClient[SalesPathAction, SalesPathObservation, SalesPathState]): | |
| # ------------------------------------------------------------------ # | |
| # Abstract method implementations required by EnvClient # | |
| # ------------------------------------------------------------------ # | |
| def _step_payload(self, action: SalesPathAction) -> Dict[str, Any]: | |
| """Serialise action → JSON dict for the WebSocket server. | |
| WSStepMessage.data IS the action dict directly (no wrapper key). | |
| """ | |
| return action.model_dump(exclude={"metadata"}) | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SalesPathObservation]: | |
| """Deserialise server JSON → StepResult[SalesPathObservation].""" | |
| # Server may nest obs under an 'observation' key | |
| obs_data = payload.get("observation", payload) | |
| obs = SalesPathObservation(**obs_data) | |
| return StepResult( | |
| observation=obs, | |
| reward=payload.get("reward", obs.reward), | |
| done=payload.get("done", obs.done), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> SalesPathState: | |
| """Deserialise server JSON → SalesPathState.""" | |
| state_data = payload.get("state", payload) | |
| return SalesPathState(**state_data) | |
| # ------------------------------------------------------------------ # | |
| # Convenience wrappers that return the unwrapped observation directly # | |
| # ------------------------------------------------------------------ # | |
| def _with_step_fields( | |
| result: StepResult[SalesPathObservation], | |
| ) -> SalesPathObservation: | |
| """ | |
| Keep observation fields in sync with StepResult wrapper fields. | |
| Some server payloads provide reward/done only at top-level. | |
| """ | |
| return result.observation.model_copy( | |
| update={ | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| ) | |
| async def reset( | |
| self, | |
| difficulty: int = 1, | |
| ) -> SalesPathObservation: | |
| result = await super().reset(difficulty=difficulty) | |
| return self._with_step_fields(result) | |
| async def step( | |
| self, | |
| action_type: str, | |
| content: str, | |
| target: str = "", | |
| ) -> SalesPathObservation: | |
| action = SalesPathAction( | |
| action_type=action_type, | |
| content=content, | |
| target=target, | |
| ) | |
| result = await super().step(action) | |
| return self._with_step_fields(result) |