Spaces:
Running on Zero
Running on Zero
File size: 2,197 Bytes
796e051 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | '''
Util for collate function and related needs
'''
import os, sys, shutil
from typing import Optional, List
import torch
from torch import Tensor
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from util.misc import NestedTensor
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def nested_tensor_from_tensor_list(tensor_list: List[Tensor], split=True):
# Modified from VisTR, which shows a possible solution to handle video inputs
# Split all video frames to one list, like an image form
if split:
# tensor_list = [tensor.split(3, dim=0) for tensor in tensor_list]
tensor_list = [item for sublist in tensor_list for item in sublist] # The length of tensor_list equals to Batch Size * #Frames
# Process each single one
if tensor_list[0].ndim == 3: # Expected (C, H, W) dimension
# Same as DETR
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
# Add Padding
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
# Return Nested Tensor Form
return NestedTensor(tensor, mask) # tensor shape is (B*F, C, H, W) and mask shape is (B*F, H, W)
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0]) # 0: Video Inputs; 1: GT Labels
return tuple(batch)
|