fix: upgrade gemini model string to 2.5-flash-lite + add tom diagnostic script
Browse files- .cursorrules +2 -2
- agent/gemini_client.py +1 -1
- scripts/diagnose_tom.py +130 -0
.cursorrules
CHANGED
|
@@ -11,9 +11,9 @@ NO Anthropic API anywhere. NO npm. NO build step.
|
|
| 11 |
|
| 12 |
## LLM Client Rules (Gemini)
|
| 13 |
|
| 14 |
-
- Model: `gemini-2.5-flash` everywhere. Never use other model names in game/agent code.
|
| 15 |
- Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
|
| 16 |
-
- Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.5-flash", ...)`.
|
| 17 |
- ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
|
| 18 |
- JSON extraction: always strip markdown fences before `json.loads()`.
|
| 19 |
- Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
|
|
|
|
| 11 |
|
| 12 |
## LLM Client Rules (Gemini)
|
| 13 |
|
| 14 |
+
- Model: `gemini-2.5-flash-lite` everywhere. Never use other model names in game/agent code.
|
| 15 |
- Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
|
| 16 |
+
- Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.5-flash-lite", ...)`.
|
| 17 |
- ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
|
| 18 |
- JSON extraction: always strip markdown fences before `json.loads()`.
|
| 19 |
- Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
|
agent/gemini_client.py
CHANGED
|
@@ -42,7 +42,7 @@ SCENARIO_ROLE_CONTEXT: dict[str, dict[str, str]] = {
|
|
| 42 |
},
|
| 43 |
}
|
| 44 |
|
| 45 |
-
GEMINI_MODEL = "gemini-2.5-flash"
|
| 46 |
# Aliases for imports (dashboard, MCP, training all use flash-lite)
|
| 47 |
MODEL_ID_DEMO = GEMINI_MODEL
|
| 48 |
MODEL_ID_DATA = GEMINI_MODEL
|
|
|
|
| 42 |
},
|
| 43 |
}
|
| 44 |
|
| 45 |
+
GEMINI_MODEL = "gemini-2.5-flash-lite"
|
| 46 |
# Aliases for imports (dashboard, MCP, training all use flash-lite)
|
| 47 |
MODEL_ID_DEMO = GEMINI_MODEL
|
| 48 |
MODEL_ID_DATA = GEMINI_MODEL
|
scripts/diagnose_tom.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
One-episode ToM belief diagnostic. Run from repo root:
|
| 3 |
+
- Put GOOGLE_API_KEY in .env, or: $env:GOOGLE_API_KEY="..." # PowerShell
|
| 4 |
+
python scripts/diagnose_tom.py
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Repo root on sys.path when launched as: python scripts/diagnose_tom.py
|
| 16 |
+
_ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
+
if str(_ROOT) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(_ROOT))
|
| 19 |
+
|
| 20 |
+
# Load .env from project root (same as main.py) so GOOGLE_API_KEY is available
|
| 21 |
+
# when set in .env but not exported in the shell.
|
| 22 |
+
try:
|
| 23 |
+
from dotenv import load_dotenv
|
| 24 |
+
|
| 25 |
+
load_dotenv(_ROOT / ".env")
|
| 26 |
+
except ImportError:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
from parlay_env.grader import _tom_accuracy
|
| 30 |
+
from parlay_env.models import BeliefState
|
| 31 |
+
from parlay_env.reward import BETA
|
| 32 |
+
|
| 33 |
+
from dashboard.api import ( # noqa: E402
|
| 34 |
+
MoveRequest,
|
| 35 |
+
_build_observation,
|
| 36 |
+
_build_session,
|
| 37 |
+
_sessions,
|
| 38 |
+
make_move,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
UTTERANCE_ODD = (
|
| 42 |
+
"I think $155,000 reflects fair value here. We'd like to move forward."
|
| 43 |
+
)
|
| 44 |
+
UTTERANCE_EVEN = (
|
| 45 |
+
"We can come down to $148,000 but that's near our floor."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _tom_reward(belief: BeliefState, hidden) -> float:
|
| 50 |
+
return BETA * _tom_accuracy(belief, hidden)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
async def _run() -> None:
|
| 54 |
+
key = (os.environ.get("GOOGLE_API_KEY") or "").strip()
|
| 55 |
+
if not key:
|
| 56 |
+
print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
|
| 57 |
+
raise SystemExit(1)
|
| 58 |
+
|
| 59 |
+
sid, session = _build_session("saas_enterprise", "diplomat", "ToM-Diagnose")
|
| 60 |
+
_sessions[sid] = session
|
| 61 |
+
tom = session["tom_tracker"]
|
| 62 |
+
state0 = session["state"]
|
| 63 |
+
|
| 64 |
+
initial = tom.current_belief
|
| 65 |
+
snapshots: list[BeliefState] = [copy.deepcopy(initial)]
|
| 66 |
+
|
| 67 |
+
init_obs = _build_observation(state0)
|
| 68 |
+
print(json.dumps(init_obs, indent=2, default=str))
|
| 69 |
+
|
| 70 |
+
for n in range(1, 9):
|
| 71 |
+
if n % 2 == 1:
|
| 72 |
+
utterance, amount = UTTERANCE_ODD, 155_000.0
|
| 73 |
+
else:
|
| 74 |
+
utterance, amount = UTTERANCE_EVEN, 148_000.0
|
| 75 |
+
result = await make_move(
|
| 76 |
+
MoveRequest(
|
| 77 |
+
session_id=sid,
|
| 78 |
+
amount=amount,
|
| 79 |
+
message=utterance,
|
| 80 |
+
tactical_move=None,
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
session = _sessions[sid]
|
| 84 |
+
st = session["state"]
|
| 85 |
+
tom = session["tom_tracker"]
|
| 86 |
+
b = tom.current_belief
|
| 87 |
+
snapshots.append(copy.deepcopy(b))
|
| 88 |
+
|
| 89 |
+
tr = _tom_reward(b, st.hidden_state)
|
| 90 |
+
print(
|
| 91 |
+
f" [Turn {n}] belief_budget={b.est_budget:.3f} "
|
| 92 |
+
f"belief_urgency={b.est_urgency:.3f} "
|
| 93 |
+
f"belief_walkaway={b.est_walk_away:.3f} "
|
| 94 |
+
f"tom_reward={tr:.4f}"
|
| 95 |
+
)
|
| 96 |
+
if result.get("done"):
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
s0, s1 = snapshots[0], snapshots[-1]
|
| 100 |
+
total_move = 0.0
|
| 101 |
+
for i in range(1, len(snapshots)):
|
| 102 |
+
a, snap_b = snapshots[i - 1], snapshots[i]
|
| 103 |
+
total_move += abs(snap_b.est_budget - a.est_budget)
|
| 104 |
+
total_move += abs(snap_b.est_urgency - a.est_urgency)
|
| 105 |
+
total_move += abs(snap_b.est_walk_away - a.est_walk_away)
|
| 106 |
+
|
| 107 |
+
print()
|
| 108 |
+
print(" === ToM Diagnostic Summary ===")
|
| 109 |
+
print(
|
| 110 |
+
f" Initial beliefs: budget={s0.est_budget:.3f} "
|
| 111 |
+
f"urgency={s0.est_urgency:.3f} walkaway={s0.est_walk_away:.3f}"
|
| 112 |
+
)
|
| 113 |
+
print(
|
| 114 |
+
f" Final beliefs: budget={s1.est_budget:.3f} "
|
| 115 |
+
f"urgency={s1.est_urgency:.3f} walkaway={s1.est_walk_away:.3f}"
|
| 116 |
+
)
|
| 117 |
+
print(f" Total belief movement: {total_move:.4f}")
|
| 118 |
+
if total_move > 0.05:
|
| 119 |
+
msg = "BELIEFS ARE MOVING — ToM reward is live"
|
| 120 |
+
else:
|
| 121 |
+
msg = "WARNING: beliefs stuck — ToM reward contributing ~0 to training signal"
|
| 122 |
+
print(f" RESULT: {msg}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def main() -> None:
|
| 126 |
+
asyncio.run(_run())
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|