Spaces:
Sleeping
Sleeping
Commit ·
757c114
1
Parent(s): 8f2f756
add retry logic for nvidia api errors
Browse files- codes/llm_provider.py +16 -5
codes/llm_provider.py
CHANGED
|
@@ -184,7 +184,8 @@ class GemmaProvider(LLMProvider):
|
|
| 184 |
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
|
| 185 |
|
| 186 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 187 |
-
"""Create NVIDIA API chat completion"""
|
|
|
|
| 188 |
headers = {
|
| 189 |
"Authorization": f"Bearer {self.api_key}",
|
| 190 |
"Accept": "application/json"
|
|
@@ -192,14 +193,24 @@ class GemmaProvider(LLMProvider):
|
|
| 192 |
payload = {
|
| 193 |
"model": model,
|
| 194 |
"messages": messages,
|
| 195 |
-
"max_tokens": kwargs.get('max_tokens', 8192),
|
| 196 |
"temperature": kwargs.get('temperature', 0.20),
|
| 197 |
"top_p": kwargs.get('top_p', 0.70),
|
| 198 |
"stream": False
|
| 199 |
}
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
def get_response_text(self, completion: Any) -> str:
|
| 205 |
"""Extract text from NVIDIA API response"""
|
|
|
|
| 184 |
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
|
| 185 |
|
| 186 |
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 187 |
+
"""Create NVIDIA API chat completion with retry logic"""
|
| 188 |
+
import time
|
| 189 |
headers = {
|
| 190 |
"Authorization": f"Bearer {self.api_key}",
|
| 191 |
"Accept": "application/json"
|
|
|
|
| 193 |
payload = {
|
| 194 |
"model": model,
|
| 195 |
"messages": messages,
|
| 196 |
+
"max_tokens": kwargs.get('max_tokens', 8192),
|
| 197 |
"temperature": kwargs.get('temperature', 0.20),
|
| 198 |
"top_p": kwargs.get('top_p', 0.70),
|
| 199 |
"stream": False
|
| 200 |
}
|
| 201 |
+
max_retries = 5
|
| 202 |
+
for attempt in range(max_retries):
|
| 203 |
+
try:
|
| 204 |
+
response = self.requests.post(self.invoke_url, headers=headers, json=payload)
|
| 205 |
+
response.raise_for_status()
|
| 206 |
+
return response.json()
|
| 207 |
+
except Exception as e:
|
| 208 |
+
if attempt < max_retries - 1:
|
| 209 |
+
wait = 10 * (attempt + 1) # 10s, 20s, 30s, 40s
|
| 210 |
+
print(f"[RETRY] Attempt {attempt+1} failed: {e}. Retrying in {wait}s...")
|
| 211 |
+
time.sleep(wait)
|
| 212 |
+
else:
|
| 213 |
+
raise
|
| 214 |
|
| 215 |
def get_response_text(self, completion: Any) -> str:
|
| 216 |
"""Extract text from NVIDIA API response"""
|