| from typing import List, Optional, Union |
|
|
| import torch |
| import torch.distributed |
| from torchpack import distributed |
|
|
| from utils.misc import list_mean, list_sum |
|
|
| __all__ = ["ddp_reduce_tensor", "DistributedMetric"] |
|
|
|
|
| def ddp_reduce_tensor( |
| tensor: torch.Tensor, reduce="mean" |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())] |
| torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) |
| if reduce == "mean": |
| return list_mean(tensor_list) |
| elif reduce == "sum": |
| return list_sum(tensor_list) |
| elif reduce == "cat": |
| return torch.cat(tensor_list, dim=0) |
| elif reduce == "root": |
| return tensor_list[0] |
| else: |
| return tensor_list |
|
|
|
|
| class DistributedMetric(object): |
| """Average metrics for distributed training.""" |
|
|
| def __init__(self, name: Optional[str] = None, backend="ddp"): |
| self.name = name |
| self.sum = 0 |
| self.count = 0 |
| self.backend = backend |
|
|
| def update(self, val: Union[torch.Tensor, int, float], delta_n=1): |
| val *= delta_n |
| if type(val) in [int, float]: |
| val = torch.Tensor(1).fill_(val).cuda() |
| if self.backend == "ddp": |
| self.count += ddp_reduce_tensor( |
| torch.Tensor(1).fill_(delta_n).cuda(), reduce="sum" |
| ) |
| self.sum += ddp_reduce_tensor(val.detach(), reduce="sum") |
| else: |
| raise NotImplementedError |
|
|
| @property |
| def avg(self): |
| if self.count == 0: |
| return torch.Tensor(1).fill_(-1) |
| else: |
| return self.sum / self.count |
|
|