| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import tarfile |
| import tempfile |
| import typing |
|
|
| import numpy as np |
| import pytest |
| import torch |
| from PIL import Image |
| from pytriton.model_config import Tensor |
|
|
| from nemo.deploy.utils import ( |
| NEMO1, |
| NEMO2, |
| broadcast_list, |
| cast_output, |
| ndarray2img, |
| nemo_checkpoint_version, |
| str_list2numpy, |
| str_ndarray2list, |
| typedict2tensor, |
| ) |
|
|
|
|
| class TestTypedict2Tensor: |
| class SampleTypedict: |
| int_field: int |
| float_field: float |
| bool_field: bool |
| str_field: str |
| int_list: typing.List[int] |
| float_list: typing.List[float] |
| bool_list: typing.List[bool] |
| str_list: typing.List[str] |
|
|
| def test_typedict2tensor_basic(self): |
| tensors = typedict2tensor(self.SampleTypedict) |
| assert len(tensors) == 8 |
| assert all(isinstance(t, Tensor) for t in tensors) |
|
|
| |
| int_tensor = next(t for t in tensors if t.name == "int_field") |
| assert int_tensor.dtype == np.int32 |
| assert int_tensor.shape == (1,) |
|
|
| |
| float_tensor = next(t for t in tensors if t.name == "float_field") |
| assert float_tensor.dtype == np.float32 |
| assert float_tensor.shape == (1,) |
|
|
| |
| bool_tensor = next(t for t in tensors if t.name == "bool_field") |
| assert bool_tensor.dtype == np.bool_ |
| assert bool_tensor.shape == (1,) |
|
|
| |
| str_tensor = next(t for t in tensors if t.name == "str_field") |
| assert str_tensor.dtype == bytes |
| assert str_tensor.shape == (1,) |
|
|
| def test_typedict2tensor_with_overwrite(self): |
| overwrite_kwargs = {"optional": True} |
| tensors = typedict2tensor(self.SampleTypedict, overwrite_kwargs=overwrite_kwargs) |
| assert all(t.optional for t in tensors) |
|
|
| def test_typedict2tensor_list_types(self): |
| tensors = typedict2tensor(self.SampleTypedict) |
|
|
| |
| int_list_tensor = next(t for t in tensors if t.name == "int_list") |
| assert int_list_tensor.dtype == np.int32 |
| assert int_list_tensor.shape == (1,) |
|
|
| |
| float_list_tensor = next(t for t in tensors if t.name == "float_list") |
| assert float_list_tensor.dtype == np.float32 |
| assert float_list_tensor.shape == (1,) |
|
|
| |
| bool_list_tensor = next(t for t in tensors if t.name == "bool_list") |
| assert bool_list_tensor.dtype == np.bool_ |
| assert bool_list_tensor.shape == (1,) |
|
|
| |
| str_list_tensor = next(t for t in tensors if t.name == "str_list") |
| assert str_list_tensor.dtype == bytes |
| assert str_list_tensor.shape == (1,) |
|
|
|
|
| class TestNemoCheckpointVersion: |
| def test_nemo2_checkpoint_dir(self): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| |
| os.makedirs(os.path.join(tmpdir, "context")) |
| os.makedirs(os.path.join(tmpdir, "weights")) |
| assert nemo_checkpoint_version(tmpdir) == NEMO2 |
|
|
| def test_nemo1_checkpoint_dir(self): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| |
| assert nemo_checkpoint_version(tmpdir) == NEMO1 |
|
|
| def test_nemo2_checkpoint_tar(self): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| tar_path = os.path.join(tmpdir, "checkpoint.tar") |
| with tarfile.open(tar_path, "w") as tar: |
| |
| context_info = tarfile.TarInfo("context") |
| context_info.type = tarfile.DIRTYPE |
| tar.addfile(context_info) |
|
|
| weights_info = tarfile.TarInfo("weights") |
| weights_info.type = tarfile.DIRTYPE |
| tar.addfile(weights_info) |
|
|
| assert nemo_checkpoint_version(tar_path) == NEMO2 |
|
|
| def test_nemo1_checkpoint_tar(self): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| tar_path = os.path.join(tmpdir, "checkpoint.tar") |
| with tarfile.open(tar_path, "w") as tar: |
| |
| pass |
|
|
| assert nemo_checkpoint_version(tar_path) == NEMO1 |
|
|
|
|
| class TestStringConversions: |
| def test_str_list2numpy(self): |
| input_list = ["hello", "world", "test"] |
| result = str_list2numpy(input_list) |
| assert isinstance(result, np.ndarray) |
| assert result.shape == (3, 1) |
| assert all(isinstance(x, bytes) for x in result.flatten()) |
|
|
| def test_str_ndarray2list(self): |
| input_array = np.array([b"hello", b"world", b"test"]).reshape(3, 1) |
| result = str_ndarray2list(input_array) |
| assert isinstance(result, list) |
| assert result == ["hello", "world", "test"] |
|
|
| def test_str_conversion_roundtrip(self): |
| input_list = ["hello", "world", "test"] |
| numpy_array = str_list2numpy(input_list) |
| output_list = str_ndarray2list(numpy_array) |
| assert input_list == output_list |
|
|
|
|
| class TestImageConversions: |
| def test_ndarray2img(self): |
| |
| img_array = np.random.randint(0, 255, size=(2, 100, 100, 3), dtype=np.uint8) |
| result = ndarray2img(img_array) |
|
|
| assert isinstance(result, list) |
| assert len(result) == 2 |
| assert all(isinstance(img, Image.Image) for img in result) |
| assert all(img.size == (100, 100) for img in result) |
|
|
|
|
| class TestCastOutput: |
| def test_cast_tensor(self): |
| input_tensor = torch.tensor([1, 2, 3]) |
| result = cast_output(input_tensor, np.int32) |
| assert isinstance(result, np.ndarray) |
| assert result.dtype == np.int32 |
| assert result.shape == (3, 1) |
|
|
| def test_cast_numpy(self): |
| input_array = np.array([1, 2, 3]) |
| result = cast_output(input_array, np.float32) |
| assert isinstance(result, np.ndarray) |
| assert result.dtype == np.float32 |
| assert result.shape == (3, 1) |
|
|
| def test_cast_string(self): |
| input_list = ["hello", "world"] |
| result = cast_output(input_list, bytes) |
| assert isinstance(result, np.ndarray) |
| assert result.shape == (2, 1) |
|
|
| def test_cast_1d_to_2d(self): |
| input_array = np.array([1, 2, 3]) |
| result = cast_output(input_array, np.int32) |
| assert result.ndim == 2 |
| assert result.shape == (3, 1) |
|
|
|
|
| class TestBroadcastList: |
| def test_broadcast_list_no_distributed(self): |
| with pytest.raises(RuntimeError, match="Distributed environment is not initialized"): |
| broadcast_list(["test"]) |
|
|
| def test_broadcast_list_distributed(self, monkeypatch): |
| |
| monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) |
| monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0) |
|
|
| |
| def mock_broadcast_object_list(object_list, src, group=None): |
| if src == 0: |
| object_list[0] = ["test"] |
|
|
| monkeypatch.setattr(torch.distributed, "broadcast_object_list", mock_broadcast_object_list) |
|
|
| result = broadcast_list(["test"]) |
| assert result == ["test"] |
|
|