mini-agent / src /data /mixed_datamodule.py
qninhdt
cc
2c9e8bc
import os
from torch.utils.data import DataLoader
from lightning import LightningDataModule
from .mixed_dataset import MixedDataset
class MixedDataModule(LightningDataModule):
def __init__(
self, bert_model, dataset_path, tool_capacity, batch_size, num_workers, seed
):
super().__init__()
self.bert_model = bert_model
self.dataset_path = dataset_path
self.tool_capacity = tool_capacity
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed
def setup(self, stage=None):
if stage == "fit":
self.train_dataset = MixedDataset(
self.bert_model,
"train",
os.path.join(self.dataset_path, "train.json"),
self.tool_capacity,
seed=self.seed,
)
self.val_dataset = MixedDataset(
self.bert_model,
"test",
os.path.join(self.dataset_path, "test.json"),
self.tool_capacity,
seed=self.seed,
)
elif stage == "test":
self.test_dataset = MixedDataset(
self.bert_model,
"test",
os.path.join(self.dataset_path, "test.json"),
self.tool_capacity,
seed=self.seed,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)