| import os |
| from typing import List |
| from typing import Union |
|
|
| import sys |
| import torch |
| import torch.distributed as tdist |
| import torch.multiprocessing as mp |
|
|
| __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' |
| __initialized = False |
|
|
|
|
| def initialized(): |
| return __initialized |
|
|
|
|
| def initialize(backend='nccl'): |
| global __device |
| if not torch.cuda.is_available(): |
| print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) |
| return |
| elif 'RANK' not in os.environ: |
| __device = torch.empty(1).cuda().device |
| print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr) |
| return |
| |
| |
| if mp.get_start_method(allow_none=True) is None: |
| mp.set_start_method('spawn') |
| global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() |
| local_rank = global_rank % num_gpus |
| torch.cuda.set_device(local_rank) |
| tdist.init_process_group(backend=backend) |
| |
| global __rank, __local_rank, __world_size, __initialized |
| __local_rank = local_rank |
| __rank, __world_size = tdist.get_rank(), tdist.get_world_size() |
| __device = torch.empty(1).cuda().device |
| __initialized = True |
| |
| assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
|
|
|
|
| def get_rank(): |
| return __rank |
|
|
|
|
| def get_local_rank(): |
| return __local_rank |
|
|
|
|
| def get_world_size(): |
| return __world_size |
|
|
|
|
| def get_device(): |
| return __device |
|
|
|
|
| def is_master(): |
| return __rank == 0 |
|
|
|
|
| def is_local_master(): |
| return __local_rank == 0 |
|
|
|
|
| def barrier(): |
| if __initialized: |
| tdist.barrier() |
|
|
|
|
| def parallelize(net, syncbn=False): |
| if syncbn: |
| net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) |
| net = net.cuda() |
| net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
| return net |
|
|
|
|
| def allreduce(t: torch.Tensor) -> None: |
| if __initialized: |
| if not t.is_cuda: |
| cu = t.detach().cuda() |
| tdist.all_reduce(cu) |
| t.copy_(cu.cpu()) |
| else: |
| tdist.all_reduce(t) |
|
|
|
|
| def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
| if __initialized: |
| if not t.is_cuda: |
| t = t.cuda() |
| ls = [torch.empty_like(t) for _ in range(__world_size)] |
| tdist.all_gather(ls, t) |
| else: |
| ls = [t] |
| if cat: |
| ls = torch.cat(ls, dim=0) |
| return ls |
|
|
|
|
| def broadcast(t: torch.Tensor, src_rank) -> None: |
| if __initialized: |
| if not t.is_cuda: |
| cu = t.detach().cuda() |
| tdist.broadcast(cu, src=src_rank) |
| t.copy_(cu.cpu()) |
| else: |
| tdist.broadcast(t, src=src_rank) |
|
|