sapiens2-normal / sapiens /registry.py
Rawal Khirodkar
Initial sapiens2-normal Space (HF download at startup, all 4 sizes)
ba23d94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Callable, Dict, Type, Union
class Registry(dict):
def register(
self, obj: Union[Type, Callable] | None = None, *, name: str | None = None
):
def _do_register(o):
key = name or o.__name__
if key in self:
return self[key] # Skip registration if already exists
self[key] = o
return o
return _do_register(obj) if obj is not None else _do_register
def register_module(self, *args, name: str | None = None):
if args and callable(args[0]):
return self.register(args[0], name=name)
def decorator(obj):
return self.register(obj, name=name)
return decorator
def build(self, cfg: Dict[str, Any], **extra_kwargs) -> Any:
cfg = dict(cfg) # shallow copy
obj_type = cfg.pop("type")
if obj_type not in self:
raise KeyError(f"{obj_type!r} not found in registry.")
cls_or_fn = self[obj_type]
return cls_or_fn(**cfg, **extra_kwargs)
# --------------------------------------------------------------------------- #
MODELS = Registry()
DATASETS = Registry()
TRANSFORMS = Registry()
OPTIMIZERS = Registry()
SCHEDULERS = Registry()
LOGGERS = Registry()
VISUALIZERS = Registry()
HOOKS = Registry()
__all__ = [
"Registry",
"MODELS",
"DATASETS",
"OPTIMIZERS",
"SCHEDULERS",
"LOGGERS",
"VISUALIZERS",
"HOOKS",
]