| |
| |
|
|
| import argparse |
| from typing import Any, Dict, List |
|
|
| from transformers import AutoTokenizer, PreTrainedTokenizer |
|
|
| from flame.data import build_dataset |
| from torchtitan.tools.logging import init_logger, logger |
|
|
|
|
| def tokenize( |
| examples: Dict[str, List[Any]], |
| tokenizer: PreTrainedTokenizer, |
| ) -> Dict: |
| if 'text' in examples: |
| samples = examples['text'] |
| elif 'content' in examples: |
| samples = examples['content'] |
| else: |
| raise ValueError(f'No "text" or "content" field found in examples:\n{examples}') |
| input_ids = tokenizer(samples)['input_ids'] |
| bits_per_token = [len(sample.encode(encoding='utf-8')) * 8 / len(input_ids[i]) for i, sample in enumerate(samples)] |
| return {'input_ids': input_ids, 'bits_per_token': bits_per_token} |
|
|
|
|
| if __name__ == '__main__': |
| init_logger() |
| parser = argparse.ArgumentParser(description='Preprocess the dataset.') |
| parser.add_argument( |
| '--dataset', |
| default='HuggingFaceFW/fineweb-edu', |
| help='Dataset to use, with comma separated values', |
| ) |
| parser.add_argument( |
| '--dataset_name', |
| default='sample-100BT', |
| help='The name of the dataset config, with comma separated values if provided', |
| ) |
| parser.add_argument( |
| '--dataset_split', |
| default='train', |
| help='Dataset split to use, with comma separated values if provided', |
| ) |
| parser.add_argument( |
| '--data_dir', |
| default=None, |
| help='Data dirs to use, with comma separated values if provided', |
| ) |
| parser.add_argument( |
| '--data_files', |
| default=None, |
| help='Data files to use, with comma separated values if provided', |
| ) |
| parser.add_argument( |
| '--data_probs', |
| default=None, |
| help='Data sampling probabilities, with comma separated values if provided', |
| ) |
| parser.add_argument( |
| '--streaming', |
| action='store_true', |
| help='Whether to use streaming mode', |
| ) |
| parser.add_argument( |
| '--num_workers', |
| type=int, |
| default=64, |
| help='Number of workers to use for preprocessing', |
| ) |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=42, |
| help='Random seed for preprocessing', |
| ) |
| parser.add_argument( |
| '--path', |
| default='data', |
| help='Path to save the preprocessed dataset', |
| ) |
| parser.add_argument( |
| '--tokenizer', |
| default='fla-hub/transformer-1.3B-100B', |
| help='Tokenizer to use', |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=2048, |
| help="Batch size for processing" |
| ) |
| args = parser.parse_args() |
|
|
| logger.info(f'Loading tokenizer {args.tokenizer}') |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) |
| logger.info(f'{tokenizer}') |
| logger.info(f'Loading dataset {args.dataset} {args.dataset_name} {args.dataset_split}') |
| dataset = build_dataset( |
| dataset=args.dataset, |
| dataset_name=args.dataset_name, |
| dataset_split=args.dataset_split, |
| data_dir=args.data_dir, |
| data_files=args.data_files, |
| data_probs=args.data_probs, |
| streaming=args.streaming, |
| num_workers=args.num_workers, |
| seed=args.seed, |
| ) |
| logger.info(f'Tokenizing and processing the dataset with batch size {args.batch_size}') |
| dataset = dataset.map( |
| lambda examples: tokenize(examples, tokenizer), |
| batched=True, |
| batch_size=args.batch_size, |
| remove_columns=list(next(iter(dataset)).keys()), |
| num_proc=args.num_workers, |
| desc="Running tokenizer on dataset" |
| ) |
| logger.info(f'{dataset}') |
| logger.info(f'Saving tokenized dataset to {args.path}') |
| dataset.save_to_disk(args.path) |
|
|