Spaces:
Running
Running
File size: 1,692 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Callable, Dict, Type, Union
class Registry(dict):
def register(
self, obj: Union[Type, Callable] | None = None, *, name: str | None = None
):
def _do_register(o):
key = name or o.__name__
if key in self:
return self[key] # Skip registration if already exists
self[key] = o
return o
return _do_register(obj) if obj is not None else _do_register
def register_module(self, *args, name: str | None = None):
if args and callable(args[0]):
return self.register(args[0], name=name)
def decorator(obj):
return self.register(obj, name=name)
return decorator
def build(self, cfg: Dict[str, Any], **extra_kwargs) -> Any:
cfg = dict(cfg) # shallow copy
obj_type = cfg.pop("type")
if obj_type not in self:
raise KeyError(f"{obj_type!r} not found in registry.")
cls_or_fn = self[obj_type]
return cls_or_fn(**cfg, **extra_kwargs)
# --------------------------------------------------------------------------- #
MODELS = Registry()
DATASETS = Registry()
TRANSFORMS = Registry()
OPTIMIZERS = Registry()
SCHEDULERS = Registry()
LOGGERS = Registry()
VISUALIZERS = Registry()
HOOKS = Registry()
__all__ = [
"Registry",
"MODELS",
"DATASETS",
"OPTIMIZERS",
"SCHEDULERS",
"LOGGERS",
"VISUALIZERS",
"HOOKS",
]
|