| |
| import logging |
| import asyncio |
| import aiohttp |
| import time |
| from typing import Dict, Optional |
| from .models_config import LLM_CONFIG |
| from .config import get_settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class LLMRouter: |
| def __init__(self, hf_token=None, use_local_models: bool = False): |
| """ |
| Initialize LLM Router with ZeroGPU Chat API (RunPod). |
| |
| Args: |
| hf_token: Not used (kept for backward compatibility) |
| use_local_models: Must be False (local models disabled) |
| """ |
| if use_local_models: |
| raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.") |
| |
| self.settings = get_settings() |
| self.base_url = self.settings.zerogpu_base_url.rstrip('/') |
| self.access_token = None |
| self.refresh_token = None |
| self.token_expires_at = 0 |
| self.session = None |
| |
| |
| if not self.settings.zerogpu_base_url: |
| raise ValueError( |
| "ZEROGPU_BASE_URL is required. " |
| "Set it in environment variables or .env file" |
| ) |
| |
| |
| if not self.settings.zerogpu_email or not self.settings.zerogpu_password: |
| raise ValueError( |
| "ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. " |
| "Set them in environment variables or .env file" |
| ) |
| |
| logger.info("ZeroGPU Chat API client initializing") |
| logger.info(f"Base URL: {self.base_url}") |
| |
| |
| try: |
| |
| logger.info("ZeroGPU Chat API client initialized (authentication on first request)") |
| except Exception as e: |
| logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}") |
| raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e |
| |
| async def route_inference(self, task_type: str, prompt: str, **kwargs): |
| """ |
| Route inference to ZeroGPU Chat API. |
| |
| Args: |
| task_type: Type of task (general_reasoning, intent_classification, etc.) |
| prompt: Input prompt |
| **kwargs: Additional parameters (max_tokens, temperature, etc.) |
| |
| Returns: |
| Generated text response |
| """ |
| logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}") |
| |
| try: |
| |
| await self._ensure_authenticated() |
| |
| |
| api_task = self._map_task_type(task_type) |
| |
| |
| kwargs['original_task_type'] = task_type |
| |
| |
| if task_type == "embedding_generation": |
| logger.warning("Embedding generation via ZeroGPU API may require special implementation") |
| result = await self._call_zerogpu_api(api_task, prompt, **kwargs) |
| else: |
| result = await self._call_zerogpu_api(api_task, prompt, **kwargs) |
| |
| if result is None: |
| logger.error(f"ZeroGPU Chat API returned None for task: {task_type}") |
| raise RuntimeError(f"Inference failed for task: {task_type}") |
| |
| logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)") |
| return result |
| |
| except Exception as e: |
| logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True) |
| raise RuntimeError( |
| f"Inference failed for task: {task_type}. " |
| f"ZeroGPU Chat API error: {e}" |
| ) from e |
| |
| async def _ensure_authenticated(self): |
| """Ensure we have a valid access token, login if needed.""" |
| |
| if self.access_token and time.time() < (self.token_expires_at - 60): |
| return |
| |
| |
| if self.session is None: |
| self.session = aiohttp.ClientSession() |
| |
| |
| await self._login() |
| |
| async def _login(self): |
| """Login to ZeroGPU Chat API and get access/refresh tokens.""" |
| try: |
| login_url = f"{self.base_url}/login" |
| login_data = { |
| "email": self.settings.zerogpu_email, |
| "password": self.settings.zerogpu_password |
| } |
| |
| async with self.session.post(login_url, json=login_data) as response: |
| if response.status == 401: |
| raise ValueError("Invalid email or password for ZeroGPU Chat API") |
| response.raise_for_status() |
| data = await response.json() |
| |
| self.access_token = data.get("access_token") |
| self.refresh_token = data.get("refresh_token") |
| |
| |
| self.token_expires_at = time.time() + 900 |
| |
| logger.info("Successfully authenticated with ZeroGPU Chat API") |
| |
| except aiohttp.ClientError as e: |
| logger.error(f"Failed to login to ZeroGPU Chat API: {e}") |
| raise RuntimeError(f"Authentication failed: {e}") from e |
| |
| async def _refresh_token(self): |
| """Refresh access token using refresh token.""" |
| try: |
| refresh_url = f"{self.base_url}/refresh" |
| headers = {"X-Refresh-Token": self.refresh_token} |
| |
| async with self.session.post(refresh_url, headers=headers) as response: |
| if response.status == 401: |
| |
| await self._login() |
| return |
| |
| response.raise_for_status() |
| data = await response.json() |
| |
| self.access_token = data.get("access_token") |
| self.refresh_token = data.get("refresh_token") |
| self.token_expires_at = time.time() + 900 |
| |
| logger.info("Successfully refreshed ZeroGPU Chat API token") |
| |
| except aiohttp.ClientError as e: |
| logger.error(f"Failed to refresh token: {e}") |
| |
| await self._login() |
| |
| def _map_task_type(self, internal_task: str) -> str: |
| """Map internal task types to ZeroGPU Chat API task types.""" |
| task_mapping = { |
| "general_reasoning": "general", |
| "response_synthesis": "general", |
| "intent_classification": "classification", |
| "safety_check": "classification", |
| "embedding_generation": "embedding" |
| } |
| return task_mapping.get(internal_task, "general") |
| |
| async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]: |
| """Call ZeroGPU Chat API for inference.""" |
| if not self.session: |
| self.session = aiohttp.ClientSession() |
| |
| |
| original_task = kwargs.pop('original_task_type', None) |
| |
| |
| model_config = self._select_model(original_task or 'general_reasoning') |
| |
| |
| payload = { |
| "message": prompt, |
| "task": task, |
| "max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)), |
| "temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)), |
| "top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)), |
| } |
| |
| |
| if 'context' in kwargs and kwargs['context']: |
| |
| context = kwargs['context'] |
| if isinstance(context, list) and len(context) > 0: |
| |
| api_context = [] |
| for item in context[:50]: |
| if isinstance(item, (list, tuple)) and len(item) >= 2: |
| |
| api_context.append({ |
| "role": "user", |
| "content": str(item[0]), |
| "timestamp": kwargs.get('timestamp', time.time()) |
| }) |
| api_context.append({ |
| "role": "assistant", |
| "content": str(item[1]), |
| "timestamp": kwargs.get('timestamp', time.time()) |
| }) |
| elif isinstance(item, dict): |
| api_context.append(item) |
| payload["context"] = api_context |
| |
| if 'system_prompt' in kwargs and kwargs['system_prompt']: |
| payload["system_prompt"] = kwargs['system_prompt'] |
| if 'repetition_penalty' in kwargs: |
| payload["repetition_penalty"] = kwargs['repetition_penalty'] |
| |
| |
| headers = { |
| "Authorization": f"Bearer {self.access_token}", |
| "Content-Type": "application/json" |
| } |
| |
| try: |
| chat_url = f"{self.base_url}/chat" |
| |
| async with self.session.post(chat_url, json=payload, headers=headers) as response: |
| |
| if response.status == 401: |
| logger.info("Token expired, refreshing...") |
| await self._refresh_token() |
| headers["Authorization"] = f"Bearer {self.access_token}" |
| |
| async with self.session.post(chat_url, json=payload, headers=headers) as retry_response: |
| retry_response.raise_for_status() |
| data = await retry_response.json() |
| return data.get("response") |
| |
| response.raise_for_status() |
| data = await response.json() |
| |
| |
| result = data.get("response") |
| if result: |
| logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})") |
| return result |
| else: |
| logger.error("ZeroGPU Chat API returned empty response") |
| return None |
| |
| except aiohttp.ClientError as e: |
| logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True) |
| raise |
| |
| def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int: |
| """ |
| Calculate safe max_tokens based on input token count and model context window. |
| |
| Args: |
| prompt: Input prompt text |
| requested_max_tokens: Desired max_tokens value |
| |
| Returns: |
| int: Adjusted max_tokens that fits within context window |
| """ |
| |
| |
| input_tokens = len(prompt) // 4 |
| |
| |
| context_window = self.settings.zerogpu_model_context_window |
| |
| logger.debug( |
| f"Calculating safe max_tokens: input ~{input_tokens} tokens, " |
| f"context_window={context_window}, requested={requested_max_tokens}" |
| ) |
| |
| |
| available_tokens = context_window - input_tokens - 100 |
| |
| |
| safe_max_tokens = min(requested_max_tokens, available_tokens) |
| |
| |
| safe_max_tokens = max(50, safe_max_tokens) |
| |
| if safe_max_tokens < requested_max_tokens: |
| logger.warning( |
| f"Reduced max_tokens from {requested_max_tokens} to {safe_max_tokens} " |
| f"(input: ~{input_tokens} tokens, context window: {context_window} tokens, " |
| f"available: {available_tokens} tokens)" |
| ) |
| |
| return safe_max_tokens |
| |
| def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str: |
| """ |
| Format prompt for ZeroGPU Chat API. |
| Can be customized based on model requirements. |
| """ |
| formatted_prompt = prompt |
| |
| |
| if self._is_math_query(prompt): |
| math_directive = "Please reason step by step, and put your final answer within \\boxed{}." |
| formatted_prompt = f"{formatted_prompt}\n\n{math_directive}" |
| |
| return formatted_prompt |
| |
| def _is_math_query(self, prompt: str) -> bool: |
| """Detect if query is mathematical""" |
| math_keywords = [ |
| "solve", "calculate", "compute", "equation", "formula", |
| "mathematical", "algebra", "geometry", "calculus", "integral", |
| "derivative", "theorem", "proof", "problem" |
| ] |
| prompt_lower = prompt.lower() |
| return any(keyword in prompt_lower for keyword in math_keywords) |
| |
| def _clean_reasoning_tags(self, text: str) -> str: |
| """Clean up reasoning tags from response if present""" |
| if not text: |
| return text |
| |
| text = text.replace("`<think>`", "").replace("`</think>`", "") |
| text = text.replace("`<think>`", "").replace("`</think>`", "") |
| text = text.strip() |
| return text |
| |
| def _select_model(self, task_type: str) -> dict: |
| """Select model configuration based on task type""" |
| model_map = { |
| "intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
| "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
| "safety_check": LLM_CONFIG["models"]["safety_checker"], |
| "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
| "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
| } |
| return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
| |
| async def get_available_models(self): |
| """Get list of available models from ZeroGPU Chat API""" |
| try: |
| await self._ensure_authenticated() |
| if not self.session: |
| self.session = aiohttp.ClientSession() |
| |
| tasks_url = f"{self.base_url}/tasks" |
| headers = {"Authorization": f"Bearer {self.access_token}"} |
| |
| async with self.session.get(tasks_url, headers=headers) as response: |
| if response.status == 401: |
| await self._refresh_token() |
| headers["Authorization"] = f"Bearer {self.access_token}" |
| async with self.session.get(tasks_url, headers=headers) as retry_response: |
| retry_response.raise_for_status() |
| data = await retry_response.json() |
| else: |
| response.raise_for_status() |
| data = await response.json() |
| |
| tasks = data.get("tasks", {}) |
| models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}" |
| for task, info in tasks.items()] |
| return models if models else ["ZeroGPU Chat API"] |
| except Exception as e: |
| logger.error(f"Failed to get available models: {e}") |
| return ["ZeroGPU Chat API"] |
| |
| async def health_check(self): |
| """Perform health check on ZeroGPU Chat API""" |
| try: |
| if not self.session: |
| self.session = aiohttp.ClientSession() |
| |
| |
| health_url = f"{self.base_url}/health" |
| async with self.session.get(health_url) as response: |
| response.raise_for_status() |
| data = await response.json() |
| |
| return { |
| "provider": "zerogpu_chat_api", |
| "status": "healthy" if data.get("status") == "healthy" else "unhealthy", |
| "models_ready": data.get("models_ready", False), |
| "base_url": self.base_url |
| } |
| except Exception as e: |
| logger.error(f"Health check failed: {e}") |
| return { |
| "provider": "zerogpu_chat_api", |
| "status": "unhealthy", |
| "error": str(e) |
| } |
| |
| async def __aenter__(self): |
| """Async context manager entry""" |
| if not self.session: |
| self.session = aiohttp.ClientSession() |
| return self |
| |
| async def __aexit__(self, exc_type, exc_val, exc_tb): |
| """Async context manager exit""" |
| if self.session: |
| await self.session.close() |
| self.session = None |
| |
| def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None, |
| user_input: Optional[str] = None) -> str: |
| """ |
| Smart context windowing with user input priority. |
| User input is NEVER truncated - context is reduced to fit. |
| |
| Args: |
| raw_context: Context dictionary |
| max_tokens: Optional override (uses config default if None) |
| user_input: Optional explicit user input (takes priority over raw_context['user_input']) |
| """ |
| |
| if max_tokens is None: |
| max_tokens = self.settings.context_preparation_budget |
| |
| |
| actual_user_input = user_input or raw_context.get('user_input', '') |
| |
| |
| user_input_tokens = len(actual_user_input) // 4 |
| |
| |
| user_input_max = self.settings.user_input_max_tokens |
| if user_input_tokens > user_input_max: |
| logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating") |
| max_chars = user_input_max * 4 |
| actual_user_input = actual_user_input[:max_chars - 3] + "..." |
| user_input_tokens = user_input_max |
| |
| |
| remaining_tokens = max_tokens - user_input_tokens |
| if remaining_tokens < 0: |
| logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})") |
| remaining_tokens = 0 |
| |
| logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}") |
| |
| |
| priority_elements = [ |
| ('recent_interactions', 0.8), |
| ('user_preferences', 0.6), |
| ('session_summary', 0.4), |
| ('historical_context', 0.2) |
| ] |
| |
| formatted_context = [] |
| total_tokens = user_input_tokens |
| |
| |
| if actual_user_input: |
| formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}") |
| |
| |
| for element, priority in priority_elements: |
| element_key_map = { |
| 'recent_interactions': raw_context.get('interaction_contexts', []), |
| 'user_preferences': raw_context.get('preferences', {}), |
| 'session_summary': raw_context.get('session_context', {}), |
| 'historical_context': raw_context.get('user_context', '') |
| } |
| |
| content = element_key_map.get(element, '') |
| |
| |
| if isinstance(content, dict): |
| content = str(content) |
| elif isinstance(content, list): |
| content = "\n".join([str(item) for item in content[:10]]) |
| |
| if not content: |
| continue |
| |
| |
| tokens = len(content) // 4 |
| |
| if total_tokens + tokens <= max_tokens: |
| formatted_context.append(f"=== {element.upper()} ===\n{content}") |
| total_tokens += tokens |
| elif priority > 0.5 and remaining_tokens > 0: |
| available = max_tokens - total_tokens |
| if available > 100: |
| truncated = self._truncate_to_tokens(content, available) |
| formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") |
| total_tokens += available |
| break |
| |
| logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})") |
| return "\n\n".join(formatted_context) |
| |
| def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: |
| """Truncate content to fit within token limit""" |
| |
| max_chars = max_tokens * 4 |
| if len(content) <= max_chars: |
| return content |
| return content[:max_chars - 3] + "..." |
|
|