Spaces:
Running
Running
Commit ·
2c0acc5
1
Parent(s): af9bb6b
add llama
Browse files- codes/llm_provider.py +11 -14
codes/llm_provider.py
CHANGED
|
@@ -79,7 +79,7 @@ class OpenAIProvider(LLMProvider):
|
|
| 79 |
|
| 80 |
|
| 81 |
class GeminiProvider(LLMProvider):
|
| 82 |
-
"""Google Gemini API
|
| 83 |
|
| 84 |
def __init__(self, api_key: Optional[str] = None):
|
| 85 |
try:
|
|
@@ -94,18 +94,13 @@ class GeminiProvider(LLMProvider):
|
|
| 94 |
|
| 95 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 96 |
"""Create Gemini chat completion"""
|
| 97 |
-
# Convert OpenAI message format to Gemini format
|
| 98 |
gemini_messages = self._convert_messages(messages)
|
| 99 |
-
|
| 100 |
# Do NOT add models/ prefix - pass model name directly
|
| 101 |
gemini_model = self.genai.GenerativeModel(model)
|
| 102 |
-
|
| 103 |
-
# Generate response
|
| 104 |
response = gemini_model.generate_content(
|
| 105 |
gemini_messages,
|
| 106 |
generation_config=self._get_generation_config(**kwargs)
|
| 107 |
)
|
| 108 |
-
|
| 109 |
return response
|
| 110 |
|
| 111 |
def _convert_messages(self, messages: List[Dict]) -> str:
|
|
@@ -175,7 +170,7 @@ class GeminiProvider(LLMProvider):
|
|
| 175 |
|
| 176 |
|
| 177 |
class GemmaProvider(LLMProvider):
|
| 178 |
-
"""NVIDIA
|
| 179 |
|
| 180 |
def __init__(self, api_key: Optional[str] = None):
|
| 181 |
import requests
|
|
@@ -189,7 +184,7 @@ class GemmaProvider(LLMProvider):
|
|
| 189 |
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
|
| 190 |
|
| 191 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 192 |
-
"""Create
|
| 193 |
headers = {
|
| 194 |
"Authorization": f"Bearer {self.api_key}",
|
| 195 |
"Accept": "application/json"
|
|
@@ -197,7 +192,7 @@ class GemmaProvider(LLMProvider):
|
|
| 197 |
payload = {
|
| 198 |
"model": model,
|
| 199 |
"messages": messages,
|
| 200 |
-
"max_tokens": kwargs.get('max_tokens',
|
| 201 |
"temperature": kwargs.get('temperature', 0.20),
|
| 202 |
"top_p": kwargs.get('top_p', 0.70),
|
| 203 |
"stream": False
|
|
@@ -207,13 +202,13 @@ class GemmaProvider(LLMProvider):
|
|
| 207 |
return response.json()
|
| 208 |
|
| 209 |
def get_response_text(self, completion: Any) -> str:
|
| 210 |
-
"""Extract text from
|
| 211 |
if isinstance(completion, dict):
|
| 212 |
return completion['choices'][0]['message']['content']
|
| 213 |
return str(completion)
|
| 214 |
|
| 215 |
def get_usage_info(self, completion: Any) -> Dict:
|
| 216 |
-
"""Extract usage from
|
| 217 |
try:
|
| 218 |
usage = completion.get('usage', {})
|
| 219 |
return {
|
|
@@ -231,9 +226,11 @@ class GemmaProvider(LLMProvider):
|
|
| 231 |
}
|
| 232 |
|
| 233 |
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 234 |
-
"""Calculate
|
| 235 |
model_costs = {
|
| 236 |
"google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
|
| 239 |
prompt_tokens = usage['prompt_tokens']
|
|
@@ -262,8 +259,8 @@ def get_default_model(provider_name: str) -> str:
|
|
| 262 |
"""Get default model for a provider"""
|
| 263 |
defaults = {
|
| 264 |
'openai': 'gpt-4o-mini',
|
| 265 |
-
'gemini': 'gemini-
|
| 266 |
-
'gemma': '
|
| 267 |
}
|
| 268 |
return defaults.get(provider_name, 'gpt-4o-mini')
|
| 269 |
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class GeminiProvider(LLMProvider):
|
| 82 |
+
"""Google Gemini API implementation"""
|
| 83 |
|
| 84 |
def __init__(self, api_key: Optional[str] = None):
|
| 85 |
try:
|
|
|
|
| 94 |
|
| 95 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 96 |
"""Create Gemini chat completion"""
|
|
|
|
| 97 |
gemini_messages = self._convert_messages(messages)
|
|
|
|
| 98 |
# Do NOT add models/ prefix - pass model name directly
|
| 99 |
gemini_model = self.genai.GenerativeModel(model)
|
|
|
|
|
|
|
| 100 |
response = gemini_model.generate_content(
|
| 101 |
gemini_messages,
|
| 102 |
generation_config=self._get_generation_config(**kwargs)
|
| 103 |
)
|
|
|
|
| 104 |
return response
|
| 105 |
|
| 106 |
def _convert_messages(self, messages: List[Dict]) -> str:
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
class GemmaProvider(LLMProvider):
|
| 173 |
+
"""NVIDIA API implementation — supports Gemma, Llama, and other NVIDIA-hosted models"""
|
| 174 |
|
| 175 |
def __init__(self, api_key: Optional[str] = None):
|
| 176 |
import requests
|
|
|
|
| 184 |
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
|
| 185 |
|
| 186 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 187 |
+
"""Create NVIDIA API chat completion"""
|
| 188 |
headers = {
|
| 189 |
"Authorization": f"Bearer {self.api_key}",
|
| 190 |
"Accept": "application/json"
|
|
|
|
| 192 |
payload = {
|
| 193 |
"model": model,
|
| 194 |
"messages": messages,
|
| 195 |
+
"max_tokens": kwargs.get('max_tokens', 8192), # increased for code generation
|
| 196 |
"temperature": kwargs.get('temperature', 0.20),
|
| 197 |
"top_p": kwargs.get('top_p', 0.70),
|
| 198 |
"stream": False
|
|
|
|
| 202 |
return response.json()
|
| 203 |
|
| 204 |
def get_response_text(self, completion: Any) -> str:
|
| 205 |
+
"""Extract text from NVIDIA API response"""
|
| 206 |
if isinstance(completion, dict):
|
| 207 |
return completion['choices'][0]['message']['content']
|
| 208 |
return str(completion)
|
| 209 |
|
| 210 |
def get_usage_info(self, completion: Any) -> Dict:
|
| 211 |
+
"""Extract usage from NVIDIA API response"""
|
| 212 |
try:
|
| 213 |
usage = completion.get('usage', {})
|
| 214 |
return {
|
|
|
|
| 226 |
}
|
| 227 |
|
| 228 |
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 229 |
+
"""Calculate NVIDIA API cost"""
|
| 230 |
model_costs = {
|
| 231 |
"google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
|
| 232 |
+
"meta/llama-3.3-70b-instruct": {"input": 0.0, "output": 0.0},
|
| 233 |
+
"meta/llama-3.1-8b-instruct": {"input": 0.0, "output": 0.0},
|
| 234 |
}
|
| 235 |
costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
|
| 236 |
prompt_tokens = usage['prompt_tokens']
|
|
|
|
| 259 |
"""Get default model for a provider"""
|
| 260 |
defaults = {
|
| 261 |
'openai': 'gpt-4o-mini',
|
| 262 |
+
'gemini': 'gemini-1.5-flash',
|
| 263 |
+
'gemma': 'meta/llama-3.3-70b-instruct', # Llama via NVIDIA API
|
| 264 |
}
|
| 265 |
return defaults.get(provider_name, 'gpt-4o-mini')
|
| 266 |
|