srishtichugh commited on
Commit
97a9e25
Β·
1 Parent(s): c7412fb

fix gemini model name and remove models/ prefix

Browse files
Files changed (1) hide show
  1. codes/llm_provider.py +11 -66
codes/llm_provider.py CHANGED
@@ -61,26 +61,20 @@ class OpenAIProvider(LLMProvider):
61
 
62
  def calculate_cost(self, usage: Dict, model: str) -> float:
63
  """Calculate OpenAI cost"""
64
- # Pricing per 1M tokens
65
  model_costs = {
66
  "gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600},
67
  "gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00},
68
  "gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50},
69
  "o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40},
70
  }
71
-
72
  costs = model_costs.get(model, model_costs["gpt-4o-mini"])
73
-
74
  prompt_tokens = usage['prompt_tokens']
75
  cached_tokens = usage.get('cached_tokens', 0)
76
  completion_tokens = usage['completion_tokens']
77
-
78
  actual_input_tokens = prompt_tokens - cached_tokens
79
-
80
  input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
81
  cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
82
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
83
-
84
  return input_cost + cached_cost + output_cost
85
 
86
 
@@ -103,11 +97,7 @@ class GeminiProvider(LLMProvider):
103
  # Convert OpenAI message format to Gemini format
104
  gemini_messages = self._convert_messages(messages)
105
 
106
- """# Fix model name - Gemini expects models/model-name format
107
- if not model.startswith('models/'):
108
- model = f'models/{model}'"""
109
-
110
- # Create model
111
  gemini_model = self.genai.GenerativeModel(model)
112
 
113
  # Generate response
@@ -120,34 +110,27 @@ class GeminiProvider(LLMProvider):
120
 
121
  def _convert_messages(self, messages: List[Dict]) -> str:
122
  """Convert OpenAI messages to Gemini prompt format"""
123
- # Gemini uses a simpler format - concatenate all messages
124
  prompt_parts = []
125
-
126
  for msg in messages:
127
  role = msg['role']
128
  content = msg['content']
129
-
130
  if role == 'system':
131
  prompt_parts.append(f"System Instructions:\n{content}\n")
132
  elif role == 'user':
133
  prompt_parts.append(f"User:\n{content}\n")
134
  elif role == 'assistant':
135
  prompt_parts.append(f"Assistant:\n{content}\n")
136
-
137
  return "\n".join(prompt_parts)
138
 
139
  def _get_generation_config(self, **kwargs):
140
  """Convert OpenAI kwargs to Gemini generation config"""
141
  config = {}
142
-
143
- # Map common parameters
144
  if 'temperature' in kwargs:
145
  config['temperature'] = kwargs['temperature']
146
  if 'max_tokens' in kwargs:
147
  config['max_output_tokens'] = kwargs['max_tokens']
148
  if 'top_p' in kwargs:
149
  config['top_p'] = kwargs['top_p']
150
-
151
  return config
152
 
153
  def get_response_text(self, completion: Any) -> str:
@@ -156,7 +139,6 @@ class GeminiProvider(LLMProvider):
156
 
157
  def get_usage_info(self, completion: Any) -> Dict:
158
  """Extract usage from Gemini response"""
159
- # Gemini provides token counts in metadata
160
  try:
161
  metadata = completion.usage_metadata
162
  return {
@@ -166,7 +148,6 @@ class GeminiProvider(LLMProvider):
166
  'cached_tokens': getattr(metadata, 'cached_content_token_count', 0)
167
  }
168
  except:
