| import torch |
| from torch.utils.data import IterableDataset |
|
|
| from transformers import PreTrainedTokenizerBase |
|
|
| from pile import ThePile |
|
|
|
|
| class ThePileTokenized(IterableDataset): |
| def __init__( |
| self, |
| base_dataset: ThePile, |
| tokenizer: PreTrainedTokenizerBase, |
| max_length: int = 1024, |
| repeat_factor: float = 1.0, |
| ): |
| self.pile = base_dataset |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.repeat_factor = repeat_factor |
|
|
| def __iter__(self): |
| ds = iter(self.pile) |
| buffer = [] |
| while True: |
| tokens = self.tokenizer.encode(next(ds)["text"]) |
| buffer += [self.tokenizer.eos_token_id] + tokens |
| while len(buffer) > self.max_length: |
| yield torch.tensor(buffer[: self.max_length]) |
| buffer = buffer[int(self.max_length / self.repeat_factor) :] |
|
|
|
|
| if __name__ == "__main__": |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from transformers import GPT2Tokenizer |
|
|
| dataset = ThePileTokenized( |
| ThePile("train"), |
| GPT2Tokenizer.from_pretrained("gpt2"), |
| max_length=2048, |
| repeat_factor=4 / 3, |
| ) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=1, |
| ) |
| for batch in tqdm(dataloader, smoothing=0.01): |
| x = 0 |
| |
|
|