| """
|
| Collection of various reward signal for the arithmetic problem.
|
| """
|
|
|
| import logging
|
| import re
|
|
|
| from src.utils.string_helper import (
|
| extract_answers_from_completions,
|
| extract_response_from_completions,
|
| )
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger("rewards")
|
|
|
|
|
| def _is_valid_arithmetic_expression(expression: str) -> bool:
|
| """
|
| Check if a string is a valid arithmetic expression containing only:
|
| - Numbers (integers only)
|
| - Arithmetic operators: +, -, x, /
|
| - Whitespace
|
|
|
| Args:
|
| expression: The expression to validate
|
|
|
| Returns:
|
| bool: True if valid arithmetic expression, False otherwise
|
| """
|
| if not expression or not expression.strip():
|
| return False
|
|
|
|
|
|
|
| pattern = r"^[\d\s\+\-x\/]+$"
|
|
|
|
|
| if not re.match(pattern, expression):
|
| return False
|
|
|
|
|
| has_number = re.search(r"\d", expression)
|
| has_operator = re.search(r"[\+\-x\/]", expression)
|
|
|
| if not (has_number and has_operator):
|
| return False
|
|
|
|
|
| try:
|
|
|
| normalized = expression.replace("x", "*")
|
|
|
|
|
| normalized = "".join(normalized.split())
|
|
|
|
|
| if re.search(r"[\+\-\*\/]{2,}", normalized):
|
| return False
|
|
|
|
|
|
|
| eval(normalized)
|
| return True
|
|
|
| except (SyntaxError, ValueError):
|
|
|
| return False
|
| except ZeroDivisionError:
|
|
|
| return True
|
| except:
|
|
|
| return False
|
|
|
|
|
| def _calculate_distance_based_reward(answer: str, correct_answer: int) -> float:
|
| """
|
| Calculate reward based on distance from the correct answer.
|
|
|
| Uses linear scaling: reward = max(0, max_reward - (distance * penalty_per_unit))
|
|
|
| Args:
|
| answer: The arithmetic expression to evaluate
|
| correct_answer: The expected result
|
|
|
| Returns:
|
| float: Reward between 0.0 and 2.0 based on distance from correct answer
|
| """
|
| if not answer or not answer.strip():
|
| return 0.0
|
|
|
|
|
| if not _is_valid_arithmetic_expression(answer):
|
| return 0.0
|
|
|
| try:
|
|
|
| normalized = answer.replace("x", "*")
|
|
|
|
|
| result = eval(normalized)
|
|
|
|
|
| if isinstance(result, (int, float)):
|
| distance = abs(result - correct_answer)
|
|
|
|
|
| if distance < 0.0001:
|
| return 2.0
|
|
|
|
|
|
|
|
|
|
|
| max_reward = 2.0
|
| penalty_per_unit = 0.2
|
|
|
| reward = max_reward - (distance * penalty_per_unit)
|
|
|
|
|
| return max(0.0, reward)
|
|
|
| return 0.0
|
|
|
| except (SyntaxError, ValueError, ZeroDivisionError, OverflowError):
|
|
|
| return 0.0
|
| except:
|
|
|
| return 0.0
|
|
|
|
|
| def format_reward_functiondef(
|
| completions: list[list[dict[str, str]]], **kwargs: dict[str, any]
|
| ) -> list[float]:
|
| """
|
| Reward function that checks if a completion contains <think>...</think> and
|
| <answer>...</answer> sections.
|
|
|
| Args:
|
| completions: List of completions of the format:
|
| [
|
| [
|
| {"role": "user", "content": "..."},
|
| {"role": "assistant", "content": "..."},
|
| ]
|
| ]
|
|
|
| Returns:
|
| List of rewards.
|
| """
|
| pattern = re.compile(r"<think>.*?</think>.*?<answer>.*?</answer>", re.DOTALL)
|
| responses = extract_response_from_completions(completions)
|
| matches = [bool(pattern.search(response)) for response in responses]
|
| return [1.0 if match else 0.0 for match in matches]
|
|
|
|
|
| def arithmetic_format_reward_function(
|
| completions: list[list[dict[str, str]]],
|
| **kwargs: dict[str, any],
|
| ) -> list[float]:
|
| """
|
| Reward function that checks if the content of the answer tag is a valid arithmetic expression.
|
|
|
| The answer should contain only numbers, arithmetic operators (+, -, x, /),
|
| and spaces. Examples of valid formats:
|
| - "1 + 2 x 6 / 3"
|
| - "2 x 1 + 3 - 1"
|
| - "4 + 5 x 2 - 1"
|
|
|
| Args:
|
| completions: List of completions of the format:
|
| [
|
| [
|
| {"role": "user", "content": "..."},
|
| {"role": "assistant", "content": "..."},
|
| ]
|
| ]
|
|
|
| Returns:
|
| List of rewards (1.0 for valid arithmetic expressions, 0.0 otherwise).
|
| """
|
|
|
| answers = extract_answers_from_completions(completions)
|
|
|
| return [
|
| 1.0 if _is_valid_arithmetic_expression(answer) else 0.0 for answer in answers
|
| ]
|
|
|
|
|
| def correctness_reward_function(
|
| completions: list[list[dict[str, str]]], **kwargs: dict[str, any]
|
| ) -> list[float]:
|
| """
|
| Reward function that provides rewards based on how close the arithmetic answer is to the correct result.
|
|
|
| The reward is calculated using linear scaling:
|
| - Perfect match (distance = 0): reward = 2.0
|
| - Each unit of distance reduces reward by 0.2 points
|
| - Minimum reward is 0.0
|
| - Invalid expressions get 0.0
|
|
|
| Args:
|
| completions: List of completions of the format:
|
| [
|
| [
|
| {"role": "user", "content": "..."},
|
| {"role": "assistant", "content": "..."},
|
| ]
|
| ]
|
| **kwargs: Must contain 'correct_answer' key with the expected result
|
|
|
| Returns:
|
| List of rewards (0.0 to 2.0 based on distance from correct answer).
|
|
|
| Raises:
|
| ValueError: If the correct answer is not provided in the kwargs.
|
| """
|
|
|
| correct_answer = kwargs["correct_answer"]
|
|
|
|
|
| answers = extract_answers_from_completions(completions)
|
| completions = [completion[-1]["content"] for completion in completions]
|
|
|
|
|
| logger.info("First question: %s", completions[0])
|
| logger.info("First answer: %s", answers[0])
|
|
|
| return [
|
| _calculate_distance_based_reward(answer, correct_answer) for answer in answers
|
| ]
|
|
|
|
|
| def mathematical_correctness_reward_function(
|
| completions: list[str], **kwargs
|
| ) -> list[float]:
|
| """
|
| Evaluates completions based on Mathematical correctness of the answer
|
|
|
| Args:
|
| completions: Generated outputs
|
| target: Expected answers
|
| **kwargs: Additional keyword arguments
|
|
|
| Returns:
|
| list[float]: Reward scores (1.0 for correct, 0.0 for incorrect)
|
| """
|
| completions = [completion[-1]["content"] for completion in completions]
|
| target = kwargs["correct_answer"]
|
| first_nums = kwargs["num1"]
|
| second_nums = kwargs["num2"]
|
| third_nums = kwargs["num3"]
|
| fourth_nums = kwargs["num4"]
|
| rewards = []
|
|
|
|
|
| logger.info("Completion:\n%s", completions[0])
|
|
|
| for completion, gt, first_num, second_num, third_num, fourth_num in zip(
|
| completions,
|
| target,
|
| first_nums,
|
| second_nums,
|
| third_nums,
|
| fourth_nums,
|
| strict=False,
|
| ):
|
| reward = 0.0
|
| try:
|
|
|
| match = re.search(r"<answer>(.*?)<\/answer>", completion, re.DOTALL)
|
| if match is None:
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β FORMAT ERROR: No <answer> tags found in completion β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info(
|
| "β Completion snippet: %-47s β",
|
| completion[:47] + "..." if len(completion) > 47 else completion,
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| rewards.append(reward)
|
| continue
|
|
|
|
|
| reward += 1.0
|
|
|
|
|
| equation = match.group(1).strip()
|
| if "=" in equation:
|
| equation = equation.split("=")[0]
|
|
|
|
|
| used_numbers = [int(n) for n in re.findall(r"\d+", equation)]
|
|
|
|
|
| correct_numbers = [first_num, second_num, third_num, fourth_num]
|
| if sorted(used_numbers) != sorted(correct_numbers):
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β NUMBER USAGE ERROR: Incorrect numbers used β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info("β Equation: %-57s β", equation[:57])
|
| logger.info("β Expected numbers: %-51s β", str(correct_numbers))
|
| logger.info("β Used numbers: %-55s β", str(used_numbers))
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| rewards.append(reward)
|
| continue
|
|
|
|
|
| reward += 1.0
|
|
|
|
|
| allowed_pattern = r"^[\d+\-*/.\s]+$"
|
| if not re.match(allowed_pattern, equation):
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β INVALID CHARACTERS: Equation contains disallowed characters β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info("β Equation: %-57s β", equation[:57])
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| rewards.append(reward)
|
| continue
|
|
|
|
|
| reward += 1.0
|
|
|
|
|
| result = eval(equation, {"__builtins__": None}, {})
|
|
|
|
|
| if abs(float(result) - float(gt)) < 1e-5:
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β
CORRECT ANSWER: Perfect match! β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info(
|
| "β Equation: %-35s = %-20s β", equation[:35], str(result)[:20]
|
| )
|
| logger.info("β Target: %-59s β", str(gt))
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| reward += 4.0
|
| rewards.append(reward)
|
| else:
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β WRONG RESULT: Equation evaluated to incorrect value β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info(
|
| "β Equation: %-35s = %-20s β", equation[:35], str(result)[:20]
|
| )
|
| logger.info("β Expected: %-57s β", str(gt))
|
| logger.info(
|
| "β Difference: %-55s β", str(abs(float(result) - float(gt)))
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| rewards.append(reward)
|
| except Exception as e:
|
|
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| logger.info(
|
| "β β EVALUATION ERROR: Exception occurred during processing β"
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€"
|
| )
|
| logger.info("β Error: %-61s β", str(e)[:61])
|
| logger.info(
|
| "β Equation: %-57s β",
|
| (equation if "equation" in locals() else "N/A")[:57],
|
| )
|
| logger.info(
|
| "βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| )
|
| rewards.append(reward)
|
| return rewards
|
|
|