| """ |
| Theory of Mind belief tracker for Parlay. |
| Tracks and updates agent beliefs about opponent hidden state. |
| """ |
| import logging |
| import sys |
| from typing import Optional |
|
|
| from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
|
|
|
|
| class ToMTracker: |
| """ |
| Tracks Theory of Mind beliefs about the opponent's hidden state. |
| |
| Maintains a belief history and updates beliefs based on: |
| - Observed offers |
| - Tactical move signals |
| - Utterance content |
| - Drift event triggers |
| """ |
|
|
| def __init__( |
| self, |
| initial_belief: BeliefState, |
| persona: PersonaType, |
| ) -> None: |
| """ |
| Args: |
| initial_belief: Starting belief state (imprecise prior). |
| persona: Opponent's persona type (known to player). |
| """ |
| self.history: list[BeliefState] = [initial_belief] |
| self.persona = persona |
| self._bluffs_detected: int = 0 |
|
|
| @property |
| def current_belief(self) -> BeliefState: |
| """Most recent belief state.""" |
| return self.history[-1] |
|
|
| @property |
| def bluffs_detected(self) -> int: |
| """Count of detected bluffs this session.""" |
| return self._bluffs_detected |
|
|
| def log_belief_snapshot(self, turn: int) -> None: |
| """Print current belief estimates to stderr (diagnostic; training / live debug).""" |
| b = self.current_belief |
| print( |
| f"[ToM turn={turn}] budget={b.est_budget:.3f} " |
| f"urgency={b.est_urgency:.3f} walkaway={b.est_walk_away:.3f}", |
| file=sys.stderr, |
| ) |
|
|
| def update( |
| self, |
| observed_offer: Optional[float], |
| observed_move: Optional[TacticalMove], |
| utterance: str, |
| turn: int, |
| ) -> BeliefState: |
| """ |
| Update beliefs based on latest opponent action. |
| |
| Args: |
| observed_offer: Opponent's counter-offer (None if no offer made). |
| observed_move: Tactical move used (if any). |
| utterance: Opponent's latest utterance. |
| turn: Current turn number. |
| |
| Returns: |
| Updated BeliefState. |
| """ |
| last = self.current_belief |
|
|
| est_budget = last.est_budget |
| if observed_offer is not None: |
| est_budget = max(est_budget, observed_offer * 1.05) |
|
|
| est_walk_away = last.est_walk_away |
| if observed_move == TacticalMove.BATNA_REVEAL: |
| est_walk_away = last.est_walk_away * 0.95 |
| logger.debug("ToM: BATNA_REVEAL detected — hedging walk-away estimate") |
|
|
| est_urgency = last.est_urgency |
| if observed_offer is not None and last.est_budget > 0: |
| offer_ratio = observed_offer / last.est_budget |
| if offer_ratio < 0.85: |
| est_urgency = min(1.0, est_urgency + 0.05) |
| elif offer_ratio > 0.95: |
| est_urgency = max(0.0, est_urgency - 0.03) |
|
|
| est_has_alternative = last.est_has_alternative |
| alternative_signals = ["competitor", "alternative", "other offer", "another bid"] |
| if any(sig in utterance.lower() for sig in alternative_signals): |
| est_has_alternative = True |
| logger.debug("ToM: alternative signal detected in utterance") |
|
|
| confidence = min(1.0, last.confidence + 0.04) |
|
|
| if ( |
| self.persona == PersonaType.SHARK |
| and observed_move == TacticalMove.BATNA_REVEAL |
| and "competitor" in utterance.lower() |
| ): |
| self._bluffs_detected += 1 |
| logger.info(f"ToM: bluff detected (total: {self._bluffs_detected})") |
|
|
| updated = BeliefState( |
| est_budget=round(est_budget, 2), |
| est_walk_away=round(est_walk_away, 2), |
| est_urgency=round(est_urgency, 4), |
| est_has_alternative=est_has_alternative, |
| confidence=round(confidence, 4), |
| ) |
| self.history.append(updated) |
| logger.debug( |
| f"ToM update turn={turn}: " |
| f"budget={est_budget:,.0f}, walk={est_walk_away:,.0f}, " |
| f"urgency={est_urgency:.2f}, alt={est_has_alternative}, " |
| f"confidence={confidence:.2f}" |
| ) |
| return updated |
|
|
| def drift_event( |
| self, |
| effect_on_urgency: float, |
| effect_on_has_alternative: bool, |
| event_description: str = "", |
| ) -> BeliefState: |
| """ |
| Apply a drift event to beliefs. |
| |
| Args: |
| effect_on_urgency: Signed delta to urgency estimate. |
| effect_on_has_alternative: Override for has_alternative belief. |
| event_description: Human-readable scenario event string |
| (e.g. "Competitor drops price 15%"). |
| |
| Returns: |
| Updated BeliefState post-drift. |
| """ |
| last = self.current_belief |
| new_urgency = float(max(0.0, min(1.0, last.est_urgency + effect_on_urgency))) |
| updated = BeliefState( |
| est_budget=last.est_budget, |
| est_walk_away=last.est_walk_away, |
| est_urgency=round(new_urgency, 4), |
| est_has_alternative=effect_on_has_alternative, |
| confidence=max(0.0, last.confidence - 0.15), |
| ) |
| self.history.append(updated) |
| desc_part = f" | event={event_description!r}" if event_description else "" |
| logger.info( |
| f"ToM drift applied{desc_part}: " |
| f"urgency_delta={effect_on_urgency:+.2f} → {new_urgency:.2f}, " |
| f"alt={effect_on_has_alternative}" |
| ) |
| return updated |
|
|
| def brier_scores(self, hidden: HiddenState) -> dict[str, float]: |
| """ |
| Compute per-field Brier scores over the full belief history. |
| |
| Brier score = (1/N) Σ (predicted - actual)² |
| Lower is better; 0 = perfect. |
| |
| Fields scored: |
| - urgency: est_urgency (continuous 0–1) vs hidden.urgency_score |
| - has_alt: est_has_alternative (0/1 probability) vs hidden.has_alternative |
| |
| Args: |
| hidden: The true hidden state revealed at episode end. |
| |
| Returns: |
| Dict with keys "urgency" and "has_alt", each a float in [0, 1]. |
| """ |
| if not self.history: |
| return {"urgency": 1.0, "has_alt": 1.0} |
|
|
| actual_urgency = hidden.urgency_score |
| actual_alt = float(hidden.has_alternative) |
|
|
| urgency_sq_err = sum( |
| (b.est_urgency - actual_urgency) ** 2 for b in self.history |
| ) |
| alt_sq_err = sum( |
| (float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history |
| ) |
| n = len(self.history) |
| return { |
| "urgency": round(urgency_sq_err / n, 6), |
| "has_alt": round(alt_sq_err / n, 6), |
| } |
|
|
| def accuracy_against(self, hidden: HiddenState) -> float: |
| """ |
| Compute current belief accuracy against true hidden state. |
| |
| Args: |
| hidden: The true hidden state. |
| |
| Returns: |
| Accuracy score in [0, 1]. |
| """ |
| b = self.current_belief |
| budget_range = max(hidden.budget_ceiling * 0.5, 1.0) |
| walk_range = max(hidden.walk_away_price * 0.5, 1.0) |
|
|
| budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range |
| walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range |
| urgency_err = abs(b.est_urgency - hidden.urgency_score) |
| alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0 |
|
|
| mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0 |
| return max(0.0, 1.0 - mean_err) |
|
|