| import io |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
| import psutil |
|
|
|
|
| _print = print |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def get_world_size(): return int(os.getenv('WORLD_SIZE', 1)) |
| def get_rank(): return int(os.getenv('RANK', 0)) |
| def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0)) |
|
|
| |
| def is_dist(): |
| return dist.is_available() and dist.is_initialized() and get_world_size() > 1 |
|
|
| def get_current_memory_gb(): |
| |
| pid = os.getpid() |
| p = psutil.Process(pid) |
| info = p.memory_full_info() |
| return info.uss / 1024. / 1024. / 1024. |
|
|
| |
| |
| |
| def print(*argc, all=False, **kwargs): |
| if not is_dist(): |
| _print(*argc, **kwargs) |
| return |
|
|
| if not all and get_local_rank() != 0: |
| return |
|
|
| output = io.StringIO() |
| kwargs['end'] = '' |
| kwargs['file'] = output |
| kwargs['flush'] = True |
| _print(*argc, **kwargs) |
|
|
| s = output.getvalue() |
| output.close() |
|
|
| s = '[rank {}] {}'.format(dist.get_rank(), s) |
| _print(s) |
|
|
|
|
| |
| |
| def reduce_mean(tensor, nprocs=None): |
| if not is_dist(): |
| return tensor |
| if not isinstance(tensor, torch.Tensor): |
| device = torch.cuda.current_device() |
| rt = torch.tensor(tensor, device=device) |
| else: |
| rt = tensor.clone() |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| nprocs = nprocs if nprocs else dist.get_world_size() |
| rt = rt / nprocs |
| if not isinstance(tensor, torch.Tensor): |
| rt = rt.item() |
| return rt |
|
|
|
|
| |
| def reduce_sum(tensor): |
| if not is_dist(): |
| return tensor |
| if not isinstance(tensor, torch.Tensor): |
| device = torch.cuda.current_device() |
| rt = torch.tensor(tensor, device=device) |
| else: |
| rt = tensor.clone() |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| if not isinstance(tensor, torch.Tensor): |
| rt = rt.item() |
| return rt |
|
|
|
|
| |
| |
| def barrier(): |
| if not is_dist(): |
| return |
| dist.barrier() |