| import os |
| import re |
| from typing import Dict, List, Union |
|
|
| import json |
|
|
| from swift.llm import InferRequest |
|
|
|
|
| class ORM: |
|
|
| def __call__(self, **kwargs) -> List[float]: |
| raise NotImplementedError |
|
|
|
|
| class ReactORM(ORM): |
|
|
| @staticmethod |
| def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list): |
| f1 = [] |
| for i in range(len(action_pred)): |
| ref_action = action_ref[i] |
| pred_action = action_pred[i] |
|
|
| ref_input = ref_list[i] |
| cand_input = cand_list[i] |
|
|
| ref_is_json = False |
| try: |
| ref_input_json = json.loads(ref_input) |
| ref_is_json = True |
| except Exception: |
| ref_input_json = ref_input |
|
|
| cand_is_json = False |
| try: |
| cand_input_json = json.loads(cand_input) |
| cand_is_json = True |
| except Exception: |
| cand_input_json = cand_input |
|
|
| if ref_action != pred_action or (ref_is_json ^ cand_is_json): |
| f1.append(0) |
| elif not ref_is_json and not cand_is_json: |
| rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json]) |
| if rougel is None or rougel < 10: |
| f1.append(0) |
| elif 10 <= rougel < 20: |
| f1.append(0.1) |
| else: |
| f1.append(1) |
| else: |
| if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict): |
| |
| |
| |
| |
| |
| f1.append(0) |
| continue |
|
|
| half_match = 0 |
| full_match = 0 |
| if ref_input_json == {}: |
| if cand_input_json == {}: |
| f1.append(1) |
| else: |
| f1.append(0) |
| else: |
| for k, v in ref_input_json.items(): |
| if k in cand_input_json.keys(): |
| if cand_input_json[k] == v: |
| full_match += 1 |
| else: |
| half_match += 1 |
|
|
| recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30) |
| precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30) |
| try: |
| f1.append((2 * recall * precision) / (recall + precision)) |
| except Exception: |
| f1.append(0.0) |
|
|
| if f1[0] == 1.0: |
| return True |
| else: |
| return False |
|
|
| @staticmethod |
| def parse_action(text): |
| if 'Action Input:' in text: |
| input_idx = text.rindex('Action Input:') |
| action_input = text[input_idx + len('Action Input:'):].strip() |
| else: |
| action_input = '{}' |
|
|
| if 'Action:' in text: |
| action_idx = text.rindex('Action:') |
| action = text[action_idx + len('Action:'):].strip() |
| if 'Action Input:' in action: |
| input_idx = action.index('Action Input:') |
| action = action[:input_idx].strip() |
| else: |
| action = 'none' |
| return action, action_input |
|
|
| @staticmethod |
| def parse_output(text): |
| action, action_input = ReactORM.parse_action(text) |
| return action, action_input |
|
|
| def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]: |
| rewards = [] |
| if not isinstance(infer_requests[0], str): |
| predictions = [request['messages'][-1]['content'] for request in infer_requests] |
| else: |
| predictions = infer_requests |
| for prediction, ground_truth in zip(predictions, solution): |
| if prediction.endswith('Observation:'): |
| prediction = prediction[:prediction.index('Observation:')].strip() |
| action_ref = [] |
| action_input_ref = [] |
| action_pred = [] |
| action_input_pred = [] |
| reference = ground_truth |
| prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip() |
| ref_action, ref_input = ReactORM.parse_output(reference) |
| pred_action, pred_input = ReactORM.parse_output(prediction) |
| action_ref.append(ref_action) |
| action_input_ref.append(ref_input) |
| if pred_action is None: |
| action_pred.append('none') |
| else: |
| action_pred.append(pred_action) |
|
|
| if pred_input is None: |
| action_input_pred.append('{}') |
| else: |
| action_input_pred.append(pred_input) |
|
|
| reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref) |
| rewards.append(float(reward)) |
| return rewards |
|
|
| @staticmethod |
| def evaluate_rougel(cand_list: list, ref_list: list): |
| if len(ref_list) == 0: |
| return None |
| try: |
| from rouge import Rouge |
| rouge = Rouge() |
| rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True) |
| rougel = rouge_score['rouge-l']['f'] |
| return rougel |
| except Exception: |
| return None |
|
|
|
|
| class MathORM(ORM): |
|
|
| def __init__(self): |
| from transformers.utils import strtobool |
| self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False')) |
| if self.use_opencompass: |
| from opencompass.datasets.math import MATHEvaluator |
| self.evaluator = MATHEvaluator() |
|
|
| @staticmethod |
| def check_terminate(answers: Union[str, List[str]]) -> List[bool]: |
| if isinstance(answers, str): |
| answers = [answers] |
| results = [] |
| for answer in answers: |
| results.append('\\boxed' in answer) |
| return results |
|
|
| @staticmethod |
| def extract_boxed_result(text): |
| pattern = r'\\boxed{([^}]*)}' |
| match = re.search(pattern, text) |
| if match: |
| return match.group(1).strip() |
| else: |
| return text |
|
|
| @staticmethod |
| def clean_latex(latex_str): |
| latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str) |
| latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '') |
| return latex_str.strip() |
|
|
| @staticmethod |
| def parse_expression(latex_str): |
| from sympy import simplify |
| from sympy.parsing.latex import parse_latex |
| try: |
| expr = parse_latex(latex_str) |
| return simplify(expr) |
| except Exception: |
| return None |
|
|
| @staticmethod |
| def compare_consecutive(first, second): |
| cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]] |
| parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list] |
| if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'): |
| value = parsed_exprs[0].equals(parsed_exprs[1]) |
| else: |
| value = parsed_exprs[0] == parsed_exprs[1] |
| if value is None: |
| value = False |
| return value |
|
|
| def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], |
| **kwargs) -> List[float]: |
| rewards = [] |
| predictions = [request.messages[-1]['content'] for request in infer_requests] |
| for prediction, ground_truth in zip(predictions, ground_truths): |
| if '# Answer' in prediction: |
| prediction = prediction.split('# Answer')[1] |
| if '# Answer' in ground_truth: |
| ground_truth = ground_truth.split('# Answer')[1] |
| prediction = prediction.strip() |
| ground_truth = ground_truth.strip() |
| prediction = MathORM.extract_boxed_result(prediction) |
| ground_truth = MathORM.extract_boxed_result(ground_truth) |
| if self.use_opencompass: |
| reward = self.evaluator.is_equiv(prediction, ground_truth) |
| else: |
| reward = MathORM.compare_consecutive(prediction, ground_truth) |
| rewards.append(float(reward)) |
| return rewards |
|
|
|
|
| class MathAccuracy(ORM): |
|
|
| def __init__(self): |
| import importlib.util |
| assert importlib.util.find_spec('math_verify') is not None, ( |
| "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") |
|
|
| def __call__(self, completions, solution, **kwargs) -> List[float]: |
| from latex2sympy2_extended import NormalizationConfig |
| from math_verify import LatexExtractionConfig, parse, verify |
| rewards = [] |
| for content, sol in zip(completions, solution): |
| gold_parsed = parse(sol, extraction_mode='first_match') |
| if len(gold_parsed) != 0: |
| |
| answer_parsed = parse( |
| content, |
| extraction_config=[ |
| LatexExtractionConfig( |
| normalization_config=NormalizationConfig( |
| nits=False, |
| malformed_operators=False, |
| basic_latex=True, |
| equations=True, |
| boxed=True, |
| units=True, |
| ), |
| |
| boxed_match_priority=0, |
| try_extract_without_anchor=False, |
| ) |
| ], |
| extraction_mode='first_match', |
| ) |
| |
| try: |
| reward = float(verify(gold_parsed, answer_parsed)) |
| except Exception: |
| reward = 0.0 |
| else: |
| |
| reward = 0.0 |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| class Format(ORM): |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])' |
| matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
| return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
| class ReActFormat(ORM): |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r'^<think>.*?</think>\s*Action:.*?Action Input:.*?$' |
| matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
| return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
| class CosineReward(ORM): |
| |
| def __init__(self, |
| tokenizer=None, |
| cosine_min_len_value_wrong: float = -0.5, |
| cosine_max_len_value_wrong: float = 0.0, |
| cosine_min_len_value_correct: float = 1.0, |
| cosine_max_len_value_correct: float = 0.5, |
| cosine_max_len: int = 1000, |
| accuracy_orm=None): |
| self.tokenizer = tokenizer |
| self.min_len_value_wrong = cosine_min_len_value_wrong |
| self.max_len_value_wrong = cosine_max_len_value_wrong |
| self.min_len_value_correct = cosine_min_len_value_correct |
| self.max_len_value_correct = cosine_max_len_value_correct |
| self.max_len = cosine_max_len |
| self.accuracy_orm = accuracy_orm or MathAccuracy() |
|
|
| @staticmethod |
| def cosfn(t, T, min_value, max_value): |
| import math |
| return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2 |
|
|
| def __call__(self, completions, solution, **kwargs) -> List[float]: |
| acc_rewards = self.accuracy_orm(completions, solution, **kwargs) |
| rewards = [] |
| for content, acc_reward in zip(completions, acc_rewards): |
| is_correct = acc_reward >= 1. |
| if is_correct: |
| |
| min_value = self.max_len_value_correct |
| max_value = self.min_len_value_correct |
| else: |
| min_value = self.max_len_value_wrong |
| max_value = self.min_len_value_wrong |
| gen_len = len(self.tokenizer.encode(content)) |
| reward = self.cosfn(gen_len, self.max_len, min_value, max_value) |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| class RepetitionPenalty(ORM): |
| |
| def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0): |
| self.ngram_size = repetition_n_grams |
| self.max_penalty = repetition_max_penalty |
|
|
| @staticmethod |
| def zipngram(text: str, ngram_size: int): |
| words = text.lower().split() |
| return zip(*[words[i:] for i in range(ngram_size)]) |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| """ |
| reward function the penalizes repetitions |
| |
| Args: |
| completions: List of model completions |
| """ |
| rewards = [] |
| for completion in completions: |
| if completion == '': |
| rewards.append(0.0) |
| continue |
| if len(completion.split()) < self.ngram_size: |
| rewards.append(0.0) |
| continue |
|
|
| ngrams = set() |
| total = 0 |
| for ng in self.zipngram(completion, self.ngram_size): |
| ngrams.add(ng) |
| total += 1 |
|
|
| scaling = 1 - len(ngrams) / total |
| reward = scaling * self.max_penalty |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| class SoftOverlong(ORM): |
|
|
| def __init__(self, tokenizer, soft_max_length, soft_cache_length): |
| self.tokenizer = tokenizer |
| assert soft_cache_length < soft_max_length |
| self.soft_max_length = soft_max_length |
| self.soft_cache_length = soft_cache_length |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| rewards = [] |
| for completion in completions: |
| completion_length = len(self.tokenizer.encode(completion)) |
| expected_len = self.soft_max_length - self.soft_cache_length |
| exceed_len = completion_length - expected_len |
| rewards.append(min(-exceed_len / self.soft_cache_length, 0)) |
| return rewards |
|
|
|
|
| orms = { |
| 'toolbench': ReactORM, |
| 'math': MathORM, |
| 'accuracy': MathAccuracy, |
| 'format': Format, |
| 'react_format': ReActFormat, |
| 'cosine': CosineReward, |
| 'repetition': RepetitionPenalty, |
| 'soft_overlong': SoftOverlong, |
| } |
|
|