| |
| |
|
|
| import re |
| import json |
| import os |
| import time |
| from functools import lru_cache |
| from collections import Counter |
| from typing import Any, Dict, List, Tuple |
|
|
| try: |
| from swift.rewards import ORM, orms |
| except Exception: |
| from swift.plugin import ORM, orms |
|
|
| from qwen_asr.inference.utils import parse_asr_output |
|
|
| _ANSWER_RE = re.compile(r"<answer>(.*?)</answer>", re.S | re.I) |
| _REWARD_DEBUG_COUNTER: Dict[str, int] = {} |
|
|
|
|
| def _as_bool(x, default: bool = False) -> bool: |
| if x is None: |
| return default |
| if isinstance(x, bool): |
| return x |
| if isinstance(x, (int, float)): |
| return bool(x) |
| return str(x).strip().lower() in {"1", "true", "yes", "y", "on"} |
|
|
|
|
| def _to_jsonable(x: Any): |
| if x is None or isinstance(x, (str, int, float, bool)): |
| return x |
| if isinstance(x, (list, tuple)): |
| return [_to_jsonable(v) for v in x] |
| if isinstance(x, dict): |
| return {str(k): _to_jsonable(v) for k, v in x.items()} |
| return str(x) |
|
|
|
|
| def _pick_field(x, i: int): |
| if x is None: |
| return None |
| if isinstance(x, (list, tuple)): |
| return x[i] if i < len(x) else None |
| return x |
|
|
|
|
| def _reward_debug_enabled(kwargs) -> bool: |
| if "reward_debug" in kwargs: |
| return _as_bool(kwargs.get("reward_debug"), default=False) |
| return _as_bool(os.environ.get("ASR_REWARD_DEBUG"), default=False) |
|
|
|
|
| def _reward_debug_path(kwargs, reward_name: str) -> str: |
| base = kwargs.get("reward_debug_path") or os.environ.get("ASR_REWARD_DEBUG_PATH") |
| if not base: |
| return "" |
| base = str(base) |
| if base.endswith(".jsonl"): |
| base = base[:-6] |
| return f"{base}.{reward_name}.pid{os.getpid()}.jsonl" |
|
|
|
|
| def _reward_debug_max_rows(kwargs) -> int: |
| x = kwargs.get("reward_debug_max_rows") |
| if x is None: |
| x = os.environ.get("ASR_REWARD_DEBUG_MAX_ROWS", 1000) |
| try: |
| x = int(x) |
| except Exception: |
| x = 1000 |
| return max(0, x) |
|
|
|
|
| def _collect_common_debug_meta(kwargs, i: int) -> Dict[str, Any]: |
| fields = [ |
| "step", "id", "sample_id", "group_id", "utt_id", "audio_id", "audio_path", |
| "task", "lang", "language", "base_wer", "base_wer_bucket", "difficulty_bucket", "wer", |
| "dirty_type", |
| ] |
| out = {} |
| for k in fields: |
| if k in kwargs: |
| out[k] = _to_jsonable(_pick_field(kwargs.get(k), i)) |
| return out |
|
|
|
|
| def _append_reward_debug_row(reward_name: str, kwargs, row: Dict[str, Any]) -> None: |
| if not _reward_debug_enabled(kwargs): |
| return |
| path = _reward_debug_path(kwargs, reward_name) |
| if not path: |
| return |
| max_rows = _reward_debug_max_rows(kwargs) |
| if max_rows <= 0: |
| return |
| cur = _REWARD_DEBUG_COUNTER.get(path, 0) |
| if cur >= max_rows: |
| return |
| try: |
| parent = os.path.dirname(path) |
| if parent: |
| os.makedirs(parent, exist_ok=True) |
| payload = { |
| "ts": time.time(), |
| "reward_name": reward_name, |
| **_to_jsonable(row), |
| } |
| with open(path, "a", encoding="utf-8") as f: |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
| _REWARD_DEBUG_COUNTER[path] = cur + 1 |
| except Exception: |
| pass |
|
|
|
|
| def _extract_completion_text(s: str) -> str: |
| if s is None: |
| return "" |
| s = s.strip() |
| m = _ANSWER_RE.search(s) |
| if m: |
| s = m.group(1).strip() |
| lower = s.lower() |
| for pfx in ["transcription:", "asr:", "answer:", "答案:", "答案:", "识别结果:", "识别结果:"]: |
| if lower.startswith(pfx): |
| s = s[len(pfx):].strip() |
| break |
| s = s.replace("<|im_end|>", "").strip() |
| try: |
| _lang, text = parse_asr_output(s, user_language=None) |
| if text: |
| s = text |
| except Exception: |
| pass |
| return s |
|
|
|
|
| def normalize_text(s: str) -> str: |
| return (s or "").strip().lower() |
|
|
|
|
| @lru_cache(maxsize=50000) |
| def _tokenize_cached(s: str) -> Tuple[str, ...]: |
| s = normalize_text(s) |
| if not s: |
| return tuple() |
| if " " in s: |
| return tuple(w for w in s.split() if w) |
| return tuple(ch for ch in s if not ch.isspace()) |
|
|
|
|
| def _tokenize(s: str) -> List[str]: |
| return list(_tokenize_cached(s)) |
|
|
|
|
| def _char_seq(s: str) -> List[str]: |
| s = normalize_text(s) |
| return [ch for ch in s if not ch.isspace()] |
|
|
|
|
| def _infer_dirty_type_from_audio(x: Any) -> str: |
| if x is None: |
| return "other" |
| if isinstance(x, (list, tuple)) and len(x) > 0: |
| x = x[0] |
| s = str(x).lower() |
| if "voices" in s: |
| return "voices_noise_plus_farfield" |
| if "noise+rsp" in s or "resample_noise" in s or ("/noise/" in s and "voices" not in s): |
| return "noise_rsp_pure_noise" |
| return "other" |
|
|
|
|
| def _get_dirty_type(kwargs, i: int) -> str: |
| dtype = _pick_field(kwargs.get("dirty_type"), i) |
| if dtype is not None: |
| s = str(dtype).strip().lower() |
| if s in {"voices", "voices_noise_plus_farfield", "voices_far", "voices_farfield"}: |
| return "voices_noise_plus_farfield" |
| if s in {"noise", "noise_rsp", "noise_rsp_pure_noise", "pure_noise"}: |
| return "noise_rsp_pure_noise" |
| return s |
|
|
| audio_path = _pick_field(kwargs.get("audio_path"), i) |
| if audio_path is not None: |
| return _infer_dirty_type_from_audio(audio_path) |
|
|
| audios = _pick_field(kwargs.get("audios"), i) |
| return _infer_dirty_type_from_audio(audios) |
|
|
|
|
| def _edit_ops_counts(ref_toks: List[str], hyp_toks: List[str]) -> Tuple[int, int, int]: |
| n, m = len(ref_toks), len(hyp_toks) |
| dp = [[0] * (m + 1) for _ in range(n + 1)] |
| bt = [[0] * (m + 1) for _ in range(n + 1)] |
|
|
| for i in range(1, n + 1): |
| dp[i][0] = i |
| bt[i][0] = 2 |
| for j in range(1, m + 1): |
| dp[0][j] = j |
| bt[0][j] = 3 |
|
|
| for i in range(1, n + 1): |
| ri = ref_toks[i - 1] |
| for j in range(1, m + 1): |
| hj = hyp_toks[j - 1] |
| if ri == hj: |
| dp[i][j] = dp[i - 1][j - 1] |
| bt[i][j] = 0 |
| else: |
| sub = dp[i - 1][j - 1] + 1 |
| dele = dp[i - 1][j] + 1 |
| ins = dp[i][j - 1] + 1 |
| best = min(sub, dele, ins) |
| dp[i][j] = best |
| if best == sub: |
| bt[i][j] = 1 |
| elif best == dele: |
| bt[i][j] = 2 |
| else: |
| bt[i][j] = 3 |
|
|
| i, j = n, m |
| sub = dele = ins = 0 |
| while i > 0 or j > 0: |
| op = bt[i][j] |
| if i > 0 and j > 0 and op == 0: |
| i -= 1 |
| j -= 1 |
| elif i > 0 and j > 0 and op == 1: |
| sub += 1 |
| i -= 1 |
| j -= 1 |
| elif i > 0 and op == 2: |
| dele += 1 |
| i -= 1 |
| else: |
| ins += 1 |
| j -= 1 |
| return sub, dele, ins |
|
|
|
|
| def _char_bigram_f1(hyp: str, ref: str) -> float: |
| h = _char_seq(hyp) |
| r = _char_seq(ref) |
| if not h and not r: |
| return 1.0 |
| if not h or not r: |
| return 0.0 |
| if len(h) < 2 or len(r) < 2: |
| inter = sum(1 for x, y in zip(h, r) if x == y) |
| p = inter / max(1, len(h)) |
| rr = inter / max(1, len(r)) |
| return 2.0 * p * rr / max(1e-8, p + rr) |
|
|
| hg = Counter((h[i], h[i + 1]) for i in range(len(h) - 1)) |
| rg = Counter((r[i], r[i + 1]) for i in range(len(r) - 1)) |
| inter = sum(min(v, rg[k]) for k, v in hg.items()) |
| p = inter / max(1, sum(hg.values())) |
| rr = inter / max(1, sum(rg.values())) |
| return 2.0 * p * rr / max(1e-8, p + rr) |
|
|
|
|
| def _lcs_lengths(hyp: str, ref: str) -> Tuple[int, int, int]: |
| h = _char_seq(hyp) |
| r = _char_seq(ref) |
| n, m = len(r), len(h) |
| if n == 0 or m == 0: |
| return 0, n, m |
| prev = [0] * (m + 1) |
| for i in range(1, n + 1): |
| cur = [0] * (m + 1) |
| ri = r[i - 1] |
| for j in range(1, m + 1): |
| if ri == h[j - 1]: |
| cur[j] = prev[j - 1] + 1 |
| else: |
| cur[j] = max(prev[j], cur[j - 1]) |
| prev = cur |
| return prev[m], n, m |
|
|
|
|
| def _lcs_f1(hyp: str, ref: str) -> float: |
| lcs_len, ref_len, hyp_len = _lcs_lengths(hyp, ref) |
| if ref_len == 0 and hyp_len == 0: |
| return 1.0 |
| if ref_len == 0 or hyp_len == 0: |
| return 0.0 |
| p = lcs_len / max(1, hyp_len) |
| r = lcs_len / max(1, ref_len) |
| return 2.0 * p * r / max(1e-8, p + r) |
|
|
|
|
| def _cmp_score(hyp: str, ref: str) -> float: |
| return 0.70 * _char_bigram_f1(hyp, ref) + 0.30 * _lcs_f1(hyp, ref) |
|
|
|
|
| def wer_reward_main(wer: float) -> float: |
| if wer <= 0.15: |
| return 1.0 - 1.8 * wer |
| elif wer <= 0.35: |
| return 0.73 - 2.3 * (wer - 0.15) |
| elif wer <= 0.70: |
| return 0.27 - 2.4 * (wer - 0.35) |
| elif wer <= 1.20: |
| return -0.57 - 0.70 * (wer - 0.70) |
| else: |
| return -0.92 |
|
|
|
|
| def length_ratio_penalty_v3( |
| hyp_len: int, |
| ref_len: int, |
| soft_min: float = 0.90, |
| soft_max: float = 1.10, |
| hard_min: float = 0.78, |
| hard_max: float = 1.30, |
| soft_penalty: float = 0.10, |
| hard_penalty: float = 0.36, |
| ) -> float: |
| ref_len = max(1, ref_len) |
| ratio = hyp_len / ref_len |
| if soft_min <= ratio <= soft_max: |
| return 0.0 |
| if hard_min <= ratio < soft_min: |
| frac = (soft_min - ratio) / max(1e-6, soft_min - hard_min) |
| return -soft_penalty * frac |
| if soft_max < ratio <= hard_max: |
| frac = (ratio - soft_max) / max(1e-6, hard_max - soft_max) |
| return -soft_penalty * frac |
| if ratio < hard_min: |
| frac = min(1.0, (hard_min - ratio) / max(1e-6, hard_min)) |
| return -(soft_penalty + (hard_penalty - soft_penalty) * frac) |
| frac = min(1.0, (ratio - hard_max) / max(1e-6, hard_max)) |
| return -(soft_penalty + (hard_penalty - soft_penalty) * frac) |
|
|
|
|
| def tail_penalty(len_ratio: float) -> float: |
| if len_ratio <= 1.15: |
| return 0.0 |
| if len_ratio <= 1.40: |
| return -0.28 * (len_ratio - 1.15) / 0.25 |
| if len_ratio <= 2.0: |
| return -0.28 - 0.42 * (len_ratio - 1.40) / 0.60 |
| return -0.70 |
|
|
|
|
| def is_hallucination_v56(hyp_toks: List[str], ref_toks: List[str], wer: float, len_ratio: float): |
| if len(hyp_toks) == 0: |
| return True, "empty" |
|
|
| run = 1 |
| for i in range(1, len(hyp_toks)): |
| run = run + 1 if hyp_toks[i] == hyp_toks[i - 1] else 1 |
| if run >= 5: |
| return True, "repeat_run>=5" |
|
|
| if len(hyp_toks) >= 8: |
| bigrams = [(hyp_toks[i], hyp_toks[i + 1]) for i in range(len(hyp_toks) - 1)] |
| c = Counter(bigrams) |
| most = c.most_common(1)[0][1] |
| if most / max(1, len(bigrams)) > 0.22: |
| return True, "repeat_bigram>0.22" |
|
|
| if len_ratio > 1.60: |
| return True, "len_ratio>1.60" |
|
|
| if wer >= 1.20: |
| return True, "wer>=1.20" |
|
|
| return False, "ok" |
|
|
|
|
| def _voices_residual(del_rate: float, len_ratio: float): |
| p_del_voice = -0.12 * del_rate - 0.08 * max(0.0, del_rate - 0.10) |
| p_under_voice = -0.06 * max(0.0, 0.98 - len_ratio) |
| return p_del_voice, p_under_voice |
|
|
|
|
| def _noise_residual(sub_rate: float, cmp_score: float): |
| p_sub_noise = -0.08 * sub_rate |
| p_cmp_noise = -0.04 * (1.0 - cmp_score) |
| return p_sub_noise, p_cmp_noise |
|
|
|
|
| class ASRWerSubLenCmpHalluDirtyV56(ORM): |
| sub_penalty_a = 0.40 |
| sub_penalty_b = 0.35 |
| cmp_penalty = 0.14 |
| hallu_extra_penalty = 0.42 |
| empty_extra_penalty = 0.28 |
|
|
| reward_clip_min = -4.0 |
| reward_clip_max = 2.0 |
|
|
| def __call__(self, completions, solution=None, **kwargs): |
| if solution is None: |
| solution = kwargs.get("solution") |
| if solution is None: |
| return [0.0 for _ in completions] |
|
|
| if isinstance(solution, str): |
| solution_list = [solution for _ in completions] |
| else: |
| solution_list = list(solution) |
|
|
| rewards = [] |
| for i, (comp, ref) in enumerate(zip(completions, solution_list)): |
| hyp = _extract_completion_text(comp) |
| ref = ref or "" |
|
|
| ref_toks = _tokenize(ref) |
| hyp_toks = _tokenize(hyp) |
|
|
| ref_len = max(1, len(ref_toks)) |
| hyp_len = len(hyp_toks) |
| len_ratio = float(hyp_len) / float(ref_len) |
|
|
| sub_cnt, del_cnt, ins_cnt = _edit_ops_counts(ref_toks, hyp_toks) |
| wer = float(sub_cnt + del_cnt + ins_cnt) / float(ref_len) |
|
|
| sub_rate = float(sub_cnt) / float(ref_len) |
| del_rate = float(del_cnt) / float(ref_len) |
|
|
| r_wer = wer_reward_main(wer) |
| p_sub = -float(self.sub_penalty_a) * sub_rate - float(self.sub_penalty_b) * max(0.0, sub_rate - 0.35) |
| p_len = length_ratio_penalty_v3(hyp_len=hyp_len, ref_len=ref_len) |
| p_tail = tail_penalty(len_ratio) |
|
|
| cmp_score = _cmp_score(hyp, ref) |
| p_cmp = -float(self.cmp_penalty) * (1.0 - cmp_score) |
|
|
| hallu, hallu_reason = is_hallucination_v56(hyp_toks, ref_toks, wer, len_ratio) |
| p_hallu = -float(self.hallu_extra_penalty) if hallu else 0.0 |
| p_empty = -float(self.empty_extra_penalty) if hyp_len == 0 else 0.0 |
|
|
| dirty_type = _get_dirty_type(kwargs, i) |
| p_del_voice = 0.0 |
| p_under_voice = 0.0 |
| p_sub_noise = 0.0 |
| p_cmp_noise = 0.0 |
|
|
| if dirty_type == "voices_noise_plus_farfield": |
| p_del_voice, p_under_voice = _voices_residual(del_rate, len_ratio) |
| elif dirty_type == "noise_rsp_pure_noise": |
| p_sub_noise, p_cmp_noise = _noise_residual(sub_rate, cmp_score) |
|
|
| reward_raw = float( |
| r_wer + p_sub + p_len + p_tail + p_cmp + p_hallu + p_empty |
| + p_del_voice + p_under_voice + p_sub_noise + p_cmp_noise |
| ) |
| r = max(float(self.reward_clip_min), min(float(self.reward_clip_max), reward_raw)) |
| rewards.append(r) |
|
|
| _append_reward_debug_row( |
| reward_name="asr_wer_sub_len_cmp_hallu_dirty_v56", |
| kwargs=kwargs, |
| row={ |
| **_collect_common_debug_meta(kwargs, i), |
| "index": i, |
| "dirty_type_resolved": dirty_type, |
| "completion_raw": comp, |
| "hyp": hyp, |
| "ref": ref, |
| "ref_len": ref_len, |
| "hyp_len": hyp_len, |
| "len_ratio": len_ratio, |
| "sub_cnt": sub_cnt, |
| "del_cnt": del_cnt, |
| "ins_cnt": ins_cnt, |
| "wer_calc": wer, |
| "sub_rate": sub_rate, |
| "del_rate": del_rate, |
| "cmp_score": cmp_score, |
| "hallu": hallu, |
| "hallu_reason": hallu_reason, |
| "r_wer": r_wer, |
| "p_sub": p_sub, |
| "p_len": p_len, |
| "p_tail": p_tail, |
| "p_cmp": p_cmp, |
| "p_hallu": p_hallu, |
| "p_empty": p_empty, |
| "p_del_voice": p_del_voice, |
| "p_under_voice": p_under_voice, |
| "p_sub_noise": p_sub_noise, |
| "p_cmp_noise": p_cmp_noise, |
| "reward_raw": reward_raw, |
| "reward": r, |
| }, |
| ) |
|
|
| return rewards |
|
|
|
|
| orms["asr_wer_sub_len_cmp_hallu_dirty_v56"] = ASRWerSubLenCmpHalluDirtyV56 |
|
|