| import re |
| from tqdm import tqdm |
| from consts import REASONING_END, REASONING_START, SOLUTION_START, SOLUTION_END |
|
|
|
|
| def formatting_reward_func(completions, **kwargs): |
| thinking_pattern = f"{REASONING_START}(.*?){REASONING_END}" |
| answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}" |
| scores = [] |
| for completion in tqdm(completions, desc="Computing formatting reward"): |
| score = 0 |
| thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL) |
| answer_matches = re.findall(answer_pattern, completion, re.DOTALL) |
| if len(thinking_matches) == 1: |
| score += 1.0 |
| if len(answer_matches) == 1: |
| score += 1.0 |
| scores.append(score) |
| return scores |
|
|
|
|
| def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: |
| answer_pattern = f"{SOLUTION_START}(.*?){SOLUTION_END}" |
|
|
| responses = [ |
| re.findall(answer_pattern, completion, re.DOTALL) |
| for completion in tqdm(completions, desc="Extracting responses for correctness") |
| ] |
| q = prompts[0] |
|
|
| print( |
| "-" * 20, |
| f"Question:\n{q}", |
| f"\nAnswer:\n{answer[0]}", |
| f"\nResponse:{completions[0]}", |
| ) |
| return [ |
| 2.0 if len(r) == 1 and a == r[0].replace("\n", "") else 0.0 |
| for r, a in tqdm( |
| zip(responses, answer), desc="Checking correctness", total=len(responses) |
| ) |
| ] |
|
|