| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import tarfile |
| from io import BytesIO |
|
|
| import numpy as np |
| import torch |
|
|
| from .eval_base_dataset import EvaluateBaseDataset, DepthFileNameMode, DatasetMode |
|
|
|
|
| class DIODEDataset(EvaluateBaseDataset): |
| def __init__( |
| self, |
| **kwargs, |
| ) -> None: |
| super().__init__( |
| |
| min_depth=0.6, |
| max_depth=350, |
| has_filled_depth=False, |
| name_mode=DepthFileNameMode.id, |
| **kwargs, |
| ) |
|
|
| def _read_npy_file(self, rel_path): |
| if self.is_tar: |
| if self.tar_obj is None: |
| self.tar_obj = tarfile.open(self.dataset_dir) |
| fileobj = self.tar_obj.extractfile("./" + rel_path) |
| npy_path_or_content = BytesIO(fileobj.read()) |
| else: |
| npy_path_or_content = os.path.join(self.dataset_dir, rel_path) |
| data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] |
| return data |
|
|
| def _read_depth_file(self, rel_path): |
| depth = self._read_npy_file(rel_path) |
| return depth |
|
|
| def _get_data_path(self, index): |
| return self.filenames[index] |
|
|
| def _get_data_item(self, index): |
| |
|
|
| rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) |
|
|
| rasters = {} |
|
|
| |
| rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) |
|
|
| |
| if DatasetMode.RGB_ONLY != self.mode: |
| |
| depth_data = self._load_depth_data( |
| depth_rel_path=depth_rel_path, filled_rel_path=None |
| ) |
| rasters.update(depth_data) |
|
|
| |
| mask = self._read_npy_file(mask_rel_path).astype(bool) |
| mask = torch.from_numpy(mask).bool() |
| rasters["valid_mask_raw"] = mask.clone() |
| rasters["valid_mask_filled"] = mask.clone() |
|
|
| other = {"index": index, "rgb_relative_path": rgb_rel_path} |
|
|
| return rasters, other |
|
|