from __future__ import annotations import uuid from typing import Any from openenv.core import Environment from llmserve_env.models import EpisodeLog, ServeAction, ServeObservation, ServeState from llmserve_env.task_catalog import get_task_config from server.reward_calculator import RewardCalculator from server.serving_backend import ServingBackend, create_serving_backend from server.slo_monitor import SLOMonitor from server.workload_generator import WorkloadGenerator class LLMServeEnvironment(Environment[ServeAction, ServeObservation, ServeState]): SUPPORTS_CONCURRENT_SESSIONS = False def __init__(self, seed: int = 42, mode: str | None = None, backend: ServingBackend | None = None) -> None: super().__init__() self.seed = seed try: self.backend = backend or create_serving_backend(mode=mode, seed=seed) except Exception as e: raise RuntimeError(f"Failed to create serving backend: {e}") from e try: self.reward_calculator = RewardCalculator() except Exception as e: raise RuntimeError(f"Failed to create reward calculator: {e}") from e self.task_config: dict[str, Any] | None = None self.workload_generator: WorkloadGenerator | None = None self.slo_monitor: SLOMonitor | None = None self.actions: list[ServeAction] = [] self.observations: list[ServeObservation] = [] self.rewards: list[float] = [] self._state = ServeState( episode_id=str(uuid.uuid4()), step_count=0, task_id="uninitialized", total_requests_served=0, total_slo_violations=0, cumulative_reward=0.0, elapsed_simulated_time_s=0.0, workload_phase="warmup", done=False, ) def reset( self, seed: int | None = None, episode_id: str | None = None, task_id: str = "static_workload", **_: Any, ) -> ServeObservation: if seed is not None: self.seed = seed self.task_config = get_task_config(task_id) self.workload_generator = WorkloadGenerator(self.task_config, seed=self.seed) self.backend.reset(seed=self.seed) self.slo_monitor = SLOMonitor() self.actions = [] self.observations = [] self.rewards = [] self._state = ServeState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, task_id=task_id, total_requests_served=0, total_slo_violations=0, cumulative_reward=0.0, elapsed_simulated_time_s=0.0, workload_phase="warmup", done=False, ) workload = self.workload_generator.next_snapshot(step_index=0) observation = self._build_initial_observation(workload) self.observations.append(observation) return observation def step( self, action: ServeAction, timeout_s: float | None = None, **_: Any, ) -> ServeObservation: del timeout_s if self.task_config is None or self.workload_generator is None or self.slo_monitor is None: raise RuntimeError("reset() must be called before step().") if self._state.done: return self._build_terminal_observation("Episode already completed.") next_step_index = self._state.step_count + 1 workload = self.workload_generator.next_snapshot(step_index=next_step_index) metrics = self.backend.run_step(self._state.task_id, action, workload) compliance, violations = self.slo_monitor.evaluate( p99_ttft_ms=metrics.p99_ttft_ms, target_ms=float(self.task_config["slo_p99_ttft_ms"]), active_requests=max(1, metrics.requests_served), ) metrics.slo_violations += violations memory_cap = float(self.task_config.get("memory_cap_gb", 40.0)) kv_cache_occupancy = min(1.0, metrics.gpu_memory_used_gb / memory_cap) reward = self.reward_calculator.calculate( task_id=self._state.task_id, metrics=metrics, slo_compliance_rate=compliance, quantization_tier=action.quantization_tier, priority_fraction=workload.priority_fraction, ) done = next_step_index >= int(self.task_config["max_steps"]) observation = ServeObservation( queue_depth=workload.queue_depth, active_requests=metrics.requests_served, kv_cache_occupancy=kv_cache_occupancy, mean_prompt_length=workload.mean_prompt_length, p50_ttft_ms=metrics.p50_ttft_ms, p99_ttft_ms=metrics.p99_ttft_ms, p50_itl_ms=metrics.p50_itl_ms, throughput_tps=metrics.throughput_tps, slo_compliance_rate=compliance, gpu_memory_used_gb=metrics.gpu_memory_used_gb, estimated_cost_per_1k=metrics.estimated_cost_per_1k, request_arrival_rate=workload.arrival_rate, spec_acceptance_rate=metrics.spec_acceptance_rate, eviction_events=metrics.eviction_events, step_index=next_step_index, task_id=self._state.task_id, reward=reward, done=done, metadata={ "phase": workload.phase, "priority_fraction": workload.priority_fraction, "task_name": self.task_config["name"], "is_throttled": metrics.is_throttled, "preemption_events": metrics.preemption_events, **self.backend.describe(), }, ) self.actions.append(action) self.observations.append(observation) self.rewards.append(reward) self._state.step_count = next_step_index self._state.total_requests_served += metrics.requests_served self._state.total_slo_violations += metrics.slo_violations self._state.cumulative_reward += reward self._state.elapsed_simulated_time_s += float(self.task_config["step_window_s"]) self._state.workload_phase = workload.phase self._state.done = done return observation @property def state(self) -> ServeState: return self._state def export_episode_log(self) -> EpisodeLog: return EpisodeLog( task_id=self._state.task_id, actions=self.actions, observations=self.observations, rewards=self.rewards, final_state=self._state, ) def _build_initial_observation(self, workload: Any) -> ServeObservation: return ServeObservation( queue_depth=workload.queue_depth, active_requests=0, kv_cache_occupancy=0.0, mean_prompt_length=workload.mean_prompt_length, p50_ttft_ms=0.0, p99_ttft_ms=0.0, p50_itl_ms=0.0, throughput_tps=0.0, slo_compliance_rate=1.0, gpu_memory_used_gb=0.0, estimated_cost_per_1k=0.0, request_arrival_rate=workload.arrival_rate, spec_acceptance_rate=0.0, eviction_events=0, step_index=0, task_id=self._state.task_id, reward=0.0, done=False, metadata={ "phase": workload.phase, "task_name": self.task_config["name"] if self.task_config else "", **self.backend.describe(), }, ) def _build_terminal_observation(self, message: str) -> ServeObservation: last = self.observations[-1] return last.model_copy(update={"done": True, "reward": 0.0, "metadata": {**last.metadata, "message": message}})