# -*- coding: utf-8 -*- # Rewards for multimodal tasks with ...... outputs. import re import json import math import itertools from typing import Any, Dict, List, Optional import random import torch from rouge_score import rouge_scorer from math_verify import parse as math_parse, verify as math_verify from mathruler.grader import grade_answer # ===================== Model-based reward configuration ===================== # Whether to use external Reward Model to compute accuracy for open-ended type USE_MODEL_FOR_OPEN_ENDED: bool = False # External RM model and service address (kept consistent with example) RM_MODEL_PATH = "internlm/POLAR-7B" RM_SERVER_ADDRESS = "xx.xx.xx.xx:xxxx" # ========================================================== # ===================== External RM evaluation dependencies ===================== from verl.workers.reward.model_reward import RewardModelClient import numpy as np # ========================================================= # ------------------------- # Patterns for format check # ------------------------- THINK_ANSWER_PATTERN = re.compile( r"\A\s*.*?\s*.*?\s*\Z", re.DOTALL ) ANSWER_CAPTURE_PATTERN = re.compile( r"\s*(.*?)\s*", re.DOTALL ) # ------------------------- # Utilities # ------------------------- def extract_answer(text: str) -> Optional[str]: if not isinstance(text, str): return None m = ANSWER_CAPTURE_PATTERN.search(text) return m.group(1).strip() if m else None def normalize_number(num_str: str) -> Optional[float]: try: return float((num_str or "").replace(",", "")) except Exception: return None def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05) -> float: pred_t = torch.tensor(pred, dtype=torch.float32) tgt_t = torch.tensor(target, dtype=torch.float32) rel_error = torch.abs(pred_t - tgt_t) / (torch.abs(tgt_t) + 1e-8) thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32) return (rel_error < (1 - thresholds)).float().mean().item() def wer(reference: str, hypothesis: str) -> float: ref_words, hyp_words = (reference or "").split(), (hypothesis or "").split() m, n = len(ref_words), len(hyp_words) d = [[0] * (n + 1) for _ in range(m + 1)] for i in range(m + 1): d[i][0] = i for j in range(n + 1): d[0][j] = j for i in range(1, m + 1): for j in range(1, n + 1): d[i][j] = d[i - 1][j - 1] if ref_words[i - 1] == hyp_words[j - 1] else 1 + min( d[i - 1][j], d[i][j - 1], d[i - 1][j - 1] ) return d[m][n] / max(1, m) def compute_rouge_score(reference: str, hypothesis: str) -> float: scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) scores = scorer.score(reference or "", hypothesis or "") return (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3.0 # ---------- IoU helpers (strict format: must be numeric lists) ---------- def _is_list_of_numbers(x, n=None): if not isinstance(x, list): return False if n is not None and len(x) != n: return False try: for v in x: float(v) return True except Exception: return False def iou_1d(pred: List[float], gt: List[float]) -> float: # Strict: must be numeric lists with length 2; otherwise return 0 if not _is_list_of_numbers(pred, 2) or not _is_list_of_numbers(gt, 2): return 0.0 try: s1, e1 = float(pred[0]), float(pred[1]) s2, e2 = float(gt[0]), float(gt[1]) except Exception: return 0.0 inter = max(0.0, min(e1, e2) - max(s1, s2)) union = max(e1, e2) - min(s1, s2) return inter / union if union > 1e-12 else 0.0 def iou_2d(box1: List[float], box2: List[float]) -> float: # Strict: must be numeric lists with length 4; otherwise return 0 if not _is_list_of_numbers(box1, 4) or not _is_list_of_numbers(box2, 4): return 0.0 try: x1, y1, x2, y2 = map(float, box1) X1, Y1, X2, Y2 = map(float, box2) except Exception: return 0.0 inter_x1, inter_y1 = max(x1, X1), max(y1, Y1) inter_x2, inter_y2 = min(x2, X2), min(y2, Y2) inter_area = max(0.0, inter_x2 - inter_x1) * max(0.0, inter_y2 - inter_y1) area1 = max(0.0, x2 - x1) * max(0.0, y2 - y1) area2 = max(0.0, X2 - X1) * max(0.0, Y2 - Y1) union = area1 + area2 - inter_area return inter_area / union if union > 1e-12 else 0.0 def mean_iou_over_gt_frames(pred_boxes: Dict[str, List[float]], gt_boxes: Dict[str, List[float]]) -> float: """ For tracking: average over all GT frames (missing predicted frames are counted as 0). pred_boxes, gt_boxes: {frame_str: [x1,y1,x2,y2]} """ if not isinstance(gt_boxes, dict) or not gt_boxes: return 0.0 total, n = 0.0, 0 for k, gbox in gt_boxes.items(): total += iou_2d(pred_boxes.get(k, []), gbox) n += 1 return total / n if n > 0 else 0.0 def mean_iou_over_intersection(pred_boxes: Dict[str, List[float]], gt_boxes: Dict[str, List[float]]) -> float: """ For spatial-temporal tasks: IoU is averaged only over the intersection of frame keys (missing frames are not penalized). """ if not isinstance(pred_boxes, dict) or not isinstance(gt_boxes, dict): return 0.0 common = [k for k in pred_boxes.keys() if k in gt_boxes] if not common: return 0.0 vals = [iou_2d(pred_boxes[k], gt_boxes[k]) for k in common] return sum(vals) / len(vals) if vals else 0.0 # ------------------------- # Segmentation point matching: strict 3↔3 optimal assignment # ------------------------- def _pairwise_l2(p, q) -> float: try: dx = float(p[0]) - float(q[0]) dy = float(p[1]) - float(q[1]) return math.hypot(dx, dy) except Exception: return float("inf") def assignment_similarity_3(pred_pts: List[List[float]], gt_pts: List[List[float]], sigma: float = 50.0) -> float: """ 3↔3 optimal matching (minimal total distance), returns Gaussian kernel similarity: sim = exp(- (avg_dist^2) / (2 * sigma^2)) ∈ [0,1] Explanation of parameters: - sigma: controls sensitivity range (smaller → more sensitive, sharper) - pred_pts, gt_pts: [[x,y], [x,y], [x,y]] exactly 3 points each - If the number or format of points is invalid, return 0.0 """ # ----------- Check input validity ----------- if not isinstance(pred_pts, list) or not isinstance(gt_pts, list) or len(pred_pts) != 3 or len(gt_pts) != 3: return 0.0 for p in pred_pts + gt_pts: if not _is_list_of_numbers(p, 2): return 0.0 # ----------- Compute optimal matching distance ----------- best_sum = float('inf') for perm in itertools.permutations(range(3)): s = 0.0 good = True for i in range(3): d = _pairwise_l2(pred_pts[perm[i]], gt_pts[i]) if math.isinf(d): good = False break s += d if good: best_sum = min(best_sum, s) if math.isinf(best_sum): return 0.0 # ----------- Average distance & Gaussian similarity ----------- avg_d = best_sum / 3.0 sim = math.exp(- (avg_d ** 2) / (2 * sigma ** 2)) return max(0.0, min(1.0, sim)) # ------------------------- # Format reward + structure reward # ------------------------- def tag_format_reward(response: str) -> float: """ Format requirement (format reward): Must strictly be: ...... Arbitrary newlines/whitespaces are allowed in the middle, but tag order and closures must be correct. Returns 1.0 if satisfied; otherwise 0.0. """ return 1.0 if THINK_ANSWER_PATTERN.fullmatch(response or "") else 0.0 def answer_structure_bonus(answer: str, ground_truth: str, data_type: str, problem_type: str) -> float: """ Structure requirements (structure reward): - spatial-temporal grounding: JSON structure must satisfy: {"time": [s, e], "boxes": {"frame_id": [x1, y1, x2, y2], ...}} +0.25 if the structure is valid; plus bbox key overlap ratio * 0.25 (overlap ratio = |pred.keys ∩ gt.keys| / |gt.keys|). - tracking: JSON structure must satisfy: {"boxes": {"frame_id": [x1, y1, x2, y2], ...}} +0.25 if the structure is valid; plus bbox key overlap ratio * 0.25. - temporal grounding: {"time": [s, e]} gets +0.5 if valid, otherwise 0. - spatial grounding: {"boxes": [x1, y1, x2, y2]} gets +0.5 if valid, otherwise 0. - segmentation: * image: {"boxes": [..], "positive_points": [[x,y],[x,y],[x,y]], "negative_points": [[x,y],[x,y],[x,y]]} * video: {"time": t, "boxes": [..], "positive_points": [[x,y],[x,y],[x,y]], "negative_points": [[x,y],[x,y],[x,y]]} +0.5 if the corresponding structure is satisfied, otherwise 0. - Other non-structured tasks: default +0.5. """ ptype = (problem_type or "").lower() dtype = (data_type or "").lower() def _json(s): try: return json.loads(s) except Exception: return None if ptype in {"spatial-temporal grounding", "tracking"}: obj_pred = _json(answer) obj_gt = _json(ground_truth) part_json = 0.0 part_overlap = 0.0 if ptype == "spatial-temporal grounding": json_ok = ( isinstance(obj_pred, dict) and isinstance(obj_pred.get("time"), list) and len(obj_pred["time"]) == 2 and isinstance(obj_pred.get("boxes"), dict) and all(_is_list_of_numbers(v, 4) for v in obj_pred["boxes"].values()) ) else: # tracking json_ok = ( isinstance(obj_pred, dict) and isinstance(obj_pred.get("boxes"), dict) and all(_is_list_of_numbers(v, 4) for v in obj_pred["boxes"].values()) ) if json_ok: part_json = 0.25 if isinstance(obj_pred, dict) and isinstance(obj_gt, dict): pboxes = obj_pred.get("boxes", {}) gboxes = obj_gt.get("boxes", {}) if isinstance(pboxes, dict) and isinstance(gboxes, dict) and len(gboxes) > 0: inter = set(pboxes.keys()) & set(gboxes.keys()) overlap_ratio = len(inter) / float(len(gboxes)) part_overlap = 0.25 * max(0.0, min(1.0, float(overlap_ratio))) return part_json + part_overlap needs_check = {"temporal grounding", "spatial grounding", "segmentation"} # if ptype in needs_check: if ptype in needs_check or "segmentation" in ptype: obj = _json(answer) if ptype == "temporal grounding": ok = isinstance(obj, dict) and _is_list_of_numbers(obj.get("time"), 2) return 0.5 if ok else 0.0 if ptype == "spatial grounding": ok = isinstance(obj, dict) and _is_list_of_numbers(obj.get("boxes"), 4) return 0.5 if ok else 0.0 # if ptype == "segmentation": if ptype == "segmentation" or "segmentation" in ptype: if dtype == "image": ok = ( isinstance(obj, dict) and _is_list_of_numbers(obj.get("boxes"), 4) and isinstance(obj.get("positive_points"), list) and len(obj["positive_points"]) == 3 and isinstance(obj.get("negative_points"), list) and len(obj["negative_points"]) == 3 and all(_is_list_of_numbers(p, 2) for p in obj["positive_points"]) and all(_is_list_of_numbers(p, 2) for p in obj["negative_points"]) ) return 0.5 if ok else 0.0 elif dtype == "video": ok = ( isinstance(obj, dict) and isinstance(obj.get("time"), (int, float)) # time must be numeric and _is_list_of_numbers(obj.get("boxes"), 4) and isinstance(obj.get("positive_points"), list) and len(obj["positive_points"]) == 3 and isinstance(obj.get("negative_points"), list) and len(obj["negative_points"]) == 3 and all(_is_list_of_numbers(p, 2) for p in obj["positive_points"]) and all(_is_list_of_numbers(p, 2) for p in obj["negative_points"]) ) return 0.5 if ok else 0.0 else: return 0.0 # Non-structured tasks: default +0.5 return 0.5 # ------------------------- # Math equivalence helper # ------------------------- def _math_equivalent(gt: str, pred: str) -> bool: """ Use math_verify to perform symbolic equivalence checking; if it fails (exceptions, etc.), fall back to grade_answer. """ try: return bool(math_verify(math_parse(gt), math_parse(pred))) except Exception: return grade_answer(pred, gt) # ------------------------- # Accuracy reward (normalized to [0,1]) # ------------------------- def accuracy_reward(response: str, ground_truth: str, data_type: str, problem_type: str) -> float: """ Normalized accuracy ∈ [0,1]. Strict format requirement: if the format is invalid, always return 0. Wrapped with try/except: any exception → 0.0. """ try: ans = extract_answer(response) or response.strip() ptype = (problem_type or "").lower() dtype = (data_type or "").lower() gt = ground_truth or "" # ------ Pure QA type ------ if ptype == "multiple choice": return 1.0 if grade_answer(ans.strip(), gt.strip()) else 0.0 if ptype == "numerical": gt_num, pr_num = normalize_number(gt), normalize_number(ans) return 1.0 if (gt_num is not None and pr_num is not None and round(gt_num, 2) == round(pr_num, 2)) else 0.0 if ptype == "regression": gt_num, pr_num = normalize_number(gt), normalize_number(ans) if gt_num is None or pr_num is None: return 0.0 return mean_relative_accuracy(pr_num, gt_num) if ptype == "ocr": return max(0.0, min(1.0, 1.0 - wer(gt, ans))) if ptype == "open-ended": return max(0.0, min(1.0, compute_rouge_score(gt, ans))) if ptype == "math": return 1.0 if _math_equivalent(gt, ans) else 0.0 # ------ JSON type (strict format)------ def _load_json(s: str): try: return json.loads(s) except Exception: return None # temporal grounding: tIoU ∈ [0,1] if ptype == "temporal grounding": pred = _load_json(ans) gtj = _load_json(gt) if not isinstance(pred, dict) or not isinstance(gtj, dict): return 0.0 return iou_1d(pred.get("time"), gtj.get("time")) # spatial grounding: box IoU ∈ [0,1] if ptype == "spatial grounding": pred = _load_json(ans) gtj = _load_json(gt) if not isinstance(pred, dict) or not isinstance(gtj, dict): return 0.0 return iou_2d(pred.get("boxes"), gtj.get("boxes")) # spatial-temporal grounding: 0.5*tIoU + 0.5*mIoU(intersection) if ptype == "spatial-temporal grounding": pred = _load_json(ans) gtj = _load_json(gt) if not isinstance(pred, dict) or not isinstance(gtj, dict): return 0.0 tiou = iou_1d(pred.get("time"), gtj.get("time")) pboxes = pred.get("boxes") gboxes = gtj.get("boxes") if not isinstance(pboxes, dict) or not isinstance(gboxes, dict): miou_inter = 0.0 else: miou_inter = mean_iou_over_intersection(pboxes, gboxes) return 0.5 * tiou + 0.5 * miou_inter # tracking: mean mIoU over GT frames (missing frames = 0) if ptype == "tracking": pred = _load_json(ans) gtj = _load_json(gt) if not isinstance(pred, dict) or not isinstance(gtj, dict): return 0.0 pboxes = pred.get("boxes") gboxes = gtj.get("boxes") if not isinstance(pboxes, dict) or not isinstance(gboxes, dict): return 0.0 return mean_iou_over_gt_frames(pboxes, gboxes) # segmentation(image/video) # if ptype == "segmentation": if "segmentation" in ptype: pred = _load_json(ans) gtj = _load_json(gt) if not isinstance(pred, dict) or not isinstance(gtj, dict): return 0.0 iou = iou_2d(pred.get("boxes"), gtj.get("boxes")) # pos_pred = pred.get("positive_points") # pos_gt = gtj.get("positive_points") # neg_pred = pred.get("negative_points") # neg_gt = gtj.get("negative_points") # # Must be strict 3↔3 # if not (isinstance(pos_pred, list) and len(pos_pred) == 3 and # isinstance(pos_gt, list) and len(pos_gt) == 3 and # all(_is_list_of_numbers(p, 2) for p in pos_pred) and # all(_is_list_of_numbers(p, 2) for p in pos_gt)): # return 0.0 # if not (isinstance(neg_pred, list) and len(neg_pred) == 3 and # isinstance(neg_gt, list) and len(neg_gt) == 3 and # all(_is_list_of_numbers(p, 2) for p in neg_pred) and # all(_is_list_of_numbers(p, 2) for p in neg_gt)): # return 0.0 # pos_sim = assignment_similarity_3(pos_pred, pos_gt) # neg_sim = assignment_similarity_3(neg_pred, neg_gt) if dtype == "image": # return 0.5 * iou + 0.25 * pos_sim + 0.25 * neg_sim return iou if dtype == "video": # time must be numeric (strict) t_pred = pred.get("time") t_gt = gtj.get("time") if not isinstance(t_pred, (int, float)) or not isinstance(t_gt, (int, float)): return 0.0 time_sim = math.exp(-abs(float(t_pred) - float(t_gt)) / 2.0) # τ=2s return 0.3 * iou + 0.3 * time_sim + 0.2 * pos_sim + 0.2 * neg_sim return 0.0 if "localization" in ptype: # Handle both dict format {"boxes": [...]} and direct list format [...] pred = _load_json(ans) gtj = _load_json(gt) # Direct list format: [x1, y1, x2, y2] if isinstance(pred, list) and isinstance(gtj, list): return iou_2d(pred, gtj) # Dict format: {"boxes": [x1, y1, x2, y2]} if isinstance(pred, dict) and isinstance(gtj, dict): return iou_2d(pred.get("boxes"), gtj.get("boxes")) return 0.0 if "recognition" in ptype: # Handle multi-label case: compare as sets (order-independent) gt_labels = set(label.strip().lower() for label in gt.split(",")) ans_labels = set(label.strip().lower() for label in ans.split(",")) return 1.0 if gt_labels == ans_labels else 0.0 if ptype == "surgical instrument count": # Use numerical comparison for count task try: gt_num = int(gt.strip()) ans_num = int(ans.strip()) return 1.0 if gt_num == ans_num else 0.0 except ValueError: # Fallback to string comparison if not valid numbers return 1.0 if gt.lower().strip() == ans.lower().strip() else 0.0 if ptype == "critical view safety": # Parse list format: ['Yes', 'Yes', 'No'] try: import ast gt_list = ast.literal_eval(gt.strip()) ans_list = ast.literal_eval(ans.strip()) if not isinstance(gt_list, list) or not isinstance(ans_list, list): return 0.0 if len(gt_list) != 3 or len(ans_list) != 3: return 0.0 # Calculate accuracy for 3 CVS criteria correct = sum(1 for g, a in zip(gt_list, ans_list) if g.strip().lower() == a.strip().lower()) return correct / 3.0 except (ValueError, SyntaxError): return 0.0 # Unknown type return 0.0 except Exception: # Outer fallback: any exception will be scored as 0 return 0.0 # ===================== Wrapper: batch call external model for open-ended ===================== def evaluate_open_ended_with_rm( open_ended_queue: List[Dict[str, Any]], results: List[Dict[str, float]], format_weight: float, rm_server_type: str, rm_batch_size: int, normalize_model_reward_by_problem_id: bool ) -> None: """ Take open-ended samples in open_ended_queue, and call external RM in batches to evaluate accuracy. Failed batches fall back to ROUGE. Optionally apply mean-std → min-max normalization within each problem_id group. After evaluation, this function will fill results[idx]['accuracy'] in-place and recompute results[idx]['overall']. """ if not USE_MODEL_FOR_OPEN_ENDED or not open_ended_queue: return client = RewardModelClient( RM_MODEL_PATH, server_type=rm_server_type, server_address=RM_SERVER_ADDRESS ) def _chunks(lst, n): for i in range(0, len(lst), n): yield lst[i:i+n] model_scores: List[float] = [0.0] * len(open_ended_queue) for batch_id, batch in enumerate(_chunks(open_ended_queue, rm_batch_size)): data = [{"prompt": b["prompt"], "reference": b["reference"], "output": b["output"]} for b in batch] try: rewards = client(data) # expected to return list[float] for j, sc in enumerate(rewards): model_scores[(batch_id * rm_batch_size) + j] = float(sc) except Exception: # Fallback: use ROUGE to compute scores for this batch for j, b in enumerate(batch): ref = b["reference"] out = b["output"] model_scores[(batch_id * rm_batch_size) + j] = float(max(0.0, min(1.0, compute_rouge_score(ref, out)))) if normalize_model_reward_by_problem_id: groups: Dict[Any, List[int]] = {} for k, b in enumerate(open_ended_queue): gid = b.get("problem_id", None) groups.setdefault(gid, []).append(k) for gid, indices in groups.items(): vals = np.array([model_scores[k] for k in indices], dtype=np.float32) mean, std = vals.mean(), vals.std() if std == 0: norm_vals = np.ones_like(vals) else: z = (vals - mean) / (std + 1e-6) norm_vals = (z - z.min()) / (z.max() - z.min() + 1e-12) for t, k in enumerate(indices): model_scores[k] = float(norm_vals[t]) # Fill back accuracy, and recompute overall for k, b in enumerate(open_ended_queue): idx = b["idx"] results[idx]["accuracy"] = float(max(0.0, min(1.0, model_scores[k]))) results[idx]["overall"] = ( (1.0 - format_weight) * results[idx]["accuracy"] + format_weight * results[idx]["format"] + results[idx]["structure_reward"] ) # ================================================================== # ------------------------- # Public API # ------------------------- def compute_score( reward_inputs: List[Dict[str, Any]], format_weight: float = 0.1, # ===== Still kept as configurable parameters ===== rm_server_type: str = "vllm", rm_batch_size: int = 64, normalize_model_reward_by_problem_id: bool = True, ) -> List[Dict[str, float]]: """ Batch interface. Each item: { "response": str, "response_length": int, "ground_truth": str, # may also contain ..., here we extract it first "data_type": str, # "image" | "video" | ... "problem_type": str # see branches above # Optional additional fields: # "problem": str # used as prompt for external RM in open-ended tasks # "problem_id": Any # grouping key for normalization } Returns: list of dict with keys {overall, format, accuracy, structure_reward} overall = (1 - format_weight) * accuracy + format_weight * format + structure_reward - format: 1.0 if ......, otherwise 0.0 - structure_reward: * spatial-temporal / tracking: 0.25 (JSON valid) + 0.25 (key overlap ratio) * other structured tasks: +0.5 if valid * non-structured tasks: default +0.5 """ if not isinstance(reward_inputs, list): raise ValueError("Please use `reward_type=batch` for this reward function.") results: List[Dict[str, float]] = [] # ===================== Collect open-ended samples to be evaluated ===================== open_ended_queue = [] # Each item: {idx, prompt, reference, output, problem_id} # ================================================================ for idx, item in enumerate(reward_inputs): try: # Normalize tag whitespaces, e.g. < / think > → raw_response = item.get("response", "") or "" response = re.sub(r"\s*(<|>|/)\s*", r"\1", raw_response) # print(f"reponse: {response}") # print(response) data_type = item.get("data_type", "") or "" problem_type = item.get("problem_type", "") or "" # ground_truth may also be wrapped in ...; extract it first here raw_gt = item.get("ground_truth", "") or "" gt_extracted = extract_answer(raw_gt) or raw_gt # print(f"gt_extracted: {gt_extracted}") # 1) format reward —— requires strict tag structure: ...... f_score = tag_format_reward(response) # 2) structure reward —— according to JSON structure requirements by task type (see function doc) ans = extract_answer(response) or "" s_reward = answer_structure_bonus(ans, gt_extracted, data_type, problem_type) # 3) accuracy (all normalized to [0,1]) if USE_MODEL_FOR_OPEN_ENDED and (problem_type or "").lower() == "open-ended": # First set to 0, and finally compute with external model and fill back a_score = 0.0 open_ended_queue.append({ "idx": idx, "prompt": item.get("problem", "") or "", "reference": gt_extracted or "", "output": ans or "", "problem_id": item.get("problem_id", None), }) else: a_score = accuracy_reward(response, gt_extracted, data_type, problem_type) if f_score == 0: s_reward = 0 overall = (1.0 - format_weight) * a_score + format_weight * f_score + s_reward results.append({ "overall": float(overall), "format": float(f_score), "accuracy": float(a_score), "structure_reward": float(s_reward), }) except Exception: # Fallback for the entire sample: any exception, all four fields are set to 0 results.append({ "overall": 0.0, "format": 0.0, "accuracy": 0.0, "structure_reward": 0.0, }) # ===================== Call wrapper for batch external evaluation and fill back ===================== evaluate_open_ended_with_rm( open_ended_queue=open_ended_queue, results=results, format_weight=format_weight, rm_server_type=rm_server_type, rm_batch_size=rm_batch_size, normalize_model_reward_by_problem_id=normalize_model_reward_by_problem_id ) # ====================================================================== if random.random() < 0.01: for idx, item in enumerate(reward_inputs): print('type', item.get("problem_type", "")) print('gt', extract_answer(item.get("ground_truth", ""))) print('ans', extract_answer(item.get("response", ""))) print({ "overall": results[idx]["overall"], "format": results[idx]["format"], "accuracy": results[idx]["accuracy"], "structure_reward": results[idx]["structure_reward"], }) return results