from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient from api.app import app from providers.nvidia_nim import NvidiaNimProvider # Mock provider mock_provider = MagicMock(spec=NvidiaNimProvider) # Track stream_response calls for test_model_mapping _stream_response_calls = [] async def _mock_stream_response(*args, **kwargs): """Minimal async generator for streaming tests.""" _stream_response_calls.append((args, kwargs)) yield "event: message_start\ndata: {}\n\n" yield "[DONE]\n\n" mock_provider.stream_response = _mock_stream_response # Patch get_provider_for_type to always return mock_provider _patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider) _patcher.start() client = TestClient(app) def test_root(): response = client.get("/") assert response.status_code == 200 assert response.json()["status"] == "ok" def test_health(): response = client.get("/health") assert response.status_code == 200 assert response.json()["status"] == "healthy" def test_create_message_stream(): """Create message returns streaming response.""" payload = { "model": "claude-3-sonnet", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 100, "stream": True, } response = client.post("/v1/messages", json=payload) assert response.status_code == 200 assert "text/event-stream" in response.headers.get("content-type", "") content = b"".join(response.iter_bytes()) assert b"message_start" in content or b"event:" in content def test_model_mapping(): # Test Haiku mapping _stream_response_calls.clear() payload_haiku = { "model": "claude-3-haiku-20240307", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 100, "stream": True, } client.post("/v1/messages", json=payload_haiku) assert len(_stream_response_calls) == 1 args = _stream_response_calls[0][0] assert args[0].model != "claude-3-haiku-20240307" assert args[0].original_model == "claude-3-haiku-20240307" def test_error_fallbacks(): from providers.exceptions import ( AuthenticationError, OverloadedError, RateLimitError, ) base_payload = { "model": "test", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 10, "stream": True, } def _raise_auth(*args, **kwargs): raise AuthenticationError("Invalid Key") def _raise_rate_limit(*args, **kwargs): raise RateLimitError("Too Many Requests") def _raise_overloaded(*args, **kwargs): raise OverloadedError("Server Overloaded") # 1. Authentication Error (401) mock_provider.stream_response = _raise_auth response = client.post("/v1/messages", json=base_payload) assert response.status_code == 401 assert response.json()["error"]["type"] == "authentication_error" # 2. Rate Limit (429) mock_provider.stream_response = _raise_rate_limit response = client.post("/v1/messages", json=base_payload) assert response.status_code == 429 assert response.json()["error"]["type"] == "rate_limit_error" # 3. Overloaded (529) mock_provider.stream_response = _raise_overloaded response = client.post("/v1/messages", json=base_payload) assert response.status_code == 529 assert response.json()["error"]["type"] == "overloaded_error" # Reset for subsequent tests mock_provider.stream_response = _mock_stream_response def test_generic_exception_returns_500(): """Non-ProviderError exceptions are caught and returned as HTTPException(500).""" def _raise_runtime(*args, **kwargs): raise RuntimeError("unexpected crash") mock_provider.stream_response = _raise_runtime response = client.post( "/v1/messages", json={ "model": "test", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 10, "stream": True, }, ) assert response.status_code == 500 mock_provider.stream_response = _mock_stream_response def test_generic_exception_with_status_code(): """Generic exception with status_code attribute uses that status (getattr fallback).""" class ExceptionWithStatus(RuntimeError): def __init__(self, msg: str, status_code: int = 500): super().__init__(msg) self.status_code = status_code def _raise_with_status(*args, **kwargs): raise ExceptionWithStatus("bad gateway", 502) mock_provider.stream_response = _raise_with_status response = client.post( "/v1/messages", json={ "model": "test", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 10, "stream": True, }, ) assert response.status_code == 502 mock_provider.stream_response = _mock_stream_response def test_generic_exception_empty_message_returns_non_empty_detail(): """Exceptions with empty __str__ still return a readable HTTP detail.""" class SilentError(RuntimeError): def __str__(self): return "" def _raise_silent(*args, **kwargs): raise SilentError() mock_provider.stream_response = _raise_silent response = client.post( "/v1/messages", json={ "model": "test", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 10, "stream": True, }, ) assert response.status_code == 500 assert response.json()["detail"] != "" mock_provider.stream_response = _mock_stream_response def test_count_tokens_endpoint(): """count_tokens endpoint returns token count.""" response = client.post( "/v1/messages/count_tokens", json={"model": "test", "messages": [{"role": "user", "content": "Hello"}]}, ) assert response.status_code == 200 assert "input_tokens" in response.json() def test_stop_endpoint_no_handler_no_cli_503(): """POST /stop without handler or cli_manager returns 503.""" # Ensure no handler or cli_manager on app state if hasattr(app.state, "message_handler"): delattr(app.state, "message_handler") if hasattr(app.state, "cli_manager"): delattr(app.state, "cli_manager") response = client.post("/stop") assert response.status_code == 503