| """Download and tokenize a HuggingFace dataset to the ARB byte vocabulary. |
| |
| Usage: |
| python training/data/tokenize_from_hf.py \\ |
| --repo "HuggingFaceFW/fineweb-edu" \\ |
| --subset "sample-10BT" \\ |
| --split "train" \\ |
| --max_samples 10000 \\ |
| --output "training/data/fineweb-edu.pt" |
| |
| The script downloads text data, encodes it as UTF-8 bytes, and saves |
| as a tensor of byte indices compatible with ARB's VOCAB=297. |
| """ |
| import os, sys, torch |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.config import VOCAB, SPECIAL_VOCAB |
|
|
|
|
| def encode_text(text: str) -> list[int]: |
| """Encode text to byte token indices (0-255) with BOS/EOS markers.""" |
| tokens = [SPECIAL_VOCAB['BOS']] |
| for byte in text.encode('utf-8'): |
| if byte < VOCAB: |
| tokens.append(byte) |
| tokens.append(SPECIAL_VOCAB['EOS']) |
| return tokens |
|
|
|
|
| def download_dataset(repo: str, subset: str = None, split: str = "train", |
| max_samples: int = None, token: str = None, |
| text_col: str = "text"): |
| """Download and tokenize a HuggingFace text dataset. |
| |
| Args: |
| repo: HF dataset repo (e.g., "HuggingFaceFW/fineweb-edu") |
| subset: Config/subset name (e.g., "sample-10BT") |
| split: Dataset split ("train", "test", etc.) |
| max_samples: Limit number of samples (None = all) |
| token: HF API token for private/gated datasets |
| text_col: Column name containing text data |
| Returns: |
| torch.Tensor of concatenated byte token sequences |
| """ |
| try: |
| from datasets import load_dataset |
| except ImportError: |
| print("Install datasets: pip install datasets") |
| sys.exit(1) |
|
|
| kwargs = {"split": split, "streaming": True} |
| if subset: |
| kwargs["name"] = subset |
| if token: |
| kwargs["token"] = token |
|
|
| print(f"Loading {repo}/{subset or ''} ({split})...") |
| ds = load_dataset(repo, **kwargs) |
|
|
| if max_samples: |
| ds = ds.take(max_samples) |
|
|
| all_tokens = [] |
| count = 0 |
| for example in ds: |
| text = example.get(text_col, "") |
| if not text or not isinstance(text, str): |
| continue |
| tokens = encode_text(text) |
| all_tokens.extend(tokens) |
| count += 1 |
| if count % 1000 == 0: |
| print(f" Processed {count} samples, {len(all_tokens):,} tokens") |
|
|
| |
| tensor = torch.tensor(all_tokens, dtype=torch.long) |
| print(f"Done: {count} samples → {len(tensor):,} byte tokens") |
| return tensor |
|
|
|
|
| def save_tensor(tensor: torch.Tensor, path: str): |
| """Save tokenized data to .pt file, creating directory if needed.""" |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
| torch.save(tensor, path) |
| print(f"Saved to {path}") |
|
|
|
|
| def load_tensor(path: str) -> torch.Tensor: |
| """Load previously tokenized .pt file.""" |
| return torch.load(path, weights_only=True) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="Tokenize HF dataset to ARB byte format") |
| parser.add_argument("--repo", type=str, default="HuggingFaceFW/fineweb-edu", |
| help="HF dataset repo") |
| parser.add_argument("--subset", type=str, default="sample-10BT", |
| help="Dataset config/subset") |
| parser.add_argument("--split", type=str, default="train") |
| parser.add_argument("--output", type=str, default="training/data/dataset.pt", |
| help="Output .pt file path") |
| parser.add_argument("--max-samples", type=int, default=None, |
| help="Limit samples (default: all)") |
| parser.add_argument("--token", type=str, default=None, |
| help="HF API token for private datasets") |
| parser.add_argument("--text-col", type=str, default="text", |
| help="Text column name") |
| parser.add_argument("--verify", action="store_true", |
| help="Load and verify an existing .pt file") |
| args = parser.parse_args() |
|
|
| if args.verify: |
| data = load_tensor(args.output) |
| print(f"Loaded {args.output}: {len(data):,} tokens, " |
| f"min={data.min().item()}, max={data.max().item()}") |
| print(f" Valid range: 0-{VOCAB-1}") |
| sys.exit(0) |
|
|
| tensor = download_dataset(args.repo, args.subset, args.split, |
| args.max_samples, args.token, args.text_col) |
| save_tensor(tensor, args.output) |
|
|