| """FineWeb-Edu text data preparation and streaming. |
| |
| Dataset: https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu |
| Subset: sample-10BT (~10B tokens) or sample-100BT (~100B tokens) |
| Storage: ~6 GB (10BT) or ~60 GB (100BT) |
| License: ODC-BY |
| """ |
| import math, torch, numpy as np |
| from dataclasses import dataclass |
| from typing import Optional |
| import sys, os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| from arbitor.config import SPECIAL_VOCAB |
|
|
|
|
| @dataclass |
| class FineWebConfig: |
| subset: str = "sample-10BT" |
| ctx: int = 2048 |
| batch_size: int = 8 |
| shuffle_buffer: int = 10000 |
| split: str = "train" |
|
|
|
|
| class FineWebStream: |
| """Streaming text batches from FineWeb-Edu via HF datasets iterable.""" |
|
|
| def __init__(self, cfg: FineWebConfig): |
| self.cfg = cfg |
| self._ds = None |
| self._bos = SPECIAL_VOCAB['BOS'] |
| self._eos = SPECIAL_VOCAB['EOS'] |
| self._pad = SPECIAL_VOCAB['PAD'] |
|
|
| def _lazy_init(self): |
| if self._ds is not None: |
| return |
| from datasets import load_dataset |
| ds = load_dataset("HuggingFaceFW/fineweb-edu", self.cfg.subset, |
| split=self.cfg.split, streaming=True) |
| self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) |
|
|
| def _tokenize(self, text: str) -> torch.LongTensor: |
| raw = text.encode("utf-8") |
| tokens = [self._bos] + list(raw) + [self._eos] |
| return torch.tensor(tokens[:self.cfg.ctx + 1], dtype=torch.long) |
|
|
| def batches(self): |
| self._lazy_init() |
| buf = [] |
| for example in self._ds: |
| tokens = self._tokenize(example["text"]) |
| if tokens.numel() < 5: |
| continue |
| buf.append(tokens) |
| if len(buf) >= self.cfg.batch_size: |
| batch = buf[:self.cfg.batch_size] |
| buf = buf[self.cfg.batch_size:] |
| max_len = max(t.numel() for t in batch) |
| padded = torch.stack([ |
| torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)]) |
| for t in batch |
| ]) |
| yield padded[:, :-1].contiguous(), padded[:, 1:].contiguous() |
|
|
| def num_samples(self) -> int: |
| if self.cfg.subset == "sample-10BT": |
| return 10_000_000_000 // self.cfg.ctx |
| return 100_000_000_000 // self.cfg.ctx |
|
|