| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch.utils.data import ( |
| BatchSampler, |
| RandomSampler, |
| SequentialSampler, |
| ) |
|
|
|
|
| class MixedBatchSampler(BatchSampler): |
| """Sample one batch from a selected dataset with given probability. |
| Compatible with datasets at different resolution |
| """ |
|
|
| def __init__( |
| self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None |
| ): |
| self.base_sampler = None |
| self.batch_size = batch_size |
| self.shuffle = shuffle |
| self.drop_last = drop_last |
| self.generator = generator |
|
|
| self.src_dataset_ls = src_dataset_ls |
| self.n_dataset = len(self.src_dataset_ls) |
|
|
| |
| self.dataset_length = [len(ds) for ds in self.src_dataset_ls] |
| self.cum_dataset_length = [ |
| sum(self.dataset_length[:i]) for i in range(self.n_dataset) |
| ] |
|
|
| |
| if self.shuffle: |
| self.src_batch_samplers = [ |
| BatchSampler( |
| sampler=RandomSampler( |
| ds, replacement=False, generator=self.generator |
| ), |
| batch_size=self.batch_size, |
| drop_last=self.drop_last, |
| ) |
| for ds in self.src_dataset_ls |
| ] |
| else: |
| self.src_batch_samplers = [ |
| BatchSampler( |
| sampler=SequentialSampler(ds), |
| batch_size=self.batch_size, |
| drop_last=self.drop_last, |
| ) |
| for ds in self.src_dataset_ls |
| ] |
| self.raw_batches = [ |
| list(bs) for bs in self.src_batch_samplers |
| ] |
| self.n_batches = [len(b) for b in self.raw_batches] |
| self.n_total_batch = sum(self.n_batches) |
|
|
| |
| if prob is None: |
| |
| self.prob = torch.tensor(self.n_batches) / self.n_total_batch |
| else: |
| self.prob = torch.as_tensor(prob) |
|
|
| def __iter__(self): |
| """_summary_ |
| |
| Yields: |
| list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls |
| """ |
| for _ in range(self.n_total_batch): |
| idx_ds = torch.multinomial( |
| self.prob, 1, replacement=True, generator=self.generator |
| ).item() |
| |
| if 0 == len(self.raw_batches[idx_ds]): |
| self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) |
| |
| batch_raw = self.raw_batches[idx_ds].pop() |
| |
| shift = self.cum_dataset_length[idx_ds] |
| batch = [n + shift for n in batch_raw] |
|
|
| yield batch |
|
|
| def __len__(self): |
| return self.n_total_batch |
|
|
|
|
| |
| if "__main__" == __name__: |
| from torch.utils.data import ConcatDataset, DataLoader, Dataset |
|
|
| class SimpleDataset(Dataset): |
| def __init__(self, start, len) -> None: |
| super().__init__() |
| self.start = start |
| self.len = len |
|
|
| def __len__(self): |
| return self.len |
|
|
| def __getitem__(self, index): |
| return self.start + index |
|
|
| dataset_1 = SimpleDataset(0, 10) |
| dataset_2 = SimpleDataset(200, 20) |
| dataset_3 = SimpleDataset(1000, 50) |
|
|
| concat_dataset = ConcatDataset( |
| [dataset_1, dataset_2, dataset_3] |
| ) |
|
|
| mixed_sampler = MixedBatchSampler( |
| src_dataset_ls=[dataset_1, dataset_2, dataset_3], |
| batch_size=4, |
| drop_last=True, |
| shuffle=False, |
| prob=[0.6, 0.3, 0.1], |
| generator=torch.Generator().manual_seed(0), |
| ) |
|
|
| loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) |
|
|
| for d in loader: |
| print(d) |
|
|