File size: 7,965 Bytes
698f4d8 80b3b2e 698f4d8 df724f2 698f4d8 80b3b2e 698f4d8 15976d0 698f4d8 15976d0 698f4d8 15976d0 698f4d8 15976d0 698f4d8 15976d0 698f4d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """
Theory of Mind belief tracker for Parlay.
Tracks and updates agent beliefs about opponent hidden state.
"""
import logging
import sys
from typing import Optional
from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove
logger = logging.getLogger(__name__)
# NOTE: ToMTracker is used in two paths:
# (1) agent/runner.py self-play — full update each turn;
# (2) parlay_env/server.py WebSocket server — also uses ToMTracker after Task 1 fix.
# Both paths now produce comparable belief_history for grader._tom_accuracy.
class ToMTracker:
"""
Tracks Theory of Mind beliefs about the opponent's hidden state.
Maintains a belief history and updates beliefs based on:
- Observed offers
- Tactical move signals
- Utterance content
- Drift event triggers
"""
def __init__(
self,
initial_belief: BeliefState,
persona: PersonaType,
) -> None:
"""
Args:
initial_belief: Starting belief state (imprecise prior).
persona: Opponent's persona type (known to player).
"""
self.history: list[BeliefState] = [initial_belief]
self.persona = persona
self._bluffs_detected: int = 0
@property
def current_belief(self) -> BeliefState:
"""Most recent belief state."""
return self.history[-1]
@property
def bluffs_detected(self) -> int:
"""Count of detected bluffs this session."""
return self._bluffs_detected
def log_belief_snapshot(self, turn: int) -> None:
"""Print current belief estimates to stderr (diagnostic; training / live debug)."""
b = self.current_belief
print(
f"[ToM turn={turn}] budget={b.est_budget:.3f} "
f"urgency={b.est_urgency:.3f} walkaway={b.est_walk_away:.3f}",
file=sys.stderr,
)
def update(
self,
observed_offer: Optional[float],
observed_move: Optional[TacticalMove],
utterance: str,
turn: int,
) -> BeliefState:
"""
Update beliefs based on latest opponent action.
Args:
observed_offer: Opponent's counter-offer (None if no offer made).
observed_move: Tactical move used (if any).
utterance: Opponent's latest utterance.
turn: Current turn number.
Returns:
Updated BeliefState.
"""
last = self.current_belief
est_budget = last.est_budget
if observed_offer is not None:
est_budget = max(est_budget, observed_offer * 1.05)
est_walk_away = last.est_walk_away
if observed_move == TacticalMove.BATNA_REVEAL:
est_walk_away = last.est_walk_away * 0.95
logger.debug("ToM: BATNA_REVEAL detected — hedging walk-away estimate")
est_urgency = last.est_urgency
if observed_offer is not None and last.est_budget > 0:
offer_ratio = observed_offer / last.est_budget
if offer_ratio < 0.85:
est_urgency = min(1.0, est_urgency + 0.05)
elif offer_ratio > 0.95:
est_urgency = max(0.0, est_urgency - 0.03)
est_has_alternative = last.est_has_alternative
alternative_signals = ["competitor", "alternative", "other offer", "another bid"]
if any(sig in utterance.lower() for sig in alternative_signals):
est_has_alternative = True
logger.debug("ToM: alternative signal detected in utterance")
confidence = min(1.0, last.confidence + 0.04)
if (
self.persona == PersonaType.SHARK
and observed_move == TacticalMove.BATNA_REVEAL
and "competitor" in utterance.lower()
):
self._bluffs_detected += 1
logger.info(f"ToM: bluff detected (total: {self._bluffs_detected})")
updated = BeliefState(
est_budget=round(est_budget, 2),
est_walk_away=round(est_walk_away, 2),
est_urgency=round(est_urgency, 4),
est_has_alternative=est_has_alternative,
confidence=round(confidence, 4),
)
self.history.append(updated)
logger.debug(
f"ToM update turn={turn}: "
f"budget={est_budget:,.0f}, walk={est_walk_away:,.0f}, "
f"urgency={est_urgency:.2f}, alt={est_has_alternative}, "
f"confidence={confidence:.2f}"
)
return updated
def drift_event(
self,
effect_on_urgency: float,
effect_on_has_alternative: bool,
event_description: str = "",
) -> BeliefState:
"""
Apply a drift event to beliefs.
Args:
effect_on_urgency: Signed delta to urgency estimate.
effect_on_has_alternative: Override for has_alternative belief.
event_description: Human-readable scenario event string
(e.g. "Competitor drops price 15%").
Returns:
Updated BeliefState post-drift.
"""
last = self.current_belief
new_urgency = float(max(0.0, min(1.0, last.est_urgency + effect_on_urgency)))
updated = BeliefState(
est_budget=last.est_budget,
est_walk_away=last.est_walk_away,
est_urgency=round(new_urgency, 4),
est_has_alternative=effect_on_has_alternative,
confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
)
self.history.append(updated)
desc_part = f" | event={event_description!r}" if event_description else ""
logger.info(
f"ToM drift applied{desc_part}: "
f"urgency_delta={effect_on_urgency:+.2f} → {new_urgency:.2f}, "
f"alt={effect_on_has_alternative}"
)
return updated
def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
"""
Compute per-field Brier scores over the full belief history.
Brier score = (1/N) Σ (predicted - actual)²
Lower is better; 0 = perfect.
Fields scored:
- urgency: est_urgency (continuous 0–1) vs hidden.urgency_score
- has_alt: est_has_alternative (0/1 probability) vs hidden.has_alternative
Args:
hidden: The true hidden state revealed at episode end.
Returns:
Dict with keys "urgency" and "has_alt", each a float in [0, 1].
"""
if not self.history:
return {"urgency": 1.0, "has_alt": 1.0}
actual_urgency = hidden.urgency_score
actual_alt = float(hidden.has_alternative)
urgency_sq_err = sum(
(b.est_urgency - actual_urgency) ** 2 for b in self.history
)
alt_sq_err = sum(
(float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history
)
n = len(self.history)
return {
"urgency": round(urgency_sq_err / n, 6),
"has_alt": round(alt_sq_err / n, 6),
}
def accuracy_against(self, hidden: HiddenState) -> float:
"""
Compute current belief accuracy against true hidden state.
Args:
hidden: The true hidden state.
Returns:
Accuracy score in [0, 1].
"""
b = self.current_belief
budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
walk_range = max(hidden.walk_away_price * 0.5, 1.0)
budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range
walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range
urgency_err = abs(b.est_urgency - hidden.urgency_score)
alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0
mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
return max(0.0, 1.0 - mean_err)
|