srishtichugh commited on
Commit
757c114
·
1 Parent(s): 8f2f756

add retry logic for nvidia api errors

Browse files
Files changed (1) hide show
  1. codes/llm_provider.py +16 -5
codes/llm_provider.py CHANGED
@@ -184,7 +184,8 @@ class GemmaProvider(LLMProvider):
184
  self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
185
 
186
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
187
- """Create NVIDIA API chat completion"""
 
188
  headers = {
189
  "Authorization": f"Bearer {self.api_key}",
190
  "Accept": "application/json"
@@ -192,14 +193,24 @@ class GemmaProvider(LLMProvider):
192
  payload = {
193
  "model": model,
194
  "messages": messages,
195
- "max_tokens": kwargs.get('max_tokens', 8192), # increased for code generation
196
  "temperature": kwargs.get('temperature', 0.20),
197
  "top_p": kwargs.get('top_p', 0.70),
198
  "stream": False
199
  }
200
- response = self.requests.post(self.invoke_url, headers=headers, json=payload)
201
- response.raise_for_status()
202
- return response.json()
 
 
 
 
 
 
 
 
 
 
203
 
204
  def get_response_text(self, completion: Any) -> str:
205
  """Extract text from NVIDIA API response"""
 
184
  self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
185
 
186
  def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
187
+ """Create NVIDIA API chat completion with retry logic"""
188
+ import time
189
  headers = {
190
  "Authorization": f"Bearer {self.api_key}",
191
  "Accept": "application/json"
 
193
  payload = {
194
  "model": model,
195
  "messages": messages,
196
+ "max_tokens": kwargs.get('max_tokens', 8192),
197
  "temperature": kwargs.get('temperature', 0.20),
198
  "top_p": kwargs.get('top_p', 0.70),
199
  "stream": False
200
  }
201
+ max_retries = 5
202
+ for attempt in range(max_retries):
203
+ try:
204
+ response = self.requests.post(self.invoke_url, headers=headers, json=payload)
205
+ response.raise_for_status()
206
+ return response.json()
207
+ except Exception as e:
208
+ if attempt < max_retries - 1:
209
+ wait = 10 * (attempt + 1) # 10s, 20s, 30s, 40s
210
+ print(f"[RETRY] Attempt {attempt+1} failed: {e}. Retrying in {wait}s...")
211
+ time.sleep(wait)
212
+ else:
213
+ raise
214
 
215
  def get_response_text(self, completion: Any) -> str:
216
  """Extract text from NVIDIA API response"""