"""Dependency injection for FastAPI.""" from urllib.parse import unquote_plus from fastapi import Depends, HTTPException, Request from loguru import logger from config.settings import Settings from config.settings import get_settings as _get_settings from providers.base import BaseProvider, ProviderConfig from providers.common import get_user_facing_error_message from providers.exceptions import AuthenticationError from providers.llamacpp import LlamaCppProvider from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider # Provider registry: keyed by provider type string, lazily populated _providers: dict[str, BaseProvider] = {} def get_settings() -> Settings: """Get application settings via dependency injection.""" return _get_settings() def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider: """Construct and return a new provider instance for the given provider type.""" if provider_type == "nvidia_nim": if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip(): raise AuthenticationError( "NVIDIA_NIM_API_KEY is not set. Add it to your .env file. " "Get a key at https://build.nvidia.com/settings/api-keys" ) config = ProviderConfig( api_key=settings.nvidia_nim_api_key, base_url=NVIDIA_NIM_BASE_URL, rate_limit=settings.provider_rate_limit, rate_window=settings.provider_rate_window, max_concurrency=settings.provider_max_concurrency, http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) return NvidiaNimProvider(config, nim_settings=settings.nim) if provider_type == "open_router": if not settings.open_router_api_key or not settings.open_router_api_key.strip(): raise AuthenticationError( "OPENROUTER_API_KEY is not set. Add it to your .env file. " "Get a key at https://openrouter.ai/keys" ) config = ProviderConfig( api_key=settings.open_router_api_key, base_url=OPENROUTER_BASE_URL, rate_limit=settings.provider_rate_limit, rate_window=settings.provider_rate_window, max_concurrency=settings.provider_max_concurrency, http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) return OpenRouterProvider(config) if provider_type == "lmstudio": config = ProviderConfig( api_key="lm-studio", base_url=settings.lm_studio_base_url, rate_limit=settings.provider_rate_limit, rate_window=settings.provider_rate_window, max_concurrency=settings.provider_max_concurrency, http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) return LMStudioProvider(config) if provider_type == "llamacpp": config = ProviderConfig( api_key="llamacpp", base_url=settings.llamacpp_base_url, rate_limit=settings.provider_rate_limit, rate_window=settings.provider_rate_window, max_concurrency=settings.provider_max_concurrency, http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) return LlamaCppProvider(config) logger.error( "Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'", provider_type, ) raise ValueError( f"Unknown provider_type: '{provider_type}'. " f"Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'" ) def get_provider_for_type(provider_type: str) -> BaseProvider: """Get or create a provider for the given provider type. Providers are cached in the registry and reused across requests. """ if provider_type not in _providers: try: _providers[provider_type] = _create_provider_for_type( provider_type, get_settings() ) except AuthenticationError as e: raise HTTPException( status_code=503, detail=get_user_facing_error_message(e) ) from e logger.info("Provider initialized: {}", provider_type) return _providers[provider_type] def validate_request_api_key(request: Request, settings: Settings) -> None: """Validate a request against configured server API key. Checks `x-api-key` header, `Authorization: Bearer ...`, or query parameter `psw` against `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op. Supports Hugging Face Spaces private deployments via query parameter authentication: - Append `?psw=your-token` to the base URL - Or `?psw:your-token` (URL-encoded colon becomes %3A) """ anthropic_auth_token = getattr(settings, "anthropic_auth_token", None) if not anthropic_auth_token: # No API key configured -> allow return # Keep Space health/app shell reachable even when API auth is enabled. if _is_public_probe_request(request): return # Allow Hugging Face private Space signed browser requests for UI pages. # This keeps API routes protected while avoiding 401 on Space shell probes. if _is_hf_signed_page_request(request): return token = None # Check headers first (preferred) header = ( request.headers.get("x-api-key") or request.headers.get("authorization") or request.headers.get("anthropic-auth-token") ) if header: # Support both raw key in X-API-Key and Bearer token in Authorization token = header if header.lower().startswith("bearer "): token = header.split(" ", 1)[1] # Strip anything after the first colon to handle tokens with appended model names if token and ":" in token: token = token.split(":", 1)[0] else: token = _extract_query_token(request) if not token: raise HTTPException(status_code=401, detail="Missing API key") if token != anthropic_auth_token: raise HTTPException(status_code=401, detail="Invalid API key") def _extract_query_token(request: Request) -> str | None: """Extract auth token from query string for private proxy deployments.""" query_params = request.query_params if "psw" in query_params: token = query_params["psw"] if token and ":" in token: return token.split(":", 1)[0] return token or None raw_query_bytes = request.scope.get("query_string", b"") raw_query = raw_query_bytes.decode("utf-8", errors="ignore") if not raw_query: return None for part in raw_query.split("&"): if part.startswith("psw:"): token = unquote_plus(part[len("psw:") :]) if token and ":" in token: return token.split(":", 1)[0] return token or None if part.startswith("psw%3A") or part.startswith("psw%3a"): token = unquote_plus(part[len("psw%3A") :]) if token and ":" in token: return token.split(":", 1)[0] return token or None return None def _is_hf_signed_page_request(request: Request) -> bool: """Return True for Hugging Face signed browser requests to non-API pages.""" if request.method not in {"GET", "HEAD"}: return False if request.url.path.startswith("/v1/"): return False if "__sign" not in request.query_params: return False accept = request.headers.get("accept", "").lower() return "text/html" in accept or "*/*" in accept def _is_public_probe_request(request: Request) -> bool: """Return True for read-only public endpoints used by app shells/health checks.""" if request.method not in {"GET", "HEAD"}: return False return request.url.path in {"/", "/health"} def require_api_key( request: Request, settings: Settings = Depends(get_settings) ) -> None: """FastAPI dependency wrapper for API key validation.""" validate_request_api_key(request, settings) def get_provider() -> BaseProvider: """Get or create the default provider (based on MODEL env var). Backward-compatible convenience for health/root endpoints and tests. """ return get_provider_for_type(get_settings().provider_type) async def cleanup_provider(): """Cleanup all provider resources.""" global _providers for provider in _providers.values(): await provider.cleanup() _providers = {} logger.debug("Provider cleanup completed")