| from typing import Dict, List, Optional, NoReturn |
| import torch |
| import lightning.pytorch as pl |
| from torch.utils.data import DataLoader |
| from data.audiotext_dataset import AudioTextDataset |
|
|
|
|
| class DataModule(pl.LightningDataModule): |
| def __init__( |
| self, |
| train_dataset: object, |
| batch_size: int, |
| num_workers: int |
| ): |
| r"""Data module. To get one batch of data: |
| |
| code-block:: python |
| |
| data_module.setup() |
| |
| for batch_data_dict in data_module.train_dataloader(): |
| print(batch_data_dict.keys()) |
| break |
| |
| Args: |
| train_sampler: Sampler object |
| train_dataset: Dataset object |
| num_workers: int |
| distributed: bool |
| """ |
| super().__init__() |
| self._train_dataset = train_dataset |
| self.num_workers = num_workers |
| self.batch_size = batch_size |
| self.collate_fn = collate_fn |
|
|
|
|
| def prepare_data(self): |
| |
| |
| pass |
|
|
| def setup(self, stage: Optional[str] = None) -> NoReturn: |
| r"""called on every device.""" |
|
|
| |
| |
|
|
| |
| |
| |
| self.train_dataset = self._train_dataset |
| |
| |
| def train_dataloader(self) -> torch.utils.data.DataLoader: |
| r"""Get train loader.""" |
| train_loader = DataLoader( |
| dataset=self.train_dataset, |
| batch_size=self.batch_size, |
| collate_fn=self.collate_fn, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| persistent_workers=False, |
| shuffle=True |
| ) |
|
|
| return train_loader |
|
|
| def val_dataloader(self): |
| |
| |
| pass |
|
|
| def test_dataloader(self): |
| |
| |
| pass |
|
|
| def teardown(self): |
| |
| |
| pass |
|
|
|
|
| def collate_fn(list_data_dict): |
| r"""Collate mini-batch data to inputs and targets for training. |
| |
| Args: |
| list_data_dict: e.g., [ |
| { |
| 'text': 'a sound of dog', |
| 'waveform': (1, samples), |
| 'modality': 'audio_text' |
| } |
| ... |
| ] |
| Returns: |
| data_dict: e.g. |
| 'audio_text': { |
| 'text': ['a sound of dog', ...] |
| 'waveform': (batch_size, 1, samples) |
| } |
| """ |
| |
| at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] |
|
|
| at_data_dict = {} |
| |
| if len(at_list_data_dict) > 0: |
| for key in at_list_data_dict[0].keys(): |
| at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] |
| if key == 'waveform': |
| at_data_dict[key] = torch.stack(at_data_dict[key]) |
| elif key == 'text': |
| at_data_dict[key] = [text for text in at_data_dict[key]] |
|
|
| |
| data_dict = { |
| 'audio_text': at_data_dict |
| } |
| |
| return data_dict |