import torch from torch.utils.data import DataLoader from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader from dataloaders.dataloader_lsmdc_retrieval import LSMDC_DataLoader from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader def dataloader_msrvtt_train(args, tokenizer): msrvtt_dataset = MSRVTT_TrainDataLoader( csv_path=args.train_csv, json_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, unfold_sentences=args.expand_msrvtt_sentences, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) dataloader = DataLoader( msrvtt_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(msrvtt_dataset), train_sampler def dataloader_msrvtt_test(args, tokenizer, subset="test"): msrvtt_testset = MSRVTT_DataLoader( csv_path=args.val_csv, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_msrvtt = DataLoader( msrvtt_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(msrvtt_testset) def dataloader_msvd_train(args, tokenizer): msvd_dataset = MSVD_DataLoader( subset="train", data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) dataloader = DataLoader( msvd_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(msvd_dataset), train_sampler def dataloader_msvd_test(args, tokenizer, subset="test"): msvd_testset = MSVD_DataLoader( subset=subset, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_msrvtt = DataLoader( msvd_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(msvd_testset) def dataloader_lsmdc_train(args, tokenizer): lsmdc_dataset = LSMDC_DataLoader( subset="train", data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) dataloader = DataLoader( lsmdc_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(lsmdc_dataset), train_sampler def dataloader_lsmdc_test(args, tokenizer, subset="test"): lsmdc_testset = LSMDC_DataLoader( subset=subset, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_msrvtt = DataLoader( lsmdc_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(lsmdc_testset) def dataloader_activity_train(args, tokenizer): activity_dataset = ActivityNet_DataLoader( subset="train", data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) dataloader = DataLoader( activity_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(activity_dataset), train_sampler def dataloader_activity_test(args, tokenizer, subset="test"): activity_testset = ActivityNet_DataLoader( subset=subset, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_msrvtt = DataLoader( activity_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(activity_testset) def dataloader_didemo_train(args, tokenizer): didemo_dataset = DiDeMo_DataLoader( subset="train", data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) dataloader = DataLoader( didemo_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(didemo_dataset), train_sampler def dataloader_didemo_test(args, tokenizer, subset="test"): didemo_testset = DiDeMo_DataLoader( subset=subset, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_didemo = DataLoader( didemo_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_didemo, len(didemo_testset) DATALOADER_DICT = {} DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None} DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None} DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test}