File size: 6,393 Bytes
4fbc241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from __future__ import annotations
import random
from typing import Any
from llmserve_env.models import WorkloadSnapshot
from server.replay_assets import load_prompt_samples, load_trace_table
class WorkloadGenerator:
def __init__(self, task_config: dict[str, Any], seed: int = 42) -> None:
self.task_config = task_config
self.seed = seed
self.rng = random.Random(seed)
self.queue_depth = 0
self.trace_rows = self._load_trace_rows()
self.prompt_samples = self._load_prompt_samples()
def reset(self, seed: int | None = None) -> None:
if seed is not None:
self.seed = seed
self.rng = random.Random(self.seed)
self.queue_depth = 0
def next_snapshot(self, step_index: int) -> WorkloadSnapshot:
trace_row = self._trace_row_for_step(step_index)
arrival_rate = self._arrival_rate_for_step(step_index, trace_row)
if self.task_config["id"] == "adversarial_multitenant" and (step_index + 1) % 100 == 0:
mean_prompt_length = 16384.0
phase = "mega-prompt"
else:
mean_prompt_length = self._prompt_length_for_step(trace_row)
phase = self._phase_for_step(step_index, trace_row)
service_hint = float(trace_row.get("service_rate_hint", arrival_rate * 0.6)) if trace_row else arrival_rate * 0.6
served_estimate = min(self.queue_depth, max(1, int(service_hint)))
queue_bias = int(trace_row.get("queue_bias", 0)) if trace_row else 0
self.queue_depth = max(0, self.queue_depth + int(arrival_rate) - served_estimate + queue_bias)
return WorkloadSnapshot(
arrival_rate=arrival_rate,
queue_depth=self.queue_depth,
mean_prompt_length=mean_prompt_length,
prompt_length_bucket=self._prompt_bucket(mean_prompt_length),
priority_fraction=float(trace_row.get("priority_fraction", self.task_config.get("priority_fraction", 0.0)))
if trace_row
else float(self.task_config.get("priority_fraction", 0.0)),
phase=phase,
step_index=step_index,
)
def _arrival_rate_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> float:
if trace_row and "arrival_rate_rps" in trace_row:
return float(trace_row["arrival_rate_rps"])
base = float(self.task_config["arrival_rate_rps"])
burst_rate = float(self.task_config.get("burst_rate_rps", base))
burst_every = int(self.task_config.get("burst_every_steps", 0))
burst_length = int(self.task_config.get("burst_length_steps", 0))
if burst_every and burst_length:
window = step_index % burst_every
if window < burst_length:
return burst_rate
if self.task_config.get("arrival_pattern") == "sinusoidal":
floor = float(self.task_config.get("arrival_floor_rps", base))
ceiling = float(self.task_config.get("arrival_ceiling_rps", burst_rate))
cycle = max(1, int(self.task_config.get("arrival_cycle_steps", 50)))
alpha = (step_index % cycle) / cycle
return floor + (ceiling - floor) * (0.5 + 0.5 * (1 if alpha < 0.5 else -1))
return base
def _prompt_length_for_step(self, trace_row: dict[str, Any] | None = None) -> float:
mode = self.task_config["prompt_distribution"]["type"]
if mode == "trace_sample":
sample_pool = self.prompt_samples or [128]
if trace_row:
prompt_p50 = float(trace_row.get("prompt_p50", min(sample_pool)))
prompt_p95 = float(trace_row.get("prompt_p95", max(sample_pool)))
bounded_pool = [
sample
for sample in sample_pool
if (prompt_p50 * 0.5) <= sample <= max(prompt_p95 * 1.1, prompt_p50 + 1.0)
]
sample_pool = bounded_pool or sample_pool
return float(self.rng.choice(sample_pool))
if mode == "uniform":
low = self.task_config["prompt_distribution"]["min"]
high = self.task_config["prompt_distribution"]["max"]
return self.rng.uniform(low, high)
if mode == "bimodal":
short = self.task_config["prompt_distribution"]["short"]
long = self.task_config["prompt_distribution"]["long"]
fraction = self.task_config["prompt_distribution"]["long_fraction"]
bucket = long if self.rng.random() < fraction else short
return self.rng.uniform(bucket["min"], bucket["max"])
low = self.task_config["prompt_distribution"]["min"]
high = self.task_config["prompt_distribution"]["max"]
return self.rng.uniform(low, high)
def _phase_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> str:
if trace_row and "phase" in trace_row:
return str(trace_row["phase"])
burst_every = int(self.task_config.get("burst_every_steps", 0))
burst_length = int(self.task_config.get("burst_length_steps", 0))
if burst_every and (step_index % burst_every) < burst_length:
return "burst"
if step_index < 3:
return "warmup"
if step_index >= int(self.task_config["max_steps"]) - 3:
return "cooldown"
return "steady"
@staticmethod
def _prompt_bucket(prompt_length: float) -> int:
boundaries = [64, 128, 256, 512, 1024, 2048, 4096]
for idx, boundary in enumerate(boundaries):
if prompt_length <= boundary:
return idx
return 7
def _load_trace_rows(self) -> list[dict[str, Any]]:
trace_file = self.task_config.get("trace_file")
if not trace_file:
return []
frame = load_trace_table(trace_file)
return frame.to_dict(orient="records")
def _load_prompt_samples(self) -> list[int]:
distribution = self.task_config.get("prompt_distribution", {})
sample_file = distribution.get("sample_file")
if not sample_file:
return []
return load_prompt_samples(sample_file)
def _trace_row_for_step(self, step_index: int) -> dict[str, Any] | None:
if not self.trace_rows:
return None
return self.trace_rows[step_index % len(self.trace_rows)]
|