| import matplotlib.pyplot as plt |
| from datasets import load_dataset, concatenate_datasets |
| from transformers import AutoTokenizer |
| import re |
| import numpy as np |
| import os |
|
|
| paths = [ |
| "/home/data/pk-2089-L6.parquet", |
| "/home/data/pk-1820-L6.parquet", |
| "/home/data/pk-2355-L6.parquet", |
| "/home/data/pk-4088-L6.parquet", |
| "/home/data/pk-3876-L6.parquet", |
| ] |
| tok = AutoTokenizer.from_pretrained("/home/rm") |
|
|
| special_tokens = { |
| "<|im_start|>", "<|im_end|>", |
| "<|eot_id|>", "|eot_id|", "<|end_of_text|>", |
| "<s>", "</s>", |
| "<|system|>", "<|user|>", "<|assistant|>", |
| "<bos>", "<eos>", "<pad>", |
| "<|start_header_id|>", "<|end_header_id|>", |
| "[INST]", "[/INST]", |
| } |
| pat = re.compile("|".join(map(re.escape, special_tokens))) |
|
|
| def clean_text(ex): |
| def norm(s): |
| if not isinstance(s, str): |
| return "" |
| s = pat.sub("", s.strip()) |
| s = re.sub(r"\s+", " ", s).strip() |
| return s |
| ex["chosen"] = norm(ex.get("chosen", "")) |
| ex["reject"] = norm(ex.get("reject", "")) |
| ex["prompt"] = "" |
| return ex |
|
|
| def add_lengths(batch): |
| c_enc = tok(batch["chosen"], add_special_tokens=False) |
| r_enc = tok(batch["reject"], add_special_tokens=False) |
| len_c = [len(x) for x in c_enc["input_ids"]] |
| len_r = [len(x) for x in r_enc["input_ids"]] |
| return { |
| "len_c": len_c, |
| "len_r": len_r, |
| "len_diff": [abs(a-b) for a,b in zip(len_c, len_r)], |
| } |
|
|
| needed = ["prompt", "chosen", "reject", "len_c", "len_r", "len_diff"] |
| sets = [] |
| for p in paths: |
| ds = load_dataset("parquet", data_files=p, split="train") |
| ds = ds.map(clean_text, num_proc=4) |
| ds = ds.map(add_lengths, batched=True, batch_size=1024, num_proc=4) |
| drop_cols = [c for c in ds.column_names if c not in needed] |
| if drop_cols: |
| ds = ds.remove_columns(drop_cols) |
| sets.append(ds) |
|
|
| full = concatenate_datasets(sets) |
|
|
| |
| len_diffs = np.array(full["len_diff"]) |
| for q in [0.50, 0.75, 0.90, 0.95, 0.99]: |
| print(f"|Δlen| 分位数 q={q:.2f}: {np.quantile(len_diffs, q)}") |
|
|
| cut = np.quantile(len_diffs, 0.95) |
| print(f"长度差 0.95 分位数阈值: {cut}") |
|
|
| |
| plt.figure(figsize=(8,5)) |
| plt.hist(len_diffs, bins=50, color="skyblue", edgecolor="black") |
| plt.axvline(cut, color="red", linestyle="--", label=f"0.95分位: {cut}") |
| plt.title("|Δlen| 长度差分布(chosen vs reject)") |
| plt.xlabel("Token Length Difference") |
| plt.ylabel("Frequency") |
| plt.legend() |
| os.makedirs("./plots", exist_ok=True) |
| plot_path = "./plots/len_diff_distribution.png" |
| plt.savefig(plot_path, dpi=300) |
| plt.close() |
| print(f"✅ 已保存长度差分布图: {plot_path}") |
|
|
| |
| full = full.filter(lambda x: x["len_diff"] <= cut, num_proc=4) |
| full = full.remove_columns(["len_c", "len_r", "len_diff"]) |
|
|
| out = "/home/data/reply_only_pairs.parquet" |
| full.to_parquet(out) |
| print("saved:", out, "rows:", len(full)) |
|
|