| import re |
| import textwrap |
| from copy import deepcopy |
| from typing import Dict, List |
|
|
| import torch |
|
|
| from swift.llm import PtEngine, RequestConfig, Template, to_device |
| from swift.llm.infer.protocol import ChatCompletionResponse |
| from swift.utils import get_logger |
|
|
| logger = get_logger() |
|
|
|
|
| class DefaultRMPlugin: |
| """ |
| Default Reward Model Plugin |
| |
| This class implements the default processing logic for reward models. |
| It assumes that `self.model` is a classification model with a value head(output dimmension 1). |
| The first logits value from the model's output is used as the reward score. |
| """ |
|
|
| def __init__(self, model, template): |
| self.model = model |
| self.template: Template = template |
|
|
| def __call__(self, inputs): |
| batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] |
| reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) |
| reward_inputs.pop('labels') |
|
|
| with torch.inference_mode(): |
| return self.model(**reward_inputs).logits[:, 0] |
|
|
|
|
| class GenRMPlugin(DefaultRMPlugin): |
|
|
| def __init__(self, model, template): |
| """ |
| Generative Reward Model Plugin Example. |
| |
| This method sets up the reward model plugin by initializing the PtEngine for efficient inference, |
| configuring the request parameters, and defining the system prompt that guides the reward model in |
| evaluating responses. |
| |
| Args: |
| model (torch.nn.Module): The generative reward model. |
| template (Template): The template used for encoding input data. |
| """ |
|
|
| super().__init__(model, template) |
| |
| self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) |
| self.request_config = RequestConfig() |
| self.system = textwrap.dedent(""" |
| Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant. |
| Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct. |
| Before finishing your response, please assign a reward using the following format: |
| |
| Reward: {reward} |
| |
| For example: |
| Reward: 0.85 |
| """) |
|
|
| def __call__(self, inputs): |
| """ |
| Compute reward scores for the provided inputs. |
| |
| This method processes each input by converting dialogue messages into a query, sending the query to the |
| reward model for inference, and extracting the reward scores from the model's responses. The final reward |
| for each input is the average of all extracted scores. |
| Args: |
| inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing: |
| - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes: |
| - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). |
| - 'content' (str): The content of the message. |
| - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images'). |
| Returns: |
| torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,), |
| where N is the number of input requests. |
| """ |
|
|
| rm_inputs = self.prepare_rm_inputs(inputs) |
| results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False) |
| rewards = self.compute_rewards(results) |
| return torch.tensor(rewards, dtype=torch.float32) |
|
|
| def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]: |
| """ |
| Prepare inputs for the reward model by converting messages into queries. |
| |
| Args: |
| inputs (List[Dict]): A list of input requests. |
| |
| Returns: |
| List[Dict]: Processed inputs for the reward model. |
| """ |
| rm_inputs = [] |
| for idx, infer_request in enumerate(inputs): |
| |
| rm_infer_request = deepcopy(infer_request) |
|
|
| |
| messages = rm_infer_request.get('messages') |
| query = self.messages_to_query(messages) |
|
|
| |
| rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}] |
|
|
| |
| rm_infer_request['messages'] = rm_messages |
| rm_inputs.append(rm_infer_request) |
| return rm_inputs |
|
|
| @staticmethod |
| def extract_reward(model_output: str) -> float: |
| """ |
| Extract the reward score from the model's output. |
| |
| Args: |
| model_output (str): The model's output string, expected to follow the format "Reward: {reward}". |
| |
| Returns: |
| float: The extracted reward score. |
| |
| Raises: |
| ValueError: If the reward score cannot be extracted or the format is incorrect. |
| """ |
| match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output) |
| if match: |
| return float(match.group(1)) |
| else: |
| logger.warning("Unable to extract reward score from the model's output, set reward to 0") |
| return None |
|
|
| @staticmethod |
| def messages_to_query(messages): |
| """ |
| Compress a list of message dictionaries into a single query string. |
| |
| Args: |
| messages (list[dict]): A list of message dictionaries, each containing: |
| - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). |
| - 'content' (str): The content of the message. |
| |
| Returns: |
| str: A single string that concatenates all messages in a formatted manner. |
| |
| Example: |
| >>> messages = [ |
| ... {'role': 'user', 'content': 'Hello, how are you?'}, |
| ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'}, |
| ... {'role': 'user', 'content': 'Can you help me with my homework?'} |
| ... ] |
| >>> print(messages_to_query(messages)) |
| User: Hello, how are you? |
| Assistant: I am fine, thank you! How can I assist you today? |
| User: Can you help me with my homework? |
| """ |
| |
| formatted_messages = [] |
|
|
| |
| role_mapping = { |
| 'user': 'User', |
| 'assistant': 'Assistant', |
| 'system': 'System' |
| |
| } |
|
|
| for idx, message in enumerate(messages): |
| if not isinstance(message, dict): |
| raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.') |
|
|
| |
| role = message.get('role') |
| content = message.get('content') |
| if not content: |
| continue |
|
|
| |
| role_formatted = role_mapping.get(role.lower(), role.capitalize()) |
|
|
| |
| formatted_messages.append(f'{role_formatted}: {content}') |
|
|
| |
| query = '\n'.join(formatted_messages) |
|
|
| return query |
|
|
| def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]: |
| """ |
| Compute average reward scores from the reward model's outputs. |
| |
| Args: |
| results (List[ChatCompletionResponse]): A list of results from the reward model. |
| |
| Returns: |
| List[float]: A list of average reward scores. |
| """ |
| rewards = [] |
| for idx, output in enumerate(results): |
| try: |
| cur_rewards = [] |
| for choice in output.choices: |
| response = choice.message.content |
| reward = self.extract_reward(response) |
| cur_rewards.append(reward) |
| cur_rewards = [r for r in cur_rewards if r is not None] |
| if cur_rewards: |
| average_reward = sum(cur_rewards) / len(cur_rewards) |
| else: |
| average_reward = 0.0 |
| logger.warning('No valid rewards extracted. Assigning reward score of 0.0.') |
|
|
| rewards.append(average_reward) |
| except Exception as e: |
| logger.error(f'Error computing reward: {e}') |
| rewards.append(0.0) |
| return rewards |
|
|
|
|
| rm_plugins = { |
| 'default': DefaultRMPlugin, |
| 'genrm': GenRMPlugin, |
| } |
|
|