0420upload / 0417_reward.py
Prummn's picture
Add files using upload-large-folder tool
03cb542 verified
# asr_RL_reward_v56_dirty.py
# -*- coding: utf-8 -*-
import re
import json
import os
import time
from functools import lru_cache
from collections import Counter
from typing import Any, Dict, List, Tuple
try:
from swift.rewards import ORM, orms
except Exception:
from swift.plugin import ORM, orms
from qwen_asr.inference.utils import parse_asr_output
_ANSWER_RE = re.compile(r"<answer>(.*?)</answer>", re.S | re.I)
_REWARD_DEBUG_COUNTER: Dict[str, int] = {}
def _as_bool(x, default: bool = False) -> bool:
if x is None:
return default
if isinstance(x, bool):
return x
if isinstance(x, (int, float)):
return bool(x)
return str(x).strip().lower() in {"1", "true", "yes", "y", "on"}
def _to_jsonable(x: Any):
if x is None or isinstance(x, (str, int, float, bool)):
return x
if isinstance(x, (list, tuple)):
return [_to_jsonable(v) for v in x]
if isinstance(x, dict):
return {str(k): _to_jsonable(v) for k, v in x.items()}
return str(x)
def _pick_field(x, i: int):
if x is None:
return None
if isinstance(x, (list, tuple)):
return x[i] if i < len(x) else None
return x
def _reward_debug_enabled(kwargs) -> bool:
if "reward_debug" in kwargs:
return _as_bool(kwargs.get("reward_debug"), default=False)
return _as_bool(os.environ.get("ASR_REWARD_DEBUG"), default=False)
def _reward_debug_path(kwargs, reward_name: str) -> str:
base = kwargs.get("reward_debug_path") or os.environ.get("ASR_REWARD_DEBUG_PATH")
if not base:
return ""
base = str(base)
if base.endswith(".jsonl"):
base = base[:-6]
return f"{base}.{reward_name}.pid{os.getpid()}.jsonl"
def _reward_debug_max_rows(kwargs) -> int:
x = kwargs.get("reward_debug_max_rows")
if x is None:
x = os.environ.get("ASR_REWARD_DEBUG_MAX_ROWS", 1000)
try:
x = int(x)
except Exception:
x = 1000
return max(0, x)
def _collect_common_debug_meta(kwargs, i: int) -> Dict[str, Any]:
fields = [
"step", "id", "sample_id", "group_id", "utt_id", "audio_id", "audio_path",
"task", "lang", "language", "base_wer", "base_wer_bucket", "difficulty_bucket", "wer",
"dirty_type",
]
out = {}
for k in fields:
if k in kwargs:
out[k] = _to_jsonable(_pick_field(kwargs.get(k), i))
return out
def _append_reward_debug_row(reward_name: str, kwargs, row: Dict[str, Any]) -> None:
if not _reward_debug_enabled(kwargs):
return
path = _reward_debug_path(kwargs, reward_name)
if not path:
return
max_rows = _reward_debug_max_rows(kwargs)
if max_rows <= 0:
return
cur = _REWARD_DEBUG_COUNTER.get(path, 0)
if cur >= max_rows:
return
try:
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
payload = {
"ts": time.time(),
"reward_name": reward_name,
**_to_jsonable(row),
}
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
_REWARD_DEBUG_COUNTER[path] = cur + 1
except Exception:
pass
def _extract_completion_text(s: str) -> str:
if s is None:
return ""
s = s.strip()
m = _ANSWER_RE.search(s)
if m:
s = m.group(1).strip()
lower = s.lower()
for pfx in ["transcription:", "asr:", "answer:", "答案:", "答案:", "识别结果:", "识别结果:"]:
if lower.startswith(pfx):
s = s[len(pfx):].strip()
break
s = s.replace("<|im_end|>", "").strip()
try:
_lang, text = parse_asr_output(s, user_language=None)
if text:
s = text
except Exception:
pass
return s
def normalize_text(s: str) -> str:
return (s or "").strip().lower()
@lru_cache(maxsize=50000)
def _tokenize_cached(s: str) -> Tuple[str, ...]:
s = normalize_text(s)
if not s:
return tuple()
if " " in s:
return tuple(w for w in s.split() if w)
return tuple(ch for ch in s if not ch.isspace())
def _tokenize(s: str) -> List[str]:
return list(_tokenize_cached(s))
def _char_seq(s: str) -> List[str]:
s = normalize_text(s)
return [ch for ch in s if not ch.isspace()]
def _infer_dirty_type_from_audio(x: Any) -> str:
if x is None:
return "other"
if isinstance(x, (list, tuple)) and len(x) > 0:
x = x[0]
s = str(x).lower()
if "voices" in s:
return "voices_noise_plus_farfield"
if "noise+rsp" in s or "resample_noise" in s or ("/noise/" in s and "voices" not in s):
return "noise_rsp_pure_noise"
return "other"
def _get_dirty_type(kwargs, i: int) -> str:
dtype = _pick_field(kwargs.get("dirty_type"), i)
if dtype is not None:
s = str(dtype).strip().lower()
if s in {"voices", "voices_noise_plus_farfield", "voices_far", "voices_farfield"}:
return "voices_noise_plus_farfield"
if s in {"noise", "noise_rsp", "noise_rsp_pure_noise", "pure_noise"}:
return "noise_rsp_pure_noise"
return s
audio_path = _pick_field(kwargs.get("audio_path"), i)
if audio_path is not None:
return _infer_dirty_type_from_audio(audio_path)
audios = _pick_field(kwargs.get("audios"), i)
return _infer_dirty_type_from_audio(audios)
def _edit_ops_counts(ref_toks: List[str], hyp_toks: List[str]) -> Tuple[int, int, int]:
n, m = len(ref_toks), len(hyp_toks)
dp = [[0] * (m + 1) for _ in range(n + 1)]
bt = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
dp[i][0] = i
bt[i][0] = 2
for j in range(1, m + 1):
dp[0][j] = j
bt[0][j] = 3
for i in range(1, n + 1):
ri = ref_toks[i - 1]
for j in range(1, m + 1):
hj = hyp_toks[j - 1]
if ri == hj:
dp[i][j] = dp[i - 1][j - 1]
bt[i][j] = 0
else:
sub = dp[i - 1][j - 1] + 1
dele = dp[i - 1][j] + 1
ins = dp[i][j - 1] + 1
best = min(sub, dele, ins)
dp[i][j] = best
if best == sub:
bt[i][j] = 1
elif best == dele:
bt[i][j] = 2
else:
bt[i][j] = 3
i, j = n, m
sub = dele = ins = 0
while i > 0 or j > 0:
op = bt[i][j]
if i > 0 and j > 0 and op == 0:
i -= 1
j -= 1
elif i > 0 and j > 0 and op == 1:
sub += 1
i -= 1
j -= 1
elif i > 0 and op == 2:
dele += 1
i -= 1
else:
ins += 1
j -= 1
return sub, dele, ins
def _char_bigram_f1(hyp: str, ref: str) -> float:
h = _char_seq(hyp)
r = _char_seq(ref)
if not h and not r:
return 1.0
if not h or not r:
return 0.0
if len(h) < 2 or len(r) < 2:
inter = sum(1 for x, y in zip(h, r) if x == y)
p = inter / max(1, len(h))
rr = inter / max(1, len(r))
return 2.0 * p * rr / max(1e-8, p + rr)
hg = Counter((h[i], h[i + 1]) for i in range(len(h) - 1))
rg = Counter((r[i], r[i + 1]) for i in range(len(r) - 1))
inter = sum(min(v, rg[k]) for k, v in hg.items())
p = inter / max(1, sum(hg.values()))
rr = inter / max(1, sum(rg.values()))
return 2.0 * p * rr / max(1e-8, p + rr)
def _lcs_lengths(hyp: str, ref: str) -> Tuple[int, int, int]:
h = _char_seq(hyp)
r = _char_seq(ref)
n, m = len(r), len(h)
if n == 0 or m == 0:
return 0, n, m
prev = [0] * (m + 1)
for i in range(1, n + 1):
cur = [0] * (m + 1)
ri = r[i - 1]
for j in range(1, m + 1):
if ri == h[j - 1]:
cur[j] = prev[j - 1] + 1
else:
cur[j] = max(prev[j], cur[j - 1])
prev = cur
return prev[m], n, m
def _lcs_f1(hyp: str, ref: str) -> float:
lcs_len, ref_len, hyp_len = _lcs_lengths(hyp, ref)
if ref_len == 0 and hyp_len == 0:
return 1.0
if ref_len == 0 or hyp_len == 0:
return 0.0
p = lcs_len / max(1, hyp_len)
r = lcs_len / max(1, ref_len)
return 2.0 * p * r / max(1e-8, p + r)
def _cmp_score(hyp: str, ref: str) -> float:
return 0.70 * _char_bigram_f1(hyp, ref) + 0.30 * _lcs_f1(hyp, ref)
def wer_reward_main(wer: float) -> float:
if wer <= 0.15:
return 1.0 - 1.8 * wer
elif wer <= 0.35:
return 0.73 - 2.3 * (wer - 0.15)
elif wer <= 0.70:
return 0.27 - 2.4 * (wer - 0.35)
elif wer <= 1.20:
return -0.57 - 0.70 * (wer - 0.70)
else:
return -0.92
def length_ratio_penalty_v3(
hyp_len: int,
ref_len: int,
soft_min: float = 0.90,
soft_max: float = 1.10,
hard_min: float = 0.78,
hard_max: float = 1.30,
soft_penalty: float = 0.10,
hard_penalty: float = 0.36,
) -> float:
ref_len = max(1, ref_len)
ratio = hyp_len / ref_len
if soft_min <= ratio <= soft_max:
return 0.0
if hard_min <= ratio < soft_min:
frac = (soft_min - ratio) / max(1e-6, soft_min - hard_min)
return -soft_penalty * frac
if soft_max < ratio <= hard_max:
frac = (ratio - soft_max) / max(1e-6, hard_max - soft_max)
return -soft_penalty * frac
if ratio < hard_min:
frac = min(1.0, (hard_min - ratio) / max(1e-6, hard_min))
return -(soft_penalty + (hard_penalty - soft_penalty) * frac)
frac = min(1.0, (ratio - hard_max) / max(1e-6, hard_max))
return -(soft_penalty + (hard_penalty - soft_penalty) * frac)
def tail_penalty(len_ratio: float) -> float:
if len_ratio <= 1.15:
return 0.0
if len_ratio <= 1.40:
return -0.28 * (len_ratio - 1.15) / 0.25
if len_ratio <= 2.0:
return -0.28 - 0.42 * (len_ratio - 1.40) / 0.60
return -0.70
def is_hallucination_v56(hyp_toks: List[str], ref_toks: List[str], wer: float, len_ratio: float):
if len(hyp_toks) == 0:
return True, "empty"
run = 1
for i in range(1, len(hyp_toks)):
run = run + 1 if hyp_toks[i] == hyp_toks[i - 1] else 1
if run >= 5:
return True, "repeat_run>=5"
if len(hyp_toks) >= 8:
bigrams = [(hyp_toks[i], hyp_toks[i + 1]) for i in range(len(hyp_toks) - 1)]
c = Counter(bigrams)
most = c.most_common(1)[0][1]
if most / max(1, len(bigrams)) > 0.22:
return True, "repeat_bigram>0.22"
if len_ratio > 1.60:
return True, "len_ratio>1.60"
if wer >= 1.20:
return True, "wer>=1.20"
return False, "ok"
def _voices_residual(del_rate: float, len_ratio: float):
p_del_voice = -0.12 * del_rate - 0.08 * max(0.0, del_rate - 0.10)
p_under_voice = -0.06 * max(0.0, 0.98 - len_ratio)
return p_del_voice, p_under_voice
def _noise_residual(sub_rate: float, cmp_score: float):
p_sub_noise = -0.08 * sub_rate
p_cmp_noise = -0.04 * (1.0 - cmp_score)
return p_sub_noise, p_cmp_noise
class ASRWerSubLenCmpHalluDirtyV56(ORM):
sub_penalty_a = 0.40
sub_penalty_b = 0.35
cmp_penalty = 0.14
hallu_extra_penalty = 0.42
empty_extra_penalty = 0.28
reward_clip_min = -4.0
reward_clip_max = 2.0
def __call__(self, completions, solution=None, **kwargs):
if solution is None:
solution = kwargs.get("solution")
if solution is None:
return [0.0 for _ in completions]
if isinstance(solution, str):
solution_list = [solution for _ in completions]
else:
solution_list = list(solution)
rewards = []
for i, (comp, ref) in enumerate(zip(completions, solution_list)):
hyp = _extract_completion_text(comp)
ref = ref or ""
ref_toks = _tokenize(ref)
hyp_toks = _tokenize(hyp)
ref_len = max(1, len(ref_toks))
hyp_len = len(hyp_toks)
len_ratio = float(hyp_len) / float(ref_len)
sub_cnt, del_cnt, ins_cnt = _edit_ops_counts(ref_toks, hyp_toks)
wer = float(sub_cnt + del_cnt + ins_cnt) / float(ref_len)
sub_rate = float(sub_cnt) / float(ref_len)
del_rate = float(del_cnt) / float(ref_len)
r_wer = wer_reward_main(wer)
p_sub = -float(self.sub_penalty_a) * sub_rate - float(self.sub_penalty_b) * max(0.0, sub_rate - 0.35)
p_len = length_ratio_penalty_v3(hyp_len=hyp_len, ref_len=ref_len)
p_tail = tail_penalty(len_ratio)
cmp_score = _cmp_score(hyp, ref)
p_cmp = -float(self.cmp_penalty) * (1.0 - cmp_score)
hallu, hallu_reason = is_hallucination_v56(hyp_toks, ref_toks, wer, len_ratio)
p_hallu = -float(self.hallu_extra_penalty) if hallu else 0.0
p_empty = -float(self.empty_extra_penalty) if hyp_len == 0 else 0.0
dirty_type = _get_dirty_type(kwargs, i)
p_del_voice = 0.0
p_under_voice = 0.0
p_sub_noise = 0.0
p_cmp_noise = 0.0
if dirty_type == "voices_noise_plus_farfield":
p_del_voice, p_under_voice = _voices_residual(del_rate, len_ratio)
elif dirty_type == "noise_rsp_pure_noise":
p_sub_noise, p_cmp_noise = _noise_residual(sub_rate, cmp_score)
reward_raw = float(
r_wer + p_sub + p_len + p_tail + p_cmp + p_hallu + p_empty
+ p_del_voice + p_under_voice + p_sub_noise + p_cmp_noise
)
r = max(float(self.reward_clip_min), min(float(self.reward_clip_max), reward_raw))
rewards.append(r)
_append_reward_debug_row(
reward_name="asr_wer_sub_len_cmp_hallu_dirty_v56",
kwargs=kwargs,
row={
**_collect_common_debug_meta(kwargs, i),
"index": i,
"dirty_type_resolved": dirty_type,
"completion_raw": comp,
"hyp": hyp,
"ref": ref,
"ref_len": ref_len,
"hyp_len": hyp_len,
"len_ratio": len_ratio,
"sub_cnt": sub_cnt,
"del_cnt": del_cnt,
"ins_cnt": ins_cnt,
"wer_calc": wer,
"sub_rate": sub_rate,
"del_rate": del_rate,
"cmp_score": cmp_score,
"hallu": hallu,
"hallu_reason": hallu_reason,
"r_wer": r_wer,
"p_sub": p_sub,
"p_len": p_len,
"p_tail": p_tail,
"p_cmp": p_cmp,
"p_hallu": p_hallu,
"p_empty": p_empty,
"p_del_voice": p_del_voice,
"p_under_voice": p_under_voice,
"p_sub_noise": p_sub_noise,
"p_cmp_noise": p_cmp_noise,
"reward_raw": reward_raw,
"reward": r,
},
)
return rewards
orms["asr_wer_sub_len_cmp_hallu_dirty_v56"] = ASRWerSubLenCmpHalluDirtyV56