File size: 3,188 Bytes
fd0c71a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
"""Compare baseline and candidate reports to show measurable improvement."""

from __future__ import annotations

import argparse
import json
from pathlib import Path


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Compare two run reports.")
    parser.add_argument("--baseline", required=True)
    parser.add_argument("--candidate", required=True)
    parser.add_argument("--output", default="outputs/reports/improvement_report.json")
    return parser.parse_args()


def _load(path: Path) -> dict:
    if not path.exists():
        return {}
    return json.loads(path.read_text(encoding="utf-8"))


def _metric(payload: dict, key: str) -> float:
    if key in payload:
        return float(payload.get(key, 0.0))

    offline = payload.get("offline_policy_eval", {}) if isinstance(payload.get("offline_policy_eval"), dict) else {}
    process = payload.get("process_eval", {}) if isinstance(payload.get("process_eval"), dict) else {}
    ablations = payload.get("policy_stack_ablations", {}) if isinstance(payload.get("policy_stack_ablations"), dict) else {}
    llm_bandit = (
        payload.get("ablations", {}).get("llm_bandit", {})
        if isinstance(payload.get("ablations"), dict)
        else {}
    )

    mapping = {
        "avg_reward": [offline.get("avg_reward"), llm_bandit.get("avg_reward"), ablations.get("llm+bandit", {}).get("avg_reward")],
        "legality_rate": [offline.get("legal_rate"), llm_bandit.get("legality_rate"), ablations.get("llm+bandit", {}).get("legality_rate")],
        "success_rate": [offline.get("success_rate"), llm_bandit.get("success_rate")],
        "avg_process_fidelity": [process.get("process_fidelity"), llm_bandit.get("avg_process_fidelity")],
        "timeout_rate": [payload.get("timeout_rate"), llm_bandit.get("timeout_rate")],
        "failure_visible_rate": [payload.get("failure_visible_rate"), llm_bandit.get("failure_visible_rate")],
    }
    for value in mapping.get(key, []):
        if value is not None:
            return float(value)
    return 0.0


def main() -> None:
    args = parse_args()
    baseline = _load(Path(args.baseline))
    candidate = _load(Path(args.candidate))

    keys = [
        "avg_reward",
        "legality_rate",
        "success_rate",
        "avg_process_fidelity",
        "timeout_rate",
        "failure_visible_rate",
    ]
    deltas = {}
    for key in keys:
        b = _metric(baseline, key)
        c = _metric(candidate, key)
        deltas[key] = round(c - b, 6)

    gate = {
        "avg_reward_up": deltas["avg_reward"] >= 0.0,
        "legality_up": deltas["legality_rate"] >= 0.0,
        "success_up": deltas["success_rate"] >= 0.0,
    }

    payload = {
        "status": "ok",
        "baseline": str(args.baseline),
        "candidate": str(args.candidate),
        "deltas": deltas,
        "gate": gate,
        "improved": all(gate.values()),
    }
    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
    print("evaluate_compare_runs_done")


if __name__ == "__main__":
    main()