Spaces:
Running
Running
File size: 5,714 Bytes
2d2e8fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | #!/usr/bin/env python3
"""Benchmark the live AMD MI300X / ROCm vLLM endpoint.
Sends N concurrent chat-completion requests against the configured vLLM server,
records per-request latency + token counts, and writes a structured summary to
``assets/rocm_benchmark.json`` so the Streamlit UI and README can display real,
reproducible AMD ROCm performance evidence for the hackathon submission.
Usage:
python scripts/rocm_benchmark.py [--concurrency 4] [--requests 12] [--prompt-tokens 512]
Reads VLLM_BASE_URL, VLLM_API_KEY, MODEL_NAME from the environment / .env.
"""
from __future__ import annotations
import argparse
import json
import os
import statistics
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
from dotenv import load_dotenv
PROMPT = (
"You are a senior detection engineer. In two short sentences, summarize how a "
"Sigma rule for MITRE ATT&CK T1059.001 (PowerShell) should reason about parent "
"process lineage and command-line obfuscation. Be concrete."
)
def _post_chat(client: httpx.Client, base_url: str, api_key: str, model: str) -> dict:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {
"model": model,
"messages": [{"role": "user", "content": PROMPT}],
"temperature": 0.2,
"max_tokens": 256,
}
started = time.perf_counter()
resp = client.post(
base_url.rstrip("/") + "/chat/completions",
json=payload,
headers=headers,
timeout=120.0,
)
elapsed_ms = (time.perf_counter() - started) * 1000.0
resp.raise_for_status()
data = resp.json()
usage = data.get("usage") or {}
completion = data["choices"][0]["message"]["content"]
return {
"latency_ms": elapsed_ms,
"prompt_tokens": int(usage.get("prompt_tokens") or 0),
"completion_tokens": int(usage.get("completion_tokens") or 0),
"total_tokens": int(usage.get("total_tokens") or 0),
"completion_chars": len(completion),
}
def _percentile(values: list[float], pct: float) -> Optional[float]:
if not values:
return None
sorted_values = sorted(values)
k = (len(sorted_values) - 1) * (pct / 100.0)
lo = int(k)
hi = min(lo + 1, len(sorted_values) - 1)
frac = k - lo
return round(sorted_values[lo] * (1 - frac) + sorted_values[hi] * frac, 2)
def main() -> int:
load_dotenv()
parser = argparse.ArgumentParser(description="Benchmark AegisOps AI vLLM endpoint on AMD MI300X / ROCm")
parser.add_argument("--requests", type=int, default=12, help="total request count")
parser.add_argument("--concurrency", type=int, default=4, help="parallel workers")
parser.add_argument("--output", type=str, default=None, help="output JSON path (default: assets/rocm_benchmark.json)")
args = parser.parse_args()
base_url = os.getenv("VLLM_BASE_URL")
api_key = os.getenv("VLLM_API_KEY", "")
model = os.getenv("MODEL_NAME")
if not base_url or not model:
print("ERROR: VLLM_BASE_URL and MODEL_NAME must be set (use .env or shell env).")
return 1
output = Path(args.output) if args.output else Path(__file__).resolve().parent.parent / "assets" / "rocm_benchmark.json"
output.parent.mkdir(parents=True, exist_ok=True)
print(f"Benchmarking {args.requests} requests @ concurrency={args.concurrency}")
print(f" endpoint: {base_url}")
print(f" model: {model}")
results: list[dict] = []
errors = 0
started_wall = time.perf_counter()
with httpx.Client() as client:
with ThreadPoolExecutor(max_workers=args.concurrency) as pool:
futures = [
pool.submit(_post_chat, client, base_url, api_key, model)
for _ in range(args.requests)
]
for future in as_completed(futures):
try:
results.append(future.result())
except Exception as exc: # noqa: BLE001
errors += 1
print(f" request failed: {type(exc).__name__}: {exc}")
wall_seconds = max(time.perf_counter() - started_wall, 1e-6)
if not results:
print("ERROR: no successful requests; nothing written.")
return 2
latencies = [r["latency_ms"] for r in results]
completion_tokens = sum(r["completion_tokens"] for r in results)
total_tokens = sum(r["total_tokens"] for r in results)
tps = round(completion_tokens / wall_seconds, 2)
summary = {
"captured_at": datetime.now(timezone.utc).isoformat(),
"endpoint": base_url,
"model": model,
"runtime": "vLLM on ROCm container",
"gpu": "AMD Instinct MI300X (AMD Developer Cloud)",
"concurrency": args.concurrency,
"requests": args.requests,
"successful": len(results),
"failed": errors,
"wall_clock_seconds": round(wall_seconds, 3),
"latency_ms_p50": _percentile(latencies, 50),
"latency_ms_p95": _percentile(latencies, 95),
"latency_ms_avg": round(statistics.fmean(latencies), 2),
"latency_ms_min": round(min(latencies), 2),
"latency_ms_max": round(max(latencies), 2),
"tokens_per_second": tps,
"completion_tokens_total": completion_tokens,
"total_tokens": total_tokens,
"prompt": PROMPT,
}
output.write_text(json.dumps(summary, indent=2) + "\n")
print(f"Wrote {output}")
print(json.dumps(summary, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())
|