Parlay / agent /tom_tracker_bayesian.py
sh4shv4t's picture
feat: flash-lite for data-gen and flash for UI; remove training page; card tests; --quiet data gen; data/ inspect path; random baseline; GRPO env wrapper; reward fixes (buyer ZOPA, ToM signals); drift + Brier metrics; Bayesian ToM module
15976d0
"""
Bayesian Theory-of-Mind belief tracker for Parlay.
Drop-in replacement for ToMTracker that uses Kalman-filter-style Gaussian
belief updates instead of hand-tuned arithmetic nudges.
Key insight
-----------
The opponent has hidden variables (budget_ceiling, walk_away_price, urgency,
has_alternative). Each observed offer is a noisy signal about these.
We model each continuous variable as a Gaussian (mean, variance) and update
using the standard Bayesian update for Gaussian conjugate priors:
posterior_mean = (prior_mean / prior_var + obs / obs_var) /
(1 / prior_var + 1 / obs_var)
posterior_var = 1 / (1 / prior_var + 1 / obs_var)
`confidence` is derived from the posterior variance:
confidence = 1 / (1 + sqrt(budget_var / budget_meanΒ²))
Usage (as feature-flag alternative to ToMTracker):
from agent.tom_tracker_bayesian import BayesianToMTracker as ToMTracker
# Then use exactly the same API as ToMTracker β€” all method signatures match.
"""
import logging
import math
import sys
from typing import Optional
from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove
logger = logging.getLogger(__name__)
class BayesianToMTracker:
"""
Gaussian-posterior belief tracker for the opponent's hidden state.
Extends the original ToMTracker API with proper Bayesian updating.
The same public methods (update, drift_event, accuracy_against,
brier_scores, log_belief_snapshot) are preserved for drop-in use.
Internal state:
_budget_mean, _budget_var β€” Gaussian over opponent's budget ceiling.
_walk_mean, _walk_var β€” Gaussian over opponent's walk-away price.
_urgency_mean, _urgency_var β€” Gaussian over urgency [0, 1].
_alt_prob β€” Bernoulli probability of has_alternative.
"""
# Observation noise variances (tuned for B2B negotiation scale).
# Budget/walk-away: observed offer is a noisy signal; high variance because
# opponents rarely reveal their true limits.
_OBS_BUDGET_VAR_FRAC = 0.10 # 10% of current mean estimate as std
_OBS_URGENCY_VAR = 0.05 # small update per offer-ratio signal
def __init__(
self,
initial_belief: BeliefState,
persona: PersonaType,
) -> None:
"""
Args:
initial_belief: Starting BeliefState (imprecise prior).
persona: Opponent persona (known to the player).
"""
self.persona = persona
self._bluffs_detected: int = 0
# Initialise Gaussian priors from the initial belief
self._budget_mean = float(initial_belief.est_budget)
self._walk_mean = float(initial_belief.est_walk_away)
self._urgency_mean = float(initial_belief.est_urgency)
self._alt_prob = 0.3 # prior: 30% chance opponent has an alternative
# Initial variances β€” large uncertainty at the start
self._budget_var = (self._budget_mean * 0.30) ** 2 # Β±30% std
self._walk_var = (self._walk_mean * 0.30) ** 2
self._urgency_var = 0.08 # std β‰ˆ 0.28 over [0, 1]
self.history: list[BeliefState] = [self._snapshot()]
logger.debug(
"BayesianToMTracker init: budget_mean=%.0f walk_mean=%.0f urgency_mean=%.2f",
self._budget_mean, self._walk_mean, self._urgency_mean,
)
# ── Internal helpers ──────────────────────────────────────────────────────
def _snapshot(self) -> BeliefState:
"""Convert current Gaussian state to a BeliefState snapshot."""
confidence = self._compute_confidence()
return BeliefState(
est_budget=round(self._budget_mean, 2),
est_walk_away=round(self._walk_mean, 2),
est_urgency=round(max(0.0, min(1.0, self._urgency_mean)), 4),
est_has_alternative=self._alt_prob >= 0.5,
confidence=round(confidence, 4),
)
def _compute_confidence(self) -> float:
"""
Confidence = 1 - mean relative std across all variables.
Shrinks variance β†’ higher confidence.
"""
budget_rel_std = math.sqrt(self._budget_var) / max(abs(self._budget_mean), 1.0)
walk_rel_std = math.sqrt(self._walk_var) / max(abs(self._walk_mean), 1.0)
urgency_std = math.sqrt(self._urgency_var)
alt_std = math.sqrt(self._alt_prob * (1.0 - self._alt_prob))
mean_uncertainty = (budget_rel_std + walk_rel_std + urgency_std + alt_std) / 4.0
return max(0.0, min(1.0, 1.0 - mean_uncertainty))
@staticmethod
def _gaussian_update(
prior_mean: float,
prior_var: float,
obs: float,
obs_var: float,
) -> tuple[float, float]:
"""
Closed-form Bayesian update for Gaussian conjugate prior.
posterior_mean = (prior_mean / prior_var + obs / obs_var) /
(1 / prior_var + 1 / obs_var)
posterior_var = 1 / (1 / prior_var + 1 / obs_var)
"""
prec_prior = 1.0 / max(prior_var, 1e-10)
prec_obs = 1.0 / max(obs_var, 1e-10)
posterior_prec = prec_prior + prec_obs
posterior_mean = (prec_prior * prior_mean + prec_obs * obs) / posterior_prec
posterior_var = 1.0 / posterior_prec
return posterior_mean, posterior_var
# ── Public API (matches ToMTracker) ──────────────────────────────────────
@property
def current_belief(self) -> BeliefState:
return self.history[-1]
@property
def bluffs_detected(self) -> int:
return self._bluffs_detected
def log_belief_snapshot(self, turn: int) -> None:
b = self.current_belief
print(
f"[BayesToM turn={turn}] "
f"budget={b.est_budget:.0f}Β±{math.sqrt(self._budget_var):.0f} "
f"urgency={b.est_urgency:.3f}Β±{math.sqrt(self._urgency_var):.3f} "
f"alt_prob={self._alt_prob:.2f} conf={b.confidence:.2f}",
file=sys.stderr,
)
def update(
self,
observed_offer: Optional[float],
observed_move: Optional[TacticalMove],
utterance: str,
turn: int,
) -> BeliefState:
"""
Bayesian update of all beliefs from one observed opponent action.
Budget update: if we see an offer O, the true budget is likely > O.
We treat O as a lower-bound signal: observation = O * 1.05
with variance proportional to the current mean.
Urgency update: offer-ratio below 0.85 β†’ urgency signal 0.7;
above 0.95 β†’ urgency signal 0.3. Both with moderate obs variance.
has_alternative: updated as Bernoulli likelihood ratio (keyword match).
"""
# ── Budget Bayesian update ──────────────────────────────────────────
if observed_offer is not None and observed_offer > 0:
budget_obs = observed_offer * 1.05
obs_budget_var = (self._budget_mean * self._OBS_BUDGET_VAR_FRAC) ** 2
self._budget_mean, self._budget_var = self._gaussian_update(
self._budget_mean, self._budget_var,
budget_obs, obs_budget_var,
)
logger.debug(
"Bayesian budget update: obs=%.0f β†’ mean=%.0f std=%.0f",
budget_obs, self._budget_mean, math.sqrt(self._budget_var),
)
# ── Walk-away update: BATNA_REVEAL is a noisy signal ───────────────
if observed_move == TacticalMove.BATNA_REVEAL:
if observed_offer is not None:
walk_obs = observed_offer * 0.95
obs_walk_var = (self._walk_mean * 0.15) ** 2
self._walk_mean, self._walk_var = self._gaussian_update(
self._walk_mean, self._walk_var,
walk_obs, obs_walk_var,
)
logger.debug("Bayesian walk-away update via BATNA_REVEAL")
# ── Urgency Bayesian update via offer-ratio signal ─────────────────
if observed_offer is not None and self._budget_mean > 0:
offer_ratio = observed_offer / self._budget_mean
if offer_ratio < 0.85:
urgency_obs = 0.70 # low offer β†’ opponent likely more urgent
elif offer_ratio > 0.95:
urgency_obs = 0.30 # high offer β†’ opponent comfortable
else:
urgency_obs = 0.50 # neutral
self._urgency_mean, self._urgency_var = self._gaussian_update(
self._urgency_mean, self._urgency_var,
urgency_obs, self._OBS_URGENCY_VAR,
)
self._urgency_mean = max(0.0, min(1.0, self._urgency_mean))
# ── has_alternative Bernoulli update (likelihood ratio) ────────────
alt_signals = ["competitor", "alternative", "other offer", "another bid"]
if any(sig in utterance.lower() for sig in alt_signals):
self._alt_prob = min(0.95, self._alt_prob + (1.0 - self._alt_prob) * 0.35)
logger.debug("Alternative signal detected β†’ alt_prob=%.2f", self._alt_prob)
else:
self._alt_prob = max(0.05, self._alt_prob * 0.98) # small decay
# ── Bluff detection (shark persona + BATNA_REVEAL + "competitor") ──
if (
self.persona == PersonaType.SHARK
and observed_move == TacticalMove.BATNA_REVEAL
and "competitor" in utterance.lower()
):
self._bluffs_detected += 1
logger.info("BayesToM: bluff detected (total: %d)", self._bluffs_detected)
updated = self._snapshot()
self.history.append(updated)
logger.debug(
"BayesToM update turn=%d: budget=%.0f walk=%.0f urgency=%.2f alt_prob=%.2f conf=%.2f",
turn, self._budget_mean, self._walk_mean, self._urgency_mean,
self._alt_prob, updated.confidence,
)
return updated
def drift_event(
self,
effect_on_urgency: float,
effect_on_has_alternative: bool,
event_description: str = "",
) -> BeliefState:
"""
Apply a market/scenario drift event.
Nudges the urgency mean and resets alt_prob based on the drift direction.
Also inflates all variances (drift = increased uncertainty).
"""
self._urgency_mean = float(max(0.0, min(1.0, self._urgency_mean + effect_on_urgency)))
self._urgency_var = min(0.1, self._urgency_var * 1.5) # inflate uncertainty
# Drift shifts alt belief
if effect_on_has_alternative:
self._alt_prob = min(0.9, self._alt_prob + 0.25)
else:
self._alt_prob = max(0.1, self._alt_prob - 0.1)
# Inflate budget/walk variances β€” drift reduces confidence
self._budget_var *= 1.3
self._walk_var *= 1.3
updated = self._snapshot()
self.history.append(updated)
desc_part = f" | event={event_description!r}" if event_description else ""
logger.info(
"BayesToM drift applied%s: urgency_delta=%+.2f β†’ %.2f, alt_prob=%.2f, conf=%.2f",
desc_part, effect_on_urgency, self._urgency_mean, self._alt_prob, updated.confidence,
)
return updated
def accuracy_against(self, hidden: HiddenState) -> float:
"""
Compute current belief accuracy against true hidden state.
Same formula as ToMTracker for comparability.
"""
b = self.current_belief
budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
walk_range = max(hidden.walk_away_price * 0.5, 1.0)
budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range
walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range
urgency_err = abs(b.est_urgency - hidden.urgency_score)
alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0
mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
return max(0.0, 1.0 - mean_err)
def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
"""Brier scores for urgency and has_alternative over full belief history."""
if not self.history:
return {"urgency": 1.0, "has_alt": 1.0}
actual_urgency = hidden.urgency_score
actual_alt = float(hidden.has_alternative)
n = len(self.history)
brier_urgency = sum((b.est_urgency - actual_urgency) ** 2 for b in self.history) / n
brier_alt = sum((float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history) / n
return {
"urgency": round(brier_urgency, 6),
"has_alt": round(brier_alt, 6),
}