| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
| from omegaconf import DictConfig |
|
|
| from nemo.export.quantize.quantizer import QUANT_CFG_CHOICES, Quantizer |
|
|
|
|
| @pytest.fixture |
| def basic_quantization_config(): |
| return DictConfig( |
| {'algorithm': 'int8', 'decoder_type': 'llama', 'awq_block_size': 128, 'sq_alpha': 0.5, 'enable_kv_cache': True} |
| ) |
|
|
|
|
| @pytest.fixture |
| def basic_export_config(): |
| return DictConfig( |
| { |
| 'dtype': '16', |
| 'decoder_type': 'llama', |
| 'inference_tensor_parallel': 1, |
| 'inference_pipeline_parallel': 1, |
| 'save_path': '/tmp/model.qnemo', |
| } |
| ) |
|
|
|
|
| class TestQuantizer: |
| def test_init_valid_configs(self, basic_quantization_config, basic_export_config): |
| quantizer = Quantizer(basic_quantization_config, basic_export_config) |
| assert quantizer.quantization_config == basic_quantization_config |
| assert quantizer.export_config == basic_export_config |
| assert quantizer.quant_cfg == QUANT_CFG_CHOICES['int8'] |
|
|
| def test_init_invalid_algorithm(self, basic_quantization_config, basic_export_config): |
| basic_quantization_config.algorithm = 'invalid_algo' |
| with pytest.raises(AssertionError): |
| Quantizer(basic_quantization_config, basic_export_config) |
|
|
| def test_init_invalid_dtype(self, basic_quantization_config, basic_export_config): |
| basic_export_config.dtype = '32' |
| with pytest.raises(AssertionError): |
| Quantizer(basic_quantization_config, basic_export_config) |
|
|
| def test_null_algorithm(self, basic_quantization_config, basic_export_config): |
| basic_quantization_config.algorithm = None |
| quantizer = Quantizer(basic_quantization_config, basic_export_config) |
| assert quantizer.quant_cfg is None |
|
|
| @patch('nemo.export.quantize.quantizer.dist') |
| def test_quantize_method(self, mock_dist, basic_quantization_config, basic_export_config): |
| mock_dist.get_rank.return_value = 0 |
|
|
| |
| mock_model = MagicMock() |
| mock_forward_loop = MagicMock() |
|
|
| quantizer = Quantizer(basic_quantization_config, basic_export_config) |
|
|
| with patch('modelopt.torch.quantization.quantize') as mock_quantize: |
| with patch('modelopt.torch.quantization.print_quant_summary'): |
| quantizer.quantize(mock_model, mock_forward_loop) |
|
|
| |
| mock_quantize.assert_called_once_with(mock_model, QUANT_CFG_CHOICES['int8'], mock_forward_loop) |
|
|
| @patch('nemo.export.quantize.quantizer.dist') |
| def test_modify_model_config(self, mock_dist): |
| mock_config = DictConfig({'sequence_parallel': True}) |
| modified_config = Quantizer.modify_model_config(mock_config) |
|
|
| assert modified_config.sequence_parallel is False |
| assert modified_config.name == 'modelopt' |
| assert modified_config.apply_rope_fusion is False |
|
|
| @patch('nemo.export.quantize.quantizer.dist') |
| @patch('nemo.export.quantize.quantizer.export_tensorrt_llm_checkpoint') |
| def test_export_method(self, mock_export, mock_dist, basic_quantization_config, basic_export_config): |
| mock_dist.get_rank.return_value = 0 |
| mock_model = MagicMock() |
| mock_model.cfg.megatron_amp_O2 = False |
| mock_model.trainer.num_nodes = 1 |
|
|
| quantizer = Quantizer(basic_quantization_config, basic_export_config) |
|
|
| with patch('nemo.export.quantize.quantizer.save_artifacts'): |
| quantizer.export(mock_model) |
|
|
| |
| mock_export.assert_called_once() |
| call_args = mock_export.call_args[1] |
| assert call_args['decoder_type'] == 'llama' |
| assert call_args['inference_tensor_parallel'] == 1 |
| assert call_args['inference_pipeline_parallel'] == 1 |
|
|