sh4shv4t commited on
Commit
3f61551
·
1 Parent(s): 80b3b2e

fix: upgrade gemini model string to 2.5-flash-lite + add tom diagnostic script

Browse files
Files changed (3) hide show
  1. .cursorrules +2 -2
  2. agent/gemini_client.py +1 -1
  3. scripts/diagnose_tom.py +130 -0
.cursorrules CHANGED
@@ -11,9 +11,9 @@ NO Anthropic API anywhere. NO npm. NO build step.
11
 
12
  ## LLM Client Rules (Gemini)
13
 
14
- - Model: `gemini-2.5-flash` everywhere. Never use other model names in game/agent code.
15
  - Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
16
- - Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.5-flash", ...)`.
17
  - ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
18
  - JSON extraction: always strip markdown fences before `json.loads()`.
19
  - Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
 
11
 
12
  ## LLM Client Rules (Gemini)
13
 
14
+ - Model: `gemini-2.5-flash-lite` everywhere. Never use other model names in game/agent code.
15
  - Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
16
+ - Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.5-flash-lite", ...)`.
17
  - ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
18
  - JSON extraction: always strip markdown fences before `json.loads()`.
19
  - Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
agent/gemini_client.py CHANGED
@@ -42,7 +42,7 @@ SCENARIO_ROLE_CONTEXT: dict[str, dict[str, str]] = {
42
  },
43
  }
44
 
45
- GEMINI_MODEL = "gemini-2.5-flash"
46
  # Aliases for imports (dashboard, MCP, training all use flash-lite)
47
  MODEL_ID_DEMO = GEMINI_MODEL
48
  MODEL_ID_DATA = GEMINI_MODEL
 
42
  },
43
  }
44
 
45
+ GEMINI_MODEL = "gemini-2.5-flash-lite"
46
  # Aliases for imports (dashboard, MCP, training all use flash-lite)
47
  MODEL_ID_DEMO = GEMINI_MODEL
48
  MODEL_ID_DATA = GEMINI_MODEL
scripts/diagnose_tom.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ One-episode ToM belief diagnostic. Run from repo root:
3
+ - Put GOOGLE_API_KEY in .env, or: $env:GOOGLE_API_KEY="..." # PowerShell
4
+ python scripts/diagnose_tom.py
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import copy
10
+ import json
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ # Repo root on sys.path when launched as: python scripts/diagnose_tom.py
16
+ _ROOT = Path(__file__).resolve().parent.parent
17
+ if str(_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(_ROOT))
19
+
20
+ # Load .env from project root (same as main.py) so GOOGLE_API_KEY is available
21
+ # when set in .env but not exported in the shell.
22
+ try:
23
+ from dotenv import load_dotenv
24
+
25
+ load_dotenv(_ROOT / ".env")
26
+ except ImportError:
27
+ pass
28
+
29
+ from parlay_env.grader import _tom_accuracy
30
+ from parlay_env.models import BeliefState
31
+ from parlay_env.reward import BETA
32
+
33
+ from dashboard.api import ( # noqa: E402
34
+ MoveRequest,
35
+ _build_observation,
36
+ _build_session,
37
+ _sessions,
38
+ make_move,
39
+ )
40
+
41
+ UTTERANCE_ODD = (
42
+ "I think $155,000 reflects fair value here. We'd like to move forward."
43
+ )
44
+ UTTERANCE_EVEN = (
45
+ "We can come down to $148,000 but that's near our floor."
46
+ )
47
+
48
+
49
+ def _tom_reward(belief: BeliefState, hidden) -> float:
50
+ return BETA * _tom_accuracy(belief, hidden)
51
+
52
+
53
+ async def _run() -> None:
54
+ key = (os.environ.get("GOOGLE_API_KEY") or "").strip()
55
+ if not key:
56
+ print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
57
+ raise SystemExit(1)
58
+
59
+ sid, session = _build_session("saas_enterprise", "diplomat", "ToM-Diagnose")
60
+ _sessions[sid] = session
61
+ tom = session["tom_tracker"]
62
+ state0 = session["state"]
63
+
64
+ initial = tom.current_belief
65
+ snapshots: list[BeliefState] = [copy.deepcopy(initial)]
66
+
67
+ init_obs = _build_observation(state0)
68
+ print(json.dumps(init_obs, indent=2, default=str))
69
+
70
+ for n in range(1, 9):
71
+ if n % 2 == 1:
72
+ utterance, amount = UTTERANCE_ODD, 155_000.0
73
+ else:
74
+ utterance, amount = UTTERANCE_EVEN, 148_000.0
75
+ result = await make_move(
76
+ MoveRequest(
77
+ session_id=sid,
78
+ amount=amount,
79
+ message=utterance,
80
+ tactical_move=None,
81
+ )
82
+ )
83
+ session = _sessions[sid]
84
+ st = session["state"]
85
+ tom = session["tom_tracker"]
86
+ b = tom.current_belief
87
+ snapshots.append(copy.deepcopy(b))
88
+
89
+ tr = _tom_reward(b, st.hidden_state)
90
+ print(
91
+ f" [Turn {n}] belief_budget={b.est_budget:.3f} "
92
+ f"belief_urgency={b.est_urgency:.3f} "
93
+ f"belief_walkaway={b.est_walk_away:.3f} "
94
+ f"tom_reward={tr:.4f}"
95
+ )
96
+ if result.get("done"):
97
+ break
98
+
99
+ s0, s1 = snapshots[0], snapshots[-1]
100
+ total_move = 0.0
101
+ for i in range(1, len(snapshots)):
102
+ a, snap_b = snapshots[i - 1], snapshots[i]
103
+ total_move += abs(snap_b.est_budget - a.est_budget)
104
+ total_move += abs(snap_b.est_urgency - a.est_urgency)
105
+ total_move += abs(snap_b.est_walk_away - a.est_walk_away)
106
+
107
+ print()
108
+ print(" === ToM Diagnostic Summary ===")
109
+ print(
110
+ f" Initial beliefs: budget={s0.est_budget:.3f} "
111
+ f"urgency={s0.est_urgency:.3f} walkaway={s0.est_walk_away:.3f}"
112
+ )
113
+ print(
114
+ f" Final beliefs: budget={s1.est_budget:.3f} "
115
+ f"urgency={s1.est_urgency:.3f} walkaway={s1.est_walk_away:.3f}"
116
+ )
117
+ print(f" Total belief movement: {total_move:.4f}")
118
+ if total_move > 0.05:
119
+ msg = "BELIEFS ARE MOVING — ToM reward is live"
120
+ else:
121
+ msg = "WARNING: beliefs stuck — ToM reward contributing ~0 to training signal"
122
+ print(f" RESULT: {msg}")
123
+
124
+
125
+ def main() -> None:
126
+ asyncio.run(_run())
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()