adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Evaluate baseline policies on one sampled case."""
from __future__ import annotations
import json
import os
from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.api.service import APIService
from app.env.env_core import PolyGuardEnv
def _evaluate_no_change_baseline(episodes: int = 8) -> dict[str, float | int | str]:
rewards: list[float] = []
legal: list[float] = []
for idx in range(episodes):
env = PolyGuardEnv()
env.reset(seed=99 + idx, difficulty="medium")
candidates = env.get_candidate_actions()
action = next((item for item in candidates if item.get("candidate_id") == "cand_01"), None)
if action is None and candidates:
action = candidates[0]
if action is None:
rewards.append(0.001)
legal.append(0.0)
continue
_, reward, _, info = env.step(action)
rewards.append(float(reward))
legal.append(1.0 if bool(info.get("safety_report", {}).get("legal")) else 0.0)
return {
"baseline_policy": "no_change_candidate",
"episodes": episodes,
"avg_reward": round(sum(rewards) / len(rewards), 6) if rewards else 0.0,
"legality_rate": round(sum(legal) / len(legal), 6) if legal else 0.0,
"success_rate": 0.0,
}
def main() -> None:
service = APIService()
service.reset(seed=99, difficulty="medium")
out = service.run_baselines()
out.update(_evaluate_no_change_baseline())
ablations: dict[str, dict[str, float]] = {}
for stack in ["bandit-only", "llm-only", "llm+bandit"]:
os.environ["POLYGUARD_POLICY_STACK"] = stack
service.reset(seed=99, difficulty="medium")
rollout_rewards: list[float] = []
legal: list[float] = []
for _ in range(3):
step = service.orchestrate()
rollout_rewards.append(float(step.get("reward", 0.0)))
legal.append(1.0 if bool(step.get("critic", {}).get("legal")) else 0.0)
if step.get("done"):
break
ablations[stack] = {
"avg_reward": (sum(rollout_rewards) / len(rollout_rewards)) if rollout_rewards else 0.0,
"legality_rate": (sum(legal) / len(legal)) if legal else 0.0,
"steps": float(len(rollout_rewards)),
}
out["policy_stack_ablations"] = ablations
root = Path(__file__).resolve().parents[1]
report = root / "outputs" / "reports"
report.mkdir(parents=True, exist_ok=True)
(report / "baselines.json").write_text(json.dumps(out, ensure_ascii=True, indent=2), encoding="utf-8")
print("baseline_eval_done")
if __name__ == "__main__":
main()