#!/usr/bin/env python3 """Inference benchmark over provider runtime and policy stacks.""" from __future__ import annotations import argparse import json from pathlib import Path import time import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.env.env_core import PolyGuardEnv from app.models.policy.provider_runtime import PolicyProviderRouter def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark local inference path.") parser.add_argument("--provider", default="transformers") parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") parser.add_argument("--runs", type=int, default=5) return parser.parse_args() def main() -> None: args = parse_args() env = PolyGuardEnv() router = PolicyProviderRouter(hf_model=args.model) provider_preference = (args.provider,) if args.provider == "transformers" else (args.provider, "transformers") rows = [] for i in range(args.runs): env.reset(seed=7_100 + i, difficulty="medium") obs = env._build_observation() # noqa: SLF001 candidates = list(obs.candidate_action_set) start = time.monotonic() selection = router.select_candidate(candidates, prompt={"run": i}, provider_preference=provider_preference) latency = (time.monotonic() - start) * 1000.0 rows.append( { "run": i, "provider": selection.provider, "candidate_id": selection.candidate_id, "latency_ms": round(latency, 3), "rationale": selection.rationale, } ) avg_latency = sum(item["latency_ms"] for item in rows) / len(rows) if rows else 0.0 payload = { "status": "ok", "runs": rows, "avg_latency_ms": round(avg_latency, 3), "provider_requested": args.provider, "model": args.model, } out = ROOT / "outputs" / "reports" out.mkdir(parents=True, exist_ok=True) (out / "inference_benchmark.json").write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") print("benchmark_inference_done") if __name__ == "__main__": main()