| |
| """Inference benchmark over provider runtime and policy stacks.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| import time |
|
|
| import sys |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from app.env.env_core import PolyGuardEnv |
| from app.models.policy.provider_runtime import PolicyProviderRouter |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Benchmark local inference path.") |
| parser.add_argument("--provider", default="transformers") |
| parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") |
| parser.add_argument("--runs", type=int, default=5) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| env = PolyGuardEnv() |
| router = PolicyProviderRouter(hf_model=args.model) |
| provider_preference = (args.provider,) if args.provider == "transformers" else (args.provider, "transformers") |
|
|
| rows = [] |
| for i in range(args.runs): |
| env.reset(seed=7_100 + i, difficulty="medium") |
| obs = env._build_observation() |
| candidates = list(obs.candidate_action_set) |
| start = time.monotonic() |
| selection = router.select_candidate(candidates, prompt={"run": i}, provider_preference=provider_preference) |
| latency = (time.monotonic() - start) * 1000.0 |
| rows.append( |
| { |
| "run": i, |
| "provider": selection.provider, |
| "candidate_id": selection.candidate_id, |
| "latency_ms": round(latency, 3), |
| "rationale": selection.rationale, |
| } |
| ) |
|
|
| avg_latency = sum(item["latency_ms"] for item in rows) / len(rows) if rows else 0.0 |
| payload = { |
| "status": "ok", |
| "runs": rows, |
| "avg_latency_ms": round(avg_latency, 3), |
| "provider_requested": args.provider, |
| "model": args.model, |
| } |
|
|
| out = ROOT / "outputs" / "reports" |
| out.mkdir(parents=True, exist_ok=True) |
| (out / "inference_benchmark.json").write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") |
| print("benchmark_inference_done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|