| """Integration tests for NIM ↔ Cloudflare ↔ Gemini fallback logic.""" |
|
|
| import asyncio |
| import pytest |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| import sys |
| import os |
| import time |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| from production_server import ( |
| FallbackManager, |
| FallbackConfig, |
| CircuitBreaker, |
| RedisManager, |
| HTTPException, |
| ) |
|
|
|
|
| @pytest.fixture |
| async def mock_redis(): |
| redis = MagicMock(spec=RedisManager) |
| redis.get_circuit_state = AsyncMock(return_value={"state": "closed", "failures": 0, "last_failure": 0}) |
| redis.set_circuit_state = AsyncMock() |
| redis.get_cache = AsyncMock(return_value=None) |
| redis.set_cache = AsyncMock() |
| redis.check_rate_limit = AsyncMock(return_value=(True, 0.0)) |
| return redis |
|
|
|
|
| class TestFallbackManager: |
| @pytest.mark.asyncio |
| async def test_uses_primary_when_healthy(self, mock_redis): |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "nim" |
| assert config["api_base"] == "https://integrate.api.nvidia.com/v1" |
|
|
| @pytest.mark.asyncio |
| async def test_falls_back_when_primary_open(self, mock_redis): |
| mock_redis.get_circuit_state = AsyncMock(side_effect=[ |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "closed", "failures": 0, "last_failure": 0}, |
| ]) |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "cloudflare" |
|
|
| @pytest.mark.asyncio |
| async def test_falls_to_tertiary_when_secondary_open(self, mock_redis): |
| import production_server |
| old_gemini = production_server.GEMINI_API_KEY |
| production_server.GEMINI_API_KEY = "test-key" |
| try: |
| mock_redis.get_circuit_state = AsyncMock(side_effect=[ |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "closed", "failures": 0, "last_failure": 0}, |
| ]) |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "gemini" |
| assert "generativelanguage" in config["api_base"] |
| assert config["api_key"] == "test-key" |
| finally: |
| production_server.GEMINI_API_KEY = old_gemini |
|
|
| @pytest.mark.asyncio |
| async def test_falls_to_mlx_when_all_cloud_down(self, mock_redis): |
| import production_server |
| old_mlx = production_server.MLX_ENABLED |
| production_server.MLX_ENABLED = True |
| try: |
| mock_redis.get_circuit_state = AsyncMock(side_effect=[ |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "closed", "failures": 0, "last_failure": 0}, |
| ]) |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "mlx" |
| finally: |
| production_server.MLX_ENABLED = old_mlx |
|
|
| @pytest.mark.asyncio |
| async def test_raises_when_all_down(self, mock_redis): |
| import production_server |
| old_mlx = production_server.MLX_ENABLED |
| production_server.MLX_ENABLED = False |
| try: |
| mock_redis.get_circuit_state = AsyncMock(side_effect=[ |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| ]) |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| with pytest.raises(HTTPException) as exc_info: |
| await mgr.get_active_provider() |
| assert exc_info.value.status_code == 503 |
| finally: |
| production_server.MLX_ENABLED = old_mlx |
|
|
| @pytest.mark.asyncio |
| async def test_respects_disabled_fallback(self, mock_redis): |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=False)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "nim" |
|
|
| @pytest.mark.asyncio |
| async def test_gemini_config(self, mock_redis): |
| import production_server |
| old_gemini = production_server.GEMINI_API_KEY |
| production_server.GEMINI_API_KEY = "gemini-test-key" |
| try: |
| mock_redis.get_circuit_state = AsyncMock(side_effect=[ |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "open", "failures": 5, "last_failure": 9999999999}, |
| {"state": "closed", "failures": 0, "last_failure": 0}, |
| ]) |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "gemini" |
| assert config["rpm_limit"] == 60 |
| assert config["cost_per_1m_input"] == 0.075 |
| finally: |
| production_server.GEMINI_API_KEY = old_gemini |
|
|
| @pytest.mark.asyncio |
| async def test_nim_config(self, mock_redis): |
| import production_server |
| old_nim = production_server.NIM_API_BASE |
| production_server.NIM_API_BASE = "https://custom.nvidia.com/v1" |
| try: |
| mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True)) |
| provider, config = await mgr.get_active_provider() |
| assert provider == "nim" |
| assert config["api_base"] == "https://custom.nvidia.com/v1" |
| finally: |
| production_server.NIM_API_BASE = old_nim |
|
|
|
|
| class TestCircuitBreakerFallback: |
| @pytest.mark.asyncio |
| async def test_circuit_open_then_half_open_then_closed(self, mock_redis): |
| cb = CircuitBreaker(mock_redis, "nim") |
| assert await cb.can_execute() |
| for _ in range(5): |
| await cb.record_failure() |
| mock_redis.get_circuit_state = AsyncMock(return_value={ |
| "state": "open", "failures": 5, "last_failure": time.time() |
| }) |
| assert not await cb.can_execute() |
| mock_redis.get_circuit_state = AsyncMock(return_value={ |
| "state": "open", "failures": 5, "last_failure": time.time() - 120 |
| }) |
| assert await cb.can_execute() |
| await cb.record_success() |
| mock_redis.get_circuit_state = AsyncMock(return_value={ |
| "state": "closed", "failures": 0, "last_failure": 0 |
| }) |
| assert await cb.can_execute() |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|