Spaces:
Running
Running
| """ | |
| Typed HTTP client for Gov Workflow OpenEnv. | |
| This keeps a simple OpenEnv-style client interface: | |
| reset() -> observation wrapper | |
| step(action) -> step wrapper | |
| state() -> state wrapper | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, TYPE_CHECKING | |
| import requests | |
| try: | |
| from openenv.core import EnvClient | |
| from openenv.core.env_client import StepResult | |
| except ModuleNotFoundError: | |
| EnvClient = None # type: ignore[assignment] | |
| StepResult = None # type: ignore[assignment] | |
| if TYPE_CHECKING: | |
| from app.models import ActionModel, EpisodeStateModel, ObservationModel, StepInfoModel | |
| class ClientStepResult: | |
| observation: "ObservationModel" | |
| reward: float | |
| done: bool | |
| terminated: bool | |
| truncated: bool | |
| info: "StepInfoModel" | |
| class GovWorkflowClient: | |
| """Small typed client for the FastAPI deployment.""" | |
| def __init__(self, base_url: str) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self.session_id: str | None = None | |
| def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]: | |
| response = requests.post(f"{self.base_url}{path}", json=body, timeout=30) | |
| response.raise_for_status() | |
| return response.json() | |
| def reset(self, task_id: str = "district_backlog_easy", seed: int | None = None) -> "ObservationModel": | |
| from app.models import ObservationModel | |
| payload: dict[str, Any] = {"task_id": task_id} | |
| if seed is not None: | |
| payload["seed"] = seed | |
| data = self._post("/reset", payload) | |
| self.session_id = data["session_id"] | |
| return ObservationModel(**data["observation"]) | |
| def step(self, action: "ActionModel") -> ClientStepResult: | |
| from app.models import ObservationModel, StepInfoModel | |
| if not self.session_id: | |
| raise RuntimeError("Session not initialized. Call reset() first.") | |
| data = self._post( | |
| "/step", | |
| { | |
| "session_id": self.session_id, | |
| "action": action.model_dump(exclude_none=True), | |
| }, | |
| ) | |
| return ClientStepResult( | |
| observation=ObservationModel(**data["observation"]), | |
| reward=float(data["reward"]), | |
| done=bool(data["done"]), | |
| terminated=bool(data["terminated"]), | |
| truncated=bool(data["truncated"]), | |
| info=StepInfoModel(**data["info"]), | |
| ) | |
| def state(self, include_action_history: bool = False) -> "EpisodeStateModel": | |
| from app.models import EpisodeStateModel | |
| if not self.session_id: | |
| raise RuntimeError("Session not initialized. Call reset() first.") | |
| data = self._post( | |
| "/state", | |
| { | |
| "session_id": self.session_id, | |
| "include_action_history": include_action_history, | |
| }, | |
| ) | |
| return EpisodeStateModel(**data["state"]) | |
| if EnvClient is not None and StepResult is not None: | |
| class GovWorkflowOpenEnvClient( | |
| EnvClient["ActionModel", "ObservationModel", "EpisodeStateModel"] | |
| ): | |
| """ | |
| OpenEnv-native websocket client. | |
| This class is additive and does not replace the existing HTTP client above. | |
| """ | |
| def _step_payload(self, action: "ActionModel") -> dict[str, Any]: | |
| return action.model_dump(exclude_none=True, mode="json") | |
| def _parse_result(self, payload: dict[str, Any]) -> StepResult["ObservationModel"]: | |
| from app.models import ObservationModel | |
| observation_payload = payload.get("observation", {}) | |
| obs = ObservationModel(**observation_payload) | |
| return StepResult( | |
| observation=obs, | |
| reward=payload.get("reward"), | |
| done=bool(payload.get("done", False)), | |
| ) | |
| def _parse_state(self, payload: dict[str, Any]) -> "EpisodeStateModel": | |
| from app.models import EpisodeStateModel | |
| state_payload = payload.get("state", payload) | |
| return EpisodeStateModel(**state_payload) | |
| else: | |
| class GovWorkflowOpenEnvClient: # type: ignore[no-redef] | |
| """ | |
| Placeholder when optional `openenv` package is unavailable. | |
| """ | |
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |
| raise ModuleNotFoundError( | |
| "GovWorkflowOpenEnvClient requires the optional 'openenv' package. " | |
| "Install it to use websocket OpenEnv client features." | |
| ) | |