| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| from unittest.mock import MagicMock, patch |
|
|
| import numpy as np |
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from nemo.deploy.service.fastapi_interface_to_pytriton import ( |
| ChatCompletionRequest, |
| CompletionRequest, |
| TritonSettings, |
| _helper_fun, |
| app, |
| convert_numpy, |
| dict_to_str, |
| query_llm_async, |
| ) |
| from nemo.deploy.service.rest_model_api import CompletionRequest as RestCompletionRequest |
| from nemo.deploy.service.rest_model_api import TritonSettings as RestTritonSettings |
| from nemo.deploy.service.rest_model_api import app as rest_app |
|
|
|
|
| @pytest.fixture |
| def client(): |
| return TestClient(app) |
|
|
|
|
| @pytest.fixture |
| def mock_triton_settings(): |
| with patch('nemo.deploy.service.fastapi_interface_to_pytriton.TritonSettings') as mock: |
| instance = mock.return_value |
| instance.triton_service_port = 8000 |
| instance.triton_service_ip = "localhost" |
| yield instance |
|
|
|
|
| @pytest.fixture |
| def rest_client(): |
| return TestClient(rest_app) |
|
|
|
|
| @pytest.fixture |
| def mock_rest_triton_settings(): |
| with patch('nemo.deploy.service.rest_model_api.TritonSettings') as mock: |
| instance = mock.return_value |
| instance.triton_service_port = 8080 |
| instance.triton_service_ip = "localhost" |
| instance.triton_request_timeout = 60 |
| instance.openai_format_response = False |
| instance.output_generation_logits = False |
| yield instance |
|
|
|
|
| class TestTritonSettings: |
| def test_default_values(self): |
| with patch.dict(os.environ, {}, clear=True): |
| settings = TritonSettings() |
| assert settings.triton_service_port == 8000 |
| assert settings.triton_service_ip == "0.0.0.0" |
|
|
| def test_custom_values(self): |
| with patch.dict(os.environ, {'TRITON_PORT': '9000', 'TRITON_HTTP_ADDRESS': '127.0.0.1'}, clear=True): |
| settings = TritonSettings() |
| assert settings.triton_service_port == 9000 |
| assert settings.triton_service_ip == "127.0.0.1" |
|
|
|
|
| class TestCompletionRequest: |
| def test_default_completions_values(self): |
| request = CompletionRequest(model="test_model", prompt="test prompt") |
| assert request.model == "test_model" |
| assert request.prompt == "test prompt" |
| assert request.max_tokens == 512 |
| assert request.temperature == 1.0 |
| assert request.top_p == 0.0 |
| assert request.top_k == 0 |
| assert request.logprobs is None |
| assert request.echo is False |
|
|
| def test_default_chat_values(self): |
| request = ChatCompletionRequest(model="test_model", messages=[{"role": "user", "content": "test message"}]) |
| assert request.model == "test_model" |
| assert request.messages == [{"role": "user", "content": "test message"}] |
| assert request.max_tokens == 512 |
| assert request.temperature == 1.0 |
| assert request.top_p == 0.0 |
| assert request.top_k == 0 |
|
|
| def test_greedy_params(self): |
| request = CompletionRequest(model="test_model", prompt="test prompt", temperature=0.0, top_p=0.0) |
| assert request.top_k == 1 |
|
|
|
|
| class TestHealthEndpoints: |
| def test_health_check(self, client): |
| response = client.get("/v1/health") |
| assert response.status_code == 200 |
| assert response.json() == {"status": "ok"} |
|
|
|
|
| class TestUtilityFunctions: |
| def test_convert_numpy(self): |
| |
| arr = np.array([1, 2, 3]) |
| assert convert_numpy(arr) == [1, 2, 3] |
|
|
| |
| nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}} |
| assert convert_numpy(nested) == {"a": [1, 2], "b": {"c": [3, 4]}} |
|
|
| |
| lst = [np.array([1, 2]), np.array([3, 4])] |
| assert convert_numpy(lst) == [[1, 2], [3, 4]] |
|
|
| def test_dict_to_str(self): |
| test_dict = {"key": "value", "number": 42} |
| result = dict_to_str(test_dict) |
| assert isinstance(result, str) |
| assert json.loads(result) == test_dict |
|
|
|
|
| class TestLLMQueryFunctions: |
| def test_helper_fun(self): |
| mock_nq = MagicMock() |
| mock_nq.query_llm.return_value = {"test": "response"} |
|
|
| with patch('nemo.deploy.service.fastapi_interface_to_pytriton.NemoQueryLLMPyTorch', return_value=mock_nq): |
| result = _helper_fun( |
| url="http://test", |
| model="test_model", |
| prompts=["test prompt"], |
| temperature=0.7, |
| top_k=10, |
| top_p=0.9, |
| compute_logprob=True, |
| max_length=100, |
| apply_chat_template=False, |
| echo=False, |
| n_top_logprobs=0, |
| ) |
| assert result == {"test": "response"} |
| mock_nq.query_llm.assert_called_once() |
|
|
| def test_query_llm_async(self): |
| mock_result = {"test": "response"} |
| with patch('nemo.deploy.service.fastapi_interface_to_pytriton._helper_fun', return_value=mock_result): |
| |
| import asyncio |
|
|
| loop = asyncio.get_event_loop() |
| result = loop.run_until_complete( |
| query_llm_async( |
| url="http://test", |
| model="test_model", |
| prompts=["test prompt"], |
| temperature=0.7, |
| top_k=10, |
| top_p=0.9, |
| compute_logprob=True, |
| max_length=100, |
| apply_chat_template=False, |
| echo=False, |
| n_top_logprobs=0, |
| ) |
| ) |
| assert result == mock_result |
|
|
|
|
| class TestAPIEndpoints: |
| def test_completions_v1(self, client): |
| mock_output = { |
| "choices": [ |
| { |
| "text": [["test response"]], |
| "logprobs": {"token_logprobs": [[1.0, 2.0]], "top_logprobs": [[{"a": 0.5}, {"b": 0.5}]]}, |
| } |
| ] |
| } |
|
|
| with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): |
| response = client.post( |
| "/v1/completions/", json={"model": "test_model", "prompt": "test prompt", "logprobs": 1} |
| ) |
| assert response.status_code == 200 |
| data = response.json() |
| assert data["choices"][0]["text"] == "test response" |
| assert "logprobs" in data["choices"][0] |
|
|
| def test_chat_completions_v1(self, client): |
| mock_output = {"choices": [{"text": [["test response"]]}]} |
|
|
| with patch('nemo.deploy.service.fastapi_interface_to_pytriton.query_llm_async', return_value=mock_output): |
| response = client.post( |
| "/v1/chat/completions/", |
| json={"model": "test_model", "messages": [{"role": "user", "content": "test message"}]}, |
| ) |
| assert response.status_code == 200 |
| data = response.json() |
| assert data["choices"][0]["message"]["role"] == "assistant" |
| assert data["choices"][0]["message"]["content"] == "test response" |
|
|
|
|
| class TestRestTritonSettings: |
| def test_default_values(self): |
| with patch.dict(os.environ, {}, clear=True): |
| settings = RestTritonSettings() |
| assert settings.triton_service_port == 8080 |
| assert settings.triton_service_ip == "0.0.0.0" |
| assert settings.triton_request_timeout == 60 |
| assert settings.openai_format_response is False |
| assert settings.output_generation_logits is False |
|
|
| def test_custom_values(self): |
| with patch.dict( |
| os.environ, |
| { |
| 'TRITON_PORT': '9000', |
| 'TRITON_HTTP_ADDRESS': '127.0.0.1', |
| 'TRITON_REQUEST_TIMEOUT': '120', |
| 'OPENAI_FORMAT_RESPONSE': 'True', |
| 'OUTPUT_GENERATION_LOGITS': 'True', |
| }, |
| clear=True, |
| ): |
| settings = RestTritonSettings() |
| assert settings.triton_service_port == 9000 |
| assert settings.triton_service_ip == "127.0.0.1" |
| assert settings.triton_request_timeout == 120 |
| assert settings.openai_format_response is True |
| assert settings.output_generation_logits is True |
|
|
|
|
| class TestRestCompletionRequest: |
| def test_default_values(self): |
| request = RestCompletionRequest(model="test_model", prompt="test prompt") |
| assert request.model == "test_model" |
| assert request.prompt == "test prompt" |
| assert request.max_tokens == 512 |
| assert request.temperature == 1.0 |
| assert request.top_p == 0.0 |
| assert request.top_k == 1 |
| assert request.stream is False |
| assert request.stop is None |
| assert request.frequency_penalty == 1.0 |
|
|
|
|
| class TestRestHealthEndpoints: |
| def test_health_check(self, rest_client): |
| response = rest_client.get("/v1/health") |
| assert response.status_code == 200 |
| assert response.json() == {"status": "ok"} |
|
|
| def test_triton_health_success(self, rest_client): |
| with patch('requests.get') as mock_get: |
| mock_response = MagicMock() |
| mock_response.status_code = 200 |
| mock_get.return_value = mock_response |
|
|
| response = rest_client.get("/v1/triton_health") |
| assert response.status_code == 200 |
| assert response.json() == {"status": "Triton server is reachable and ready"} |
|
|
|
|
| class TestRestCompletionsEndpoint: |
| def test_completions_success(self, rest_client): |
| mock_output = [["test response"]] |
| with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
| mock_instance = mock_llm.return_value |
| mock_instance.query_llm.return_value = mock_output |
|
|
| response = rest_client.post( |
| "/v1/completions/", |
| json={ |
| "model": "test_model", |
| "prompt": "test prompt", |
| "max_tokens": 100, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "top_k": 10, |
| }, |
| ) |
| assert response.status_code == 200 |
| assert response.json() == {"output": "test response"} |
|
|
| def test_completions_standard_format(self, rest_client, mock_rest_triton_settings): |
| mock_output = [["test response"]] |
| mock_rest_triton_settings.openai_format_response = False |
|
|
| with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
| mock_instance = mock_llm.return_value |
| mock_instance.query_llm.return_value = mock_output |
|
|
| response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) |
| assert response.status_code == 200 |
| assert response.json() == {"output": "test response"} |
|
|
| def test_completions_error_handling(self, rest_client): |
| with patch('nemo.deploy.service.rest_model_api.NemoQueryLLM') as mock_llm: |
| mock_instance = mock_llm.return_value |
| mock_instance.query_llm.side_effect = Exception("Test error") |
|
|
| response = rest_client.post("/v1/completions/", json={"model": "test_model", "prompt": "test prompt"}) |
| assert response.status_code == 200 |
| assert response.json() == {"error": "An exception occurred"} |
|
|