File size: 7,965 Bytes
698f4d8
 
 
 
 
80b3b2e
698f4d8
 
 
 
 
 
df724f2
 
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b3b2e
 
 
 
 
 
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15976d0
698f4d8
 
 
 
 
 
 
15976d0
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
15976d0
698f4d8
15976d0
 
698f4d8
 
 
 
15976d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
Theory of Mind belief tracker for Parlay.
Tracks and updates agent beliefs about opponent hidden state.
"""
import logging
import sys
from typing import Optional

from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove

logger = logging.getLogger(__name__)

# NOTE: ToMTracker is used in two paths:
# (1) agent/runner.py self-play — full update each turn;
# (2) parlay_env/server.py WebSocket server — also uses ToMTracker after Task 1 fix.
# Both paths now produce comparable belief_history for grader._tom_accuracy.


class ToMTracker:
    """
    Tracks Theory of Mind beliefs about the opponent's hidden state.

    Maintains a belief history and updates beliefs based on:
    - Observed offers
    - Tactical move signals
    - Utterance content
    - Drift event triggers
    """

    def __init__(
        self,
        initial_belief: BeliefState,
        persona: PersonaType,
    ) -> None:
        """
        Args:
            initial_belief: Starting belief state (imprecise prior).
            persona:        Opponent's persona type (known to player).
        """
        self.history: list[BeliefState] = [initial_belief]
        self.persona = persona
        self._bluffs_detected: int = 0

    @property
    def current_belief(self) -> BeliefState:
        """Most recent belief state."""
        return self.history[-1]

    @property
    def bluffs_detected(self) -> int:
        """Count of detected bluffs this session."""
        return self._bluffs_detected

    def log_belief_snapshot(self, turn: int) -> None:
        """Print current belief estimates to stderr (diagnostic; training / live debug)."""
        b = self.current_belief
        print(
            f"[ToM turn={turn}] budget={b.est_budget:.3f}  "
            f"urgency={b.est_urgency:.3f}  walkaway={b.est_walk_away:.3f}",
            file=sys.stderr,
        )

    def update(
        self,
        observed_offer: Optional[float],
        observed_move: Optional[TacticalMove],
        utterance: str,
        turn: int,
    ) -> BeliefState:
        """
        Update beliefs based on latest opponent action.

        Args:
            observed_offer: Opponent's counter-offer (None if no offer made).
            observed_move:  Tactical move used (if any).
            utterance:      Opponent's latest utterance.
            turn:           Current turn number.

        Returns:
            Updated BeliefState.
        """
        last = self.current_belief

        est_budget = last.est_budget
        if observed_offer is not None:
            est_budget = max(est_budget, observed_offer * 1.05)

        est_walk_away = last.est_walk_away
        if observed_move == TacticalMove.BATNA_REVEAL:
            est_walk_away = last.est_walk_away * 0.95
            logger.debug("ToM: BATNA_REVEAL detected — hedging walk-away estimate")

        est_urgency = last.est_urgency
        if observed_offer is not None and last.est_budget > 0:
            offer_ratio = observed_offer / last.est_budget
            if offer_ratio < 0.85:
                est_urgency = min(1.0, est_urgency + 0.05)
            elif offer_ratio > 0.95:
                est_urgency = max(0.0, est_urgency - 0.03)

        est_has_alternative = last.est_has_alternative
        alternative_signals = ["competitor", "alternative", "other offer", "another bid"]
        if any(sig in utterance.lower() for sig in alternative_signals):
            est_has_alternative = True
            logger.debug("ToM: alternative signal detected in utterance")

        confidence = min(1.0, last.confidence + 0.04)

        if (
            self.persona == PersonaType.SHARK
            and observed_move == TacticalMove.BATNA_REVEAL
            and "competitor" in utterance.lower()
        ):
            self._bluffs_detected += 1
            logger.info(f"ToM: bluff detected (total: {self._bluffs_detected})")

        updated = BeliefState(
            est_budget=round(est_budget, 2),
            est_walk_away=round(est_walk_away, 2),
            est_urgency=round(est_urgency, 4),
            est_has_alternative=est_has_alternative,
            confidence=round(confidence, 4),
        )
        self.history.append(updated)
        logger.debug(
            f"ToM update turn={turn}: "
            f"budget={est_budget:,.0f}, walk={est_walk_away:,.0f}, "
            f"urgency={est_urgency:.2f}, alt={est_has_alternative}, "
            f"confidence={confidence:.2f}"
        )
        return updated

    def drift_event(
        self,
        effect_on_urgency: float,
        effect_on_has_alternative: bool,
        event_description: str = "",
    ) -> BeliefState:
        """
        Apply a drift event to beliefs.

        Args:
            effect_on_urgency:         Signed delta to urgency estimate.
            effect_on_has_alternative: Override for has_alternative belief.
            event_description:         Human-readable scenario event string
                                       (e.g. "Competitor drops price 15%").

        Returns:
            Updated BeliefState post-drift.
        """
        last = self.current_belief
        new_urgency = float(max(0.0, min(1.0, last.est_urgency + effect_on_urgency)))
        updated = BeliefState(
            est_budget=last.est_budget,
            est_walk_away=last.est_walk_away,
            est_urgency=round(new_urgency, 4),
            est_has_alternative=effect_on_has_alternative,
            confidence=max(0.0, last.confidence - 0.15),  # drift reduces confidence
        )
        self.history.append(updated)
        desc_part = f" | event={event_description!r}" if event_description else ""
        logger.info(
            f"ToM drift applied{desc_part}: "
            f"urgency_delta={effect_on_urgency:+.2f}{new_urgency:.2f}, "
            f"alt={effect_on_has_alternative}"
        )
        return updated

    def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
        """
        Compute per-field Brier scores over the full belief history.

        Brier score = (1/N) Σ (predicted - actual)²
        Lower is better; 0 = perfect.

        Fields scored:
          - urgency:      est_urgency (continuous 0–1) vs hidden.urgency_score
          - has_alt:      est_has_alternative (0/1 probability) vs hidden.has_alternative

        Args:
            hidden: The true hidden state revealed at episode end.

        Returns:
            Dict with keys "urgency" and "has_alt", each a float in [0, 1].
        """
        if not self.history:
            return {"urgency": 1.0, "has_alt": 1.0}

        actual_urgency = hidden.urgency_score
        actual_alt = float(hidden.has_alternative)

        urgency_sq_err = sum(
            (b.est_urgency - actual_urgency) ** 2 for b in self.history
        )
        alt_sq_err = sum(
            (float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history
        )
        n = len(self.history)
        return {
            "urgency": round(urgency_sq_err / n, 6),
            "has_alt": round(alt_sq_err / n, 6),
        }

    def accuracy_against(self, hidden: HiddenState) -> float:
        """
        Compute current belief accuracy against true hidden state.

        Args:
            hidden: The true hidden state.

        Returns:
            Accuracy score in [0, 1].
        """
        b = self.current_belief
        budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
        walk_range   = max(hidden.walk_away_price * 0.5, 1.0)

        budget_err  = abs(b.est_budget       - hidden.budget_ceiling)  / budget_range
        walk_err    = abs(b.est_walk_away    - hidden.walk_away_price) / walk_range
        urgency_err = abs(b.est_urgency      - hidden.urgency_score)
        alt_err     = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0

        mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
        return max(0.0, 1.0 - mean_err)