srishtichugh commited on
Commit
2c0acc5
·
1 Parent(s): af9bb6b

add llama

Browse files
Files changed (1) hide show
  1. codes/llm_provider.py +11 -14
codes/llm_provider.py CHANGED
@@ -79,7 +79,7 @@ class OpenAIProvider(LLMProvider):
79
 
80
 
81
  class GeminiProvider(LLMProvider):
82
- """Google Gemini API implementat`ion"""
83
 
84
  def __init__(self, api_key: Optional[str] = None):
85
  try:
@@ -94,18 +94,13 @@ class GeminiProvider(LLMProvider):
94
 
95
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
96
  """Create Gemini chat completion"""
97
- # Convert OpenAI message format to Gemini format
98
  gemini_messages = self._convert_messages(messages)
99
-
100
  # Do NOT add models/ prefix - pass model name directly
101
  gemini_model = self.genai.GenerativeModel(model)
102
-
103
- # Generate response
104
  response = gemini_model.generate_content(
105
  gemini_messages,
106
  generation_config=self._get_generation_config(**kwargs)
107
  )
108
-
109
  return response
110
 
111
  def _convert_messages(self, messages: List[Dict]) -> str:
@@ -175,7 +170,7 @@ class GeminiProvider(LLMProvider):
175
 
176
 
177
  class GemmaProvider(LLMProvider):
178
- """NVIDIA Gemma API implementation"""
179
 
180
  def __init__(self, api_key: Optional[str] = None):
181
  import requests
@@ -189,7 +184,7 @@ class GemmaProvider(LLMProvider):
189
  self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
190
 
191
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
192
- """Create Gemma chat completion"""
193
  headers = {
194
  "Authorization": f"Bearer {self.api_key}",
195
  "Accept": "application/json"
@@ -197,7 +192,7 @@ class GemmaProvider(LLMProvider):
197
  payload = {
198
  "model": model,
199
  "messages": messages,
200
- "max_tokens": kwargs.get('max_tokens', 512),
201
  "temperature": kwargs.get('temperature', 0.20),
202
  "top_p": kwargs.get('top_p', 0.70),
203
  "stream": False
@@ -207,13 +202,13 @@ class GemmaProvider(LLMProvider):
207
  return response.json()
208
 
209
  def get_response_text(self, completion: Any) -> str:
210
- """Extract text from Gemma response"""
211
  if isinstance(completion, dict):
212
  return completion['choices'][0]['message']['content']
213
  return str(completion)
214
 
215
  def get_usage_info(self, completion: Any) -> Dict:
216
- """Extract usage from Gemma response"""
217
  try:
218
  usage = completion.get('usage', {})
219
  return {
@@ -231,9 +226,11 @@ class GemmaProvider(LLMProvider):
231
  }
232
 
233
  def calculate_cost(self, usage: Dict, model: str) -> float:
234
- """Calculate Gemma cost"""
235
  model_costs = {
236
  "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
 
 
237
  }
238
  costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
239
  prompt_tokens = usage['prompt_tokens']
@@ -262,8 +259,8 @@ def get_default_model(provider_name: str) -> str:
262
  """Get default model for a provider"""
263
  defaults = {
264
  'openai': 'gpt-4o-mini',
265
- 'gemini': 'gemini-2.0-flash', # valid, free tier, no models/ prefix needed
266
- 'gemma': 'google/gemma-3-27b-it',
267
  }
268
  return defaults.get(provider_name, 'gpt-4o-mini')
269
 
 
79
 
80
 
81
  class GeminiProvider(LLMProvider):
82
+ """Google Gemini API implementation"""
83
 
84
  def __init__(self, api_key: Optional[str] = None):
85
  try:
 
94
 
95
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
96
  """Create Gemini chat completion"""
 
97
  gemini_messages = self._convert_messages(messages)
 
98
  # Do NOT add models/ prefix - pass model name directly
99
  gemini_model = self.genai.GenerativeModel(model)
 
 
100
  response = gemini_model.generate_content(
101
  gemini_messages,
102
  generation_config=self._get_generation_config(**kwargs)
103
  )
 
104
  return response
105
 
106
  def _convert_messages(self, messages: List[Dict]) -> str:
 
170
 
171
 
172
  class GemmaProvider(LLMProvider):
173
+ """NVIDIA API implementation — supports Gemma, Llama, and other NVIDIA-hosted models"""
174
 
175
  def __init__(self, api_key: Optional[str] = None):
176
  import requests
 
184
  self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
185
 
186
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
187
+ """Create NVIDIA API chat completion"""
188
  headers = {
189
  "Authorization": f"Bearer {self.api_key}",
190
  "Accept": "application/json"
 
192
  payload = {
193
  "model": model,
194
  "messages": messages,
195
+ "max_tokens": kwargs.get('max_tokens', 8192), # increased for code generation
196
  "temperature": kwargs.get('temperature', 0.20),
197
  "top_p": kwargs.get('top_p', 0.70),
198
  "stream": False
 
202
  return response.json()
203
 
204
  def get_response_text(self, completion: Any) -> str:
205
+ """Extract text from NVIDIA API response"""
206
  if isinstance(completion, dict):
207
  return completion['choices'][0]['message']['content']
208
  return str(completion)
209
 
210
  def get_usage_info(self, completion: Any) -> Dict:
211
+ """Extract usage from NVIDIA API response"""
212
  try:
213
  usage = completion.get('usage', {})
214
  return {
 
226
  }
227
 
228
  def calculate_cost(self, usage: Dict, model: str) -> float:
229
+ """Calculate NVIDIA API cost"""
230
  model_costs = {
231
  "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
232
+ "meta/llama-3.3-70b-instruct": {"input": 0.0, "output": 0.0},
233
+ "meta/llama-3.1-8b-instruct": {"input": 0.0, "output": 0.0},
234
  }
235
  costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
236
  prompt_tokens = usage['prompt_tokens']
 
259
  """Get default model for a provider"""
260
  defaults = {
261
  'openai': 'gpt-4o-mini',
262
+ 'gemini': 'gemini-1.5-flash',
263
+ 'gemma': 'meta/llama-3.3-70b-instruct', # Llama via NVIDIA API
264
  }
265
  return defaults.get(provider_name, 'gpt-4o-mini')
266