"""Integration tests for NIM ↔ Cloudflare ↔ Gemini fallback logic.""" import asyncio import pytest from unittest.mock import AsyncMock, MagicMock, patch import sys import os import time sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from production_server import ( FallbackManager, FallbackConfig, CircuitBreaker, RedisManager, HTTPException, ) @pytest.fixture async def mock_redis(): redis = MagicMock(spec=RedisManager) redis.get_circuit_state = AsyncMock(return_value={"state": "closed", "failures": 0, "last_failure": 0}) redis.set_circuit_state = AsyncMock() redis.get_cache = AsyncMock(return_value=None) redis.set_cache = AsyncMock() redis.check_rate_limit = AsyncMock(return_value=(True, 0.0)) return redis class TestFallbackManager: @pytest.mark.asyncio async def test_uses_primary_when_healthy(self, mock_redis): mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "nim" assert config["api_base"] == "https://integrate.api.nvidia.com/v1" @pytest.mark.asyncio async def test_falls_back_when_primary_open(self, mock_redis): mock_redis.get_circuit_state = AsyncMock(side_effect=[ {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "closed", "failures": 0, "last_failure": 0}, ]) mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "cloudflare" @pytest.mark.asyncio async def test_falls_to_tertiary_when_secondary_open(self, mock_redis): import production_server old_gemini = production_server.GEMINI_API_KEY production_server.GEMINI_API_KEY = "test-key" try: mock_redis.get_circuit_state = AsyncMock(side_effect=[ {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "closed", "failures": 0, "last_failure": 0}, ]) mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "gemini" assert "generativelanguage" in config["api_base"] assert config["api_key"] == "test-key" finally: production_server.GEMINI_API_KEY = old_gemini @pytest.mark.asyncio async def test_falls_to_mlx_when_all_cloud_down(self, mock_redis): import production_server old_mlx = production_server.MLX_ENABLED production_server.MLX_ENABLED = True try: mock_redis.get_circuit_state = AsyncMock(side_effect=[ {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "closed", "failures": 0, "last_failure": 0}, ]) mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "mlx" finally: production_server.MLX_ENABLED = old_mlx @pytest.mark.asyncio async def test_raises_when_all_down(self, mock_redis): import production_server old_mlx = production_server.MLX_ENABLED production_server.MLX_ENABLED = False try: mock_redis.get_circuit_state = AsyncMock(side_effect=[ {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, ]) mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) with pytest.raises(HTTPException) as exc_info: await mgr.get_active_provider() assert exc_info.value.status_code == 503 finally: production_server.MLX_ENABLED = old_mlx @pytest.mark.asyncio async def test_respects_disabled_fallback(self, mock_redis): mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=False)) provider, config = await mgr.get_active_provider() assert provider == "nim" @pytest.mark.asyncio async def test_gemini_config(self, mock_redis): import production_server old_gemini = production_server.GEMINI_API_KEY production_server.GEMINI_API_KEY = "gemini-test-key" try: mock_redis.get_circuit_state = AsyncMock(side_effect=[ {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "open", "failures": 5, "last_failure": 9999999999}, {"state": "closed", "failures": 0, "last_failure": 0}, ]) mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "gemini" assert config["rpm_limit"] == 60 assert config["cost_per_1m_input"] == 0.075 finally: production_server.GEMINI_API_KEY = old_gemini @pytest.mark.asyncio async def test_nim_config(self, mock_redis): import production_server old_nim = production_server.NIM_API_BASE production_server.NIM_API_BASE = "https://custom.nvidia.com/v1" try: mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) provider, config = await mgr.get_active_provider() assert provider == "nim" assert config["api_base"] == "https://custom.nvidia.com/v1" finally: production_server.NIM_API_BASE = old_nim class TestCircuitBreakerFallback: @pytest.mark.asyncio async def test_circuit_open_then_half_open_then_closed(self, mock_redis): cb = CircuitBreaker(mock_redis, "nim") assert await cb.can_execute() for _ in range(5): await cb.record_failure() mock_redis.get_circuit_state = AsyncMock(return_value={ "state": "open", "failures": 5, "last_failure": time.time() }) assert not await cb.can_execute() mock_redis.get_circuit_state = AsyncMock(return_value={ "state": "open", "failures": 5, "last_failure": time.time() - 120 }) assert await cb.can_execute() await cb.record_success() mock_redis.get_circuit_state = AsyncMock(return_value={ "state": "closed", "failures": 0, "last_failure": 0 }) assert await cb.can_execute() if __name__ == "__main__": pytest.main([__file__, "-v"])