File size: 2,197 Bytes
796e051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
'''
    Util for collate function and related needs
'''
import os, sys, shutil
from typing import Optional, List
import torch
from torch import Tensor


# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from util.misc import NestedTensor





def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes



def nested_tensor_from_tensor_list(tensor_list: List[Tensor], split=True):
    # Modified from VisTR, which shows a possible solution to handle video inputs

    # Split all video frames to one list, like an image form
    if split:
        # tensor_list = [tensor.split(3, dim=0) for tensor in tensor_list]            
        tensor_list = [item for sublist in tensor_list for item in sublist]           # The length of tensor_list equals to Batch Size * #Frames


    # Process each single one
    if tensor_list[0].ndim == 3:                # Expected (C, H, W) dimension

        # Same as DETR
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)

        # Add Padding
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False

    else:

        raise ValueError('not supported')


    # Return Nested Tensor Form
    return NestedTensor(tensor, mask)           # tensor shape is (B*F, C, H, W) and mask shape is (B*F, H, W)



def collate_fn(batch):
    
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])     # 0: Video Inputs;  1: GT Labels

    return tuple(batch)