Parlay / scripts /eval_comparison.py
sh4shv4t's picture
Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)
df724f2
"""Compare baseline, Gemini, and GRPO JSON summaries."""
import argparse
import json
from pathlib import Path
def _load(path: str) -> dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def _fmt_pct(value: float) -> str:
return f"{100.0 * value:.1f}%"
def _row(label: str, data: dict) -> str:
return (
f"| {label} | {data.get('avg_reward', 0):.3f} | "
f"{_fmt_pct(float(data.get('deal_rate', 0)))} | "
f"{float(data.get('avg_efficiency', 0)):.3f} | "
f"{float(data.get('avg_tom_accuracy', 0)):.3f} | "
f"{int(data.get('bluffs_caught', 0))} |"
)
def _save_chart(baseline: dict, gemini: dict, grpo: dict, output_path: Path) -> None:
import matplotlib.pyplot as plt
labels = ["avg_reward", "deal_rate", "avg_efficiency", "avg_tom_accuracy"]
names = ["Random", "Gemini", "GRPO"]
series = [baseline, gemini, grpo]
x = range(len(labels))
width = 0.22
plt.figure(figsize=(10, 5))
for idx, name in enumerate(names):
vals = [float(series[idx].get(k, 0.0)) for k in labels]
plt.bar([p + (idx - 1) * width for p in x], vals, width=width, label=name)
plt.xticks(list(x), labels)
plt.ylabel("Metric value")
plt.title("Parlay Baseline vs Gemini vs GRPO")
plt.legend()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
def main() -> None:
parser = argparse.ArgumentParser(description="Compare evaluation result JSON files")
parser.add_argument("--baseline-results", required=True)
parser.add_argument("--gemini-results", required=True)
parser.add_argument("--grpo-results", required=True)
args = parser.parse_args()
baseline = _load(args.baseline_results)
gemini = _load(args.gemini_results)
grpo = _load(args.grpo_results)
lines = [
"| Model | avg_reward | deal_rate | avg_efficiency | avg_tom_accuracy | bluffs_caught |",
"|---|---:|---:|---:|---:|---:|",
_row("Random baseline", baseline),
_row("Gemini baseline", gemini),
_row("GRPO", grpo),
]
table = "\n".join(lines)
print(table)
chart_path = Path("results/comparison.png")
_save_chart(baseline, gemini, grpo, chart_path)
print(f"\nSaved chart: {chart_path.resolve()}")
if __name__ == "__main__":
main()