Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| from typing import Dict, Any | |
| # Import our schemas and providers | |
| from src.schemas import RouterConfig, LLMResponse | |
| from src.providers.base import BaseLLMProvider | |
| from src.providers.openai_client import OpenAIProvider | |
| from src.providers.anthropic_client import AnthropicProvider | |
| # Set up logging so we can see the retries happening in the terminal | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class OmniRouter: | |
| """ | |
| The central routing engine. | |
| Handles provider selection, retries, and error management. | |
| """ | |
| def __init__(self, api_keys: Dict[str, str]): | |
| """ | |
| We initialize the router with a dictionary of API keys. | |
| We then map string names (like 'openai') to their concrete class instances. | |
| """ | |
| self.providers: Dict[str, BaseLLMProvider] = {} | |
| # If the user passed an OpenAI key, activate the OpenAI provider | |
| if "openai" in api_keys: | |
| self.providers["openai"] = OpenAIProvider(api_key=api_keys["openai"]) | |
| #---Register Anthropic | |
| if "anthropic" in api_keys: | |
| self.providers["anthropic"] = AnthropicProvider(api_key=api_keys["anthropic"]) | |
| # We will add others here later | |
| async def generate(self, prompt: str, config: RouterConfig) -> LLMResponse: | |
| """ | |
| The main entry point. Routes the prompt to the correct provider with retries. | |
| """ | |
| # 1. Check if the requested provider actually exists in our dictionary | |
| provider = self.providers.get(config.provider) | |
| if not provider: | |
| raise ValueError(f"Provider '{config.provider}' is not configured.") | |
| last_exception = None | |
| # 2. PRIMARY RETRY LOOP | |
| for attempt in range(config.max_retries): | |
| try: | |
| # If this is a retry, log it | |
| if attempt > 0: | |
| logger.info(f"[{config.provider}] Retrying... Attempt {attempt + 1} of {config.max_retries}") | |
| # 3. The actual API call to whatever provider is currently selected | |
| response = await provider.async_generate(prompt, config) | |
| return response | |
| except Exception as e: | |
| # If the API crashes, we catch it here instead of crashing the app | |
| logger.warning(f"[{config.provider}] Attempt {attempt + 1} failed with error: {str(e)}") | |
| last_exception = e | |
| # 4. EXPONENTIAL BACKOFF | |
| # Wait 2^attempt seconds (1s, 2s, 4s, 8s...) before trying again | |
| wait_time = 2 ** attempt | |
| logger.info(f"Waiting {wait_time} seconds before next attempt...") | |
| await asyncio.sleep(wait_time) | |
| # 2. FAILOVER LOGIC (The Holy Grail) | |
| if config.fallback_provider: | |
| logger.error(f"🚨 Primary provider '{config.provider}' exhausted all retries. Initiating FAILOVER to '{config.fallback_provider}'...") | |
| # Create a new config for the fallback provider | |
| fallback_config = RouterConfig( | |
| provider=config.fallback_provider, | |
| model=config.fallback_model or config.model, # Use specific fallback model if provided | |
| temperature=config.temperature, | |
| max_retries=config.max_retries | |
| ) | |
| # Recursively call generate with the new config! | |
| return await self.generate(prompt, fallback_config) | |
| # 5. If we loop through all max_retries and still fail, crash gracefully | |
| logger.error(f"All {config.max_retries} attempts failed and no fallback configured.") | |
| raise last_exception |