Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| """ | |
| Data Loader for OmniShotCut. Modified from DETR. | |
| """ | |
| import os, sys, shutil | |
| import random | |
| from typing import List, Union, Optional, Tuple | |
| from pathlib import Path | |
| import numpy as np | |
| import ffmpeg | |
| import torch | |
| import torch.utils.data | |
| import imageio | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from datasets.transforms import Video_Augmentation_Transform | |
| from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping | |
| from util.visualization import visualize_concated_frames | |
| def align_segments_to_crop(segments, crop_start, crop_len): | |
| """ | |
| segments: list of (start, end, intra_label, inter_label) in global frame indices, [start, end) | |
| returns: | |
| ab: list of (start2, end2) in cropped local indices | |
| ys: list of labels | |
| """ | |
| s = int(crop_start) | |
| e = s + int(crop_len) | |
| ab = [] # Refer to the time range | |
| intras = [] # Refer to the label | |
| inters = [] | |
| for start, end, intra, inter in segments: | |
| start = int(start) | |
| end = int(end) | |
| na = max(start, s) | |
| nb = min(end, e) | |
| if nb <= na: | |
| continue | |
| ab.append([na - s, nb - s]) | |
| intras.append(intra) | |
| inters.append(inter) | |
| # Change the first Inter to be new start | |
| inters[0] = unique_inter_label_mapping["new_start"] | |
| return ab, intras, inters | |
| def pad_to_length(x, N, pad_value=(1.0, 0.0)): | |
| K = x.shape[0] | |
| assert K <= N | |
| pad = torch.tensor(pad_value, dtype=x.dtype, device=x.device) | |
| pad = pad.unsqueeze(0).expand(N - K, 2) # (N-K, 2) | |
| return torch.cat([x, pad], dim=0) | |
| class CutAnything_Dataloader(Dataset): | |
| def __init__(self, args, set_type): | |
| # Fetch information | |
| self.set_type = set_type # "train" or "val" set for the dataloader | |
| self.args = args | |
| self.process_height = args.process_height | |
| self.process_width = args.process_width | |
| self.max_process_window_length = args.max_process_window_length # The max number of frames we need | |
| self.has_overlength_prob = args.has_overlength_prob # If we have overlength window | |
| self.max_padding_length = args.max_process_window_length - args.min_video_in_padding # Max padding frames allowed in max_process_window_length | |
| self.num_queries = args.num_queries | |
| # Choose Data Info | |
| if set_type == "train": | |
| data_info_path = args.train_data_info_path | |
| elif set_type == "val": | |
| data_info_path = args.val_data_info_path | |
| if not os.path.exists(data_info_path): | |
| print("We cannot find", data_info_path) | |
| assert(os.path.exists(data_info_path)) | |
| # Load pkl files | |
| data_info = [] | |
| for sub_pkl_name in sorted(os.listdir(data_info_path)): | |
| sub_pkl_path = os.path.join(data_info_path, sub_pkl_name) | |
| data_info.extend(np.load(sub_pkl_path, allow_pickle=True)) | |
| # Collect | |
| if set_type == "val" and args.max_val_num is not None: # None means to use all | |
| data_info = data_info[:args.max_val_num] | |
| print("Total number of", set_type, "dataset is", len(data_info)) | |
| self.data_info = data_info | |
| # Augmentation (Horizontal Flip + Color Jitter + Gray Scale + Blur) + Transform (ImageNet Normalization) | |
| if set_type == "train": | |
| self.video_transform = Video_Augmentation_Transform( | |
| set_type = "train", | |
| horizontal_flip_prob = 0.5, # Horizontal Flip | |
| vertical_flip_prob = 0.0, # Vertical Flip | |
| jitter_prob = 0.15, # Color Jitter Prob | |
| jitter_param = (0.05, 0.05, 0.05, 0.02), # Color Jitter | |
| grayscale_prob = 0.0, # GraryScale | |
| blur_prob = 0.03, # Blur | |
| blur_kernel_size = 3, # Should be odd number | |
| blur_sigma = (0.1, 0.3), | |
| noise_prob = 0.0, # Add Gaussian Noise | |
| noise_sigma = (0.003, 0.01), | |
| noise_clip = (0.0, 1.0), | |
| compression_prob = 0.05, # Image-based compression | |
| compression_choices = ["jpeg", "webp"], | |
| ) | |
| elif set_type == "val": | |
| self.video_transform = Video_Augmentation_Transform( | |
| set_type = "val" | |
| ) | |
| else: | |
| raise NotImplementedError("we do not support set type of", set_type) | |
| def __len__(self): | |
| return len(self.data_info) | |
| def __getitem__(self, idx): | |
| while True: # Iterate until there is a valid video read | |
| try: | |
| # Fetch | |
| data_dict = self.data_info[idx] | |
| video_path = data_dict["video_path"] | |
| gt_ranges = data_dict["transition_ranges"] | |
| gt_intra_labels = data_dict["transition_intra_labels"] | |
| gt_inter_labels = data_dict["transition_inter_labels"] | |
| fps = data_dict["fps"] | |
| assert(len(gt_ranges) == len(gt_intra_labels) and len(gt_ranges) == len(gt_inter_labels)) | |
| # Sanity Check | |
| if not os.path.exists(video_path): | |
| print("We cannot find", video_path) | |
| assert(os.path.exists(video_path)) | |
| ############################################################ Construct the Video Inputs ######################################################################### | |
| # Read the video by ffmpeg | |
| resolution = str(self.process_width) + "x" + str(self.process_height) | |
| video_stream, err = ffmpeg.input( | |
| video_path | |
| ).output( | |
| "pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution, vsync = 'passthrough', | |
| ).run( | |
| capture_stdout = True, capture_stderr = True | |
| ) # The resize is already included | |
| video_np_full = np.frombuffer(video_stream, np.uint8).reshape(-1, self.process_height, self.process_width, 3) | |
| original_num_frames = len(video_np_full) | |
| if original_num_frames < self.max_process_window_length: | |
| print("We only has", original_num_frames, "number of frames!") | |
| raise Exception("The number of frames in the video is too short") # Exception Cases will choose a new idx | |
| # Visualize (Comment Out Later) | |
| # visualize_concated_frames(video_np_full, "instance_"+str(idx), gt_ranges, max_frames_per_img=400, end_range_exclusive=True) | |
| # Crop the video to be fixed length | |
| if self.set_type == "train" and random.random() < self.has_overlength_prob: # Overlength case, might have padding | |
| start_sample_frame_idx = random.randint(0, original_num_frames - self.max_process_window_length + self.max_padding_length - 1) | |
| else: # Regular Case (Must inside the full video) | |
| start_sample_frame_idx = random.randint(0, original_num_frames - self.max_process_window_length - 1) if self.set_type == "train" else 0 | |
| end_sample_frame_idx = min(len(video_np_full), start_sample_frame_idx + self.max_process_window_length) | |
| video_np = video_np_full[ start_sample_frame_idx : end_sample_frame_idx] | |
| # Add padding | |
| num_padding_frames = self.max_process_window_length - len(video_np) | |
| assert(num_padding_frames <= self.max_padding_length) | |
| black_padding_frames = np.zeros((num_padding_frames, self.process_height, self.process_width, 3), dtype=video_np.dtype) | |
| video_np = np.concatenate([video_np, black_padding_frames], axis=0) | |
| # Video Data Transform + Augmentation | |
| video_tensor = self.video_transform(video_np, idx) # output shape is (F, C, H, W) | |
| ################################################################################################################################################################## | |
| ######################################################## Construct the Label System ############################################################################## | |
| # Construct the Standard Label: [End_Frame_Idx, Intra-Label, Inter-Label]. | |
| standard_labels = [[ *gt_ranges[clip_idx], unique_intra_label_mapping[gt_intra_labels[clip_idx]], unique_inter_label_mapping[gt_inter_labels[clip_idx]] ] for clip_idx in range(len(gt_ranges))] | |
| # Map the GT based on the start_sample_frame_idx position | |
| cropped_ranges, crop_intra_classification_label, crop_inter_classification_label = align_segments_to_crop(standard_labels, start_sample_frame_idx, self.max_process_window_length) | |
| ## Sanity Check: the cropped number of clips must be less than the number of query | |
| if len(cropped_ranges) > self.num_queries: | |
| raise Exception("The number of clips ", len(cropped_ranges), " is more than the number of query!") | |
| if len(cropped_ranges) != len(crop_intra_classification_label) or len(cropped_ranges) != len(crop_inter_classification_label): | |
| raise Exception("We cannot find ranges to be aligned with labels!") | |
| # Prepare the GT Video Classification Label | |
| intra_label_tensor = torch.tensor(crop_intra_classification_label) | |
| inter_label_tensor = torch.tensor(crop_inter_classification_label) | |
| pad_len = self.num_queries - intra_label_tensor.numel() | |
| intra_label_tensor = F.pad(intra_label_tensor, (0, pad_len), "constant", unique_intra_label_mapping["padding"]) | |
| inter_label_tensor = F.pad(inter_label_tensor, (0, pad_len), "constant", unique_inter_label_mapping["padding"]) | |
| # Prepare the GT Shot Range | |
| shot_labels_tensor = torch.tensor(cropped_ranges)[:, 1].to(torch.int64) # [Inclusive, Exclusive) | |
| shot_labels_tensor = F.pad(shot_labels_tensor, (0, pad_len), "constant", self.max_process_window_length) | |
| # Write as dictionary | |
| gt_target = {"shot_labels" : shot_labels_tensor, "intra_clip_labels" : intra_label_tensor, "inter_clip_labels" : inter_label_tensor} | |
| ############################################################################################################################################################ | |
| # Build Auxiliary info for dictionary | |
| aux_info = { | |
| "idx" : idx, | |
| "video_path" : video_path, | |
| "fps" : fps, | |
| "start_frame_idx" : start_sample_frame_idx, | |
| "end_frame_idx" : end_sample_frame_idx, | |
| } | |
| except Exception as error: | |
| print("We face error", error, "and we will fetch next one!") | |
| old_idx = idx | |
| idx = random.randint(0, len(self.data_info)) | |
| print("We cannot process the video", old_idx, " and we choose a new idx of ", idx) | |
| continue # For any error occurs, we run it again with new idx proposed (a random int less than current value) | |
| break | |
| # Return | |
| return video_tensor, gt_target, aux_info | |
| def build(args, set_type): | |
| dataset = CutAnything_Dataloader(args, set_type) | |
| return dataset | |