169
- # Fallback if metadata not available
170
  return {
171
  'prompt_tokens': 0,
172
  'completion_tokens': 0,
@@ -176,25 +157,20 @@ class GeminiProvider(LLMProvider):
176
 
177
  def calculate_cost(self, usage: Dict, model: str) -> float:
178
  """Calculate Gemini cost"""
179
- # Gemini pricing per 1M tokens (as of Jan 2026)
180
  model_costs = {
181
  "gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30},
182
  "gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00},
183
- "gemini-2.0-flash-exp": {"input": 0.0, "cached": 0.0, "output": 0.0}, # Free during preview
 
184
  }
185
-
186
- costs = model_costs.get(model, model_costs["gemini-1.5-flash"])
187
-
188
  prompt_tokens = usage['prompt_tokens']
189
  cached_tokens = usage.get('cached_tokens', 0)
190
  completion_tokens = usage['completion_tokens']
191
-
192
  actual_input_tokens = prompt_tokens - cached_tokens
193
-
194
  input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
195
  cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
196
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
197
-
198
  return input_cost + cached_cost + output_cost
199
 
200
 
@@ -214,31 +190,24 @@ class GemmaProvider(LLMProvider):
214
 
215
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
216
  """Create Gemma chat completion"""
217
- # Prepare headers
218
  headers = {
219
  "Authorization": f"Bearer {self.api_key}",
220
- "Accept": "application/json" # Non-streaming for simplicity
221
  }
222
-
223
- # Prepare payload
224
  payload = {
225
  "model": model,
226
  "messages": messages,
227
  "max_tokens": kwargs.get('max_tokens', 512),
228
  "temperature": kwargs.get('temperature', 0.20),
229
  "top_p": kwargs.get('top_p', 0.70),
230
- "stream": False # Disable streaming for now
231
  }
232
-
233
- # Make request
234
  response = self.requests.post(self.invoke_url, headers=headers, json=payload)
235
  response.raise_for_status()
236
-
237
  return response.json()
238
 
239
  def get_response_text(self, completion: Any) -> str:
240
  """Extract text from Gemma response"""
241
- # NVIDIA API returns OpenAI-compatible format
242
  if isinstance(completion, dict):
243
  return completion['choices'][0]['message']['content']
244
  return str(completion)
@@ -251,7 +220,7 @@ class GemmaProvider(LLMProvider):
251
  'prompt_tokens': usage.get('prompt_tokens', 0),
252
  'completion_tokens': usage.get('completion_tokens', 0),
253
  'total_tokens': usage.get('total_tokens', 0),
254
- 'cached_tokens': 0 # NVIDIA API doesn't provide cached token info
255
  }
256
  except:
