vegarl / server /workload_generator.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
from __future__ import annotations
import random
from typing import Any
from llmserve_env.models import WorkloadSnapshot
from server.replay_assets import load_prompt_samples, load_trace_table
class WorkloadGenerator:
def __init__(self, task_config: dict[str, Any], seed: int = 42) -> None:
self.task_config = task_config
self.seed = seed
self.rng = random.Random(seed)
self.queue_depth = 0
self.trace_rows = self._load_trace_rows()
self.prompt_samples = self._load_prompt_samples()
def reset(self, seed: int | None = None) -> None:
if seed is not None:
self.seed = seed
self.rng = random.Random(self.seed)
self.queue_depth = 0
def next_snapshot(self, step_index: int) -> WorkloadSnapshot:
trace_row = self._trace_row_for_step(step_index)
arrival_rate = self._arrival_rate_for_step(step_index, trace_row)
if self.task_config["id"] == "adversarial_multitenant" and (step_index + 1) % 100 == 0:
mean_prompt_length = 16384.0
phase = "mega-prompt"
else:
mean_prompt_length = self._prompt_length_for_step(trace_row)
phase = self._phase_for_step(step_index, trace_row)
service_hint = float(trace_row.get("service_rate_hint", arrival_rate * 0.6)) if trace_row else arrival_rate * 0.6
served_estimate = min(self.queue_depth, max(1, int(service_hint)))
queue_bias = int(trace_row.get("queue_bias", 0)) if trace_row else 0
self.queue_depth = max(0, self.queue_depth + int(arrival_rate) - served_estimate + queue_bias)
return WorkloadSnapshot(
arrival_rate=arrival_rate,
queue_depth=self.queue_depth,
mean_prompt_length=mean_prompt_length,
prompt_length_bucket=self._prompt_bucket(mean_prompt_length),
priority_fraction=float(trace_row.get("priority_fraction", self.task_config.get("priority_fraction", 0.0)))
if trace_row
else float(self.task_config.get("priority_fraction", 0.0)),
phase=phase,
step_index=step_index,
)
def _arrival_rate_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> float:
if trace_row and "arrival_rate_rps" in trace_row:
return float(trace_row["arrival_rate_rps"])
base = float(self.task_config["arrival_rate_rps"])
burst_rate = float(self.task_config.get("burst_rate_rps", base))
burst_every = int(self.task_config.get("burst_every_steps", 0))
burst_length = int(self.task_config.get("burst_length_steps", 0))
if burst_every and burst_length:
window = step_index % burst_every
if window < burst_length:
return burst_rate
if self.task_config.get("arrival_pattern") == "sinusoidal":
floor = float(self.task_config.get("arrival_floor_rps", base))
ceiling = float(self.task_config.get("arrival_ceiling_rps", burst_rate))
cycle = max(1, int(self.task_config.get("arrival_cycle_steps", 50)))
alpha = (step_index % cycle) / cycle
return floor + (ceiling - floor) * (0.5 + 0.5 * (1 if alpha < 0.5 else -1))
return base
def _prompt_length_for_step(self, trace_row: dict[str, Any] | None = None) -> float:
mode = self.task_config["prompt_distribution"]["type"]
if mode == "trace_sample":
sample_pool = self.prompt_samples or [128]
if trace_row:
prompt_p50 = float(trace_row.get("prompt_p50", min(sample_pool)))
prompt_p95 = float(trace_row.get("prompt_p95", max(sample_pool)))
bounded_pool = [
sample
for sample in sample_pool
if (prompt_p50 * 0.5) <= sample <= max(prompt_p95 * 1.1, prompt_p50 + 1.0)
]
sample_pool = bounded_pool or sample_pool
return float(self.rng.choice(sample_pool))
if mode == "uniform":
low = self.task_config["prompt_distribution"]["min"]
high = self.task_config["prompt_distribution"]["max"]
return self.rng.uniform(low, high)
if mode == "bimodal":
short = self.task_config["prompt_distribution"]["short"]
long = self.task_config["prompt_distribution"]["long"]
fraction = self.task_config["prompt_distribution"]["long_fraction"]
bucket = long if self.rng.random() < fraction else short
return self.rng.uniform(bucket["min"], bucket["max"])
low = self.task_config["prompt_distribution"]["min"]
high = self.task_config["prompt_distribution"]["max"]
return self.rng.uniform(low, high)
def _phase_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> str:
if trace_row and "phase" in trace_row:
return str(trace_row["phase"])
burst_every = int(self.task_config.get("burst_every_steps", 0))
burst_length = int(self.task_config.get("burst_length_steps", 0))
if burst_every and (step_index % burst_every) < burst_length:
return "burst"
if step_index < 3:
return "warmup"
if step_index >= int(self.task_config["max_steps"]) - 3:
return "cooldown"
return "steady"
@staticmethod
def _prompt_bucket(prompt_length: float) -> int:
boundaries = [64, 128, 256, 512, 1024, 2048, 4096]
for idx, boundary in enumerate(boundaries):
if prompt_length <= boundary:
return idx
return 7
def _load_trace_rows(self) -> list[dict[str, Any]]:
trace_file = self.task_config.get("trace_file")
if not trace_file:
return []
frame = load_trace_table(trace_file)
return frame.to_dict(orient="records")
def _load_prompt_samples(self) -> list[int]:
distribution = self.task_config.get("prompt_distribution", {})
sample_file = distribution.get("sample_file")
if not sample_file:
return []
return load_prompt_samples(sample_file)
def _trace_row_for_step(self, step_index: int) -> dict[str, Any] | None:
if not self.trace_rows:
return None
return self.trace_rows[step_index % len(self.trace_rows)]