| """Gymnasium-compatible wrapper around LLMServeEnvironment for RL training.""" |
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import numpy as np |
|
|
| from llmserve_env.models import ServeAction, ServeObservation |
| from rl.normalize import RunningNormalizer |
| from server.llmserve_environment import LLMServeEnvironment |
|
|
|
|
| |
| OBS_FIELDS: list[str] = [ |
| "queue_depth", |
| "active_requests", |
| "kv_cache_occupancy", |
| "mean_prompt_length", |
| "p50_ttft_ms", |
| "p99_ttft_ms", |
| "p50_itl_ms", |
| "throughput_tps", |
| "slo_compliance_rate", |
| "gpu_memory_used_gb", |
| "estimated_cost_per_1k", |
| "request_arrival_rate", |
| "spec_acceptance_rate", |
| "eviction_events", |
| "step_index", |
| ] |
| OBS_DIM = len(OBS_FIELDS) |
|
|
|
|
| def obs_to_vector(obs: ServeObservation) -> np.ndarray: |
| """Flatten a ServeObservation into a float32 array of shape (15,).""" |
| return np.array([float(getattr(obs, f)) for f in OBS_FIELDS], dtype=np.float32) |
|
|
|
|
| class GymEnvWrapper: |
| """Thin wrapper that gives the LLMServeEnvironment a Gymnasium-like interface. |
| |
| Supports: |
| - reset() -> obs (np.ndarray) |
| - step(action_dict) -> (obs, reward, done, info) |
| - Optional running normalization of observations |
| """ |
|
|
| def __init__( |
| self, |
| task_id: str = "static_workload", |
| seed: int = 42, |
| normalize: bool = True, |
| mode: str = "sim", |
| ) -> None: |
| self.task_id = task_id |
| self.seed = seed |
| self._env = LLMServeEnvironment(seed=seed, mode=mode) |
| self.normalizer = RunningNormalizer(shape=(OBS_DIM,)) if normalize else None |
| self._last_obs: ServeObservation | None = None |
| self._episode_step = 0 |
|
|
| def reset(self, seed: int | None = None) -> np.ndarray: |
| ep_seed = seed if seed is not None else self.seed |
| obs = self._env.reset(seed=ep_seed, task_id=self.task_id) |
| self._last_obs = obs |
| self._episode_step = 0 |
| vec = obs_to_vector(obs) |
| if self.normalizer is not None: |
| self.normalizer.update(vec) |
| vec = self.normalizer.normalize(vec) |
| return vec |
|
|
| def step(self, action: dict[str, Any] | ServeAction) -> tuple[np.ndarray, float, bool, dict[str, Any]]: |
| if isinstance(action, dict): |
| action = ServeAction(**action) |
| obs = self._env.step(action) |
| self._last_obs = obs |
| self._episode_step += 1 |
| reward = float(getattr(obs, "reward", 0.0) or 0.0) |
| done = bool(getattr(obs, "done", False)) |
| vec = obs_to_vector(obs) |
| if self.normalizer is not None: |
| self.normalizer.update(vec) |
| vec = self.normalizer.normalize(vec) |
| info = {"task_id": self.task_id, "step": self._episode_step, "raw_obs": obs} |
| return vec, reward, done, info |
|
|
| @property |
| def obs_dim(self) -> int: |
| return OBS_DIM |
|
|
| @property |
| def last_observation(self) -> ServeObservation | None: |
| return self._last_obs |
|
|