""" Create a small test set from The Pile that is disjoint from the first 100K training examples (same stream order and filter as GistAutoencoderDataset). Usage: python -m data.create_test_set [--num_examples 10] [--output data/test_set_10.json] Uses config.MAX_SAMPLES and config.INPUT_SEQ_LENGTH so the test set is guaranteed to come after the training window in the stream. """ import argparse import json import os from datasets import load_dataset # Same mapping and filter as dataset.py DATASET_HF_MAP = {"the_pile": "monology/pile-uncopyrighted"} def main(): p = argparse.ArgumentParser(description="Create test set after first N training examples") p.add_argument("--num_examples", type=int, default=10, help="Number of test examples to collect") p.add_argument("--output", type=str, default="data/test_set_10.json", help="Output JSON path") p.add_argument("--max_training_samples", type=int, default=None, help="Skip this many training examples (default: config.MAX_SAMPLES)") p.add_argument("--input_seq_length", type=int, default=None, help="Same as training (default: config.INPUT_SEQ_LENGTH)") args = p.parse_args() if args.max_training_samples is None: import config as cfg args.max_training_samples = cfg.MAX_SAMPLES if args.input_seq_length is None: import config as cfg args.input_seq_length = cfg.INPUT_SEQ_LENGTH hf_id = DATASET_HF_MAP.get("the_pile", "monology/pile-uncopyrighted") max_chars = args.input_seq_length * 4 skip = args.max_training_samples + 1000 target = args.num_examples print(f"Loading stream: {hf_id} (train)") print(f"Filter: len(text) <= {max_chars} chars (input_seq_length={args.input_seq_length})") print(f"Skipping first {skip} examples that pass filter, then collecting {target} for test set...") ds = load_dataset(hf_id, split="train", streaming=True) collected = [] seen = 0 for row in ds: x = row.get("text") if not x or not isinstance(x, str) or len(x) > max_chars: continue if seen < skip: seen += 1 continue collected.append({"text": x}) if len(collected) >= target: break print(f"Collected {len(collected)} test examples (after skipping {seen} training examples).") os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) with open(args.output, "w") as f: json.dump({"examples": collected, "num_training_skipped": skip, "max_chars": max_chars}, f, indent=2) print(f"Saved to {args.output}") if __name__ == "__main__": main()