| import re |
| from typing import List, Optional, Any, Union |
|
|
| class RewardFunctions: |
| @staticmethod |
| def format_reward(completions: List[str], **kwargs) -> List[float]: |
| """Checks for <reasoning>...</reasoning><answer>...</answer> format.""" |
| pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" |
| return [1.0 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions] |
|
|
| @staticmethod |
| def accuracy_reward(completions: List[str], output: Optional[Union[str, List[str]]] = None, **kwargs) -> List[float]: |
| """Compares model completions to the reference output. |
| Robustly extracts answers from <answer> tags and normalizes for comparison.""" |
| if output is None: |
| return [0.0] * len(completions) |
| |
| if isinstance(output, str): |
| output = [output] * len(completions) |
| |
| def normalize(text: str) -> str: |
| |
| text = re.sub(r"</?answer>", "", text, flags=re.IGNORECASE) |
| |
| text = text.lower().strip() |
| |
| text = re.sub(r'[.\u3002?!\uff01\uff1f]+$', '', text) |
| |
| text = " ".join(text.split()) |
| |
| text = re.sub(r'^(the answer is|answer:|result:)\s*', '', text) |
| return text |
|
|
| rewards = [] |
| for c, ref in zip(completions, output): |
| |
| c_match = re.search(r"<answer>(.*?)</answer>", c, re.DOTALL | re.IGNORECASE) |
| c_answer = c_match.group(1).strip() if c_match else c.strip() |
| |
| |
| ref_match = re.search(r"<answer>(.*?)</answer>", str(ref), re.DOTALL | re.IGNORECASE) |
| ref_answer = ref_match.group(1).strip() if ref_match else str(ref).strip() |
| |
| norm_c = normalize(c_answer) |
| norm_ref = normalize(ref_answer) |
| |
| if norm_c == norm_ref: |
| rewards.append(1.0) |
| elif norm_ref in norm_c or norm_c in norm_ref: |
| |
| |
| if len(norm_c) > 0 and len(norm_ref) > 0: |
| ratio = min(len(norm_c), len(norm_ref)) / max(len(norm_c), len(norm_ref)) |
| rewards.append(0.5 * ratio if ratio > 0.5 else 0.2) |
| else: |
| rewards.append(0.0) |
| else: |
| rewards.append(0.0) |
| return rewards |
|
|
| @staticmethod |
| def reasoning_reward(completions: List[str], **kwargs) -> List[float]: |
| """Rewards presence and quality of reasoning steps.""" |
| rewards = [] |
| for c in completions: |
| match = re.search(r"<reasoning>(.*?)</reasoning>", c, re.DOTALL | re.IGNORECASE) |
| if match: |
| reasoning = match.group(1).strip() |
| |
| |
| step_markers = len(re.findall(r"(?:step\s*\d+)|(?:\d+\.)|(?:\bfirst\b|\bsecond\b|\bthird\b|\bfinally\b)", reasoning, re.I)) |
| |
| |
| logical_connectors = len(re.findall(r"(?:\btherefore\b|\bthus\b|\bbecause\b|\bhence\b|\bso\b|\bsince\b|\bconsequently\b)", reasoning, re.I)) |
| |
| |
| thought_markers = len(re.findall(r"(?:\blet's\b|\bwe can\b|\bif we\b|\bthen\b|\bassume\b)", reasoning, re.I)) |
| |
| |
| score = 0.0 |
| if len(reasoning) > 200: |
| score += 0.4 |
| elif len(reasoning) > 50: |
| score += 0.2 |
| |
| |
| score += min(0.3, step_markers * 0.1) |
| score += min(0.2, logical_connectors * 0.05) |
| score += min(0.1, thought_markers * 0.02) |
| |
| |
| if len(reasoning) < 20: |
| score = 0.1 |
| |
| rewards.append(min(1.0, score)) |
| else: |
| rewards.append(0.0) |
| return rewards |
|
|
| @staticmethod |
| def length_penalty(completions: List[str], max_len: int = 1000, **kwargs) -> List[float]: |
| """Penalizes excessively long completions.""" |
| return [max(0.0, 1.0 - (len(c) / max_len)) if len(c) > max_len else 1.0 for c in completions] |
|
|
| @staticmethod |
| def combined_reward(completions: List[str], **kwargs) -> List[float]: |
| """Combines format, accuracy, reasoning, and length rewards.""" |
| f_rewards = RewardFunctions.format_reward(completions, **kwargs) |
| a_rewards = RewardFunctions.accuracy_reward(completions, **kwargs) |
| r_rewards = RewardFunctions.reasoning_reward(completions, **kwargs) |
| l_rewards = RewardFunctions.length_penalty(completions, **kwargs) |
| |
| |
| return [ |
| f * 0.15 + a * 0.55 + r * 0.2 + l * 0.1 |
| for f, a, r, l in zip(f_rewards, a_rewards, r_rewards, l_rewards) |
| ] |
|
|