| |
| import json |
| import torch |
| import random |
| from datasets import load_dataset |
| from transformers import BloomTokenizerFast |
| from torch.utils.data import Dataset, get_worker_info |
|
|
|
|
| def cycled(itr): |
| while True: |
| for itm in itr: |
| yield itm |
|
|
| class C4X(Dataset): |
|
|
| def __init__(self, seq_len=512, split='train'): |
| self.seq = seq_len |
| self.ds = load_dataset( |
| 'c4', |
| name='en', |
| split=split, |
| streaming=True, |
| ) |
| self.tok = BloomTokenizerFast.from_pretrained('bigscience/bloomz-1b7') |
| self.init = False |
|
|
| def __len__(self): |
| return 1_000_000_000 |
| |
| def _init(self): |
| if self.init: |
| return |
| wi = get_worker_info() |
| self.ds = cycled( |
| self.ds.shuffle( |
| seed=wi.seed, |
| buffer_size=10_000, |
| ) |
| ) |
| self.init = True |
|
|
| def _get_next(self): |
| self._init() |
| obj = next(self.ds)['text'] |
| tkn = self.tok.encode(obj) |
| return tkn |
|
|
| def _get_full(self): |
| obj = [] |
| while len(obj) < self.seq: |
| obj += self._get_next() |
| obj.append(self.tok.eos_token_id) |
| s = random.randint(0, len(obj)-self.seq) |
| return obj[s:s+self.seq] |
|
|
| def __getitem__(self, _): |
| return torch.tensor(self._get_full()) |
|
|
| def decode(self, tkns): |
| return self.tok.decode(tkns) |
|
|