| from transformers import AutoTokenizer |
| from datasets import load_dataset, concatenate_datasets |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| tokenizer_path = "/home/rm3.4.1_9e-6" |
| parquet_paths = [ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" |
| |
| ] |
| |
| output_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
| def count_total_tokens(ex): |
| """为样本添加 total_tokens / chosen_tokens / rejected_tokens 字段""" |
| prompt = ex["chosen_prompt"] |
| chosen_ids = tokenizer(prompt + ex["chosen"], add_special_tokens=False)["input_ids"] |
| rejected_ids = tokenizer(prompt + ex["reject"], add_special_tokens=False)["input_ids"] |
| ex["total_tokens"] = len(chosen_ids) + len(rejected_ids) |
| ex["chosen_tokens"] = len(chosen_ids) |
| ex["rejected_tokens"] = len(rejected_ids) |
| return ex |
|
|
| def summary(arr): |
| """返回 max, min, mean(三个 int/float)""" |
| return int(arr.max()), int(arr.min()), float(arr.mean()) |
|
|
| |
| cleaned_sets = [] |
| stats_before = {} |
| stats_after = {} |
|
|
| for path in parquet_paths: |
| name = path.split("/")[-1] |
| print(f"\n▶ 处理 {name}") |
|
|
| |
| ds = load_dataset("parquet", data_files=path, split="train") |
| print(len(ds)) |
| |
| tokens_b = np.array( |
| tokenizer(ds["chosen_prompt"][0] + ds["chosen"][0], add_special_tokens=False)["input_ids"] |
| ) |
| |
| ds_tmp = ds.map(count_total_tokens, desc=f"[{name}] 计算 token (预统计)", num_proc=4) |
| stats_before[name] = summary(np.array(ds_tmp["total_tokens"])) |
|
|
| |
| ds = ds_tmp.filter( |
| lambda x: 1000 <= x["total_tokens"] <= 8192, |
| desc=f"[{name}] 过滤区间 [1000, 8192]" |
| ) |
|
|
| |
| stats_after[name] = summary(np.array(ds["total_tokens"])) |
|
|
| |
| |
| |
|
|
| cleaned_sets.append(ds) |
|
|
| |
| print("\n================ Token 统计对比 ================ ") |
| print(f"{'数据集':<22} | {'过滤前 max/min/mean':<25} | {'过滤后 max/min/mean':<25}") |
| print("-"*80) |
| for name in parquet_paths: |
| n = name.split("/")[-1] |
| b_max, b_min, b_mean = stats_before[n] |
| a_max, a_min, a_mean = stats_after[n] |
| print(f"{n:<22} | {b_max:5d}/{b_min:5d}/{b_mean:7.1f} | {a_max:5d}/{a_min:5d}/{a_mean:7.1f}") |
|
|
| |
| merged = concatenate_datasets(cleaned_sets) |
| merged.to_parquet(output_path) |
| print("\n✅ 合并后样本数:", len(merged),) |
| |
|
|