Spaces:
Sleeping
Sleeping
File size: 1,875 Bytes
d110f58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | """OpenEnv client for PolypharmacyEnv.
Provides a typed async/sync client for interacting with a PolypharmacyEnv
server via WebSocket, following the OpenEnv EnvClient pattern.
Example (async):
>>> async with PolypharmacyClient(base_url="ws://localhost:8000") as env:
... result = await env.reset(task_id="easy_screening")
... result = await env.step(PolypharmacyAction(action_type="finish_review"))
Example (sync):
>>> with PolypharmacyClient(base_url="ws://localhost:8000").sync() as env:
... result = env.reset(task_id="easy_screening")
"""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import PolypharmacyAction, PolypharmacyObservation, PolypharmacyState
class PolypharmacyClient(
EnvClient[PolypharmacyAction, PolypharmacyObservation, PolypharmacyState]
):
"""Typed OpenEnv client for the PolypharmacyEnv environment."""
def _step_payload(self, action: PolypharmacyAction) -> Dict[str, Any]:
"""Convert a PolypharmacyAction to the JSON payload for the server."""
return action.model_dump(exclude_none=True)
def _parse_result(
self, payload: Dict[str, Any]
) -> StepResult[PolypharmacyObservation]:
"""Parse a server response into a StepResult with typed observation."""
obs_data = payload.get("observation", payload)
obs = PolypharmacyObservation.model_validate(obs_data)
return StepResult(
observation=obs,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> PolypharmacyState:
"""Parse a server state response into a typed PolypharmacyState."""
return PolypharmacyState.model_validate(payload)
|