refactor(agent): migrate Gemini client from google-generativeai to google-genai
Browse files- .cursorrules +2 -2
- agent/gemini_client.py +64 -29
- requirements.txt +1 -1
- training/notebooks/parlay_training.ipynb +1 -1
.cursorrules
CHANGED
|
@@ -12,8 +12,8 @@ NO Anthropic API anywhere. NO npm. NO build step.
|
|
| 12 |
## LLM Client Rules (Gemini)
|
| 13 |
|
| 14 |
- Model: `gemini-2.0-flash` everywhere. Never use other model names in game/agent code.
|
| 15 |
-
- Import: `import google.
|
| 16 |
-
-
|
| 17 |
- ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
|
| 18 |
- JSON extraction: always strip markdown fences before `json.loads()`.
|
| 19 |
- Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
|
|
|
|
| 12 |
## LLM Client Rules (Gemini)
|
| 13 |
|
| 14 |
- Model: `gemini-2.0-flash` everywhere. Never use other model names in game/agent code.
|
| 15 |
+
- Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
|
| 16 |
+
- Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.0-flash", ...)`.
|
| 17 |
- ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
|
| 18 |
- JSON extraction: always strip markdown fences before `json.loads()`.
|
| 19 |
- Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
|
agent/gemini_client.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Google Gemini 2.0 Flash client for Parlay.
|
| 3 |
-
All calls are async (via run_in_executor). All errors return SYNTHETIC_RESPONSE.
|
| 4 |
"""
|
| 5 |
import asyncio
|
| 6 |
import json
|
|
@@ -8,12 +8,41 @@ import logging
|
|
| 8 |
import os
|
| 9 |
from typing import Optional
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
SYNTHETIC_RESPONSE: dict = {
|
| 19 |
"utterance": "I need a moment to consider your proposal.",
|
|
@@ -49,22 +78,25 @@ async def call_gemini(
|
|
| 49 |
f'{{"utterance": "...", "offer_amount": <number or null>, '
|
| 50 |
f'"tactical_move": <string or null>}}'
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
| 61 |
max_output_tokens=max_tokens,
|
| 62 |
temperature=0.7,
|
| 63 |
),
|
| 64 |
-
)
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
text = response.text.strip()
|
| 68 |
text = text.replace("```json", "").replace("```", "").strip()
|
| 69 |
parsed = json.loads(text)
|
| 70 |
|
|
@@ -79,10 +111,10 @@ async def call_gemini(
|
|
| 79 |
|
| 80 |
except json.JSONDecodeError:
|
| 81 |
logger.warning("Gemini JSON parse failed — using text fallback")
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
except Exception as exc:
|
| 87 |
logger.error(f"Gemini API error: {exc}")
|
| 88 |
return SYNTHETIC_RESPONSE
|
|
@@ -117,19 +149,22 @@ async def call_gemini_tom(
|
|
| 117 |
)
|
| 118 |
|
| 119 |
try:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
tom_prompt,
|
| 126 |
-
|
| 127 |
max_output_tokens=200,
|
| 128 |
temperature=0.3,
|
| 129 |
),
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
return json.loads(text)
|
| 134 |
except Exception as exc:
|
| 135 |
logger.error(f"Gemini ToM inference error: {exc}")
|
|
|
|
| 1 |
"""
|
| 2 |
Google Gemini 2.0 Flash client for Parlay.
|
| 3 |
+
Uses the google-genai SDK. All calls are async (via run_in_executor). All errors return SYNTHETIC_RESPONSE.
|
| 4 |
"""
|
| 5 |
import asyncio
|
| 6 |
import json
|
|
|
|
| 8 |
import os
|
| 9 |
from typing import Optional
|
| 10 |
|
| 11 |
+
from google import genai
|
| 12 |
+
from google.genai import types
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
MODEL_ID = "gemini-2.0-flash"
|
| 17 |
+
|
| 18 |
+
_client: Optional[genai.Client] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_client() -> genai.Client:
|
| 22 |
+
"""Lazily construct API client (empty key is allowed; calls then fail and return synthetic output)."""
|
| 23 |
+
global _client
|
| 24 |
+
if _client is None:
|
| 25 |
+
_client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY") or "")
|
| 26 |
+
return _client
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _legacy_messages_to_history(messages: list[dict]) -> list[types.Content]:
|
| 30 |
+
"""Convert legacy {'role','parts'} messages to google-genai Content list."""
|
| 31 |
+
contents: list[types.Content] = []
|
| 32 |
+
for m in messages:
|
| 33 |
+
role = m.get("role", "user")
|
| 34 |
+
if role not in ("user", "model"):
|
| 35 |
+
role = "user"
|
| 36 |
+
raw_parts = m.get("parts") or []
|
| 37 |
+
parts: list[types.Part] = []
|
| 38 |
+
for p in raw_parts:
|
| 39 |
+
text = p if isinstance(p, str) else str(p)
|
| 40 |
+
parts.append(types.Part(text=text))
|
| 41 |
+
if not parts:
|
| 42 |
+
parts.append(types.Part(text=""))
|
| 43 |
+
contents.append(types.Content(role=role, parts=parts))
|
| 44 |
+
return contents
|
| 45 |
+
|
| 46 |
|
| 47 |
SYNTHETIC_RESPONSE: dict = {
|
| 48 |
"utterance": "I need a moment to consider your proposal.",
|
|
|
|
| 78 |
f'{{"utterance": "...", "offer_amount": <number or null>, '
|
| 79 |
f'"tactical_move": <string or null>}}'
|
| 80 |
)
|
| 81 |
+
user_message = f"{full_prompt}\n\nUser: {last_msg}"
|
| 82 |
+
|
| 83 |
+
def _call() -> types.GenerateContentResponse:
|
| 84 |
+
chat = _get_client().chats.create(
|
| 85 |
+
model=MODEL_ID,
|
| 86 |
+
history=_legacy_messages_to_history(history),
|
| 87 |
+
)
|
| 88 |
+
return chat.send_message(
|
| 89 |
+
user_message,
|
| 90 |
+
config=types.GenerateContentConfig(
|
| 91 |
max_output_tokens=max_tokens,
|
| 92 |
temperature=0.7,
|
| 93 |
),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
loop = asyncio.get_event_loop()
|
| 97 |
+
response = await loop.run_in_executor(None, _call)
|
| 98 |
|
| 99 |
+
text = (response.text or "").strip()
|
| 100 |
text = text.replace("```json", "").replace("```", "").strip()
|
| 101 |
parsed = json.loads(text)
|
| 102 |
|
|
|
|
| 111 |
|
| 112 |
except json.JSONDecodeError:
|
| 113 |
logger.warning("Gemini JSON parse failed — using text fallback")
|
| 114 |
+
raw = text[:300] if text else ""
|
| 115 |
+
if raw:
|
| 116 |
+
return {**SYNTHETIC_RESPONSE, "utterance": raw}
|
| 117 |
+
return SYNTHETIC_RESPONSE
|
| 118 |
except Exception as exc:
|
| 119 |
logger.error(f"Gemini API error: {exc}")
|
| 120 |
return SYNTHETIC_RESPONSE
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
try:
|
| 152 |
+
def _call() -> types.GenerateContentResponse:
|
| 153 |
+
chat = _get_client().chats.create(
|
| 154 |
+
model=MODEL_ID,
|
| 155 |
+
history=_legacy_messages_to_history(conversation_history),
|
| 156 |
+
)
|
| 157 |
+
return chat.send_message(
|
| 158 |
tom_prompt,
|
| 159 |
+
config=types.GenerateContentConfig(
|
| 160 |
max_output_tokens=200,
|
| 161 |
temperature=0.3,
|
| 162 |
),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
loop = asyncio.get_event_loop()
|
| 166 |
+
response = await loop.run_in_executor(None, _call)
|
| 167 |
+
text = (response.text or "").strip().replace("```json", "").replace("```", "").strip()
|
| 168 |
return json.loads(text)
|
| 169 |
except Exception as exc:
|
| 170 |
logger.error(f"Gemini ToM inference error: {exc}")
|
requirements.txt
CHANGED
|
@@ -3,7 +3,7 @@ uvicorn[standard]==0.45.0
|
|
| 3 |
websockets==15.0.1
|
| 4 |
pydantic>=2.11.7,<3.0.0
|
| 5 |
aiosqlite==0.20.0
|
| 6 |
-
google-
|
| 7 |
fastmcp==3.2.4
|
| 8 |
numpy==1.26.4
|
| 9 |
scikit-learn==1.4.2
|
|
|
|
| 3 |
websockets==15.0.1
|
| 4 |
pydantic>=2.11.7,<3.0.0
|
| 5 |
aiosqlite==0.20.0
|
| 6 |
+
google-genai>=1.0.0
|
| 7 |
fastmcp==3.2.4
|
| 8 |
numpy==1.26.4
|
| 9 |
scikit-learn==1.4.2
|
training/notebooks/parlay_training.ipynb
CHANGED
|
@@ -43,7 +43,7 @@
|
|
| 43 |
"outputs": [],
|
| 44 |
"source": [
|
| 45 |
"# Cell 1: Install all dependencies\n",
|
| 46 |
-
"!pip install -q fastapi uvicorn websockets pydantic aiosqlite google-
|
| 47 |
"!pip install -q trl peft transformers accelerate bitsandbytes datasets huggingface-hub\n",
|
| 48 |
"!pip install -q matplotlib\n",
|
| 49 |
"print('✓ All dependencies installed')"
|
|
|
|
| 43 |
"outputs": [],
|
| 44 |
"source": [
|
| 45 |
"# Cell 1: Install all dependencies\n",
|
| 46 |
+
"!pip install -q fastapi uvicorn websockets pydantic aiosqlite google-genai fastmcp numpy python-dotenv httpx\n",
|
| 47 |
"!pip install -q trl peft transformers accelerate bitsandbytes datasets huggingface-hub\n",
|
| 48 |
"!pip install -q matplotlib\n",
|
| 49 |
"print('✓ All dependencies installed')"
|