257
  return {
@@ -263,46 +232,29 @@ class GemmaProvider(LLMProvider):
263
 
264
  def calculate_cost(self, usage: Dict, model: str) -> float:
265
  """Calculate Gemma cost"""
266
- # NVIDIA API pricing (check current pricing at build.nvidia.com)
267
- # For now, using placeholder values - update with actual pricing
268
  model_costs = {
269
- "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0}, # Free tier or update with actual costs
270
  }
271
-
272
  costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
273
-
274
  prompt_tokens = usage['prompt_tokens']
275
  completion_tokens = usage['completion_tokens']
276
-
277
  input_cost = (prompt_tokens / 1_000_000) * costs["input"]
278
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
279
-
280
  return input_cost + output_cost
281
 
282
 
283
  def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider:
284
- """
285
- Factory function to get LLM provider.
286
-
287
- Args:
288
- provider_name: Name of provider ('openai' or 'gemini')
289
- api_key: Optional API key (uses env var if not provided)
290
-
291
- Returns:
292
- LLMProvider instance
293
- """
294
  providers = {
295
  'openai': OpenAIProvider,
296
  'gemini': GeminiProvider,
297
  'gemma': GemmaProvider,
298
  }
299
-
300
  if provider_name not in providers:
301
  raise ValueError(
302
  f"Unknown provider: {provider_name}. "
303
  f"Available providers: {list(providers.keys())}"
304
  )
305
-
306
  return providers[provider_name](api_key=api_key)
307
 
308
 
@@ -310,33 +262,26 @@ def get_default_model(provider_name: str) -> str:
310
  """Get default model for a provider"""
311
  defaults = {
312
  'openai': 'gpt-4o-mini',
313
- 'gemini': 'gemini-1.5-flash-latest',
314
  'gemma': 'google/gemma-3-27b-it',
315
  }
316
  return defaults.get(provider_name, 'gpt-4o-mini')
317
 
318
 
319
  if __name__ == "__main__":
320
- # Test script
321
  print("Testing LLM Provider abstraction...")
322
-
323
- # Test OpenAI
324
  try:
325
  provider = get_provider('openai')
326
  print("βœ… OpenAI provider initialized")
327
  except Exception as e:
328
  print(f"❌ OpenAI provider failed: {e}")
329
-
330
- # Test Gemini
331
  try:
332
  provider = get_provider('gemini')
333
  print("βœ… Gemini provider initialized")
334
  except Exception as e:
335
  print(f"❌ Gemini provider failed: {e}")
336
-
337
- # Test Gemma
338
  try:
339
  provider = get_provider('gemma')
340
  print("βœ… Gemma provider initialized")
341
  except Exception as e:
342
- print(f"❌ Gemma provider failed: {e}")
 
61
 
62
  def calculate_cost(self, usage: Dict, model: str) -> float:
63
  """Calculate OpenAI cost"""
 
64
  model_costs = {
65
  "gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600},
66
  "gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00},
67
  "gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50},
68
  "o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40},
69
  }
 
70
  costs = model_costs.get(model, model_costs["gpt-4o-mini"])
 
71
  prompt_tokens = usage['prompt_tokens']
72
  cached_tokens = usage.get('cached_tokens', 0)
73
  completion_tokens = usage['completion_tokens']
 
74
  actual_input_tokens = prompt_tokens - cached_tokens
 
75
  input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
76
  cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
77
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
 
78
  return input_cost + cached_cost + output_cost
79
 
80
 
 
97
  # Convert OpenAI message format to Gemini format
98
  gemini_messages = self._convert_messages(messages)
99
 
100
+ # Do NOT add models/ prefix - pass model name directly
 
 
 
 
101
  gemini_model = self.genai.GenerativeModel(model)
102
 
103
  # Generate response
 
110
 
111
  def _convert_messages(self, messages: List[Dict]) -> str:
112
  """Convert OpenAI messages to Gemini prompt format"""
 
113
  prompt_parts = []
 
114
  for msg in messages:
115
  role = msg['role']
116
  content = msg['content']
 
117
  if role == 'system':
118
  prompt_parts.append(f"System Instructions:\n{content}\n")
119
  elif role == 'user':
120
  prompt_parts.append(f"User:\n{content}\n")
121
  elif role == 'assistant':
122
  prompt_parts.append(f"Assistant:\n{content}\n")
 
123
  return "\n".join(prompt_parts)
124
 
125
  def _get_generation_config(self, **kwargs):
126
  """Convert OpenAI kwargs to Gemini generation config"""
127
  config = {}
 
 
128
  if 'temperature' in kwargs:
129
  config['temperature'] = kwargs['temperature']
130
  if 'max_tokens' in kwargs:
131
  config['max_output_tokens'] = kwargs['max_tokens']
132
  if 'top_p' in kwargs:
133
  config['top_p'] = kwargs['top_p']
 
134
  return config
135
 
136
  def get_response_text(self, completion: Any) -> str:
 
139
 
140
  def get_usage_info(self, completion: Any) -> Dict:
141
  """Extract usage from Gemini response"""
 
142
  try:
143
  metadata = completion.usage_metadata
144
  return {
 
148
  'cached_tokens': getattr(metadata, 'cached_content_token_count', 0)
149
  }
150
  except:
 
151
  return {
152
  'prompt_tokens': 0,
153
  'completion_tokens': 0,
 
157
 
158
  def calculate_cost(self, usage: Dict, model: str) -> float:
159
  """Calculate Gemini cost"""
 
160
  model_costs = {
161
  "gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30},
162
  "gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00},
163
+ "gemini-2.0-flash": {"input": 0.0, "cached": 0.0, "output": 0.0},
164
+ "gemini-2.0-flash-lite": {"input": 0.0, "cached": 0.0, "output": 0.0},
165
  }
166
+ costs = model_costs.get(model, {"input": 0.0, "cached": 0.0, "output": 0.0})
 
 
167
  prompt_tokens = usage['prompt_tokens']
168
  cached_tokens = usage.get('cached_tokens', 0)
169
  completion_tokens = usage['completion_tokens']
 
170
  actual_input_tokens = prompt_tokens - cached_tokens
 
171
  input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
172
  cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
173
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
 
174
  return input_cost + cached_cost + output_cost
175
 
176
 
 
190
 
191
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
192
  """Create Gemma chat completion"""
 
193
  headers = {
194
  "Authorization": f"Bearer {self.api_key}",
195
+ "Accept": "application/json"
196
  }
 
 
197
  payload = {
198
  "model": model,
199
  "messages": messages,
200
  "max_tokens": kwargs.get('max_tokens', 512),
201
  "temperature": kwargs.get('temperature', 0.20),
202
  "top_p": kwargs.get('top_p', 0.70),
203
+ "stream": False
204
  }
 
 
205
  response = self.requests.post(self.invoke_url, headers=headers, json=payload)
206
  response.raise_for_status()
 
207
  return response.json()
208
 
209
  def get_response_text(self, completion: Any) -> str:
210
  """Extract text from Gemma response"""
 
211
  if isinstance(completion, dict):
212
  return completion['choices'][0]['message']['content']
213
  return str(completion)
 
220
  'prompt_tokens': usage.get('prompt_tokens', 0),
221
  'completion_tokens': usage.get('completion_tokens', 0),
222
  'total_tokens': usage.get('total_tokens', 0),
223
+ 'cached_tokens': 0
224
  }
225
  except:
226
  return {
 
232
 
233
  def calculate_cost(self, usage: Dict, model: str) -> float:
234
  """Calculate Gemma cost"""
 
 
235
  model_costs = {
236
+ "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
237
  }
 
238
  costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
 
239
  prompt_tokens = usage['prompt_tokens']
240
  completion_tokens = usage['completion_tokens']
 
241
  input_cost = (prompt_tokens / 1_000_000) * costs["input"]
242
  output_cost = (completion_tokens / 1_000_000) * costs["output"]
 
243
  return input_cost + output_cost
244
 
245
 
246
  def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider:
247
+ """Factory function to get LLM provider."""
 
 
 
 
 
 
 
 
 
248
  providers = {
249
  'openai': OpenAIProvider,
250
  'gemini': GeminiProvider,
251
  'gemma': GemmaProvider,
252
  }
 
253
  if provider_name not in providers:
254
  raise ValueError(
255
  f"Unknown provider: {provider_name}. "
256
  f"Available providers: {list(providers.keys())}"
257
  )
 
258
  return providers[provider_name](api_key=api_key)
259
 
260
 
 
262
  """Get default model for a provider"""
263
  defaults = {
264
  'openai': 'gpt-4o-mini',
265
+ 'gemini': 'gemini-1.5-flash', # valid, free tier, no models/ prefix needed
266
  'gemma': 'google/gemma-3-27b-it',
267
  }
268
  return defaults.get(provider_name, 'gpt-4o-mini')
269
 
270
 
271
  if __name__ == "__main__":
 
272
  print("Testing LLM Provider abstraction...")
 
 
273
  try:
274
  provider = get_provider('openai')
275
  print("βœ… OpenAI provider initialized")
276
  except Exception as e:
277
  print(f"❌ OpenAI provider failed: {e}")
 
 
278
  try:
279
  provider = get_provider('gemini')
280
  print("βœ… Gemini provider initialized")
281
  except Exception as e:
282
  print(f"❌ Gemini provider failed: {e}")
 
 
283
  try:
284
  provider = get_provider('gemma')
285
  print("βœ… Gemma provider initialized")
286
  except Exception as e:
287
+ print(f"❌ Gemma provider failed: {e}")