from __future__ import annotations import random from typing import Any from llmserve_env.models import WorkloadSnapshot from server.replay_assets import load_prompt_samples, load_trace_table class WorkloadGenerator: def __init__(self, task_config: dict[str, Any], seed: int = 42) -> None: self.task_config = task_config self.seed = seed self.rng = random.Random(seed) self.queue_depth = 0 self.trace_rows = self._load_trace_rows() self.prompt_samples = self._load_prompt_samples() def reset(self, seed: int | None = None) -> None: if seed is not None: self.seed = seed self.rng = random.Random(self.seed) self.queue_depth = 0 def next_snapshot(self, step_index: int) -> WorkloadSnapshot: trace_row = self._trace_row_for_step(step_index) arrival_rate = self._arrival_rate_for_step(step_index, trace_row) if self.task_config["id"] == "adversarial_multitenant" and (step_index + 1) % 100 == 0: mean_prompt_length = 16384.0 phase = "mega-prompt" else: mean_prompt_length = self._prompt_length_for_step(trace_row) phase = self._phase_for_step(step_index, trace_row) service_hint = float(trace_row.get("service_rate_hint", arrival_rate * 0.6)) if trace_row else arrival_rate * 0.6 served_estimate = min(self.queue_depth, max(1, int(service_hint))) queue_bias = int(trace_row.get("queue_bias", 0)) if trace_row else 0 self.queue_depth = max(0, self.queue_depth + int(arrival_rate) - served_estimate + queue_bias) return WorkloadSnapshot( arrival_rate=arrival_rate, queue_depth=self.queue_depth, mean_prompt_length=mean_prompt_length, prompt_length_bucket=self._prompt_bucket(mean_prompt_length), priority_fraction=float(trace_row.get("priority_fraction", self.task_config.get("priority_fraction", 0.0))) if trace_row else float(self.task_config.get("priority_fraction", 0.0)), phase=phase, step_index=step_index, ) def _arrival_rate_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> float: if trace_row and "arrival_rate_rps" in trace_row: return float(trace_row["arrival_rate_rps"]) base = float(self.task_config["arrival_rate_rps"]) burst_rate = float(self.task_config.get("burst_rate_rps", base)) burst_every = int(self.task_config.get("burst_every_steps", 0)) burst_length = int(self.task_config.get("burst_length_steps", 0)) if burst_every and burst_length: window = step_index % burst_every if window < burst_length: return burst_rate if self.task_config.get("arrival_pattern") == "sinusoidal": floor = float(self.task_config.get("arrival_floor_rps", base)) ceiling = float(self.task_config.get("arrival_ceiling_rps", burst_rate)) cycle = max(1, int(self.task_config.get("arrival_cycle_steps", 50))) alpha = (step_index % cycle) / cycle return floor + (ceiling - floor) * (0.5 + 0.5 * (1 if alpha < 0.5 else -1)) return base def _prompt_length_for_step(self, trace_row: dict[str, Any] | None = None) -> float: mode = self.task_config["prompt_distribution"]["type"] if mode == "trace_sample": sample_pool = self.prompt_samples or [128] if trace_row: prompt_p50 = float(trace_row.get("prompt_p50", min(sample_pool))) prompt_p95 = float(trace_row.get("prompt_p95", max(sample_pool))) bounded_pool = [ sample for sample in sample_pool if (prompt_p50 * 0.5) <= sample <= max(prompt_p95 * 1.1, prompt_p50 + 1.0) ] sample_pool = bounded_pool or sample_pool return float(self.rng.choice(sample_pool)) if mode == "uniform": low = self.task_config["prompt_distribution"]["min"] high = self.task_config["prompt_distribution"]["max"] return self.rng.uniform(low, high) if mode == "bimodal": short = self.task_config["prompt_distribution"]["short"] long = self.task_config["prompt_distribution"]["long"] fraction = self.task_config["prompt_distribution"]["long_fraction"] bucket = long if self.rng.random() < fraction else short return self.rng.uniform(bucket["min"], bucket["max"]) low = self.task_config["prompt_distribution"]["min"] high = self.task_config["prompt_distribution"]["max"] return self.rng.uniform(low, high) def _phase_for_step(self, step_index: int, trace_row: dict[str, Any] | None = None) -> str: if trace_row and "phase" in trace_row: return str(trace_row["phase"]) burst_every = int(self.task_config.get("burst_every_steps", 0)) burst_length = int(self.task_config.get("burst_length_steps", 0)) if burst_every and (step_index % burst_every) < burst_length: return "burst" if step_index < 3: return "warmup" if step_index >= int(self.task_config["max_steps"]) - 3: return "cooldown" return "steady" @staticmethod def _prompt_bucket(prompt_length: float) -> int: boundaries = [64, 128, 256, 512, 1024, 2048, 4096] for idx, boundary in enumerate(boundaries): if prompt_length <= boundary: return idx return 7 def _load_trace_rows(self) -> list[dict[str, Any]]: trace_file = self.task_config.get("trace_file") if not trace_file: return [] frame = load_trace_table(trace_file) return frame.to_dict(orient="records") def _load_prompt_samples(self) -> list[int]: distribution = self.task_config.get("prompt_distribution", {}) sample_file = distribution.get("sample_file") if not sample_file: return [] return load_prompt_samples(sample_file) def _trace_row_for_step(self, step_index: int) -> dict[str, Any] | None: if not self.trace_rows: return None return self.trace_rows[step_index % len(self.trace_rows)]