import torch, wandb, pandas as pd from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig # === 参数 === rm_path = "/home/rm3.4.1_9e-6" data_path = "/home/data/test_sys_3round.parquet" batch_size = 16 max_length = 8192 N = 1500 seed = 42 # === wandb === wandb.init(project="reward_model_scoring", name="rm3.4_9e-6_-format_test-all-v1") # === 模型 & tokenizer === tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True) tokenizer.padding_side = "left" # 保持你的设置 config = AutoConfig.from_pretrained(rm_path) config.num_labels = 1 # reward head model = AutoModelForSequenceClassification.from_pretrained( rm_path, config=config, device_map="auto") model.eval() device = next(model.parameters()).device # === 数据 === # df = pd.read_parquet(data_path).sample(n=N, random_state=seed).reset_index(drop=True) df = pd.read_parquet(data_path).reset_index(drop=True) def format_input(prompt, reply): txt = (prompt + reply).rstrip("\n") if not txt.endswith(tokenizer.eos_token): txt += " " + tokenizer.eos_token return txt def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device): # 1 tokenize(保持你的做法) ch = tokenizer(chosen_texts, add_special_tokens=False, truncation=True, max_length=max_length, padding=False) rj = tokenizer(rejected_texts, add_special_tokens=False, truncation=True, max_length=max_length, padding=False) ids1, mask1 = ch["input_ids"], ch["attention_mask"] ids2, mask2 = rj["input_ids"], rj["attention_mask"] # 2 ensure eos 存在(保持你的做法) for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)): for i in range(len(arr_ids)): arr_ids[i][-1] = tokenizer.eos_token_id arr_mask[i][-1] = 1 # 3 left-pad 到 joint_max(保持你的做法) joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2)) lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1] ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2] mask1 = [lpad(x, 0) for x in mask1] mask2 = [lpad(x, 0) for x in mask2] input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device) attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device) # 额外返回 ch/rj 的 tokenized 结果用于截断诊断 return input_ids, attn_masks, len(chosen_texts), ch, rj def was_truncated(token_seqs, max_length): """长度达到 max_length 视为可能被截断(用于快速诊断)。""" return [len(x) >= max_length for x in token_seqs] # === 推理 === chosen_scores, rejected_scores, accs = [], [], [] sample_table = wandb.Table(columns=["index","prompt","chosen","rejected", "chosen_score","rejected_score","delta","acc"]) total_ch_trunc = 0 total_rj_trunc = 0 total_count = 0 # 统计“被截断 vs 未截断”的 acc accs_truncated = [] # 这对样本(chosen/reject 任一被截断)上的 acc accs_not_trunc = [] # 两个都未截断的 acc for i in tqdm(range(0, len(df), batch_size)): batch = df.iloc[i:i+batch_size] chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])] rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])] input_ids, attn_masks, split, ch_tok, rj_tok = encode_batch( chosen_texts, rejected_texts, tokenizer, max_length, device ) # —— 截断诊断(batch 级)—— ch_trunc_flags = was_truncated(ch_tok["input_ids"], max_length) rj_trunc_flags = was_truncated(rj_tok["input_ids"], max_length) batch_ch_trunc_rate = sum(ch_trunc_flags) / len(ch_trunc_flags) batch_rj_trunc_rate = sum(rj_trunc_flags) / len(rj_trunc_flags) wandb.log({ "batch_trunc_rate_chosen": batch_ch_trunc_rate, "batch_trunc_rate_reject": batch_rj_trunc_rate, }) total_ch_trunc += sum(ch_trunc_flags) total_rj_trunc += sum(rj_trunc_flags) total_count += len(ch_trunc_flags) with torch.no_grad(): rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1) # ✅ 去掉反归一化(保持与你训练端一致) # if config.std is not None and config.mean is not None: # rewards = rewards * config.std + config.mean chosen_r, rejected_r = rewards[:split], rewards[split:] for j in range(len(chosen_r)): idx = i + j c, r = chosen_r[j].item(), rejected_r[j].item() delta = c - r acc = int(delta > 0) chosen_scores.append(c) rejected_scores.append(r) accs.append(acc) # —— 逐对样本的“被截断 vs 未截断”分类 —— pair_truncated = bool(ch_trunc_flags[j] or rj_trunc_flags[j]) if pair_truncated: accs_truncated.append(acc) else: accs_not_trunc.append(acc) avg_acc = sum(accs) / len(accs) print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}") sample_table.add_data(idx, batch["chosen_prompt"].iloc[j], batch["chosen"].iloc[j], batch["reject"].iloc[j], c, r, delta, acc) # === 结果 === df["chosen_score"] = chosen_scores df["rejected_score"] = rejected_scores df["delta"] = df["chosen_score"] - df["rejected_score"] df["acc"] = accs accuracy = df["acc"].mean() mean_chosen = df["chosen_score"].mean() mean_reject = df["rejected_score"].mean() mean_delta = df["delta"].mean() # 全局截断率(简单估计) overall_ch_trunc_rate = total_ch_trunc / max(total_count, 1) overall_rj_trunc_rate = total_rj_trunc / max(total_count, 1) # 被截断 vs 未截断 的 acc acc_trunc = sum(accs_truncated)/len(accs_truncated) if accs_truncated else float("nan") acc_notrunc = sum(accs_not_trunc)/len(accs_not_trunc) if accs_not_trunc else float("nan") print(f"\n✅ Accuracy = {accuracy:.3f}") print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}") print(f"✂️ trunc_rate_chosen = {overall_ch_trunc_rate:.3f}, trunc_rate_reject = {overall_rj_trunc_rate:.3f}") print(f"🔍 acc_truncated = {acc_trunc:.3f} | acc_not_truncated = {acc_notrunc:.3f}") wandb.log({ "samples_table": sample_table, "final_accuracy": accuracy, "mean_chosen_score": mean_chosen, "mean_rejected_score": mean_reject, "mean_delta_score": mean_delta, "overall_trunc_rate_chosen": overall_ch_trunc_rate, "overall_trunc_rate_reject": overall_rj_trunc_rate, "acc_truncated": acc_trunc, "acc_not_truncated": acc_notrunc, }) wandb.finish()