""" Typed HTTP client for Gov Workflow OpenEnv. This keeps a simple OpenEnv-style client interface: reset() -> observation wrapper step(action) -> step wrapper state() -> state wrapper """ from __future__ import annotations from dataclasses import dataclass from typing import Any, TYPE_CHECKING import requests try: from openenv.core import EnvClient from openenv.core.env_client import StepResult except ModuleNotFoundError: EnvClient = None # type: ignore[assignment] StepResult = None # type: ignore[assignment] if TYPE_CHECKING: from app.models import ActionModel, EpisodeStateModel, ObservationModel, StepInfoModel @dataclass class ClientStepResult: observation: "ObservationModel" reward: float done: bool terminated: bool truncated: bool info: "StepInfoModel" class GovWorkflowClient: """Small typed client for the FastAPI deployment.""" def __init__(self, base_url: str) -> None: self.base_url = base_url.rstrip("/") self.session_id: str | None = None def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]: response = requests.post(f"{self.base_url}{path}", json=body, timeout=30) response.raise_for_status() return response.json() def reset(self, task_id: str = "district_backlog_easy", seed: int | None = None) -> "ObservationModel": from app.models import ObservationModel payload: dict[str, Any] = {"task_id": task_id} if seed is not None: payload["seed"] = seed data = self._post("/reset", payload) self.session_id = data["session_id"] return ObservationModel(**data["observation"]) def step(self, action: "ActionModel") -> ClientStepResult: from app.models import ObservationModel, StepInfoModel if not self.session_id: raise RuntimeError("Session not initialized. Call reset() first.") data = self._post( "/step", { "session_id": self.session_id, "action": action.model_dump(exclude_none=True), }, ) return ClientStepResult( observation=ObservationModel(**data["observation"]), reward=float(data["reward"]), done=bool(data["done"]), terminated=bool(data["terminated"]), truncated=bool(data["truncated"]), info=StepInfoModel(**data["info"]), ) def state(self, include_action_history: bool = False) -> "EpisodeStateModel": from app.models import EpisodeStateModel if not self.session_id: raise RuntimeError("Session not initialized. Call reset() first.") data = self._post( "/state", { "session_id": self.session_id, "include_action_history": include_action_history, }, ) return EpisodeStateModel(**data["state"]) if EnvClient is not None and StepResult is not None: class GovWorkflowOpenEnvClient( EnvClient["ActionModel", "ObservationModel", "EpisodeStateModel"] ): """ OpenEnv-native websocket client. This class is additive and does not replace the existing HTTP client above. """ def _step_payload(self, action: "ActionModel") -> dict[str, Any]: return action.model_dump(exclude_none=True, mode="json") def _parse_result(self, payload: dict[str, Any]) -> StepResult["ObservationModel"]: from app.models import ObservationModel observation_payload = payload.get("observation", {}) obs = ObservationModel(**observation_payload) return StepResult( observation=obs, reward=payload.get("reward"), done=bool(payload.get("done", False)), ) def _parse_state(self, payload: dict[str, Any]) -> "EpisodeStateModel": from app.models import EpisodeStateModel state_payload = payload.get("state", payload) return EpisodeStateModel(**state_payload) else: class GovWorkflowOpenEnvClient: # type: ignore[no-redef] """ Placeholder when optional `openenv` package is unavailable. """ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ModuleNotFoundError( "GovWorkflowOpenEnvClient requires the optional 'openenv' package. " "Install it to use websocket OpenEnv client features." )