File size: 2,636 Bytes
df724f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Run Gemini self-play baseline and save summary JSON."""
import argparse
import asyncio
import json
import logging
from pathlib import Path

from agent.runner import run_episode
from parlay_env.models import PersonaType

PERSONAS = [PersonaType.SHARK, PersonaType.DIPLOMAT, PersonaType.VETERAN]
SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]


def _mean(values: list[float]) -> float:
    return sum(values) / len(values) if values else 0.0


async def _run(episodes: int) -> list[dict]:
    rows: list[dict] = []
    for i in range(episodes):
        persona = PERSONAS[i % len(PERSONAS)]
        scenario_id = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
        result = await run_episode(
            persona=persona,
            scenario_id=scenario_id,
            inject_noise=False,
            force_drift=True,
            seed=i + 100,
            max_turns=20,
        )
        rows.append(
            {
                "avg_reward": float(result.grade.total_reward),
                "deal_rate": 1.0 if result.final_price is not None else 0.0,
                "avg_efficiency": float(result.grade.deal_efficiency),
                "avg_tom_accuracy": float(result.grade.tom_accuracy_avg),
                "bluffs_caught": int(result.grade.bluffs_caught),
            }
        )
    return rows


def _summarise(rows: list[dict], episodes_requested: int) -> dict:
    return {
        "episodes_requested": episodes_requested,
        "episodes_completed": len(rows),
        "avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
        "deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
        "avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
        "avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
        "bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Run Gemini self-play baseline")
    parser.add_argument("--episodes", type=int, default=20)
    parser.add_argument("--output", default="results/gemini_baseline.json")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
    rows = asyncio.run(_run(args.episodes))
    summary = _summarise(rows, args.episodes)

    out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
    print(json.dumps(summary, indent=2))
    print(f"\nSaved Gemini baseline to {out.resolve()}")


if __name__ == "__main__":
    main()