ml-intern-local-fork / production /tests /test_fallback.py
raazkumar's picture
Upload production/tests/test_fallback.py
f36c6d0 verified
"""Integration tests for NIM ↔ Cloudflare ↔ Gemini fallback logic."""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
import os
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from production_server import (
FallbackManager,
FallbackConfig,
CircuitBreaker,
RedisManager,
HTTPException,
)
@pytest.fixture
async def mock_redis():
redis = MagicMock(spec=RedisManager)
redis.get_circuit_state = AsyncMock(return_value={"state": "closed", "failures": 0, "last_failure": 0})
redis.set_circuit_state = AsyncMock()
redis.get_cache = AsyncMock(return_value=None)
redis.set_cache = AsyncMock()
redis.check_rate_limit = AsyncMock(return_value=(True, 0.0))
return redis
class TestFallbackManager:
@pytest.mark.asyncio
async def test_uses_primary_when_healthy(self, mock_redis):
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "nim"
assert config["api_base"] == "https://integrate.api.nvidia.com/v1"
@pytest.mark.asyncio
async def test_falls_back_when_primary_open(self, mock_redis):
mock_redis.get_circuit_state = AsyncMock(side_effect=[
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "closed", "failures": 0, "last_failure": 0},
])
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "cloudflare"
@pytest.mark.asyncio
async def test_falls_to_tertiary_when_secondary_open(self, mock_redis):
import production_server
old_gemini = production_server.GEMINI_API_KEY
production_server.GEMINI_API_KEY = "test-key"
try:
mock_redis.get_circuit_state = AsyncMock(side_effect=[
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "closed", "failures": 0, "last_failure": 0},
])
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "gemini"
assert "generativelanguage" in config["api_base"]
assert config["api_key"] == "test-key"
finally:
production_server.GEMINI_API_KEY = old_gemini
@pytest.mark.asyncio
async def test_falls_to_mlx_when_all_cloud_down(self, mock_redis):
import production_server
old_mlx = production_server.MLX_ENABLED
production_server.MLX_ENABLED = True
try:
mock_redis.get_circuit_state = AsyncMock(side_effect=[
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "closed", "failures": 0, "last_failure": 0},
])
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "mlx"
finally:
production_server.MLX_ENABLED = old_mlx
@pytest.mark.asyncio
async def test_raises_when_all_down(self, mock_redis):
import production_server
old_mlx = production_server.MLX_ENABLED
production_server.MLX_ENABLED = False
try:
mock_redis.get_circuit_state = AsyncMock(side_effect=[
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
])
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
with pytest.raises(HTTPException) as exc_info:
await mgr.get_active_provider()
assert exc_info.value.status_code == 503
finally:
production_server.MLX_ENABLED = old_mlx
@pytest.mark.asyncio
async def test_respects_disabled_fallback(self, mock_redis):
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=False))
provider, config = await mgr.get_active_provider()
assert provider == "nim"
@pytest.mark.asyncio
async def test_gemini_config(self, mock_redis):
import production_server
old_gemini = production_server.GEMINI_API_KEY
production_server.GEMINI_API_KEY = "gemini-test-key"
try:
mock_redis.get_circuit_state = AsyncMock(side_effect=[
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "open", "failures": 5, "last_failure": 9999999999},
{"state": "closed", "failures": 0, "last_failure": 0},
])
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "gemini"
assert config["rpm_limit"] == 60
assert config["cost_per_1m_input"] == 0.075
finally:
production_server.GEMINI_API_KEY = old_gemini
@pytest.mark.asyncio
async def test_nim_config(self, mock_redis):
import production_server
old_nim = production_server.NIM_API_BASE
production_server.NIM_API_BASE = "https://custom.nvidia.com/v1"
try:
mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
provider, config = await mgr.get_active_provider()
assert provider == "nim"
assert config["api_base"] == "https://custom.nvidia.com/v1"
finally:
production_server.NIM_API_BASE = old_nim
class TestCircuitBreakerFallback:
@pytest.mark.asyncio
async def test_circuit_open_then_half_open_then_closed(self, mock_redis):
cb = CircuitBreaker(mock_redis, "nim")
assert await cb.can_execute()
for _ in range(5):
await cb.record_failure()
mock_redis.get_circuit_state = AsyncMock(return_value={
"state": "open", "failures": 5, "last_failure": time.time()
})
assert not await cb.can_execute()
mock_redis.get_circuit_state = AsyncMock(return_value={
"state": "open", "failures": 5, "last_failure": time.time() - 120
})
assert await cb.can_execute()
await cb.record_success()
mock_redis.get_circuit_state = AsyncMock(return_value={
"state": "closed", "failures": 0, "last_failure": 0
})
assert await cb.can_execute()
if __name__ == "__main__":
pytest.main([__file__, "-v"])