| import argparse
|
| import random
|
| import torch
|
| import yaml
|
| from collections import OrderedDict
|
| from os import path as osp
|
|
|
| from basicsr.utils import set_random_seed
|
| from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
|
|
|
|
|
| def ordered_yaml():
|
| """Support OrderedDict for yaml.
|
|
|
| Returns:
|
| yaml Loader and Dumper.
|
| """
|
| try:
|
| from yaml import CDumper as Dumper
|
| from yaml import CLoader as Loader
|
| except ImportError:
|
| from yaml import Dumper, Loader
|
|
|
| _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
|
|
| def dict_representer(dumper, data):
|
| return dumper.represent_dict(data.items())
|
|
|
| def dict_constructor(loader, node):
|
| return OrderedDict(loader.construct_pairs(node))
|
|
|
| Dumper.add_representer(OrderedDict, dict_representer)
|
| Loader.add_constructor(_mapping_tag, dict_constructor)
|
| return Loader, Dumper
|
|
|
|
|
| def dict2str(opt, indent_level=1):
|
| """dict to string for printing options.
|
|
|
| Args:
|
| opt (dict): Option dict.
|
| indent_level (int): Indent level. Default: 1.
|
|
|
| Return:
|
| (str): Option string for printing.
|
| """
|
| msg = '\n'
|
| for k, v in opt.items():
|
| if isinstance(v, dict):
|
| msg += ' ' * (indent_level * 2) + k + ':['
|
| msg += dict2str(v, indent_level + 1)
|
| msg += ' ' * (indent_level * 2) + ']\n'
|
| else:
|
| msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
| return msg
|
|
|
|
|
| def _postprocess_yml_value(value):
|
|
|
| if value == '~' or value.lower() == 'none':
|
| return None
|
|
|
| if value.lower() == 'true':
|
| return True
|
| elif value.lower() == 'false':
|
| return False
|
|
|
| if value.startswith('!!float'):
|
| return float(value.replace('!!float', ''))
|
|
|
| if value.isdigit():
|
| return int(value)
|
| elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
|
| return float(value)
|
|
|
| if value.startswith('['):
|
| return eval(value)
|
|
|
| return value
|
|
|
|
|
| def parse_options(root_path, is_train=True):
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
| parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
| parser.add_argument('--auto_resume', action='store_true')
|
| parser.add_argument('--debug', action='store_true')
|
| parser.add_argument('--local_rank', type=int, default=0)
|
| parser.add_argument(
|
| '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
|
| args = parser.parse_args()
|
|
|
|
|
| with open(args.opt, mode='r') as f:
|
| opt = yaml.load(f, Loader=ordered_yaml()[0])
|
|
|
|
|
| if args.launcher == 'none':
|
| opt['dist'] = False
|
| print('Disable distributed.', flush=True)
|
| else:
|
| opt['dist'] = True
|
| if args.launcher == 'slurm' and 'dist_params' in opt:
|
| init_dist(args.launcher, **opt['dist_params'])
|
| else:
|
| init_dist(args.launcher)
|
| opt['rank'], opt['world_size'] = get_dist_info()
|
|
|
|
|
| seed = opt.get('manual_seed')
|
| if seed is None:
|
| seed = random.randint(1, 10000)
|
| opt['manual_seed'] = seed
|
| set_random_seed(seed + opt['rank'])
|
|
|
|
|
| if args.force_yml is not None:
|
| for entry in args.force_yml:
|
|
|
| keys, value = entry.split('=')
|
| keys, value = keys.strip(), value.strip()
|
| value = _postprocess_yml_value(value)
|
| eval_str = 'opt'
|
| for key in keys.split(':'):
|
| eval_str += f'["{key}"]'
|
| eval_str += '=value'
|
|
|
| exec(eval_str)
|
|
|
| opt['auto_resume'] = args.auto_resume
|
| opt['is_train'] = is_train
|
|
|
|
|
| if args.debug and not opt['name'].startswith('debug'):
|
| opt['name'] = 'debug_' + opt['name']
|
|
|
| if opt['num_gpu'] == 'auto':
|
| opt['num_gpu'] = torch.cuda.device_count()
|
|
|
|
|
| for phase, dataset in opt['datasets'].items():
|
|
|
| phase = phase.split('_')[0]
|
| dataset['phase'] = phase
|
| if 'scale' in opt:
|
| dataset['scale'] = opt['scale']
|
| if dataset.get('dataroot_gt') is not None:
|
| dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
| if dataset.get('dataroot_lq') is not None:
|
| dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
|
|
|
|
| for key, val in opt['path'].items():
|
| if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
| opt['path'][key] = osp.expanduser(val)
|
|
|
| if is_train:
|
| experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
| opt['path']['experiments_root'] = experiments_root
|
| opt['path']['models'] = osp.join(experiments_root, 'models')
|
| opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
| opt['path']['log'] = experiments_root
|
| opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
|
|
|
|
| if 'debug' in opt['name']:
|
| if 'val' in opt:
|
| opt['val']['val_freq'] = 8
|
| opt['logger']['print_freq'] = 1
|
| opt['logger']['save_checkpoint_freq'] = 8
|
| else:
|
| results_root = osp.join(root_path, 'results', opt['name'])
|
| opt['path']['results_root'] = results_root
|
| opt['path']['log'] = results_root
|
| opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
|
|
| return opt, args
|
|
|
|
|
| @master_only
|
| def copy_opt_file(opt_file, experiments_root):
|
|
|
| import sys
|
| import time
|
| from shutil import copyfile
|
| cmd = ' '.join(sys.argv)
|
| filename = osp.join(experiments_root, osp.basename(opt_file))
|
| copyfile(opt_file, filename)
|
|
|
| with open(filename, 'r+') as f:
|
| lines = f.readlines()
|
| lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
| f.seek(0)
|
| f.writelines(lines)
|
|
|