| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import ast |
| import builtins |
| import collections.abc as abc |
| import importlib |
| import inspect |
| import os |
| import uuid |
| from collections import OrderedDict |
| from contextlib import contextmanager |
| from dataclasses import is_dataclass |
| from typing import Any, Dict, List, Tuple, Union |
|
|
| import attrs |
| import yaml |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
| from cosmos1.utils.lazy_config.file_io import PathManager |
| from cosmos1.utils.lazy_config.registry import _convert_target_to_string |
|
|
| __all__ = ["LazyCall", "LazyConfig"] |
|
|
|
|
| def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: |
| return OrderedDict(sorted(d.items(), key=lambda x: x[0])) |
|
|
|
|
| def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: |
| return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) |
|
|
|
|
| def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: |
| if isinstance(obj, dict): |
| return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) |
| elif isinstance(obj, list): |
| return [sort_recursive(item) for item in obj] |
| return obj |
|
|
|
|
| yaml.add_representer(OrderedDict, dict_representer) |
|
|
|
|
| def get_default_params(cls_or_func): |
| if callable(cls_or_func): |
| |
| signature = inspect.signature(cls_or_func) |
| else: |
| |
| signature = inspect.signature(cls_or_func.__init__) |
| params = signature.parameters |
| default_params = { |
| name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty |
| } |
| return default_params |
|
|
|
|
| class LazyCall: |
| """ |
| Wrap a callable so that when it's called, the call will not be executed, |
| but returns a dict that describes the call. |
| |
| LazyCall object has to be called with only keyword arguments. Positional |
| arguments are not yet supported. |
| |
| Examples: |
| :: |
| from detectron2.config import instantiate, LazyCall |
| |
| layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) |
| layer_cfg.out_channels = 64 # can edit it afterwards |
| layer = instantiate(layer_cfg) |
| """ |
|
|
| def __init__(self, target): |
| if not (callable(target) or isinstance(target, (str, abc.Mapping))): |
| raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") |
| self._target = target |
|
|
| def __call__(self, **kwargs): |
| if is_dataclass(self._target) or attrs.has(self._target): |
| |
| |
| target = _convert_target_to_string(self._target) |
| else: |
| target = self._target |
| kwargs["_target_"] = target |
|
|
| _final_params = get_default_params(self._target) |
| _final_params.update(kwargs) |
|
|
| return DictConfig(content=_final_params, flags={"allow_objects": True}) |
|
|
|
|
| def _visit_dict_config(cfg, func): |
| """ |
| Apply func recursively to all DictConfig in cfg. |
| """ |
| if isinstance(cfg, DictConfig): |
| func(cfg) |
| for v in cfg.values(): |
| _visit_dict_config(v, func) |
| elif isinstance(cfg, ListConfig): |
| for v in cfg: |
| _visit_dict_config(v, func) |
|
|
|
|
| def _validate_py_syntax(filename): |
| |
| with PathManager.open(filename, "r") as f: |
| content = f.read() |
| try: |
| ast.parse(content) |
| except SyntaxError as e: |
| raise SyntaxError(f"Config file {filename} has syntax error!") from e |
|
|
|
|
| def _cast_to_config(obj): |
| |
| if isinstance(obj, dict): |
| return DictConfig(obj, flags={"allow_objects": True}) |
| return obj |
|
|
|
|
| _CFG_PACKAGE_NAME = "detectron2._cfg_loader" |
| """ |
| A namespace to put all imported config into. |
| """ |
|
|
|
|
| def _random_package_name(filename): |
| |
| return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) |
|
|
|
|
| @contextmanager |
| def _patch_import(): |
| """ |
| Enhance relative import statements in config files, so that they: |
| 1. locate files purely based on relative location, regardless of packages. |
| e.g. you can import file without having __init__ |
| 2. do not cache modules globally; modifications of module states has no side effect |
| 3. support other storage system through PathManager, so config files can be in the cloud |
| 4. imported dict are turned into omegaconf.DictConfig automatically |
| """ |
| old_import = builtins.__import__ |
|
|
| def find_relative_file(original_file, relative_import_path, level): |
| |
| |
| |
| relative_import_err = """ |
| Relative import of directories is not allowed within config files. |
| Within a config file, relative import can only import other config files. |
| """.replace( |
| "\n", " " |
| ) |
| if not len(relative_import_path): |
| raise ImportError(relative_import_err) |
|
|
| cur_file = os.path.dirname(original_file) |
| for _ in range(level - 1): |
| cur_file = os.path.dirname(cur_file) |
| cur_name = relative_import_path.lstrip(".") |
| for part in cur_name.split("."): |
| cur_file = os.path.join(cur_file, part) |
| if not cur_file.endswith(".py"): |
| cur_file += ".py" |
| if not PathManager.isfile(cur_file): |
| cur_file_no_suffix = cur_file[: -len(".py")] |
| if PathManager.isdir(cur_file_no_suffix): |
| raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) |
| else: |
| raise ImportError( |
| f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist." |
| ) |
| return cur_file |
|
|
| def new_import(name, globals=None, locals=None, fromlist=(), level=0): |
| if ( |
| |
| level != 0 |
| and globals is not None |
| and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) |
| ): |
| cur_file = find_relative_file(globals["__file__"], name, level) |
| _validate_py_syntax(cur_file) |
| spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) |
| module = importlib.util.module_from_spec(spec) |
| module.__file__ = cur_file |
| with PathManager.open(cur_file) as f: |
| content = f.read() |
| exec(compile(content, cur_file, "exec"), module.__dict__) |
| for name in fromlist: |
| val = _cast_to_config(module.__dict__[name]) |
| module.__dict__[name] = val |
| return module |
| return old_import(name, globals, locals, fromlist=fromlist, level=level) |
|
|
| builtins.__import__ = new_import |
| yield new_import |
| builtins.__import__ = old_import |
|
|
|
|
| class LazyConfig: |
| """ |
| Provide methods to save, load, and overrides an omegaconf config object |
| which may contain definition of lazily-constructed objects. |
| """ |
|
|
| @staticmethod |
| def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): |
| """ |
| Load a config file. |
| |
| Args: |
| filename: absolute path or relative path w.r.t. the current working directory |
| keys: keys to load and return. If not given, return all keys |
| (whose values are config objects) in a dict. |
| """ |
| has_keys = keys is not None |
| filename = filename.replace("/./", "/") |
| if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: |
| raise ValueError(f"Config file {filename} has to be a python or yaml file.") |
| if filename.endswith(".py"): |
| _validate_py_syntax(filename) |
|
|
| with _patch_import(): |
| |
| module_namespace = { |
| "__file__": filename, |
| "__package__": _random_package_name(filename), |
| } |
| with PathManager.open(filename) as f: |
| content = f.read() |
| |
| |
| |
| exec(compile(content, filename, "exec"), module_namespace) |
|
|
| ret = module_namespace |
| else: |
| with PathManager.open(filename) as f: |
| obj = yaml.unsafe_load(f) |
| ret = OmegaConf.create(obj, flags={"allow_objects": True}) |
|
|
| if has_keys: |
| if isinstance(keys, str): |
| return _cast_to_config(ret[keys]) |
| else: |
| return tuple(_cast_to_config(ret[a]) for a in keys) |
| else: |
| if filename.endswith(".py"): |
| |
| ret = DictConfig( |
| { |
| name: _cast_to_config(value) |
| for name, value in ret.items() |
| if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") |
| }, |
| flags={"allow_objects": True}, |
| ) |
| return ret |
|
|