Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Any, Callable, List, Tuple, Union | |
| from sapiens.registry import DATASETS | |
| from .base_dataset import BaseDataset | |
| class CombinedDataset(BaseDataset): | |
| def __init__( | |
| self, datasets: list, pipeline: List[Union[dict, Callable]] = [], **kwargs | |
| ): | |
| self.datasets = [] | |
| for cfg in datasets: | |
| dataset = DATASETS.build(cfg) | |
| self.datasets.append(dataset) | |
| self._lens = [len(dataset) for dataset in self.datasets] | |
| self._len = sum(self._lens) | |
| super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs) | |
| assert len(self.datasets) > 0 | |
| return | |
| def __len__(self): | |
| return self._len | |
| def _get_subset_index(self, index: int) -> Tuple[int, int]: | |
| if index >= len(self) or index < -len(self): | |
| raise ValueError(f"index {index} out of bounds for length {len(self)}.") | |
| if index < 0: | |
| index = index + len(self) | |
| subset_index = 0 | |
| while index >= self._lens[subset_index]: | |
| index -= self._lens[subset_index] | |
| subset_index += 1 | |
| return subset_index, index | |
| def prepare_data(self, idx: int) -> Any: | |
| data_info = self.get_data_info(idx) | |
| if data_info is None: | |
| return None | |
| for transform in self.pipeline.transforms: | |
| data_info = transform(data_info) | |
| return data_info | |
| def get_data_info(self, idx: int) -> dict: | |
| subset_idx, sample_idx = self._get_subset_index(idx) | |
| data_info = self.datasets[subset_idx][sample_idx] | |
| return data_info | |