File size: 2,629 Bytes
db16238 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """
Create a small test set from The Pile that is disjoint from the first 100K
training examples (same stream order and filter as GistAutoencoderDataset).
Usage:
python -m data.create_test_set [--num_examples 10] [--output data/test_set_10.json]
Uses config.MAX_SAMPLES and config.INPUT_SEQ_LENGTH so the test set is
guaranteed to come after the training window in the stream.
"""
import argparse
import json
import os
from datasets import load_dataset
# Same mapping and filter as dataset.py
DATASET_HF_MAP = {"the_pile": "monology/pile-uncopyrighted"}
def main():
p = argparse.ArgumentParser(description="Create test set after first N training examples")
p.add_argument("--num_examples", type=int, default=10, help="Number of test examples to collect")
p.add_argument("--output", type=str, default="data/test_set_10.json", help="Output JSON path")
p.add_argument("--max_training_samples", type=int, default=None, help="Skip this many training examples (default: config.MAX_SAMPLES)")
p.add_argument("--input_seq_length", type=int, default=None, help="Same as training (default: config.INPUT_SEQ_LENGTH)")
args = p.parse_args()
if args.max_training_samples is None:
import config as cfg
args.max_training_samples = cfg.MAX_SAMPLES
if args.input_seq_length is None:
import config as cfg
args.input_seq_length = cfg.INPUT_SEQ_LENGTH
hf_id = DATASET_HF_MAP.get("the_pile", "monology/pile-uncopyrighted")
max_chars = args.input_seq_length * 4
skip = args.max_training_samples + 1000
target = args.num_examples
print(f"Loading stream: {hf_id} (train)")
print(f"Filter: len(text) <= {max_chars} chars (input_seq_length={args.input_seq_length})")
print(f"Skipping first {skip} examples that pass filter, then collecting {target} for test set...")
ds = load_dataset(hf_id, split="train", streaming=True)
collected = []
seen = 0
for row in ds:
x = row.get("text")
if not x or not isinstance(x, str) or len(x) > max_chars:
continue
if seen < skip:
seen += 1
continue
collected.append({"text": x})
if len(collected) >= target:
break
print(f"Collected {len(collected)} test examples (after skipping {seen} training examples).")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump({"examples": collected, "num_training_skipped": skip, "max_chars": max_chars}, f, indent=2)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()
|