| """Random-action baseline for Parlay.""" |
| import argparse |
| import asyncio |
| import json |
| import logging |
| import random |
| from pathlib import Path |
|
|
| from parlay_env.grader import grade_episode |
| from parlay_env.models import TacticalMove |
| from parlay_env.server import _handle_reset, _handle_step, get_session_state |
|
|
| logger = logging.getLogger(__name__) |
|
|
| PERSONAS = ["shark", "diplomat", "veteran"] |
| SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"] |
| RANDOM_LINES = [ |
| "Let's keep talking.", |
| "I can move a bit.", |
| "This is my proposal.", |
| "We should find middle ground.", |
| "Given that context, here's my number.", |
| ] |
|
|
|
|
| def _mean(values: list[float]) -> float: |
| return sum(values) / len(values) if values else 0.0 |
|
|
|
|
| async def _run_single_episode(scenario_id: str, persona: str, seed: int) -> dict: |
| random.seed(seed) |
| reset = await _handle_reset({"scenario_id": scenario_id, "persona": persona, "seed": seed}) |
| session_id = str(reset["session_id"]) |
| final_price = None |
| t_close = None |
| done = False |
|
|
| while not done: |
| state = get_session_state(session_id) |
| if state is None: |
| break |
| if state.episode_done: |
| break |
|
|
| low = state.hidden_state.walk_away_price |
| high = state.hidden_state.budget_ceiling |
| offer = round(random.uniform(low, high), 2) |
|
|
| moves: list[TacticalMove | None] = [None] |
| if state.credibility_points >= 0: |
| moves.append(TacticalMove.ANCHOR_HIGH) |
| if state.credibility_points >= 5: |
| moves.append(TacticalMove.SILENCE) |
| if state.credibility_points >= 20: |
| moves.append(TacticalMove.BATNA_REVEAL) |
| move = random.choice(moves) |
|
|
| payload = { |
| "session_id": session_id, |
| "action": { |
| "utterance": random.choice(RANDOM_LINES), |
| "offer_amount": offer, |
| "tactical_move": move.value if move else None, |
| }, |
| } |
| step = await _handle_step(payload) |
| done = bool(step.get("done", False)) |
|
|
| state = get_session_state(session_id) |
| if state and state.deal_reached and final_price is None: |
| final_price = offer |
| t_close = state.step_count |
|
|
| state = get_session_state(session_id) |
| if state is None: |
| raise RuntimeError(f"Missing session state for {session_id}") |
| grade = grade_episode(state, final_price=final_price, t_close=t_close, t_max=20) |
| return { |
| "avg_reward": float(grade.total_reward), |
| "deal_rate": 1.0 if final_price is not None else 0.0, |
| "avg_efficiency": float(grade.deal_efficiency), |
| "avg_tom_accuracy": float(grade.tom_accuracy_avg), |
| "bluffs_caught": int(grade.bluffs_caught), |
| } |
|
|
|
|
| async def _run_baseline(episodes: int) -> list[dict]: |
| rows: list[dict] = [] |
| for i in range(episodes): |
| persona = PERSONAS[i % len(PERSONAS)] |
| scenario = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)] |
| try: |
| rows.append(await _run_single_episode(scenario, persona, i + 7)) |
| except Exception as exc: |
| logger.warning("Baseline episode %d failed (%s/%s): %s", i + 1, scenario, persona, exc) |
| return rows |
|
|
|
|
| def _summarise(rows: list[dict], episodes_requested: int) -> dict: |
| if not rows: |
| return { |
| "episodes_requested": episodes_requested, |
| "episodes_completed": 0, |
| "avg_reward": 0.0, |
| "deal_rate": 0.0, |
| "avg_efficiency": 0.0, |
| "avg_tom_accuracy": 0.0, |
| "bluffs_caught": 0, |
| } |
| return { |
| "episodes_requested": episodes_requested, |
| "episodes_completed": len(rows), |
| "avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4), |
| "deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4), |
| "avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4), |
| "avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4), |
| "bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)), |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Parlay random baseline") |
| parser.add_argument("--episodes", type=int, default=20) |
| parser.add_argument("--output", default="results/baseline.json") |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") |
| rows = asyncio.run(_run_baseline(args.episodes)) |
| summary = _summarise(rows, args.episodes) |
|
|
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| print(json.dumps(summary, indent=2)) |
| print(f"\nSaved random baseline to {out_path.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|