sh4shv4t commited on
Commit
dd46a0d
·
1 Parent(s): f5f4abf

refactor(agent): migrate Gemini client from google-generativeai to google-genai

Browse files
.cursorrules CHANGED
@@ -12,8 +12,8 @@ NO Anthropic API anywhere. NO npm. NO build step.
12
  ## LLM Client Rules (Gemini)
13
 
14
  - Model: `gemini-2.0-flash` everywhere. Never use other model names in game/agent code.
15
- - Import: `import google.generativeai as genai` — never `anthropic`.
16
- - Configure once at module level: `genai.configure(api_key=os.environ["GOOGLE_API_KEY"])`.
17
  - ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
18
  - JSON extraction: always strip markdown fences before `json.loads()`.
19
  - Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
 
12
  ## LLM Client Rules (Gemini)
13
 
14
  - Model: `gemini-2.0-flash` everywhere. Never use other model names in game/agent code.
15
+ - Import: `from google import genai` and `from google.genai import types` — never `anthropic` or legacy `google.generativeai`.
16
+ - Client: `genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", ""))`; chats via `client.chats.create(model="gemini-2.0-flash", ...)`.
17
  - ALL Gemini calls wrapped in try/except returning `SYNTHETIC_RESPONSE` on failure.
18
  - JSON extraction: always strip markdown fences before `json.loads()`.
19
  - Async Gemini calls: use `asyncio.get_event_loop().run_in_executor(None, lambda: ...)`.
agent/gemini_client.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Google Gemini 2.0 Flash client for Parlay.
3
- All calls are async (via run_in_executor). All errors return SYNTHETIC_RESPONSE.
4
  """
5
  import asyncio
6
  import json
@@ -8,12 +8,41 @@ import logging
8
  import os
9
  from typing import Optional
10
 
11
- import google.generativeai as genai
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
- genai.configure(api_key=os.environ.get("GOOGLE_API_KEY", ""))
16
- _model = genai.GenerativeModel("gemini-2.0-flash")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  SYNTHETIC_RESPONSE: dict = {
19
  "utterance": "I need a moment to consider your proposal.",
@@ -49,22 +78,25 @@ async def call_gemini(
49
  f'{{"utterance": "...", "offer_amount": <number or null>, '
50
  f'"tactical_move": <string or null>}}'
51
  )
52
-
53
- loop = asyncio.get_event_loop()
54
- chat = _model.start_chat(history=history)
55
-
56
- response = await loop.run_in_executor(
57
- None,
58
- lambda: chat.send_message(
59
- f"{full_prompt}\n\nUser: {last_msg}",
60
- generation_config=genai.types.GenerationConfig(
 
61
  max_output_tokens=max_tokens,
62
  temperature=0.7,
63
  ),
64
- ),
65
- )
 
 
66
 
67
- text = response.text.strip()
68
  text = text.replace("```json", "").replace("```", "").strip()
69
  parsed = json.loads(text)
70
 
@@ -79,10 +111,10 @@ async def call_gemini(
79
 
80
  except json.JSONDecodeError:
81
  logger.warning("Gemini JSON parse failed — using text fallback")
82
- try:
83
- return {**SYNTHETIC_RESPONSE, "utterance": response.text[:300]}
84
- except Exception:
85
- return SYNTHETIC_RESPONSE
86
  except Exception as exc:
87
  logger.error(f"Gemini API error: {exc}")
88
  return SYNTHETIC_RESPONSE
@@ -117,19 +149,22 @@ async def call_gemini_tom(
117
  )
118
 
119
  try:
120
- loop = asyncio.get_event_loop()
121
- chat = _model.start_chat(history=conversation_history)
122
- response = await loop.run_in_executor(
123
- None,
124
- lambda: chat.send_message(
 
125
  tom_prompt,
126
- generation_config=genai.types.GenerationConfig(
127
  max_output_tokens=200,
128
  temperature=0.3,
129
  ),
130
- ),
131
- )
132
- text = response.text.strip().replace("```json", "").replace("```", "").strip()
 
 
133
  return json.loads(text)
134
  except Exception as exc:
135
  logger.error(f"Gemini ToM inference error: {exc}")
 
1
  """
2
  Google Gemini 2.0 Flash client for Parlay.
