File size: 11,955 Bytes
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
9d82eed
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
48756ef
80b3b2e
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d82eed
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b3b2e
698f4d8
 
 
 
 
 
 
 
 
 
 
9d82eed
698f4d8
 
 
 
 
 
 
 
 
 
15976d0
 
 
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d82eed
 
80b3b2e
 
 
 
9d82eed
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df724f2
 
 
 
 
698f4d8
 
9d82eed
80b3b2e
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b3b2e
 
698f4d8
 
 
15976d0
 
 
 
698f4d8
15976d0
 
 
 
 
698f4d8
 
 
 
 
9d82eed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d82eed
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2568517
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
Self-play episode runner for Parlay.
CLI: python -m agent.runner --steps N --persona shark --scenario saas_enterprise
"""
import argparse
import asyncio
import logging
import uuid
from dataclasses import dataclass
from typing import Optional

import numpy as np

from parlay_env.grader import EpisodeGrade, compute_step_reward, grade_episode
from parlay_env.models import (
    BeliefState,
    HiddenState,
    ParlayAction,
    ParlayState,
    PersonaType,
    TacticalMove,
)

from agent.gemini_client import call_gemini, call_gemini_tom
from agent.personas import PERSONAS, build_system_prompt
from agent.tom_tracker import ToMTracker
from game.scenarios import get_scenario

TOM_DIAGNOSTIC = False  # Set False before full training run

logger = logging.getLogger(__name__)


@dataclass
class EpisodeResult:
    """Result from a single self-play episode."""

    session: ParlayState
    system_prompt: str
    conversation: list[dict]
    grade: EpisodeGrade
    final_price: Optional[float]


async def run_episode(
    persona: PersonaType = PersonaType.SHARK,
    scenario_id: str = "saas_enterprise",
    inject_noise: bool = False,
    force_drift: bool = False,
    seed: int = 42,
    max_turns: int = 20,
) -> EpisodeResult:
    """
    Run a single self-play negotiation episode.

    Per-turn step rewards are computed via compute_step_reward() and
    accumulated into state.cumulative_reward each turn so the terminal
    grader sees the full dense-reward signal.

    Args:
        persona:      Opponent persona.
        scenario_id:  Scenario to play.
        inject_noise: If True, inject random early moves for data diversity.
        force_drift:  If True, force a drift event to fire.
        seed:         Random seed for reproducibility.
        max_turns:    Maximum number of turns.

    Returns:
        EpisodeResult with session state, conversation, and grade.
    """
    rng = np.random.default_rng(seed)
    scenario = get_scenario(scenario_id)
    persona_cfg = PERSONAS[persona]

    noise = float(rng.uniform(0.95, 1.05))
    hidden = HiddenState(
        budget_ceiling=round(scenario.batna_buyer * noise, 2),
        walk_away_price=round(scenario.batna_seller * noise, 2),
        urgency_score=float(np.clip(0.5 + rng.uniform(-0.2, 0.2), 0.0, 1.0)),
        has_alternative=bool(rng.choice([True, False])),
        persona_drifted=False,
    )

    initial_belief = BeliefState(
        est_budget=hidden.budget_ceiling * 0.75,
        est_walk_away=hidden.walk_away_price * 1.20,
        est_urgency=0.50,
        est_has_alternative=False,
        confidence=0.25,
    )

    tom = ToMTracker(initial_belief, persona)
    session_id = str(uuid.uuid4())

    state = ParlayState(
        session_id=session_id,
        scenario_id=scenario_id,
        persona=persona,
        step_count=0,
        cumulative_reward=0.0,
        hidden_state=hidden,
        belief_history=[initial_belief],
        offer_history=[],
        drift_events_fired=0,
        episode_done=False,
        credibility_points=100,
    )

    system_prompt = build_system_prompt(
        persona=persona,
        scenario_id=scenario_id,
        scenario_title=scenario.title,
        scenario_description=scenario.description,
        batna=hidden.walk_away_price,
        budget=hidden.budget_ceiling,
        urgency=hidden.urgency_score,
    )

    conversation: list[dict] = []
    drift_adapted = False
    drift_turn: Optional[int] = None
    final_price: Optional[float] = None
    cumulative_reward: float = 0.0

    opening = persona_cfg.opening_line
    conversation.append({"role": "model", "content": opening, "turn": 0})

    forced_drift_turn = int(rng.integers(3, 8)) if force_drift else None

    for turn in range(max_turns):
        for event in scenario.drift_events:
            if event.trigger_turn == turn or (forced_drift_turn == turn):
                drift_turn = turn
                tom.drift_event(
                    event.effect_on_urgency,
                    event.effect_on_has_alternative,
                    event_description=event.event,
                )
                logger.info(f"Drift event at turn {turn}: {event.event!r}")
                break

        if inject_noise and turn < 3 and rng.random() < 0.3:
            random_move: Optional[TacticalMove] = TacticalMove(
                rng.choice([m.value for m in TacticalMove])
            )
        else:
            random_move = None

        agent_messages = [
            {"role": "user" if i % 2 == 0 else "model", "parts": [m["content"]]}
            for i, m in enumerate(conversation)
        ]
        current_offer_str = str(state.offer_history[-1]) if state.offer_history else "None"
        agent_messages.append({
            "role": "user",
            "parts": [
                f"Turn {turn + 1}. Make your move. "
                f"Current offer on table: {current_offer_str}"
            ],
        })

        # Always pass persona explicitly so mock mode uses the right responses
        agent_response = await call_gemini(
            system_prompt,
            agent_messages,
            persona=persona.value,
            scenario_id=scenario_id,
        )
        action = ParlayAction(
            utterance=agent_response.get("utterance", "..."),
            offer_amount=agent_response.get("offer_amount"),
            tactical_move=random_move or _parse_tactical_move(
                agent_response.get("tactical_move")
            ),
        )

        conversation.append({
            "role": "negotiator",
            "content": action.utterance,
            "offer": action.offer_amount,
            "move": action.tactical_move.value if action.tactical_move else None,
            "turn": turn + 1,
        })

        opponent_messages = agent_messages + [
            {"role": "user", "parts": [action.utterance]}
        ]
        opponent_response = await call_gemini(
            (
                f"You are the human buyer in this negotiation. "
                f"Respond naturally to the AI seller.\n"
                f"Scenario: {scenario.title}. "
                f"Your budget ceiling: {hidden.budget_ceiling:,.0f}\n"
                f"Occasionally (roughly 1 in 4 turns), make a claim about having a competing offer "
                f"or a lower budget than you actually have. State a BATNA that is 20-30% away from "
                f"your true position. This is realistic negotiation behavior and tests whether the "
                f"agent can detect bluffing."
            ),
            opponent_messages,
            persona=persona.value,
            scenario_id=scenario_id,
        )

        conversation.append({
            "role": "opponent",
            "content": opponent_response.get("utterance", "..."),
            "offer": opponent_response.get("offer_amount"),
            "turn": turn + 1,
        })

        tom.update(
            observed_offer=opponent_response.get("offer_amount"),
            observed_move=_parse_tactical_move(opponent_response.get("tactical_move")),
            utterance=opponent_response.get("utterance", ""),
            turn=turn,
        )
        if TOM_DIAGNOSTIC:
            tom.log_belief_snapshot(turn=turn)

        if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2:
            adaptation_signals = ["understand", "noted", "given that", "considering"]
            matched = next(
                (s for s in adaptation_signals if s in action.utterance.lower()), None
            )
            if matched:
                drift_adapted = True
                logger.info(
                    f"drift_adapted=True at turn={turn} "
                    f"matched_phrase={matched!r} "
                    f"utterance_snippet={action.utterance[:80]!r}"
                )

        new_offers = list(state.offer_history)
        if action.offer_amount:
            new_offers.append(action.offer_amount)

        cp_delta = _get_cp_cost(action.tactical_move) if action.tactical_move else 0

        # Build next state (without cumulative_reward — computed below)
        next_state_fields = {
            **state.model_dump(),
            "step_count": turn + 1,
            "offer_history": new_offers,
            "belief_history": tom.history,
            "episode_done": turn + 1 >= max_turns,
            "termination_reason": "max_turns" if turn + 1 >= max_turns else None,
            "credibility_points": max(0, state.credibility_points + 5 - cp_delta),
        }
        next_state_tmp = ParlayState(**next_state_fields)

        # Compute per-step dense reward and accumulate
        step_reward = compute_step_reward(state, action, next_state_tmp)
        cumulative_reward += step_reward
        logger.debug(
            f"Turn {turn + 1}: step_reward={step_reward:.3f}, "
            f"cumulative={cumulative_reward:.3f}"
        )

        # Update state carrying forward the accumulated reward
        state = ParlayState(
            **{**next_state_fields, "cumulative_reward": cumulative_reward}
        )

        # Check for deal close (within 3% of each other)
        if action.offer_amount and opponent_response.get("offer_amount"):
            agent_offer = action.offer_amount
            opp_offer = float(opponent_response["offer_amount"])
            if abs(agent_offer - opp_offer) / max(agent_offer, 1) < 0.03:
                final_price = (agent_offer + opp_offer) / 2
                logger.info(f"Deal reached at {final_price:,.0f} on turn {turn + 1}")
                break

    grade = grade_episode(
        state,
        final_price=final_price,
        t_close=state.step_count if final_price else None,
        t_max=max_turns,
        drift_adapted=drift_adapted,
        bluffs_caught=tom.bluffs_detected,
    )

    logger.info(
        f"Episode done: scenario={scenario_id}, persona={persona.value}, "
        f"reward={grade.total_reward:.2f}, efficiency={grade.deal_efficiency:.3f}, "
        f"cumulative_step_reward={cumulative_reward:.3f}"
    )

    return EpisodeResult(
        session=state,
        system_prompt=system_prompt,
        conversation=conversation,
        grade=grade,
        final_price=final_price,
    )


def _parse_tactical_move(value: Optional[str]) -> Optional[TacticalMove]:
    """Parse tactical move from string, returning None if invalid."""
    if not value:
        return None
    try:
        return TacticalMove(value)
    except ValueError:
        return None


def _get_cp_cost(move: Optional[TacticalMove]) -> int:
    """Return the credibility-point cost for a tactical move."""
    costs: dict[TacticalMove, int] = {
        TacticalMove.ANCHOR_HIGH: 0,
        TacticalMove.BATNA_REVEAL: 20,
        TacticalMove.SILENCE: 5,
    }
    return costs.get(move, 0) if move else 0


def main() -> None:
    parser = argparse.ArgumentParser(description="Parlay self-play runner")
    parser.add_argument("--steps", type=int, default=20, help="Max turns per episode")
    parser.add_argument(
        "--persona", default="shark", choices=[p.value for p in PersonaType]
    )
    parser.add_argument("--scenario", default="saas_enterprise")
    parser.add_argument("--episodes", type=int, default=1)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")

    async def _run() -> None:
        for i in range(args.episodes):
            result = await run_episode(
                persona=PersonaType(args.persona),
                scenario_id=args.scenario,
                seed=i,
                max_turns=args.steps,
            )
            print(
                f"\nEpisode {i + 1}: reward={result.grade.total_reward:.2f}, "
                f"efficiency={result.grade.deal_efficiency:.3f}, "
                f"deal={'YES' if result.final_price else 'NO'}"
            )

    asyncio.run(_run())


if __name__ == "__main__":
    main()