from __future__ import annotations import json from typing import Any from urllib import request from llmserve_env.models import EpisodeLog, ServeAction, ServeObservation, ServeState class LLMServeEnv: def __init__(self, base_url: str) -> None: self.base_url = base_url.rstrip("/") self.session_id: str | None = None @classmethod def from_url(cls, base_url: str) -> "LLMServeEnv": return cls(base_url=base_url) @classmethod def from_hub(cls, repo_id: str) -> "LLMServeEnv": return cls(base_url=f"https://huggingface.co/spaces/{repo_id}") def reset(self, task_id: str, seed: int | None = None) -> ServeObservation: payload = self._post("/reset", {"task_id": task_id, "seed": seed}) self.session_id = payload.get("session_id") return self._parse_observation_payload(payload) def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]: action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action body: dict[str, Any] = {"action": action_payload} if self.session_id is not None: body["session_id"] = self.session_id payload = self._post("/step", body) observation = self._parse_observation_payload(payload) if payload.get("session_id") and self.session_id is None: self.session_id = str(payload["session_id"]) return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata def state(self) -> ServeState: path = f"/state?session_id={self.session_id}" if self.session_id is not None else "/state" payload = self._get(path) return ServeState.model_validate(payload) def tasks(self) -> dict[str, Any]: return self._get("/tasks") def grade(self, log: EpisodeLog | None = None) -> dict[str, Any]: body = {} if log is None else {"episode_log": log.model_dump(mode="json")} return self._post("/grader", body) def baseline(self, task_id: str | None = None, use_openai: bool = False, model: str | None = None) -> dict[str, Any]: params = [] if task_id: params.append(f"task_id={task_id}") if use_openai: params.append("use_openai=true") if model: params.append(f"model={model}") suffix = f"?{'&'.join(params)}" if params else "" return self._get(f"/baseline{suffix}") def _parse_observation_payload(self, payload: dict[str, Any]) -> ServeObservation: observation_payload = dict(payload["observation"]) observation_payload["reward"] = payload.get("reward") observation_payload["done"] = payload.get("done", False) return ServeObservation.model_validate(observation_payload) def _get(self, path: str) -> dict[str, Any]: with request.urlopen(f"{self.base_url}{path}") as response: return json.loads(response.read().decode("utf-8")) def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]: body = json.dumps(payload).encode("utf-8") headers = {"Content-Type": "application/json"} req = request.Request(f"{self.base_url}{path}", data=body, headers=headers, method="POST") with request.urlopen(req) as response: return json.loads(response.read().decode("utf-8"))