Parlay / scripts /diagnose_tom.py
sh4shv4t's picture
fix: upgrade gemini model string to 2.5-flash-lite + add tom diagnostic script
3f61551
"""
One-episode ToM belief diagnostic. Run from repo root:
- Put GOOGLE_API_KEY in .env, or: $env:GOOGLE_API_KEY="..." # PowerShell
python scripts/diagnose_tom.py
"""
from __future__ import annotations
import asyncio
import copy
import json
import os
import sys
from pathlib import Path
# Repo root on sys.path when launched as: python scripts/diagnose_tom.py
_ROOT = Path(__file__).resolve().parent.parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
# Load .env from project root (same as main.py) so GOOGLE_API_KEY is available
# when set in .env but not exported in the shell.
try:
from dotenv import load_dotenv
load_dotenv(_ROOT / ".env")
except ImportError:
pass
from parlay_env.grader import _tom_accuracy
from parlay_env.models import BeliefState
from parlay_env.reward import BETA
from dashboard.api import ( # noqa: E402
MoveRequest,
_build_observation,
_build_session,
_sessions,
make_move,
)
UTTERANCE_ODD = (
"I think $155,000 reflects fair value here. We'd like to move forward."
)
UTTERANCE_EVEN = (
"We can come down to $148,000 but that's near our floor."
)
def _tom_reward(belief: BeliefState, hidden) -> float:
return BETA * _tom_accuracy(belief, hidden)
async def _run() -> None:
key = (os.environ.get("GOOGLE_API_KEY") or "").strip()
if not key:
print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
raise SystemExit(1)
sid, session = _build_session("saas_enterprise", "diplomat", "ToM-Diagnose")
_sessions[sid] = session
tom = session["tom_tracker"]
state0 = session["state"]
initial = tom.current_belief
snapshots: list[BeliefState] = [copy.deepcopy(initial)]
init_obs = _build_observation(state0)
print(json.dumps(init_obs, indent=2, default=str))
for n in range(1, 9):
if n % 2 == 1:
utterance, amount = UTTERANCE_ODD, 155_000.0
else:
utterance, amount = UTTERANCE_EVEN, 148_000.0
result = await make_move(
MoveRequest(
session_id=sid,
amount=amount,
message=utterance,
tactical_move=None,
)
)
session = _sessions[sid]
st = session["state"]
tom = session["tom_tracker"]
b = tom.current_belief
snapshots.append(copy.deepcopy(b))
tr = _tom_reward(b, st.hidden_state)
print(
f" [Turn {n}] belief_budget={b.est_budget:.3f} "
f"belief_urgency={b.est_urgency:.3f} "
f"belief_walkaway={b.est_walk_away:.3f} "
f"tom_reward={tr:.4f}"
)
if result.get("done"):
break
s0, s1 = snapshots[0], snapshots[-1]
total_move = 0.0
for i in range(1, len(snapshots)):
a, snap_b = snapshots[i - 1], snapshots[i]
total_move += abs(snap_b.est_budget - a.est_budget)
total_move += abs(snap_b.est_urgency - a.est_urgency)
total_move += abs(snap_b.est_walk_away - a.est_walk_away)
print()
print(" === ToM Diagnostic Summary ===")
print(
f" Initial beliefs: budget={s0.est_budget:.3f} "
f"urgency={s0.est_urgency:.3f} walkaway={s0.est_walk_away:.3f}"
)
print(
f" Final beliefs: budget={s1.est_budget:.3f} "
f"urgency={s1.est_urgency:.3f} walkaway={s1.est_walk_away:.3f}"
)
print(f" Total belief movement: {total_move:.4f}")
if total_move > 0.05:
msg = "BELIEFS ARE MOVING — ToM reward is live"
else:
msg = "WARNING: beliefs stuck — ToM reward contributing ~0 to training signal"
print(f" RESULT: {msg}")
def main() -> None:
asyncio.run(_run())
if __name__ == "__main__":
main()