| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| from verl.utils.profiler.config import NPUToolConfig, ProfilerConfig |
| from verl.utils.profiler.mstx_profile import NPUProfiler |
|
|
|
|
| class TestNPUProfilerInitialization(unittest.TestCase): |
| def setUp(self): |
| NPUProfiler._define_count = 0 |
|
|
| def test_init_with_default_config(self): |
| tool_config = NPUToolConfig() |
| profiler = NPUProfiler(rank=0, config=None, tool_config=tool_config) |
| self.assertFalse(profiler.enable) |
| self.assertFalse(hasattr(profiler, "profile_npu")) |
|
|
| def test_init_with_disabled_config(self): |
| config = ProfilerConfig(enable=False) |
| tool_config = NPUToolConfig() |
| profiler = NPUProfiler(rank=0, config=config, tool_config=tool_config) |
| self.assertFalse(profiler.enable) |
| self.assertFalse(hasattr(profiler, "profile_npu")) |
|
|
| def test_init_with_all_ranks_true(self): |
| config = ProfilerConfig(enable=True, all_ranks=True) |
| tool_config = NPUToolConfig() |
| profiler = NPUProfiler(rank=0, config=config, tool_config=tool_config) |
| self.assertTrue(profiler.this_rank) |
|
|
| def test_init_with_ranks_list(self): |
| config = ProfilerConfig(enable=True, ranks=[1, 2]) |
| tool_config = NPUToolConfig() |
| profiler = NPUProfiler(rank=1, config=config, tool_config=tool_config) |
| self.assertTrue(profiler.this_rank) |
|
|
| def test_init_with_rank_not_in_ranks(self): |
| config = ProfilerConfig(enable=True, ranks=[1, 2]) |
| tool_config = NPUToolConfig() |
| profiler = NPUProfiler(rank=3, config=config, tool_config=tool_config) |
| self.assertFalse(profiler.this_rank) |
|
|
|
|
| class TestNPUProfilerStart(unittest.TestCase): |
| def setUp(self): |
| NPUProfiler._define_count = 0 |
| self.config = ProfilerConfig(enable=True, ranks=[0]) |
| self.tool_config = NPUToolConfig(discrete=False) |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_start_when_enabled_and_this_rank(self, mock_get_profiler): |
| profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config) |
| profiler.start(role="worker", profile_step="1") |
| self.assertTrue(profiler.this_step) |
| self.assertEqual(NPUProfiler._define_count, 1) |
| mock_get_profiler.assert_called_once() |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_start_when_not_this_rank(self, mock_get_profiler): |
| profiler = NPUProfiler(rank=1, config=self.config, tool_config=self.tool_config) |
| profiler.start() |
| self.assertFalse(profiler.this_step) |
| self.assertEqual(NPUProfiler._define_count, 0) |
| mock_get_profiler.assert_not_called() |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_start_discrete_mode_does_not_increase_count(self, mock_get_profiler): |
| tool_config = NPUToolConfig(discrete=True) |
| profiler = NPUProfiler(rank=0, config=self.config, tool_config=tool_config) |
| profiler.start() |
| self.assertEqual(NPUProfiler._define_count, 0) |
| mock_get_profiler.assert_not_called() |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_multiple_start_calls_do_not_increase_count(self, mock_get_profiler): |
| profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config) |
| profiler.start() |
| profiler.start() |
| self.assertEqual(NPUProfiler._define_count, 1) |
| mock_get_profiler.assert_called_once() |
|
|
|
|
| class TestNPUProfilerStartStopInteraction(unittest.TestCase): |
| def setUp(self): |
| NPUProfiler._define_count = 0 |
| self.config = ProfilerConfig(enable=True, ranks=[0]) |
| self.tool_config = NPUToolConfig(discrete=False) |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_start_stop_cycle(self, mock_get_profiler): |
| mock_profile_npu = MagicMock() |
| mock_get_profiler.return_value = mock_profile_npu |
|
|
| profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config) |
| profiler.start() |
| self.assertEqual(NPUProfiler._define_count, 1) |
| self.assertEqual(mock_profile_npu.start.call_count, 1) |
| profiler.stop() |
| self.assertEqual(NPUProfiler._define_count, 0) |
| self.assertEqual(mock_profile_npu.step.call_count, 1) |
| self.assertEqual(mock_profile_npu.stop.call_count, 1) |
|
|
| @patch("verl.utils.profiler.mstx_profile.get_npu_profiler") |
| def test_multiple_instances_share_define_count(self, mock_get_profiler): |
| mock_profile_npu = MagicMock() |
| mock_get_profiler.return_value = mock_profile_npu |
|
|
| profiler1 = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config) |
| profiler2 = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config) |
| profiler1.start() |
| profiler2.start() |
| self.assertEqual(NPUProfiler._define_count, 1) |
| self.assertEqual(mock_profile_npu.start.call_count, 1) |
| profiler1.stop() |
| self.assertEqual(NPUProfiler._define_count, 0) |
|
|
|
|
| class TestNPUProfilerAnnotate(unittest.TestCase): |
| def setUp(self): |
| self.config = ProfilerConfig(enable=True, all_ranks=True) |
| self.tool_config = NPUToolConfig(discrete=False) |
| self.rank = 0 |
|
|
| def test_annotate_decorator_applied_correctly(self): |
| mock_worker = MagicMock() |
| mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config) |
| mock_worker.profiler.this_step = True |
|
|
| mock_mark_range = "mocked_range_handle" |
|
|
| with ( |
| patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch, |
| patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch, |
| ): |
| mock_start_patch.return_value = mock_mark_range |
|
|
| with patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler: |
| decorator = mock_worker.profiler.annotate(message="test") |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| result = test_func(mock_worker) |
|
|
| self.assertEqual(result, "result") |
| mock_start_patch.assert_called_once_with(message="test") |
| mock_end_patch.assert_called_once_with(mock_mark_range) |
| mock_get_profiler.assert_not_called() |
|
|
| def test_annotate_when_profiler_disabled(self): |
| disabled_config = ProfilerConfig(enable=False) |
| mock_worker = MagicMock() |
| mock_worker.profiler = NPUProfiler(rank=self.rank, config=disabled_config, tool_config=self.tool_config) |
|
|
| with ( |
| patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch, |
| patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch, |
| patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler, |
| ): |
| decorator = mock_worker.profiler.annotate(message="test") |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| result = test_func(mock_worker) |
|
|
| self.assertEqual(result, "result") |
| mock_start_patch.assert_not_called() |
| mock_end_patch.assert_not_called() |
| mock_get_profiler.assert_not_called() |
|
|
| def test_annotate_when_this_step_disabled(self): |
| mock_worker = MagicMock() |
| mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config) |
| mock_worker.profiler.this_step = False |
|
|
| with ( |
| patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch, |
| patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch, |
| patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler, |
| ): |
| decorator = mock_worker.profiler.annotate(message="test") |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| result = test_func(mock_worker) |
|
|
| self.assertEqual(result, "result") |
| mock_start_patch.assert_not_called() |
| mock_end_patch.assert_not_called() |
| mock_get_profiler.assert_not_called() |
|
|
| def test_annotate_discrete_mode_enabled(self): |
| discrete_tool_config = NPUToolConfig(discrete=True) |
| mock_worker = MagicMock() |
| mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=discrete_tool_config) |
| mock_worker.profiler.this_step = True |
|
|
| mock_mark_range = "mocked_range_handle" |
| mock_profile_npu = MagicMock() |
|
|
| with ( |
| patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch, |
| patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch, |
| patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler, |
| ): |
| mock_start_patch.return_value = mock_mark_range |
| mock_get_profiler.return_value = mock_profile_npu |
| decorator = mock_worker.profiler.annotate(message="test", role="test_role") |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| result = test_func(mock_worker) |
|
|
| self.assertEqual(result, "result") |
| mock_start_patch.assert_called_once_with(message="test") |
| mock_end_patch.assert_called_once_with(mock_mark_range) |
| mock_get_profiler.assert_called_once_with( |
| contents=mock_worker.profiler.profile_contents, |
| profile_level=mock_worker.profiler.profile_level, |
| profile_save_path=mock_worker.profiler.profile_save_path, |
| analysis=mock_worker.profiler.analysis, |
| role="test_role", |
| ) |
| mock_profile_npu.start.assert_called_once() |
| mock_profile_npu.step.assert_called_once() |
| mock_profile_npu.stop.assert_called_once() |
|
|
| def test_annotate_with_default_message(self): |
| mock_worker = MagicMock() |
| mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config) |
| mock_worker.profiler.this_step = True |
|
|
| mock_mark_range = "mocked_range_handle" |
| with ( |
| patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch, |
| patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch, |
| ): |
| mock_start_patch.return_value = mock_mark_range |
| decorator = mock_worker.profiler.annotate() |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| test_func(mock_worker) |
|
|
| mock_start_patch.assert_called_once_with(message="test_func") |
| mock_end_patch.assert_called_once_with(mock_mark_range) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|