File size: 4,554 Bytes
3eae4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Typed HTTP client for Gov Workflow OpenEnv.

This keeps a simple OpenEnv-style client interface:
    reset() -> observation wrapper
    step(action) -> step wrapper
    state() -> state wrapper
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, TYPE_CHECKING

import requests
try:
    from openenv.core import EnvClient
    from openenv.core.env_client import StepResult
except ModuleNotFoundError:
    EnvClient = None  # type: ignore[assignment]
    StepResult = None  # type: ignore[assignment]

if TYPE_CHECKING:
    from app.models import ActionModel, EpisodeStateModel, ObservationModel, StepInfoModel


@dataclass
class ClientStepResult:
    observation: "ObservationModel"
    reward: float
    done: bool
    terminated: bool
    truncated: bool
    info: "StepInfoModel"


class GovWorkflowClient:
    """Small typed client for the FastAPI deployment."""

    def __init__(self, base_url: str) -> None:
        self.base_url = base_url.rstrip("/")
        self.session_id: str | None = None

    def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]:
        response = requests.post(f"{self.base_url}{path}", json=body, timeout=30)
        response.raise_for_status()
        return response.json()

    def reset(self, task_id: str = "district_backlog_easy", seed: int | None = None) -> "ObservationModel":
        from app.models import ObservationModel

        payload: dict[str, Any] = {"task_id": task_id}
        if seed is not None:
            payload["seed"] = seed
        data = self._post("/reset", payload)
        self.session_id = data["session_id"]
        return ObservationModel(**data["observation"])

    def step(self, action: "ActionModel") -> ClientStepResult:
        from app.models import ObservationModel, StepInfoModel

        if not self.session_id:
            raise RuntimeError("Session not initialized. Call reset() first.")
        data = self._post(
            "/step",
            {
                "session_id": self.session_id,
                "action": action.model_dump(exclude_none=True),
            },
        )
        return ClientStepResult(
            observation=ObservationModel(**data["observation"]),
            reward=float(data["reward"]),
            done=bool(data["done"]),
            terminated=bool(data["terminated"]),
            truncated=bool(data["truncated"]),
            info=StepInfoModel(**data["info"]),
        )

    def state(self, include_action_history: bool = False) -> "EpisodeStateModel":
        from app.models import EpisodeStateModel

        if not self.session_id:
            raise RuntimeError("Session not initialized. Call reset() first.")
        data = self._post(
            "/state",
            {
                "session_id": self.session_id,
                "include_action_history": include_action_history,
            },
        )
        return EpisodeStateModel(**data["state"])


if EnvClient is not None and StepResult is not None:
    class GovWorkflowOpenEnvClient(
        EnvClient["ActionModel", "ObservationModel", "EpisodeStateModel"]
    ):
        """
        OpenEnv-native websocket client.

        This class is additive and does not replace the existing HTTP client above.
        """

        def _step_payload(self, action: "ActionModel") -> dict[str, Any]:
            return action.model_dump(exclude_none=True, mode="json")

        def _parse_result(self, payload: dict[str, Any]) -> StepResult["ObservationModel"]:
            from app.models import ObservationModel

            observation_payload = payload.get("observation", {})
            obs = ObservationModel(**observation_payload)
            return StepResult(
                observation=obs,
                reward=payload.get("reward"),
                done=bool(payload.get("done", False)),
            )

        def _parse_state(self, payload: dict[str, Any]) -> "EpisodeStateModel":
            from app.models import EpisodeStateModel

            state_payload = payload.get("state", payload)
            return EpisodeStateModel(**state_payload)
else:
    class GovWorkflowOpenEnvClient:  # type: ignore[no-redef]
        """
        Placeholder when optional `openenv` package is unavailable.
        """

        def __init__(self, *args: Any, **kwargs: Any) -> None:
            raise ModuleNotFoundError(
                "GovWorkflowOpenEnvClient requires the optional 'openenv' package. "
                "Install it to use websocket OpenEnv client features."
            )