File size: 2,629 Bytes
db16238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Create a small test set from The Pile that is disjoint from the first 100K
training examples (same stream order and filter as GistAutoencoderDataset).

Usage:
    python -m data.create_test_set [--num_examples 10] [--output data/test_set_10.json]

Uses config.MAX_SAMPLES and config.INPUT_SEQ_LENGTH so the test set is
guaranteed to come after the training window in the stream.
"""

import argparse
import json
import os

from datasets import load_dataset

# Same mapping and filter as dataset.py
DATASET_HF_MAP = {"the_pile": "monology/pile-uncopyrighted"}


def main():
    p = argparse.ArgumentParser(description="Create test set after first N training examples")
    p.add_argument("--num_examples", type=int, default=10, help="Number of test examples to collect")
    p.add_argument("--output", type=str, default="data/test_set_10.json", help="Output JSON path")
    p.add_argument("--max_training_samples", type=int, default=None, help="Skip this many training examples (default: config.MAX_SAMPLES)")
    p.add_argument("--input_seq_length", type=int, default=None, help="Same as training (default: config.INPUT_SEQ_LENGTH)")
    args = p.parse_args()

    if args.max_training_samples is None:
        import config as cfg
        args.max_training_samples = cfg.MAX_SAMPLES
    if args.input_seq_length is None:
        import config as cfg
        args.input_seq_length = cfg.INPUT_SEQ_LENGTH

    hf_id = DATASET_HF_MAP.get("the_pile", "monology/pile-uncopyrighted")
    max_chars = args.input_seq_length * 4
    skip = args.max_training_samples + 1000
    target = args.num_examples

    print(f"Loading stream: {hf_id} (train)")
    print(f"Filter: len(text) <= {max_chars} chars (input_seq_length={args.input_seq_length})")
    print(f"Skipping first {skip} examples that pass filter, then collecting {target} for test set...")

    ds = load_dataset(hf_id, split="train", streaming=True)
    collected = []
    seen = 0

    for row in ds:
        x = row.get("text")
        if not x or not isinstance(x, str) or len(x) > max_chars:
            continue
        if seen < skip:
            seen += 1
            continue
        collected.append({"text": x})
        if len(collected) >= target:
            break

    print(f"Collected {len(collected)} test examples (after skipping {seen} training examples).")

    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
    with open(args.output, "w") as f:
        json.dump({"examples": collected, "num_training_skipped": skip, "max_chars": max_chars}, f, indent=2)

    print(f"Saved to {args.output}")


if __name__ == "__main__":
    main()