| import json |
| import time |
| import random |
| from typing import Literal |
|
|
| import requests |
| import zstandard as zstd |
| from torch.utils.data import IterableDataset, get_worker_info |
|
|
|
|
| Subset = Literal["train", "val", "test"] |
| URLs = { |
| "val": [ |
| "https://the-eye.eu/public/AI/pile/val.jsonl.zst", |
| ], |
| "test": [ |
| "https://the-eye.eu/public/AI/pile/test.jsonl.zst", |
| ], |
| "train": [ |
| "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst", |
| "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst", |
| ], |
| } |
|
|
|
|
| def _read_line_from_stream(reader, initial_line="", buffer_size=4096): |
| line = initial_line |
| while True: |
| c = reader.read(buffer_size) |
| if not c: |
| raise StopIteration |
| line += c.decode("utf-8") |
| if "\n" in line: |
| break |
| return line.split("\n", 1) |
|
|
|
|
| def _line_streamer(reader, buffer_size=4096): |
| rest = "" |
| while True: |
| try: |
| line, rest = _read_line_from_stream( |
| reader, |
| rest, |
| buffer_size, |
| ) |
| yield line |
| except StopIteration: |
| break |
|
|
|
|
| class ThePile(IterableDataset): |
| TEXT_BUFFER_SIZE = 4096 |
|
|
| def __init__(self, subset: Subset): |
| self.subset = subset |
|
|
| def __iter__(self): |
| urls = URLs[self.subset].copy() |
| while True: |
| wi = get_worker_info() |
| seed = wi.id if wi is not None else None |
| rnd = random.Random(seed) |
| rnd.shuffle(urls) |
| for url in urls: |
| r = requests.get(url, stream=True) |
| with zstd.ZstdDecompressor().stream_reader(r.raw) as reader: |
| for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE): |
| data = json.loads(line) |
| yield data |
|
|
|
|
| if __name__ == "__main__": |
| from tqdm import tqdm |
|
|
| dataset = ThePile("train") |
| for data in tqdm(dataset, smoothing=0.01): |
| pass |
| |