rm_code / reward_acc_v1.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
import torch, wandb, pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
# === 参数 ===
rm_path = "/home/rm3.4.1_9e-6"
data_path = "/home/data/test_sys_3round.parquet"
batch_size = 16
max_length = 8192
N = 1500
seed = 42
# === wandb ===
wandb.init(project="reward_model_scoring", name="rm3.4_9e-6_-format_test-all-v1")
# === 模型 & tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True)
tokenizer.padding_side = "left" # 保持你的设置
config = AutoConfig.from_pretrained(rm_path)
config.num_labels = 1 # reward head
model = AutoModelForSequenceClassification.from_pretrained(
rm_path, config=config, device_map="auto")
model.eval()
device = next(model.parameters()).device
# === 数据 ===
# df = pd.read_parquet(data_path).sample(n=N, random_state=seed).reset_index(drop=True)
df = pd.read_parquet(data_path).reset_index(drop=True)
def format_input(prompt, reply):
txt = (prompt + reply).rstrip("\n")
if not txt.endswith(tokenizer.eos_token):
txt += " " + tokenizer.eos_token
return txt
def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device):
# 1 tokenize(保持你的做法)
ch = tokenizer(chosen_texts, add_special_tokens=False,
truncation=True, max_length=max_length, padding=False)
rj = tokenizer(rejected_texts, add_special_tokens=False,
truncation=True, max_length=max_length, padding=False)
ids1, mask1 = ch["input_ids"], ch["attention_mask"]
ids2, mask2 = rj["input_ids"], rj["attention_mask"]
# 2 ensure eos 存在(保持你的做法)
for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)):
for i in range(len(arr_ids)):
arr_ids[i][-1] = tokenizer.eos_token_id
arr_mask[i][-1] = 1
# 3 left-pad 到 joint_max(保持你的做法)
joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2))
lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq
ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1]
ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2]
mask1 = [lpad(x, 0) for x in mask1]
mask2 = [lpad(x, 0) for x in mask2]
input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device)
attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device)
# 额外返回 ch/rj 的 tokenized 结果用于截断诊断
return input_ids, attn_masks, len(chosen_texts), ch, rj
def was_truncated(token_seqs, max_length):
"""长度达到 max_length 视为可能被截断(用于快速诊断)。"""
return [len(x) >= max_length for x in token_seqs]
# === 推理 ===
chosen_scores, rejected_scores, accs = [], [], []
sample_table = wandb.Table(columns=["index","prompt","chosen","rejected",
"chosen_score","rejected_score","delta","acc"])
total_ch_trunc = 0
total_rj_trunc = 0
total_count = 0
# 统计“被截断 vs 未截断”的 acc
accs_truncated = [] # 这对样本(chosen/reject 任一被截断)上的 acc
accs_not_trunc = [] # 两个都未截断的 acc
for i in tqdm(range(0, len(df), batch_size)):
batch = df.iloc[i:i+batch_size]
chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])]
rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])]
input_ids, attn_masks, split, ch_tok, rj_tok = encode_batch(
chosen_texts, rejected_texts, tokenizer, max_length, device
)
# —— 截断诊断(batch 级)——
ch_trunc_flags = was_truncated(ch_tok["input_ids"], max_length)
rj_trunc_flags = was_truncated(rj_tok["input_ids"], max_length)
batch_ch_trunc_rate = sum(ch_trunc_flags) / len(ch_trunc_flags)
batch_rj_trunc_rate = sum(rj_trunc_flags) / len(rj_trunc_flags)
wandb.log({
"batch_trunc_rate_chosen": batch_ch_trunc_rate,
"batch_trunc_rate_reject": batch_rj_trunc_rate,
})
total_ch_trunc += sum(ch_trunc_flags)
total_rj_trunc += sum(rj_trunc_flags)
total_count += len(ch_trunc_flags)
with torch.no_grad():
rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1)
# ✅ 去掉反归一化(保持与你训练端一致)
# if config.std is not None and config.mean is not None:
# rewards = rewards * config.std + config.mean
chosen_r, rejected_r = rewards[:split], rewards[split:]
for j in range(len(chosen_r)):
idx = i + j
c, r = chosen_r[j].item(), rejected_r[j].item()
delta = c - r
acc = int(delta > 0)
chosen_scores.append(c)
rejected_scores.append(r)
accs.append(acc)
# —— 逐对样本的“被截断 vs 未截断”分类 ——
pair_truncated = bool(ch_trunc_flags[j] or rj_trunc_flags[j])
if pair_truncated:
accs_truncated.append(acc)
else:
accs_not_trunc.append(acc)
avg_acc = sum(accs) / len(accs)
print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}")
sample_table.add_data(idx, batch["chosen_prompt"].iloc[j],
batch["chosen"].iloc[j], batch["reject"].iloc[j],
c, r, delta, acc)
# === 结果 ===
df["chosen_score"] = chosen_scores
df["rejected_score"] = rejected_scores
df["delta"] = df["chosen_score"] - df["rejected_score"]
df["acc"] = accs
accuracy = df["acc"].mean()
mean_chosen = df["chosen_score"].mean()
mean_reject = df["rejected_score"].mean()
mean_delta = df["delta"].mean()
# 全局截断率(简单估计)
overall_ch_trunc_rate = total_ch_trunc / max(total_count, 1)
overall_rj_trunc_rate = total_rj_trunc / max(total_count, 1)
# 被截断 vs 未截断 的 acc
acc_trunc = sum(accs_truncated)/len(accs_truncated) if accs_truncated else float("nan")
acc_notrunc = sum(accs_not_trunc)/len(accs_not_trunc) if accs_not_trunc else float("nan")
print(f"\n✅ Accuracy = {accuracy:.3f}")
print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}")
print(f"✂️ trunc_rate_chosen = {overall_ch_trunc_rate:.3f}, trunc_rate_reject = {overall_rj_trunc_rate:.3f}")
print(f"🔍 acc_truncated = {acc_trunc:.3f} | acc_not_truncated = {acc_notrunc:.3f}")
wandb.log({
"samples_table": sample_table,
"final_accuracy": accuracy,
"mean_chosen_score": mean_chosen,
"mean_rejected_score": mean_reject,
"mean_delta_score": mean_delta,
"overall_trunc_rate_chosen": overall_ch_trunc_rate,
"overall_trunc_rate_reject": overall_rj_trunc_rate,
"acc_truncated": acc_trunc,
"acc_not_truncated": acc_notrunc,
})
wandb.finish()