| import torch, wandb, pandas as pd |
| from tqdm import tqdm |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig |
|
|
| |
| rm_path = "/home/rm3.4.1_9e-6" |
| data_path = "/home/data/test_sys_3round.parquet" |
| batch_size = 16 |
| max_length = 8192 |
| N = 1500 |
| seed = 42 |
|
|
| |
| wandb.init(project="reward_model_scoring", name="rm3.4_9e-6_-format_test-all-v1") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True) |
| tokenizer.padding_side = "left" |
| config = AutoConfig.from_pretrained(rm_path) |
| config.num_labels = 1 |
| model = AutoModelForSequenceClassification.from_pretrained( |
| rm_path, config=config, device_map="auto") |
| model.eval() |
|
|
| device = next(model.parameters()).device |
|
|
| |
| |
| df = pd.read_parquet(data_path).reset_index(drop=True) |
|
|
| def format_input(prompt, reply): |
| txt = (prompt + reply).rstrip("\n") |
| if not txt.endswith(tokenizer.eos_token): |
| txt += " " + tokenizer.eos_token |
| return txt |
|
|
| def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device): |
| |
| ch = tokenizer(chosen_texts, add_special_tokens=False, |
| truncation=True, max_length=max_length, padding=False) |
| rj = tokenizer(rejected_texts, add_special_tokens=False, |
| truncation=True, max_length=max_length, padding=False) |
| ids1, mask1 = ch["input_ids"], ch["attention_mask"] |
| ids2, mask2 = rj["input_ids"], rj["attention_mask"] |
|
|
| |
| for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)): |
| for i in range(len(arr_ids)): |
| arr_ids[i][-1] = tokenizer.eos_token_id |
| arr_mask[i][-1] = 1 |
|
|
| |
| joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2)) |
| lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq |
| ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1] |
| ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2] |
| mask1 = [lpad(x, 0) for x in mask1] |
| mask2 = [lpad(x, 0) for x in mask2] |
|
|
| input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device) |
| attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device) |
| |
| return input_ids, attn_masks, len(chosen_texts), ch, rj |
|
|
| def was_truncated(token_seqs, max_length): |
| """长度达到 max_length 视为可能被截断(用于快速诊断)。""" |
| return [len(x) >= max_length for x in token_seqs] |
|
|
| |
| chosen_scores, rejected_scores, accs = [], [], [] |
| sample_table = wandb.Table(columns=["index","prompt","chosen","rejected", |
| "chosen_score","rejected_score","delta","acc"]) |
|
|
| total_ch_trunc = 0 |
| total_rj_trunc = 0 |
| total_count = 0 |
|
|
| |
| accs_truncated = [] |
| accs_not_trunc = [] |
|
|
| for i in tqdm(range(0, len(df), batch_size)): |
| batch = df.iloc[i:i+batch_size] |
| chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])] |
| rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])] |
|
|
| input_ids, attn_masks, split, ch_tok, rj_tok = encode_batch( |
| chosen_texts, rejected_texts, tokenizer, max_length, device |
| ) |
|
|
| |
| ch_trunc_flags = was_truncated(ch_tok["input_ids"], max_length) |
| rj_trunc_flags = was_truncated(rj_tok["input_ids"], max_length) |
| batch_ch_trunc_rate = sum(ch_trunc_flags) / len(ch_trunc_flags) |
| batch_rj_trunc_rate = sum(rj_trunc_flags) / len(rj_trunc_flags) |
| wandb.log({ |
| "batch_trunc_rate_chosen": batch_ch_trunc_rate, |
| "batch_trunc_rate_reject": batch_rj_trunc_rate, |
| }) |
| total_ch_trunc += sum(ch_trunc_flags) |
| total_rj_trunc += sum(rj_trunc_flags) |
| total_count += len(ch_trunc_flags) |
|
|
| with torch.no_grad(): |
| rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1) |
| |
| |
| |
|
|
| chosen_r, rejected_r = rewards[:split], rewards[split:] |
|
|
| for j in range(len(chosen_r)): |
| idx = i + j |
| c, r = chosen_r[j].item(), rejected_r[j].item() |
| delta = c - r |
| acc = int(delta > 0) |
|
|
| chosen_scores.append(c) |
| rejected_scores.append(r) |
| accs.append(acc) |
|
|
| |
| pair_truncated = bool(ch_trunc_flags[j] or rj_trunc_flags[j]) |
| if pair_truncated: |
| accs_truncated.append(acc) |
| else: |
| accs_not_trunc.append(acc) |
|
|
| avg_acc = sum(accs) / len(accs) |
| print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}") |
|
|
| sample_table.add_data(idx, batch["chosen_prompt"].iloc[j], |
| batch["chosen"].iloc[j], batch["reject"].iloc[j], |
| c, r, delta, acc) |
|
|
| |
| df["chosen_score"] = chosen_scores |
| df["rejected_score"] = rejected_scores |
| df["delta"] = df["chosen_score"] - df["rejected_score"] |
| df["acc"] = accs |
|
|
| accuracy = df["acc"].mean() |
| mean_chosen = df["chosen_score"].mean() |
| mean_reject = df["rejected_score"].mean() |
| mean_delta = df["delta"].mean() |
|
|
| |
| overall_ch_trunc_rate = total_ch_trunc / max(total_count, 1) |
| overall_rj_trunc_rate = total_rj_trunc / max(total_count, 1) |
|
|
| |
| acc_trunc = sum(accs_truncated)/len(accs_truncated) if accs_truncated else float("nan") |
| acc_notrunc = sum(accs_not_trunc)/len(accs_not_trunc) if accs_not_trunc else float("nan") |
|
|
| print(f"\n✅ Accuracy = {accuracy:.3f}") |
| print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}") |
| print(f"✂️ trunc_rate_chosen = {overall_ch_trunc_rate:.3f}, trunc_rate_reject = {overall_rj_trunc_rate:.3f}") |
| print(f"🔍 acc_truncated = {acc_trunc:.3f} | acc_not_truncated = {acc_notrunc:.3f}") |
|
|
| wandb.log({ |
| "samples_table": sample_table, |
| "final_accuracy": accuracy, |
| "mean_chosen_score": mean_chosen, |
| "mean_rejected_score": mean_reject, |
| "mean_delta_score": mean_delta, |
| "overall_trunc_rate_chosen": overall_ch_trunc_rate, |
| "overall_trunc_rate_reject": overall_rj_trunc_rate, |
| "acc_truncated": acc_trunc, |
| "acc_not_truncated": acc_notrunc, |
| }) |
| wandb.finish() |
|
|