| |
| |
| import re |
| import json |
| import math |
| import itertools |
| from typing import Any, Dict, List, Optional |
| import random |
|
|
| import torch |
| from rouge_score import rouge_scorer |
| from math_verify import parse as math_parse, verify as math_verify |
| from mathruler.grader import grade_answer |
|
|
| |
| |
| USE_MODEL_FOR_OPEN_ENDED: bool = False |
|
|
| |
| RM_MODEL_PATH = "internlm/POLAR-7B" |
| RM_SERVER_ADDRESS = "xx.xx.xx.xx:xxxx" |
| |
|
|
| |
| from verl.workers.reward.model_reward import RewardModelClient |
| import numpy as np |
| |
|
|
|
|
| |
| |
| |
| THINK_ANSWER_PATTERN = re.compile( |
| r"\A\s*<think>.*?</think>\s*<answer>.*?</answer>\s*\Z", |
| re.DOTALL |
| ) |
|
|
| ANSWER_CAPTURE_PATTERN = re.compile( |
| r"<answer>\s*(.*?)\s*</answer>", |
| re.DOTALL |
| ) |
|
|
|
|
| |
| |
| |
| def extract_answer(text: str) -> Optional[str]: |
| if not isinstance(text, str): |
| return None |
| m = ANSWER_CAPTURE_PATTERN.search(text) |
| return m.group(1).strip() if m else None |
|
|
|
|
| def normalize_number(num_str: str) -> Optional[float]: |
| try: |
| return float((num_str or "").replace(",", "")) |
| except Exception: |
| return None |
|
|
|
|
| def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05) -> float: |
| pred_t = torch.tensor(pred, dtype=torch.float32) |
| tgt_t = torch.tensor(target, dtype=torch.float32) |
| rel_error = torch.abs(pred_t - tgt_t) / (torch.abs(tgt_t) + 1e-8) |
| thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32) |
| return (rel_error < (1 - thresholds)).float().mean().item() |
|
|
|
|
| def wer(reference: str, hypothesis: str) -> float: |
| ref_words, hyp_words = (reference or "").split(), (hypothesis or "").split() |
| m, n = len(ref_words), len(hyp_words) |
| d = [[0] * (n + 1) for _ in range(m + 1)] |
| for i in range(m + 1): d[i][0] = i |
| for j in range(n + 1): d[0][j] = j |
| for i in range(1, m + 1): |
| for j in range(1, n + 1): |
| d[i][j] = d[i - 1][j - 1] if ref_words[i - 1] == hyp_words[j - 1] else 1 + min( |
| d[i - 1][j], d[i][j - 1], d[i - 1][j - 1] |
| ) |
| return d[m][n] / max(1, m) |
|
|
|
|
| def compute_rouge_score(reference: str, hypothesis: str) -> float: |
| scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
| scores = scorer.score(reference or "", hypothesis or "") |
| return (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3.0 |
|
|
|
|
| |
| def _is_list_of_numbers(x, n=None): |
| if not isinstance(x, list): |
| return False |
| if n is not None and len(x) != n: |
| return False |
| try: |
| for v in x: |
| float(v) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def iou_1d(pred: List[float], gt: List[float]) -> float: |
| |
| if not _is_list_of_numbers(pred, 2) or not _is_list_of_numbers(gt, 2): |
| return 0.0 |
| try: |
| s1, e1 = float(pred[0]), float(pred[1]) |
| s2, e2 = float(gt[0]), float(gt[1]) |
| except Exception: |
| return 0.0 |
| inter = max(0.0, min(e1, e2) - max(s1, s2)) |
| union = max(e1, e2) - min(s1, s2) |
| return inter / union if union > 1e-12 else 0.0 |
|
|
|
|
| def iou_2d(box1: List[float], box2: List[float]) -> float: |
| |
| if not _is_list_of_numbers(box1, 4) or not _is_list_of_numbers(box2, 4): |
| return 0.0 |
| try: |
| x1, y1, x2, y2 = map(float, box1) |
| X1, Y1, X2, Y2 = map(float, box2) |
| except Exception: |
| return 0.0 |
| inter_x1, inter_y1 = max(x1, X1), max(y1, Y1) |
| inter_x2, inter_y2 = min(x2, X2), min(y2, Y2) |
| inter_area = max(0.0, inter_x2 - inter_x1) * max(0.0, inter_y2 - inter_y1) |
| area1 = max(0.0, x2 - x1) * max(0.0, y2 - y1) |
| area2 = max(0.0, X2 - X1) * max(0.0, Y2 - Y1) |
| union = area1 + area2 - inter_area |
| return inter_area / union if union > 1e-12 else 0.0 |
|
|
|
|
| def mean_iou_over_gt_frames(pred_boxes: Dict[str, List[float]], gt_boxes: Dict[str, List[float]]) -> float: |
| """ |
| For tracking: average over all GT frames (missing predicted frames are counted as 0). |
| pred_boxes, gt_boxes: {frame_str: [x1,y1,x2,y2]} |
| """ |
| if not isinstance(gt_boxes, dict) or not gt_boxes: |
| return 0.0 |
| total, n = 0.0, 0 |
| for k, gbox in gt_boxes.items(): |
| total += iou_2d(pred_boxes.get(k, []), gbox) |
| n += 1 |
| return total / n if n > 0 else 0.0 |
|
|
|
|
| def mean_iou_over_intersection(pred_boxes: Dict[str, List[float]], gt_boxes: Dict[str, List[float]]) -> float: |
| """ |
| For spatial-temporal tasks: IoU is averaged only over the intersection of frame keys |
| (missing frames are not penalized). |
| """ |
| if not isinstance(pred_boxes, dict) or not isinstance(gt_boxes, dict): |
| return 0.0 |
| common = [k for k in pred_boxes.keys() if k in gt_boxes] |
| if not common: |
| return 0.0 |
| vals = [iou_2d(pred_boxes[k], gt_boxes[k]) for k in common] |
| return sum(vals) / len(vals) if vals else 0.0 |
|
|
|
|
| |
| |
| |
| def _pairwise_l2(p, q) -> float: |
| try: |
| dx = float(p[0]) - float(q[0]) |
| dy = float(p[1]) - float(q[1]) |
| return math.hypot(dx, dy) |
| except Exception: |
| return float("inf") |
|
|
|
|
| def assignment_similarity_3(pred_pts: List[List[float]], |
| gt_pts: List[List[float]], |
| sigma: float = 50.0) -> float: |
| """ |
| 3↔3 optimal matching (minimal total distance), returns Gaussian kernel similarity: |
| sim = exp(- (avg_dist^2) / (2 * sigma^2)) ∈ [0,1] |
| |
| Explanation of parameters: |
| - sigma: controls sensitivity range (smaller → more sensitive, sharper) |
| - pred_pts, gt_pts: [[x,y], [x,y], [x,y]] exactly 3 points each |
| - If the number or format of points is invalid, return 0.0 |
| """ |
| |
| if not isinstance(pred_pts, list) or not isinstance(gt_pts, list) or len(pred_pts) != 3 or len(gt_pts) != 3: |
| return 0.0 |
| for p in pred_pts + gt_pts: |
| if not _is_list_of_numbers(p, 2): |
| return 0.0 |
|
|
| |
| best_sum = float('inf') |
| for perm in itertools.permutations(range(3)): |
| s = 0.0 |
| good = True |
| for i in range(3): |
| d = _pairwise_l2(pred_pts[perm[i]], gt_pts[i]) |
| if math.isinf(d): |
| good = False |
| break |
| s += d |
| if good: |
| best_sum = min(best_sum, s) |
|
|
| if math.isinf(best_sum): |
| return 0.0 |
|
|
| |
| avg_d = best_sum / 3.0 |
| sim = math.exp(- (avg_d ** 2) / (2 * sigma ** 2)) |
| return max(0.0, min(1.0, sim)) |
|
|
|
|
| |
| |
| |
| def tag_format_reward(response: str) -> float: |
| """ |
| Format requirement (format reward): |
| Must strictly be: <think>...</think><answer>...</answer> |
| Arbitrary newlines/whitespaces are allowed in the middle, but tag order and closures must be correct. |
| Returns 1.0 if satisfied; otherwise 0.0. |
| """ |
| return 1.0 if THINK_ANSWER_PATTERN.fullmatch(response or "") else 0.0 |
|
|
|
|
| def answer_structure_bonus(answer: str, ground_truth: str, data_type: str, problem_type: str) -> float: |
| """ |
| Structure requirements (structure reward): |
| - spatial-temporal grounding: |
| JSON structure must satisfy: |
| {"time": [s, e], "boxes": {"frame_id": [x1, y1, x2, y2], ...}} |
| +0.25 if the structure is valid; |
| plus bbox key overlap ratio * 0.25 (overlap ratio = |pred.keys ∩ gt.keys| / |gt.keys|). |
| - tracking: |
| JSON structure must satisfy: |
| {"boxes": {"frame_id": [x1, y1, x2, y2], ...}} |
| +0.25 if the structure is valid; |
| plus bbox key overlap ratio * 0.25. |
| - temporal grounding: |
| {"time": [s, e]} gets +0.5 if valid, otherwise 0. |
| - spatial grounding: |
| {"boxes": [x1, y1, x2, y2]} gets +0.5 if valid, otherwise 0. |
| - segmentation: |
| * image: |
| {"boxes": [..], "positive_points": [[x,y],[x,y],[x,y]], "negative_points": [[x,y],[x,y],[x,y]]} |
| * video: |
| {"time": t, "boxes": [..], "positive_points": [[x,y],[x,y],[x,y]], "negative_points": [[x,y],[x,y],[x,y]]} |
| +0.5 if the corresponding structure is satisfied, otherwise 0. |
| - Other non-structured tasks: default +0.5. |
| """ |
| ptype = (problem_type or "").lower() |
| dtype = (data_type or "").lower() |
|
|
| def _json(s): |
| try: |
| return json.loads(s) |
| except Exception: |
| return None |
|
|
| if ptype in {"spatial-temporal grounding", "tracking"}: |
| obj_pred = _json(answer) |
| obj_gt = _json(ground_truth) |
| part_json = 0.0 |
| part_overlap = 0.0 |
|
|
| if ptype == "spatial-temporal grounding": |
| json_ok = ( |
| isinstance(obj_pred, dict) |
| and isinstance(obj_pred.get("time"), list) and len(obj_pred["time"]) == 2 |
| and isinstance(obj_pred.get("boxes"), dict) |
| and all(_is_list_of_numbers(v, 4) for v in obj_pred["boxes"].values()) |
| ) |
| else: |
| json_ok = ( |
| isinstance(obj_pred, dict) |
| and isinstance(obj_pred.get("boxes"), dict) |
| and all(_is_list_of_numbers(v, 4) for v in obj_pred["boxes"].values()) |
| ) |
| if json_ok: |
| part_json = 0.25 |
|
|
| if isinstance(obj_pred, dict) and isinstance(obj_gt, dict): |
| pboxes = obj_pred.get("boxes", {}) |
| gboxes = obj_gt.get("boxes", {}) |
| if isinstance(pboxes, dict) and isinstance(gboxes, dict) and len(gboxes) > 0: |
| inter = set(pboxes.keys()) & set(gboxes.keys()) |
| overlap_ratio = len(inter) / float(len(gboxes)) |
| part_overlap = 0.25 * max(0.0, min(1.0, float(overlap_ratio))) |
|
|
| return part_json + part_overlap |
|
|
| needs_check = {"temporal grounding", "spatial grounding", "segmentation"} |
| |
| if ptype in needs_check or "segmentation" in ptype: |
| obj = _json(answer) |
| if ptype == "temporal grounding": |
| ok = isinstance(obj, dict) and _is_list_of_numbers(obj.get("time"), 2) |
| return 0.5 if ok else 0.0 |
|
|
| if ptype == "spatial grounding": |
| ok = isinstance(obj, dict) and _is_list_of_numbers(obj.get("boxes"), 4) |
| return 0.5 if ok else 0.0 |
|
|
| |
| if ptype == "segmentation" or "segmentation" in ptype: |
| if dtype == "image": |
| ok = ( |
| isinstance(obj, dict) |
| and _is_list_of_numbers(obj.get("boxes"), 4) |
| and isinstance(obj.get("positive_points"), list) and len(obj["positive_points"]) == 3 |
| and isinstance(obj.get("negative_points"), list) and len(obj["negative_points"]) == 3 |
| and all(_is_list_of_numbers(p, 2) for p in obj["positive_points"]) |
| and all(_is_list_of_numbers(p, 2) for p in obj["negative_points"]) |
| ) |
| return 0.5 if ok else 0.0 |
| elif dtype == "video": |
| ok = ( |
| isinstance(obj, dict) |
| and isinstance(obj.get("time"), (int, float)) |
| and _is_list_of_numbers(obj.get("boxes"), 4) |
| and isinstance(obj.get("positive_points"), list) and len(obj["positive_points"]) == 3 |
| and isinstance(obj.get("negative_points"), list) and len(obj["negative_points"]) == 3 |
| and all(_is_list_of_numbers(p, 2) for p in obj["positive_points"]) |
| and all(_is_list_of_numbers(p, 2) for p in obj["negative_points"]) |
| ) |
| return 0.5 if ok else 0.0 |
| else: |
| return 0.0 |
|
|
| |
| return 0.5 |
|
|
|
|
| |
| |
| |
| def _math_equivalent(gt: str, pred: str) -> bool: |
| """ |
| Use math_verify to perform symbolic equivalence checking; if it fails (exceptions, etc.), |
| fall back to grade_answer. |
| """ |
| try: |
| return bool(math_verify(math_parse(gt), math_parse(pred))) |
| except Exception: |
| return grade_answer(pred, gt) |
|
|
|
|
| |
| |
| |
| def accuracy_reward(response: str, |
| ground_truth: str, |
| data_type: str, |
| problem_type: str) -> float: |
| """ |
| Normalized accuracy ∈ [0,1]. Strict format requirement: if the format is invalid, always return 0. |
| Wrapped with try/except: any exception → 0.0. |
| """ |
| try: |
| ans = extract_answer(response) or response.strip() |
| ptype = (problem_type or "").lower() |
| dtype = (data_type or "").lower() |
| gt = ground_truth or "" |
|
|
| |
| if ptype == "multiple choice": |
| return 1.0 if grade_answer(ans.strip(), gt.strip()) else 0.0 |
|
|
| if ptype == "numerical": |
| gt_num, pr_num = normalize_number(gt), normalize_number(ans) |
| return 1.0 if (gt_num is not None and pr_num is not None and round(gt_num, 2) == round(pr_num, 2)) else 0.0 |
|
|
| if ptype == "regression": |
| gt_num, pr_num = normalize_number(gt), normalize_number(ans) |
| if gt_num is None or pr_num is None: |
| return 0.0 |
| return mean_relative_accuracy(pr_num, gt_num) |
|
|
| if ptype == "ocr": |
| return max(0.0, min(1.0, 1.0 - wer(gt, ans))) |
|
|
| if ptype == "open-ended": |
| return max(0.0, min(1.0, compute_rouge_score(gt, ans))) |
|
|
| if ptype == "math": |
| return 1.0 if _math_equivalent(gt, ans) else 0.0 |
|
|
| |
| def _load_json(s: str): |
| try: |
| return json.loads(s) |
| except Exception: |
| return None |
|
|
| |
| if ptype == "temporal grounding": |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| if not isinstance(pred, dict) or not isinstance(gtj, dict): |
| return 0.0 |
| return iou_1d(pred.get("time"), gtj.get("time")) |
|
|
| |
| if ptype == "spatial grounding": |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| if not isinstance(pred, dict) or not isinstance(gtj, dict): |
| return 0.0 |
| return iou_2d(pred.get("boxes"), gtj.get("boxes")) |
|
|
| |
| if ptype == "spatial-temporal grounding": |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| if not isinstance(pred, dict) or not isinstance(gtj, dict): |
| return 0.0 |
| tiou = iou_1d(pred.get("time"), gtj.get("time")) |
| pboxes = pred.get("boxes") |
| gboxes = gtj.get("boxes") |
| if not isinstance(pboxes, dict) or not isinstance(gboxes, dict): |
| miou_inter = 0.0 |
| else: |
| miou_inter = mean_iou_over_intersection(pboxes, gboxes) |
| return 0.5 * tiou + 0.5 * miou_inter |
|
|
| |
| if ptype == "tracking": |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| if not isinstance(pred, dict) or not isinstance(gtj, dict): |
| return 0.0 |
| pboxes = pred.get("boxes") |
| gboxes = gtj.get("boxes") |
| if not isinstance(pboxes, dict) or not isinstance(gboxes, dict): |
| return 0.0 |
| return mean_iou_over_gt_frames(pboxes, gboxes) |
|
|
| |
| |
| if "segmentation" in ptype: |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| if not isinstance(pred, dict) or not isinstance(gtj, dict): |
| return 0.0 |
|
|
| iou = iou_2d(pred.get("boxes"), gtj.get("boxes")) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| if dtype == "image": |
| |
| return iou |
|
|
| if dtype == "video": |
| |
| t_pred = pred.get("time") |
| t_gt = gtj.get("time") |
| if not isinstance(t_pred, (int, float)) or not isinstance(t_gt, (int, float)): |
| return 0.0 |
| time_sim = math.exp(-abs(float(t_pred) - float(t_gt)) / 2.0) |
| return 0.3 * iou + 0.3 * time_sim + 0.2 * pos_sim + 0.2 * neg_sim |
|
|
| return 0.0 |
| |
| if "localization" in ptype: |
| |
| pred = _load_json(ans) |
| gtj = _load_json(gt) |
| |
| |
| if isinstance(pred, list) and isinstance(gtj, list): |
| return iou_2d(pred, gtj) |
| |
| |
| if isinstance(pred, dict) and isinstance(gtj, dict): |
| return iou_2d(pred.get("boxes"), gtj.get("boxes")) |
| |
| return 0.0 |
|
|
| if "recognition" in ptype: |
| |
| gt_labels = set(label.strip().lower() for label in gt.split(",")) |
| ans_labels = set(label.strip().lower() for label in ans.split(",")) |
| return 1.0 if gt_labels == ans_labels else 0.0 |
|
|
| if ptype == "surgical instrument count": |
| |
| try: |
| gt_num = int(gt.strip()) |
| ans_num = int(ans.strip()) |
| return 1.0 if gt_num == ans_num else 0.0 |
| except ValueError: |
| |
| return 1.0 if gt.lower().strip() == ans.lower().strip() else 0.0 |
| |
| if ptype == "critical view safety": |
| |
| try: |
| import ast |
| gt_list = ast.literal_eval(gt.strip()) |
| ans_list = ast.literal_eval(ans.strip()) |
| if not isinstance(gt_list, list) or not isinstance(ans_list, list): |
| return 0.0 |
| if len(gt_list) != 3 or len(ans_list) != 3: |
| return 0.0 |
| |
| correct = sum(1 for g, a in zip(gt_list, ans_list) if g.strip().lower() == a.strip().lower()) |
| return correct / 3.0 |
| except (ValueError, SyntaxError): |
| return 0.0 |
|
|
| |
| return 0.0 |
| except Exception: |
| |
| return 0.0 |
|
|
|
|
| |
| def evaluate_open_ended_with_rm( |
| open_ended_queue: List[Dict[str, Any]], |
| results: List[Dict[str, float]], |
| format_weight: float, |
| rm_server_type: str, |
| rm_batch_size: int, |
| normalize_model_reward_by_problem_id: bool |
| ) -> None: |
| """ |
| Take open-ended samples in open_ended_queue, and call external RM in batches to evaluate accuracy. |
| Failed batches fall back to ROUGE. Optionally apply mean-std → min-max normalization within |
| each problem_id group. |
| After evaluation, this function will fill results[idx]['accuracy'] in-place and recompute |
| results[idx]['overall']. |
| """ |
| if not USE_MODEL_FOR_OPEN_ENDED or not open_ended_queue: |
| return |
|
|
| client = RewardModelClient( |
| RM_MODEL_PATH, |
| server_type=rm_server_type, |
| server_address=RM_SERVER_ADDRESS |
| ) |
|
|
| def _chunks(lst, n): |
| for i in range(0, len(lst), n): |
| yield lst[i:i+n] |
|
|
| model_scores: List[float] = [0.0] * len(open_ended_queue) |
|
|
| for batch_id, batch in enumerate(_chunks(open_ended_queue, rm_batch_size)): |
| data = [{"prompt": b["prompt"], "reference": b["reference"], "output": b["output"]} for b in batch] |
| try: |
| rewards = client(data) |
| for j, sc in enumerate(rewards): |
| model_scores[(batch_id * rm_batch_size) + j] = float(sc) |
| except Exception: |
| |
| for j, b in enumerate(batch): |
| ref = b["reference"] |
| out = b["output"] |
| model_scores[(batch_id * rm_batch_size) + j] = float(max(0.0, min(1.0, compute_rouge_score(ref, out)))) |
|
|
| if normalize_model_reward_by_problem_id: |
| groups: Dict[Any, List[int]] = {} |
| for k, b in enumerate(open_ended_queue): |
| gid = b.get("problem_id", None) |
| groups.setdefault(gid, []).append(k) |
|
|
| for gid, indices in groups.items(): |
| vals = np.array([model_scores[k] for k in indices], dtype=np.float32) |
| mean, std = vals.mean(), vals.std() |
| if std == 0: |
| norm_vals = np.ones_like(vals) |
| else: |
| z = (vals - mean) / (std + 1e-6) |
| norm_vals = (z - z.min()) / (z.max() - z.min() + 1e-12) |
| for t, k in enumerate(indices): |
| model_scores[k] = float(norm_vals[t]) |
|
|
| |
| for k, b in enumerate(open_ended_queue): |
| idx = b["idx"] |
| results[idx]["accuracy"] = float(max(0.0, min(1.0, model_scores[k]))) |
| results[idx]["overall"] = ( |
| (1.0 - format_weight) * results[idx]["accuracy"] |
| + format_weight * results[idx]["format"] |
| + results[idx]["structure_reward"] |
| ) |
| |
|
|
|
|
| |
| |
| |
| def compute_score( |
| reward_inputs: List[Dict[str, Any]], |
| format_weight: float = 0.1, |
| |
| rm_server_type: str = "vllm", |
| rm_batch_size: int = 64, |
| normalize_model_reward_by_problem_id: bool = True, |
| ) -> List[Dict[str, float]]: |
| """ |
| Batch interface. |
| Each item: |
| { |
| "response": str, |
| "response_length": int, |
| "ground_truth": str, # may also contain <answer>...</answer>, here we extract it first |
| "data_type": str, # "image" | "video" | ... |
| "problem_type": str # see branches above |
| # Optional additional fields: |
| # "problem": str # used as prompt for external RM in open-ended tasks |
| # "problem_id": Any # grouping key for normalization |
| } |
| Returns: list of dict with keys {overall, format, accuracy, structure_reward} |
| overall = (1 - format_weight) * accuracy + format_weight * format + structure_reward |
| - format: 1.0 if <think>...</think><answer>...</answer>, otherwise 0.0 |
| - structure_reward: |
| * spatial-temporal / tracking: 0.25 (JSON valid) + 0.25 (key overlap ratio) |
| * other structured tasks: +0.5 if valid |
| * non-structured tasks: default +0.5 |
| """ |
| if not isinstance(reward_inputs, list): |
| raise ValueError("Please use `reward_type=batch` for this reward function.") |
|
|
| results: List[Dict[str, float]] = [] |
| |
| open_ended_queue = [] |
| |
|
|
| for idx, item in enumerate(reward_inputs): |
| try: |
| |
| raw_response = item.get("response", "") or "" |
| response = re.sub(r"\s*(<|>|/)\s*", r"\1", raw_response) |
| |
|
|
| |
|
|
| data_type = item.get("data_type", "") or "" |
| problem_type = item.get("problem_type", "") or "" |
|
|
| |
| raw_gt = item.get("ground_truth", "") or "" |
| gt_extracted = extract_answer(raw_gt) or raw_gt |
| |
|
|
| |
| f_score = tag_format_reward(response) |
|
|
| |
| ans = extract_answer(response) or "" |
| s_reward = answer_structure_bonus(ans, gt_extracted, data_type, problem_type) |
|
|
| |
| if USE_MODEL_FOR_OPEN_ENDED and (problem_type or "").lower() == "open-ended": |
| |
| a_score = 0.0 |
| open_ended_queue.append({ |
| "idx": idx, |
| "prompt": item.get("problem", "") or "", |
| "reference": gt_extracted or "", |
| "output": ans or "", |
| "problem_id": item.get("problem_id", None), |
| }) |
| else: |
| a_score = accuracy_reward(response, gt_extracted, data_type, problem_type) |
|
|
| if f_score == 0: |
| s_reward = 0 |
|
|
| overall = (1.0 - format_weight) * a_score + format_weight * f_score + s_reward |
|
|
| results.append({ |
| "overall": float(overall), |
| "format": float(f_score), |
| "accuracy": float(a_score), |
| "structure_reward": float(s_reward), |
| }) |
| except Exception: |
| |
| results.append({ |
| "overall": 0.0, |
| "format": 0.0, |
| "accuracy": 0.0, |
| "structure_reward": 0.0, |
| }) |
|
|
|
|
|
|
| |
| evaluate_open_ended_with_rm( |
| open_ended_queue=open_ended_queue, |
| results=results, |
| format_weight=format_weight, |
| rm_server_type=rm_server_type, |
| rm_batch_size=rm_batch_size, |
| normalize_model_reward_by_problem_id=normalize_model_reward_by_problem_id |
| ) |
| |
|
|
| if random.random() < 0.01: |
|
|
| for idx, item in enumerate(reward_inputs): |
|
|
| print('type', item.get("problem_type", "")) |
| print('gt', extract_answer(item.get("ground_truth", ""))) |
| print('ans', extract_answer(item.get("response", ""))) |
| print({ |
| "overall": results[idx]["overall"], |
| "format": results[idx]["format"], |
| "accuracy": results[idx]["accuracy"], |
| "structure_reward": results[idx]["structure_reward"], |
| }) |
|
|
| return results |
|
|