OmniShotCut / datasets /utils.py
HikariDawn's picture
feat: initial push
796e051
'''
Util for collate function and related needs
'''
import os, sys, shutil
from typing import Optional, List
import torch
from torch import Tensor
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from util.misc import NestedTensor
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def nested_tensor_from_tensor_list(tensor_list: List[Tensor], split=True):
# Modified from VisTR, which shows a possible solution to handle video inputs
# Split all video frames to one list, like an image form
if split:
# tensor_list = [tensor.split(3, dim=0) for tensor in tensor_list]
tensor_list = [item for sublist in tensor_list for item in sublist] # The length of tensor_list equals to Batch Size * #Frames
# Process each single one
if tensor_list[0].ndim == 3: # Expected (C, H, W) dimension
# Same as DETR
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
# Add Padding
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
# Return Nested Tensor Form
return NestedTensor(tensor, mask) # tensor shape is (B*F, C, H, W) and mask shape is (B*F, H, W)
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0]) # 0: Video Inputs; 1: GT Labels
return tuple(batch)