File size: 3,772 Bytes
3f61551 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
One-episode ToM belief diagnostic. Run from repo root:
- Put GOOGLE_API_KEY in .env, or: $env:GOOGLE_API_KEY="..." # PowerShell
python scripts/diagnose_tom.py
"""
from __future__ import annotations
import asyncio
import copy
import json
import os
import sys
from pathlib import Path
# Repo root on sys.path when launched as: python scripts/diagnose_tom.py
_ROOT = Path(__file__).resolve().parent.parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
# Load .env from project root (same as main.py) so GOOGLE_API_KEY is available
# when set in .env but not exported in the shell.
try:
from dotenv import load_dotenv
load_dotenv(_ROOT / ".env")
except ImportError:
pass
from parlay_env.grader import _tom_accuracy
from parlay_env.models import BeliefState
from parlay_env.reward import BETA
from dashboard.api import ( # noqa: E402
MoveRequest,
_build_observation,
_build_session,
_sessions,
make_move,
)
UTTERANCE_ODD = (
"I think $155,000 reflects fair value here. We'd like to move forward."
)
UTTERANCE_EVEN = (
"We can come down to $148,000 but that's near our floor."
)
def _tom_reward(belief: BeliefState, hidden) -> float:
return BETA * _tom_accuracy(belief, hidden)
async def _run() -> None:
key = (os.environ.get("GOOGLE_API_KEY") or "").strip()
if not key:
print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
raise SystemExit(1)
sid, session = _build_session("saas_enterprise", "diplomat", "ToM-Diagnose")
_sessions[sid] = session
tom = session["tom_tracker"]
state0 = session["state"]
initial = tom.current_belief
snapshots: list[BeliefState] = [copy.deepcopy(initial)]
init_obs = _build_observation(state0)
print(json.dumps(init_obs, indent=2, default=str))
for n in range(1, 9):
if n % 2 == 1:
utterance, amount = UTTERANCE_ODD, 155_000.0
else:
utterance, amount = UTTERANCE_EVEN, 148_000.0
result = await make_move(
MoveRequest(
session_id=sid,
amount=amount,
message=utterance,
tactical_move=None,
)
)
session = _sessions[sid]
st = session["state"]
tom = session["tom_tracker"]
b = tom.current_belief
snapshots.append(copy.deepcopy(b))
tr = _tom_reward(b, st.hidden_state)
print(
f" [Turn {n}] belief_budget={b.est_budget:.3f} "
f"belief_urgency={b.est_urgency:.3f} "
f"belief_walkaway={b.est_walk_away:.3f} "
f"tom_reward={tr:.4f}"
)
if result.get("done"):
break
s0, s1 = snapshots[0], snapshots[-1]
total_move = 0.0
for i in range(1, len(snapshots)):
a, snap_b = snapshots[i - 1], snapshots[i]
total_move += abs(snap_b.est_budget - a.est_budget)
total_move += abs(snap_b.est_urgency - a.est_urgency)
total_move += abs(snap_b.est_walk_away - a.est_walk_away)
print()
print(" === ToM Diagnostic Summary ===")
print(
f" Initial beliefs: budget={s0.est_budget:.3f} "
f"urgency={s0.est_urgency:.3f} walkaway={s0.est_walk_away:.3f}"
)
print(
f" Final beliefs: budget={s1.est_budget:.3f} "
f"urgency={s1.est_urgency:.3f} walkaway={s1.est_walk_away:.3f}"
)
print(f" Total belief movement: {total_move:.4f}")
if total_move > 0.05:
msg = "BELIEFS ARE MOVING — ToM reward is live"
else:
msg = "WARNING: beliefs stuck — ToM reward contributing ~0 to training signal"
print(f" RESULT: {msg}")
def main() -> None:
asyncio.run(_run())
if __name__ == "__main__":
main()
|