| import torch, wandb, pandas as pd |
| from tqdm import tqdm |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig |
|
|
| |
| rm_path = "/home/rm5.0_9e-6" |
| data_path = "/home/data/raw/test/1159-L6_format_full_label_v5.0safe.parquet" |
| batch_size = 16 |
| max_length = 8192 |
|
|
| |
| wandb.init(project="reward_model_scoring", name="5.0_9e-6") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(rm_path, trust_remote_code=True) |
| tokenizer.padding_side = "left" |
| config = AutoConfig.from_pretrained(rm_path) |
| config.num_labels = 1 |
| model = AutoModelForSequenceClassification.from_pretrained( |
| rm_path, config=config, device_map="auto") |
| model.eval() |
|
|
| device = next(model.parameters()).device |
|
|
| |
| df = pd.read_parquet(data_path).sample(n=1500).reset_index(drop=True) |
|
|
| def format_input(prompt, reply): |
| txt = (prompt + reply).rstrip("\n") |
| if not txt.endswith(tokenizer.eos_token): |
| txt += " " + tokenizer.eos_token |
| return txt |
|
|
| def encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device): |
| |
| ch = tokenizer(chosen_texts, add_special_tokens=False, |
| truncation=True, max_length=max_length, padding=False) |
| rj = tokenizer(rejected_texts, add_special_tokens=False, |
| truncation=True, max_length=max_length, padding=False) |
| ids1, mask1 = ch["input_ids"], ch["attention_mask"] |
| ids2, mask2 = rj["input_ids"], rj["attention_mask"] |
|
|
| |
| for arr_ids, arr_mask in ((ids1, mask1), (ids2, mask2)): |
| for i in range(len(arr_ids)): |
| arr_ids[i][-1] = tokenizer.eos_token_id |
| arr_mask[i][-1] = 1 |
|
|
| |
| joint_max = max(max(len(x) for x in ids1), max(len(x) for x in ids2)) |
| lpad = lambda seq, pad: [pad]*(joint_max-len(seq)) + seq |
| ids1 = [lpad(x, tokenizer.pad_token_id) for x in ids1] |
| ids2 = [lpad(x, tokenizer.pad_token_id) for x in ids2] |
| mask1 = [lpad(x, 0) for x in mask1] |
| mask2 = [lpad(x, 0) for x in mask2] |
|
|
| input_ids = torch.tensor(ids1 + ids2, dtype=torch.long).to(device) |
| attn_masks = torch.tensor(mask1 + mask2, dtype=torch.long).to(device) |
| return input_ids, attn_masks, len(chosen_texts) |
|
|
| |
| chosen_scores, rejected_scores, accs = [], [], [] |
| sample_table = wandb.Table(columns=["index","prompt","chosen","rejected", |
| "chosen_score","rejected_score","delta","acc"]) |
|
|
| for i in tqdm(range(0, len(df), batch_size)): |
| batch = df.iloc[i:i+batch_size] |
| chosen_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["chosen"])] |
| rejected_texts = [format_input(p, a) for p,a in zip(batch["chosen_prompt"], batch["reject"])] |
|
|
| input_ids, attn_masks, split = encode_batch(chosen_texts, rejected_texts, tokenizer, max_length, device) |
|
|
| with torch.no_grad(): |
| rewards = model(input_ids=input_ids, attention_mask=attn_masks).logits.squeeze(-1) |
| if config.std is not None and config.mean is not None: |
| rewards = rewards * config.std + config.mean |
|
|
| chosen_r, rejected_r = rewards[:split], rewards[split:] |
|
|
| for j in range(len(chosen_r)): |
| idx = i + j |
| c, r = chosen_r[j].item(), rejected_r[j].item() |
| delta = c - r |
| acc = int(delta > 0) |
| chosen_scores.append(c); rejected_scores.append(r); accs.append(acc) |
| avg_acc = sum(accs) / len(accs) |
| print(f"[{idx}] acc={acc}, chosen={c:.3f}, rejected={r:.3f}, Δ={delta:.3f} | avg acc={avg_acc:.3f}") |
|
|
| |
| sample_table.add_data(idx, batch["chosen_prompt"].iloc[j], |
| batch["chosen"].iloc[j], batch["reject"].iloc[j], |
| c, r, delta, acc) |
|
|
| |
| df["chosen_score"] = chosen_scores |
| df["rejected_score"] = rejected_scores |
| df["delta"] = df["chosen_score"] - df["rejected_score"] |
| df["acc"] = accs |
|
|
| accuracy = df["acc"].mean() |
| mean_chosen = df["chosen_score"].mean() |
| mean_reject = df["rejected_score"].mean() |
| mean_delta = df["delta"].mean() |
|
|
| print(f"\n✅ Accuracy = {accuracy:.3f}") |
| print(f"📊 mean_chosen = {mean_chosen:.3f}, mean_rejected = {mean_reject:.3f}, mean_delta = {mean_delta:.3f}") |
|
|
| wandb.log({ |
| "samples_table": sample_table, |
| "final_accuracy": accuracy, |
| "mean_chosen_score": mean_chosen, |
| "mean_rejected_score": mean_reject, |
| "mean_delta_score": mean_delta, |
| }) |
| wandb.finish() |
|
|