""" Self-play episode runner for Parlay. CLI: python -m agent.runner --steps N --persona shark --scenario saas_enterprise """ import argparse import asyncio import logging import uuid from dataclasses import dataclass from typing import Optional import numpy as np from parlay_env.grader import EpisodeGrade, compute_step_reward, grade_episode from parlay_env.models import ( BeliefState, HiddenState, ParlayAction, ParlayState, PersonaType, TacticalMove, ) from agent.gemini_client import call_gemini, call_gemini_tom from agent.personas import PERSONAS, build_system_prompt from agent.tom_tracker import ToMTracker from game.scenarios import get_scenario TOM_DIAGNOSTIC = False # Set False before full training run logger = logging.getLogger(__name__) @dataclass class EpisodeResult: """Result from a single self-play episode.""" session: ParlayState system_prompt: str conversation: list[dict] grade: EpisodeGrade final_price: Optional[float] async def run_episode( persona: PersonaType = PersonaType.SHARK, scenario_id: str = "saas_enterprise", inject_noise: bool = False, force_drift: bool = False, seed: int = 42, max_turns: int = 20, ) -> EpisodeResult: """ Run a single self-play negotiation episode. Per-turn step rewards are computed via compute_step_reward() and accumulated into state.cumulative_reward each turn so the terminal grader sees the full dense-reward signal. Args: persona: Opponent persona. scenario_id: Scenario to play. inject_noise: If True, inject random early moves for data diversity. force_drift: If True, force a drift event to fire. seed: Random seed for reproducibility. max_turns: Maximum number of turns. Returns: EpisodeResult with session state, conversation, and grade. """ rng = np.random.default_rng(seed) scenario = get_scenario(scenario_id) persona_cfg = PERSONAS[persona] noise = float(rng.uniform(0.95, 1.05)) hidden = HiddenState( budget_ceiling=round(scenario.batna_buyer * noise, 2), walk_away_price=round(scenario.batna_seller * noise, 2), urgency_score=float(np.clip(0.5 + rng.uniform(-0.2, 0.2), 0.0, 1.0)), has_alternative=bool(rng.choice([True, False])), persona_drifted=False, ) initial_belief = BeliefState( est_budget=hidden.budget_ceiling * 0.75, est_walk_away=hidden.walk_away_price * 1.20, est_urgency=0.50, est_has_alternative=False, confidence=0.25, ) tom = ToMTracker(initial_belief, persona) session_id = str(uuid.uuid4()) state = ParlayState( session_id=session_id, scenario_id=scenario_id, persona=persona, step_count=0, cumulative_reward=0.0, hidden_state=hidden, belief_history=[initial_belief], offer_history=[], drift_events_fired=0, episode_done=False, credibility_points=100, ) system_prompt = build_system_prompt( persona=persona, scenario_id=scenario_id, scenario_title=scenario.title, scenario_description=scenario.description, batna=hidden.walk_away_price, budget=hidden.budget_ceiling, urgency=hidden.urgency_score, ) conversation: list[dict] = [] drift_adapted = False drift_turn: Optional[int] = None final_price: Optional[float] = None cumulative_reward: float = 0.0 opening = persona_cfg.opening_line conversation.append({"role": "model", "content": opening, "turn": 0}) forced_drift_turn = int(rng.integers(3, 8)) if force_drift else None for turn in range(max_turns): for event in scenario.drift_events: if event.trigger_turn == turn or (forced_drift_turn == turn): drift_turn = turn tom.drift_event( event.effect_on_urgency, event.effect_on_has_alternative, event_description=event.event, ) logger.info(f"Drift event at turn {turn}: {event.event!r}") break if inject_noise and turn < 3 and rng.random() < 0.3: random_move: Optional[TacticalMove] = TacticalMove( rng.choice([m.value for m in TacticalMove]) ) else: random_move = None agent_messages = [ {"role": "user" if i % 2 == 0 else "model", "parts": [m["content"]]} for i, m in enumerate(conversation) ] current_offer_str = str(state.offer_history[-1]) if state.offer_history else "None" agent_messages.append({ "role": "user", "parts": [ f"Turn {turn + 1}. Make your move. " f"Current offer on table: {current_offer_str}" ], }) # Always pass persona explicitly so mock mode uses the right responses agent_response = await call_gemini( system_prompt, agent_messages, persona=persona.value, scenario_id=scenario_id, ) action = ParlayAction( utterance=agent_response.get("utterance", "..."), offer_amount=agent_response.get("offer_amount"), tactical_move=random_move or _parse_tactical_move( agent_response.get("tactical_move") ), ) conversation.append({ "role": "negotiator", "content": action.utterance, "offer": action.offer_amount, "move": action.tactical_move.value if action.tactical_move else None, "turn": turn + 1, }) opponent_messages = agent_messages + [ {"role": "user", "parts": [action.utterance]} ] opponent_response = await call_gemini( ( f"You are the human buyer in this negotiation. " f"Respond naturally to the AI seller.\n" f"Scenario: {scenario.title}. " f"Your budget ceiling: {hidden.budget_ceiling:,.0f}\n" f"Occasionally (roughly 1 in 4 turns), make a claim about having a competing offer " f"or a lower budget than you actually have. State a BATNA that is 20-30% away from " f"your true position. This is realistic negotiation behavior and tests whether the " f"agent can detect bluffing." ), opponent_messages, persona=persona.value, scenario_id=scenario_id, ) conversation.append({ "role": "opponent", "content": opponent_response.get("utterance", "..."), "offer": opponent_response.get("offer_amount"), "turn": turn + 1, }) tom.update( observed_offer=opponent_response.get("offer_amount"), observed_move=_parse_tactical_move(opponent_response.get("tactical_move")), utterance=opponent_response.get("utterance", ""), turn=turn, ) if TOM_DIAGNOSTIC: tom.log_belief_snapshot(turn=turn) if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2: adaptation_signals = ["understand", "noted", "given that", "considering"] matched = next( (s for s in adaptation_signals if s in action.utterance.lower()), None ) if matched: drift_adapted = True logger.info( f"drift_adapted=True at turn={turn} " f"matched_phrase={matched!r} " f"utterance_snippet={action.utterance[:80]!r}" ) new_offers = list(state.offer_history) if action.offer_amount: new_offers.append(action.offer_amount) cp_delta = _get_cp_cost(action.tactical_move) if action.tactical_move else 0 # Build next state (without cumulative_reward — computed below) next_state_fields = { **state.model_dump(), "step_count": turn + 1, "offer_history": new_offers, "belief_history": tom.history, "episode_done": turn + 1 >= max_turns, "termination_reason": "max_turns" if turn + 1 >= max_turns else None, "credibility_points": max(0, state.credibility_points + 5 - cp_delta), } next_state_tmp = ParlayState(**next_state_fields) # Compute per-step dense reward and accumulate step_reward = compute_step_reward(state, action, next_state_tmp) cumulative_reward += step_reward logger.debug( f"Turn {turn + 1}: step_reward={step_reward:.3f}, " f"cumulative={cumulative_reward:.3f}" ) # Update state carrying forward the accumulated reward state = ParlayState( **{**next_state_fields, "cumulative_reward": cumulative_reward} ) # Check for deal close (within 3% of each other) if action.offer_amount and opponent_response.get("offer_amount"): agent_offer = action.offer_amount opp_offer = float(opponent_response["offer_amount"]) if abs(agent_offer - opp_offer) / max(agent_offer, 1) < 0.03: final_price = (agent_offer + opp_offer) / 2 logger.info(f"Deal reached at {final_price:,.0f} on turn {turn + 1}") break grade = grade_episode( state, final_price=final_price, t_close=state.step_count if final_price else None, t_max=max_turns, drift_adapted=drift_adapted, bluffs_caught=tom.bluffs_detected, ) logger.info( f"Episode done: scenario={scenario_id}, persona={persona.value}, " f"reward={grade.total_reward:.2f}, efficiency={grade.deal_efficiency:.3f}, " f"cumulative_step_reward={cumulative_reward:.3f}" ) return EpisodeResult( session=state, system_prompt=system_prompt, conversation=conversation, grade=grade, final_price=final_price, ) def _parse_tactical_move(value: Optional[str]) -> Optional[TacticalMove]: """Parse tactical move from string, returning None if invalid.""" if not value: return None try: return TacticalMove(value) except ValueError: return None def _get_cp_cost(move: Optional[TacticalMove]) -> int: """Return the credibility-point cost for a tactical move.""" costs: dict[TacticalMove, int] = { TacticalMove.ANCHOR_HIGH: 0, TacticalMove.BATNA_REVEAL: 20, TacticalMove.SILENCE: 5, } return costs.get(move, 0) if move else 0 def main() -> None: parser = argparse.ArgumentParser(description="Parlay self-play runner") parser.add_argument("--steps", type=int, default=20, help="Max turns per episode") parser.add_argument( "--persona", default="shark", choices=[p.value for p in PersonaType] ) parser.add_argument("--scenario", default="saas_enterprise") parser.add_argument("--episodes", type=int, default=1) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") async def _run() -> None: for i in range(args.episodes): result = await run_episode( persona=PersonaType(args.persona), scenario_id=args.scenario, seed=i, max_turns=args.steps, ) print( f"\nEpisode {i + 1}: reward={result.grade.total_reward:.2f}, " f"efficiency={result.grade.deal_efficiency:.3f}, " f"deal={'YES' if result.final_price else 'NO'}" ) asyncio.run(_run()) if __name__ == "__main__": main()