#!/usr/bin/env python3 """Benchmark the live AMD MI300X / ROCm vLLM endpoint. Sends N concurrent chat-completion requests against the configured vLLM server, records per-request latency + token counts, and writes a structured summary to ``assets/rocm_benchmark.json`` so the Streamlit UI and README can display real, reproducible AMD ROCm performance evidence for the hackathon submission. Usage: python scripts/rocm_benchmark.py [--concurrency 4] [--requests 12] [--prompt-tokens 512] Reads VLLM_BASE_URL, VLLM_API_KEY, MODEL_NAME from the environment / .env. """ from __future__ import annotations import argparse import json import os import statistics import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Optional import httpx from dotenv import load_dotenv PROMPT = ( "You are a senior detection engineer. In two short sentences, summarize how a " "Sigma rule for MITRE ATT&CK T1059.001 (PowerShell) should reason about parent " "process lineage and command-line obfuscation. Be concrete." ) def _post_chat(client: httpx.Client, base_url: str, api_key: str, model: str) -> dict: headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" payload = { "model": model, "messages": [{"role": "user", "content": PROMPT}], "temperature": 0.2, "max_tokens": 256, } started = time.perf_counter() resp = client.post( base_url.rstrip("/") + "/chat/completions", json=payload, headers=headers, timeout=120.0, ) elapsed_ms = (time.perf_counter() - started) * 1000.0 resp.raise_for_status() data = resp.json() usage = data.get("usage") or {} completion = data["choices"][0]["message"]["content"] return { "latency_ms": elapsed_ms, "prompt_tokens": int(usage.get("prompt_tokens") or 0), "completion_tokens": int(usage.get("completion_tokens") or 0), "total_tokens": int(usage.get("total_tokens") or 0), "completion_chars": len(completion), } def _percentile(values: list[float], pct: float) -> Optional[float]: if not values: return None sorted_values = sorted(values) k = (len(sorted_values) - 1) * (pct / 100.0) lo = int(k) hi = min(lo + 1, len(sorted_values) - 1) frac = k - lo return round(sorted_values[lo] * (1 - frac) + sorted_values[hi] * frac, 2) def main() -> int: load_dotenv() parser = argparse.ArgumentParser(description="Benchmark AegisOps AI vLLM endpoint on AMD MI300X / ROCm") parser.add_argument("--requests", type=int, default=12, help="total request count") parser.add_argument("--concurrency", type=int, default=4, help="parallel workers") parser.add_argument("--output", type=str, default=None, help="output JSON path (default: assets/rocm_benchmark.json)") args = parser.parse_args() base_url = os.getenv("VLLM_BASE_URL") api_key = os.getenv("VLLM_API_KEY", "") model = os.getenv("MODEL_NAME") if not base_url or not model: print("ERROR: VLLM_BASE_URL and MODEL_NAME must be set (use .env or shell env).") return 1 output = Path(args.output) if args.output else Path(__file__).resolve().parent.parent / "assets" / "rocm_benchmark.json" output.parent.mkdir(parents=True, exist_ok=True) print(f"Benchmarking {args.requests} requests @ concurrency={args.concurrency}") print(f" endpoint: {base_url}") print(f" model: {model}") results: list[dict] = [] errors = 0 started_wall = time.perf_counter() with httpx.Client() as client: with ThreadPoolExecutor(max_workers=args.concurrency) as pool: futures = [ pool.submit(_post_chat, client, base_url, api_key, model) for _ in range(args.requests) ] for future in as_completed(futures): try: results.append(future.result()) except Exception as exc: # noqa: BLE001 errors += 1 print(f" request failed: {type(exc).__name__}: {exc}") wall_seconds = max(time.perf_counter() - started_wall, 1e-6) if not results: print("ERROR: no successful requests; nothing written.") return 2 latencies = [r["latency_ms"] for r in results] completion_tokens = sum(r["completion_tokens"] for r in results) total_tokens = sum(r["total_tokens"] for r in results) tps = round(completion_tokens / wall_seconds, 2) summary = { "captured_at": datetime.now(timezone.utc).isoformat(), "endpoint": base_url, "model": model, "runtime": "vLLM on ROCm container", "gpu": "AMD Instinct MI300X (AMD Developer Cloud)", "concurrency": args.concurrency, "requests": args.requests, "successful": len(results), "failed": errors, "wall_clock_seconds": round(wall_seconds, 3), "latency_ms_p50": _percentile(latencies, 50), "latency_ms_p95": _percentile(latencies, 95), "latency_ms_avg": round(statistics.fmean(latencies), 2), "latency_ms_min": round(min(latencies), 2), "latency_ms_max": round(max(latencies), 2), "tokens_per_second": tps, "completion_tokens_total": completion_tokens, "total_tokens": total_tokens, "prompt": PROMPT, } output.write_text(json.dumps(summary, indent=2) + "\n") print(f"Wrote {output}") print(json.dumps(summary, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())