| import os |
| import json |
| import torch |
| import datasets |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import PreTrainedTokenizerFast |
|
|
| class CustomDataset(Dataset): |
| def __init__(self, data, tokenizer, max_length=512): |
| self.data = data |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| text = self.data[idx]["text"] |
| inputs = self.tokenizer( |
| text, |
| max_length=self.max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ) |
| return { |
| "input_ids": inputs["input_ids"].squeeze(0), |
| "attention_mask": inputs["attention_mask"].squeeze(0) |
| } |
|
|
| class DataLoaderHandler: |
| def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512): |
| self.dataset_path = dataset_path |
| self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
| self.batch_size = batch_size |
| self.max_length = max_length |
|
|
| def load_dataset(self): |
| if self.dataset_path.endswith(".json"): |
| with open(self.dataset_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| elif self.dataset_path.endswith(".jsonl"): |
| data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")] |
| else: |
| raise ValueError("Unsupported dataset format. Use JSON or JSONL.") |
| return data |
|
|
| def get_dataloader(self): |
| data = self.load_dataset() |
| dataset = CustomDataset(data, self.tokenizer, self.max_length) |
| return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) |
|
|
| if __name__ == "__main__": |
| dataset_path = "data/dataset.jsonl" |
| tokenizer_path = "tokenizer.json" |
| batch_size = 16 |
|
|
| data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size) |
| dataloader = data_loader_handler.get_dataloader() |
|
|
| for batch in dataloader: |
| print(batch["input_ids"].shape, batch["attention_mask"].shape) |
| break |