| import io |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def _load_checkpoint_for_ema(model_ema, checkpoint): |
| """ |
| Workaround for ModelEma._load_checkpoint to accept an already-loaded object |
| """ |
| mem_file = io.BytesIO() |
| torch.save(checkpoint, mem_file) |
| mem_file.seek(0) |
| model_ema._load_checkpoint(mem_file) |
|
|
|
|
| def setup_for_distributed(is_master): |
| """ |
| This function disables printing when not in master process |
| """ |
| import builtins as __builtin__ |
| builtin_print = __builtin__.print |
|
|
| def print(*args, **kwargs): |
| force = kwargs.pop('force', False) |
| if is_master or force: |
| builtin_print(*args, **kwargs) |
|
|
| __builtin__.print = print |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
|
|
| def get_world_size(): |
| if not is_dist_avail_and_initialized(): |
| return 1 |
| return dist.get_world_size() |
|
|
|
|
| def get_rank(): |
| if not is_dist_avail_and_initialized(): |
| return 0 |
| return dist.get_rank() |
|
|
|
|
| def is_main_process(): |
| return get_rank() == 0 |
|
|
|
|
| def save_on_master(*args, **kwargs): |
| if is_main_process(): |
| torch.save(*args, **kwargs) |
|
|
|
|
| def init_distributed_mode(args): |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| args.rank = int(os.environ["RANK"]) |
| args.world_size = int(os.environ['WORLD_SIZE']) |
| args.gpu = int(os.environ['LOCAL_RANK']) |
| elif 'SLURM_PROCID' in os.environ: |
| args.rank = int(os.environ['SLURM_PROCID']) |
| args.gpu = args.rank % torch.cuda.device_count() |
| else: |
| print('Not using distributed mode') |
| args.distributed = False |
| return |
|
|
| args.distributed = True |
|
|
| torch.cuda.set_device(args.gpu) |
| args.dist_backend = 'nccl' |
| print('| distributed init (rank {}): {}'.format( |
| args.rank, args.dist_url), flush=True) |
| torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| world_size=args.world_size, rank=args.rank) |
| torch.distributed.barrier() |
| setup_for_distributed(args.rank == 0) |
|
|
|
|
| def format_step(step): |
| if isinstance(step, str): |
| return step |
| s = "" |
| if len(step) > 0: |
| s += "Training Epoch: {} ".format(step[0]) |
| if len(step) > 1: |
| s += "Training Iteration: {} ".format(step[1]) |
| if len(step) > 2: |
| s += "Validation Iteration: {} ".format(step[2]) |
| return s |
|
|