blog2code-api / codes /llm_provider.py
srishtichugh's picture
use gemini-2.0-flash-lite
bc496da
raw
history blame
12.3 kB
"""
LLM Provider abstraction layer for Blog2Code.
Supports multiple LLM providers: OpenAI, Google Gemini, NVIDIA Gemma
"""
import os
from typing import Dict, List, Any, Optional
from abc import ABC, abstractmethod
class LLMProvider(ABC):
"""Base class for LLM providers"""
@abstractmethod
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create a chat completion"""
pass
@abstractmethod
def get_response_text(self, completion: Any) -> str:
"""Extract text from completion response"""
pass
@abstractmethod
def get_usage_info(self, completion: Any) -> Dict:
"""Extract token usage information"""
pass
@abstractmethod
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate cost based on usage"""
pass
class OpenAIProvider(LLMProvider):
"""OpenAI API implementation"""
def __init__(self, api_key: Optional[str] = None):
from openai import OpenAI
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create OpenAI chat completion"""
return self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
def get_response_text(self, completion: Any) -> str:
"""Extract text from OpenAI response"""
return completion.choices[0].message.content
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from OpenAI response"""
return {
'prompt_tokens': completion.usage.prompt_tokens,
'completion_tokens': completion.usage.completion_tokens,
'total_tokens': completion.usage.total_tokens,
'cached_tokens': getattr(completion.usage.prompt_tokens_details, 'cached_tokens', 0) if hasattr(completion.usage, 'prompt_tokens_details') else 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate OpenAI cost"""
# Pricing per 1M tokens
model_costs = {
"gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600},
"gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00},
"gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50},
"o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40},
}
costs = model_costs.get(model, model_costs["gpt-4o-mini"])
prompt_tokens = usage['prompt_tokens']
cached_tokens = usage.get('cached_tokens', 0)
completion_tokens = usage['completion_tokens']
actual_input_tokens = prompt_tokens - cached_tokens
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + cached_cost + output_cost
class GeminiProvider(LLMProvider):
"""Google Gemini API implementation"""
def __init__(self, api_key: Optional[str] = None):
try:
import google.generativeai as genai
self.genai = genai
genai.configure(api_key=api_key or os.environ.get("GEMINI_API_KEY"))
except ImportError:
raise ImportError(
"google-generativeai not installed. "
"Install with: pip install google-generativeai"
)
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create Gemini chat completion"""
# Convert OpenAI message format to Gemini format
gemini_messages = self._convert_messages(messages)
# Fix model name - Gemini expects models/model-name format
if not model.startswith('models/'):
model = f'models/{model}'
# Create model
gemini_model = self.genai.GenerativeModel(model)
# Generate response
response = gemini_model.generate_content(
gemini_messages,
generation_config=self._get_generation_config(**kwargs)
)
return response
def _convert_messages(self, messages: List[Dict]) -> str:
"""Convert OpenAI messages to Gemini prompt format"""
# Gemini uses a simpler format - concatenate all messages
prompt_parts = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
prompt_parts.append(f"System Instructions:\n{content}\n")
elif role == 'user':
prompt_parts.append(f"User:\n{content}\n")
elif role == 'assistant':
prompt_parts.append(f"Assistant:\n{content}\n")
return "\n".join(prompt_parts)
def _get_generation_config(self, **kwargs):
"""Convert OpenAI kwargs to Gemini generation config"""
config = {}
# Map common parameters
if 'temperature' in kwargs:
config['temperature'] = kwargs['temperature']
if 'max_tokens' in kwargs:
config['max_output_tokens'] = kwargs['max_tokens']
if 'top_p' in kwargs:
config['top_p'] = kwargs['top_p']
return config
def get_response_text(self, completion: Any) -> str:
"""Extract text from Gemini response"""
return completion.text
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from Gemini response"""
# Gemini provides token counts in metadata
try:
metadata = completion.usage_metadata
return {
'prompt_tokens': metadata.prompt_token_count,
'completion_tokens': metadata.candidates_token_count,
'total_tokens': metadata.total_token_count,
'cached_tokens': getattr(metadata, 'cached_content_token_count', 0)
}
except:
# Fallback if metadata not available
return {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'cached_tokens': 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate Gemini cost"""
# Gemini pricing per 1M tokens (as of Jan 2026)
model_costs = {
"gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30},
"gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00},
"gemini-2.0-flash-exp": {"input": 0.0, "cached": 0.0, "output": 0.0}, # Free during preview
}
costs = model_costs.get(model, model_costs["gemini-1.5-flash"])
prompt_tokens = usage['prompt_tokens']
cached_tokens = usage.get('cached_tokens', 0)
completion_tokens = usage['completion_tokens']
actual_input_tokens = prompt_tokens - cached_tokens
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + cached_cost + output_cost
class GemmaProvider(LLMProvider):
"""NVIDIA Gemma API implementation"""
def __init__(self, api_key: Optional[str] = None):
import requests
self.requests = requests
self.api_key = api_key or os.environ.get("NVIDIA_API_KEY")
if not self.api_key:
raise ValueError(
"NVIDIA_API_KEY not found. "
"Set it as an environment variable or pass it to the constructor."
)
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
"""Create Gemma chat completion"""
# Prepare headers
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json" # Non-streaming for simplicity
}
# Prepare payload
payload = {
"model": model,
"messages": messages,
"max_tokens": kwargs.get('max_tokens', 512),
"temperature": kwargs.get('temperature', 0.20),
"top_p": kwargs.get('top_p', 0.70),
"stream": False # Disable streaming for now
}
# Make request
response = self.requests.post(self.invoke_url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
def get_response_text(self, completion: Any) -> str:
"""Extract text from Gemma response"""
# NVIDIA API returns OpenAI-compatible format
if isinstance(completion, dict):
return completion['choices'][0]['message']['content']
return str(completion)
def get_usage_info(self, completion: Any) -> Dict:
"""Extract usage from Gemma response"""
try:
usage = completion.get('usage', {})
return {
'prompt_tokens': usage.get('prompt_tokens', 0),
'completion_tokens': usage.get('completion_tokens', 0),
'total_tokens': usage.get('total_tokens', 0),
'cached_tokens': 0 # NVIDIA API doesn't provide cached token info
}
except:
return {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'cached_tokens': 0
}
def calculate_cost(self, usage: Dict, model: str) -> float:
"""Calculate Gemma cost"""
# NVIDIA API pricing (check current pricing at build.nvidia.com)
# For now, using placeholder values - update with actual pricing
model_costs = {
"google/gemma-3-27b-it": {"input": 0.0, "output": 0.0}, # Free tier or update with actual costs
}
costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
prompt_tokens = usage['prompt_tokens']
completion_tokens = usage['completion_tokens']
input_cost = (prompt_tokens / 1_000_000) * costs["input"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
return input_cost + output_cost
def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider:
"""
Factory function to get LLM provider.
Args:
provider_name: Name of provider ('openai' or 'gemini')
api_key: Optional API key (uses env var if not provided)
Returns:
LLMProvider instance
"""
providers = {
'openai': OpenAIProvider,
'gemini': GeminiProvider,
'gemma': GemmaProvider,
}
if provider_name not in providers:
raise ValueError(
f"Unknown provider: {provider_name}. "
f"Available providers: {list(providers.keys())}"
)
return providers[provider_name](api_key=api_key)
def get_default_model(provider_name: str) -> str:
"""Get default model for a provider"""
defaults = {
'openai': 'gpt-4o-mini',
'gemini': 'gemini-2.0-flash-lite',
'gemma': 'google/gemma-3-27b-it',
}
return defaults.get(provider_name, 'gpt-4o-mini')
if __name__ == "__main__":
# Test script
print("Testing LLM Provider abstraction...")
# Test OpenAI
try:
provider = get_provider('openai')
print("βœ… OpenAI provider initialized")
except Exception as e:
print(f"❌ OpenAI provider failed: {e}")
# Test Gemini
try:
provider = get_provider('gemini')
print("βœ… Gemini provider initialized")
except Exception as e:
print(f"❌ Gemini provider failed: {e}")
# Test Gemma
try:
provider = get_provider('gemma')
print("βœ… Gemma provider initialized")
except Exception as e:
print(f"❌ Gemma provider failed: {e}")