vegarl / llmserve_env /models.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from openenv.core import Action, Observation
from pydantic import BaseModel, ConfigDict, Field, model_validator
class QuantizationTier(str, Enum):
FP16 = "FP16"
INT8 = "INT8"
INT4 = "INT4"
class ServeAction(Action):
model_config = ConfigDict(extra="forbid")
batch_cap: int = Field(default=32, ge=1, le=512)
kv_budget_fraction: float = Field(default=1.0, ge=0.1, le=1.0)
speculation_depth: int = Field(default=0, ge=0, le=8)
quantization_tier: Literal["FP16", "INT8", "INT4"] = QuantizationTier.FP16.value
prefill_decode_split: bool = False
priority_routing: bool = False
@model_validator(mode="before")
@classmethod
def normalize_web_payload(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
normalized = dict(data)
normalized["batch_cap"] = _clamp_int(normalized.get("batch_cap"), default=32, minimum=1, maximum=512)
normalized["kv_budget_fraction"] = _clamp_float(
normalized.get("kv_budget_fraction"),
default=1.0,
minimum=0.1,
maximum=1.0,
)
normalized["speculation_depth"] = _clamp_int(
normalized.get("speculation_depth"),
default=0,
minimum=0,
maximum=8,
)
normalized["quantization_tier"] = _normalize_quantization_tier(normalized.get("quantization_tier"))
return normalized
class ServeObservation(Observation):
model_config = ConfigDict(extra="forbid")
queue_depth: int = Field(ge=0)
active_requests: int = Field(ge=0)
kv_cache_occupancy: float = Field(ge=0.0, le=1.0)
mean_prompt_length: float = Field(ge=0.0)
p50_ttft_ms: float = Field(ge=0.0)
p99_ttft_ms: float = Field(ge=0.0)
p50_itl_ms: float = Field(ge=0.0)
throughput_tps: float = Field(ge=0.0)
slo_compliance_rate: float = Field(ge=0.0, le=1.0)
gpu_memory_used_gb: float = Field(ge=0.0)
estimated_cost_per_1k: float = Field(ge=0.0)
request_arrival_rate: float = Field(ge=0.0)
spec_acceptance_rate: float = Field(ge=0.0, le=1.0)
eviction_events: int = Field(ge=0)
step_index: int = Field(ge=0)
task_id: str = "uninitialized"
class ServeState(BaseModel):
model_config = ConfigDict(extra="forbid")
episode_id: str
step_count: int = Field(ge=0)
task_id: str
total_requests_served: int = Field(ge=0)
total_slo_violations: int = Field(ge=0)
cumulative_reward: float = 0.0
elapsed_simulated_time_s: float = Field(ge=0.0)
workload_phase: str = "warmup"
done: bool = False
class RewardSignal(BaseModel):
model_config = ConfigDict(extra="forbid")
reward: float
components: dict[str, float]
done: bool
class WorkloadSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
arrival_rate: float = Field(ge=0.0)
queue_depth: int = Field(ge=0)
mean_prompt_length: float = Field(ge=0.0)
prompt_length_bucket: int = Field(ge=0, le=7)
priority_fraction: float = Field(ge=0.0, le=1.0)
phase: str
step_index: int = Field(default=0, ge=0)
class MetricsSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
p50_ttft_ms: float = Field(ge=0.0)
p99_ttft_ms: float = Field(ge=0.0)
p50_itl_ms: float = Field(ge=0.0)
throughput_tps: float = Field(ge=0.0)
gpu_memory_used_gb: float = Field(ge=0.0)
estimated_cost_per_1k: float = Field(ge=0.0)
spec_acceptance_rate: float = Field(ge=0.0, le=1.0)
eviction_events: int = Field(ge=0)
preemption_events: int = Field(default=0, ge=0)
is_throttled: bool = Field(default=False)
slo_violations: int = Field(ge=0)
requests_served: int = Field(ge=0)
class EpisodeLog(BaseModel):
model_config = ConfigDict(extra="forbid")
task_id: str
actions: list[ServeAction]
observations: list[ServeObservation]
rewards: list[float]
final_state: ServeState
def default_action() -> ServeAction:
return ServeAction(
batch_cap=32,
kv_budget_fraction=1.0,
speculation_depth=0,
quantization_tier=QuantizationTier.FP16.value,
prefill_decode_split=False,
priority_routing=False,
)
def model_to_dict(model: BaseModel) -> dict[str, Any]:
return model.model_dump(mode="json")
def _clamp_int(value: Any, default: int, minimum: int, maximum: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return max(minimum, min(maximum, parsed))
def _clamp_float(value: Any, default: float, minimum: float, maximum: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return max(minimum, min(maximum, parsed))
def _normalize_quantization_tier(value: Any) -> str:
if isinstance(value, QuantizationTier):
return value.value
if isinstance(value, str) and value in {tier.value for tier in QuantizationTier}:
return value
return QuantizationTier.FP16.value