| """ |
| Integration tests for ml-intern production server. |
| |
| Tests cover: |
| - Rate limiting across distributed instances |
| - Circuit breaker state transitions |
| - Cache hit/miss behavior |
| - Budget enforcement |
| - Session isolation |
| - Health check endpoints |
| - Graceful shutdown handling |
| """ |
|
|
| import asyncio |
| import hashlib |
| import json |
| import os |
| import sys |
| import time |
| import uuid |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| import pytest |
| import redis.asyncio as aioredis |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| from production_server import ( |
| CircuitBreaker, |
| ConcurrencyLimiter, |
| CostTracker, |
| RedisManager, |
| generate_cache_key, |
| get_provider_from_model, |
| estimate_cost, |
| ChatRequest, |
| ) |
|
|
|
|
| @pytest.fixture |
| async def redis_manager(): |
| manager = MagicMock(spec=RedisManager) |
| manager.get_cache = AsyncMock(return_value=None) |
| manager.set_cache = AsyncMock() |
| manager.delete_cache = AsyncMock() |
| async def check_limit(key, rpm): |
| return True, 0.0 |
| manager.check_rate_limit = check_limit |
| async def get_circuit(provider): |
| return {"state": "closed", "failures": 0, "last_failure": 0} |
| manager.get_circuit_state = get_circuit |
| async def set_circuit(provider, state): |
| pass |
| manager.set_circuit_state = set_circuit |
| yield manager |
|
|
|
|
| @pytest.fixture |
| def concurrency_limiter(): |
| return ConcurrencyLimiter(10) |
|
|
|
|
| class TestProviderResolution: |
| def test_cloud_providers(self): |
| assert get_provider_from_model("anthropic/claude-opus-4") == ("anthropic", "anthropic/claude-opus-4") |
| assert get_provider_from_model("openai/gpt-5") == ("openai", "openai/gpt-5") |
| |
| def test_free_tier_providers(self): |
| assert get_provider_from_model("groq/llama-3.3-70b") == ("groq", "llama-3.3-70b") |
| |
| def test_nim_provider(self): |
| assert get_provider_from_model("nim/llama-3-8b") == ("nim", "llama-3-8b") |
| assert get_provider_from_model("nim/llama-3.1-405b-instruct") == ("nim", "llama-3.1-405b-instruct") |
| |
| def test_local_providers(self): |
| assert get_provider_from_model("ollama/llama3.1") == ("ollama", "llama3.1") |
| assert get_provider_from_model("vllm/llama-3-8b") == ("vllm", "llama-3-8b") |
| assert get_provider_from_model("llamacpp/llama-3-8b") == ("llamacpp", "llama-3-8b") |
| assert get_provider_from_model("lmstudio/llama-3-8b") == ("lmstudio", "llama-3-8b") |
| assert get_provider_from_model("mlx/llama-3-8b") == ("mlx", "llama-3-8b") |
| assert get_provider_from_model("tgi/llama-3-8b") == ("tgi", "llama-3-8b") |
| assert get_provider_from_model("local/my-model") == ("local", "my-model") |
| |
| def test_default_provider(self): |
| assert get_provider_from_model("some-model") == ("huggingface", "some-model") |
|
|
|
|
| class TestCostEstimation: |
| def test_anthropic_cost(self): |
| cost = estimate_cost("anthropic", "claude-opus-4", 1000000, 1000000) |
| assert abs(cost - 90.0) < 1 |
| |
| def test_openai_cost(self): |
| cost = estimate_cost("openai", "gpt-5", 1000000, 1000000) |
| assert abs(cost - 12.5) < 1 |
| |
| def test_free_providers_zero_cost(self): |
| for provider in ["groq", "nim", "ollama", "vllm", "llamacpp", "lmstudio", "mlx", "tgi", "local", "huggingface"]: |
| cost = estimate_cost(provider, "test-model", 1000000, 1000000) |
| assert cost == 0.0, f"Provider {provider} should have zero cost" |
|
|
|
|
| class TestCacheKeyGeneration: |
| def test_deterministic_keys(self): |
| req1 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "Hello"}], |
| temperature=0.7, |
| ) |
| req2 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "Hello"}], |
| temperature=0.7, |
| ) |
| assert generate_cache_key(req1) == generate_cache_key(req2) |
| |
| def test_different_content_different_keys(self): |
| req1 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "Hello"}], |
| ) |
| req2 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "World"}], |
| ) |
| assert generate_cache_key(req1) != generate_cache_key(req2) |
| |
| def test_stream_not_in_cache_key(self): |
| req1 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "Hello"}], |
| stream=False, |
| ) |
| req2 = ChatRequest( |
| model="groq/llama-3.3-70b", |
| messages=[{"role": "user", "content": "Hello"}], |
| stream=True, |
| ) |
| assert generate_cache_key(req1) == generate_cache_key(req2) |
|
|
|
|
| class TestCircuitBreaker: |
| @pytest.mark.asyncio |
| async def test_initially_closed(self, redis_manager): |
| cb = CircuitBreaker(redis_manager, "groq") |
| assert await cb.can_execute() |
| |
| @pytest.mark.asyncio |
| async def test_opens_after_threshold(self, redis_manager): |
| cb = CircuitBreaker(redis_manager, "groq") |
| for _ in range(5): |
| await cb.record_failure() |
| redis_manager.get_circuit_state = AsyncMock(return_value={ |
| "state": "open", |
| "failures": 5, |
| "last_failure": time.time(), |
| }) |
| assert not await cb.can_execute() |
| |
| @pytest.mark.asyncio |
| async def test_half_open_after_timeout(self, redis_manager): |
| cb = CircuitBreaker(redis_manager, "groq") |
| redis_manager.get_circuit_state = AsyncMock(return_value={ |
| "state": "open", |
| "failures": 5, |
| "last_failure": time.time() - 120, |
| }) |
| assert await cb.can_execute() |
| |
| @pytest.mark.asyncio |
| async def test_closes_on_success(self, redis_manager): |
| cb = CircuitBreaker(redis_manager, "groq") |
| redis_manager.get_circuit_state = AsyncMock(return_value={ |
| "state": "half-open", |
| "failures": 0, |
| "last_failure": 0, |
| }) |
| await cb.record_success() |
| redis_manager.get_circuit_state = AsyncMock(return_value={ |
| "state": "closed", |
| "failures": 0, |
| "last_failure": 0, |
| }) |
| assert await cb.can_execute() |
|
|
|
|
| class TestBudgetTracking: |
| def test_can_spend_within_budget(self): |
| tracker = CostTracker("session-1", budget_usd=10.0) |
| assert tracker.can_spend(5.0) |
| |
| def test_cannot_exceed_budget(self): |
| tracker = CostTracker("session-1", budget_usd=10.0) |
| tracker.spent_usd = 8.0 |
| assert not tracker.can_spend(3.0) |
| |
| def test_exact_budget_boundary(self): |
| tracker = CostTracker("session-1", budget_usd=10.0) |
| tracker.spent_usd = 5.0 |
| assert tracker.can_spend(5.0) |
| assert not tracker.can_spend(5.01) |
| |
| def test_zero_budget(self): |
| tracker = CostTracker("session-1", budget_usd=0.0) |
| assert not tracker.can_spend(0.01) |
|
|
|
|
| class TestConcurrencyLimiter: |
| @pytest.mark.asyncio |
| async def test_acquire_release(self): |
| limiter = ConcurrencyLimiter(2) |
| await limiter.acquire() |
| limiter.release() |
| assert True |
| |
| @pytest.mark.asyncio |
| async def test_blocks_at_limit(self): |
| limiter = ConcurrencyLimiter(1) |
| await limiter.acquire() |
| task = asyncio.create_task(limiter.acquire()) |
| await asyncio.sleep(0.1) |
| limiter.release() |
| await asyncio.wait_for(task, timeout=2.0) |
|
|
|
|
| class TestRateLimiting: |
| @pytest.mark.asyncio |
| async def test_token_bucket_allows_requests(self): |
| manager = MagicMock() |
| async def mock_check(key, rpm): |
| return True, 0.0 |
| manager.check_rate_limit = mock_check |
| allowed, retry = await manager.check_rate_limit("groq:session-1", 40) |
| assert allowed |
| assert retry == 0.0 |
| |
| @pytest.mark.asyncio |
| async def test_token_bucket_denies_when_empty(self): |
| manager = MagicMock() |
| async def mock_check_denied(key, rpm): |
| return False, 1.5 |
| manager.check_rate_limit = mock_check_denied |
| allowed, retry = await manager.check_rate_limit("groq:session-1", 40) |
| assert not allowed |
| assert retry > 0 |
|
|
|
|
| class TestEndToEndFlow: |
| @pytest.mark.asyncio |
| async def test_full_request_flow(self, redis_manager): |
| session_id = str(uuid.uuid4()) |
| provider = "groq" |
| model = "llama-3.3-70b-versatile" |
| allowed, _ = await redis_manager.check_rate_limit(f"{provider}:{session_id}", 30) |
| assert allowed |
| tracker = CostTracker(session_id, budget_usd=10.0) |
| estimated_cost = estimate_cost(provider, model, 1000, 500) |
| assert tracker.can_spend(estimated_cost) |
| cb = CircuitBreaker(redis_manager, provider) |
| assert await cb.can_execute() |
| tracker.record_spend(estimated_cost) |
| await cb.record_success() |
| assert tracker.spent_usd > 0 |
| assert tracker.spent_usd <= tracker.budget_usd |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|