File size: 7,818 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d6826f
 
 
 
 
 
 
 
 
 
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from __future__ import annotations

import uuid
from typing import Any

from openenv.core import Environment

from llmserve_env.models import EpisodeLog, ServeAction, ServeObservation, ServeState
from llmserve_env.task_catalog import get_task_config
from server.reward_calculator import RewardCalculator
from server.serving_backend import ServingBackend, create_serving_backend
from server.slo_monitor import SLOMonitor
from server.workload_generator import WorkloadGenerator


class LLMServeEnvironment(Environment[ServeAction, ServeObservation, ServeState]):
    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self, seed: int = 42, mode: str | None = None, backend: ServingBackend | None = None) -> None:
        super().__init__()
        self.seed = seed
        try:
            self.backend = backend or create_serving_backend(mode=mode, seed=seed)
        except Exception as e:
            raise RuntimeError(f"Failed to create serving backend: {e}") from e
        
        try:
            self.reward_calculator = RewardCalculator()
        except Exception as e:
            raise RuntimeError(f"Failed to create reward calculator: {e}") from e
        
        self.task_config: dict[str, Any] | None = None
        self.workload_generator: WorkloadGenerator | None = None
        self.slo_monitor: SLOMonitor | None = None
        self.actions: list[ServeAction] = []
        self.observations: list[ServeObservation] = []
        self.rewards: list[float] = []
        self._state = ServeState(
            episode_id=str(uuid.uuid4()),
            step_count=0,
            task_id="uninitialized",
            total_requests_served=0,
            total_slo_violations=0,
            cumulative_reward=0.0,
            elapsed_simulated_time_s=0.0,
            workload_phase="warmup",
            done=False,
        )

    def reset(
        self,
        seed: int | None = None,
        episode_id: str | None = None,
        task_id: str = "static_workload",
        **_: Any,
    ) -> ServeObservation:
        if seed is not None:
            self.seed = seed
        self.task_config = get_task_config(task_id)
        self.workload_generator = WorkloadGenerator(self.task_config, seed=self.seed)
        self.backend.reset(seed=self.seed)
        self.slo_monitor = SLOMonitor()
        self.actions = []
        self.observations = []
        self.rewards = []
        self._state = ServeState(
            episode_id=episode_id or str(uuid.uuid4()),
            step_count=0,
            task_id=task_id,
            total_requests_served=0,
            total_slo_violations=0,
            cumulative_reward=0.0,
            elapsed_simulated_time_s=0.0,
            workload_phase="warmup",
            done=False,
        )
        workload = self.workload_generator.next_snapshot(step_index=0)
        observation = self._build_initial_observation(workload)
        self.observations.append(observation)
        return observation

    def step(
        self,
        action: ServeAction,
        timeout_s: float | None = None,
        **_: Any,
    ) -> ServeObservation:
        del timeout_s
        if self.task_config is None or self.workload_generator is None or self.slo_monitor is None:
            raise RuntimeError("reset() must be called before step().")

        if self._state.done:
            return self._build_terminal_observation("Episode already completed.")

        next_step_index = self._state.step_count + 1
        workload = self.workload_generator.next_snapshot(step_index=next_step_index)
        metrics = self.backend.run_step(self._state.task_id, action, workload)
        compliance, violations = self.slo_monitor.evaluate(
            p99_ttft_ms=metrics.p99_ttft_ms,
            target_ms=float(self.task_config["slo_p99_ttft_ms"]),
            active_requests=max(1, metrics.requests_served),
        )
        metrics.slo_violations += violations
        memory_cap = float(self.task_config.get("memory_cap_gb", 40.0))
        kv_cache_occupancy = min(1.0, metrics.gpu_memory_used_gb / memory_cap)

        reward = self.reward_calculator.calculate(
            task_id=self._state.task_id,
            metrics=metrics,
            slo_compliance_rate=compliance,
            quantization_tier=action.quantization_tier,
            priority_fraction=workload.priority_fraction,
        )
        done = next_step_index >= int(self.task_config["max_steps"])

        observation = ServeObservation(
            queue_depth=workload.queue_depth,
            active_requests=metrics.requests_served,
            kv_cache_occupancy=kv_cache_occupancy,
            mean_prompt_length=workload.mean_prompt_length,
            p50_ttft_ms=metrics.p50_ttft_ms,
            p99_ttft_ms=metrics.p99_ttft_ms,
            p50_itl_ms=metrics.p50_itl_ms,
            throughput_tps=metrics.throughput_tps,
            slo_compliance_rate=compliance,
            gpu_memory_used_gb=metrics.gpu_memory_used_gb,
            estimated_cost_per_1k=metrics.estimated_cost_per_1k,
            request_arrival_rate=workload.arrival_rate,
            spec_acceptance_rate=metrics.spec_acceptance_rate,
            eviction_events=metrics.eviction_events,
            step_index=next_step_index,
            task_id=self._state.task_id,
            reward=reward,
            done=done,
            metadata={
                "phase": workload.phase,
                "priority_fraction": workload.priority_fraction,
                "task_name": self.task_config["name"],
                "is_throttled": metrics.is_throttled,
                "preemption_events": metrics.preemption_events,
                **self.backend.describe(),
            },
        )

        self.actions.append(action)
        self.observations.append(observation)
        self.rewards.append(reward)
        self._state.step_count = next_step_index
        self._state.total_requests_served += metrics.requests_served
        self._state.total_slo_violations += metrics.slo_violations
        self._state.cumulative_reward += reward
        self._state.elapsed_simulated_time_s += float(self.task_config["step_window_s"])
        self._state.workload_phase = workload.phase
        self._state.done = done
        return observation

    @property
    def state(self) -> ServeState:
        return self._state

    def export_episode_log(self) -> EpisodeLog:
        return EpisodeLog(
            task_id=self._state.task_id,
            actions=self.actions,
            observations=self.observations,
            rewards=self.rewards,
            final_state=self._state,
        )

    def _build_initial_observation(self, workload: Any) -> ServeObservation:
        return ServeObservation(
            queue_depth=workload.queue_depth,
            active_requests=0,
            kv_cache_occupancy=0.0,
            mean_prompt_length=workload.mean_prompt_length,
            p50_ttft_ms=0.0,
            p99_ttft_ms=0.0,
            p50_itl_ms=0.0,
            throughput_tps=0.0,
            slo_compliance_rate=1.0,
            gpu_memory_used_gb=0.0,
            estimated_cost_per_1k=0.0,
            request_arrival_rate=workload.arrival_rate,
            spec_acceptance_rate=0.0,
            eviction_events=0,
            step_index=0,
            task_id=self._state.task_id,
            reward=0.0,
            done=False,
            metadata={
                "phase": workload.phase,
                "task_name": self.task_config["name"] if self.task_config else "",
                **self.backend.describe(),
            },
        )

    def _build_terminal_observation(self, message: str) -> ServeObservation:
        last = self.observations[-1]
        return last.model_copy(update={"done": True, "reward": 0.0, "metadata": {**last.metadata, "message": message}})