File size: 2,783 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3
"""Evaluate baseline policies on one sampled case."""

from __future__ import annotations

import json
import os
from pathlib import Path

import sys

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from app.api.service import APIService
from app.env.env_core import PolyGuardEnv


def _evaluate_no_change_baseline(episodes: int = 8) -> dict[str, float | int | str]:
    rewards: list[float] = []
    legal: list[float] = []
    for idx in range(episodes):
        env = PolyGuardEnv()
        env.reset(seed=99 + idx, difficulty="medium")
        candidates = env.get_candidate_actions()
        action = next((item for item in candidates if item.get("candidate_id") == "cand_01"), None)
        if action is None and candidates:
            action = candidates[0]
        if action is None:
            rewards.append(0.001)
            legal.append(0.0)
            continue
        _, reward, _, info = env.step(action)
        rewards.append(float(reward))
        legal.append(1.0 if bool(info.get("safety_report", {}).get("legal")) else 0.0)

    return {
        "baseline_policy": "no_change_candidate",
        "episodes": episodes,
        "avg_reward": round(sum(rewards) / len(rewards), 6) if rewards else 0.0,
        "legality_rate": round(sum(legal) / len(legal), 6) if legal else 0.0,
        "success_rate": 0.0,
    }


def main() -> None:
    service = APIService()
    service.reset(seed=99, difficulty="medium")
    out = service.run_baselines()
    out.update(_evaluate_no_change_baseline())
    ablations: dict[str, dict[str, float]] = {}
    for stack in ["bandit-only", "llm-only", "llm+bandit"]:
        os.environ["POLYGUARD_POLICY_STACK"] = stack
        service.reset(seed=99, difficulty="medium")
        rollout_rewards: list[float] = []
        legal: list[float] = []
        for _ in range(3):
            step = service.orchestrate()
            rollout_rewards.append(float(step.get("reward", 0.0)))
            legal.append(1.0 if bool(step.get("critic", {}).get("legal")) else 0.0)
            if step.get("done"):
                break
        ablations[stack] = {
            "avg_reward": (sum(rollout_rewards) / len(rollout_rewards)) if rollout_rewards else 0.0,
            "legality_rate": (sum(legal) / len(legal)) if legal else 0.0,
            "steps": float(len(rollout_rewards)),
        }
    out["policy_stack_ablations"] = ablations
    root = Path(__file__).resolve().parents[1]
    report = root / "outputs" / "reports"
    report.mkdir(parents=True, exist_ok=True)
    (report / "baselines.json").write_text(json.dumps(out, ensure_ascii=True, indent=2), encoding="utf-8")
    print("baseline_eval_done")


if __name__ == "__main__":
    main()