| |
| import argparse |
| import os |
| import random |
| import time |
| import logging |
| import numpy as np |
| from base import config |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser(description=' ') |
| parser.add_argument('--config', type=str, default='**.yaml', help='config file') |
| parser.add_argument('opts', help=' ', default=None, |
| nargs=argparse.REMAINDER) |
| args = parser.parse_args() |
| assert args.config is not None |
| cfg = config.load_cfg_from_cfg_file(args.config) |
| if args.opts is not None: |
| cfg = config.merge_cfg_from_list(cfg, args.opts) |
| return cfg |
|
|
|
|
| def get_logger(): |
| logger_name = "main-logger" |
| logger = logging.getLogger(logger_name) |
| logger.setLevel(logging.INFO) |
| handler = logging.StreamHandler() |
| fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s" |
| handler.setFormatter(logging.Formatter(fmt)) |
| logger.addHandler(handler) |
| return logger |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def check_mkdir(dir_name): |
| if not os.path.exists(dir_name): |
| os.mkdir(dir_name) |
|
|
|
|
| def check_makedirs(dir_name): |
| if not os.path.exists(dir_name): |
| os.makedirs(dir_name) |
|
|
|
|
| def main_process(args): |
| return not args.multiprocessing_distributed or ( |
| args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) |
|
|