| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Configuration utility functions |
| """ |
|
|
| import importlib |
| from typing import Any, Callable, List, Union |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
| OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
| def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: |
| """ |
| Load a configuration. Will resolve inheritance. |
| """ |
| config = OmegaConf.load(path) |
| if argv is not None: |
| config_argv = OmegaConf.from_dotlist(argv) |
| config = OmegaConf.merge(config, config_argv) |
| config = resolve_recursive(config, resolve_inheritance) |
| return config |
|
|
|
|
| def resolve_recursive( |
| config: Any, |
| resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], |
| ) -> Any: |
| config = resolver(config) |
| if isinstance(config, DictConfig): |
| for k in config.keys(): |
| v = config.get(k) |
| if isinstance(v, (DictConfig, ListConfig)): |
| config[k] = resolve_recursive(v, resolver) |
| if isinstance(config, ListConfig): |
| for i in range(len(config)): |
| v = config.get(i) |
| if isinstance(v, (DictConfig, ListConfig)): |
| config[i] = resolve_recursive(v, resolver) |
| return config |
|
|
|
|
| def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: |
| """ |
| Recursively resolve inheritance if the config contains: |
| __inherit__: path/to/parent.yaml or a ListConfig of such paths. |
| """ |
| if isinstance(config, DictConfig): |
| inherit = config.pop("__inherit__", None) |
|
|
| if inherit: |
| inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] |
|
|
| parent_config = None |
| for parent_path in inherit_list: |
| assert isinstance(parent_path, str) |
| parent_config = ( |
| load_config(parent_path) |
| if parent_config is None |
| else OmegaConf.merge(parent_config, load_config(parent_path)) |
| ) |
|
|
| if len(config.keys()) > 0: |
| config = OmegaConf.merge(parent_config, config) |
| else: |
| config = parent_config |
| return config |
|
|
|
|
| def import_item(path: str, name: str) -> Any: |
| """ |
| Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass |
| """ |
| return getattr(importlib.import_module(path), name) |
|
|
|
|
| def create_object(config: DictConfig) -> Any: |
| """ |
| Create an object from config. |
| The config is expected to contains the following: |
| __object__: |
| path: path.to.module |
| name: MyClass |
| args: as_config | as_params (default to as_config) |
| """ |
| item = import_item( |
| path=config.__object__.path, |
| name=config.__object__.name, |
| ) |
| args = config.__object__.get("args", "as_config") |
| if args == "as_config": |
| return item(config) |
| if args == "as_params": |
| config = OmegaConf.to_object(config) |
| config.pop("__object__") |
| return item(**config) |
| raise NotImplementedError(f"Unknown args type: {args}") |