| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import pdb |
|
|
| from .base_depth_dataset import BaseDepthDataset |
| from .eval_base_dataset import EvaluateBaseDataset, DatasetMode, get_pred_name |
| from .diode_dataset import DIODEDataset |
| from .eth3d_dataset import ETH3DDataset |
| from .hypersim_dataset import HypersimDataset |
| from .kitti_dataset import KITTIDataset |
| from .nyu_dataset import NYUDataset |
| from .scannet_dataset import ScanNetDataset |
| from .vkitti_dataset import VirtualKITTIDataset |
| from .depthanything_dataset import DepthAnythingDataset |
| from .base_inpaint_dataset import BaseInpaintDataset |
|
|
| dataset_name_class_dict = { |
| "hypersim": HypersimDataset, |
| "vkitti": VirtualKITTIDataset, |
| "nyu_v2": NYUDataset, |
| "kitti": KITTIDataset, |
| "eth3d": ETH3DDataset, |
| "diode": DIODEDataset, |
| "scannet": ScanNetDataset, |
| 'depthanything': DepthAnythingDataset, |
| 'inpainting': BaseInpaintDataset |
| } |
|
|
|
|
| def get_dataset( |
| cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs |
| ): |
| if "mixed" == cfg_data_split.name: |
| |
| dataset_ls = [ |
| get_dataset(_cfg, base_data_dir, mode, **kwargs) |
| for _cfg in cfg_data_split.dataset_list |
| ] |
| return dataset_ls |
| elif cfg_data_split.name in dataset_name_class_dict.keys(): |
| dataset_class = dataset_name_class_dict[cfg_data_split.name] |
| dataset = dataset_class( |
| mode=mode, |
| filename_ls_path=cfg_data_split.filenames, |
| dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), |
| **cfg_data_split, |
| **kwargs, |
| ) |
| else: |
| raise NotImplementedError |
|
|
| return dataset |
|
|
| def get_eval_dataset( |
| cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs |
| ) -> EvaluateBaseDataset: |
| if "mixed" == cfg_data_split.name: |
| assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." |
| dataset_ls = [ |
| get_dataset(_cfg, base_data_dir, mode, **kwargs) |
| for _cfg in cfg_data_split.dataset_list |
| ] |
| return dataset_ls |
| elif cfg_data_split.name in dataset_name_class_dict.keys(): |
| dataset_class = dataset_name_class_dict[cfg_data_split.name] |
| dataset = dataset_class( |
| mode=mode, |
| filename_ls_path=cfg_data_split.filenames, |
| dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), |
| **cfg_data_split, |
| **kwargs, |
| ) |
| else: |
| raise NotImplementedError |
|
|
| return dataset |
|
|