import os from torch.utils.data import DataLoader from lightning import LightningDataModule from .mixed_dataset import MixedDataset class MixedDataModule(LightningDataModule): def __init__( self, bert_model, dataset_path, tool_capacity, batch_size, num_workers, seed ): super().__init__() self.bert_model = bert_model self.dataset_path = dataset_path self.tool_capacity = tool_capacity self.batch_size = batch_size self.num_workers = num_workers self.seed = seed def setup(self, stage=None): if stage == "fit": self.train_dataset = MixedDataset( self.bert_model, "train", os.path.join(self.dataset_path, "train.json"), self.tool_capacity, seed=self.seed, ) self.val_dataset = MixedDataset( self.bert_model, "test", os.path.join(self.dataset_path, "test.json"), self.tool_capacity, seed=self.seed, ) elif stage == "test": self.test_dataset = MixedDataset( self.bert_model, "test", os.path.join(self.dataset_path, "test.json"), self.tool_capacity, seed=self.seed, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, )