| |
| """ |
| Preprocess multilingual Wikipedia data for model training. |
| |
| This script performs the following steps: |
| 1. Downloads the Wikimedia Wikipedia dataset for the specified languages |
| (https://huggingface.co/datasets/wikimedia/wikipedia). |
| 2. Tokenizes the dataset using a Hugging Face tokenizer corresponding |
| to the specified model. |
| 3. Aggregates token IDs up to a target number of tokens per language. |
| 4. Saves the tokenized data as a PyTorch tensor for later use. |
| |
| Usage example: |
| python load_datas.py \ |
| --languages en zh vi \ |
| --model-id meta-llama/Llama-3.1-8B-Instruct \ |
| --tokenizer meta-llama/Llama-3.1-8B-Instruct \ |
| --output-dir train-data |
| """ |
| |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| import torch |
| import os |
| from tqdm import tqdm |
| import multiprocessing |
| from functools import partial |
| import argparse |
|
|
| NUM_PROC_BASE = max(1, os.cpu_count() // 2 if os.cpu_count() else 1) |
| TARGET_TOKENS_PER_LANGUAGE = 100_000_000 |
| DATE_SNAPSHOT = "20231101" |
|
|
| def tokenize_function(examples, tokenizer): |
| output = tokenizer( |
| examples["text"], |
| add_special_tokens=False, |
| truncation=False, |
| padding=False, |
| ) |
| return {"input_ids": output.input_ids} |
|
|
| def build_and_save( |
| lang, |
| model_id, |
| tokenizer_name, |
| output_dir, |
| num_proc_map=NUM_PROC_BASE |
| ): |
| print(f"Starting data processing for language: {lang}") |
|
|
| train_filename_base = f"id.{lang}.train.{model_id.replace('/', '_')}" |
| train_output_path = os.path.join(output_dir, train_filename_base) |
|
|
| try: |
| ds = load_dataset("wikimedia/wikipedia", f"{DATE_SNAPSHOT}.{lang}", split="train", trust_remote_code=True) |
| if len(ds) == 0: |
| print(f"Warning: Dataset for {lang} is empty. Skipping.") |
| return |
| except Exception as e: |
| print(f"Error loading dataset for {lang}: {e}") |
| raise |
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_name, |
| use_fast=True, |
| trust_remote_code=True, |
| ) |
| except Exception as e: |
| print(f"Error loading tokenizer '{tokenizer_name}': {e}") |
| raise |
|
|
| tokenization_func_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer) |
|
|
| tokenized_ds = ds.map( |
| tokenization_func_with_tokenizer, |
| batched=True, |
| num_proc=num_proc_map, |
| remove_columns=ds.column_names, |
| desc=f"Tokenizing {lang}" |
| ) |
|
|
| all_document_token_lists = [] |
| for processed_example in tqdm(tokenized_ds, desc=f"Collecting token lists for {lang}"): |
| token_list_for_one_doc = processed_example['input_ids'] |
| if isinstance(token_list_for_one_doc, list): |
| all_document_token_lists.append(token_list_for_one_doc) |
|
|
| if not all_document_token_lists: |
| print(f"Warning: No token sequences found for {lang} after tokenization. Skipping.") |
| return |
|
|
| final_token_ids = [] |
| collected_tokens_count = 0 |
| for doc_tokens_list in tqdm(all_document_token_lists, desc=f"Aggregating tokens for {lang}"): |
| if not doc_tokens_list: |
| continue |
|
|
| current_doc_token_count = len(doc_tokens_list) |
|
|
| if collected_tokens_count + current_doc_token_count <= TARGET_TOKENS_PER_LANGUAGE: |
| final_token_ids.extend(doc_tokens_list) |
| collected_tokens_count += current_doc_token_count |
| else: |
| remaining_needed = TARGET_TOKENS_PER_LANGUAGE - collected_tokens_count |
| final_token_ids.extend(doc_tokens_list[:remaining_needed]) |
| collected_tokens_count += remaining_needed |
| break |
|
|
| if collected_tokens_count >= TARGET_TOKENS_PER_LANGUAGE: |
| break |
|
|
| del all_document_token_lists |
| del tokenized_ds |
| del ds |
|
|
| if collected_tokens_count == 0: |
| print(f"Warning: Zero tokens collected for {lang}. Skipping save.") |
| return |
|
|
| if collected_tokens_count < TARGET_TOKENS_PER_LANGUAGE: |
| print(f"Warning: Language {lang} has only {collected_tokens_count:,} tokens, " |
| f"which is less than the target of {TARGET_TOKENS_PER_LANGUAGE:,}.") |
|
|
| full_tensor = torch.tensor(final_token_ids, dtype=torch.long) |
| del final_token_ids |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| torch.save(full_tensor, train_output_path) |
| print(f"Saved {full_tensor.numel():,} tokens for {lang}.") |
| del full_tensor |
|
|
| def run_job(args): |
| lang, model_id, tokenizer_name, output_dir, num_proc_map = args |
| print(f"Processing language: {lang} (PID: {os.getpid()})") |
| try: |
| build_and_save( |
| lang=lang, |
| model_id=model_id, |
| tokenizer_name=tokenizer_name, |
| output_dir=output_dir, |
| num_proc_map=num_proc_map |
| ) |
| return lang, True, None |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return lang, False, str(e) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Preprocess Wikipedia data for multiple languages.") |
| parser.add_argument( |
| "--languages", type=str, default='en,zh,eu,ga', |
| help="Comma-separated list of languages to process, e.g., 'en,zh,fr'" |
| ) |
| parser.add_argument("--model-id", type=str, required=True, help="Model identifier (used for file naming).") |
| parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path.") |
| parser.add_argument("--output-dir", type=str, default="train-data", help="Where to store tokenized tensors.") |
| parser.add_argument("--max-concurrent", type=int, default=6, help="Max concurrent processes.") |
| args = parser.parse_args() |
|
|
| args.languages = [lang.strip() for lang in args.languages.split(',') if lang.strip()] |
|
|
| MAX_CONCURRENT_LANGUAGES = args.max_concurrent |
| NUM_MAP_PROC_PER_LANG = max(1, NUM_PROC_BASE // MAX_CONCURRENT_LANGUAGES if MAX_CONCURRENT_LANGUAGES > 0 else NUM_PROC_BASE) |
|
|
| print(f"Starting batch processing for {len(args.languages)} languages.") |
|
|
| job_args_list = [ |
| (lang, args.model_id, args.tokenizer, args.output_dir, NUM_MAP_PROC_PER_LANG) |
| for lang in args.languages |
| ] |
|
|
| successful_langs = [] |
| failed_langs_with_errors = {} |
|
|
| with multiprocessing.Pool(processes=MAX_CONCURRENT_LANGUAGES) as pool: |
| results_iterable = pool.imap_unordered(run_job, job_args_list) |
| for result in tqdm(results_iterable, total=len(args.languages), desc="Overall Language Progress"): |
| lang_processed, success, error_msg = result |
| if success: |
| successful_langs.append(lang_processed) |
| else: |
| failed_langs_with_errors[lang_processed] = error_msg |
|
|
| print("Batch processing finished.") |
| print(f"Successfully processed: {', '.join(sorted(successful_langs))}") |
| if failed_langs_with_errors: |
| print(f"Failed to process: {', '.join(sorted(failed_langs_with_errors.keys()))}") |
| for lang_failed, err in failed_langs_with_errors.items(): |
| print(f" - {lang_failed}: {err}") |
|
|