Spaces:
Running
Running
| """ | |
| LLM Provider abstraction layer for Blog2Code. | |
| Supports multiple LLM providers: OpenAI, Google Gemini, NVIDIA Gemma | |
| """ | |
| import os | |
| from typing import Dict, List, Any, Optional | |
| from abc import ABC, abstractmethod | |
| class LLMProvider(ABC): | |
| """Base class for LLM providers""" | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create a chat completion""" | |
| pass | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from completion response""" | |
| pass | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract token usage information""" | |
| pass | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate cost based on usage""" | |
| pass | |
| class OpenAIProvider(LLMProvider): | |
| """OpenAI API implementation""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| from openai import OpenAI | |
| self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY")) | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create OpenAI chat completion""" | |
| return self.client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| **kwargs | |
| ) | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from OpenAI response""" | |
| return completion.choices[0].message.content | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from OpenAI response""" | |
| return { | |
| 'prompt_tokens': completion.usage.prompt_tokens, | |
| 'completion_tokens': completion.usage.completion_tokens, | |
| 'total_tokens': completion.usage.total_tokens, | |
| 'cached_tokens': getattr(completion.usage.prompt_tokens_details, 'cached_tokens', 0) if hasattr(completion.usage, 'prompt_tokens_details') else 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate OpenAI cost""" | |
| # Pricing per 1M tokens | |
| model_costs = { | |
| "gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600}, | |
| "gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00}, | |
| "gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50}, | |
| "o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40}, | |
| } | |
| costs = model_costs.get(model, model_costs["gpt-4o-mini"]) | |
| prompt_tokens = usage['prompt_tokens'] | |
| cached_tokens = usage.get('cached_tokens', 0) | |
| completion_tokens = usage['completion_tokens'] | |
| actual_input_tokens = prompt_tokens - cached_tokens | |
| input_cost = (actual_input_tokens / 1_000_000) * costs["input"] | |
| cached_cost = (cached_tokens / 1_000_000) * costs["cached"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + cached_cost + output_cost | |
| class GeminiProvider(LLMProvider): | |
| """Google Gemini API implementation""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| try: | |
| import google.generativeai as genai | |
| self.genai = genai | |
| genai.configure(api_key=api_key or os.environ.get("GEMINI_API_KEY")) | |
| except ImportError: | |
| raise ImportError( | |
| "google-generativeai not installed. " | |
| "Install with: pip install google-generativeai" | |
| ) | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create Gemini chat completion""" | |
| # Convert OpenAI message format to Gemini format | |
| gemini_messages = self._convert_messages(messages) | |
| # Fix model name - Gemini expects models/model-name format | |
| if not model.startswith('models/'): | |
| model = f'models/{model}' | |
| # Create model | |
| gemini_model = self.genai.GenerativeModel(model) | |
| # Generate response | |
| response = gemini_model.generate_content( | |
| gemini_messages, | |
| generation_config=self._get_generation_config(**kwargs) | |
| ) | |
| return response | |
| def _convert_messages(self, messages: List[Dict]) -> str: | |
| """Convert OpenAI messages to Gemini prompt format""" | |
| # Gemini uses a simpler format - concatenate all messages | |
| prompt_parts = [] | |
| for msg in messages: | |
| role = msg['role'] | |
| content = msg['content'] | |
| if role == 'system': | |
| prompt_parts.append(f"System Instructions:\n{content}\n") | |
| elif role == 'user': | |
| prompt_parts.append(f"User:\n{content}\n") | |
| elif role == 'assistant': | |
| prompt_parts.append(f"Assistant:\n{content}\n") | |
| return "\n".join(prompt_parts) | |
| def _get_generation_config(self, **kwargs): | |
| """Convert OpenAI kwargs to Gemini generation config""" | |
| config = {} | |
| # Map common parameters | |
| if 'temperature' in kwargs: | |
| config['temperature'] = kwargs['temperature'] | |
| if 'max_tokens' in kwargs: | |
| config['max_output_tokens'] = kwargs['max_tokens'] | |
| if 'top_p' in kwargs: | |
| config['top_p'] = kwargs['top_p'] | |
| return config | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from Gemini response""" | |
| return completion.text | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from Gemini response""" | |
| # Gemini provides token counts in metadata | |
| try: | |
| metadata = completion.usage_metadata | |
| return { | |
| 'prompt_tokens': metadata.prompt_token_count, | |
| 'completion_tokens': metadata.candidates_token_count, | |
| 'total_tokens': metadata.total_token_count, | |
| 'cached_tokens': getattr(metadata, 'cached_content_token_count', 0) | |
| } | |
| except: | |
| # Fallback if metadata not available | |
| return { | |
| 'prompt_tokens': 0, | |
| 'completion_tokens': 0, | |
| 'total_tokens': 0, | |
| 'cached_tokens': 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate Gemini cost""" | |
| # Gemini pricing per 1M tokens (as of Jan 2026) | |
| model_costs = { | |
| "gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30}, | |
| "gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00}, | |
| "gemini-2.0-flash-exp": {"input": 0.0, "cached": 0.0, "output": 0.0}, # Free during preview | |
| } | |
| costs = model_costs.get(model, model_costs["gemini-1.5-flash"]) | |
| prompt_tokens = usage['prompt_tokens'] | |
| cached_tokens = usage.get('cached_tokens', 0) | |
| completion_tokens = usage['completion_tokens'] | |
| actual_input_tokens = prompt_tokens - cached_tokens | |
| input_cost = (actual_input_tokens / 1_000_000) * costs["input"] | |
| cached_cost = (cached_tokens / 1_000_000) * costs["cached"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + cached_cost + output_cost | |
| class GemmaProvider(LLMProvider): | |
| """NVIDIA Gemma API implementation""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| import requests | |
| self.requests = requests | |
| self.api_key = api_key or os.environ.get("NVIDIA_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "NVIDIA_API_KEY not found. " | |
| "Set it as an environment variable or pass it to the constructor." | |
| ) | |
| self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any: | |
| """Create Gemma chat completion""" | |
| # Prepare headers | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Accept": "application/json" # Non-streaming for simplicity | |
| } | |
| # Prepare payload | |
| payload = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": kwargs.get('max_tokens', 512), | |
| "temperature": kwargs.get('temperature', 0.20), | |
| "top_p": kwargs.get('top_p', 0.70), | |
| "stream": False # Disable streaming for now | |
| } | |
| # Make request | |
| response = self.requests.post(self.invoke_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| def get_response_text(self, completion: Any) -> str: | |
| """Extract text from Gemma response""" | |
| # NVIDIA API returns OpenAI-compatible format | |
| if isinstance(completion, dict): | |
| return completion['choices'][0]['message']['content'] | |
| return str(completion) | |
| def get_usage_info(self, completion: Any) -> Dict: | |
| """Extract usage from Gemma response""" | |
| try: | |
| usage = completion.get('usage', {}) | |
| return { | |
| 'prompt_tokens': usage.get('prompt_tokens', 0), | |
| 'completion_tokens': usage.get('completion_tokens', 0), | |
| 'total_tokens': usage.get('total_tokens', 0), | |
| 'cached_tokens': 0 # NVIDIA API doesn't provide cached token info | |
| } | |
| except: | |
| return { | |
| 'prompt_tokens': 0, | |
| 'completion_tokens': 0, | |
| 'total_tokens': 0, | |
| 'cached_tokens': 0 | |
| } | |
| def calculate_cost(self, usage: Dict, model: str) -> float: | |
| """Calculate Gemma cost""" | |
| # NVIDIA API pricing (check current pricing at build.nvidia.com) | |
| # For now, using placeholder values - update with actual pricing | |
| model_costs = { | |
| "google/gemma-3-27b-it": {"input": 0.0, "output": 0.0}, # Free tier or update with actual costs | |
| } | |
| costs = model_costs.get(model, {"input": 0.0, "output": 0.0}) | |
| prompt_tokens = usage['prompt_tokens'] | |
| completion_tokens = usage['completion_tokens'] | |
| input_cost = (prompt_tokens / 1_000_000) * costs["input"] | |
| output_cost = (completion_tokens / 1_000_000) * costs["output"] | |
| return input_cost + output_cost | |
| def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider: | |
| """ | |
| Factory function to get LLM provider. | |
| Args: | |
| provider_name: Name of provider ('openai' or 'gemini') | |
| api_key: Optional API key (uses env var if not provided) | |
| Returns: | |
| LLMProvider instance | |
| """ | |
| providers = { | |
| 'openai': OpenAIProvider, | |
| 'gemini': GeminiProvider, | |
| 'gemma': GemmaProvider, | |
| } | |
| if provider_name not in providers: | |
| raise ValueError( | |
| f"Unknown provider: {provider_name}. " | |
| f"Available providers: {list(providers.keys())}" | |
| ) | |
| return providers[provider_name](api_key=api_key) | |
| def get_default_model(provider_name: str) -> str: | |
| """Get default model for a provider""" | |
| defaults = { | |
| 'openai': 'gpt-4o-mini', | |
| 'gemini': 'gemini-2.0-flash-lite', | |
| 'gemma': 'google/gemma-3-27b-it', | |
| } | |
| return defaults.get(provider_name, 'gpt-4o-mini') | |
| if __name__ == "__main__": | |
| # Test script | |
| print("Testing LLM Provider abstraction...") | |
| # Test OpenAI | |
| try: | |
| provider = get_provider('openai') | |
| print("β OpenAI provider initialized") | |
| except Exception as e: | |
| print(f"β OpenAI provider failed: {e}") | |
| # Test Gemini | |
| try: | |
| provider = get_provider('gemini') | |
| print("β Gemini provider initialized") | |
| except Exception as e: | |
| print(f"β Gemini provider failed: {e}") | |
| # Test Gemma | |
| try: | |
| provider = get_provider('gemma') | |
| print("β Gemma provider initialized") | |
| except Exception as e: | |
| print(f"β Gemma provider failed: {e}") | |