File size: 4,555 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Download and tokenize a HuggingFace dataset to the ARB byte vocabulary.

Usage:
    python training/data/tokenize_from_hf.py \\
        --repo "HuggingFaceFW/fineweb-edu" \\
        --subset "sample-10BT" \\
        --split "train" \\
        --max_samples 10000 \\
        --output "training/data/fineweb-edu.pt"

The script downloads text data, encodes it as UTF-8 bytes, and saves
as a tensor of byte indices compatible with ARB's VOCAB=297.
"""
import os, sys, torch

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.config import VOCAB, SPECIAL_VOCAB


def encode_text(text: str) -> list[int]:
    """Encode text to byte token indices (0-255) with BOS/EOS markers."""
    tokens = [SPECIAL_VOCAB['BOS']]
    for byte in text.encode('utf-8'):
        if byte < VOCAB:  # bytes 0-255 fit in our vocab
            tokens.append(byte)
    tokens.append(SPECIAL_VOCAB['EOS'])
    return tokens


def download_dataset(repo: str, subset: str = None, split: str = "train",
                     max_samples: int = None, token: str = None,
                     text_col: str = "text"):
    """Download and tokenize a HuggingFace text dataset.

    Args:
        repo: HF dataset repo (e.g., "HuggingFaceFW/fineweb-edu")
        subset: Config/subset name (e.g., "sample-10BT")
        split: Dataset split ("train", "test", etc.)
        max_samples: Limit number of samples (None = all)
        token: HF API token for private/gated datasets
        text_col: Column name containing text data
    Returns:
        torch.Tensor of concatenated byte token sequences
    """
    try:
        from datasets import load_dataset
    except ImportError:
        print("Install datasets: pip install datasets")
        sys.exit(1)

    kwargs = {"split": split, "streaming": True}
    if subset:
        kwargs["name"] = subset
    if token:
        kwargs["token"] = token

    print(f"Loading {repo}/{subset or ''} ({split})...")
    ds = load_dataset(repo, **kwargs)

    if max_samples:
        ds = ds.take(max_samples)

    all_tokens = []
    count = 0
    for example in ds:
        text = example.get(text_col, "")
        if not text or not isinstance(text, str):
            continue
        tokens = encode_text(text)
        all_tokens.extend(tokens)
        count += 1
        if count % 1000 == 0:
            print(f"  Processed {count} samples, {len(all_tokens):,} tokens")

    # Truncate or pad to align with model's expected dimensions
    tensor = torch.tensor(all_tokens, dtype=torch.long)
    print(f"Done: {count} samples → {len(tensor):,} byte tokens")
    return tensor


def save_tensor(tensor: torch.Tensor, path: str):
    """Save tokenized data to .pt file, creating directory if needed."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    torch.save(tensor, path)
    print(f"Saved to {path}")


def load_tensor(path: str) -> torch.Tensor:
    """Load previously tokenized .pt file."""
    return torch.load(path, weights_only=True)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Tokenize HF dataset to ARB byte format")
    parser.add_argument("--repo", type=str, default="HuggingFaceFW/fineweb-edu",
                        help="HF dataset repo")
    parser.add_argument("--subset", type=str, default="sample-10BT",
                        help="Dataset config/subset")
    parser.add_argument("--split", type=str, default="train")
    parser.add_argument("--output", type=str, default="training/data/dataset.pt",
                        help="Output .pt file path")
    parser.add_argument("--max-samples", type=int, default=None,
                        help="Limit samples (default: all)")
    parser.add_argument("--token", type=str, default=None,
                        help="HF API token for private datasets")
    parser.add_argument("--text-col", type=str, default="text",
                        help="Text column name")
    parser.add_argument("--verify", action="store_true",
                        help="Load and verify an existing .pt file")
    args = parser.parse_args()

    if args.verify:
        data = load_tensor(args.output)
        print(f"Loaded {args.output}: {len(data):,} tokens, "
              f"min={data.min().item()}, max={data.max().item()}")
        print(f"  Valid range: 0-{VOCAB-1}")
        sys.exit(0)

    tensor = download_dataset(args.repo, args.subset, args.split,
                               args.max_samples, args.token, args.text_col)
    save_tensor(tensor, args.output)