| """ |
| Create a small test set from The Pile that is disjoint from the first 100K |
| training examples (same stream order and filter as GistAutoencoderDataset). |
| |
| Usage: |
| python -m data.create_test_set [--num_examples 10] [--output data/test_set_10.json] |
| |
| Uses config.MAX_SAMPLES and config.INPUT_SEQ_LENGTH so the test set is |
| guaranteed to come after the training window in the stream. |
| """ |
|
|
| import argparse |
| import json |
| import os |
|
|
| from datasets import load_dataset |
|
|
| |
| DATASET_HF_MAP = {"the_pile": "monology/pile-uncopyrighted"} |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser(description="Create test set after first N training examples") |
| p.add_argument("--num_examples", type=int, default=10, help="Number of test examples to collect") |
| p.add_argument("--output", type=str, default="data/test_set_10.json", help="Output JSON path") |
| p.add_argument("--max_training_samples", type=int, default=None, help="Skip this many training examples (default: config.MAX_SAMPLES)") |
| p.add_argument("--input_seq_length", type=int, default=None, help="Same as training (default: config.INPUT_SEQ_LENGTH)") |
| args = p.parse_args() |
|
|
| if args.max_training_samples is None: |
| import config as cfg |
| args.max_training_samples = cfg.MAX_SAMPLES |
| if args.input_seq_length is None: |
| import config as cfg |
| args.input_seq_length = cfg.INPUT_SEQ_LENGTH |
|
|
| hf_id = DATASET_HF_MAP.get("the_pile", "monology/pile-uncopyrighted") |
| max_chars = args.input_seq_length * 4 |
| skip = args.max_training_samples + 1000 |
| target = args.num_examples |
|
|
| print(f"Loading stream: {hf_id} (train)") |
| print(f"Filter: len(text) <= {max_chars} chars (input_seq_length={args.input_seq_length})") |
| print(f"Skipping first {skip} examples that pass filter, then collecting {target} for test set...") |
|
|
| ds = load_dataset(hf_id, split="train", streaming=True) |
| collected = [] |
| seen = 0 |
|
|
| for row in ds: |
| x = row.get("text") |
| if not x or not isinstance(x, str) or len(x) > max_chars: |
| continue |
| if seen < skip: |
| seen += 1 |
| continue |
| collected.append({"text": x}) |
| if len(collected) >= target: |
| break |
|
|
| print(f"Collected {len(collected)} test examples (after skipping {seen} training examples).") |
|
|
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) |
| with open(args.output, "w") as f: |
| json.dump({"examples": collected, "num_training_skipped": skip, "max_chars": max_chars}, f, indent=2) |
|
|
| print(f"Saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|