| import argparse |
| import os |
|
|
| import yaml |
|
|
| __all__ = ['get_config', 'print_config'] |
|
|
|
|
| def get_config(args): |
|
|
| config = dict2namespace(setdefault(_get_raw_config(args.config), _get_raw_config("default.yml"))) |
|
|
| if not hasattr(config.sampling, "sigma_dist"): |
| config.sampling.sigma_dist = config.model.sigma_dist |
| if not hasattr(config.biggan, "resolution"): |
| config.biggan.resolution = config.data.image_size |
|
|
| if args.consistent: |
| config.sampling.consistent = args.consistent |
| config.sampling.noise_first = False |
| if args.step_lr: |
| config.sampling.step_lr = args.step_lr |
| if args.nsigma != 0: |
| config.sampling.nsigma = args.nsigma |
| if args.step_lr != 0: |
| config.sampling.step_lr = args.step_lr |
| if args.batch_size != 0: |
| config.sampling.batch_size = args.batch_size |
| config.fast_fid.batch_size = args.batch_size |
| |
| if args.model_types is not None and len(args.model_types)==1 and args.model_types[0] in [0, 6, 23] and config.data.dataset in ['tinyImages', 'CIFAR10']: |
| config.sampling.batch_size = min(200, config.sampling.batch_size) |
| if args.model_types is not None and len(args.model_types) == 1 and args.model_types[0] in [8] and config.data.dataset in ['tinyImages', 'CIFAR10']: |
| config.sampling.batch_size = min(800, config.sampling.batch_size) |
|
|
| if args.ODI_steps == -1: |
| args.ODI_steps = None |
| if args.fid_num_samples != 0: |
| config.fast_fid.num_samples = args.fid_num_samples |
| if args.begin_ckpt != 0: |
| config.fast_fid.begin_ckpt = args.begin_ckpt |
| config.sampling.ckpt_id = args.begin_ckpt |
| if args.end_ckpt != 0: |
| config.fast_fid.end_ckpt = args.begin_ckpt |
| if args.adam: |
| config.optim.beta1 = args.adam_beta[0] |
| config.optim.beta2 = args.adam_beta[1] |
| if args.D_adam: |
| config.optim.adv_beta1 = args.D_adam_beta[0] |
| config.optim.adv_beta2 = args.D_adam_beta[1] |
| if args.D_steps != 0: |
| config.adversarial.D_steps = args.D_steps |
|
|
| return config |
|
|
|
|
| def _get_raw_config(name): |
| here = os.path.dirname(os.path.abspath(__file__)) |
| with open(os.path.join(here, name), 'r') as f: |
| yaml_dict = yaml.load(f, Loader=yaml.FullLoader) |
| return yaml_dict |
|
|
|
|
| def setdefault(config, default): |
| |
| for x in default: |
| v = default.get(x) |
| if isinstance(v, dict) and x in config: |
| setdefault(config.get(x), v) |
| else: |
| config.setdefault(x, v) |
| return config |
|
|
|
|
| def dict2namespace(config): |
| namespace = argparse.Namespace() |
| for key, value in config.items(): |
| if isinstance(value, dict): |
| new_value = dict2namespace(value) |
| else: |
| new_value = value |
| setattr(namespace, key, new_value) |
| return namespace |
|
|
|
|
| def print_config(config): |
| print(">" * 80) |
| print(yaml.dump(config, default_flow_style=False)) |
| print("<" * 80) |
|
|
|
|