ronitraj commited on
Commit
49f43bf
·
1 Parent(s): 8ec915d

Make client compatible with spaces lacking session_id

Browse files
Files changed (1) hide show
  1. llmserve_env/client.py +8 -6
llmserve_env/client.py CHANGED
@@ -26,17 +26,19 @@ class LLMServeEnv:
26
  return self._parse_observation_payload(payload)
27
 
28
  def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
29
- if self.session_id is None:
30
- raise RuntimeError("reset() must be called before step() so the client has a session_id.")
31
  action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
32
- payload = self._post("/step", {"action": action_payload, "session_id": self.session_id})
 
 
 
33
  observation = self._parse_observation_payload(payload)
 
 
34
  return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
35
 
36
  def state(self) -> ServeState:
37
- if self.session_id is None:
38
- raise RuntimeError("reset() must be called before state() so the client has a session_id.")
39
- payload = self._get(f"/state?session_id={self.session_id}")
40
  return ServeState.model_validate(payload)
41
 
42
  def tasks(self) -> dict[str, Any]:
 
26
  return self._parse_observation_payload(payload)
27
 
28
  def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
 
 
29
  action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
30
+ body: dict[str, Any] = {"action": action_payload}
31
+ if self.session_id is not None:
32
+ body["session_id"] = self.session_id
33
+ payload = self._post("/step", body)
34
  observation = self._parse_observation_payload(payload)
35
+ if payload.get("session_id") and self.session_id is None:
36
+ self.session_id = str(payload["session_id"])
37
  return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
38
 
39
  def state(self) -> ServeState:
40
+ path = f"/state?session_id={self.session_id}" if self.session_id is not None else "/state"
41
+ payload = self._get(path)
 
42
  return ServeState.model_validate(payload)
43
 
44
  def tasks(self) -> dict[str, Any]: