| """Compare baseline, Gemini, and GRPO JSON summaries.""" |
| import argparse |
| import json |
| from pathlib import Path |
|
|
|
|
| def _load(path: str) -> dict: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def _fmt_pct(value: float) -> str: |
| return f"{100.0 * value:.1f}%" |
|
|
|
|
| def _row(label: str, data: dict) -> str: |
| return ( |
| f"| {label} | {data.get('avg_reward', 0):.3f} | " |
| f"{_fmt_pct(float(data.get('deal_rate', 0)))} | " |
| f"{float(data.get('avg_efficiency', 0)):.3f} | " |
| f"{float(data.get('avg_tom_accuracy', 0)):.3f} | " |
| f"{int(data.get('bluffs_caught', 0))} |" |
| ) |
|
|
|
|
| def _save_chart(baseline: dict, gemini: dict, grpo: dict, output_path: Path) -> None: |
| import matplotlib.pyplot as plt |
|
|
| labels = ["avg_reward", "deal_rate", "avg_efficiency", "avg_tom_accuracy"] |
| names = ["Random", "Gemini", "GRPO"] |
| series = [baseline, gemini, grpo] |
|
|
| x = range(len(labels)) |
| width = 0.22 |
|
|
| plt.figure(figsize=(10, 5)) |
| for idx, name in enumerate(names): |
| vals = [float(series[idx].get(k, 0.0)) for k in labels] |
| plt.bar([p + (idx - 1) * width for p in x], vals, width=width, label=name) |
|
|
| plt.xticks(list(x), labels) |
| plt.ylabel("Metric value") |
| plt.title("Parlay Baseline vs Gemini vs GRPO") |
| plt.legend() |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(output_path, dpi=150) |
| plt.close() |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Compare evaluation result JSON files") |
| parser.add_argument("--baseline-results", required=True) |
| parser.add_argument("--gemini-results", required=True) |
| parser.add_argument("--grpo-results", required=True) |
| args = parser.parse_args() |
|
|
| baseline = _load(args.baseline_results) |
| gemini = _load(args.gemini_results) |
| grpo = _load(args.grpo_results) |
|
|
| lines = [ |
| "| Model | avg_reward | deal_rate | avg_efficiency | avg_tom_accuracy | bluffs_caught |", |
| "|---|---:|---:|---:|---:|---:|", |
| _row("Random baseline", baseline), |
| _row("Gemini baseline", gemini), |
| _row("GRPO", grpo), |
| ] |
| table = "\n".join(lines) |
| print(table) |
|
|
| chart_path = Path("results/comparison.png") |
| _save_chart(baseline, gemini, grpo, chart_path) |
| print(f"\nSaved chart: {chart_path.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|