vegarl / llmserve_env /client.py
ronitraj's picture
Make client compatible with spaces lacking session_id
49f43bf
from __future__ import annotations
import json
from typing import Any
from urllib import request
from llmserve_env.models import EpisodeLog, ServeAction, ServeObservation, ServeState
class LLMServeEnv:
def __init__(self, base_url: str) -> None:
self.base_url = base_url.rstrip("/")
self.session_id: str | None = None
@classmethod
def from_url(cls, base_url: str) -> "LLMServeEnv":
return cls(base_url=base_url)
@classmethod
def from_hub(cls, repo_id: str) -> "LLMServeEnv":
return cls(base_url=f"https://huggingface.co/spaces/{repo_id}")
def reset(self, task_id: str, seed: int | None = None) -> ServeObservation:
payload = self._post("/reset", {"task_id": task_id, "seed": seed})
self.session_id = payload.get("session_id")
return self._parse_observation_payload(payload)
def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
body: dict[str, Any] = {"action": action_payload}
if self.session_id is not None:
body["session_id"] = self.session_id
payload = self._post("/step", body)
observation = self._parse_observation_payload(payload)
if payload.get("session_id") and self.session_id is None:
self.session_id = str(payload["session_id"])
return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
def state(self) -> ServeState:
path = f"/state?session_id={self.session_id}" if self.session_id is not None else "/state"
payload = self._get(path)
return ServeState.model_validate(payload)
def tasks(self) -> dict[str, Any]:
return self._get("/tasks")
def grade(self, log: EpisodeLog | None = None) -> dict[str, Any]:
body = {} if log is None else {"episode_log": log.model_dump(mode="json")}
return self._post("/grader", body)
def baseline(self, task_id: str | None = None, use_openai: bool = False, model: str | None = None) -> dict[str, Any]:
params = []
if task_id:
params.append(f"task_id={task_id}")
if use_openai:
params.append("use_openai=true")
if model:
params.append(f"model={model}")
suffix = f"?{'&'.join(params)}" if params else ""
return self._get(f"/baseline{suffix}")
def _parse_observation_payload(self, payload: dict[str, Any]) -> ServeObservation:
observation_payload = dict(payload["observation"])
observation_payload["reward"] = payload.get("reward")
observation_payload["done"] = payload.get("done", False)
return ServeObservation.model_validate(observation_payload)
def _get(self, path: str) -> dict[str, Any]:
with request.urlopen(f"{self.base_url}{path}") as response:
return json.loads(response.read().decode("utf-8"))
def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
body = json.dumps(payload).encode("utf-8")
headers = {"Content-Type": "application/json"}
req = request.Request(f"{self.base_url}{path}", data=body, headers=headers, method="POST")
with request.urlopen(req) as response:
return json.loads(response.read().decode("utf-8"))