Claude_Code / tests /api /test_api.py
Jainish1808
Move project files to repository root for Hugging Face Space
bf177ff
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from api.app import app
from providers.nvidia_nim import NvidiaNimProvider
# Mock provider
mock_provider = MagicMock(spec=NvidiaNimProvider)
# Track stream_response calls for test_model_mapping
_stream_response_calls = []
async def _mock_stream_response(*args, **kwargs):
"""Minimal async generator for streaming tests."""
_stream_response_calls.append((args, kwargs))
yield "event: message_start\ndata: {}\n\n"
yield "[DONE]\n\n"
mock_provider.stream_response = _mock_stream_response
# Patch get_provider_for_type to always return mock_provider
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
_patcher.start()
client = TestClient(app)
def test_root():
response = client.get("/")
assert response.status_code == 200
assert response.json()["status"] == "ok"
def test_health():
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_create_message_stream():
"""Create message returns streaming response."""
payload = {
"model": "claude-3-sonnet",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 100,
"stream": True,
}
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
assert "text/event-stream" in response.headers.get("content-type", "")
content = b"".join(response.iter_bytes())
assert b"message_start" in content or b"event:" in content
def test_model_mapping():
# Test Haiku mapping
_stream_response_calls.clear()
payload_haiku = {
"model": "claude-3-haiku-20240307",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 100,
"stream": True,
}
client.post("/v1/messages", json=payload_haiku)
assert len(_stream_response_calls) == 1
args = _stream_response_calls[0][0]
assert args[0].model != "claude-3-haiku-20240307"
assert args[0].original_model == "claude-3-haiku-20240307"
def test_error_fallbacks():
from providers.exceptions import (
AuthenticationError,
OverloadedError,
RateLimitError,
)
base_payload = {
"model": "test",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"stream": True,
}
def _raise_auth(*args, **kwargs):
raise AuthenticationError("Invalid Key")
def _raise_rate_limit(*args, **kwargs):
raise RateLimitError("Too Many Requests")
def _raise_overloaded(*args, **kwargs):
raise OverloadedError("Server Overloaded")
# 1. Authentication Error (401)
mock_provider.stream_response = _raise_auth
response = client.post("/v1/messages", json=base_payload)
assert response.status_code == 401
assert response.json()["error"]["type"] == "authentication_error"
# 2. Rate Limit (429)
mock_provider.stream_response = _raise_rate_limit
response = client.post("/v1/messages", json=base_payload)
assert response.status_code == 429
assert response.json()["error"]["type"] == "rate_limit_error"
# 3. Overloaded (529)
mock_provider.stream_response = _raise_overloaded
response = client.post("/v1/messages", json=base_payload)
assert response.status_code == 529
assert response.json()["error"]["type"] == "overloaded_error"
# Reset for subsequent tests
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_returns_500():
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
def _raise_runtime(*args, **kwargs):
raise RuntimeError("unexpected crash")
mock_provider.stream_response = _raise_runtime
response = client.post(
"/v1/messages",
json={
"model": "test",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"stream": True,
},
)
assert response.status_code == 500
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_with_status_code():
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
class ExceptionWithStatus(RuntimeError):
def __init__(self, msg: str, status_code: int = 500):
super().__init__(msg)
self.status_code = status_code
def _raise_with_status(*args, **kwargs):
raise ExceptionWithStatus("bad gateway", 502)
mock_provider.stream_response = _raise_with_status
response = client.post(
"/v1/messages",
json={
"model": "test",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"stream": True,
},
)
assert response.status_code == 502
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_empty_message_returns_non_empty_detail():
"""Exceptions with empty __str__ still return a readable HTTP detail."""
class SilentError(RuntimeError):
def __str__(self):
return ""
def _raise_silent(*args, **kwargs):
raise SilentError()
mock_provider.stream_response = _raise_silent
response = client.post(
"/v1/messages",
json={
"model": "test",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"stream": True,
},
)
assert response.status_code == 500
assert response.json()["detail"] != ""
mock_provider.stream_response = _mock_stream_response
def test_count_tokens_endpoint():
"""count_tokens endpoint returns token count."""
response = client.post(
"/v1/messages/count_tokens",
json={"model": "test", "messages": [{"role": "user", "content": "Hello"}]},
)
assert response.status_code == 200
assert "input_tokens" in response.json()
def test_stop_endpoint_no_handler_no_cli_503():
"""POST /stop without handler or cli_manager returns 503."""
# Ensure no handler or cli_manager on app state
if hasattr(app.state, "message_handler"):
delattr(app.state, "message_handler")
if hasattr(app.state, "cli_manager"):
delattr(app.state, "cli_manager")
response = client.post("/stop")
assert response.status_code == 503