| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import glob |
| import io |
| import json |
| import os |
| import pdb |
| import random |
| import tarfile |
| from enum import Enum |
| from typing import Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision.transforms import InterpolationMode, Resize, CenterCrop |
| import torchvision.transforms as transforms |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from src.util.depth_transform import DepthNormalizerBase |
| import random |
|
|
| from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode |
|
|
|
|
| def read_image_from_tar(tar_obj, img_rel_path): |
| image = tar_obj.extractfile("./" + img_rel_path) |
| image = image.read() |
| image = Image.open(io.BytesIO(image)) |
|
|
|
|
| class BaseDepthDataset(Dataset): |
| def __init__( |
| self, |
| mode: DatasetMode, |
| filename_ls_path: str, |
| dataset_dir: str, |
| disp_name: str, |
| min_depth: float, |
| max_depth: float, |
| has_filled_depth: bool, |
| name_mode: DepthFileNameMode, |
| depth_transform: Union[DepthNormalizerBase, None] = None, |
| tokenizer: CLIPTokenizer = None, |
| augmentation_args: dict = None, |
| resize_to_hw=None, |
| move_invalid_to_far_plane: bool = True, |
| rgb_transform=lambda x: x / 255.0 * 2 - 1, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.mode = mode |
| |
| self.filename_ls_path = filename_ls_path |
| self.disp_name = disp_name |
| self.has_filled_depth = has_filled_depth |
| self.name_mode: DepthFileNameMode = name_mode |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| |
| self.depth_transform: DepthNormalizerBase = depth_transform |
| self.augm_args = augmentation_args |
| self.resize_to_hw = resize_to_hw |
| self.rgb_transform = rgb_transform |
| self.move_invalid_to_far_plane = move_invalid_to_far_plane |
| self.tokenizer = tokenizer |
| |
| self.filenames = [] |
| filename_paths = glob.glob(self.filename_ls_path) |
| for path in filename_paths: |
| with open(path, "r") as f: |
| self.filenames += json.load(f) |
| |
| self.tar_obj = None |
| self.is_tar = ( |
| True |
| if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) |
| else False |
| ) |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|
| def __getitem__(self, index): |
| rasters, other = self._get_data_item(index) |
| if DatasetMode.TRAIN == self.mode: |
| rasters = self._training_preprocess(rasters) |
| |
| outputs = rasters |
| outputs.update(other) |
| return outputs |
|
|
| def _get_data_item(self, index): |
| rgb_path = self.filenames[index]['rgb_path'] |
| depth_path = self.filenames[index]['depth_path'] |
| mask_path = None |
| if 'valid_mask' in self.filenames[index]: |
| mask_path = self.filenames[index]['valid_mask'] |
| if self.filenames[index]['caption'] is not None: |
| coca_caption = self.filenames[index]['caption']['coca_caption'] |
| spatial_caption = self.filenames[index]['caption']['spatial_caption'] |
| empty_caption = '' |
| caption_choices = [coca_caption, spatial_caption, empty_caption] |
| probabilities = [0.4, 0.4, 0.2] |
| caption = random.choices(caption_choices, probabilities)[0] |
| else: |
| caption = '' |
|
|
| rasters = {} |
| |
| rasters.update(self._load_rgb_data(rgb_path)) |
|
|
| |
| if DatasetMode.RGB_ONLY != self.mode and depth_path is not None: |
| |
| depth_data = self._load_depth_data(depth_path) |
| rasters.update(depth_data) |
| |
| if mask_path is not None: |
| valid_mask_raw = Image.open(mask_path) |
| valid_mask_filled = Image.open(mask_path) |
| rasters["valid_mask_raw"] = torch.from_numpy(np.asarray(valid_mask_raw)).unsqueeze(0).bool() |
| rasters["valid_mask_filled"] = torch.from_numpy(np.asarray(valid_mask_filled)).unsqueeze(0).bool() |
| else: |
| rasters["valid_mask_raw"] = self._get_valid_mask( |
| rasters["depth_raw_linear"] |
| ).clone() |
| rasters["valid_mask_filled"] = self._get_valid_mask( |
| rasters["depth_filled_linear"] |
| ).clone() |
|
|
| other = {"index": index, "rgb_path": rgb_path, 'text': caption} |
|
|
| if self.resize_to_hw is not None: |
| resize_transform = transforms.Compose([ |
| Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), |
| CenterCrop(size=self.resize_to_hw)]) |
| rasters = {k: resize_transform(v) for k, v in rasters.items()} |
|
|
| return rasters, other |
|
|
| def _load_rgb_data(self, rgb_path): |
| |
| rgb = self._read_rgb_file(rgb_path) |
| rgb_norm = rgb / 255.0 * 2.0 - 1.0 |
|
|
| outputs = { |
| "rgb_int": torch.from_numpy(rgb).int(), |
| "rgb_norm": torch.from_numpy(rgb_norm).float(), |
| } |
| return outputs |
|
|
| def _load_depth_data(self, depth_path, filled_rel_path=None): |
| |
| outputs = {} |
| depth_raw = self._read_depth_file(depth_path).squeeze() |
| depth_raw_linear = torch.from_numpy(depth_raw.copy()).float().unsqueeze(0) |
| outputs["depth_raw_linear"] = depth_raw_linear.clone() |
|
|
| if self.has_filled_depth: |
| depth_filled = self._read_depth_file(filled_rel_path).squeeze() |
| depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) |
| outputs["depth_filled_linear"] = depth_filled_linear |
| else: |
| outputs["depth_filled_linear"] = depth_raw_linear.clone() |
|
|
| return outputs |
|
|
| def _get_data_path(self, index): |
| filename_line = self.filenames[index] |
|
|
| |
| rgb_rel_path = filename_line[0] |
|
|
| depth_rel_path, text_rel_path = None, None |
| if DatasetMode.RGB_ONLY != self.mode: |
| depth_rel_path = filename_line[1] |
| if len(filename_line) > 2: |
| text_rel_path = filename_line[2] |
| return rgb_rel_path, depth_rel_path, text_rel_path |
|
|
| def _read_image(self, img_path) -> np.ndarray: |
| image_to_read = img_path |
| image = Image.open(image_to_read) |
| image = np.asarray(image) |
| return image |
|
|
| def _read_rgb_file(self, path) -> np.ndarray: |
| rgb = self._read_image(path) |
| rgb = np.transpose(rgb, (2, 0, 1)).astype(int) |
| return rgb |
|
|
| def _read_depth_file(self, path): |
| depth_in = self._read_image(path) |
| |
| depth_decoded = depth_in |
| return depth_decoded |
|
|
| def _get_valid_mask(self, depth: torch.Tensor): |
| valid_mask = torch.logical_and( |
| (depth > self.min_depth), (depth < self.max_depth) |
| ).bool() |
| return valid_mask |
|
|
| def _training_preprocess(self, rasters): |
| |
| if self.augm_args is not None: |
| rasters = self._augment_data(rasters) |
|
|
| |
| |
| |
|
|
| rasters["depth_raw_norm"] = self.depth_transform( |
| rasters["depth_raw_linear"], rasters["valid_mask_raw"] |
| ).clone() |
| rasters["depth_filled_norm"] = self.depth_transform( |
| rasters["depth_filled_linear"], rasters["valid_mask_filled"] |
| ).clone() |
|
|
| |
| if self.move_invalid_to_far_plane: |
| if self.depth_transform.far_plane_at_max: |
| rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
| self.depth_transform.norm_max |
| ) |
| else: |
| rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
| self.depth_transform.norm_min |
| ) |
|
|
| |
| if self.resize_to_hw is not None: |
| resize_transform = transforms.Compose([ |
| Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), |
| CenterCrop(size=self.resize_to_hw)]) |
| rasters = {k: resize_transform(v) for k, v in rasters.items()} |
| return rasters |
|
|
| def _augment_data(self, rasters_dict): |
| |
| lr_flip_p = self.augm_args.lr_flip_p |
| if random.random() < lr_flip_p: |
| rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} |
|
|
| return rasters_dict |
|
|
| def __del__(self): |
| if hasattr(self, "tar_obj") and self.tar_obj is not None: |
| self.tar_obj.close() |
| self.tar_obj = None |
|
|
| def get_pred_name(rgb_basename, name_mode, suffix=".png"): |
| if DepthFileNameMode.rgb_id == name_mode: |
| pred_basename = "pred_" + rgb_basename.split("_")[1] |
| elif DepthFileNameMode.i_d_rgb == name_mode: |
| pred_basename = rgb_basename.replace("_rgb.", "_pred.") |
| elif DepthFileNameMode.id == name_mode: |
| pred_basename = "pred_" + rgb_basename |
| elif DepthFileNameMode.rgb_i_d == name_mode: |
| pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) |
| else: |
| raise NotImplementedError |
| |
| pred_basename = os.path.splitext(pred_basename)[0] + suffix |
|
|
| return pred_basename |
|
|