Parlay / agent /tom_tracker.py
sh4shv4t's picture
Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)
df724f2
"""
Theory of Mind belief tracker for Parlay.
Tracks and updates agent beliefs about opponent hidden state.
"""
import logging
import sys
from typing import Optional
from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove
logger = logging.getLogger(__name__)
# NOTE: ToMTracker is used in two paths:
# (1) agent/runner.py self-play — full update each turn;
# (2) parlay_env/server.py WebSocket server — also uses ToMTracker after Task 1 fix.
# Both paths now produce comparable belief_history for grader._tom_accuracy.
class ToMTracker:
"""
Tracks Theory of Mind beliefs about the opponent's hidden state.
Maintains a belief history and updates beliefs based on:
- Observed offers
- Tactical move signals
- Utterance content
- Drift event triggers
"""
def __init__(
self,
initial_belief: BeliefState,
persona: PersonaType,
) -> None:
"""
Args:
initial_belief: Starting belief state (imprecise prior).
persona: Opponent's persona type (known to player).
"""
self.history: list[BeliefState] = [initial_belief]
self.persona = persona
self._bluffs_detected: int = 0
@property
def current_belief(self) -> BeliefState:
"""Most recent belief state."""
return self.history[-1]
@property
def bluffs_detected(self) -> int:
"""Count of detected bluffs this session."""
return self._bluffs_detected
def log_belief_snapshot(self, turn: int) -> None:
"""Print current belief estimates to stderr (diagnostic; training / live debug)."""
b = self.current_belief
print(
f"[ToM turn={turn}] budget={b.est_budget:.3f} "
f"urgency={b.est_urgency:.3f} walkaway={b.est_walk_away:.3f}",
file=sys.stderr,
)
def update(
self,
observed_offer: Optional[float],
observed_move: Optional[TacticalMove],
utterance: str,
turn: int,
) -> BeliefState:
"""
Update beliefs based on latest opponent action.
Args:
observed_offer: Opponent's counter-offer (None if no offer made).
observed_move: Tactical move used (if any).
utterance: Opponent's latest utterance.
turn: Current turn number.
Returns:
Updated BeliefState.
"""
last = self.current_belief
est_budget = last.est_budget
if observed_offer is not None:
est_budget = max(est_budget, observed_offer * 1.05)
est_walk_away = last.est_walk_away
if observed_move == TacticalMove.BATNA_REVEAL:
est_walk_away = last.est_walk_away * 0.95
logger.debug("ToM: BATNA_REVEAL detected — hedging walk-away estimate")
est_urgency = last.est_urgency
if observed_offer is not None and last.est_budget > 0:
offer_ratio = observed_offer / last.est_budget
if offer_ratio < 0.85:
est_urgency = min(1.0, est_urgency + 0.05)
elif offer_ratio > 0.95:
est_urgency = max(0.0, est_urgency - 0.03)
est_has_alternative = last.est_has_alternative
alternative_signals = ["competitor", "alternative", "other offer", "another bid"]
if any(sig in utterance.lower() for sig in alternative_signals):
est_has_alternative = True
logger.debug("ToM: alternative signal detected in utterance")
confidence = min(1.0, last.confidence + 0.04)
if (
self.persona == PersonaType.SHARK
and observed_move == TacticalMove.BATNA_REVEAL
and "competitor" in utterance.lower()
):
self._bluffs_detected += 1
logger.info(f"ToM: bluff detected (total: {self._bluffs_detected})")
updated = BeliefState(
est_budget=round(est_budget, 2),
est_walk_away=round(est_walk_away, 2),
est_urgency=round(est_urgency, 4),
est_has_alternative=est_has_alternative,
confidence=round(confidence, 4),
)
self.history.append(updated)
logger.debug(
f"ToM update turn={turn}: "
f"budget={est_budget:,.0f}, walk={est_walk_away:,.0f}, "
f"urgency={est_urgency:.2f}, alt={est_has_alternative}, "
f"confidence={confidence:.2f}"
)
return updated
def drift_event(
self,
effect_on_urgency: float,
effect_on_has_alternative: bool,
event_description: str = "",
) -> BeliefState:
"""
Apply a drift event to beliefs.
Args:
effect_on_urgency: Signed delta to urgency estimate.
effect_on_has_alternative: Override for has_alternative belief.
event_description: Human-readable scenario event string
(e.g. "Competitor drops price 15%").
Returns:
Updated BeliefState post-drift.
"""
last = self.current_belief
new_urgency = float(max(0.0, min(1.0, last.est_urgency + effect_on_urgency)))
updated = BeliefState(
est_budget=last.est_budget,
est_walk_away=last.est_walk_away,
est_urgency=round(new_urgency, 4),
est_has_alternative=effect_on_has_alternative,
confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
)
self.history.append(updated)
desc_part = f" | event={event_description!r}" if event_description else ""
logger.info(
f"ToM drift applied{desc_part}: "
f"urgency_delta={effect_on_urgency:+.2f}{new_urgency:.2f}, "
f"alt={effect_on_has_alternative}"
)
return updated
def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
"""
Compute per-field Brier scores over the full belief history.
Brier score = (1/N) Σ (predicted - actual)²
Lower is better; 0 = perfect.
Fields scored:
- urgency: est_urgency (continuous 0–1) vs hidden.urgency_score
- has_alt: est_has_alternative (0/1 probability) vs hidden.has_alternative
Args:
hidden: The true hidden state revealed at episode end.
Returns:
Dict with keys "urgency" and "has_alt", each a float in [0, 1].
"""
if not self.history:
return {"urgency": 1.0, "has_alt": 1.0}
actual_urgency = hidden.urgency_score
actual_alt = float(hidden.has_alternative)
urgency_sq_err = sum(
(b.est_urgency - actual_urgency) ** 2 for b in self.history
)
alt_sq_err = sum(
(float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history
)
n = len(self.history)
return {
"urgency": round(urgency_sq_err / n, 6),
"has_alt": round(alt_sq_err / n, 6),
}
def accuracy_against(self, hidden: HiddenState) -> float:
"""
Compute current belief accuracy against true hidden state.
Args:
hidden: The true hidden state.
Returns:
Accuracy score in [0, 1].
"""
b = self.current_belief
budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
walk_range = max(hidden.walk_away_price * 0.5, 1.0)
budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range
walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range
urgency_err = abs(b.est_urgency - hidden.urgency_score)
alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0
mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
return max(0.0, 1.0 - mean_err)