File size: 2,325 Bytes
3aeaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations

from typing import Any

import requests
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from models import SeigeAction, SeigeObservation


class SeigeOpenEnvClient(EnvClient[SeigeAction, SeigeObservation, State]):
    def _step_payload(self, action: SeigeAction) -> dict[str, Any]:
        return action.model_dump(exclude_none=True)

    def _parse_result(self, payload: dict[str, Any]) -> StepResult[SeigeObservation]:
        obs_data = payload.get("observation", {})
        observation = SeigeObservation(
            red=obs_data.get("red"),
            blue=obs_data.get("blue"),
            current_agent=obs_data.get("current_agent", "both"),
            info=obs_data.get("info", {}),
            done=payload.get("done", False),
            reward=payload.get("reward"),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict[str, Any]) -> State:
        return State(**payload)


class SeigeClient:
    def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 30.0) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout

    def reset(self) -> dict:
        payload = self._post("/reset", {}).json()
        return payload.get("observation", payload)

    def step(self, action: dict[str, Any]) -> dict:
        payload = self._post("/step", {"action": action}).json()
        observation = payload.get("observation")
        if isinstance(observation, dict) and "current_agent" in observation:
            current_agent = observation.get("current_agent")
            if current_agent in {"red", "blue"}:
                payload["observation"] = observation.get(current_agent) or {}
        return payload

    def state(self) -> dict:
        return requests.get(f"{self.base_url}/state", timeout=self.timeout).json()

    def health(self) -> dict:
        return requests.get(f"{self.base_url}/health", timeout=self.timeout).json()

    def _post(self, path: str, payload: dict[str, Any]):
        return requests.post(f"{self.base_url}{path}", json=payload, timeout=self.timeout)