File size: 3,772 Bytes
3f61551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
One-episode ToM belief diagnostic. Run from repo root:
  - Put GOOGLE_API_KEY in .env, or: $env:GOOGLE_API_KEY="..."  # PowerShell
  python scripts/diagnose_tom.py
"""
from __future__ import annotations

import asyncio
import copy
import json
import os
import sys
from pathlib import Path

# Repo root on sys.path when launched as: python scripts/diagnose_tom.py
_ROOT = Path(__file__).resolve().parent.parent
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

# Load .env from project root (same as main.py) so GOOGLE_API_KEY is available
# when set in .env but not exported in the shell.
try:
    from dotenv import load_dotenv

    load_dotenv(_ROOT / ".env")
except ImportError:
    pass

from parlay_env.grader import _tom_accuracy
from parlay_env.models import BeliefState
from parlay_env.reward import BETA

from dashboard.api import (  # noqa: E402
    MoveRequest,
    _build_observation,
    _build_session,
    _sessions,
    make_move,
)

UTTERANCE_ODD = (
    "I think $155,000 reflects fair value here. We'd like to move forward."
)
UTTERANCE_EVEN = (
    "We can come down to $148,000 but that's near our floor."
)


def _tom_reward(belief: BeliefState, hidden) -> float:
    return BETA * _tom_accuracy(belief, hidden)


async def _run() -> None:
    key = (os.environ.get("GOOGLE_API_KEY") or "").strip()
    if not key:
        print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
        raise SystemExit(1)

    sid, session = _build_session("saas_enterprise", "diplomat", "ToM-Diagnose")
    _sessions[sid] = session
    tom = session["tom_tracker"]
    state0 = session["state"]

    initial = tom.current_belief
    snapshots: list[BeliefState] = [copy.deepcopy(initial)]

    init_obs = _build_observation(state0)
    print(json.dumps(init_obs, indent=2, default=str))

    for n in range(1, 9):
        if n % 2 == 1:
            utterance, amount = UTTERANCE_ODD, 155_000.0
        else:
            utterance, amount = UTTERANCE_EVEN, 148_000.0
        result = await make_move(
            MoveRequest(
                session_id=sid,
                amount=amount,
                message=utterance,
                tactical_move=None,
            )
        )
        session = _sessions[sid]
        st = session["state"]
        tom = session["tom_tracker"]
        b = tom.current_belief
        snapshots.append(copy.deepcopy(b))

        tr = _tom_reward(b, st.hidden_state)
        print(
            f"     [Turn {n}] belief_budget={b.est_budget:.3f}  "
            f"belief_urgency={b.est_urgency:.3f}  "
            f"belief_walkaway={b.est_walk_away:.3f}  "
            f"tom_reward={tr:.4f}"
        )
        if result.get("done"):
            break

    s0, s1 = snapshots[0], snapshots[-1]
    total_move = 0.0
    for i in range(1, len(snapshots)):
        a, snap_b = snapshots[i - 1], snapshots[i]
        total_move += abs(snap_b.est_budget - a.est_budget)
        total_move += abs(snap_b.est_urgency - a.est_urgency)
        total_move += abs(snap_b.est_walk_away - a.est_walk_away)

    print()
    print("   === ToM Diagnostic Summary ===")
    print(
        f"   Initial beliefs: budget={s0.est_budget:.3f}  "
        f"urgency={s0.est_urgency:.3f}  walkaway={s0.est_walk_away:.3f}"
    )
    print(
        f"   Final beliefs:   budget={s1.est_budget:.3f}  "
        f"urgency={s1.est_urgency:.3f}  walkaway={s1.est_walk_away:.3f}"
    )
    print(f"   Total belief movement: {total_move:.4f}")
    if total_move > 0.05:
        msg = "BELIEFS ARE MOVING — ToM reward is live"
    else:
        msg = "WARNING: beliefs stuck — ToM reward contributing ~0 to training signal"
    print(f"   RESULT: {msg}")


def main() -> None:
    asyncio.run(_run())


if __name__ == "__main__":
    main()