| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import tempfile |
| from unittest.mock import MagicMock, patch |
|
|
| import numpy as np |
| import pytest |
| from PIL import Image |
|
|
| from nemo.deploy.multimodal.query_multimodal import NemoQueryMultimodal |
|
|
|
|
| class TestNemoQueryMultimodal: |
| @pytest.fixture |
| def query_multimodal(self): |
| return NemoQueryMultimodal(url="localhost", model_name="test_model", model_type="neva") |
|
|
| @pytest.fixture |
| def mock_image(self): |
| |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: |
| img = Image.new('RGB', (100, 100), color='red') |
| img.save(tmp.name) |
| return tmp.name |
|
|
| @pytest.fixture |
| def mock_video(self): |
| |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: |
| |
| return tmp.name |
|
|
| @pytest.fixture |
| def mock_audio(self): |
| |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: |
| |
| return tmp.name |
|
|
| def test_init(self): |
| nq = NemoQueryMultimodal(url="localhost", model_name="test_model", model_type="neva") |
| assert nq.url == "localhost" |
| assert nq.model_name == "test_model" |
| assert nq.model_type == "neva" |
|
|
| def test_setup_media_image_local(self, query_multimodal, mock_image): |
| result = query_multimodal.setup_media(mock_image) |
| assert isinstance(result, np.ndarray) |
| assert result.shape[0] == 1 |
| os.unlink(mock_image) |
|
|
| @patch('requests.get') |
| def test_setup_media_image_url(self, mock_get, query_multimodal): |
| |
| mock_response = MagicMock() |
| mock_response.content = b"fake_image_data" |
| mock_get.return_value = mock_response |
|
|
| |
| with patch('PIL.Image.open') as mock_image_open: |
| mock_image = MagicMock() |
| mock_image.convert.return_value = mock_image |
| mock_image_open.return_value = mock_image |
|
|
| result = query_multimodal.setup_media("http://example.com/image.jpg") |
| assert isinstance(result, np.ndarray) |
| assert result.shape[0] == 1 |
|
|
| def test_frame_len(self, query_multimodal): |
| |
| frames = [np.zeros((100, 100, 3)) for _ in range(100)] |
| assert query_multimodal.frame_len(frames) == 100 |
|
|
| |
| frames = [np.zeros((100, 100, 3)) for _ in range(300)] |
| result = query_multimodal.frame_len(frames) |
| assert result <= 256 |
|
|
| def test_get_subsampled_frames(self, query_multimodal): |
| frames = [np.zeros((100, 100, 3)) for _ in range(10)] |
| subsample_len = 5 |
| result = query_multimodal.get_subsampled_frames(frames, subsample_len) |
| assert len(result) == subsample_len |
|
|
| @patch('nemo.deploy.multimodal.query_multimodal.ModelClient') |
| def test_query(self, mock_model_client, query_multimodal, mock_image): |
| |
| mock_client_instance = MagicMock() |
| mock_client_instance.infer_batch.return_value = {"outputs": np.array(["test response"])} |
| mock_client_instance.model_config.outputs = [MagicMock(dtype=np.bytes_)] |
| mock_model_client.return_value.__enter__.return_value = mock_client_instance |
|
|
| result = query_multimodal.query( |
| input_text="test prompt", |
| input_media=mock_image, |
| max_output_len=30, |
| top_k=1, |
| top_p=0.0, |
| temperature=1.0, |
| ) |
|
|
| assert isinstance(result, np.ndarray) |
| assert result[0] == "test response" |
| os.unlink(mock_image) |
|
|
| @patch('nemo.deploy.multimodal.query_multimodal.VideoReader') |
| def test_setup_media_video(self, mock_video_reader, mock_video): |
| nq = NemoQueryMultimodal(url="localhost", model_name="test_model", model_type="video-neva") |
|
|
| |
| mock_frames = [MagicMock(asnumpy=lambda: np.zeros((100, 100, 3))) for _ in range(10)] |
| mock_video_reader.return_value = mock_frames |
|
|
| result = nq.setup_media(mock_video) |
| assert isinstance(result, np.ndarray) |
| os.unlink(mock_video) |
|
|
| @patch('soundfile.read') |
| def test_setup_media_audio(self, mock_sf_read, mock_audio): |
| nq = NemoQueryMultimodal(url="localhost", model_name="test_model", model_type="salm") |
|
|
| |
| mock_sf_read.return_value = (np.zeros(1000), 16000) |
|
|
| result = nq.setup_media(mock_audio) |
| assert isinstance(result, dict) |
| assert "input_signal" in result |
| assert "input_signal_length" in result |
| os.unlink(mock_audio) |
|
|