blog2code-api / codes /llm_provider.py
srishtichugh's picture
add retry logic for nvidia api errors
757c114
"""
LLM Provider abstraction layer for Blog2Code.
Supports multiple LLM providers: OpenAI, Google Gemini, NVIDIA Gemma
"""
import os
from typing import Dict, List, Any, Optional
from abc import ABC, abstractmethod
class LLMProvider(ABC):
"""Base class for LLM providers"""
@abstractmethod
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create a chat completion"""
pass
@abstractmethod
def get_response_text(self, completion: Any) -> str:
"""Extract text from completion response"""
pass
@abstractmethod
def get_usage_info(self, completion: Any) -> Dict:
"""Extract token usage information"""
pass
@abstractmethod
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate cost based on usage"""
pass
class OpenAIProvider(LLMProvider):
"""OpenAI API implementation"""
def __init__(self, api_key: Optional[str] = None):
from openai import OpenAI
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create OpenAI chat completion"""
return self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
def get_response_text(self, completion: Any) -> str:
"""Extract text from OpenAI response"""
return completion.choices[0].message.content
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from OpenAI response"""
return {
'prompt_tokens': completion.usage.prompt_tokens,
'completion_tokens': completion.usage.completion_tokens,
'total_tokens': completion.usage.total_tokens,
'cached_tokens': getattr(completion.usage.prompt_tokens_details, 'cached_tokens', 0) if hasattr(completion.usage, 'prompt_tokens_details') else 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate OpenAI cost"""
model_costs = {
"gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600},
"gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00},
"gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50},
"o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40},
}
costs = model_costs.get(model, model_costs["gpt-4o-mini"])
prompt_tokens = usage['prompt_tokens']
cached_tokens = usage.get('cached_tokens', 0)
completion_tokens = usage['completion_tokens']
actual_input_tokens = prompt_tokens - cached_tokens
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + cached_cost + output_cost
class GeminiProvider(LLMProvider):
"""Google Gemini API implementation"""
def __init__(self, api_key: Optional[str] = None):
try:
import google.generativeai as genai
self.genai = genai
genai.configure(api_key=api_key or os.environ.get("GEMINI_API_KEY"))
except ImportError:
raise ImportError(
"google-generativeai not installed. "
"Install with: pip install google-generativeai"
)
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create Gemini chat completion"""
gemini_messages = self._convert_messages(messages)
# Do NOT add models/ prefix - pass model name directly
gemini_model = self.genai.GenerativeModel(model)
response = gemini_model.generate_content(
gemini_messages,
generation_config=self._get_generation_config(**kwargs)
)
return response
def _convert_messages(self, messages: List[Dict]) -> str:
"""Convert OpenAI messages to Gemini prompt format"""
prompt_parts = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
prompt_parts.append(f"System Instructions:\n{content}\n")
elif role == 'user':
prompt_parts.append(f"User:\n{content}\n")
elif role == 'assistant':
prompt_parts.append(f"Assistant:\n{content}\n")
return "\n".join(prompt_parts)
def _get_generation_config(self, **kwargs):
"""Convert OpenAI kwargs to Gemini generation config"""
config = {}
if 'temperature' in kwargs:
config['temperature'] = kwargs['temperature']
if 'max_tokens' in kwargs:
config['max_output_tokens'] = kwargs['max_tokens']
if 'top_p' in kwargs:
config['top_p'] = kwargs['top_p']
return config
def get_response_text(self, completion: Any) -> str:
"""Extract text from Gemini response"""
return completion.text
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from Gemini response"""
try:
metadata = completion.usage_metadata
return {
'prompt_tokens': metadata.prompt_token_count,
'completion_tokens': metadata.candidates_token_count,
'total_tokens': metadata.total_token_count,
'cached_tokens': getattr(metadata, 'cached_content_token_count', 0)
}
except:
return {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'cached_tokens': 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate Gemini cost"""
model_costs = {
"gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30},
"gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00},
"gemini-2.0-flash": {"input": 0.0, "cached": 0.0, "output": 0.0},
"gemini-2.0-flash-lite": {"input": 0.0, "cached": 0.0, "output": 0.0},
}
costs = model_costs.get(model, {"input": 0.0, "cached": 0.0, "output": 0.0})
prompt_tokens = usage['prompt_tokens']
cached_tokens = usage.get('cached_tokens', 0)
completion_tokens = usage['completion_tokens']
actual_input_tokens = prompt_tokens - cached_tokens
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + cached_cost + output_cost
class GemmaProvider(LLMProvider):
"""NVIDIA API implementation β€” supports Gemma, Llama, and other NVIDIA-hosted models"""
def __init__(self, api_key: Optional[str] = None):
import requests
self.requests = requests
self.api_key = api_key or os.environ.get("NVIDIA_API_KEY")
if not self.api_key:
raise ValueError(
"NVIDIA_API_KEY not found. "
"Set it as an environment variable or pass it to the constructor."
)
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create NVIDIA API chat completion with retry logic"""
import time
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json"
}
payload = {
"model": model,
"messages": messages,
"max_tokens": kwargs.get('max_tokens', 8192),
"temperature": kwargs.get('temperature', 0.20),
"top_p": kwargs.get('top_p', 0.70),
"stream": False
}
max_retries = 5
for attempt in range(max_retries):
try:
response = self.requests.post(self.invoke_url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except Exception as e:
if attempt < max_retries - 1:
wait = 10 * (attempt + 1) # 10s, 20s, 30s, 40s
print(f"[RETRY] Attempt {attempt+1} failed: {e}. Retrying in {wait}s...")
time.sleep(wait)
else:
raise
def get_response_text(self, completion: Any) -> str:
"""Extract text from NVIDIA API response"""
if isinstance(completion, dict):
return completion['choices'][0]['message']['content']
return str(completion)
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from NVIDIA API response"""
try:
usage = completion.get('usage', {})
return {
'prompt_tokens': usage.get('prompt_tokens', 0),
'completion_tokens': usage.get('completion_tokens', 0),
'total_tokens': usage.get('total_tokens', 0),
'cached_tokens': 0
}
except:
return {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'cached_tokens': 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate NVIDIA API cost"""
model_costs = {
"google/gemma-3-27b-it": {"input": 0.0, "output": 0.0},
"meta/llama-3.3-70b-instruct": {"input": 0.0, "output": 0.0},
"meta/llama-3.1-8b-instruct": {"input": 0.0, "output": 0.0},
}
costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
prompt_tokens = usage['prompt_tokens']
completion_tokens = usage['completion_tokens']
input_cost = (prompt_tokens / 1_000_000) * costs["input"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + output_cost
def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider:
"""Factory function to get LLM provider."""
providers = {
'openai': OpenAIProvider,
'gemini': GeminiProvider,
'gemma': GemmaProvider,
}
if provider_name not in providers:
raise ValueError(
f"Unknown provider: {provider_name}. "
f"Available providers: {list(providers.keys())}"
)
return providers[provider_name](api_key=api_key)
def get_default_model(provider_name: str) -> str:
"""Get default model for a provider"""
defaults = {
'openai': 'gpt-4o-mini',
'gemini': 'gemini-1.5-flash',
'gemma': 'meta/llama-3.3-70b-instruct', # Llama via NVIDIA API
}
return defaults.get(provider_name, 'gpt-4o-mini')
if __name__ == "__main__":
print("Testing LLM Provider abstraction...")
try:
provider = get_provider('openai')
print("βœ… OpenAI provider initialized")
except Exception as e:
print(f"❌ OpenAI provider failed: {e}")
try:
provider = get_provider('gemini')
print("βœ… Gemini provider initialized")
except Exception as e:
print(f"❌ Gemini provider failed: {e}")
try:
provider = get_provider('gemma')
print("βœ… Gemma provider initialized")
except Exception as e:
print(f"❌ Gemma provider failed: {e}")