Spaces:
Sleeping
Sleeping
| """ | |
| LLM Provider abstraction layer for Blog2Code. | |
| Supports multiple LLM providers: OpenAI, Google Gemini, NVIDIA Gemma | |
| """ | |
| import os | |
| from typing import Dict, List, Any, Optional | |
| from abc import ABC, abstractmethod | |
| class LLMProvider(ABC): | |
| """Base class for LLM providers""" | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create a chat completion""" | |
| pass | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from completion response""" | |
| pass | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract token usage information""" | |
| pass | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate cost based on usage""" | |
| pass | |
| class OpenAIProvider(LLMProvider): | |
| """OpenAI API implementation""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| from openai import OpenAI | |
| self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY")) | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create OpenAI chat completion""" | |
| return self.client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| **kwargs | |
| ) | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from OpenAI response""" | |
| return completion.choices[0].message.content | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from OpenAI response""" | |
| return { | |
| 'prompt_tokens': completion.usage.prompt_tokens, | |
| 'completion_tokens': completion.usage.completion_tokens, | |
| 'total_tokens': completion.usage.total_tokens, | |
| 'cached_tokens': getattr(completion.usage.prompt_tokens_details, 'cached_tokens', 0) if hasattr(completion.usage, 'prompt_tokens_details') else 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate OpenAI cost""" | |
| model_costs = { | |
| "gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600}, | |
| "gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00}, | |
| "gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50}, | |
| "o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40}, | |
| } | |
| costs = model_costs.get(model, model_costs["gpt-4o-mini"]) | |
| prompt_tokens = usage['prompt_tokens'] | |
| cached_tokens = usage.get('cached_tokens', 0) | |
| completion_tokens = usage['completion_tokens'] | |
| actual_input_tokens = prompt_tokens - cached_tokens | |
| input_cost = (actual_input_tokens / 1_000_000) * costs["input"] | |
| cached_cost = (cached_tokens / 1_000_000) * costs["cached"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + cached_cost + output_cost | |
| class GeminiProvider(LLMProvider): | |
| """Google Gemini API implementation""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| try: | |
| import google.generativeai as genai | |
| self.genai = genai | |
| genai.configure(api_key=api_key or os.environ.get("GEMINI_API_KEY")) | |
| except ImportError: | |
| raise ImportError( | |
| "google-generativeai not installed. " | |
| "Install with: pip install google-generativeai" | |
| ) | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create Gemini chat completion""" | |
| gemini_messages = self._convert_messages(messages) | |
| # Do NOT add models/ prefix - pass model name directly | |
| gemini_model = self.genai.GenerativeModel(model) | |
| response = gemini_model.generate_content( | |
| gemini_messages, | |
| generation_config=self._get_generation_config(**kwargs) | |
| ) | |
| return response | |
| def _convert_messages(self, messages: List[Dict]) -> str: | |
| """Convert OpenAI messages to Gemini prompt format""" | |
| prompt_parts = [] | |
| for msg in messages: | |
| role = msg['role'] | |
| content = msg['content'] | |
| if role == 'system': | |
| prompt_parts.append(f"System Instructions:\n{content}\n") | |
| elif role == 'user': | |
| prompt_parts.append(f"User:\n{content}\n") | |
| elif role == 'assistant': | |
| prompt_parts.append(f"Assistant:\n{content}\n") | |
| return "\n".join(prompt_parts) | |
| def _get_generation_config(self, **kwargs): | |
| """Convert OpenAI kwargs to Gemini generation config""" | |
| config = {} | |
| if 'temperature' in kwargs: | |
| config['temperature'] = kwargs['temperature'] | |
| if 'max_tokens' in kwargs: | |
| config['max_output_tokens'] = kwargs['max_tokens'] | |
| if 'top_p' in kwargs: | |
| config['top_p'] = kwargs['top_p'] | |
| return config | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from Gemini response""" | |
| return completion.text | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from Gemini response""" | |
| try: | |
| metadata = completion.usage_metadata | |
| return { | |
| 'prompt_tokens': metadata.prompt_token_count, | |
| 'completion_tokens': metadata.candidates_token_count, | |
| 'total_tokens': metadata.total_token_count, | |
| 'cached_tokens': getattr(metadata, 'cached_content_token_count', 0) | |
| } | |
| except: | |
| return { | |
| 'prompt_tokens': 0, | |
| 'completion_tokens': 0, | |
| 'total_tokens': 0, | |
| 'cached_tokens': 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate Gemini cost""" | |
| model_costs = { | |
| "gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30}, | |
| "gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00}, | |
| "gemini-2.0-flash": {"input": 0.0, "cached": 0.0, "output": 0.0}, | |
| "gemini-2.0-flash-lite": {"input": 0.0, "cached": 0.0, "output": 0.0}, | |
| } | |
| costs = model_costs.get(model, {"input": 0.0, "cached": 0.0, "output": 0.0}) | |
| prompt_tokens = usage['prompt_tokens'] | |
| cached_tokens = usage.get('cached_tokens', 0) | |
| completion_tokens = usage['completion_tokens'] | |
| actual_input_tokens = prompt_tokens - cached_tokens | |
| input_cost = (actual_input_tokens / 1_000_000) * costs["input"] | |
| cached_cost = (cached_tokens / 1_000_000) * costs["cached"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + cached_cost + output_cost | |
| class GemmaProvider(LLMProvider): | |
| """NVIDIA API implementation β supports Gemma, Llama, and other NVIDIA-hosted models""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| import requests | |
| self.requests = requests | |
| self.api_key = api_key or os.environ.get("NVIDIA_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "NVIDIA_API_KEY not found. " | |
| "Set it as an environment variable or pass it to the constructor." | |
| ) | |
| self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create NVIDIA API chat completion with retry logic""" | |
| import time | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": kwargs.get('max_tokens', 8192), | |
| "temperature": kwargs.get('temperature', 0.20), | |
| "top_p": kwargs.get('top_p', 0.70), | |
| "stream": False | |
| } | |
| max_retries = 5 | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.requests.post(self.invoke_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| wait = 10 * (attempt + 1) # 10s, 20s, 30s, 40s | |
| print(f"[RETRY] Attempt {attempt+1} failed: {e}. Retrying in {wait}s...") | |
| time.sleep(wait) | |
| else: | |
| raise | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from NVIDIA API response""" | |
| if isinstance(completion, dict): | |
| return completion['choices'][0]['message']['content'] | |
| return str(completion) | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from NVIDIA API response""" | |
| try: | |
| usage = completion.get('usage', {}) | |
| return { | |
| 'prompt_tokens': usage.get('prompt_tokens', 0), | |
| 'completion_tokens': usage.get('completion_tokens', 0), | |
| 'total_tokens': usage.get('total_tokens', 0), | |
| 'cached_tokens': 0 | |
| } | |
| except: | |
| return { | |
| 'prompt_tokens': 0, | |
| 'completion_tokens': 0, | |
| 'total_tokens': 0, | |
| 'cached_tokens': 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate NVIDIA API cost""" | |
| model_costs = { | |
| "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0}, | |
| "meta/llama-3.3-70b-instruct": {"input": 0.0, "output": 0.0}, | |
| "meta/llama-3.1-8b-instruct": {"input": 0.0, "output": 0.0}, | |
| } | |
| costs = model_costs.get(model, {"input": 0.0, "output": 0.0}) | |
| prompt_tokens = usage['prompt_tokens'] | |
| completion_tokens = usage['completion_tokens'] | |
| input_cost = (prompt_tokens / 1_000_000) * costs["input"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + output_cost | |
| def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider: | |
| """Factory function to get LLM provider.""" | |
| providers = { | |
| 'openai': OpenAIProvider, | |
| 'gemini': GeminiProvider, | |
| 'gemma': GemmaProvider, | |
| } | |
| if provider_name not in providers: | |
| raise ValueError( | |
| f"Unknown provider: {provider_name}. " | |
| f"Available providers: {list(providers.keys())}" | |
| ) | |
| return providers[provider_name](api_key=api_key) | |
| def get_default_model(provider_name: str) -> str: | |
| """Get default model for a provider""" | |
| defaults = { | |
| 'openai': 'gpt-4o-mini', | |
| 'gemini': 'gemini-1.5-flash', | |
| 'gemma': 'meta/llama-3.3-70b-instruct', # Llama via NVIDIA API | |
| } | |
| return defaults.get(provider_name, 'gpt-4o-mini') | |
| if __name__ == "__main__": | |
| print("Testing LLM Provider abstraction...") | |
| try: | |
| provider = get_provider('openai') | |
| print("β OpenAI provider initialized") | |
| except Exception as e: | |
| print(f"β OpenAI provider failed: {e}") | |
| try: | |
| provider = get_provider('gemini') | |
| print("β Gemini provider initialized") | |
| except Exception as e: | |
| print(f"β Gemini provider failed: {e}") | |
| try: | |
| provider = get_provider('gemma') | |
| print("β Gemma provider initialized") | |
| except Exception as e: | |
| print(f"β Gemma provider failed: {e}") |