| import argparse |
| import re |
|
|
| import torch |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
|
|
| from openrlhf.models import get_llm_for_sequence_regression |
| from openrlhf.utils import get_tokenizer |
| from openrlhf.utils.logging_utils import init_logger |
|
|
| logger = init_logger(__name__) |
|
|
| class RewardModelProxy: |
| def __init__(self, args): |
| |
| self.reward_model = get_llm_for_sequence_regression( |
| args.reward_pretrain, |
| "reward", |
| normalize_reward=args.normalize_reward, |
| use_flash_attention_2=args.flash_attn, |
| bf16=args.bf16, |
| load_in_4bit=args.load_in_4bit, |
| value_head_prefix=args.value_head_prefix, |
| device_map="auto", |
| ) |
| self.reward_model.eval() |
|
|
| self.tokenizer = get_tokenizer( |
| args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer |
| ) |
| self.max_length = args.max_len |
| self.batch_size = args.batch_size |
|
|
| def get_reward(self, queries): |
| |
| |
| |
| |
|
|
| |
| correct_count = 0 |
| total_count = len(df) |
| allscores = [] |
| |
| for index, row in df.iterrows(): |
| chosen_query = row["chosen_prompt"] + " " + row["chosen"] |
| reject_query = row["chosen_prompt"] + " " + row["reject"] |
|
|
| |
| scores = self.compare_queries(chosen_query, reject_query) |
|
|
| all_scores.append(scores) |
|
|
| |
| chosen_score, reject_score = scores |
| if chosen_score > reject_score: |
| correct_count += 1 |
|
|
| accuracy = correct_count / total_count if total_count > 0 else 0 |
| print(f"Current Accuracy: {accuracy * 100:.2f}%") |
| return all_scores, accuracy |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def compare_queries(self, chosen_query, reject_query): |
| """ |
| Compare the reward scores for chosen_query and reject_query. |
| :param chosen_query: The query with the 'chosen' answer |
| :param reject_query: The query with the 'reject' answer |
| :return: Tuple (chosen_score, reject_score) |
| """ |
| with torch.no_grad(): |
| inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device) |
| inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device) |
|
|
| chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0] |
| reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0] |
|
|
| return chosen_score, reject_score |
|
|
| def tokenize_fn(self, texts, device): |
| batch = self.tokenizer( |
| texts, |
| return_tensors="pt", |
| max_length=self.max_length, |
| padding=True, |
| truncation=True, |
| ) |
| return {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") |
| parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") |
| parser.add_argument("--value_head_prefix", type=str, default="value_head") |
| parser.add_argument("--max_len", type=int, default="2048") |
|
|
| parser.add_argument("--port", type=int, default=5000, help="Port number for the server") |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") |
|
|
| |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) |
| parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="flash_attention_2", |
| help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)", |
| ) |
| parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) |
| parser.add_argument("--batch_size", type=int, default=None) |
|
|
| args = parser.parse_args() |
|
|
| |
| reward_model = RewardModelProxy(args) |
| app = FastAPI() |
|
|
| @app.post("/get_reward") |
| async def get_reward(request: Request): |
| data = await request.json() |
| queries = data.get("query") |
| rewards = reward_model.get_reward(queries) |
| result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}} |
| logger.info(f"Sent JSON: {result}") |
| return JSONResponse(result) |
|
|
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") |