raazkumar commited on
Commit
e7b843e
·
verified ·
1 Parent(s): 8296605

Upload production/tests/test_fallback.py

Browse files
Files changed (1) hide show
  1. production/tests/test_fallback.py +150 -0
production/tests/test_fallback.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for NIM ↔ Cloudflare fallback logic.
3
+ """
4
+
5
+ import asyncio
6
+ import pytest
7
+ from unittest.mock import AsyncMock, MagicMock, patch
8
+
9
+ import sys
10
+ import os
11
+ import time
12
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
13
+
14
+ from production_server import (
15
+ FallbackManager,
16
+ FallbackConfig,
17
+ CircuitBreaker,
18
+ RedisManager,
19
+ HTTPException,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ async def mock_redis():
25
+ redis = MagicMock(spec=RedisManager)
26
+ redis.get_circuit_state = AsyncMock(return_value={"state": "closed", "failures": 0, "last_failure": 0})
27
+ redis.set_circuit_state = AsyncMock()
28
+ redis.get_cache = AsyncMock(return_value=None)
29
+ redis.set_cache = AsyncMock()
30
+ redis.check_rate_limit = AsyncMock(return_value=(True, 0.0))
31
+ return redis
32
+
33
+
34
+ class TestFallbackManager:
35
+ @pytest.mark.asyncio
36
+ async def test_uses_primary_when_healthy(self, mock_redis):
37
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
38
+ provider, config = await mgr.get_active_provider()
39
+ assert provider == "nim"
40
+ assert config["api_base"] == "https://integrate.api.nvidia.com/v1"
41
+
42
+ @pytest.mark.asyncio
43
+ async def test_falls_back_when_primary_open(self, mock_redis):
44
+ mock_redis.get_circuit_state = AsyncMock(side_effect=[
45
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
46
+ {"state": "closed", "failures": 0, "last_failure": 0},
47
+ ])
48
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
49
+ provider, config = await mgr.get_active_provider()
50
+ assert provider == "cloudflare"
51
+
52
+ @pytest.mark.asyncio
53
+ async def test_falls_to_mlx_when_both_cloud_down(self, mock_redis):
54
+ import production_server
55
+ old_mlx = production_server.MLX_ENABLED
56
+ production_server.MLX_ENABLED = True
57
+ try:
58
+ mock_redis.get_circuit_state = AsyncMock(side_effect=[
59
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
60
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
61
+ {"state": "closed", "failures": 0, "last_failure": 0},
62
+ ])
63
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
64
+ provider, config = await mgr.get_active_provider()
65
+ assert provider == "mlx"
66
+ finally:
67
+ production_server.MLX_ENABLED = old_mlx
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_raises_when_all_down(self, mock_redis):
71
+ import production_server
72
+ old_mlx = production_server.MLX_ENABLED
73
+ production_server.MLX_ENABLED = False
74
+ try:
75
+ mock_redis.get_circuit_state = AsyncMock(side_effect=[
76
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
77
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
78
+ ])
79
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
80
+ with pytest.raises(HTTPException) as exc_info:
81
+ await mgr.get_active_provider()
82
+ assert exc_info.value.status_code == 503
83
+ finally:
84
+ production_server.MLX_ENABLED = old_mlx
85
+
86
+ @pytest.mark.asyncio
87
+ async def test_respects_disabled_fallback(self, mock_redis):
88
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=False))
89
+ provider, config = await mgr.get_active_provider()
90
+ assert provider == "nim"
91
+
92
+ @pytest.mark.asyncio
93
+ async def test_cloudflare_config(self, mock_redis):
94
+ import production_server
95
+ old_cf_key = production_server.CLOUDFLARE_API_KEY
96
+ old_cf_id = production_server.CLOUDFLARE_ACCOUNT_ID
97
+ production_server.CLOUDFLARE_API_KEY = "test-key"
98
+ production_server.CLOUDFLARE_ACCOUNT_ID = "test-account"
99
+ try:
100
+ mock_redis.get_circuit_state = AsyncMock(side_effect=[
101
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
102
+ {"state": "closed", "failures": 0, "last_failure": 0},
103
+ ])
104
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
105
+ provider, config = await mgr.get_active_provider()
106
+ assert provider == "cloudflare"
107
+ assert "api.cloudflare.com" in config["api_base"]
108
+ assert config["api_key"] == "test-key"
109
+ finally:
110
+ production_server.CLOUDFLARE_API_KEY = old_cf_key
111
+ production_server.CLOUDFLARE_ACCOUNT_ID = old_cf_id
112
+
113
+ @pytest.mark.asyncio
114
+ async def test_nim_config(self, mock_redis):
115
+ import production_server
116
+ old_nim = production_server.NIM_API_BASE
117
+ production_server.NIM_API_BASE = "https://custom.nvidia.com/v1"
118
+ try:
119
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
120
+ provider, config = await mgr.get_active_provider()
121
+ assert provider == "nim"
122
+ assert config["api_base"] == "https://custom.nvidia.com/v1"
123
+ finally:
124
+ production_server.NIM_API_BASE = old_nim
125
+
126
+
127
+ class TestCircuitBreakerFallback:
128
+ @pytest.mark.asyncio
129
+ async def test_circuit_open_then_half_open_then_closed(self, mock_redis):
130
+ cb = CircuitBreaker(mock_redis, "nim")
131
+ assert await cb.can_execute()
132
+ for _ in range(5):
133
+ await cb.record_failure()
134
+ mock_redis.get_circuit_state = AsyncMock(return_value={
135
+ "state": "open", "failures": 5, "last_failure": time.time()
136
+ })
137
+ assert not await cb.can_execute()
138
+ mock_redis.get_circuit_state = AsyncMock(return_value={
139
+ "state": "open", "failures": 5, "last_failure": time.time() - 120
140
+ })
141
+ assert await cb.can_execute()
142
+ await cb.record_success()
143
+ mock_redis.get_circuit_state = AsyncMock(return_value={
144
+ "state": "closed", "failures": 0, "last_failure": 0
145
+ })
146
+ assert await cb.can_execute()
147
+
148
+
149
+ if __name__ == "__main__":
150
+ pytest.main([__file__, "-v"])