raazkumar commited on
Commit
f36c6d0
·
verified ·
1 Parent(s): 091c7e0

Upload production/tests/test_fallback.py

Browse files
Files changed (1) hide show
  1. production/tests/test_fallback.py +38 -21
production/tests/test_fallback.py CHANGED
@@ -1,6 +1,4 @@
1
- """
2
- Integration tests for NIM ↔ Cloudflare fallback logic.
3
- """
4
 
5
  import asyncio
6
  import pytest
@@ -34,7 +32,7 @@ async def mock_redis():
34
  class TestFallbackManager:
35
  @pytest.mark.asyncio
36
  async def test_uses_primary_when_healthy(self, mock_redis):
37
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
38
  provider, config = await mgr.get_active_provider()
39
  assert provider == "nim"
40
  assert config["api_base"] == "https://integrate.api.nvidia.com/v1"
@@ -45,22 +43,42 @@ class TestFallbackManager:
45
  {"state": "open", "failures": 5, "last_failure": 9999999999},
46
  {"state": "closed", "failures": 0, "last_failure": 0},
47
  ])
48
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
49
  provider, config = await mgr.get_active_provider()
50
  assert provider == "cloudflare"
51
 
52
  @pytest.mark.asyncio
53
- async def test_falls_to_mlx_when_both_cloud_down(self, mock_redis):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  import production_server
55
  old_mlx = production_server.MLX_ENABLED
56
  production_server.MLX_ENABLED = True
57
  try:
58
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
 
59
  {"state": "open", "failures": 5, "last_failure": 9999999999},
60
  {"state": "open", "failures": 5, "last_failure": 9999999999},
61
  {"state": "closed", "failures": 0, "last_failure": 0},
62
  ])
63
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
64
  provider, config = await mgr.get_active_provider()
65
  assert provider == "mlx"
66
  finally:
@@ -75,8 +93,9 @@ class TestFallbackManager:
75
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
76
  {"state": "open", "failures": 5, "last_failure": 9999999999},
77
  {"state": "open", "failures": 5, "last_failure": 9999999999},
 
78
  ])
79
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
80
  with pytest.raises(HTTPException) as exc_info:
81
  await mgr.get_active_provider()
82
  assert exc_info.value.status_code == 503
@@ -85,30 +104,28 @@ class TestFallbackManager:
85
 
86
  @pytest.mark.asyncio
87
  async def test_respects_disabled_fallback(self, mock_redis):
88
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=False))
89
  provider, config = await mgr.get_active_provider()
90
  assert provider == "nim"
91
 
92
  @pytest.mark.asyncio
93
- async def test_cloudflare_config(self, mock_redis):
94
  import production_server
95
- old_cf_key = production_server.CLOUDFLARE_API_KEY
96
- old_cf_id = production_server.CLOUDFLARE_ACCOUNT_ID
97
- production_server.CLOUDFLARE_API_KEY = "test-key"
98
- production_server.CLOUDFLARE_ACCOUNT_ID = "test-account"
99
  try:
100
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
 
101
  {"state": "open", "failures": 5, "last_failure": 9999999999},
102
  {"state": "closed", "failures": 0, "last_failure": 0},
103
  ])
104
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
105
  provider, config = await mgr.get_active_provider()
106
- assert provider == "cloudflare"
107
- assert "api.cloudflare.com" in config["api_base"]
108
- assert config["api_key"] == "test-key"
109
  finally:
110
- production_server.CLOUDFLARE_API_KEY = old_cf_key
111
- production_server.CLOUDFLARE_ACCOUNT_ID = old_cf_id
112
 
113
  @pytest.mark.asyncio
114
  async def test_nim_config(self, mock_redis):
@@ -116,7 +133,7 @@ class TestFallbackManager:
116
  old_nim = production_server.NIM_API_BASE
117
  production_server.NIM_API_BASE = "https://custom.nvidia.com/v1"
118
  try:
119
- mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", enabled=True))
120
  provider, config = await mgr.get_active_provider()
121
  assert provider == "nim"
122
  assert config["api_base"] == "https://custom.nvidia.com/v1"
 
