File size: 2,377 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""FineWeb-Edu text data preparation and streaming.

Dataset: https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu
Subset: sample-10BT (~10B tokens) or sample-100BT (~100B tokens)
Storage: ~6 GB (10BT) or ~60 GB (100BT)
License: ODC-BY
"""
import math, torch, numpy as np
from dataclasses import dataclass
from typing import Optional
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB


@dataclass
class FineWebConfig:
    subset: str = "sample-10BT"
    ctx: int = 2048
    batch_size: int = 8
    shuffle_buffer: int = 10000
    split: str = "train"


class FineWebStream:
    """Streaming text batches from FineWeb-Edu via HF datasets iterable."""

    def __init__(self, cfg: FineWebConfig):
        self.cfg = cfg
        self._ds = None
        self._bos = SPECIAL_VOCAB['BOS']
        self._eos = SPECIAL_VOCAB['EOS']
        self._pad = SPECIAL_VOCAB['PAD']

    def _lazy_init(self):
        if self._ds is not None:
            return
        from datasets import load_dataset
        ds = load_dataset("HuggingFaceFW/fineweb-edu", self.cfg.subset,
                          split=self.cfg.split, streaming=True)
        self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)

    def _tokenize(self, text: str) -> torch.LongTensor:
        raw = text.encode("utf-8")
        tokens = [self._bos] + list(raw) + [self._eos]
        return torch.tensor(tokens[:self.cfg.ctx + 1], dtype=torch.long)

    def batches(self):
        self._lazy_init()
        buf = []
        for example in self._ds:
            tokens = self._tokenize(example["text"])
            if tokens.numel() < 5:
                continue
            buf.append(tokens)
            if len(buf) >= self.cfg.batch_size:
                batch = buf[:self.cfg.batch_size]
                buf = buf[self.cfg.batch_size:]
                max_len = max(t.numel() for t in batch)
                padded = torch.stack([
                    torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)])
                    for t in batch
                ])
                yield padded[:, :-1].contiguous(), padded[:, 1:].contiguous()

    def num_samples(self) -> int:
        if self.cfg.subset == "sample-10BT":
            return 10_000_000_000 // self.cfg.ctx
        return 100_000_000_000 // self.cfg.ctx