| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import importlib.util |
| import os |
| import sys |
| from abc import ABC, abstractmethod |
| from collections import defaultdict |
| from functools import partial |
| from typing import Callable, Optional, Tuple, TypedDict |
|
|
| import torch |
| from transformers import PreTrainedTokenizer |
|
|
| from ...protocol import DataProto |
| from .config import RewardConfig |
|
|
|
|
| class RewardInput(TypedDict): |
| response: str |
| response_length: int |
| ground_truth: str |
|
|
|
|
| class RewardScore(TypedDict): |
| overall: float |
| format: Optional[float] |
| accuracy: Optional[float] |
|
|
|
|
| SequentialRewardFunction = Callable[[RewardInput], RewardScore] |
|
|
| BatchRewardFunction = Callable[[list[RewardInput]], list[RewardScore]] |
|
|
|
|
| class FunctionRewardManager(ABC): |
| """Reward manager for rule-based reward.""" |
|
|
| def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): |
| if config.reward_function is None: |
| raise ValueError("Reward function is not provided.") |
|
|
| if not os.path.exists(config.reward_function): |
| raise FileNotFoundError(f"Reward function file {config.reward_function} not found.") |
|
|
| spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function) |
| module = importlib.util.module_from_spec(spec) |
| try: |
| sys.modules["custom_reward_fn"] = module |
| spec.loader.exec_module(module) |
| except Exception as e: |
| raise RuntimeError(f"Failed to load reward function: {e}") |
|
|
| if not hasattr(module, config.reward_function_name): |
| raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.") |
|
|
| reward_fn = getattr(module, config.reward_function_name) |
| print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.") |
| self.reward_fn = partial(reward_fn, **config.reward_function_kwargs) |
| self.config = config |
| self.tokenizer = tokenizer |
|
|
| @abstractmethod |
| def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, dict[str, list[float]]]: |
| """Compute reward for a batch of data.""" |
| ... |
|
|
|
|
| class SequentialFunctionRewardManager(FunctionRewardManager): |
| reward_fn: SequentialRewardFunction |
|
|
| def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, dict[str, list[float]]]: |
| reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) |
| reward_metrics = defaultdict(list) |
| response_ids = data.batch["responses"] |
| response_length = torch.sum(data.batch["response_mask"], dim=-1) |
| for i in range(len(data)): |
| cur_response_length = int(response_length[i].item()) |
| valid_response_ids = response_ids[i][:cur_response_length] |
| response_str = self.tokenizer.decode( |
| valid_response_ids, skip_special_tokens=self.config.skip_special_tokens |
| ) |
| score = self.reward_fn( |
| { |
| "response": response_str, |
| "response_length": cur_response_length, |
| "ground_truth": data.non_tensor_batch["ground_truth"][i], |
| } |
| ) |
| reward_tensor[i, cur_response_length - 1] = score["overall"] |
| for key, value in score.items(): |
| reward_metrics[key].append(value) |
|
|
| return reward_tensor, reward_metrics |
|
|
|
|
| class BatchFunctionRewardManager(FunctionRewardManager): |
| reward_fn: BatchRewardFunction |
|
|
| def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, dict[str, list[float]]]: |
| reward_inputs = [] |
| response_ids = data.batch["responses"] |
| response_length = torch.sum(data.batch["response_mask"], dim=-1) |
| for i in range(len(data)): |
| cur_response_length = int(response_length[i].item()) |
| valid_response_ids = response_ids[i][:cur_response_length] |
| response_str = self.tokenizer.decode( |
| valid_response_ids, skip_special_tokens=self.config.skip_special_tokens |
| ) |
|
|
|
|
| reward_inputs.append( |
| { |
| "response": response_str, |
| "response_length": cur_response_length, |
| "ground_truth": data.non_tensor_batch["ground_truth"][i], |
| "data_type": data.non_tensor_batch["data_type"][i], |
| "problem_type": data.non_tensor_batch["problem_type"][i], |
| "problem": data.non_tensor_batch["problem_reserved_text"][i], |
| "problem_id": data.non_tensor_batch["problem_id"][i], |
| } |
| ) |
|
|
|
|
| |
| |
| |
|
|
| scores = self.reward_fn(reward_inputs) |
| reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) |
| reward_metrics = defaultdict(list) |
| for i, score in enumerate(scores): |
| cur_response_length = int(response_length[i].item()) |
| reward_tensor[i, cur_response_length - 1] = score["overall"] |
| for key, value in score.items(): |
| reward_metrics[key].append(value) |
|
|
| return reward_tensor, reward_metrics |
|
|