Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Evaluate baseline policies on one sampled case.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.api.service import APIService | |
| from app.env.env_core import PolyGuardEnv | |
| def _evaluate_no_change_baseline(episodes: int = 8) -> dict[str, float | int | str]: | |
| rewards: list[float] = [] | |
| legal: list[float] = [] | |
| for idx in range(episodes): | |
| env = PolyGuardEnv() | |
| env.reset(seed=99 + idx, difficulty="medium") | |
| candidates = env.get_candidate_actions() | |
| action = next((item for item in candidates if item.get("candidate_id") == "cand_01"), None) | |
| if action is None and candidates: | |
| action = candidates[0] | |
| if action is None: | |
| rewards.append(0.001) | |
| legal.append(0.0) | |
| continue | |
| _, reward, _, info = env.step(action) | |
| rewards.append(float(reward)) | |
| legal.append(1.0 if bool(info.get("safety_report", {}).get("legal")) else 0.0) | |
| return { | |
| "baseline_policy": "no_change_candidate", | |
| "episodes": episodes, | |
| "avg_reward": round(sum(rewards) / len(rewards), 6) if rewards else 0.0, | |
| "legality_rate": round(sum(legal) / len(legal), 6) if legal else 0.0, | |
| "success_rate": 0.0, | |
| } | |
| def main() -> None: | |
| service = APIService() | |
| service.reset(seed=99, difficulty="medium") | |
| out = service.run_baselines() | |
| out.update(_evaluate_no_change_baseline()) | |
| ablations: dict[str, dict[str, float]] = {} | |
| for stack in ["bandit-only", "llm-only", "llm+bandit"]: | |
| os.environ["POLYGUARD_POLICY_STACK"] = stack | |
| service.reset(seed=99, difficulty="medium") | |
| rollout_rewards: list[float] = [] | |
| legal: list[float] = [] | |
| for _ in range(3): | |
| step = service.orchestrate() | |
| rollout_rewards.append(float(step.get("reward", 0.0))) | |
| legal.append(1.0 if bool(step.get("critic", {}).get("legal")) else 0.0) | |
| if step.get("done"): | |
| break | |
| ablations[stack] = { | |
| "avg_reward": (sum(rollout_rewards) / len(rollout_rewards)) if rollout_rewards else 0.0, | |
| "legality_rate": (sum(legal) / len(legal)) if legal else 0.0, | |
| "steps": float(len(rollout_rewards)), | |
| } | |
| out["policy_stack_ablations"] = ablations | |
| root = Path(__file__).resolve().parents[1] | |
| report = root / "outputs" / "reports" | |
| report.mkdir(parents=True, exist_ok=True) | |
| (report / "baselines.json").write_text(json.dumps(out, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("baseline_eval_done") | |
| if __name__ == "__main__": | |
| main() | |