"""Run Gemini self-play baseline and save summary JSON.""" import argparse import asyncio import json import logging from pathlib import Path from agent.runner import run_episode from parlay_env.models import PersonaType PERSONAS = [PersonaType.SHARK, PersonaType.DIPLOMAT, PersonaType.VETERAN] SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"] def _mean(values: list[float]) -> float: return sum(values) / len(values) if values else 0.0 async def _run(episodes: int) -> list[dict]: rows: list[dict] = [] for i in range(episodes): persona = PERSONAS[i % len(PERSONAS)] scenario_id = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)] result = await run_episode( persona=persona, scenario_id=scenario_id, inject_noise=False, force_drift=True, seed=i + 100, max_turns=20, ) rows.append( { "avg_reward": float(result.grade.total_reward), "deal_rate": 1.0 if result.final_price is not None else 0.0, "avg_efficiency": float(result.grade.deal_efficiency), "avg_tom_accuracy": float(result.grade.tom_accuracy_avg), "bluffs_caught": int(result.grade.bluffs_caught), } ) return rows def _summarise(rows: list[dict], episodes_requested: int) -> dict: return { "episodes_requested": episodes_requested, "episodes_completed": len(rows), "avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4), "deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4), "avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4), "avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4), "bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)), } def main() -> None: parser = argparse.ArgumentParser(description="Run Gemini self-play baseline") parser.add_argument("--episodes", type=int, default=20) parser.add_argument("--output", default="results/gemini_baseline.json") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") rows = asyncio.run(_run(args.episodes)) summary = _summarise(rows, args.episodes) out = Path(args.output) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(summary, indent=2), encoding="utf-8") print(json.dumps(summary, indent=2)) print(f"\nSaved Gemini baseline to {out.resolve()}") if __name__ == "__main__": main()