Parlay / scripts /run_gemini_baseline.py
sh4shv4t's picture
Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)
df724f2
"""Run Gemini self-play baseline and save summary JSON."""
import argparse
import asyncio
import json
import logging
from pathlib import Path
from agent.runner import run_episode
from parlay_env.models import PersonaType
PERSONAS = [PersonaType.SHARK, PersonaType.DIPLOMAT, PersonaType.VETERAN]
SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
def _mean(values: list[float]) -> float:
return sum(values) / len(values) if values else 0.0
async def _run(episodes: int) -> list[dict]:
rows: list[dict] = []
for i in range(episodes):
persona = PERSONAS[i % len(PERSONAS)]
scenario_id = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
result = await run_episode(
persona=persona,
scenario_id=scenario_id,
inject_noise=False,
force_drift=True,
seed=i + 100,
max_turns=20,
)
rows.append(
{
"avg_reward": float(result.grade.total_reward),
"deal_rate": 1.0 if result.final_price is not None else 0.0,
"avg_efficiency": float(result.grade.deal_efficiency),
"avg_tom_accuracy": float(result.grade.tom_accuracy_avg),
"bluffs_caught": int(result.grade.bluffs_caught),
}
)
return rows
def _summarise(rows: list[dict], episodes_requested: int) -> dict:
return {
"episodes_requested": episodes_requested,
"episodes_completed": len(rows),
"avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
"deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
"avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
"avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
"bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Run Gemini self-play baseline")
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--output", default="results/gemini_baseline.json")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
rows = asyncio.run(_run(args.episodes))
summary = _summarise(rows, args.episodes)
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
print(f"\nSaved Gemini baseline to {out.resolve()}")
if __name__ == "__main__":
main()