Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Inference benchmark over provider runtime and policy stacks.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import time | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.env.env_core import PolyGuardEnv | |
| from app.models.policy.provider_runtime import PolicyProviderRouter | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Benchmark local inference path.") | |
| parser.add_argument("--provider", default="transformers") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") | |
| parser.add_argument("--runs", type=int, default=5) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| env = PolyGuardEnv() | |
| router = PolicyProviderRouter(hf_model=args.model) | |
| provider_preference = (args.provider,) if args.provider == "transformers" else (args.provider, "transformers") | |
| rows = [] | |
| for i in range(args.runs): | |
| env.reset(seed=7_100 + i, difficulty="medium") | |
| obs = env._build_observation() # noqa: SLF001 | |
| candidates = list(obs.candidate_action_set) | |
| start = time.monotonic() | |
| selection = router.select_candidate(candidates, prompt={"run": i}, provider_preference=provider_preference) | |
| latency = (time.monotonic() - start) * 1000.0 | |
| rows.append( | |
| { | |
| "run": i, | |
| "provider": selection.provider, | |
| "candidate_id": selection.candidate_id, | |
| "latency_ms": round(latency, 3), | |
| "rationale": selection.rationale, | |
| } | |
| ) | |
| avg_latency = sum(item["latency_ms"] for item in rows) / len(rows) if rows else 0.0 | |
| payload = { | |
| "status": "ok", | |
| "runs": rows, | |
| "avg_latency_ms": round(avg_latency, 3), | |
| "provider_requested": args.provider, | |
| "model": args.model, | |
| } | |
| out = ROOT / "outputs" / "reports" | |
| out.mkdir(parents=True, exist_ok=True) | |
| (out / "inference_benchmark.json").write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("benchmark_inference_done") | |
| if __name__ == "__main__": | |
| main() | |