tb-llmae-s1-depth-study / create_test_set.py
arkanathp's picture
Upload create_test_set.py with huggingface_hub
db16238 verified
"""
Create a small test set from The Pile that is disjoint from the first 100K
training examples (same stream order and filter as GistAutoencoderDataset).
Usage:
python -m data.create_test_set [--num_examples 10] [--output data/test_set_10.json]
Uses config.MAX_SAMPLES and config.INPUT_SEQ_LENGTH so the test set is
guaranteed to come after the training window in the stream.
"""
import argparse
import json
import os
from datasets import load_dataset
# Same mapping and filter as dataset.py
DATASET_HF_MAP = {"the_pile": "monology/pile-uncopyrighted"}
def main():
p = argparse.ArgumentParser(description="Create test set after first N training examples")
p.add_argument("--num_examples", type=int, default=10, help="Number of test examples to collect")
p.add_argument("--output", type=str, default="data/test_set_10.json", help="Output JSON path")
p.add_argument("--max_training_samples", type=int, default=None, help="Skip this many training examples (default: config.MAX_SAMPLES)")
p.add_argument("--input_seq_length", type=int, default=None, help="Same as training (default: config.INPUT_SEQ_LENGTH)")
args = p.parse_args()
if args.max_training_samples is None:
import config as cfg
args.max_training_samples = cfg.MAX_SAMPLES
if args.input_seq_length is None:
import config as cfg
args.input_seq_length = cfg.INPUT_SEQ_LENGTH
hf_id = DATASET_HF_MAP.get("the_pile", "monology/pile-uncopyrighted")
max_chars = args.input_seq_length * 4
skip = args.max_training_samples + 1000
target = args.num_examples
print(f"Loading stream: {hf_id} (train)")
print(f"Filter: len(text) <= {max_chars} chars (input_seq_length={args.input_seq_length})")
print(f"Skipping first {skip} examples that pass filter, then collecting {target} for test set...")
ds = load_dataset(hf_id, split="train", streaming=True)
collected = []
seen = 0
for row in ds:
x = row.get("text")
if not x or not isinstance(x, str) or len(x) > max_chars:
continue
if seen < skip:
seen += 1
continue
collected.append({"text": x})
if len(collected) >= target:
break
print(f"Collected {len(collected)} test examples (after skipping {seen} training examples).")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump({"examples": collected, "num_training_skipped": skip, "max_chars": max_chars}, f, indent=2)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()