raazkumar commited on
Commit
7f8d8d7
·
verified ·
1 Parent(s): f9ffc4e

Upload production/tests/test_integration.py

Browse files
Files changed (1) hide show
  1. production/tests/test_integration.py +268 -0
production/tests/test_integration.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for ml-intern production server.
3
+
4
+ Tests cover:
5
+ - Rate limiting across distributed instances
6
+ - Circuit breaker state transitions
7
+ - Cache hit/miss behavior
8
+ - Budget enforcement
9
+ - Session isolation
10
+ - Health check endpoints
11
+ - Graceful shutdown handling
12
+ """
13
+
14
+ import asyncio
15
+ import hashlib
16
+ import json
17
+ import os
18
+ import sys
19
+ import time
20
+ import uuid
21
+ from unittest.mock import AsyncMock, MagicMock, patch
22
+
23
+ import pytest
24
+ import redis.asyncio as aioredis
25
+
26
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
27
+
28
+ from production_server import (
29
+ CircuitBreaker,
30
+ ConcurrencyLimiter,
31
+ CostTracker,
32
+ RedisManager,
33
+ generate_cache_key,
34
+ get_provider_from_model,
35
+ estimate_cost,
36
+ ChatRequest,
37
+ )
38
+
39
+
40
+ @pytest.fixture
41
+ async def redis_manager():
42
+ manager = MagicMock(spec=RedisManager)
43
+ manager.get_cache = AsyncMock(return_value=None)
44
+ manager.set_cache = AsyncMock()
45
+ manager.delete_cache = AsyncMock()
46
+ async def check_limit(key, rpm):
47
+ return True, 0.0
48
+ manager.check_rate_limit = check_limit
49
+ async def get_circuit(provider):
50
+ return {"state": "closed", "failures": 0, "last_failure": 0}
51
+ manager.get_circuit_state = get_circuit
52
+ async def set_circuit(provider, state):
53
+ pass
54
+ manager.set_circuit_state = set_circuit
55
+ yield manager
56
+
57
+
58
+ @pytest.fixture
59
+ def concurrency_limiter():
60
+ return ConcurrencyLimiter(10)
61
+
62
+
63
+ class TestProviderResolution:
64
+ def test_cloud_providers(self):
65
+ assert get_provider_from_model("anthropic/claude-opus-4") == ("anthropic", "anthropic/claude-opus-4")
66
+ assert get_provider_from_model("openai/gpt-5") == ("openai", "openai/gpt-5")
67
+
68
+ def test_free_tier_providers(self):
69
+ assert get_provider_from_model("groq/llama-3.3-70b") == ("groq", "llama-3.3-70b")
70
+
71
+ def test_nim_provider(self):
72
+ assert get_provider_from_model("nim/llama-3-8b") == ("nim", "llama-3-8b")
73
+ assert get_provider_from_model("nim/llama-3.1-405b-instruct") == ("nim", "llama-3.1-405b-instruct")
74
+
75
+ def test_local_providers(self):
76
+ assert get_provider_from_model("ollama/llama3.1") == ("ollama", "llama3.1")
77
+ assert get_provider_from_model("vllm/llama-3-8b") == ("vllm", "llama-3-8b")
78
+ assert get_provider_from_model("llamacpp/llama-3-8b") == ("llamacpp", "llama-3-8b")
79
+ assert get_provider_from_model("lmstudio/llama-3-8b") == ("lmstudio", "llama-3-8b")
80
+ assert get_provider_from_model("mlx/llama-3-8b") == ("mlx", "llama-3-8b")
81
+ assert get_provider_from_model("tgi/llama-3-8b") == ("tgi", "llama-3-8b")
82
+ assert get_provider_from_model("local/my-model") == ("local", "my-model")
83
+
84
+ def test_default_provider(self):
85
+ assert get_provider_from_model("some-model") == ("huggingface", "some-model")
86
+
87
+
88
+ class TestCostEstimation:
89
+ def test_anthropic_cost(self):
90
+ cost = estimate_cost("anthropic", "claude-opus-4", 1000000, 1000000)
91
+ assert abs(cost - 90.0) < 1
92
+
93
+ def test_openai_cost(self):
94
+ cost = estimate_cost("openai", "gpt-5", 1000000, 1000000)
95
+ assert abs(cost - 12.5) < 1
96
+
97
+ def test_free_providers_zero_cost(self):
98
+ for provider in ["groq", "nim", "ollama", "vllm", "llamacpp", "lmstudio", "mlx", "tgi", "local", "huggingface"]:
99
+ cost = estimate_cost(provider, "test-model", 1000000, 1000000)
100
+ assert cost == 0.0, f"Provider {provider} should have zero cost"
101
+
102
+
103
+ class TestCacheKeyGeneration:
104
+ def test_deterministic_keys(self):
105
+ req1 = ChatRequest(
106
+ model="groq/llama-3.3-70b",
107
+ messages=[{"role": "user", "content": "Hello"}],
108
+ temperature=0.7,
109
+ )
110
+ req2 = ChatRequest(
111
+ model="groq/llama-3.3-70b",
112
+ messages=[{"role": "user", "content": "Hello"}],
113
+ temperature=0.7,
114
+ )
115
+ assert generate_cache_key(req1) == generate_cache_key(req2)
116
+
117
+ def test_different_content_different_keys(self):
118
+ req1 = ChatRequest(
119
+ model="groq/llama-3.3-70b",
120
+ messages=[{"role": "user", "content": "Hello"}],
121
+ )
122
+ req2 = ChatRequest(
123
+ model="groq/llama-3.3-70b",
124
+ messages=[{"role": "user", "content": "World"}],
125
+ )
126
+ assert generate_cache_key(req1) != generate_cache_key(req2)
127
+
128
+ def test_stream_not_in_cache_key(self):
129
+ req1 = ChatRequest(
130
+ model="groq/llama-3.3-70b",
131
+ messages=[{"role": "user", "content": "Hello"}],
132
+ stream=False,
133
+ )
134
+ req2 = ChatRequest(
135
+ model="groq/llama-3.3-70b",
136
+ messages=[{"role": "user", "content": "Hello"}],
137
+ stream=True,
138
+ )
139
+ assert generate_cache_key(req1) == generate_cache_key(req2)
140
+
141
+
142
+ class TestCircuitBreaker:
143
+ @pytest.mark.asyncio
144
+ async def test_initially_closed(self, redis_manager):
145
+ cb = CircuitBreaker(redis_manager, "groq")
146
+ assert await cb.can_execute()
147
+
148
+ @pytest.mark.asyncio
149
+ async def test_opens_after_threshold(self, redis_manager):
150
+ cb = CircuitBreaker(redis_manager, "groq")
151
+ for _ in range(5):
152
+ await cb.record_failure()
153
+ redis_manager.get_circuit_state = AsyncMock(return_value={
154
+ "state": "open",
155
+ "failures": 5,
156
+ "last_failure": time.time(),
157
+ })
158
+ assert not await cb.can_execute()
159
+
160
+ @pytest.mark.asyncio
161
+ async def test_half_open_after_timeout(self, redis_manager):
162
+ cb = CircuitBreaker(redis_manager, "groq")
163
+ redis_manager.get_circuit_state = AsyncMock(return_value={
164
+ "state": "open",
165
+ "failures": 5,
166
+ "last_failure": time.time() - 120,
167
+ })
168
+ assert await cb.can_execute()
169
+
170
+ @pytest.mark.asyncio
171
+ async def test_closes_on_success(self, redis_manager):
172
+ cb = CircuitBreaker(redis_manager, "groq")
173
+ redis_manager.get_circuit_state = AsyncMock(return_value={
174
+ "state": "half-open",
175
+ "failures": 0,
176
+ "last_failure": 0,
177
+ })
178
+ await cb.record_success()
179
+ redis_manager.get_circuit_state = AsyncMock(return_value={
180
+ "state": "closed",
181
+ "failures": 0,
182
+ "last_failure": 0,
183
+ })
184
+ assert await cb.can_execute()
185
+
186
+
187
+ class TestBudgetTracking:
188
+ def test_can_spend_within_budget(self):
189
+ tracker = CostTracker("session-1", budget_usd=10.0)
190
+ assert tracker.can_spend(5.0)
191
+
192
+ def test_cannot_exceed_budget(self):
193
+ tracker = CostTracker("session-1", budget_usd=10.0)
194
+ tracker.spent_usd = 8.0
195
+ assert not tracker.can_spend(3.0)
196
+
197
+ def test_exact_budget_boundary(self):
198
+ tracker = CostTracker("session-1", budget_usd=10.0)
199
+ tracker.spent_usd = 5.0
200
+ assert tracker.can_spend(5.0)
201
+ assert not tracker.can_spend(5.01)
202
+
203
+ def test_zero_budget(self):
204
+ tracker = CostTracker("session-1", budget_usd=0.0)
205
+ assert not tracker.can_spend(0.01)
206
+
207
+
208
+ class TestConcurrencyLimiter:
209
+ @pytest.mark.asyncio
210
+ async def test_acquire_release(self):
211
+ limiter = ConcurrencyLimiter(2)
212
+ await limiter.acquire()
213
+ limiter.release()
214
+ assert True
215
+
216
+ @pytest.mark.asyncio
217
+ async def test_blocks_at_limit(self):
218
+ limiter = ConcurrencyLimiter(1)
219
+ await limiter.acquire()
220
+ task = asyncio.create_task(limiter.acquire())
221
+ await asyncio.sleep(0.1)
222
+ limiter.release()
223
+ await asyncio.wait_for(task, timeout=2.0)
224
+
225
+
226
+ class TestRateLimiting:
227
+ @pytest.mark.asyncio
228
+ async def test_token_bucket_allows_requests(self):
229
+ manager = MagicMock()
230
+ async def mock_check(key, rpm):
231
+ return True, 0.0
232
+ manager.check_rate_limit = mock_check
233
+ allowed, retry = await manager.check_rate_limit("groq:session-1", 40)
234
+ assert allowed
235
+ assert retry == 0.0
236
+
237
+ @pytest.mark.asyncio
238
+ async def test_token_bucket_denies_when_empty(self):
239
+ manager = MagicMock()
240
+ async def mock_check_denied(key, rpm):
241
+ return False, 1.5
242
+ manager.check_rate_limit = mock_check_denied
243
+ allowed, retry = await manager.check_rate_limit("groq:session-1", 40)
244
+ assert not allowed
245
+ assert retry > 0
246
+
247
+
248
+ class TestEndToEndFlow:
249
+ @pytest.mark.asyncio
250
+ async def test_full_request_flow(self, redis_manager):
251
+ session_id = str(uuid.uuid4())
252
+ provider = "groq"
253
+ model = "llama-3.3-70b-versatile"
254
+ allowed, _ = await redis_manager.check_rate_limit(f"{provider}:{session_id}", 30)
255
+ assert allowed
256
+ tracker = CostTracker(session_id, budget_usd=10.0)
257
+ estimated_cost = estimate_cost(provider, model, 1000, 500)
258
+ assert tracker.can_spend(estimated_cost)
259
+ cb = CircuitBreaker(redis_manager, provider)
260
+ assert await cb.can_execute()
261
+ tracker.record_spend(estimated_cost)
262
+ await cb.record_success()
263
+ assert tracker.spent_usd > 0
264
+ assert tracker.spent_usd <= tracker.budget_usd
265
+
266
+
267
+ if __name__ == "__main__":
268
+ pytest.main([__file__, "-v"])