File size: 1,692 Bytes
ba23d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Any, Callable, Dict, Type, Union


class Registry(dict):
    def register(
        self, obj: Union[Type, Callable] | None = None, *, name: str | None = None
    ):
        def _do_register(o):
            key = name or o.__name__
            if key in self:
                return self[key]  # Skip registration if already exists
            self[key] = o
            return o

        return _do_register(obj) if obj is not None else _do_register

    def register_module(self, *args, name: str | None = None):
        if args and callable(args[0]):
            return self.register(args[0], name=name)

        def decorator(obj):
            return self.register(obj, name=name)

        return decorator

    def build(self, cfg: Dict[str, Any], **extra_kwargs) -> Any:
        cfg = dict(cfg)  # shallow copy
        obj_type = cfg.pop("type")
        if obj_type not in self:
            raise KeyError(f"{obj_type!r} not found in registry.")
        cls_or_fn = self[obj_type]
        return cls_or_fn(**cfg, **extra_kwargs)


# --------------------------------------------------------------------------- #
MODELS = Registry()
DATASETS = Registry()
TRANSFORMS = Registry()
OPTIMIZERS = Registry()
SCHEDULERS = Registry()
LOGGERS = Registry()
VISUALIZERS = Registry()
HOOKS = Registry()

__all__ = [
    "Registry",
    "MODELS",
    "DATASETS",
    "OPTIMIZERS",
    "SCHEDULERS",
    "LOGGERS",
    "VISUALIZERS",
    "HOOKS",
]