| |
| |
| """ |
| Misc functions, including distributed helpers. |
| |
| Mostly copy-paste from torchvision references. |
| """ |
| from typing import List, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torchvision |
| from torch import Tensor |
|
|
|
|
| def _max_by_axis(the_list): |
| |
| maxes = the_list[0] |
| for sublist in the_list[1:]: |
| for index, item in enumerate(sublist): |
| maxes[index] = max(maxes[index], item) |
| return maxes |
|
|
|
|
| class NestedTensor(object): |
| def __init__(self, tensors, mask: Optional[Tensor]): |
| self.tensors = tensors |
| self.mask = mask |
|
|
| def to(self, device): |
| |
| cast_tensor = self.tensors.to(device) |
| mask = self.mask |
| if mask is not None: |
| assert mask is not None |
| cast_mask = mask.to(device) |
| else: |
| cast_mask = None |
| return NestedTensor(cast_tensor, cast_mask) |
|
|
| def decompose(self): |
| return self.tensors, self.mask |
|
|
| def __repr__(self): |
| return str(self.tensors) |
|
|
|
|
| def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
| |
| if tensor_list[0].ndim == 3: |
| if torchvision._is_tracing(): |
| |
| |
| return _onnx_nested_tensor_from_tensor_list(tensor_list) |
|
|
| |
| max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
| |
| batch_shape = [len(tensor_list)] + max_size |
| b, c, h, w = batch_shape |
| dtype = tensor_list[0].dtype |
| device = tensor_list[0].device |
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
| for img, pad_img, m in zip(tensor_list, tensor, mask): |
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
| m[: img.shape[1], : img.shape[2]] = False |
| else: |
| raise ValueError("not supported") |
| return NestedTensor(tensor, mask) |
|
|
|
|
| |
| |
| @torch.jit.unused |
| def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
| max_size = [] |
| for i in range(tensor_list[0].dim()): |
| max_size_i = torch.max( |
| torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
| ).to(torch.int64) |
| max_size.append(max_size_i) |
| max_size = tuple(max_size) |
|
|
| |
| |
| |
| |
| padded_imgs = [] |
| padded_masks = [] |
| for img in tensor_list: |
| padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
| padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
| padded_imgs.append(padded_img) |
|
|
| m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
| padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
| padded_masks.append(padded_mask.to(torch.bool)) |
|
|
| tensor = torch.stack(padded_imgs) |
| mask = torch.stack(padded_masks) |
|
|
| return NestedTensor(tensor, mask=mask) |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|