vegarl / rl /env_wrapper.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
"""Gymnasium-compatible wrapper around LLMServeEnvironment for RL training."""
from __future__ import annotations
from typing import Any
import numpy as np
from llmserve_env.models import ServeAction, ServeObservation
from rl.normalize import RunningNormalizer
from server.llmserve_environment import LLMServeEnvironment
# The 15 numeric observation fields in fixed order.
OBS_FIELDS: list[str] = [
"queue_depth",
"active_requests",
"kv_cache_occupancy",
"mean_prompt_length",
"p50_ttft_ms",
"p99_ttft_ms",
"p50_itl_ms",
"throughput_tps",
"slo_compliance_rate",
"gpu_memory_used_gb",
"estimated_cost_per_1k",
"request_arrival_rate",
"spec_acceptance_rate",
"eviction_events",
"step_index",
]
OBS_DIM = len(OBS_FIELDS)
def obs_to_vector(obs: ServeObservation) -> np.ndarray:
"""Flatten a ServeObservation into a float32 array of shape (15,)."""
return np.array([float(getattr(obs, f)) for f in OBS_FIELDS], dtype=np.float32)
class GymEnvWrapper:
"""Thin wrapper that gives the LLMServeEnvironment a Gymnasium-like interface.
Supports:
- reset() -> obs (np.ndarray)
- step(action_dict) -> (obs, reward, done, info)
- Optional running normalization of observations
"""
def __init__(
self,
task_id: str = "static_workload",
seed: int = 42,
normalize: bool = True,
mode: str = "sim",
) -> None:
self.task_id = task_id
self.seed = seed
self._env = LLMServeEnvironment(seed=seed, mode=mode)
self.normalizer = RunningNormalizer(shape=(OBS_DIM,)) if normalize else None
self._last_obs: ServeObservation | None = None
self._episode_step = 0
def reset(self, seed: int | None = None) -> np.ndarray:
ep_seed = seed if seed is not None else self.seed
obs = self._env.reset(seed=ep_seed, task_id=self.task_id)
self._last_obs = obs
self._episode_step = 0
vec = obs_to_vector(obs)
if self.normalizer is not None:
self.normalizer.update(vec)
vec = self.normalizer.normalize(vec)
return vec
def step(self, action: dict[str, Any] | ServeAction) -> tuple[np.ndarray, float, bool, dict[str, Any]]:
if isinstance(action, dict):
action = ServeAction(**action)
obs = self._env.step(action)
self._last_obs = obs
self._episode_step += 1
reward = float(getattr(obs, "reward", 0.0) or 0.0)
done = bool(getattr(obs, "done", False))
vec = obs_to_vector(obs)
if self.normalizer is not None:
self.normalizer.update(vec)
vec = self.normalizer.normalize(vec)
info = {"task_id": self.task_id, "step": self._episode_step, "raw_obs": obs}
return vec, reward, done, info
@property
def obs_dim(self) -> int:
return OBS_DIM
@property
def last_observation(self) -> ServeObservation | None:
return self._last_obs