# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from unittest.mock import MagicMock, patch from verl.utils import omega_conf_to_dataclass from verl.utils.profiler.config import NsightToolConfig, ProfilerConfig from verl.utils.profiler.nvtx_profile import NsightSystemsProfiler class TestProfilerConfig(unittest.TestCase): def test_config_init(self): import os from hydra import compose, initialize_config_dir with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): cfg = compose(config_name="ppo_trainer") for config in [ cfg.actor_rollout_ref.actor.profiler, cfg.actor_rollout_ref.rollout.profiler, cfg.actor_rollout_ref.ref.profiler, cfg.critic.profiler, cfg.reward_model.profiler, ]: profiler_config = omega_conf_to_dataclass(config) self.assertEqual(profiler_config.tool, config.tool) self.assertEqual(profiler_config.enable, config.enable) self.assertEqual(profiler_config.all_ranks, config.all_ranks) self.assertEqual(profiler_config.ranks, config.ranks) self.assertEqual(profiler_config.save_path, config.save_path) self.assertEqual(profiler_config.ranks, config.ranks) assert isinstance(profiler_config, ProfilerConfig) with self.assertRaises(AttributeError): _ = profiler_config.non_existing_key assert config.get("non_existing_key") == profiler_config.get("non_existing_key") assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) def test_frozen_config(self): """Test that modifying frozen keys in ProfilerConfig raises exceptions.""" from dataclasses import FrozenInstanceError from verl.utils.profiler.config import ProfilerConfig # Create a new ProfilerConfig instance config = ProfilerConfig(all_ranks=False, ranks=[0]) with self.assertRaises(FrozenInstanceError): config.all_ranks = True with self.assertRaises(FrozenInstanceError): config.ranks = [1, 2, 3] with self.assertRaises(TypeError): config["all_ranks"] = True with self.assertRaises(TypeError): config["ranks"] = [1, 2, 3] class TestNsightSystemsProfiler(unittest.TestCase): """Test suite for NsightSystemsProfiler functionality. Test Plan: 1. Initialization: Verify profiler state after creation 2. Basic Profiling: Test start/stop functionality 3. Discrete Mode: TODO: Test discrete profiling behavior 4. Annotation: Test the annotate decorator in both normal and discrete modes 5. Config Validation: Verify proper config initialization from OmegaConf """ def setUp(self): self.config = ProfilerConfig(enable=True, all_ranks=True) self.rank = 0 self.profiler = NsightSystemsProfiler(self.rank, self.config, tool_config=NsightToolConfig(discrete=False)) def test_initialization(self): self.assertEqual(self.profiler.this_rank, True) self.assertEqual(self.profiler.this_step, False) def test_start_stop_profiling(self): with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: # Test start self.profiler.start() self.assertTrue(self.profiler.this_step) mock_start.assert_called_once() # Test stop self.profiler.stop() self.assertFalse(self.profiler.this_step) mock_stop.assert_called_once() # def test_discrete_profiling(self): # discrete_config = ProfilerConfig(discrete=True, all_ranks=True) # profiler = NsightSystemsProfiler(self.rank, discrete_config) # with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: # profiler.start() # self.assertTrue(profiler.this_step) # mock_start.assert_not_called() # Shouldn't start immediately in discrete mode # profiler.stop() # self.assertFalse(profiler.this_step) # mock_stop.assert_not_called() # Shouldn't stop immediately in discrete mode def test_annotate_decorator(self): mock_self = MagicMock() mock_self.profiler = self.profiler mock_self.profiler.this_step = True decorator = mock_self.profiler.annotate(message="test") @decorator def test_func(self, *args, **kwargs): return "result" with ( patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, ): result = test_func(mock_self) self.assertEqual(result, "result") mock_start_range.assert_called_once() mock_end_range.assert_called_once() mock_start.assert_not_called() # Not discrete mode mock_stop.assert_not_called() # Not discrete mode # def test_annotate_discrete_mode(self): # discrete_config = ProfilerConfig(discrete=True, all_ranks=True) # profiler = NsightSystemsProfiler(self.rank, discrete_config) # mock_self = MagicMock() # mock_self.profiler = profiler # mock_self.profiler.this_step = True # @NsightSystemsProfiler.annotate(message="test") # def test_func(self, *args, **kwargs): # return "result" # with ( # patch("torch.cuda.profiler.start") as mock_start, # patch("torch.cuda.profiler.stop") as mock_stop, # patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, # patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, # ): # result = test_func(mock_self) # self.assertEqual(result, "result") # mock_start_range.assert_called_once() # mock_end_range.assert_called_once() # mock_start.assert_called_once() # Should start in discrete mode # mock_stop.assert_called_once() # Should stop in discrete mode if __name__ == "__main__": unittest.main()