File size: 5,714 Bytes
2d2e8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
"""Benchmark the live AMD MI300X / ROCm vLLM endpoint.

Sends N concurrent chat-completion requests against the configured vLLM server,
records per-request latency + token counts, and writes a structured summary to
``assets/rocm_benchmark.json`` so the Streamlit UI and README can display real,
reproducible AMD ROCm performance evidence for the hackathon submission.

Usage:
    python scripts/rocm_benchmark.py [--concurrency 4] [--requests 12] [--prompt-tokens 512]

Reads VLLM_BASE_URL, VLLM_API_KEY, MODEL_NAME from the environment / .env.
"""

from __future__ import annotations

import argparse
import json
import os
import statistics
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import httpx
from dotenv import load_dotenv


PROMPT = (
    "You are a senior detection engineer. In two short sentences, summarize how a "
    "Sigma rule for MITRE ATT&CK T1059.001 (PowerShell) should reason about parent "
    "process lineage and command-line obfuscation. Be concrete."
)


def _post_chat(client: httpx.Client, base_url: str, api_key: str, model: str) -> dict:
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    payload = {
        "model": model,
        "messages": [{"role": "user", "content": PROMPT}],
        "temperature": 0.2,
        "max_tokens": 256,
    }
    started = time.perf_counter()
    resp = client.post(
        base_url.rstrip("/") + "/chat/completions",
        json=payload,
        headers=headers,
        timeout=120.0,
    )
    elapsed_ms = (time.perf_counter() - started) * 1000.0
    resp.raise_for_status()
    data = resp.json()
    usage = data.get("usage") or {}
    completion = data["choices"][0]["message"]["content"]
    return {
        "latency_ms": elapsed_ms,
        "prompt_tokens": int(usage.get("prompt_tokens") or 0),
        "completion_tokens": int(usage.get("completion_tokens") or 0),
        "total_tokens": int(usage.get("total_tokens") or 0),
        "completion_chars": len(completion),
    }


def _percentile(values: list[float], pct: float) -> Optional[float]:
    if not values:
        return None
    sorted_values = sorted(values)
    k = (len(sorted_values) - 1) * (pct / 100.0)
    lo = int(k)
    hi = min(lo + 1, len(sorted_values) - 1)
    frac = k - lo
    return round(sorted_values[lo] * (1 - frac) + sorted_values[hi] * frac, 2)


def main() -> int:
    load_dotenv()
    parser = argparse.ArgumentParser(description="Benchmark AegisOps AI vLLM endpoint on AMD MI300X / ROCm")
    parser.add_argument("--requests", type=int, default=12, help="total request count")
    parser.add_argument("--concurrency", type=int, default=4, help="parallel workers")
    parser.add_argument("--output", type=str, default=None, help="output JSON path (default: assets/rocm_benchmark.json)")
    args = parser.parse_args()

    base_url = os.getenv("VLLM_BASE_URL")
    api_key = os.getenv("VLLM_API_KEY", "")
    model = os.getenv("MODEL_NAME")
    if not base_url or not model:
        print("ERROR: VLLM_BASE_URL and MODEL_NAME must be set (use .env or shell env).")
        return 1

    output = Path(args.output) if args.output else Path(__file__).resolve().parent.parent / "assets" / "rocm_benchmark.json"
    output.parent.mkdir(parents=True, exist_ok=True)

    print(f"Benchmarking {args.requests} requests @ concurrency={args.concurrency}")
    print(f"  endpoint: {base_url}")
    print(f"  model:    {model}")

    results: list[dict] = []
    errors = 0
    started_wall = time.perf_counter()
    with httpx.Client() as client:
        with ThreadPoolExecutor(max_workers=args.concurrency) as pool:
            futures = [
                pool.submit(_post_chat, client, base_url, api_key, model)
                for _ in range(args.requests)
            ]
            for future in as_completed(futures):
                try:
                    results.append(future.result())
                except Exception as exc:  # noqa: BLE001
                    errors += 1
                    print(f"  request failed: {type(exc).__name__}: {exc}")

    wall_seconds = max(time.perf_counter() - started_wall, 1e-6)
    if not results:
        print("ERROR: no successful requests; nothing written.")
        return 2

    latencies = [r["latency_ms"] for r in results]
    completion_tokens = sum(r["completion_tokens"] for r in results)
    total_tokens = sum(r["total_tokens"] for r in results)
    tps = round(completion_tokens / wall_seconds, 2)

    summary = {
        "captured_at": datetime.now(timezone.utc).isoformat(),
        "endpoint": base_url,
        "model": model,
        "runtime": "vLLM on ROCm container",
        "gpu": "AMD Instinct MI300X (AMD Developer Cloud)",
        "concurrency": args.concurrency,
        "requests": args.requests,
        "successful": len(results),
        "failed": errors,
        "wall_clock_seconds": round(wall_seconds, 3),
        "latency_ms_p50": _percentile(latencies, 50),
        "latency_ms_p95": _percentile(latencies, 95),
        "latency_ms_avg": round(statistics.fmean(latencies), 2),
        "latency_ms_min": round(min(latencies), 2),
        "latency_ms_max": round(max(latencies), 2),
        "tokens_per_second": tps,
        "completion_tokens_total": completion_tokens,
        "total_tokens": total_tokens,
        "prompt": PROMPT,
    }

    output.write_text(json.dumps(summary, indent=2) + "\n")
    print(f"Wrote {output}")
    print(json.dumps(summary, indent=2))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())