3
+ Uses the google-genai SDK. All calls are async (via run_in_executor). All errors return SYNTHETIC_RESPONSE.
4
  """
5
  import asyncio
6
  import json
 
8
  import os
9
  from typing import Optional
10
 
11
+ from google import genai
12
+ from google.genai import types
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ MODEL_ID = "gemini-2.0-flash"
17
+
18
+ _client: Optional[genai.Client] = None
19
+
20
+
21
+ def _get_client() -> genai.Client:
22
+ """Lazily construct API client (empty key is allowed; calls then fail and return synthetic output)."""
23
+ global _client
24
+ if _client is None:
25
+ _client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY") or "")
26
+ return _client
27
+
28
+
29
+ def _legacy_messages_to_history(messages: list[dict]) -> list[types.Content]:
30
+ """Convert legacy {'role','parts'} messages to google-genai Content list."""
31
+ contents: list[types.Content] = []
32
+ for m in messages:
33
+ role = m.get("role", "user")
34
+ if role not in ("user", "model"):
35
+ role = "user"
36
+ raw_parts = m.get("parts") or []
37
+ parts: list[types.Part] = []
38
+ for p in raw_parts:
39
+ text = p if isinstance(p, str) else str(p)
40
+ parts.append(types.Part(text=text))
41
+ if not parts:
42
+ parts.append(types.Part(text=""))
43
+ contents.append(types.Content(role=role, parts=parts))
44
+ return contents
45
+
46
 
47
  SYNTHETIC_RESPONSE: dict = {
48
  "utterance": "I need a moment to consider your proposal.",
 
78
  f'{{"utterance": "...", "offer_amount": <number or null>, '
79
  f'"tactical_move": <string or null>}}'
80
  )
81
+ user_message = f"{full_prompt}\n\nUser: {last_msg}"
82
+
83
+ def _call() -> types.GenerateContentResponse:
84
+ chat = _get_client().chats.create(
85
+ model=MODEL_ID,
86
+ history=_legacy_messages_to_history(history),
87
+ )
88
+ return chat.send_message(
89
+ user_message,
90
+ config=types.GenerateContentConfig(
91
  max_output_tokens=max_tokens,
92
  temperature=0.7,
93
  ),
94
+ )
95
+
96
+ loop = asyncio.get_event_loop()
97
+ response = await loop.run_in_executor(None, _call)
98
 
99
+ text = (response.text or "").strip()
100
  text = text.replace("```json", "").replace("```", "").strip()
101
  parsed = json.loads(text)
102
 
 
111
 
112
  except json.JSONDecodeError:
113
  logger.warning("Gemini JSON parse failed — using text fallback")
114
+ raw = text[:300] if text else ""
115
+ if raw:
116
+ return {**SYNTHETIC_RESPONSE, "utterance": raw}
117
+ return SYNTHETIC_RESPONSE
118
  except Exception as exc:
119
  logger.error(f"Gemini API error: {exc}")
120
  return SYNTHETIC_RESPONSE
 
149
  )
150
 
151
  try:
152
+ def _call() -> types.GenerateContentResponse:
153
+ chat = _get_client().chats.create(
154
+ model=MODEL_ID,
155
+ history=_legacy_messages_to_history(conversation_history),
156
+ )
157
+ return chat.send_message(
158
  tom_prompt,
159
+ config=types.GenerateContentConfig(
160
  max_output_tokens=200,
161
  temperature=0.3,
162
  ),
163
+ )
164
+
165
+ loop = asyncio.get_event_loop()
166
+ response = await loop.run_in_executor(None, _call)
167
+ text = (response.text or "").strip().replace("```json", "").replace("```", "").strip()
168
  return json.loads(text)
169
  except Exception as exc:
170
  logger.error(f"Gemini ToM inference error: {exc}")
requirements.txt CHANGED
@@ -3,7 +3,7 @@ uvicorn[standard]==0.45.0
3
  websockets==15.0.1
4
  pydantic>=2.11.7,<3.0.0
5
  aiosqlite==0.20.0
6
- google-generativeai>=0.8.0
7
  fastmcp==3.2.4
8
  numpy==1.26.4
9
  scikit-learn==1.4.2
 
3
  websockets==15.0.1
4
  pydantic>=2.11.7,<3.0.0
5
  aiosqlite==0.20.0
6
+ google-genai>=1.0.0
7
  fastmcp==3.2.4
8
  numpy==1.26.4
9
  scikit-learn==1.4.2
training/notebooks/parlay_training.ipynb CHANGED
@@ -43,7 +43,7 @@
43
  "outputs": [],
44
  "source": [
45
  "# Cell 1: Install all dependencies\n",
46
- "!pip install -q fastapi uvicorn websockets pydantic aiosqlite google-generativeai fastmcp numpy python-dotenv httpx\n",
47
  "!pip install -q trl peft transformers accelerate bitsandbytes datasets huggingface-hub\n",
48
  "!pip install -q matplotlib\n",
49
  "print('✓ All dependencies installed')"
 
43
  "outputs": [],
44
  "source": [
45
  "# Cell 1: Install all dependencies\n",
46
+ "!pip install -q fastapi uvicorn websockets pydantic aiosqlite google-genai fastmcp numpy python-dotenv httpx\n",
47
  "!pip install -q trl peft transformers accelerate bitsandbytes datasets huggingface-hub\n",
48
  "!pip install -q matplotlib\n",
49
  "print('✓ All dependencies installed')"