# asr_RL_reward_v56_dirty.py # -*- coding: utf-8 -*- import re import json import os import time from functools import lru_cache from collections import Counter from typing import Any, Dict, List, Tuple try: from swift.rewards import ORM, orms except Exception: from swift.plugin import ORM, orms from qwen_asr.inference.utils import parse_asr_output _ANSWER_RE = re.compile(r"(.*?)", re.S | re.I) _REWARD_DEBUG_COUNTER: Dict[str, int] = {} def _as_bool(x, default: bool = False) -> bool: if x is None: return default if isinstance(x, bool): return x if isinstance(x, (int, float)): return bool(x) return str(x).strip().lower() in {"1", "true", "yes", "y", "on"} def _to_jsonable(x: Any): if x is None or isinstance(x, (str, int, float, bool)): return x if isinstance(x, (list, tuple)): return [_to_jsonable(v) for v in x] if isinstance(x, dict): return {str(k): _to_jsonable(v) for k, v in x.items()} return str(x) def _pick_field(x, i: int): if x is None: return None if isinstance(x, (list, tuple)): return x[i] if i < len(x) else None return x def _reward_debug_enabled(kwargs) -> bool: if "reward_debug" in kwargs: return _as_bool(kwargs.get("reward_debug"), default=False) return _as_bool(os.environ.get("ASR_REWARD_DEBUG"), default=False) def _reward_debug_path(kwargs, reward_name: str) -> str: base = kwargs.get("reward_debug_path") or os.environ.get("ASR_REWARD_DEBUG_PATH") if not base: return "" base = str(base) if base.endswith(".jsonl"): base = base[:-6] return f"{base}.{reward_name}.pid{os.getpid()}.jsonl" def _reward_debug_max_rows(kwargs) -> int: x = kwargs.get("reward_debug_max_rows") if x is None: x = os.environ.get("ASR_REWARD_DEBUG_MAX_ROWS", 1000) try: x = int(x) except Exception: x = 1000 return max(0, x) def _collect_common_debug_meta(kwargs, i: int) -> Dict[str, Any]: fields = [ "step", "id", "sample_id", "group_id", "utt_id", "audio_id", "audio_path", "task", "lang", "language", "base_wer", "base_wer_bucket", "difficulty_bucket", "wer", "dirty_type", ] out = {} for k in fields: if k in kwargs: out[k] = _to_jsonable(_pick_field(kwargs.get(k), i)) return out def _append_reward_debug_row(reward_name: str, kwargs, row: Dict[str, Any]) -> None: if not _reward_debug_enabled(kwargs): return path = _reward_debug_path(kwargs, reward_name) if not path: return max_rows = _reward_debug_max_rows(kwargs) if max_rows <= 0: return cur = _REWARD_DEBUG_COUNTER.get(path, 0) if cur >= max_rows: return try: parent = os.path.dirname(path) if parent: os.makedirs(parent, exist_ok=True) payload = { "ts": time.time(), "reward_name": reward_name, **_to_jsonable(row), } with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") _REWARD_DEBUG_COUNTER[path] = cur + 1 except Exception: pass def _extract_completion_text(s: str) -> str: if s is None: return "" s = s.strip() m = _ANSWER_RE.search(s) if m: s = m.group(1).strip() lower = s.lower() for pfx in ["transcription:", "asr:", "answer:", "答案:", "答案:", "识别结果:", "识别结果:"]: if lower.startswith(pfx): s = s[len(pfx):].strip() break s = s.replace("<|im_end|>", "").strip() try: _lang, text = parse_asr_output(s, user_language=None) if text: s = text except Exception: pass return s def normalize_text(s: str) -> str: return (s or "").strip().lower() @lru_cache(maxsize=50000) def _tokenize_cached(s: str) -> Tuple[str, ...]: s = normalize_text(s) if not s: return tuple() if " " in s: return tuple(w for w in s.split() if w) return tuple(ch for ch in s if not ch.isspace()) def _tokenize(s: str) -> List[str]: return list(_tokenize_cached(s)) def _char_seq(s: str) -> List[str]: s = normalize_text(s) return [ch for ch in s if not ch.isspace()] def _infer_dirty_type_from_audio(x: Any) -> str: if x is None: return "other" if isinstance(x, (list, tuple)) and len(x) > 0: x = x[0] s = str(x).lower() if "voices" in s: return "voices_noise_plus_farfield" if "noise+rsp" in s or "resample_noise" in s or ("/noise/" in s and "voices" not in s): return "noise_rsp_pure_noise" return "other" def _get_dirty_type(kwargs, i: int) -> str: dtype = _pick_field(kwargs.get("dirty_type"), i) if dtype is not None: s = str(dtype).strip().lower() if s in {"voices", "voices_noise_plus_farfield", "voices_far", "voices_farfield"}: return "voices_noise_plus_farfield" if s in {"noise", "noise_rsp", "noise_rsp_pure_noise", "pure_noise"}: return "noise_rsp_pure_noise" return s audio_path = _pick_field(kwargs.get("audio_path"), i) if audio_path is not None: return _infer_dirty_type_from_audio(audio_path) audios = _pick_field(kwargs.get("audios"), i) return _infer_dirty_type_from_audio(audios) def _edit_ops_counts(ref_toks: List[str], hyp_toks: List[str]) -> Tuple[int, int, int]: n, m = len(ref_toks), len(hyp_toks) dp = [[0] * (m + 1) for _ in range(n + 1)] bt = [[0] * (m + 1) for _ in range(n + 1)] for i in range(1, n + 1): dp[i][0] = i bt[i][0] = 2 for j in range(1, m + 1): dp[0][j] = j bt[0][j] = 3 for i in range(1, n + 1): ri = ref_toks[i - 1] for j in range(1, m + 1): hj = hyp_toks[j - 1] if ri == hj: dp[i][j] = dp[i - 1][j - 1] bt[i][j] = 0 else: sub = dp[i - 1][j - 1] + 1 dele = dp[i - 1][j] + 1 ins = dp[i][j - 1] + 1 best = min(sub, dele, ins) dp[i][j] = best if best == sub: bt[i][j] = 1 elif best == dele: bt[i][j] = 2 else: bt[i][j] = 3 i, j = n, m sub = dele = ins = 0 while i > 0 or j > 0: op = bt[i][j] if i > 0 and j > 0 and op == 0: i -= 1 j -= 1 elif i > 0 and j > 0 and op == 1: sub += 1 i -= 1 j -= 1 elif i > 0 and op == 2: dele += 1 i -= 1 else: ins += 1 j -= 1 return sub, dele, ins def _char_bigram_f1(hyp: str, ref: str) -> float: h = _char_seq(hyp) r = _char_seq(ref) if not h and not r: return 1.0 if not h or not r: return 0.0 if len(h) < 2 or len(r) < 2: inter = sum(1 for x, y in zip(h, r) if x == y) p = inter / max(1, len(h)) rr = inter / max(1, len(r)) return 2.0 * p * rr / max(1e-8, p + rr) hg = Counter((h[i], h[i + 1]) for i in range(len(h) - 1)) rg = Counter((r[i], r[i + 1]) for i in range(len(r) - 1)) inter = sum(min(v, rg[k]) for k, v in hg.items()) p = inter / max(1, sum(hg.values())) rr = inter / max(1, sum(rg.values())) return 2.0 * p * rr / max(1e-8, p + rr) def _lcs_lengths(hyp: str, ref: str) -> Tuple[int, int, int]: h = _char_seq(hyp) r = _char_seq(ref) n, m = len(r), len(h) if n == 0 or m == 0: return 0, n, m prev = [0] * (m + 1) for i in range(1, n + 1): cur = [0] * (m + 1) ri = r[i - 1] for j in range(1, m + 1): if ri == h[j - 1]: cur[j] = prev[j - 1] + 1 else: cur[j] = max(prev[j], cur[j - 1]) prev = cur return prev[m], n, m def _lcs_f1(hyp: str, ref: str) -> float: lcs_len, ref_len, hyp_len = _lcs_lengths(hyp, ref) if ref_len == 0 and hyp_len == 0: return 1.0 if ref_len == 0 or hyp_len == 0: return 0.0 p = lcs_len / max(1, hyp_len) r = lcs_len / max(1, ref_len) return 2.0 * p * r / max(1e-8, p + r) def _cmp_score(hyp: str, ref: str) -> float: return 0.70 * _char_bigram_f1(hyp, ref) + 0.30 * _lcs_f1(hyp, ref) def wer_reward_main(wer: float) -> float: if wer <= 0.15: return 1.0 - 1.8 * wer elif wer <= 0.35: return 0.73 - 2.3 * (wer - 0.15) elif wer <= 0.70: return 0.27 - 2.4 * (wer - 0.35) elif wer <= 1.20: return -0.57 - 0.70 * (wer - 0.70) else: return -0.92 def length_ratio_penalty_v3( hyp_len: int, ref_len: int, soft_min: float = 0.90, soft_max: float = 1.10, hard_min: float = 0.78, hard_max: float = 1.30, soft_penalty: float = 0.10, hard_penalty: float = 0.36, ) -> float: ref_len = max(1, ref_len) ratio = hyp_len / ref_len if soft_min <= ratio <= soft_max: return 0.0 if hard_min <= ratio < soft_min: frac = (soft_min - ratio) / max(1e-6, soft_min - hard_min) return -soft_penalty * frac if soft_max < ratio <= hard_max: frac = (ratio - soft_max) / max(1e-6, hard_max - soft_max) return -soft_penalty * frac if ratio < hard_min: frac = min(1.0, (hard_min - ratio) / max(1e-6, hard_min)) return -(soft_penalty + (hard_penalty - soft_penalty) * frac) frac = min(1.0, (ratio - hard_max) / max(1e-6, hard_max)) return -(soft_penalty + (hard_penalty - soft_penalty) * frac) def tail_penalty(len_ratio: float) -> float: if len_ratio <= 1.15: return 0.0 if len_ratio <= 1.40: return -0.28 * (len_ratio - 1.15) / 0.25 if len_ratio <= 2.0: return -0.28 - 0.42 * (len_ratio - 1.40) / 0.60 return -0.70 def is_hallucination_v56(hyp_toks: List[str], ref_toks: List[str], wer: float, len_ratio: float): if len(hyp_toks) == 0: return True, "empty" run = 1 for i in range(1, len(hyp_toks)): run = run + 1 if hyp_toks[i] == hyp_toks[i - 1] else 1 if run >= 5: return True, "repeat_run>=5" if len(hyp_toks) >= 8: bigrams = [(hyp_toks[i], hyp_toks[i + 1]) for i in range(len(hyp_toks) - 1)] c = Counter(bigrams) most = c.most_common(1)[0][1] if most / max(1, len(bigrams)) > 0.22: return True, "repeat_bigram>0.22" if len_ratio > 1.60: return True, "len_ratio>1.60" if wer >= 1.20: return True, "wer>=1.20" return False, "ok" def _voices_residual(del_rate: float, len_ratio: float): p_del_voice = -0.12 * del_rate - 0.08 * max(0.0, del_rate - 0.10) p_under_voice = -0.06 * max(0.0, 0.98 - len_ratio) return p_del_voice, p_under_voice def _noise_residual(sub_rate: float, cmp_score: float): p_sub_noise = -0.08 * sub_rate p_cmp_noise = -0.04 * (1.0 - cmp_score) return p_sub_noise, p_cmp_noise class ASRWerSubLenCmpHalluDirtyV56(ORM): sub_penalty_a = 0.40 sub_penalty_b = 0.35 cmp_penalty = 0.14 hallu_extra_penalty = 0.42 empty_extra_penalty = 0.28 reward_clip_min = -4.0 reward_clip_max = 2.0 def __call__(self, completions, solution=None, **kwargs): if solution is None: solution = kwargs.get("solution") if solution is None: return [0.0 for _ in completions] if isinstance(solution, str): solution_list = [solution for _ in completions] else: solution_list = list(solution) rewards = [] for i, (comp, ref) in enumerate(zip(completions, solution_list)): hyp = _extract_completion_text(comp) ref = ref or "" ref_toks = _tokenize(ref) hyp_toks = _tokenize(hyp) ref_len = max(1, len(ref_toks)) hyp_len = len(hyp_toks) len_ratio = float(hyp_len) / float(ref_len) sub_cnt, del_cnt, ins_cnt = _edit_ops_counts(ref_toks, hyp_toks) wer = float(sub_cnt + del_cnt + ins_cnt) / float(ref_len) sub_rate = float(sub_cnt) / float(ref_len) del_rate = float(del_cnt) / float(ref_len) r_wer = wer_reward_main(wer) p_sub = -float(self.sub_penalty_a) * sub_rate - float(self.sub_penalty_b) * max(0.0, sub_rate - 0.35) p_len = length_ratio_penalty_v3(hyp_len=hyp_len, ref_len=ref_len) p_tail = tail_penalty(len_ratio) cmp_score = _cmp_score(hyp, ref) p_cmp = -float(self.cmp_penalty) * (1.0 - cmp_score) hallu, hallu_reason = is_hallucination_v56(hyp_toks, ref_toks, wer, len_ratio) p_hallu = -float(self.hallu_extra_penalty) if hallu else 0.0 p_empty = -float(self.empty_extra_penalty) if hyp_len == 0 else 0.0 dirty_type = _get_dirty_type(kwargs, i) p_del_voice = 0.0 p_under_voice = 0.0 p_sub_noise = 0.0 p_cmp_noise = 0.0 if dirty_type == "voices_noise_plus_farfield": p_del_voice, p_under_voice = _voices_residual(del_rate, len_ratio) elif dirty_type == "noise_rsp_pure_noise": p_sub_noise, p_cmp_noise = _noise_residual(sub_rate, cmp_score) reward_raw = float( r_wer + p_sub + p_len + p_tail + p_cmp + p_hallu + p_empty + p_del_voice + p_under_voice + p_sub_noise + p_cmp_noise ) r = max(float(self.reward_clip_min), min(float(self.reward_clip_max), reward_raw)) rewards.append(r) _append_reward_debug_row( reward_name="asr_wer_sub_len_cmp_hallu_dirty_v56", kwargs=kwargs, row={ **_collect_common_debug_meta(kwargs, i), "index": i, "dirty_type_resolved": dirty_type, "completion_raw": comp, "hyp": hyp, "ref": ref, "ref_len": ref_len, "hyp_len": hyp_len, "len_ratio": len_ratio, "sub_cnt": sub_cnt, "del_cnt": del_cnt, "ins_cnt": ins_cnt, "wer_calc": wer, "sub_rate": sub_rate, "del_rate": del_rate, "cmp_score": cmp_score, "hallu": hallu, "hallu_reason": hallu_reason, "r_wer": r_wer, "p_sub": p_sub, "p_len": p_len, "p_tail": p_tail, "p_cmp": p_cmp, "p_hallu": p_hallu, "p_empty": p_empty, "p_del_voice": p_del_voice, "p_under_voice": p_under_voice, "p_sub_noise": p_sub_noise, "p_cmp_noise": p_cmp_noise, "reward_raw": reward_raw, "reward": r, }, ) return rewards orms["asr_wer_sub_len_cmp_hallu_dirty_v56"] = ASRWerSubLenCmpHalluDirtyV56