| |
| |
|
|
| from argparse import Namespace |
| import pandas as pd |
| from vllm import LLM, EngineArgs |
| from vllm.utils import FlexibleArgumentParser |
| import wandb |
|
|
| |
| PREFIX = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' |
| SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
| DATA_PATH = "/home/data/test_transformed_v1.parquet" |
| WANDB_PROJECT = "reranker_eval_wrong" |
| WANDB_RUN_NAME = "qwen3_seqcls_scoring" |
|
|
| def format_query(chosen_prompt: str) -> str: |
| |
| instruction = ( |
| "Given a roleplay prompt and recent context, score candidate replies higher when they stay in character, continue the scene coherently, and feel vivid and engaging." |
| ) |
| return f"{PREFIX}<Instruct>: {instruction}\n<Query>:{chosen_prompt}\n" |
|
|
| def format_document(doc_text: str) -> str: |
| |
| return f"<Document>: {doc_text}{SUFFIX}" |
|
|
| def parse_args(): |
| parser = FlexibleArgumentParser() |
| parser = EngineArgs.add_cli_args(parser) |
| parser.set_defaults( |
| model="deeppin/Qwen3-Reranker-8B-SequenceClassification", |
| task="score", |
| enforce_eager=True, |
| trust_remote_code=True, |
| ) |
| return parser.parse_args() |
|
|
| def main(args: Namespace): |
| |
| df = pd.read_parquet(DATA_PATH) |
|
|
| wandb.init(project=WANDB_PROJECT, name=WANDB_RUN_NAME) |
| wandb.config.update({"model": args.model, "data_path": DATA_PATH}) |
|
|
| |
| llm = LLM(**vars(args)) |
|
|
| |
| correct = 0 |
| total = 0 |
| wrong_samples = [] |
| for i, row in df.iterrows(): |
| chosen_prompt = row["chosen_prompt"] |
| chosen = row["chosen"] |
| reject = row["reject"] |
|
|
| |
| if not isinstance(chosen_prompt, str) or not isinstance(chosen, str) or not isinstance(reject, str): |
| continue |
| if chosen.strip() == "" or reject.strip() == "": |
| continue |
|
|
| q = format_query(chosen_prompt) |
| d1 = format_document(chosen) |
| d2 = format_document(reject) |
|
|
| try: |
| |
| outs = llm.score([q, q], [d1, d2]) |
| |
| s1, s2 = (o.outputs.score for o in outs) |
| chosen_better = (s1 > s2) |
| total += 1 |
| if chosen_better: |
| correct += 1 |
| running_acc = correct / total if total > 0 else 0.0 |
| |
| print({"chosen_score": s1, "reject_score": s2, "chosen_better": chosen_better},f"[RunningAcc] {correct}/{total} = {running_acc:.4f}") |
| wandb.log({ |
| "metric/running_acc": running_acc, |
| "score/chosen": float(s1), |
| "score/reject": float(s2), |
| "score/margin": float(s1 - s2), |
| }, step=total) |
| if not chosen_better: |
| wrong_samples.append({ |
| "index": int(i), |
| "chosen_score": float(s1), |
| "reject_score": float(s2), |
| "margin": float(s1 - s2), |
| "chosen_prompt": chosen_prompt, |
| "chosen": chosen, |
| "reject": reject, |
| }) |
| except Exception as e: |
| |
| print(f"[Error] index={i}: {e}") |
|
|
| |
| final_acc = correct / total if total > 0 else 0.0 |
| print(f"[FinalAcc] {correct}/{total} = {final_acc:.4f}") |
| wandb.summary["final/accuracy"] = final_acc |
| wandb.summary["final/total"] = total |
| wandb.summary["final/correct"] = correct |
| wandb.summary["final/wrong"] = len(wrong_samples) |
| |
| if wrong_samples: |
| table = wandb.Table(columns=[ |
| "index", "chosen_score", "reject_score", "margin", |
| "chosen_prompt", "chosen", "reject" |
| ]) |
| for r in wrong_samples: |
| table.add_data( |
| r["index"], r["chosen_score"], r["reject_score"], r["margin"], |
| r["chosen_prompt"], r["chosen"], r["reject"] |
| ) |
| wandb.log({"errors/wrong_samples": table}) |
| |
| try: |
| _df = pd.DataFrame(wrong_samples) |
| _df.to_csv("wrong_samples.csv", index=False) |
| art = wandb.Artifact("wrong_samples", type="dataset") |
| art.add_file("wrong_samples.csv") |
| wandb.log_artifact(art) |
| except Exception: |
| pass |
| wandb.finish() |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|