Make client compatible with spaces lacking session_id
Browse files- llmserve_env/client.py +8 -6
llmserve_env/client.py
CHANGED
|
@@ -26,17 +26,19 @@ class LLMServeEnv:
|
|
| 26 |
return self._parse_observation_payload(payload)
|
| 27 |
|
| 28 |
def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
|
| 29 |
-
if self.session_id is None:
|
| 30 |
-
raise RuntimeError("reset() must be called before step() so the client has a session_id.")
|
| 31 |
action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
observation = self._parse_observation_payload(payload)
|
|
|
|
|
|
|
| 34 |
return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
|
| 35 |
|
| 36 |
def state(self) -> ServeState:
|
| 37 |
-
if self.session_id is None
|
| 38 |
-
|
| 39 |
-
payload = self._get(f"/state?session_id={self.session_id}")
|
| 40 |
return ServeState.model_validate(payload)
|
| 41 |
|
| 42 |
def tasks(self) -> dict[str, Any]:
|
|
|
|
| 26 |
return self._parse_observation_payload(payload)
|
| 27 |
|
| 28 |
def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
|
|
|
|
|
|
|
| 29 |
action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
|
| 30 |
+
body: dict[str, Any] = {"action": action_payload}
|
| 31 |
+
if self.session_id is not None:
|
| 32 |
+
body["session_id"] = self.session_id
|
| 33 |
+
payload = self._post("/step", body)
|
| 34 |
observation = self._parse_observation_payload(payload)
|
| 35 |
+
if payload.get("session_id") and self.session_id is None:
|
| 36 |
+
self.session_id = str(payload["session_id"])
|
| 37 |
return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
|
| 38 |
|
| 39 |
def state(self) -> ServeState:
|
| 40 |
+
path = f"/state?session_id={self.session_id}" if self.session_id is not None else "/state"
|
| 41 |
+
payload = self._get(path)
|
|
|
|
| 42 |
return ServeState.model_validate(payload)
|
| 43 |
|
| 44 |
def tasks(self) -> dict[str, Any]:
|