qwen-trainer-scripts / rewards.py
mindchain's picture
Upload folder using huggingface_hub
78a0ca9 verified
import re
from typing import List, Optional, Any, Union
class RewardFunctions:
@staticmethod
def format_reward(completions: List[str], **kwargs) -> List[float]:
"""Checks for <reasoning>...</reasoning><answer>...</answer> format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
return [1.0 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions]
@staticmethod
def accuracy_reward(completions: List[str], output: Optional[Union[str, List[str]]] = None, **kwargs) -> List[float]:
"""Compares model completions to the reference output.
Robustly extracts answers from <answer> tags and normalizes for comparison."""
if output is None:
return [0.0] * len(completions)
if isinstance(output, str):
output = [output] * len(completions)
def normalize(text: str) -> str:
# Remove <answer> tags if they still exist
text = re.sub(r"</?answer>", "", text, flags=re.IGNORECASE)
# Lowercase
text = text.lower().strip()
# Remove punctuation at the end
text = re.sub(r'[.\u3002?!\uff01\uff1f]+$', '', text)
# Normalize whitespace
text = " ".join(text.split())
# Remove common "The answer is" prefix
text = re.sub(r'^(the answer is|answer:|result:)\s*', '', text)
return text
rewards = []
for c, ref in zip(completions, output):
# Extract answer from <answer> tags if present in completion
c_match = re.search(r"<answer>(.*?)</answer>", c, re.DOTALL | re.IGNORECASE)
c_answer = c_match.group(1).strip() if c_match else c.strip()
# Extract answer from <answer> tags if present in reference
ref_match = re.search(r"<answer>(.*?)</answer>", str(ref), re.DOTALL | re.IGNORECASE)
ref_answer = ref_match.group(1).strip() if ref_match else str(ref).strip()
norm_c = normalize(c_answer)
norm_ref = normalize(ref_answer)
if norm_c == norm_ref:
rewards.append(1.0)
elif norm_ref in norm_c or norm_c in norm_ref:
# Partial credit if one is a substring of the other (e.g. "42" in "The answer is 42")
# but only if the overlap is significant
if len(norm_c) > 0 and len(norm_ref) > 0:
ratio = min(len(norm_c), len(norm_ref)) / max(len(norm_c), len(norm_ref))
rewards.append(0.5 * ratio if ratio > 0.5 else 0.2)
else:
rewards.append(0.0)
else:
rewards.append(0.0)
return rewards
@staticmethod
def reasoning_reward(completions: List[str], **kwargs) -> List[float]:
"""Rewards presence and quality of reasoning steps."""
rewards = []
for c in completions:
match = re.search(r"<reasoning>(.*?)</reasoning>", c, re.DOTALL | re.IGNORECASE)
if match:
reasoning = match.group(1).strip()
# Check for step markers
step_markers = len(re.findall(r"(?:step\s*\d+)|(?:\d+\.)|(?:\bfirst\b|\bsecond\b|\bthird\b|\bfinally\b)", reasoning, re.I))
# Check for logical connectors
logical_connectors = len(re.findall(r"(?:\btherefore\b|\bthus\b|\bbecause\b|\bhence\b|\bso\b|\bsince\b|\bconsequently\b)", reasoning, re.I))
# Check for "thought" markers
thought_markers = len(re.findall(r"(?:\blet's\b|\bwe can\b|\bif we\b|\bthen\b|\bassume\b)", reasoning, re.I))
# Base score on length and diversity
score = 0.0
if len(reasoning) > 200:
score += 0.4
elif len(reasoning) > 50:
score += 0.2
# Bonus for steps and logic
score += min(0.3, step_markers * 0.1)
score += min(0.2, logical_connectors * 0.05)
score += min(0.1, thought_markers * 0.02)
# Penalty for very short reasoning with tags
if len(reasoning) < 20:
score = 0.1
rewards.append(min(1.0, score))
else:
rewards.append(0.0)
return rewards
@staticmethod
def length_penalty(completions: List[str], max_len: int = 1000, **kwargs) -> List[float]:
"""Penalizes excessively long completions."""
return [max(0.0, 1.0 - (len(c) / max_len)) if len(c) > max_len else 1.0 for c in completions]
@staticmethod
def combined_reward(completions: List[str], **kwargs) -> List[float]:
"""Combines format, accuracy, reasoning, and length rewards."""
f_rewards = RewardFunctions.format_reward(completions, **kwargs)
a_rewards = RewardFunctions.accuracy_reward(completions, **kwargs)
r_rewards = RewardFunctions.reasoning_reward(completions, **kwargs)
l_rewards = RewardFunctions.length_penalty(completions, **kwargs)
# Weight: 15% format, 55% accuracy, 20% reasoning, 10% length
return [
f * 0.15 + a * 0.55 + r * 0.2 + l * 0.1
for f, a, r, l in zip(f_rewards, a_rewards, r_rewards, l_rewards)
]