| |
| import time |
| from typing import Callable, Dict, List, Optional, Union |
|
|
| import torch.nn as nn |
|
|
| import mmengine |
| from mmengine.device import get_device |
| from mmengine.model import revert_sync_batchnorm |
| from mmengine.optim import BaseOptimWrapper, _ParamScheduler |
| from mmengine.registry import STRATEGIES |
| from mmengine.utils import get_git_hash |
| from .base import BaseStrategy |
|
|
|
|
| @STRATEGIES.register_module() |
| class SingleDeviceStrategy(BaseStrategy): |
| """Strategy for single device training.""" |
|
|
| def prepare( |
| self, |
| model: Union[nn.Module, dict], |
| *, |
| optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, |
| param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, |
| compile: Union[dict, bool] = False, |
| dispatch_kwargs: Optional[dict] = None, |
| ): |
| """Prepare model and some components. |
| |
| Args: |
| model (:obj:`torch.nn.Module` or dict): The model to be run. It |
| can be a dict used for build a model. |
| |
| Keyword Args: |
| optim_wrapper (BaseOptimWrapper or dict, optional): Computing the |
| gradient of model parameters and updating them. |
| Defaults to None. |
| See :meth:`build_optim_wrapper` for examples. |
| param_scheduler (_ParamScheduler or dict or list, optional): |
| Parameter scheduler for updating optimizer parameters. If |
| specified, :attr:`optim_wrapper` should also be specified. |
| Defaults to None. |
| See :meth:`build_param_scheduler` for examples. |
| compile (dict, optional): Config to compile model. |
| Defaults to False. Requires PyTorch>=2.0. |
| dispatch_kwargs (dict, optional): Kwargs to be passed to other |
| methods of Strategy. Defaults to None. |
| If ``accumulative_counts`` is set in ``optim_wrapper``, you |
| need to provide ``max_iters`` in ``dispatch_kwargs``. |
| """ |
| if self._prepared: |
| return self._prepared_components() |
| if dispatch_kwargs is not None: |
| self.dispatch_kwargs.update(dispatch_kwargs) |
|
|
| model = self.build_model(model) |
| model = self._init_model_weights(model) |
| model = self._wrap_model(model) |
| model = self.compile_model(model, compile=compile) |
|
|
| self.model = model |
|
|
| if optim_wrapper is not None: |
| self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) |
| self._scale_lr() |
|
|
| accumulative_counts = getattr(self.optim_wrapper, |
| '_accumulative_counts', 1) |
| if accumulative_counts > 1: |
| if 'max_iters' not in self.dispatch_kwargs: |
| raise ValueError( |
| '"max_iters" must be specified because ' |
| '"accumulative_counts" was set as ' |
| f'{accumulative_counts} which is greater than 1.') |
|
|
| self.optim_wrapper.initialize_count_status( |
| self.model, 0, self.dispatch_kwargs['max_iters']) |
|
|
| if param_scheduler is not None: |
| self.param_schedulers = self.build_param_scheduler( |
| param_scheduler, self.optim_wrapper) |
|
|
| self._prepared = True |
| return self._prepared_components() |
|
|
| def _wrap_model(self, model: nn.Module) -> nn.Module: |
| model = self.convert_model(model) |
| current_device = get_device() |
| return model.to(current_device) |
|
|
| def convert_model(self, model: nn.Module) -> nn.Module: |
| """Convert layers of model. |
| |
| convert all ``SyncBatchNorm`` (SyncBN) and |
| ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers in the model to |
| ``BatchNormXd`` layers. |
| |
| Args: |
| model (nn.Module): Model to convert. |
| """ |
| self.logger.info( |
| 'Distributed training is not used, all SyncBatchNorm (SyncBN) ' |
| 'layers in the model will be automatically reverted to ' |
| 'BatchNormXd layers if they are used.') |
| model = revert_sync_batchnorm(model) |
| return model |
|
|
| def load_checkpoint( |
| self, |
| filename: str, |
| *, |
| map_location: Union[str, Callable] = 'cpu', |
| strict: bool = False, |
| revise_keys: list = [(r'^module.', '')], |
| callback: Optional[Callable] = None, |
| ) -> dict: |
| """Load checkpoint from given ``filename``. |
| |
| Args: |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| ``open-mmlab://xxx``. |
| |
| Keyword Args: |
| map_location (str or callable): A string or a callable function to |
| specifying how to remap storage locations. |
| Defaults to 'cpu'. |
| strict (bool): strict (bool): Whether to allow different params for |
| the model and checkpoint. |
| revise_keys (list): A list of customized keywords to modify the |
| state_dict in checkpoint. Each item is a (pattern, replacement) |
| pair of the regular expression operations. Defaults to strip |
| the prefix 'module.' by [(r'^module\\.', '')]. |
| callback (callable, callable): Callback function to modify the |
| checkpoint after loading the checkpoint. |
| Defaults to None. |
| """ |
| from mmengine.runner.checkpoint import _load_checkpoint |
|
|
| self.logger.info(f'Load checkpoint from {filename}') |
|
|
| if map_location == 'default': |
| device = get_device() |
| checkpoint = _load_checkpoint(filename, map_location=device) |
| else: |
| checkpoint = _load_checkpoint(filename, map_location=map_location) |
|
|
| |
| if callback is not None: |
| callback(checkpoint) |
|
|
| state_dict = checkpoint.pop('state_dict') |
| self.load_model_state_dict( |
| state_dict, strict=strict, revise_keys=revise_keys) |
|
|
| return checkpoint |
|
|
| def resume( |
| self, |
| filename: str, |
| *, |
| resume_optimizer: bool = True, |
| resume_param_scheduler: bool = True, |
| map_location: Union[str, Callable] = 'default', |
| callback: Optional[Callable] = None, |
| ) -> dict: |
| """Resume training from given ``filename``. |
| |
| Four types of states will be resumed. |
| |
| - model state |
| - optimizer state |
| - scheduler state |
| - randomness state |
| |
| Args: |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| ``open-mmlab://xxx``. |
| |
| Keyword Args: |
| resume_optimizer (bool): Whether to resume optimizer state. |
| Defaults to True. |
| resume_param_scheduler (bool): Whether to resume param scheduler |
| state. Defaults to True. |
| map_location (str or callable):A string or a callable function to |
| specifying how to remap storage locations. |
| Defaults to 'default'. |
| callback (callable, callable): Callback function to modify the |
| checkpoint before saving the checkpoint. |
| Defaults to None. |
| """ |
| self.logger.info(f'Resume checkpoint from {filename}') |
|
|
| checkpoint = self.load_checkpoint( |
| filename, map_location=map_location, callback=callback) |
|
|
| if resume_optimizer: |
| self.load_optim_state_dict(checkpoint.pop('optimizer')) |
|
|
| if resume_param_scheduler and hasattr(self, 'param_schedulers'): |
| self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) |
|
|
| |
| resumed_seed = checkpoint['meta'].get('seed', None) |
| current_seed = self._randomness.get('seed') |
| if resumed_seed is not None and resumed_seed != current_seed: |
| if current_seed is not None: |
| self.logger.warning(f'The value of random seed in the ' |
| f'checkpoint "{resumed_seed}" is ' |
| f'different from the value in ' |
| f'`randomness` config "{current_seed}"') |
| self._randomness.update(seed=resumed_seed) |
| self._set_randomness(**self._randomness) |
|
|
| |
| cur_iter = checkpoint['meta']['iter'] |
|
|
| if hasattr(self, 'optim_wrapper'): |
| accumulative_counts = getattr(self.optim_wrapper, |
| '_accumulative_counts', 1) |
| if accumulative_counts > 1: |
| if 'max_iters' not in self.dispatch_kwargs: |
| raise ValueError( |
| '"max_iters" must be specified because ' |
| '"accumulative_counts" was set as ' |
| f'{accumulative_counts} which is greater than 1.') |
| |
| self.optim_wrapper.initialize_count_status( |
| self.model, cur_iter, self.dispatch_kwargs['max_iters']) |
|
|
| return checkpoint |
|
|
| def save_checkpoint( |
| self, |
| filename: str, |
| *, |
| save_optimizer: bool = True, |
| save_param_scheduler: bool = True, |
| extra_ckpt: Optional[dict] = None, |
| callback: Optional[Callable] = None, |
| ) -> None: |
| """Save checkpoint to given ``filename``. |
| |
| Args: |
| filename (str): Filename to save checkpoint. |
| |
| Keyword Args: |
| save_optimizer (bool): Whether to save the optimizer to |
| the checkpoint. Defaults to True. |
| save_param_scheduler (bool): Whether to save the param_scheduler |
| to the checkpoint. Defaults to True. |
| extra_ckpt (dict, optional): Extra checkpoint to save. |
| Defaults to None. |
| callback (callable, callable): Callback function to modify the |
| checkpoint before saving the checkpoint. |
| Defaults to None. |
| """ |
| from mmengine.runner.checkpoint import save_checkpoint |
|
|
| state_dict: dict = dict() |
| state_dict['state_dict'] = self.model_state_dict() |
|
|
| |
| if save_optimizer and hasattr(self, 'optim_wrapper'): |
| state_dict['optimizer'] = self.optim_state_dict() |
|
|
| if save_param_scheduler and hasattr(self, 'param_schedulers'): |
| state_dict['param_schedulers'] = self.scheduler_state_dict() |
|
|
| |
| if extra_ckpt is None: |
| extra_ckpt = dict() |
| if 'meta' not in extra_ckpt: |
| extra_ckpt['meta'] = dict() |
| extra_ckpt['meta'].update( |
| seed=self.seed, |
| time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), |
| mmengine=mmengine.__version__ + get_git_hash(), |
| ) |
|
|
| state_dict.update(extra_ckpt) |
|
|
| |
| if callback is not None: |
| callback(state_dict) |
|
|
| save_checkpoint(state_dict, filename) |
|
|