import re from typing import List, Optional, Any, Union class RewardFunctions: @staticmethod def format_reward(completions: List[str], **kwargs) -> List[float]: """Checks for ...... format.""" pattern = r".*?\s*.*?" return [1.0 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions] @staticmethod def accuracy_reward(completions: List[str], output: Optional[Union[str, List[str]]] = None, **kwargs) -> List[float]: """Compares model completions to the reference output. Robustly extracts answers from tags and normalizes for comparison.""" if output is None: return [0.0] * len(completions) if isinstance(output, str): output = [output] * len(completions) def normalize(text: str) -> str: # Remove tags if they still exist text = re.sub(r"", "", text, flags=re.IGNORECASE) # Lowercase text = text.lower().strip() # Remove punctuation at the end text = re.sub(r'[.\u3002?!\uff01\uff1f]+$', '', text) # Normalize whitespace text = " ".join(text.split()) # Remove common "The answer is" prefix text = re.sub(r'^(the answer is|answer:|result:)\s*', '', text) return text rewards = [] for c, ref in zip(completions, output): # Extract answer from tags if present in completion c_match = re.search(r"(.*?)", c, re.DOTALL | re.IGNORECASE) c_answer = c_match.group(1).strip() if c_match else c.strip() # Extract answer from tags if present in reference ref_match = re.search(r"(.*?)", str(ref), re.DOTALL | re.IGNORECASE) ref_answer = ref_match.group(1).strip() if ref_match else str(ref).strip() norm_c = normalize(c_answer) norm_ref = normalize(ref_answer) if norm_c == norm_ref: rewards.append(1.0) elif norm_ref in norm_c or norm_c in norm_ref: # Partial credit if one is a substring of the other (e.g. "42" in "The answer is 42") # but only if the overlap is significant if len(norm_c) > 0 and len(norm_ref) > 0: ratio = min(len(norm_c), len(norm_ref)) / max(len(norm_c), len(norm_ref)) rewards.append(0.5 * ratio if ratio > 0.5 else 0.2) else: rewards.append(0.0) else: rewards.append(0.0) return rewards @staticmethod def reasoning_reward(completions: List[str], **kwargs) -> List[float]: """Rewards presence and quality of reasoning steps.""" rewards = [] for c in completions: match = re.search(r"(.*?)", c, re.DOTALL | re.IGNORECASE) if match: reasoning = match.group(1).strip() # Check for step markers step_markers = len(re.findall(r"(?:step\s*\d+)|(?:\d+\.)|(?:\bfirst\b|\bsecond\b|\bthird\b|\bfinally\b)", reasoning, re.I)) # Check for logical connectors logical_connectors = len(re.findall(r"(?:\btherefore\b|\bthus\b|\bbecause\b|\bhence\b|\bso\b|\bsince\b|\bconsequently\b)", reasoning, re.I)) # Check for "thought" markers thought_markers = len(re.findall(r"(?:\blet's\b|\bwe can\b|\bif we\b|\bthen\b|\bassume\b)", reasoning, re.I)) # Base score on length and diversity score = 0.0 if len(reasoning) > 200: score += 0.4 elif len(reasoning) > 50: score += 0.2 # Bonus for steps and logic score += min(0.3, step_markers * 0.1) score += min(0.2, logical_connectors * 0.05) score += min(0.1, thought_markers * 0.02) # Penalty for very short reasoning with tags if len(reasoning) < 20: score = 0.1 rewards.append(min(1.0, score)) else: rewards.append(0.0) return rewards @staticmethod def length_penalty(completions: List[str], max_len: int = 1000, **kwargs) -> List[float]: """Penalizes excessively long completions.""" return [max(0.0, 1.0 - (len(c) / max_len)) if len(c) > max_len else 1.0 for c in completions] @staticmethod def combined_reward(completions: List[str], **kwargs) -> List[float]: """Combines format, accuracy, reasoning, and length rewards.""" f_rewards = RewardFunctions.format_reward(completions, **kwargs) a_rewards = RewardFunctions.accuracy_reward(completions, **kwargs) r_rewards = RewardFunctions.reasoning_reward(completions, **kwargs) l_rewards = RewardFunctions.length_penalty(completions, **kwargs) # Weight: 15% format, 55% accuracy, 20% reasoning, 10% length return [ f * 0.15 + a * 0.55 + r * 0.2 + l * 0.1 for f, a, r, l in zip(f_rewards, a_rewards, r_rewards, l_rewards) ]