Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations | |
| from typing import Any, Callable, Dict, Type, Union | |
| class Registry(dict): | |
| def register( | |
| self, obj: Union[Type, Callable] | None = None, *, name: str | None = None | |
| ): | |
| def _do_register(o): | |
| key = name or o.__name__ | |
| if key in self: | |
| return self[key] # Skip registration if already exists | |
| self[key] = o | |
| return o | |
| return _do_register(obj) if obj is not None else _do_register | |
| def register_module(self, *args, name: str | None = None): | |
| if args and callable(args[0]): | |
| return self.register(args[0], name=name) | |
| def decorator(obj): | |
| return self.register(obj, name=name) | |
| return decorator | |
| def build(self, cfg: Dict[str, Any], **extra_kwargs) -> Any: | |
| cfg = dict(cfg) # shallow copy | |
| obj_type = cfg.pop("type") | |
| if obj_type not in self: | |
| raise KeyError(f"{obj_type!r} not found in registry.") | |
| cls_or_fn = self[obj_type] | |
| return cls_or_fn(**cfg, **extra_kwargs) | |
| # --------------------------------------------------------------------------- # | |
| MODELS = Registry() | |
| DATASETS = Registry() | |
| TRANSFORMS = Registry() | |
| OPTIMIZERS = Registry() | |
| SCHEDULERS = Registry() | |
| LOGGERS = Registry() | |
| VISUALIZERS = Registry() | |
| HOOKS = Registry() | |
| __all__ = [ | |
| "Registry", | |
| "MODELS", | |
| "DATASETS", | |
| "OPTIMIZERS", | |
| "SCHEDULERS", | |
| "LOGGERS", | |
| "VISUALIZERS", | |
| "HOOKS", | |
| ] | |