| |
| |
| from .dataset_wrapper import MixDatasetWrapper |
| from .codesign import CoDesignDataset |
| from .resample import ClusterResampler |
|
|
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| import utils.register as R |
| from utils.logger import print_log |
|
|
| def create_dataset(config: dict): |
| splits = [] |
| for split_name in ['train', 'valid', 'test']: |
| split_config = config.get(split_name, None) |
| if split_config is None: |
| splits.append(None) |
| continue |
| if isinstance(split_config, list): |
| dataset = MixDatasetWrapper( |
| *[R.construct(cfg) for cfg in split_config] |
| ) |
| else: |
| dataset = R.construct(split_config) |
| splits.append(dataset) |
| return splits |
|
|
|
|
| def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False): |
| if 'wrapper' in config: |
| dataset = R.construct(config['wrapper'], dataset=dataset) |
| batch_size = config.get('batch_size', n_gpu) |
| if validation: |
| batch_size = config.get('val_batch_size', batch_size) |
| shuffle = config.get('shuffle', False) |
| num_workers = config.get('num_workers', 4) |
| collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None |
| if n_gpu > 1: |
| sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
| batch_size = int(batch_size / n_gpu) |
| print_log(f'Batch size on a single GPU: {batch_size}') |
| else: |
| sampler = None |
| return DataLoader( |
| dataset=dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| shuffle=(shuffle and sampler is None), |
| collate_fn=collate_fn, |
| sampler=sampler |
| ) |