| |
| import inspect |
| import warnings |
| from functools import partial |
|
|
| from .misc import is_seq_of |
|
|
|
|
| def build_from_cfg(cfg, registry, default_args=None): |
| """Build a module from config dict. |
| |
| Args: |
| cfg (dict): Config dict. It should at least contain the key "type". |
| registry (:obj:`Registry`): The registry to search the type from. |
| default_args (dict, optional): Default initialization arguments. |
| |
| Returns: |
| object: The constructed object. |
| """ |
| if not isinstance(cfg, dict): |
| raise TypeError(f'cfg must be a dict, but got {type(cfg)}') |
| if 'type' not in cfg: |
| if default_args is None or 'type' not in default_args: |
| raise KeyError( |
| '`cfg` or `default_args` must contain the key "type", ' |
| f'but got {cfg}\n{default_args}') |
| if not isinstance(registry, Registry): |
| raise TypeError('registry must be an mmcv.Registry object, ' |
| f'but got {type(registry)}') |
| if not (isinstance(default_args, dict) or default_args is None): |
| raise TypeError('default_args must be a dict or None, ' |
| f'but got {type(default_args)}') |
|
|
| args = cfg.copy() |
|
|
| if default_args is not None: |
| for name, value in default_args.items(): |
| args.setdefault(name, value) |
|
|
| obj_type = args.pop('type') |
| if isinstance(obj_type, str): |
| obj_cls = registry.get(obj_type) |
| if obj_cls is None: |
| raise KeyError( |
| f'{obj_type} is not in the {registry.name} registry') |
| elif inspect.isclass(obj_type): |
| obj_cls = obj_type |
| else: |
| raise TypeError( |
| f'type must be a str or valid type, but got {type(obj_type)}') |
| try: |
| return obj_cls(**args) |
| except Exception as e: |
| |
| raise type(e)(f'{obj_cls.__name__}: {e}') |
|
|
|
|
| class Registry: |
| """A registry to map strings to classes. |
| |
| Registered object could be built from registry. |
| Example: |
| >>> MODELS = Registry('models') |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> resnet = MODELS.build(dict(type='ResNet')) |
| |
| Please refer to |
| https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for |
| advanced usage. |
| |
| Args: |
| name (str): Registry name. |
| build_func(func, optional): Build function to construct instance from |
| Registry, func:`build_from_cfg` is used if neither ``parent`` or |
| ``build_func`` is specified. If ``parent`` is specified and |
| ``build_func`` is not given, ``build_func`` will be inherited |
| from ``parent``. Default: None. |
| parent (Registry, optional): Parent registry. The class registered in |
| children registry could be built from parent. Default: None. |
| scope (str, optional): The scope of registry. It is the key to search |
| for children registry. If not specified, scope will be the name of |
| the package where class is defined, e.g. mmdet, mmcls, mmseg. |
| Default: None. |
| """ |
|
|
| def __init__(self, name, build_func=None, parent=None, scope=None): |
| self._name = name |
| self._module_dict = dict() |
| self._children = dict() |
| self._scope = self.infer_scope() if scope is None else scope |
|
|
| |
| |
| |
| |
| if build_func is None: |
| if parent is not None: |
| self.build_func = parent.build_func |
| else: |
| self.build_func = build_from_cfg |
| else: |
| self.build_func = build_func |
| if parent is not None: |
| assert isinstance(parent, Registry) |
| parent._add_children(self) |
| self.parent = parent |
| else: |
| self.parent = None |
|
|
| def __len__(self): |
| return len(self._module_dict) |
|
|
| def __contains__(self, key): |
| return self.get(key) is not None |
|
|
| def __repr__(self): |
| format_str = self.__class__.__name__ + \ |
| f'(name={self._name}, ' \ |
| f'items={self._module_dict})' |
| return format_str |
|
|
| @staticmethod |
| def infer_scope(): |
| """Infer the scope of registry. |
| |
| The name of the package where registry is defined will be returned. |
| |
| Example: |
| # in mmdet/models/backbone/resnet.py |
| >>> MODELS = Registry('models') |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> pass |
| The scope of ``ResNet`` will be ``mmdet``. |
| |
| |
| Returns: |
| scope (str): The inferred scope name. |
| """ |
| |
| |
| filename = inspect.getmodule(inspect.stack()[2][0]).__name__ |
| split_filename = filename.split('.') |
| return split_filename[0] |
|
|
| @staticmethod |
| def split_scope_key(key): |
| """Split scope and key. |
| |
| The first scope will be split from key. |
| |
| Examples: |
| >>> Registry.split_scope_key('mmdet.ResNet') |
| 'mmdet', 'ResNet' |
| >>> Registry.split_scope_key('ResNet') |
| None, 'ResNet' |
| |
| Return: |
| scope (str, None): The first scope. |
| key (str): The remaining key. |
| """ |
| split_index = key.find('.') |
| if split_index != -1: |
| return key[:split_index], key[split_index + 1:] |
| else: |
| return None, key |
|
|
| @property |
| def name(self): |
| return self._name |
|
|
| @property |
| def scope(self): |
| return self._scope |
|
|
| @property |
| def module_dict(self): |
| return self._module_dict |
|
|
| @property |
| def children(self): |
| return self._children |
|
|
| def get(self, key): |
| """Get the registry record. |
| |
| Args: |
| key (str): The class name in string format. |
| |
| Returns: |
| class: The corresponding class. |
| """ |
| scope, real_key = self.split_scope_key(key) |
| if scope is None or scope == self._scope: |
| |
| if real_key in self._module_dict: |
| return self._module_dict[real_key] |
| else: |
| |
| if scope in self._children: |
| return self._children[scope].get(real_key) |
| else: |
| |
| parent = self.parent |
| while parent.parent is not None: |
| parent = parent.parent |
| return parent.get(key) |
|
|
| def build(self, *args, **kwargs): |
| return self.build_func(*args, **kwargs, registry=self) |
|
|
| def _add_children(self, registry): |
| """Add children for a registry. |
| |
| The ``registry`` will be added as children based on its scope. |
| The parent registry could build objects from children registry. |
| |
| Example: |
| >>> models = Registry('models') |
| >>> mmdet_models = Registry('models', parent=models) |
| >>> @mmdet_models.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> resnet = models.build(dict(type='mmdet.ResNet')) |
| """ |
|
|
| assert isinstance(registry, Registry) |
| assert registry.scope is not None |
| assert registry.scope not in self.children, \ |
| f'scope {registry.scope} exists in {self.name} registry' |
| self.children[registry.scope] = registry |
|
|
| def _register_module(self, module_class, module_name=None, force=False): |
| if not inspect.isclass(module_class): |
| raise TypeError('module must be a class, ' |
| f'but got {type(module_class)}') |
|
|
| if module_name is None: |
| module_name = module_class.__name__ |
| if isinstance(module_name, str): |
| module_name = [module_name] |
| for name in module_name: |
| if not force and name in self._module_dict: |
| raise KeyError(f'{name} is already registered ' |
| f'in {self.name}') |
| self._module_dict[name] = module_class |
|
|
| def deprecated_register_module(self, cls=None, force=False): |
| warnings.warn( |
| 'The old API of register_module(module, force=False) ' |
| 'is deprecated and will be removed, please use the new API ' |
| 'register_module(name=None, force=False, module=None) instead.') |
| if cls is None: |
| return partial(self.deprecated_register_module, force=force) |
| self._register_module(cls, force=force) |
| return cls |
|
|
| def register_module(self, name=None, force=False, module=None): |
| """Register a module. |
| |
| A record will be added to `self._module_dict`, whose key is the class |
| name or the specified name, and value is the class itself. |
| It can be used as a decorator or a normal function. |
| |
| Example: |
| >>> backbones = Registry('backbone') |
| >>> @backbones.register_module() |
| >>> class ResNet: |
| >>> pass |
| |
| >>> backbones = Registry('backbone') |
| >>> @backbones.register_module(name='mnet') |
| >>> class MobileNet: |
| >>> pass |
| |
| >>> backbones = Registry('backbone') |
| >>> class ResNet: |
| >>> pass |
| >>> backbones.register_module(ResNet) |
| |
| Args: |
| name (str | None): The module name to be registered. If not |
| specified, the class name will be used. |
| force (bool, optional): Whether to override an existing class with |
| the same name. Default: False. |
| module (type): Module class to be registered. |
| """ |
| if not isinstance(force, bool): |
| raise TypeError(f'force must be a boolean, but got {type(force)}') |
| |
| |
| if isinstance(name, type): |
| return self.deprecated_register_module(name, force=force) |
|
|
| |
| if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
| raise TypeError( |
| 'name must be either of None, an instance of str or a sequence' |
| f' of str, but got {type(name)}') |
|
|
| |
| if module is not None: |
| self._register_module( |
| module_class=module, module_name=name, force=force) |
| return module |
|
|
| |
| def _register(cls): |
| self._register_module( |
| module_class=cls, module_name=name, force=force) |
| return cls |
|
|
| return _register |
|
|