| from grader import math_equal |
| from parser import strip_string |
| import timeout_decorator |
| from collections import defaultdict, Counter |
| from utils import load_jsonl |
|
|
|
|
| @timeout_decorator.timeout(5) |
| def math_equal_timeout(pred, gt): |
| try: |
| return math_equal(pred, gt) |
| except Exception as e: |
| print("Timeout error:", e) |
| return False |
|
|
|
|
| def group_pred(preds, strip=True, use_symbol=False): |
| orginal_preds = preds |
| if not use_symbol: |
| if strip: |
| preds = [strip_string(pred) for pred in preds] |
| cnt = Counter(preds) |
| majority = cnt.most_common(1)[0][0] |
| groups = defaultdict(list) |
| for idx, pred in enumerate(preds): |
| groups[pred].append(idx) |
| return groups, orginal_preds[groups[majority][0]] |
|
|
| groups = defaultdict(list) |
| for idx, pred in enumerate(preds): |
| found_group = False |
| if strip: |
| pred = strip_string(pred) |
| for group_pred in groups: |
| try: |
| if math_equal_timeout(pred, group_pred): |
| groups[group_pred].append(idx) |
| found_group = True |
| break |
| except: |
| continue |
| if not found_group: |
| groups[pred].append(idx) |
| |
| majority = sorted(groups.items(), key=lambda item: len(item[1]), reverse=True)[0][0] |
| majority = orginal_preds[groups[majority][0]] |
| return groups, majority |
|
|
|
|
| def eval_rm_k_metrics(data_path, k=8): |
| print(f"evaluating rm@{k}") |
| data_list = load_jsonl(data_path) |
|
|
| count, right_count = 0, 0 |
| for sample in data_list: |
| assert len(sample['pred_score']) >= k, sample['data_source'] |
| pred_score = sample['pred_score'][:k] |
| pred = sample['score'][:k] |
| assert len(pred_score) == len(pred), f"{len(pred_score)}, {len(pred)}" |
|
|
| rm_score = pred_score |
| rm_score = [inner_score for score in rm_score for inner_score in score] |
| assert len(rm_score) == len(pred), f"{len(rm_score)}, {len(pred)}" |
|
|
| max_index = rm_score.index(max(rm_score)) |
| max_pred = pred[max_index] |
| right_count += max_pred |
| count += 1 |
|
|
| print(count) |
| task_acc = right_count / count * 100 |
| print(f"acc: {task_acc:.1f}") |
| return task_acc |
|
|
|
|
| def eval_maj_k_metrics(data_path, k=8): |
| print(f"evaluating maj@{k}") |
|
|
| data_list = load_jsonl(data_path) |
| count, right_count = 0, 0 |
| for sample in data_list: |
| assert len(sample['score']) >= k, sample |
| groups, majority_pred = group_pred(sample['pred'][:k], strip=False, use_symbol=False) |
| idx = groups[majority_pred][0] |
| right_count += sample['score'][idx] |
| count += 1 |
|
|
| task_acc = right_count / count * 100 |
| print(f"acc: {task_acc:.1f}") |
| return task_acc |
|
|
|
|
| if __name__ == "__main__": |
| data_path = "./data/eval_rm_maj_example/math_cot_100.jsonl" |
|
|
| candidate = 8 |
| all_result = {} |
| all_result[f'maj@{candidate}'] = eval_maj_k_metrics(data_path, k=candidate) |
| all_result[f'rm@{candidate}'] = eval_rm_k_metrics(data_path, k=candidate) |
| print(all_result) |
|
|