Spaces:
Running
Running
File size: 4,554 Bytes
df97e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
Typed HTTP client for Gov Workflow OpenEnv.
This keeps a simple OpenEnv-style client interface:
reset() -> observation wrapper
step(action) -> step wrapper
state() -> state wrapper
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, TYPE_CHECKING
import requests
try:
from openenv.core import EnvClient
from openenv.core.env_client import StepResult
except ModuleNotFoundError:
EnvClient = None # type: ignore[assignment]
StepResult = None # type: ignore[assignment]
if TYPE_CHECKING:
from app.models import ActionModel, EpisodeStateModel, ObservationModel, StepInfoModel
@dataclass
class ClientStepResult:
observation: "ObservationModel"
reward: float
done: bool
terminated: bool
truncated: bool
info: "StepInfoModel"
class GovWorkflowClient:
"""Small typed client for the FastAPI deployment."""
def __init__(self, base_url: str) -> None:
self.base_url = base_url.rstrip("/")
self.session_id: str | None = None
def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]:
response = requests.post(f"{self.base_url}{path}", json=body, timeout=30)
response.raise_for_status()
return response.json()
def reset(self, task_id: str = "district_backlog_easy", seed: int | None = None) -> "ObservationModel":
from app.models import ObservationModel
payload: dict[str, Any] = {"task_id": task_id}
if seed is not None:
payload["seed"] = seed
data = self._post("/reset", payload)
self.session_id = data["session_id"]
return ObservationModel(**data["observation"])
def step(self, action: "ActionModel") -> ClientStepResult:
from app.models import ObservationModel, StepInfoModel
if not self.session_id:
raise RuntimeError("Session not initialized. Call reset() first.")
data = self._post(
"/step",
{
"session_id": self.session_id,
"action": action.model_dump(exclude_none=True),
},
)
return ClientStepResult(
observation=ObservationModel(**data["observation"]),
reward=float(data["reward"]),
done=bool(data["done"]),
terminated=bool(data["terminated"]),
truncated=bool(data["truncated"]),
info=StepInfoModel(**data["info"]),
)
def state(self, include_action_history: bool = False) -> "EpisodeStateModel":
from app.models import EpisodeStateModel
if not self.session_id:
raise RuntimeError("Session not initialized. Call reset() first.")
data = self._post(
"/state",
{
"session_id": self.session_id,
"include_action_history": include_action_history,
},
)
return EpisodeStateModel(**data["state"])
if EnvClient is not None and StepResult is not None:
class GovWorkflowOpenEnvClient(
EnvClient["ActionModel", "ObservationModel", "EpisodeStateModel"]
):
"""
OpenEnv-native websocket client.
This class is additive and does not replace the existing HTTP client above.
"""
def _step_payload(self, action: "ActionModel") -> dict[str, Any]:
return action.model_dump(exclude_none=True, mode="json")
def _parse_result(self, payload: dict[str, Any]) -> StepResult["ObservationModel"]:
from app.models import ObservationModel
observation_payload = payload.get("observation", {})
obs = ObservationModel(**observation_payload)
return StepResult(
observation=obs,
reward=payload.get("reward"),
done=bool(payload.get("done", False)),
)
def _parse_state(self, payload: dict[str, Any]) -> "EpisodeStateModel":
from app.models import EpisodeStateModel
state_payload = payload.get("state", payload)
return EpisodeStateModel(**state_payload)
else:
class GovWorkflowOpenEnvClient: # type: ignore[no-redef]
"""
Placeholder when optional `openenv` package is unavailable.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ModuleNotFoundError(
"GovWorkflowOpenEnvClient requires the optional 'openenv' package. "
"Install it to use websocket OpenEnv client features."
)
|