| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.backends.cudnn as cudnn |
| |
| |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
|
|
| import os |
| import os.path as osp |
| import sys |
| import numpy as np |
| import pprint |
| import timeit |
| import time |
| import copy |
| import matplotlib.pyplot as plt |
|
|
| from .cfg_holder import cfg_unique_holder as cfguh |
|
|
| from .data_factory import \ |
| get_dataset, collate, \ |
| get_loader, \ |
| get_transform, \ |
| get_estimator, \ |
| get_formatter, \ |
| get_sampler |
|
|
| from .model_zoo import \ |
| get_model, get_optimizer, get_scheduler |
|
|
| from .log_service import print_log, distributed_log_manager |
|
|
| from .evaluator import get_evaluator |
| from . import sync |
|
|
| class train_stage(object): |
| """ |
| This is a template for a train stage, |
| (can be either train or test or anything) |
| Usually, it takes RANK |
| one dataloader, one model, one optimizer, one scheduler. |
| But it is not limited to these parameters. |
| """ |
| def __init__(self): |
| self.nested_eval_stage = None |
| self.rv_keep = None |
|
|
| def is_better(self, x): |
| return (self.rv_keep is None) or (x>self.rv_keep) |
|
|
| def set_model(self, net, mode): |
| if mode == 'train': |
| return net.train() |
| elif mode == 'eval': |
| return net.eval() |
| else: |
| raise ValueError |
|
|
| def __call__(self, |
| **paras): |
| cfg = cfguh().cfg |
| cfgt = cfg.train |
| logm = distributed_log_manager() |
| epochn, itern, samplen = 0, 0, 0 |
|
|
| step_type = cfgt.get('step_type', 'iter') |
| assert step_type in ['epoch', 'iter', 'sample'], \ |
| 'Step type must be in [epoch, iter, sample]' |
| |
| step_num = cfgt.get('step_num' , None) |
| gradacc_every = cfgt.get('gradacc_every', 1 ) |
| log_every = cfgt.get('log_every' , None) |
| ckpt_every = cfgt.get('ckpt_every' , None) |
| eval_start = cfgt.get('eval_start' , 0 ) |
| eval_every = cfgt.get('eval_every' , None) |
|
|
| if paras.get('resume_step', None) is not None: |
| resume_step = paras['resume_step'] |
| assert step_type == resume_step['type'] |
| epochn = resume_step['epochn'] |
| itern = resume_step['itern'] |
| samplen = resume_step['samplen'] |
| del paras['resume_step'] |
|
|
| trainloader = paras['trainloader'] |
| optimizer = paras['optimizer'] |
| scheduler = paras['scheduler'] |
| net = paras['net'] |
|
|
| GRANK, LRANK, NRANK = sync.get_rank('all') |
| GWSIZE, LWSIZE, NODES = sync.get_world_size('all') |
|
|
| weight_path = osp.join(cfgt.log_dir, 'weight') |
| if (GRANK==0) and (not osp.isdir(weight_path)): |
| os.makedirs(weight_path) |
| if (GRANK==0) and (cfgt.save_init_model): |
| self.save(net, is_init=True, step=0, optimizer=optimizer) |
|
|
| epoch_time = timeit.default_timer() |
| end_flag = False |
| net.train() |
|
|
| while True: |
| if step_type == 'epoch': |
| lr = scheduler[epochn] if scheduler is not None else None |
| for batch in trainloader: |
| |
| if not isinstance(batch[0], list): |
| bs = batch[0].shape[0] |
| else: |
| bs = len(batch[0]) |
| if cfgt.skip_partial_batch and (bs != cfgt.batch_size_per_gpu): |
| continue |
|
|
| itern_next = itern + 1 |
| samplen_next = samplen + bs*GWSIZE |
|
|
| if step_type == 'iter': |
| lr = scheduler[itern//gradacc_every] if scheduler is not None else None |
| grad_update = itern%gradacc_every==(gradacc_every-1) |
| elif step_type == 'sample': |
| lr = scheduler[samplen] if scheduler is not None else None |
| |
| |
|
|
| |
| paras_new = self.main( |
| batch=batch, |
| lr=lr, |
| itern=itern, |
| epochn=epochn, |
| samplen=samplen, |
| isinit=False, |
| grad_update=grad_update, |
| **paras) |
| |
|
|
| paras.update(paras_new) |
| logm.accumulate(bs, **paras['log_info']) |
|
|
| |
| |
| |
|
|
| display_flag = False |
| if log_every is not None: |
| display_i = (itern//log_every) != (itern_next//log_every) |
| display_s = (samplen//log_every) != (samplen_next//log_every) |
| display_flag = (display_i and (step_type=='iter')) \ |
| or (display_s and (step_type=='sample')) |
|
|
| if display_flag: |
| tbstep = itern_next if step_type=='iter' else samplen_next |
| console_info = logm.train_summary( |
| itern_next, epochn, samplen_next, lr, tbstep=tbstep) |
| logm.clear() |
| print_log(console_info) |
|
|
| |
| |
| |
|
|
| eval_flag = False |
| if (self.nested_eval_stage is not None) and (eval_every is not None) and (NRANK == 0): |
| if step_type=='iter': |
| eval_flag = (itern//eval_every) != (itern_next//eval_every) |
| eval_flag = eval_flag and (itern_next>=eval_start) |
| eval_flag = eval_flag or itern==0 |
| if step_type=='sample': |
| eval_flag = (samplen//eval_every) != (samplen_next//eval_every) |
| eval_flag = eval_flag and (samplen_next>=eval_start) |
| eval_flag = eval_flag or samplen==0 |
|
|
| if eval_flag: |
| eval_cnt = itern_next if step_type=='iter' else samplen_next |
| net = self.set_model(net, 'eval') |
| rv = self.nested_eval_stage( |
| eval_cnt=eval_cnt, **paras) |
| rv = rv.get('eval_rv', None) |
| if rv is not None: |
| logm.tensorboard_log(eval_cnt, rv, mode='eval') |
| if self.is_better(rv): |
| self.rv_keep = rv |
| if GRANK==0: |
| step = {'epochn':epochn, 'itern':itern_next, |
| 'samplen':samplen_next, 'type':step_type, } |
| self.save(net, is_best=True, step=step, optimizer=optimizer) |
| net = self.set_model(net, 'train') |
|
|
| |
| |
| |
|
|
| ckpt_flag = False |
| if (GRANK==0) and (ckpt_every is not None): |
| |
| ckpt_i = (itern//ckpt_every) != (itern_next//ckpt_every) |
| ckpt_s = (samplen//ckpt_every) != (samplen_next//ckpt_every) |
| ckpt_flag = (ckpt_i and (step_type=='iter')) \ |
| or (ckpt_s and (step_type=='sample')) |
|
|
| if ckpt_flag: |
| if step_type == 'iter': |
| print_log('Checkpoint... {}'.format(itern_next)) |
| step = {'epochn':epochn, 'itern':itern_next, |
| 'samplen':samplen_next, 'type':step_type, } |
| self.save(net, itern=itern_next, step=step, optimizer=optimizer) |
| else: |
| print_log('Checkpoint... {}'.format(samplen_next)) |
| step = {'epochn':epochn, 'itern':itern_next, |
| 'samplen':samplen_next, 'type':step_type, } |
| self.save(net, samplen=samplen_next, step=step, optimizer=optimizer) |
|
|
| |
| |
| |
|
|
| itern = itern_next |
| samplen = samplen_next |
|
|
| if step_type is not None: |
| end_flag = (itern>=step_num and (step_type=='iter')) \ |
| or (samplen>=step_num and (step_type=='sample')) |
| if end_flag: |
| break |
| |
|
|
| epochn += 1 |
| print_log('Epoch {} time:{:.2f}s.'.format( |
| epochn, timeit.default_timer()-epoch_time)) |
| epoch_time = timeit.default_timer() |
|
|
| if end_flag: |
| break |
| elif step_type != 'epoch': |
| |
| trainloader = self.trick_update_trainloader(trainloader) |
| continue |
|
|
| |
| |
| |
|
|
| display_flag = False |
| if (log_every is not None) and (step_type=='epoch'): |
| display_flag = (epochn==1) or (epochn%log_every==0) |
|
|
| if display_flag: |
| console_info = logm.train_summary( |
| itern, epochn, samplen, lr, tbstep=epochn) |
| logm.clear() |
| print_log(console_info) |
|
|
| |
| |
| |
|
|
| eval_flag = False |
| if (self.nested_eval_stage is not None) and (eval_every is not None) \ |
| and (step_type=='epoch') and (NRANK==0): |
| eval_flag = (epochn%eval_every==0) and (itern_next>=eval_start) |
| eval_flag = (epochn==1) or eval_flag |
|
|
| if eval_flag: |
| net = self.set_model(net, 'eval') |
| rv = self.nested_eval_stage( |
| eval_cnt=epochn, |
| **paras)['eval_rv'] |
| if rv is not None: |
| logm.tensorboard_log(epochn, rv, mode='eval') |
| if self.is_better(rv): |
| self.rv_keep = rv |
| if (GRANK==0): |
| step = {'epochn':epochn, 'itern':itern, |
| 'samplen':samplen, 'type':step_type, } |
| self.save(net, is_best=True, step=step, optimizer=optimizer) |
| net = self.set_model(net, 'train') |
|
|
| |
| |
| |
|
|
| ckpt_flag = False |
| if (ckpt_every is not None) and (GRANK==0) and (step_type=='epoch'): |
| |
| ckpt_flag = epochn%ckpt_every==0 |
|
|
| if ckpt_flag: |
| print_log('Checkpoint... {}'.format(itern_next)) |
| step = {'epochn':epochn, 'itern':itern, |
| 'samplen':samplen, 'type':step_type, } |
| self.save(net, epochn=epochn, step=step, optimizer=optimizer) |
|
|
| |
| |
| |
| if (step_type=='epoch') and (epochn>=step_num): |
| break |
| |
|
|
| |
| trainloader = self.trick_update_trainloader(trainloader) |
|
|
| logm.tensorboard_close() |
| return {} |
|
|
| def main(self, **paras): |
| raise NotImplementedError |
|
|
| def trick_update_trainloader(self, trainloader): |
| return trainloader |
|
|
| def save_model(self, net, path_noext, **paras): |
| cfgt = cfguh().cfg.train |
| path = path_noext+'.pth' |
| if isinstance(net, (torch.nn.DataParallel, |
| torch.nn.parallel.DistributedDataParallel)): |
| netm = net.module |
| else: |
| netm = net |
| torch.save(netm.state_dict(), path) |
| print_log('Saving model file {0}'.format(path)) |
|
|
| def save(self, net, itern=None, epochn=None, samplen=None, |
| is_init=False, is_best=False, is_last=False, **paras): |
| exid = cfguh().cfg.env.experiment_id |
| cfgt = cfguh().cfg.train |
| cfgm = cfguh().cfg.model |
| if isinstance(net, (torch.nn.DataParallel, |
| torch.nn.parallel.DistributedDataParallel)): |
| netm = net.module |
| else: |
| netm = net |
| net_symbol = cfgm.symbol |
|
|
| check = sum([ |
| itern is not None, samplen is not None, epochn is not None, |
| is_init, is_best, is_last]) |
| assert check<2 |
|
|
| if itern is not None: |
| path_noexp = '{}_{}_iter_{}'.format(exid, net_symbol, itern) |
| elif samplen is not None: |
| path_noexp = '{}_{}_samplen_{}'.format(exid, net_symbol, samplen) |
| elif epochn is not None: |
| path_noexp = '{}_{}_epoch_{}'.format(exid, net_symbol, epochn) |
| elif is_init: |
| path_noexp = '{}_{}_init'.format(exid, net_symbol) |
| elif is_best: |
| path_noexp = '{}_{}_best'.format(exid, net_symbol) |
| elif is_last: |
| path_noexp = '{}_{}_last'.format(exid, net_symbol) |
| else: |
| path_noexp = '{}_{}_default'.format(exid, net_symbol) |
|
|
| path_noexp = osp.join(cfgt.log_dir, 'weight', path_noexp) |
| self.save_model(net, path_noexp, **paras) |
|
|
| class eval_stage(object): |
| def __init__(self): |
| self.evaluator = None |
|
|
| def create_dir(self, path): |
| local_rank = sync.get_rank('local') |
| if (not osp.isdir(path)) and (local_rank == 0): |
| os.makedirs(path) |
| sync.nodewise_sync().barrier() |
|
|
| def __call__(self, |
| evalloader, |
| net, |
| **paras): |
| cfgt = cfguh().cfg.eval |
| local_rank = sync.get_rank('local') |
| if self.evaluator is None: |
| evaluator = get_evaluator()(cfgt.evaluator) |
| self.evaluator = evaluator |
| else: |
| evaluator = self.evaluator |
|
|
| time_check = timeit.default_timer() |
|
|
| for idx, batch in enumerate(evalloader): |
| rv = self.main(batch, net) |
| evaluator.add_batch(**rv) |
| if cfgt.output_result: |
| try: |
| self.output_f(**rv, cnt=paras['eval_cnt']) |
| except: |
| self.output_f(**rv) |
| if idx%cfgt.log_display == cfgt.log_display-1: |
| print_log('processed.. {}, Time:{:.2f}s'.format( |
| idx+1, timeit.default_timer() - time_check)) |
| time_check = timeit.default_timer() |
| |
|
|
| evaluator.set_sample_n(len(evalloader.dataset)) |
| eval_rv = evaluator.compute() |
| if local_rank == 0: |
| evaluator.one_line_summary() |
| evaluator.save(cfgt.log_dir) |
| evaluator.clear_data() |
| return { |
| 'eval_rv' : eval_rv |
| } |
|
|
| class exec_container(object): |
| """ |
| This is the base functor for all types of executions. |
| One execution can have multiple stages, |
| but are only allowed to use the same |
| config, network, dataloader. |
| Thus, in most of the cases, one exec_container is one |
| training/evaluation/demo... |
| If DPP is in use, this functor should be spawn. |
| """ |
| def __init__(self, |
| cfg, |
| **kwargs): |
| self.cfg = cfg |
| self.registered_stages = [] |
| self.node_rank = None |
| self.local_rank = None |
| self.global_rank = None |
| self.local_world_size = None |
| self.global_world_size = None |
| self.nodewise_sync_global_obj = sync.nodewise_sync_global() |
|
|
| def register_stage(self, stage): |
| self.registered_stages.append(stage) |
|
|
| def __call__(self, |
| local_rank, |
| **kwargs): |
| cfg = self.cfg |
| cfguh().save_cfg(cfg) |
|
|
| self.node_rank = cfg.env.node_rank |
| self.local_rank = local_rank |
| self.nodes = cfg.env.nodes |
| self.local_world_size = cfg.env.gpu_count |
|
|
| self.global_rank = self.local_rank + self.node_rank * self.nodes |
| self.global_world_size = self.nodes * self.local_world_size |
|
|
| dist.init_process_group( |
| backend = cfg.env.dist_backend, |
| init_method = cfg.env.dist_url, |
| rank = self.global_rank, |
| world_size = self.global_world_size,) |
| torch.cuda.set_device(local_rank) |
| sync.nodewise_sync().copy_global(self.nodewise_sync_global_obj).local_init() |
| |
| if isinstance(cfg.env.rnd_seed, int): |
| np.random.seed(cfg.env.rnd_seed + self.global_rank) |
| torch.manual_seed(cfg.env.rnd_seed + self.global_rank) |
|
|
| time_start = timeit.default_timer() |
|
|
| para = {'itern_total' : 0,} |
| dl_para = self.prepare_dataloader() |
| assert isinstance(dl_para, dict) |
| para.update(dl_para) |
|
|
| md_para = self.prepare_model() |
| assert isinstance(md_para, dict) |
| para.update(md_para) |
|
|
| for stage in self.registered_stages: |
| stage_para = stage(**para) |
| if stage_para is not None: |
| para.update(stage_para) |
|
|
| if self.global_rank==0: |
| self.save_last_model(**para) |
|
|
| print_log( |
| 'Total {:.2f} seconds'.format(timeit.default_timer() - time_start)) |
| dist.destroy_process_group() |
|
|
| def prepare_dataloader(self): |
| """ |
| Prepare the dataloader from config. |
| """ |
| return { |
| 'trainloader' : None, |
| 'evalloader' : None} |
|
|
| def prepare_model(self): |
| """ |
| Prepare the model from config. |
| """ |
| return {'net' : None} |
|
|
| def save_last_model(self, **para): |
| return |
|
|
| def destroy(self): |
| self.nodewise_sync_global_obj.destroy() |
|
|
| class train(exec_container): |
| def prepare_dataloader(self): |
| cfg = cfguh().cfg |
| trainset = get_dataset()(cfg.train.dataset) |
| sampler = get_sampler()( |
| dataset=trainset, cfg=cfg.train.dataset.get('sampler', 'default_train')) |
| trainloader = torch.utils.data.DataLoader( |
| trainset, |
| batch_size = cfg.train.batch_size_per_gpu, |
| sampler = sampler, |
| num_workers = cfg.train.dataset_num_workers_per_gpu, |
| drop_last = False, |
| pin_memory = cfg.train.dataset.get('pin_memory', False), |
| collate_fn = collate(),) |
|
|
| evalloader = None |
| if 'eval' in cfg: |
| evalset = get_dataset()(cfg.eval.dataset) |
| if evalset is not None: |
| sampler = get_sampler()( |
| dataset=evalset, cfg=cfg.eval.dataset.get('sampler', 'default_eval')) |
| evalloader = torch.utils.data.DataLoader( |
| evalset, |
| batch_size = cfg.eval.batch_size_per_gpu, |
| sampler = sampler, |
| num_workers = cfg.eval.dataset_num_workers_per_gpu, |
| drop_last = False, |
| pin_memory = cfg.eval.dataset.get('pin_memory', False), |
| collate_fn = collate(),) |
| |
| return { |
| 'trainloader' : trainloader, |
| 'evalloader' : evalloader,} |
|
|
| def prepare_model(self): |
| cfg = cfguh().cfg |
| net = get_model()(cfg.model) |
| if cfg.env.cuda: |
| net.to(self.local_rank) |
| net = torch.nn.parallel.DistributedDataParallel( |
| net, device_ids=[self.local_rank], |
| find_unused_parameters=True) |
| net.train() |
| scheduler = get_scheduler()(cfg.train.scheduler) |
| optimizer = get_optimizer()(net, cfg.train.optimizer) |
| return { |
| 'net' : net, |
| 'optimizer' : optimizer, |
| 'scheduler' : scheduler,} |
|
|
| def save_last_model(self, **para): |
| cfgt = cfguh().cfg.train |
| net = para['net'] |
| net_symbol = cfguh().cfg.model.symbol |
| if isinstance(net, (torch.nn.DataParallel, |
| torch.nn.parallel.DistributedDataParallel)): |
| netm = net.module |
| else: |
| netm = net |
| path = osp.join(cfgt.log_dir, '{}_{}_last.pth'.format( |
| cfgt.experiment_id, net_symbol)) |
| torch.save(netm.state_dict(), path) |
| print_log('Saving model file {0}'.format(path)) |
|
|
| class eval(exec_container): |
| def prepare_dataloader(self): |
| cfg = cfguh().cfg |
| evalloader = None |
| if cfg.eval.get('dataset', None) is not None: |
| evalset = get_dataset()(cfg.eval.dataset) |
| if evalset is None: |
| return |
| sampler = get_sampler()( |
| dataset=evalset, cfg=getattr(cfg.eval.dataset, 'sampler', 'default_eval')) |
| evalloader = torch.utils.data.DataLoader( |
| evalset, |
| batch_size = cfg.eval.batch_size_per_gpu, |
| sampler = sampler, |
| num_workers = cfg.eval.dataset_num_workers_per_gpu, |
| drop_last = False, |
| pin_memory = False, |
| collate_fn = collate(), ) |
| return { |
| 'trainloader' : None, |
| 'evalloader' : evalloader,} |
|
|
| def prepare_model(self): |
| cfg = cfguh().cfg |
| net = get_model()(cfg.model) |
| if cfg.env.cuda: |
| net.to(self.local_rank) |
| net = torch.nn.parallel.DistributedDataParallel( |
| net, device_ids=[self.local_rank], |
| find_unused_parameters=True) |
| net.eval() |
| return {'net' : net,} |
|
|
| def save_last_model(self, **para): |
| return |
|
|
| |
| |
| |
|
|
| def torch_to_numpy(*argv): |
| if len(argv) > 1: |
| data = list(argv) |
| else: |
| data = argv[0] |
|
|
| if isinstance(data, torch.Tensor): |
| return data.to('cpu').detach().numpy() |
| elif isinstance(data, (list, tuple)): |
| out = [] |
| for di in data: |
| out.append(torch_to_numpy(di)) |
| return out |
| elif isinstance(data, dict): |
| out = {} |
| for ni, di in data.items(): |
| out[ni] = torch_to_numpy(di) |
| return out |
| else: |
| return data |
|
|
| import importlib |
|
|
| def get_obj_from_str(string, reload=False): |
| module, cls = string.rsplit(".", 1) |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|