| |
| |
| |
| |
| |
| |
| import sys |
| import argparse |
| import importlib |
| from omegaconf import DictConfig |
|
|
| def prepare_parser_from_dict(dic, parser=None): |
| """Prepare an argparser from a dictionary. |
| |
| Args: |
| dic (dict): Two-level config dictionary with unique bottom-level keys. |
| parser (argparse.ArgumentParser, optional): If a parser already |
| exists, add the keys from the dictionary on the top of it. |
| |
| Returns: |
| argparse.ArgumentParser: |
| Parser instance with groups corresponding to the first level keys |
| and arguments corresponding to the second level keys with default |
| values given by the values. |
| """ |
|
|
| def standardized_entry_type(value): |
| """If the default value is None, replace NoneType by str_int_float. |
| If the default value is boolean, look for boolean strings.""" |
| if value is None: |
| return str_int_float |
| if isinstance(str2bool(value), bool): |
| return str2bool_arg |
| return type(value) |
|
|
| if parser is None: |
| parser = argparse.ArgumentParser() |
| for k in dic.keys(): |
| group = parser.add_argument_group(k) |
| if isinstance(dic[k], list): |
| entry_type = standardized_entry_type(dic[k]) |
| group.add_argument("--" + k, default=dic[k], type=entry_type) |
| elif isinstance(dic[k], dict): |
| for kk in dic[k].keys(): |
| entry_type = standardized_entry_type(dic[k][kk]) |
| group.add_argument("--" + kk, default=dic[k][kk], type=entry_type) |
| elif isinstance(dic[k], str): |
| entry_type = standardized_entry_type(dic[k]) |
| group.add_argument("--" + k, default=dic[k], type=entry_type) |
| return parser |
|
|
|
|
| def str_int_float(value): |
| """Type to convert strings to int, float (in this order) if possible. |
| |
| Args: |
| value (str): Value to convert. |
| |
| Returns: |
| int, float, str: Converted value. |
| """ |
| if isint(value): |
| return int(value) |
| if isfloat(value): |
| return float(value) |
| elif isinstance(value, str): |
| return value |
|
|
|
|
| def str2bool(value): |
| """Type to convert strings to Boolean (returns input if not boolean)""" |
| if not isinstance(value, str): |
| return value |
| if value.lower() in ("yes", "true", "y", "1"): |
| return True |
| elif value.lower() in ("no", "false", "n", "0"): |
| return False |
| else: |
| return value |
|
|
|
|
| def str2bool_arg(value): |
| """Argparse type to convert strings to Boolean""" |
| value = str2bool(value) |
| if isinstance(value, bool): |
| return value |
| raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
|
| def isfloat(value): |
| """Computes whether `value` can be cast to a float. |
| |
| Args: |
| value (str): Value to check. |
| |
| Returns: |
| bool: Whether `value` can be cast to a float. |
| |
| """ |
| try: |
| float(value) |
| return True |
| except ValueError: |
| return False |
|
|
|
|
| def isint(value): |
| """Computes whether `value` can be cast to an int |
| |
| Args: |
| value (str): Value to check. |
| |
| Returns: |
| bool: Whether `value` can be cast to an int. |
| |
| """ |
| try: |
| int(value) |
| return True |
| except ValueError: |
| return False |
|
|
|
|
| def parse_args_as_dict(parser, return_plain_args=False, args=None): |
| """Get a dict of dicts out of process `parser.parse_args()` |
| |
| Top-level keys corresponding to groups and bottom-level keys corresponding |
| to arguments. Under `'main_args'`, the arguments which don't belong to a |
| argparse group (i.e main arguments defined before parsing from a dict) can |
| be found. |
| |
| Args: |
| parser (argparse.ArgumentParser): ArgumentParser instance containing |
| groups. Output of `prepare_parser_from_dict`. |
| return_plain_args (bool): Whether to return the output or |
| `parser.parse_args()`. |
| args (list): List of arguments as read from the command line. |
| Used for unit testing. |
| |
| Returns: |
| dict: |
| Dictionary of dictionaries containing the arguments. Optionally the |
| direct output `parser.parse_args()`. |
| """ |
| args = parser.parse_args(args=args) |
| args_dic = {} |
| for group in parser._action_groups: |
| group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} |
| args_dic[group.title] = group_dict |
| if sys.version_info.minor == 10: |
| args_dic["main_args"] = args_dic["positional arguments"] |
| del args_dic["positional arguments"] |
| else: |
| args_dic["main_args"] = args_dic["optional arguments"] |
| del args_dic["optional arguments"] |
| if return_plain_args: |
| return args_dic, args |
| return args_dic |
|
|
| def instantiate(config, **kwargs): |
| if '__target__' in config: |
| module_path, class_name = config['__target__'].rsplit('.', 1) |
| module = importlib.import_module(module_path) |
| cls = getattr(module, class_name) |
| |
| params = {} |
| for key, value in config.items(): |
| if key != '__target__': |
| if isinstance(value, DictConfig) and '__target__' in value: |
| params[key] = instantiate(value) |
| else: |
| params[key] = value |
| |
| params.update(kwargs) |
| return cls(**params) |
| else: |
| |
| return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()} |