1
+ """Integration tests for NIM ↔ Cloudflare ↔ Gemini fallback logic."""
 
 
2
 
3
  import asyncio
4
  import pytest
 
32
  class TestFallbackManager:
33
  @pytest.mark.asyncio
34
  async def test_uses_primary_when_healthy(self, mock_redis):
35
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
36
  provider, config = await mgr.get_active_provider()
37
  assert provider == "nim"
38
  assert config["api_base"] == "https://integrate.api.nvidia.com/v1"
 
43
  {"state": "open", "failures": 5, "last_failure": 9999999999},
44
  {"state": "closed", "failures": 0, "last_failure": 0},
45
  ])
46
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
47
  provider, config = await mgr.get_active_provider()
48
  assert provider == "cloudflare"
49
 
50
  @pytest.mark.asyncio
51
+ async def test_falls_to_tertiary_when_secondary_open(self, mock_redis):
52
+ import production_server
53
+ old_gemini = production_server.GEMINI_API_KEY
54
+ production_server.GEMINI_API_KEY = "test-key"
55
+ try:
56
+ mock_redis.get_circuit_state = AsyncMock(side_effect=[
57
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
58
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
59
+ {"state": "closed", "failures": 0, "last_failure": 0},
60
+ ])
61
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
62
+ provider, config = await mgr.get_active_provider()
63
+ assert provider == "gemini"
64
+ assert "generativelanguage" in config["api_base"]
65
+ assert config["api_key"] == "test-key"
66
+ finally:
67
+ production_server.GEMINI_API_KEY = old_gemini
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_falls_to_mlx_when_all_cloud_down(self, mock_redis):
71
  import production_server
72
  old_mlx = production_server.MLX_ENABLED
73
  production_server.MLX_ENABLED = True
74
  try:
75
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
76
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
77
  {"state": "open", "failures": 5, "last_failure": 9999999999},
78
  {"state": "open", "failures": 5, "last_failure": 9999999999},
79
  {"state": "closed", "failures": 0, "last_failure": 0},
80
  ])
81
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
82
  provider, config = await mgr.get_active_provider()
83
  assert provider == "mlx"
84
  finally:
 
93
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
94
  {"state": "open", "failures": 5, "last_failure": 9999999999},
95
  {"state": "open", "failures": 5, "last_failure": 9999999999},
96
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
97
  ])
98
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
99
  with pytest.raises(HTTPException) as exc_info:
100
  await mgr.get_active_provider()
101
  assert exc_info.value.status_code == 503
 
104
 
105
  @pytest.mark.asyncio
106
  async def test_respects_disabled_fallback(self, mock_redis):
107
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=False))
108
  provider, config = await mgr.get_active_provider()
109
  assert provider == "nim"
110
 
111
  @pytest.mark.asyncio
112
+ async def test_gemini_config(self, mock_redis):
113
  import production_server
114
+ old_gemini = production_server.GEMINI_API_KEY
115
+ production_server.GEMINI_API_KEY = "gemini-test-key"
 
 
116
  try:
117
  mock_redis.get_circuit_state = AsyncMock(side_effect=[
118
+ {"state": "open", "failures": 5, "last_failure": 9999999999},
119
  {"state": "open", "failures": 5, "last_failure": 9999999999},
120
  {"state": "closed", "failures": 0, "last_failure": 0},
121
  ])
122
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
123
  provider, config = await mgr.get_active_provider()
124
+ assert provider == "gemini"
125
+ assert config["rpm_limit"] == 60
126
+ assert config["cost_per_1m_input"] == 0.075
127
  finally:
128
+ production_server.GEMINI_API_KEY = old_gemini
 
129
 
130
  @pytest.mark.asyncio
131
  async def test_nim_config(self, mock_redis):
 
133
  old_nim = production_server.NIM_API_BASE
134
  production_server.NIM_API_BASE = "https://custom.nvidia.com/v1"
135
  try:
136
+ mgr = FallbackManager(mock_redis, FallbackConfig(primary="nim", secondary="cloudflare", tertiary="gemini", enabled=True))
137
  provider, config = await mgr.get_active_provider()
138
  assert provider == "nim"
139
  assert config["api_base"] == "https://custom.nvidia.com/v1"