File size: 2,400 Bytes
df724f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Compare baseline, Gemini, and GRPO JSON summaries."""
import argparse
import json
from pathlib import Path


def _load(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def _fmt_pct(value: float) -> str:
    return f"{100.0 * value:.1f}%"


def _row(label: str, data: dict) -> str:
    return (
        f"| {label} | {data.get('avg_reward', 0):.3f} | "
        f"{_fmt_pct(float(data.get('deal_rate', 0)))} | "
        f"{float(data.get('avg_efficiency', 0)):.3f} | "
        f"{float(data.get('avg_tom_accuracy', 0)):.3f} | "
        f"{int(data.get('bluffs_caught', 0))} |"
    )


def _save_chart(baseline: dict, gemini: dict, grpo: dict, output_path: Path) -> None:
    import matplotlib.pyplot as plt

    labels = ["avg_reward", "deal_rate", "avg_efficiency", "avg_tom_accuracy"]
    names = ["Random", "Gemini", "GRPO"]
    series = [baseline, gemini, grpo]

    x = range(len(labels))
    width = 0.22

    plt.figure(figsize=(10, 5))
    for idx, name in enumerate(names):
        vals = [float(series[idx].get(k, 0.0)) for k in labels]
        plt.bar([p + (idx - 1) * width for p in x], vals, width=width, label=name)

    plt.xticks(list(x), labels)
    plt.ylabel("Metric value")
    plt.title("Parlay Baseline vs Gemini vs GRPO")
    plt.legend()
    output_path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150)
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser(description="Compare evaluation result JSON files")
    parser.add_argument("--baseline-results", required=True)
    parser.add_argument("--gemini-results", required=True)
    parser.add_argument("--grpo-results", required=True)
    args = parser.parse_args()

    baseline = _load(args.baseline_results)
    gemini = _load(args.gemini_results)
    grpo = _load(args.grpo_results)

    lines = [
        "| Model | avg_reward | deal_rate | avg_efficiency | avg_tom_accuracy | bluffs_caught |",
        "|---|---:|---:|---:|---:|---:|",
        _row("Random baseline", baseline),
        _row("Gemini baseline", gemini),
        _row("GRPO", grpo),
    ]
    table = "\n".join(lines)
    print(table)

    chart_path = Path("results/comparison.png")
    _save_chart(baseline, gemini, grpo, chart_path)
    print(f"\nSaved chart: {chart_path.resolve()}")


if __name__ == "__main__":
    main()