| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
| import torch |
| from typing import Dict |
|
|
|
|
| class DataModule(pl.LightningDataModule): |
| """ |
| Lightning DataModule for handling training and validation datasets. |
| |
| Args: |
| training_set (torch.utils.data.Dataset): Training dataset. |
| validation_set (torch.utils.data.Dataset): Validation dataset. |
| |
| Attributes: |
| training_set (torch.utils.data.Dataset): Training dataset. |
| validation_set (torch.utils.data.Dataset): Validation dataset. |
| train_ds (torch.utils.data.Dataset): Alias for the training dataset during setup. |
| val_ds (torch.utils.data.Dataset): Alias for the validation dataset during setup. |
| |
| Methods: |
| setup(self, stage: Optional[str] = None): |
| Setup method to load and preprocess datasets. |
| |
| train_dataloader(self) -> DataLoader: |
| Return a DataLoader for the training dataset. |
| |
| val_dataloader(self) -> DataLoader: |
| Return a DataLoader for the validation dataset. |
| """ |
| def __init__(self, training_set, validation_set): |
| super().__init__() |
| self.training_set = training_set |
| self.validation_set = validation_set |
|
|
| def setup(self, stage: str): |
| self.train_ds = self.training_set |
| self.val_ds = self.validation_set |
|
|
| def train_dataloader(self): |
| return DataLoader(self.train_ds, batch_size=1, shuffle=True) |
|
|
| def val_dataloader(self): |
| return DataLoader(self.val_ds, batch_size=1, shuffle=False) |