| |
| |
|
|
| |
| |
|
|
| import logging |
| import random |
| from copy import deepcopy |
|
|
| import numpy as np |
|
|
| import torch |
| from iopath.common.file_io import g_pathmgr |
| from PIL import Image as PILImage |
| from torchvision.datasets.vision import VisionDataset |
|
|
| from training.dataset.vos_raw_dataset import VOSRawDataset |
| from training.dataset.vos_sampler import VOSSampler |
| from training.dataset.vos_segment_loader import JSONSegmentLoader |
|
|
| from training.utils.data_utils import Frame, Object, VideoDatapoint |
|
|
| MAX_RETRIES = 100 |
|
|
|
|
| class VOSDataset(VisionDataset): |
| def __init__( |
| self, |
| transforms, |
| training: bool, |
| video_dataset: VOSRawDataset, |
| sampler: VOSSampler, |
| multiplier: int, |
| always_target=True, |
| target_segments_available=True, |
| ): |
| self._transforms = transforms |
| self.training = training |
| self.video_dataset = video_dataset |
| self.sampler = sampler |
|
|
| self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) |
| self.repeat_factors *= multiplier |
| print(f"Raw dataset length = {len(self.video_dataset)}") |
|
|
| self.curr_epoch = 0 |
| self.always_target = always_target |
| self.target_segments_available = target_segments_available |
|
|
| def _get_datapoint(self, idx): |
|
|
| for retry in range(MAX_RETRIES): |
| try: |
| if isinstance(idx, torch.Tensor): |
| idx = idx.item() |
| |
| video, segment_loader = self.video_dataset.get_video(idx) |
| |
| sampled_frms_and_objs = self.sampler.sample( |
| video, segment_loader, epoch=self.curr_epoch |
| ) |
| break |
| except Exception as e: |
| if self.training: |
| logging.warning( |
| f"Loading failed (id={idx}); Retry {retry} with exception: {e}" |
| ) |
| idx = random.randrange(0, len(self.video_dataset)) |
| else: |
| |
| raise e |
|
|
| datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) |
| for transform in self._transforms: |
| datapoint = transform(datapoint, epoch=self.curr_epoch) |
| return datapoint |
|
|
| def construct(self, video, sampled_frms_and_objs, segment_loader): |
| """ |
| Constructs a VideoDatapoint sample to pass to transforms |
| """ |
| sampled_frames = sampled_frms_and_objs.frames |
| sampled_object_ids = sampled_frms_and_objs.object_ids |
|
|
| images = [] |
| rgb_images = load_images(sampled_frames) |
| |
| for frame_idx, frame in enumerate(sampled_frames): |
| w, h = rgb_images[frame_idx].size |
| images.append( |
| Frame( |
| data=rgb_images[frame_idx], |
| objects=[], |
| ) |
| ) |
| |
| if isinstance(segment_loader, JSONSegmentLoader): |
| segments = segment_loader.load( |
| frame.frame_idx, obj_ids=sampled_object_ids |
| ) |
| else: |
| segments = segment_loader.load(frame.frame_idx) |
| for obj_id in sampled_object_ids: |
| |
| if obj_id in segments: |
| assert ( |
| segments[obj_id] is not None |
| ), "None targets are not supported" |
| |
| segment = segments[obj_id].to(torch.uint8) |
| else: |
| |
| if not self.always_target: |
| continue |
| segment = torch.zeros(h, w, dtype=torch.uint8) |
|
|
| images[frame_idx].objects.append( |
| Object( |
| object_id=obj_id, |
| frame_index=frame.frame_idx, |
| segment=segment, |
| ) |
| ) |
| return VideoDatapoint( |
| frames=images, |
| video_id=video.video_id, |
| size=(h, w), |
| ) |
|
|
| def __getitem__(self, idx): |
| return self._get_datapoint(idx) |
|
|
| def __len__(self): |
| return len(self.video_dataset) |
|
|
|
|
| def load_images(frames): |
| all_images = [] |
| cache = {} |
| for frame in frames: |
| if frame.data is None: |
| |
| path = frame.image_path |
| if path in cache: |
| all_images.append(deepcopy(all_images[cache[path]])) |
| continue |
| with g_pathmgr.open(path, "rb") as fopen: |
| all_images.append(PILImage.open(fopen).convert("RGB")) |
| cache[path] = len(all_images) - 1 |
| else: |
| |
| |
| all_images.append(tensor_2_PIL(frame.data)) |
|
|
| return all_images |
|
|
|
|
| def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: |
| data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 |
| data = data.astype(np.uint8) |
| return PILImage.fromarray(data) |
|
|