test / dataloaders /dataloader_msrvtt_retrieval.py
jaewooo's picture
Initial upload
de15dc5 verified
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import os
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from collections import defaultdict
import json
import random
from dataloaders.rawvideo_util import RawVideoExtractor
class MSRVTT_DataLoader(Dataset):
"""MSRVTT dataset loader."""
def __init__(
self,
csv_path,
features_path,
tokenizer,
max_words=30,
feature_framerate=1.0,
max_frames=100,
image_resolution=224,
frame_order=0,
slice_framepos=0,
):
self.data = pd.read_csv(csv_path)
self.features_path = features_path
self.feature_framerate = feature_framerate
self.max_words = max_words
self.max_frames = max_frames
self.tokenizer = tokenizer
# 0: ordinary order; 1: reverse order; 2: random order.
self.frame_order = frame_order
assert self.frame_order in [0, 1, 2]
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
self.slice_framepos = slice_framepos
assert self.slice_framepos in [0, 1, 2]
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
def __len__(self):
return len(self.data)
def _get_text(self, video_id, sentence):
choice_video_ids = [video_id]
n_caption = len(choice_video_ids)
k = n_caption
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
for i, video_id in enumerate(choice_video_ids):
words = self.tokenizer.tokenize(sentence)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
return pairs_text, pairs_mask, pairs_segment, choice_video_ids
def _get_rawvideo(self, choice_video_ids):
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
max_video_length = [0] * len(choice_video_ids)
# Pair x L x T x 3 x H x W
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3,
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32)
for i, video_id in enumerate(choice_video_ids):
# Individual for YoucokII dataset, due to it video format
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id))
if os.path.exists(video_path) is False:
video_path = video_path.replace(".mp4", ".webm")
raw_video_data = self.rawVideoExtractor.get_video_data(video_path)
raw_video_data = raw_video_data['video']
if len(raw_video_data.shape) > 3:
raw_video_data_clip = raw_video_data
# L x T x 3 x H x W
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
if self.max_frames < raw_video_slice.shape[0]:
if self.slice_framepos == 0:
video_slice = raw_video_slice[:self.max_frames, ...]
elif self.slice_framepos == 1:
video_slice = raw_video_slice[-self.max_frames:, ...]
else:
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
video_slice = raw_video_slice[sample_indx, ...]
else:
video_slice = raw_video_slice
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
slice_len = video_slice.shape[0]
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
if slice_len < 1:
pass
else:
video[i][:slice_len, ...] = video_slice
else:
print("video path: {} error. video id: {}".format(video_path, video_id))
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
return video, video_mask
def __getitem__(self, idx):
video_id = self.data['video_id'].values[idx]
sentence = self.data['sentence'].values[idx]
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence)
video, video_mask = self._get_rawvideo(choice_video_ids)
return pairs_text, pairs_mask, pairs_segment, video, video_mask
class MSRVTT_TrainDataLoader(Dataset):
"""MSRVTT train dataset loader."""
def __init__(
self,
csv_path,
json_path,
features_path,
tokenizer,
max_words=30,
feature_framerate=1.0,
max_frames=100,
unfold_sentences=False,
image_resolution=224,
frame_order=0,
slice_framepos=0,
):
self.csv = pd.read_csv(csv_path)
self.data = json.load(open(json_path, 'r'))
self.features_path = features_path
self.feature_framerate = feature_framerate
self.max_words = max_words
self.max_frames = max_frames
self.tokenizer = tokenizer
# 0: ordinary order; 1: reverse order; 2: random order.
self.frame_order = frame_order
assert self.frame_order in [0, 1, 2]
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
self.slice_framepos = slice_framepos
assert self.slice_framepos in [0, 1, 2]
self.unfold_sentences = unfold_sentences
self.sample_len = 0
if self.unfold_sentences:
train_video_ids = list(self.csv['video_id'].values)
self.sentences_dict = {}
for itm in self.data['sentences']:
if itm['video_id'] in train_video_ids:
self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption'])
self.sample_len = len(self.sentences_dict)
else:
num_sentences = 0
self.sentences = defaultdict(list)
s_video_id_set = set()
for itm in self.data['sentences']:
self.sentences[itm['video_id']].append(itm['caption'])
num_sentences += 1
s_video_id_set.add(itm['video_id'])
# Use to find the clips in the same video
self.parent_ids = {}
self.children_video_ids = defaultdict(list)
for itm in self.data['videos']:
vid = itm["video_id"]
url_posfix = itm["url"].split("?v=")[-1]
self.parent_ids[vid] = url_posfix
self.children_video_ids[url_posfix].append(vid)
self.sample_len = len(self.csv)
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
def __len__(self):
return self.sample_len
def _get_text(self, video_id, caption=None):
k = 1
choice_video_ids = [video_id]
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
for i, video_id in enumerate(choice_video_ids):
if caption is not None:
words = self.tokenizer.tokenize(caption)
else:
words = self._get_single_text(video_id)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
return pairs_text, pairs_mask, pairs_segment, choice_video_ids
def _get_single_text(self, video_id):
rind = random.randint(0, len(self.sentences[video_id]) - 1)
caption = self.sentences[video_id][rind]
words = self.tokenizer.tokenize(caption)
return words
def _get_rawvideo(self, choice_video_ids):
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
max_video_length = [0] * len(choice_video_ids)
# Pair x L x T x 3 x H x W
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3,
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32)
for i, video_id in enumerate(choice_video_ids):
# Individual for YoucokII dataset, due to it video format
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id))
if os.path.exists(video_path) is False:
video_path = video_path.replace(".mp4", ".webm")
raw_video_data = self.rawVideoExtractor.get_video_data(video_path)
raw_video_data = raw_video_data['video']
if len(raw_video_data.shape) > 3:
raw_video_data_clip = raw_video_data
# L x T x 3 x H x W
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
if self.max_frames < raw_video_slice.shape[0]:
if self.slice_framepos == 0:
video_slice = raw_video_slice[:self.max_frames, ...]
elif self.slice_framepos == 1:
video_slice = raw_video_slice[-self.max_frames:, ...]
else:
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
video_slice = raw_video_slice[sample_indx, ...]
else:
video_slice = raw_video_slice
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
slice_len = video_slice.shape[0]
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
if slice_len < 1:
pass
else:
video[i][:slice_len, ...] = video_slice
else:
print("video path: {} error. video id: {}".format(video_path, video_id))
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
return video, video_mask
def __getitem__(self, idx):
if self.unfold_sentences:
video_id, caption = self.sentences_dict[idx]
else:
video_id, caption = self.csv['video_id'].values[idx], None
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption)
video, video_mask = self._get_rawvideo(choice_video_ids)
return pairs_text, pairs_mask, pairs_segment, video, video_mask