| import os |
| from os import path, replace |
|
|
| import torch |
| from torch.utils.data.dataset import Dataset |
| from torchvision import transforms |
| from torchvision.transforms import InterpolationMode |
| from PIL import Image |
| import numpy as np |
|
|
| from dataset.range_transform import im_normalization, im_mean, im_rgb2lab_normalization, ToTensor, RGB2Lab |
| from dataset.reseed import reseed |
|
|
| import util.functional as F |
|
|
| class VOSDataset_221128_TransColorization_batch(Dataset): |
| """ |
| Works for DAVIS/YouTubeVOS/BL30K training |
| For each sequence: |
| - Pick three frames |
| - Pick two objects |
| - Apply some random transforms that are the same for all frames |
| - Apply random transform to each of the frame |
| - The distance between frames is controlled |
| """ |
| def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=2, finetune=False): |
| self.im_root = im_root |
| self.gt_root = gt_root |
| self.max_jump = max_jump |
| self.is_bl = is_bl |
| self.num_frames = num_frames |
| self.max_num_obj = max_num_obj |
|
|
| self.videos = [] |
| self.frames = {} |
| vid_list = sorted(os.listdir(self.im_root)) |
| |
| for vid in vid_list: |
| if subset is not None: |
| if vid not in subset: |
| continue |
| frames = sorted(os.listdir(os.path.join(self.im_root, vid))) |
| if len(frames) < num_frames: |
| continue |
| self.frames[vid] = frames |
| self.videos.append(vid) |
|
|
| print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root)) |
|
|
| |
| self.pair_im_lone_transform = transforms.Compose([ |
| transforms.ColorJitter(0.01, 0.01, 0.01, 0), |
| ]) |
|
|
| self.pair_im_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean), |
| ]) |
|
|
| self.pair_gt_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0), |
| ]) |
|
|
| |
| self.all_im_lone_transform = transforms.Compose([ |
| transforms.ColorJitter(0.1, 0.03, 0.03, 0), |
| |
| ]) |
|
|
| patchsz = 448 |
| self.all_im_dual_transform = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR) |
| ]) |
|
|
| self.all_gt_dual_transform = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST) |
| ]) |
|
|
| |
| self.final_im_transform = transforms.Compose([ |
| RGB2Lab(), |
| ToTensor(), |
| im_rgb2lab_normalization, |
| ]) |
|
|
| def __getitem__(self, idx): |
| video = self.videos[idx] |
| info = {} |
| info['name'] = video |
|
|
| vid_im_path = path.join(self.im_root, video) |
| vid_gt_path = path.join(self.gt_root, video) |
| frames = self.frames[video] |
|
|
| trials = 0 |
| while trials < 5: |
| info['frames'] = [] |
|
|
| num_frames = self.num_frames |
| length = len(frames) |
| this_max_jump = min(len(frames), self.max_jump) |
|
|
| |
| frames_idx = [np.random.randint(length)] |
| acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx)) |
| while(len(frames_idx) < num_frames): |
| idx = np.random.choice(list(acceptable_set)) |
| frames_idx.append(idx) |
| new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))) |
| acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx)) |
|
|
| frames_idx = sorted(frames_idx) |
| if np.random.rand() < 0.5: |
| |
| frames_idx = frames_idx[::-1] |
|
|
| sequence_seed = np.random.randint(2147483647) |
| images = [] |
| masks = [] |
| target_objects = [] |
| for f_idx in frames_idx: |
| jpg_name = frames[f_idx] |
| png_name = jpg_name.replace('.jpg', '.png') |
| info['frames'].append(jpg_name) |
|
|
| reseed(sequence_seed) |
| this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB') |
| this_im = self.all_im_dual_transform(this_im) |
| this_im = self.all_im_lone_transform(this_im) |
|
|
| reseed(sequence_seed) |
| this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P') |
| this_gt = self.all_gt_dual_transform(this_gt) |
|
|
| pairwise_seed = np.random.randint(2147483647) |
| reseed(pairwise_seed) |
| this_im = self.pair_im_dual_transform(this_im) |
| this_im = self.pair_im_lone_transform(this_im) |
|
|
| reseed(pairwise_seed) |
| this_gt = self.pair_gt_dual_transform(this_gt) |
|
|
| this_im = self.final_im_transform(this_im) |
| |
| |
| |
| |
|
|
| this_gt = np.array(this_gt) |
|
|
| this_im_l = this_im[:1,:,:] |
| this_im_ab = this_im[1:3,:,:] |
| |
|
|
| |
| |
|
|
| this_im_lll = this_im_l.repeat(3,1,1) |
| images.append(this_im_lll) |
| masks.append(this_im_ab) |
|
|
| images = torch.stack(images, 0) |
| |
|
|
| |
| break |
|
|
| first_frame_gt = masks[0].unsqueeze(0) |
| |
|
|
| info['num_objects'] = 2 |
|
|
| masks = np.stack(masks, 0) |
| |
|
|
|
|
| cls_gt = masks |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] |
|
|
| |
|
|
| selector = torch.FloatTensor(selector) |
| |
| |
| |
|
|
| data = { |
| 'rgb': images, |
| 'first_frame_gt': first_frame_gt, |
| 'cls_gt': cls_gt, |
| 'selector': selector, |
| 'info': info, |
| } |
|
|
| return data |
|
|
| def __len__(self): |
| return len(self.videos) |
|
|