"""FineWeb-Edu text data preparation and streaming. Dataset: https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu Subset: sample-10BT (~10B tokens) or sample-100BT (~100B tokens) Storage: ~6 GB (10BT) or ~60 GB (100BT) License: ODC-BY """ import math, torch, numpy as np from dataclasses import dataclass from typing import Optional import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.config import SPECIAL_VOCAB @dataclass class FineWebConfig: subset: str = "sample-10BT" ctx: int = 2048 batch_size: int = 8 shuffle_buffer: int = 10000 split: str = "train" class FineWebStream: """Streaming text batches from FineWeb-Edu via HF datasets iterable.""" def __init__(self, cfg: FineWebConfig): self.cfg = cfg self._ds = None self._bos = SPECIAL_VOCAB['BOS'] self._eos = SPECIAL_VOCAB['EOS'] self._pad = SPECIAL_VOCAB['PAD'] def _lazy_init(self): if self._ds is not None: return from datasets import load_dataset ds = load_dataset("HuggingFaceFW/fineweb-edu", self.cfg.subset, split=self.cfg.split, streaming=True) self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) def _tokenize(self, text: str) -> torch.LongTensor: raw = text.encode("utf-8") tokens = [self._bos] + list(raw) + [self._eos] return torch.tensor(tokens[:self.cfg.ctx + 1], dtype=torch.long) def batches(self): self._lazy_init() buf = [] for example in self._ds: tokens = self._tokenize(example["text"]) if tokens.numel() < 5: continue buf.append(tokens) if len(buf) >= self.cfg.batch_size: batch = buf[:self.cfg.batch_size] buf = buf[self.cfg.batch_size:] max_len = max(t.numel() for t in batch) padded = torch.stack([ torch.cat([t, t.new_full((max_len - t.numel(),), self._pad)]) for t in batch ]) yield padded[:, :-1].contiguous(), padded[:, 1:].contiguous() def num_samples(self) -> int: if self.cfg.subset == "sample-10BT": return 10_000_000_000 // self.cfg.ctx return 100_000_000_000 // self.cfg.ctx