| from __future__ import annotations |
|
|
| from enum import Enum |
| from typing import Any, Literal |
|
|
| from openenv.core import Action, Observation |
| from pydantic import BaseModel, ConfigDict, Field, model_validator |
|
|
|
|
| class QuantizationTier(str, Enum): |
| FP16 = "FP16" |
| INT8 = "INT8" |
| INT4 = "INT4" |
|
|
|
|
| class ServeAction(Action): |
| model_config = ConfigDict(extra="forbid") |
|
|
| batch_cap: int = Field(default=32, ge=1, le=512) |
| kv_budget_fraction: float = Field(default=1.0, ge=0.1, le=1.0) |
| speculation_depth: int = Field(default=0, ge=0, le=8) |
| quantization_tier: Literal["FP16", "INT8", "INT4"] = QuantizationTier.FP16.value |
| prefill_decode_split: bool = False |
| priority_routing: bool = False |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def normalize_web_payload(cls, data: Any) -> Any: |
| if not isinstance(data, dict): |
| return data |
|
|
| normalized = dict(data) |
| normalized["batch_cap"] = _clamp_int(normalized.get("batch_cap"), default=32, minimum=1, maximum=512) |
| normalized["kv_budget_fraction"] = _clamp_float( |
| normalized.get("kv_budget_fraction"), |
| default=1.0, |
| minimum=0.1, |
| maximum=1.0, |
| ) |
| normalized["speculation_depth"] = _clamp_int( |
| normalized.get("speculation_depth"), |
| default=0, |
| minimum=0, |
| maximum=8, |
| ) |
| normalized["quantization_tier"] = _normalize_quantization_tier(normalized.get("quantization_tier")) |
| return normalized |
|
|
|
|
| class ServeObservation(Observation): |
| model_config = ConfigDict(extra="forbid") |
|
|
| queue_depth: int = Field(ge=0) |
| active_requests: int = Field(ge=0) |
| kv_cache_occupancy: float = Field(ge=0.0, le=1.0) |
| mean_prompt_length: float = Field(ge=0.0) |
| p50_ttft_ms: float = Field(ge=0.0) |
| p99_ttft_ms: float = Field(ge=0.0) |
| p50_itl_ms: float = Field(ge=0.0) |
| throughput_tps: float = Field(ge=0.0) |
| slo_compliance_rate: float = Field(ge=0.0, le=1.0) |
| gpu_memory_used_gb: float = Field(ge=0.0) |
| estimated_cost_per_1k: float = Field(ge=0.0) |
| request_arrival_rate: float = Field(ge=0.0) |
| spec_acceptance_rate: float = Field(ge=0.0, le=1.0) |
| eviction_events: int = Field(ge=0) |
| step_index: int = Field(ge=0) |
| task_id: str = "uninitialized" |
|
|
|
|
| class ServeState(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
| episode_id: str |
| step_count: int = Field(ge=0) |
| task_id: str |
| total_requests_served: int = Field(ge=0) |
| total_slo_violations: int = Field(ge=0) |
| cumulative_reward: float = 0.0 |
| elapsed_simulated_time_s: float = Field(ge=0.0) |
| workload_phase: str = "warmup" |
| done: bool = False |
|
|
|
|
| class RewardSignal(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
| reward: float |
| components: dict[str, float] |
| done: bool |
|
|
|
|
| class WorkloadSnapshot(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
| arrival_rate: float = Field(ge=0.0) |
| queue_depth: int = Field(ge=0) |
| mean_prompt_length: float = Field(ge=0.0) |
| prompt_length_bucket: int = Field(ge=0, le=7) |
| priority_fraction: float = Field(ge=0.0, le=1.0) |
| phase: str |
| step_index: int = Field(default=0, ge=0) |
|
|
|
|
| class MetricsSnapshot(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
| p50_ttft_ms: float = Field(ge=0.0) |
| p99_ttft_ms: float = Field(ge=0.0) |
| p50_itl_ms: float = Field(ge=0.0) |
| throughput_tps: float = Field(ge=0.0) |
| gpu_memory_used_gb: float = Field(ge=0.0) |
| estimated_cost_per_1k: float = Field(ge=0.0) |
| spec_acceptance_rate: float = Field(ge=0.0, le=1.0) |
| eviction_events: int = Field(ge=0) |
| preemption_events: int = Field(default=0, ge=0) |
| is_throttled: bool = Field(default=False) |
| slo_violations: int = Field(ge=0) |
| requests_served: int = Field(ge=0) |
|
|
|
|
| class EpisodeLog(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
| task_id: str |
| actions: list[ServeAction] |
| observations: list[ServeObservation] |
| rewards: list[float] |
| final_state: ServeState |
|
|
|
|
| def default_action() -> ServeAction: |
| return ServeAction( |
| batch_cap=32, |
| kv_budget_fraction=1.0, |
| speculation_depth=0, |
| quantization_tier=QuantizationTier.FP16.value, |
| prefill_decode_split=False, |
| priority_routing=False, |
| ) |
|
|
|
|
| def model_to_dict(model: BaseModel) -> dict[str, Any]: |
| return model.model_dump(mode="json") |
|
|
|
|
| def _clamp_int(value: Any, default: int, minimum: int, maximum: int) -> int: |
| try: |
| parsed = int(value) |
| except (TypeError, ValueError): |
| return default |
| return max(minimum, min(maximum, parsed)) |
|
|
|
|
| def _clamp_float(value: Any, default: float, minimum: float, maximum: float) -> float: |
| try: |
| parsed = float(value) |
| except (TypeError, ValueError): |
| return default |
| return max(minimum, min(maximum, parsed)) |
|
|
|
|
| def _normalize_quantization_tier(value: Any) -> str: |
| if isinstance(value, QuantizationTier): |
| return value.value |
| if isinstance(value, str) and value in {tier.value for tier in QuantizationTier}: |
| return value |
| return QuantizationTier.FP16.value |
|
|