Spaces:
Running
Running
Initial: Gradio demo + 8 VITW examples + WER scoring
Browse files- .gitattributes +8 -0
- README.md +38 -7
- app.py +214 -0
- examples/distortion.wav +3 -0
- examples/dropout.wav +3 -0
- examples/echo.wav +3 -0
- examples/far_field.wav +3 -0
- examples/mixed.wav +3 -0
- examples/noise.wav +3 -0
- examples/obstructed.wav +3 -0
- examples/recording.wav +3 -0
- requirements.txt +11 -0
- vendor/MegaASR/A2S-SFT/__init__.py +1 -0
- vendor/MegaASR/A2S-SFT/arguments.py +57 -0
- vendor/MegaASR/A2S-SFT/checkpointing.py +63 -0
- vendor/MegaASR/A2S-SFT/dataloader.py +81 -0
- vendor/MegaASR/A2S-SFT/finetune.py +88 -0
- vendor/MegaASR/A2S-SFT/modeling.py +89 -0
- vendor/MegaASR/A2S-SFT/readme.md +53 -0
- vendor/MegaASR/A2S-SFT/trainer.py +101 -0
- vendor/MegaASR/DG-WGPO/README.md +3 -0
- vendor/MegaASR/data/download_librispeech.py +94 -0
- vendor/MegaASR/data/download_librispeech.sh +3 -0
- vendor/MegaASR/eval/cn_tn.py +1203 -0
- vendor/MegaASR/eval/evaluate_wer.py +179 -0
- vendor/MegaASR/eval/evaluate_wer.sh +4 -0
- vendor/MegaASR/eval/readme.md +48 -0
- vendor/MegaASR/model/Qwen3_ASR.py +104 -0
- vendor/MegaASR/model/__pycache__/Qwen3_ASR.cpython-311.pyc +0 -0
- vendor/MegaASR/model/__pycache__/megaASR.cpython-311.pyc +0 -0
- vendor/MegaASR/model/__pycache__/router.cpython-311.pyc +0 -0
- vendor/MegaASR/model/megaASR.py +192 -0
- vendor/MegaASR/model/router.py +108 -0
- vendor/MegaASR/model/utils/__init__.py +0 -0
- vendor/MegaASR/model/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- vendor/MegaASR/model/utils/__pycache__/audio_quality.cpython-311.pyc +0 -0
- vendor/MegaASR/model/utils/__pycache__/lora_switch.cpython-311.pyc +0 -0
- vendor/MegaASR/model/utils/audio_quality.py +179 -0
- vendor/MegaASR/model/utils/lora_switch.py +235 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/distortion.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/dropout.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/echo.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/far_field.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/mixed.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/noise.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/obstructed.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/recording.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,13 +1,44 @@
|
|
| 1 |
---
|
| 2 |
-
title: Mega
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
python_version: '3.13'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Mega-ASR — robust ASR demo
|
| 3 |
+
emoji: 🎙️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Robust in-the-wild speech recognition with Mega-ASR
|
| 12 |
+
models:
|
| 13 |
+
- zhifeixie/Mega-ASR
|
| 14 |
+
- Reza2kn/mega-asr-onnx
|
| 15 |
+
datasets:
|
| 16 |
+
- xzf-thu/Voices-in-the-Wild-Bench
|
| 17 |
+
hf_oauth: false
|
| 18 |
+
tags:
|
| 19 |
+
- automatic-speech-recognition
|
| 20 |
+
- robust-asr
|
| 21 |
+
- mega-asr
|
| 22 |
+
- benchmark
|
| 23 |
+
- wer
|
| 24 |
---
|
| 25 |
|
| 26 |
+
# Mega-ASR demo
|
| 27 |
+
|
| 28 |
+
Live demo of [Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR), a 1.7B-param
|
| 29 |
+
multilingual ASR foundation model designed for robust transcription in real-world
|
| 30 |
+
acoustic conditions (noise, far-field, reverberation, recording artifacts, etc.).
|
| 31 |
+
|
| 32 |
+
Pre-loaded examples come from
|
| 33 |
+
[Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench).
|
| 34 |
+
|
| 35 |
+
The demo accepts your audio + an optional ground-truth transcript and shows the
|
| 36 |
+
word-level agreement (1 - WER) color-coded as:
|
| 37 |
+
|
| 38 |
+
- 🟢 **green** ≥ 70 %
|
| 39 |
+
- 🟠 **orange** 50-70 %
|
| 40 |
+
- 🟡 **yellow** 25-50 %
|
| 41 |
+
- 🔴 **red** < 25 %
|
| 42 |
+
|
| 43 |
+
INT4 ONNX deployment artifacts (encoder + prefill + step decoder) are at
|
| 44 |
+
[Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx).
|
app.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mega-ASR robust speech-recognition demo on Hugging Face Spaces.
|
| 2 |
+
|
| 3 |
+
Loads Mega-ASR (zhifeixie/Mega-ASR) and transcribes user-supplied audio
|
| 4 |
+
against an optional reference text. Computes word-level agreement (1 - WER)
|
| 5 |
+
and shows it with a color band:
|
| 6 |
+
green >= 70%
|
| 7 |
+
orange 50-70%
|
| 8 |
+
yellow 25-50%
|
| 9 |
+
red < 25%
|
| 10 |
+
|
| 11 |
+
Pre-loaded examples are the 8 single-condition clips from the
|
| 12 |
+
Voices-in-the-Wild-Bench (noise / far_field / obstructed / distortion /
|
| 13 |
+
recording / echo / dropout / mixed).
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------- model
|
| 25 |
+
_model = None
|
| 26 |
+
_normalize_text = None
|
| 27 |
+
_word_norm_re = re.compile(r"[^a-z0-9\s]")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_model():
|
| 31 |
+
"""Lazy-load Mega-ASR on first request (HF Spaces ZeroGPU pattern)."""
|
| 32 |
+
global _model
|
| 33 |
+
if _model is not None:
|
| 34 |
+
return _model
|
| 35 |
+
print("[mega-asr] loading model …")
|
| 36 |
+
from huggingface_hub import snapshot_download
|
| 37 |
+
import sys
|
| 38 |
+
# The mega-asr source is published at github.com/xzf-thu/Mega-ASR but not
|
| 39 |
+
# on PyPI. We vendor a copy here under ./vendor/MegaASR.
|
| 40 |
+
sys.path.insert(0, str(Path(__file__).parent / "vendor"))
|
| 41 |
+
from MegaASR.model.megaASR import MegaASR
|
| 42 |
+
ckpt = snapshot_download("zhifeixie/Mega-ASR")
|
| 43 |
+
_model = MegaASR(
|
| 44 |
+
model_path=Path(ckpt) / "Qwen3-ASR-1.7B",
|
| 45 |
+
lora_dir=Path(ckpt) / "mega-asr-merged",
|
| 46 |
+
router_checkpoint=Path(ckpt) / "audio_quality_router" / "best_acc_model.safetensors",
|
| 47 |
+
routing_enabled=True,
|
| 48 |
+
quality_threshold=0.5,
|
| 49 |
+
device_map="cuda" if _cuda_available() else "cpu",
|
| 50 |
+
)
|
| 51 |
+
print("[mega-asr] loaded.")
|
| 52 |
+
return _model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _cuda_available() -> bool:
|
| 56 |
+
try:
|
| 57 |
+
import torch
|
| 58 |
+
return torch.cuda.is_available()
|
| 59 |
+
except Exception:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------- agree
|
| 64 |
+
def _normalize(text: str) -> str:
|
| 65 |
+
# Mega-ASR prefixes its output with "language English<asr_text>…"
|
| 66 |
+
# Strip up to the `<asr_text>` marker if present.
|
| 67 |
+
if "<asr_text>" in text:
|
| 68 |
+
text = text.split("<asr_text>", 1)[1]
|
| 69 |
+
text = text.lower()
|
| 70 |
+
text = _word_norm_re.sub(" ", text)
|
| 71 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 72 |
+
return text
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _wer(ref: str, hyp: str) -> tuple[float, int, int]:
|
| 76 |
+
r = ref.split()
|
| 77 |
+
h = hyp.split()
|
| 78 |
+
if not r:
|
| 79 |
+
return (1.0 if h else 0.0, len(h), 0)
|
| 80 |
+
d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
|
| 81 |
+
for i in range(len(r) + 1):
|
| 82 |
+
d[i, 0] = i
|
| 83 |
+
for j in range(len(h) + 1):
|
| 84 |
+
d[0, j] = j
|
| 85 |
+
for i in range(1, len(r) + 1):
|
| 86 |
+
for j in range(1, len(h) + 1):
|
| 87 |
+
sub = d[i - 1, j - 1] + (0 if r[i - 1] == h[j - 1] else 1)
|
| 88 |
+
ins = d[i, j - 1] + 1
|
| 89 |
+
dele = d[i - 1, j] + 1
|
| 90 |
+
d[i, j] = min(sub, ins, dele)
|
| 91 |
+
err = int(d[len(r), len(h)])
|
| 92 |
+
return err / max(len(r), 1), err, len(r)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _band_html(agree_pct: float, hyp_text: str) -> str:
|
| 96 |
+
if agree_pct >= 70:
|
| 97 |
+
color, label = "#2ec27e", "match"
|
| 98 |
+
emoji = "✅"
|
| 99 |
+
elif agree_pct >= 50:
|
| 100 |
+
color, label = "#e8a23a", "close"
|
| 101 |
+
emoji = "🟠"
|
| 102 |
+
elif agree_pct >= 25:
|
| 103 |
+
color, label = "#e0c34a", "partial"
|
| 104 |
+
emoji = "🟡"
|
| 105 |
+
else:
|
| 106 |
+
color, label = "#e0524c", "diverged"
|
| 107 |
+
emoji = "🔴"
|
| 108 |
+
return f"""
|
| 109 |
+
<div style="border-radius:8px;padding:14px 16px;background:{color}1a;border:2px solid {color};font-size:15px;">
|
| 110 |
+
<div style="font-size:18px;color:{color};margin-bottom:6px;"><b>{emoji} {agree_pct:.1f}% agreement</b> · {label}</div>
|
| 111 |
+
<div style="color:#1a1a1a;line-height:1.5;"><b>Transcription:</b> {hyp_text}</div>
|
| 112 |
+
</div>
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------------------------------------------------------------------- core
|
| 117 |
+
REFERENCES = {
|
| 118 |
+
"noise": "I usually take the quieter road home because the main street gets crowded after work.",
|
| 119 |
+
"far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.",
|
| 120 |
+
"obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
|
| 121 |
+
"distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
|
| 122 |
+
"recording": "Can you check whether the train still stops at the downtown station after eight tonight?",
|
| 123 |
+
"echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.",
|
| 124 |
+
"dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
|
| 125 |
+
"mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.",
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def transcribe(audio_path: str | None, reference: str = ""):
|
| 130 |
+
if not audio_path:
|
| 131 |
+
return "<div style='color:#999;font-style:italic;'>Please upload an audio clip first.</div>"
|
| 132 |
+
model = _load_model()
|
| 133 |
+
out = model.infer(audio_path, return_route=True)
|
| 134 |
+
hyp = out["text"][0] if isinstance(out.get("text"), list) else out.get("text", "")
|
| 135 |
+
use_lora = out.get("use_lora")
|
| 136 |
+
if not reference.strip():
|
| 137 |
+
# No reference: just show the transcription unscored.
|
| 138 |
+
return f"""
|
| 139 |
+
<div style="border-radius:8px;padding:14px 16px;background:#f3f4f6;border:1px solid #d1d5db;">
|
| 140 |
+
<div style="color:#555;margin-bottom:6px;font-size:13px;">{'LoRA' if use_lora else 'base'} adapter</div>
|
| 141 |
+
<div style="font-size:16px;"><b>Transcription:</b> {_normalize(hyp) or hyp}</div>
|
| 142 |
+
</div>"""
|
| 143 |
+
r_norm = _normalize(reference)
|
| 144 |
+
h_norm = _normalize(hyp)
|
| 145 |
+
wer_val, errors, n_words = _wer(r_norm, h_norm)
|
| 146 |
+
agree = max(0.0, 1.0 - wer_val) * 100
|
| 147 |
+
return _band_html(agree, h_norm or "<i>(empty)</i>") + f"""
|
| 148 |
+
<div style="margin-top:10px;font-size:13px;color:#555;">
|
| 149 |
+
<b>Reference:</b> {r_norm}<br>
|
| 150 |
+
<b>{'LoRA' if use_lora else 'base'} adapter</b> · WER {wer_val*100:.1f}% ({errors}/{n_words} words)
|
| 151 |
+
</div>"""
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ---------------------------------------------------------------------- UI
|
| 155 |
+
def _load_example(name: str) -> tuple[str, str]:
|
| 156 |
+
return (f"examples/{name}.wav", REFERENCES[name])
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
with gr.Blocks(title="Mega-ASR — robust speech recognition") as demo:
|
| 160 |
+
gr.Markdown("""
|
| 161 |
+
# Mega-ASR demo
|
| 162 |
+
|
| 163 |
+
[Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) is a 1.7B-param multilingual ASR model
|
| 164 |
+
trained on 2.6M real-world voice samples covering noise, far-field, reverberation,
|
| 165 |
+
distortion, and other in-the-wild acoustic conditions. It uses an audio-quality
|
| 166 |
+
router that swaps between a base model and a LoRA-adapted "robust" model per input.
|
| 167 |
+
|
| 168 |
+
Upload an audio clip (or pick one of the 8 examples from
|
| 169 |
+
[Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)),
|
| 170 |
+
optionally paste the ground-truth transcript, and see what the model produces
|
| 171 |
+
plus a word-level agreement score.
|
| 172 |
+
""")
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column(scale=1):
|
| 175 |
+
audio_in = gr.Audio(type="filepath", label="Audio")
|
| 176 |
+
ref_in = gr.Textbox(label="Reference transcript (optional)",
|
| 177 |
+
placeholder="If you have the ground-truth text, paste it here for scoring.",
|
| 178 |
+
lines=3)
|
| 179 |
+
btn = gr.Button("Transcribe", variant="primary")
|
| 180 |
+
gr.Markdown("### 8 noisy examples")
|
| 181 |
+
with gr.Row():
|
| 182 |
+
noise_btn = gr.Button("🔊 noise")
|
| 183 |
+
far_btn = gr.Button("📡 far-field")
|
| 184 |
+
with gr.Row():
|
| 185 |
+
obs_btn = gr.Button("🚧 obstructed")
|
| 186 |
+
dist_btn = gr.Button("🎛️ distortion")
|
| 187 |
+
with gr.Row():
|
| 188 |
+
rec_btn = gr.Button("🎙️ recording")
|
| 189 |
+
echo_btn = gr.Button("🏛️ echo")
|
| 190 |
+
with gr.Row():
|
| 191 |
+
drop_btn = gr.Button("✂️ dropout")
|
| 192 |
+
mix_btn = gr.Button("🌪️ mixed")
|
| 193 |
+
with gr.Column(scale=1):
|
| 194 |
+
output = gr.HTML(label="Result")
|
| 195 |
+
|
| 196 |
+
btn.click(transcribe, inputs=[audio_in, ref_in], outputs=output)
|
| 197 |
+
# Example buttons
|
| 198 |
+
for b, name in [(noise_btn, "noise"), (far_btn, "far_field"), (obs_btn, "obstructed"),
|
| 199 |
+
(dist_btn, "distortion"), (rec_btn, "recording"), (echo_btn, "echo"),
|
| 200 |
+
(drop_btn, "dropout"), (mix_btn, "mixed")]:
|
| 201 |
+
b.click(_load_example, inputs=gr.State(name), outputs=[audio_in, ref_in])
|
| 202 |
+
|
| 203 |
+
gr.Markdown("""
|
| 204 |
+
---
|
| 205 |
+
**Color bands:** 🟢 ≥70% · 🟠 50-70% · 🟡 25-50% · 🔴 <25%
|
| 206 |
+
|
| 207 |
+
**Companion projects:**
|
| 208 |
+
- INT4 ONNX deployment artifacts: [Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx)
|
| 209 |
+
- Benchmark dataset: [xzf-thu/Voices-in-the-Wild-Bench](https://huggingface.co/datasets/xzf-thu/Voices-in-the-Wild-Bench)
|
| 210 |
+
""")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
demo.launch()
|
examples/distortion.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:194231a8aa2a31049d167df3f52bc62d4e9377aa935678c983d1165f3c9ca86d
|
| 3 |
+
size 353324
|
examples/dropout.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff94831ee3497ce90d9b873719b823e5c4c4a9890dec86832e0b6357cd2b2e6f
|
| 3 |
+
size 320684
|
examples/echo.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2eda219a3b7091c7a2772408db3f0356d1d7d30184d0523a6c98f6fdec35bd2b
|
| 3 |
+
size 359084
|
examples/far_field.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4546478ee704b5b7b81bb4937e6c74b48be82aef541e0fb3388fcb49789082d
|
| 3 |
+
size 284204
|
examples/mixed.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:339b28e7f59abb2dcfac81c22298060a5da85d09ea3363a8e4004e17a15b31e2
|
| 3 |
+
size 243884
|
examples/noise.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96793b49e286c7b05a3081fccf0d6a6f7df85cc5aef0a2d28f3b4aaba60d95d1
|
| 3 |
+
size 416684
|
examples/obstructed.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:411e605c25a6a8a6b09e49f0db2ad7543854bbcd0cab4cd7157fc429f9c5b0d3
|
| 3 |
+
size 422444
|
examples/recording.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a548239681d613a007375825fd2423494be634c49bd450e46743f29101ffdcfc
|
| 3 |
+
size 240044
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
qwen-asr
|
| 3 |
+
torch>=2.6
|
| 4 |
+
torchaudio
|
| 5 |
+
huggingface_hub
|
| 6 |
+
safetensors
|
| 7 |
+
psutil
|
| 8 |
+
soundfile
|
| 9 |
+
peft
|
| 10 |
+
whisper_normalizer
|
| 11 |
+
numpy
|
vendor/MegaASR/A2S-SFT/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
vendor/MegaASR/A2S-SFT/arguments.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def parse_args():
|
| 6 |
+
p = argparse.ArgumentParser("Mega-ASR A2S-SFT")
|
| 7 |
+
|
| 8 |
+
# paths
|
| 9 |
+
p.add_argument("--model_path", type=str, required=True)
|
| 10 |
+
p.add_argument("--train_file", type=str, required=True)
|
| 11 |
+
p.add_argument("--eval_file", type=str, default="")
|
| 12 |
+
p.add_argument("--output_dir", type=str, default="outputs/a2s_sft")
|
| 13 |
+
|
| 14 |
+
# data
|
| 15 |
+
p.add_argument("--sr", type=int, default=16000)
|
| 16 |
+
p.add_argument("--padding_side", type=str, default="auto",
|
| 17 |
+
choices=["auto", "left", "right"])
|
| 18 |
+
|
| 19 |
+
# training
|
| 20 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 21 |
+
p.add_argument("--grad_acc", type=int, default=8)
|
| 22 |
+
p.add_argument("--epochs", type=float, default=1)
|
| 23 |
+
p.add_argument("--lr", type=float, default=1e-5)
|
| 24 |
+
p.add_argument("--lr_encoder", type=float, default=1e-5)
|
| 25 |
+
p.add_argument("--lr_aligner", type=float, default=1e-5)
|
| 26 |
+
p.add_argument("--lr_llm", type=float, default=1e-5)
|
| 27 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 28 |
+
p.add_argument("--warmup_ratio", type=float, default=0.03)
|
| 29 |
+
p.add_argument("--lr_scheduler_type", type=str, default="linear")
|
| 30 |
+
p.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 31 |
+
p.add_argument("--log_steps", type=int, default=10)
|
| 32 |
+
p.add_argument("--report_to", type=str, default="none")
|
| 33 |
+
|
| 34 |
+
# dataloader
|
| 35 |
+
p.add_argument("--num_workers", type=int, default=4)
|
| 36 |
+
p.add_argument("--pin_memory", type=int, default=1)
|
| 37 |
+
p.add_argument("--persistent_workers", type=int, default=1)
|
| 38 |
+
p.add_argument("--prefetch_factor", type=int, default=2)
|
| 39 |
+
|
| 40 |
+
# save / resume
|
| 41 |
+
p.add_argument("--save_steps", type=int, default=200)
|
| 42 |
+
p.add_argument("--save_total_limit", type=int, default=5)
|
| 43 |
+
p.add_argument("--resume", type=int, default=0)
|
| 44 |
+
p.add_argument("--resume_from", type=str, default="")
|
| 45 |
+
|
| 46 |
+
# lora
|
| 47 |
+
p.add_argument("--use_lora", type=int, default=1)
|
| 48 |
+
p.add_argument("--lora_scope", type=str, default="encoder_aligner",
|
| 49 |
+
choices=["encoder", "aligner", "encoder_aligner",
|
| 50 |
+
"encoder_b4_aligner", "llm", "all"])
|
| 51 |
+
p.add_argument("--lora_r", type=int, default=8)
|
| 52 |
+
p.add_argument("--lora_alpha", type=int, default=16)
|
| 53 |
+
p.add_argument("--lora_dropout", type=float, default=0.05)
|
| 54 |
+
p.add_argument("--lora_bias", type=str, default="none")
|
| 55 |
+
p.add_argument("--merge_lora_into_base_from", type=str, default="")
|
| 56 |
+
|
| 57 |
+
return p.parse_args()
|
vendor/MegaASR/A2S-SFT/checkpointing.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from transformers import TrainerCallback, TrainingArguments
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_CKPT_RE = re.compile(r"^checkpoint-(\d+)$")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def find_latest_checkpoint(output_dir: str) -> Optional[str]:
|
| 14 |
+
if not output_dir or not os.path.isdir(output_dir):
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
best_step, best_path = -1, None
|
| 18 |
+
for name in os.listdir(output_dir):
|
| 19 |
+
match = _CKPT_RE.match(name)
|
| 20 |
+
if not match:
|
| 21 |
+
continue
|
| 22 |
+
path = os.path.join(output_dir, name)
|
| 23 |
+
step = int(match.group(1))
|
| 24 |
+
if os.path.isdir(path) and step > best_step:
|
| 25 |
+
best_step, best_path = step, path
|
| 26 |
+
return best_path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def copy_hf_files(src_dir: str, dst_dir: str):
|
| 30 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 31 |
+
for name in [
|
| 32 |
+
"config.json",
|
| 33 |
+
"generation_config.json",
|
| 34 |
+
"preprocessor_config.json",
|
| 35 |
+
"processor_config.json",
|
| 36 |
+
"tokenizer_config.json",
|
| 37 |
+
"tokenizer.json",
|
| 38 |
+
"special_tokens_map.json",
|
| 39 |
+
"chat_template.json",
|
| 40 |
+
"merges.txt",
|
| 41 |
+
"vocab.json",
|
| 42 |
+
]:
|
| 43 |
+
src = os.path.join(src_dir, name)
|
| 44 |
+
if os.path.exists(src):
|
| 45 |
+
shutil.copy2(src, os.path.join(dst_dir, name))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MakeCheckpointInferableCallback(TrainerCallback):
|
| 49 |
+
"""Copy tokenizer/config files into every adapter checkpoint."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, base_model_path: str):
|
| 52 |
+
self.base_model_path = base_model_path
|
| 53 |
+
|
| 54 |
+
def on_save(self, args: TrainingArguments, state, control, **kwargs):
|
| 55 |
+
if args.process_index != 0:
|
| 56 |
+
return control
|
| 57 |
+
|
| 58 |
+
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
| 59 |
+
if not os.path.isdir(ckpt_dir):
|
| 60 |
+
ckpt_dir = kwargs.get("checkpoint", ckpt_dir)
|
| 61 |
+
|
| 62 |
+
copy_hf_files(self.base_model_path, ckpt_dir)
|
| 63 |
+
return control
|
vendor/MegaASR/A2S-SFT/dataloader.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def read_audio(path: str, sr: int = 16000):
|
| 11 |
+
return librosa.load(path, sr=sr, mono=True)[0]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def audio_messages(prompt: str):
|
| 15 |
+
return [
|
| 16 |
+
{"role": "system", "content": prompt or ""},
|
| 17 |
+
{"role": "user", "content": [{"type": "audio", "audio": None}]},
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class Qwen3ASRCollator:
|
| 23 |
+
processor: Any
|
| 24 |
+
sampling_rate: int = 16000
|
| 25 |
+
|
| 26 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 27 |
+
prompts = [x.get("prompt", "") for x in features]
|
| 28 |
+
targets = [x["text"] for x in features]
|
| 29 |
+
audios = [read_audio(x["audio"], self.sampling_rate) for x in features]
|
| 30 |
+
|
| 31 |
+
prefixes = [
|
| 32 |
+
self.processor.apply_chat_template(
|
| 33 |
+
[audio_messages(p)],
|
| 34 |
+
add_generation_prompt=True,
|
| 35 |
+
tokenize=False,
|
| 36 |
+
)[0]
|
| 37 |
+
for p in prompts
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
eos = self.processor.tokenizer.eos_token or ""
|
| 41 |
+
full_texts = [p + t + eos for p, t in zip(prefixes, targets)]
|
| 42 |
+
|
| 43 |
+
batch = self.processor(
|
| 44 |
+
text=full_texts,
|
| 45 |
+
audio=audios,
|
| 46 |
+
return_tensors="pt",
|
| 47 |
+
padding=True,
|
| 48 |
+
truncation=False,
|
| 49 |
+
)
|
| 50 |
+
prefix_batch = self.processor(
|
| 51 |
+
text=prefixes,
|
| 52 |
+
audio=audios,
|
| 53 |
+
return_tensors="pt",
|
| 54 |
+
padding=True,
|
| 55 |
+
truncation=False,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
labels = batch["input_ids"].clone()
|
| 59 |
+
prefix_lens = prefix_batch["attention_mask"].sum(dim=1)
|
| 60 |
+
full_lens = batch["attention_mask"].sum(dim=1)
|
| 61 |
+
|
| 62 |
+
seq_len = labels.size(1)
|
| 63 |
+
padding_side = getattr(self.processor.tokenizer, "padding_side", "right")
|
| 64 |
+
|
| 65 |
+
for i, prefix_len in enumerate(prefix_lens):
|
| 66 |
+
start = seq_len - int(full_lens[i]) if padding_side == "left" else 0
|
| 67 |
+
labels[i, start:start + int(prefix_len)] = -100
|
| 68 |
+
|
| 69 |
+
pad_id = self.processor.tokenizer.pad_token_id
|
| 70 |
+
if pad_id is not None:
|
| 71 |
+
labels[labels == pad_id] = -100
|
| 72 |
+
|
| 73 |
+
batch["labels"] = labels
|
| 74 |
+
return batch
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_datasets(train_file: str, eval_file: str = ""):
|
| 78 |
+
files = {"train": train_file}
|
| 79 |
+
if eval_file:
|
| 80 |
+
files["validation"] = eval_file
|
| 81 |
+
return load_dataset("json", data_files=files)
|
vendor/MegaASR/A2S-SFT/finetune.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from transformers import TrainingArguments
|
| 3 |
+
|
| 4 |
+
from arguments import parse_args
|
| 5 |
+
from checkpointing import MakeCheckpointInferableCallback, find_latest_checkpoint
|
| 6 |
+
from dataloader import Qwen3ASRCollator, build_datasets
|
| 7 |
+
from modeling import apply_lora, load_qwen3_asr
|
| 8 |
+
from trainer import MegaASRTrainer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_training_args(args, use_bf16: bool):
|
| 12 |
+
report_to = [] if args.report_to.lower() in ["", "none"] else [args.report_to]
|
| 13 |
+
|
| 14 |
+
return TrainingArguments(
|
| 15 |
+
output_dir=args.output_dir,
|
| 16 |
+
per_device_train_batch_size=args.batch_size,
|
| 17 |
+
gradient_accumulation_steps=args.grad_acc,
|
| 18 |
+
learning_rate=args.lr,
|
| 19 |
+
num_train_epochs=args.epochs,
|
| 20 |
+
logging_steps=args.log_steps,
|
| 21 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
| 22 |
+
warmup_ratio=args.warmup_ratio,
|
| 23 |
+
weight_decay=args.weight_decay,
|
| 24 |
+
max_grad_norm=args.max_grad_norm,
|
| 25 |
+
dataloader_num_workers=args.num_workers,
|
| 26 |
+
dataloader_pin_memory=bool(args.pin_memory),
|
| 27 |
+
dataloader_persistent_workers=bool(args.persistent_workers),
|
| 28 |
+
dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
| 29 |
+
save_strategy="steps",
|
| 30 |
+
save_steps=args.save_steps,
|
| 31 |
+
save_total_limit=args.save_total_limit,
|
| 32 |
+
save_safetensors=True,
|
| 33 |
+
eval_strategy="steps",
|
| 34 |
+
eval_steps=args.save_steps,
|
| 35 |
+
do_eval=bool(args.eval_file),
|
| 36 |
+
bf16=use_bf16,
|
| 37 |
+
fp16=not use_bf16,
|
| 38 |
+
ddp_find_unused_parameters=False,
|
| 39 |
+
remove_unused_columns=False,
|
| 40 |
+
report_to=report_to,
|
| 41 |
+
run_name="Mega-ASR-A2S-SFT",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
args = parse_args()
|
| 47 |
+
|
| 48 |
+
model, processor, use_bf16 = load_qwen3_asr(args.model_path)
|
| 49 |
+
|
| 50 |
+
if args.padding_side != "auto":
|
| 51 |
+
processor.tokenizer.padding_side = args.padding_side
|
| 52 |
+
print("padding_side =", processor.tokenizer.padding_side)
|
| 53 |
+
|
| 54 |
+
model = apply_lora(model, args)
|
| 55 |
+
|
| 56 |
+
dataset = build_datasets(args.train_file, args.eval_file)
|
| 57 |
+
collator = Qwen3ASRCollator(processor=processor, sampling_rate=args.sr)
|
| 58 |
+
training_args = build_training_args(args, use_bf16)
|
| 59 |
+
|
| 60 |
+
trainer = MegaASRTrainer(
|
| 61 |
+
model=model,
|
| 62 |
+
args=training_args,
|
| 63 |
+
train_dataset=dataset["train"],
|
| 64 |
+
eval_dataset=dataset.get("validation", None),
|
| 65 |
+
data_collator=collator,
|
| 66 |
+
processing_class=processor,
|
| 67 |
+
callbacks=[MakeCheckpointInferableCallback(args.model_path)],
|
| 68 |
+
processor=processor,
|
| 69 |
+
base_model_path=args.model_path,
|
| 70 |
+
merged_from_lora_path=args.merge_lora_into_base_from.strip(),
|
| 71 |
+
lr_encoder=args.lr_encoder,
|
| 72 |
+
lr_aligner=args.lr_aligner,
|
| 73 |
+
lr_llm=args.lr_llm,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
resume_from = args.resume_from.strip()
|
| 77 |
+
if not resume_from and args.resume:
|
| 78 |
+
resume_from = find_latest_checkpoint(args.output_dir) or ""
|
| 79 |
+
|
| 80 |
+
if resume_from:
|
| 81 |
+
print(f"[resume] {resume_from}")
|
| 82 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 83 |
+
else:
|
| 84 |
+
trainer.train()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
main()
|
vendor/MegaASR/A2S-SFT/modeling.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
import torch
|
| 3 |
+
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
| 4 |
+
from transformers import GenerationConfig
|
| 5 |
+
from qwen_asr import Qwen3ASRModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
LORA_TARGETS = {
|
| 9 |
+
"encoder": r"^audio_tower\.layers\.\d+\..*\.(q_proj|k_proj|v_proj|out_proj|fc1|fc2)$",
|
| 10 |
+
"aligner": r"^audio_tower\.(conv_out|proj1|proj2)$",
|
| 11 |
+
"encoder_aligner": (
|
| 12 |
+
r"^(audio_tower\.(conv_out|proj1|proj2)$"
|
| 13 |
+
r"|audio_tower\.layers\.\d+\..*\.(q_proj|k_proj|v_proj|out_proj|fc1|fc2)$)"
|
| 14 |
+
),
|
| 15 |
+
"encoder_b4_aligner": (
|
| 16 |
+
r"^(audio_tower\.(conv_out|proj1|proj2)$"
|
| 17 |
+
r"|audio_tower\.layers\.(20|21|22|23)\..*\.(q_proj|k_proj|v_proj|out_proj|fc1|fc2)$)"
|
| 18 |
+
),
|
| 19 |
+
"llm": r"^model\.layers\.\d+\..*\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$",
|
| 20 |
+
"all": (
|
| 21 |
+
r"^(audio_tower\.(conv_out|proj1|proj2)$"
|
| 22 |
+
r"|audio_tower\.layers\.\d+\..*\.(q_proj|k_proj|v_proj|out_proj|fc1|fc2)$"
|
| 23 |
+
r"|model\.layers\.\d+\..*\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$)"
|
| 24 |
+
),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def patch_outer_forward(model):
|
| 29 |
+
cls = model.__class__
|
| 30 |
+
if getattr(cls, "_forward_patched", False):
|
| 31 |
+
return
|
| 32 |
+
if not hasattr(model, "thinker"):
|
| 33 |
+
raise RuntimeError("Qwen3-ASR wrapper has no `thinker` module.")
|
| 34 |
+
|
| 35 |
+
def forward(self, input_ids=None, attention_mask=None, input_features=None,
|
| 36 |
+
feature_attention_mask=None, labels=None, **kwargs):
|
| 37 |
+
return self.thinker.forward(
|
| 38 |
+
input_ids=input_ids,
|
| 39 |
+
attention_mask=attention_mask,
|
| 40 |
+
input_features=input_features,
|
| 41 |
+
feature_attention_mask=feature_attention_mask,
|
| 42 |
+
labels=labels,
|
| 43 |
+
**kwargs,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
cls.forward = forward
|
| 47 |
+
cls._forward_patched = True
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_qwen3_asr(model_path: str):
|
| 51 |
+
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
|
| 52 |
+
wrapper = Qwen3ASRModel.from_pretrained(
|
| 53 |
+
model_path,
|
| 54 |
+
dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
| 55 |
+
device_map=None,
|
| 56 |
+
)
|
| 57 |
+
model, processor = wrapper.model, wrapper.processor
|
| 58 |
+
patch_outer_forward(model)
|
| 59 |
+
model.generation_config = GenerationConfig.from_model_config(model.config)
|
| 60 |
+
return model, processor, use_bf16
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def apply_lora(model, args):
|
| 64 |
+
if not args.use_lora:
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
old_lora = args.merge_lora_into_base_from.strip()
|
| 68 |
+
if old_lora:
|
| 69 |
+
if args.resume or args.resume_from.strip():
|
| 70 |
+
raise ValueError("Do not use --merge_lora_into_base_from with --resume.")
|
| 71 |
+
print(f"[merge_lora] {old_lora}")
|
| 72 |
+
model.thinker = PeftModel.from_pretrained(
|
| 73 |
+
model.thinker, old_lora, is_trainable=False
|
| 74 |
+
).merge_and_unload()
|
| 75 |
+
|
| 76 |
+
for param in model.parameters():
|
| 77 |
+
param.requires_grad = False
|
| 78 |
+
|
| 79 |
+
lora_config = LoraConfig(
|
| 80 |
+
r=args.lora_r,
|
| 81 |
+
lora_alpha=args.lora_alpha,
|
| 82 |
+
lora_dropout=args.lora_dropout,
|
| 83 |
+
bias=args.lora_bias,
|
| 84 |
+
task_type=TaskType.CAUSAL_LM,
|
| 85 |
+
target_modules=LORA_TARGETS[args.lora_scope],
|
| 86 |
+
)
|
| 87 |
+
model.thinker = get_peft_model(model.thinker, lora_config)
|
| 88 |
+
model.thinker.print_trainable_parameters()
|
| 89 |
+
return model
|
vendor/MegaASR/A2S-SFT/readme.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## A2S-SFT Training
|
| 2 |
+
|
| 3 |
+
`src/MegaASR/A2S-SFT` contains the core training code for Mega-ASR supervised fine-tuning. It is designed for Qwen3-ASR-style speech-to-text models and supports LoRA training on different parts of the model.
|
| 4 |
+
|
| 5 |
+
```text
|
| 6 |
+
src/MegaASR/A2S-SFT/
|
| 7 |
+
├── arguments.py # Defines training arguments and hyperparameters.
|
| 8 |
+
├── checkpointing.py # Saves base-model metadata and processor/tokenizer files for LoRA reuse.
|
| 9 |
+
├── dataloader.py # Loads JSONL data, reads audio, builds inputs, and masks non-target labels.
|
| 10 |
+
├── finetune.py # Main entry point for launching A2S-SFT training.
|
| 11 |
+
├── modeling.py # Loads the model and defines LoRA injection scopes.
|
| 12 |
+
├── trainer.py # Defines MegaASRTrainer with adapter-only saving and module-wise learning rates.
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
### Model and LoRA Scope
|
| 16 |
+
|
| 17 |
+
Choose the base ASR model with `--model_path`. The LoRA training range is controlled by `--lora_scope`:
|
| 18 |
+
|
| 19 |
+
```text
|
| 20 |
+
encoder # speech encoder only
|
| 21 |
+
aligner # audio-text aligner / projector only
|
| 22 |
+
encoder_aligner # speech encoder + aligner
|
| 23 |
+
encoder_b4_aligner # last four encoder layers + aligner
|
| 24 |
+
llm # language model only
|
| 25 |
+
all # encoder + aligner + language model
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
In our training pipeline, we use a progressive strategy:
|
| 29 |
+
|
| 30 |
+
```text
|
| 31 |
+
Stage 1: encoder_aligner
|
| 32 |
+
First adapt the speech encoder and audio-text aligner for robust acoustic perception and alignment.
|
| 33 |
+
|
| 34 |
+
Stage 2: llm
|
| 35 |
+
Then adapt the LLM to improve transcription generation under degraded acoustic conditions.
|
| 36 |
+
|
| 37 |
+
Stage 3: all
|
| 38 |
+
Finally tune encoder, aligner, and LLM together for joint optimization.
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
The learning rates of the three parts can be set separately:
|
| 42 |
+
|
| 43 |
+
```text
|
| 44 |
+
--lr_encoder # learning rate for speech encoder LoRA
|
| 45 |
+
--lr_aligner # learning rate for audio-text aligner LoRA
|
| 46 |
+
--lr_llm # learning rate for LLM LoRA
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
If a later stage starts from a previous LoRA checkpoint, use:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
--merge_lora_into_base_from ${PREVIOUS_LORA_DIR}
|
| 53 |
+
```
|
vendor/MegaASR/A2S-SFT/trainer.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from safetensors.torch import load_file as safe_load_file
|
| 7 |
+
from transformers import Trainer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MegaASRTrainer(Trainer):
|
| 11 |
+
"""Trainer for Mega-ASR LoRA SFT."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, *args, processor=None, base_model_path: str = "",
|
| 14 |
+
merged_from_lora_path: str = "", lr_encoder: float = 1e-5,
|
| 15 |
+
lr_aligner: float = 1e-5, lr_llm: float = 1e-5, **kwargs):
|
| 16 |
+
super().__init__(*args, **kwargs)
|
| 17 |
+
self.processor = processor
|
| 18 |
+
self.base_model_path = base_model_path
|
| 19 |
+
self.merged_from_lora_path = merged_from_lora_path
|
| 20 |
+
self.lr_encoder = lr_encoder
|
| 21 |
+
self.lr_aligner = lr_aligner
|
| 22 |
+
self.lr_llm = lr_llm
|
| 23 |
+
|
| 24 |
+
def _prepare_inputs(self, inputs):
|
| 25 |
+
inputs = super()._prepare_inputs(inputs)
|
| 26 |
+
dtype = getattr(self.model, "dtype", None)
|
| 27 |
+
if dtype is None:
|
| 28 |
+
return inputs
|
| 29 |
+
for k, v in list(inputs.items()):
|
| 30 |
+
if torch.is_tensor(v) and v.is_floating_point():
|
| 31 |
+
inputs[k] = v.to(dtype=dtype)
|
| 32 |
+
return inputs
|
| 33 |
+
|
| 34 |
+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
| 35 |
+
output_dir = output_dir or self.args.output_dir
|
| 36 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 37 |
+
self.model.thinker.save_pretrained(output_dir, safe_serialization=True)
|
| 38 |
+
|
| 39 |
+
if self.processor is not None:
|
| 40 |
+
self.processor.save_pretrained(output_dir)
|
| 41 |
+
self._write_text(output_dir, "base_model.txt", self.base_model_path)
|
| 42 |
+
self._write_text(output_dir, "merged_from_lora.txt", self.merged_from_lora_path)
|
| 43 |
+
|
| 44 |
+
for name in ["model.safetensors", "pytorch_model.bin",
|
| 45 |
+
"model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
| 46 |
+
path = os.path.join(output_dir, name)
|
| 47 |
+
if os.path.exists(path):
|
| 48 |
+
os.remove(path)
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def _write_text(output_dir: str, name: str, text: str):
|
| 52 |
+
if text:
|
| 53 |
+
with open(os.path.join(output_dir, name), "w", encoding="utf-8") as f:
|
| 54 |
+
f.write(text + "\n")
|
| 55 |
+
|
| 56 |
+
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
|
| 57 |
+
model = model or self.model
|
| 58 |
+
adapter_path = os.path.join(resume_from_checkpoint, "adapter_model.safetensors")
|
| 59 |
+
if os.path.isfile(adapter_path):
|
| 60 |
+
model.thinker.load_state_dict(safe_load_file(adapter_path), strict=False)
|
| 61 |
+
return
|
| 62 |
+
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _group_name(name: str) -> str:
|
| 66 |
+
if "lora_" not in name:
|
| 67 |
+
return "other"
|
| 68 |
+
if any(x in name for x in ["audio_tower.conv_out", "audio_tower.proj1", "audio_tower.proj2"]):
|
| 69 |
+
return "aligner"
|
| 70 |
+
if "audio_tower.layers." in name:
|
| 71 |
+
return "encoder"
|
| 72 |
+
if "model.layers." in name and "audio_tower.layers." not in name:
|
| 73 |
+
return "llm"
|
| 74 |
+
return "other"
|
| 75 |
+
|
| 76 |
+
def create_optimizer(self):
|
| 77 |
+
if self.optimizer is not None:
|
| 78 |
+
return self.optimizer
|
| 79 |
+
|
| 80 |
+
groups = {"encoder": [], "aligner": [], "llm": [], "other": []}
|
| 81 |
+
for name, param in self.model.named_parameters():
|
| 82 |
+
if param.requires_grad:
|
| 83 |
+
groups[self._group_name(name)].append(param)
|
| 84 |
+
|
| 85 |
+
lrs = {"encoder": self.lr_encoder, "aligner": self.lr_aligner,
|
| 86 |
+
"llm": self.lr_llm, "other": self.args.learning_rate}
|
| 87 |
+
optim_groups = [
|
| 88 |
+
{"params": params, "lr": lrs[name], "weight_decay": self.args.weight_decay}
|
| 89 |
+
for name, params in groups.items() if params
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
if self.args.process_index == 0:
|
| 93 |
+
for name, params in groups.items():
|
| 94 |
+
print(f"[optimizer] {name:7s}: {sum(p.numel() for p in params)} params")
|
| 95 |
+
|
| 96 |
+
self.optimizer = torch.optim.AdamW(
|
| 97 |
+
optim_groups,
|
| 98 |
+
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
| 99 |
+
eps=self.args.adam_epsilon,
|
| 100 |
+
)
|
| 101 |
+
return self.optimizer
|
vendor/MegaASR/DG-WGPO/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DAPO-LoRA Training
|
| 2 |
+
|
| 3 |
+
We use Verl for RL code base. Our code is coming soon ~
|
vendor/MegaASR/data/download_librispeech.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Convert LibriSpeech metadata JSONL to Mega-ASR SFT JSONL.
|
| 6 |
+
|
| 7 |
+
Input example:
|
| 8 |
+
{
|
| 9 |
+
"index": 0,
|
| 10 |
+
"audio_path": ".../xxx.flac",
|
| 11 |
+
"answer": "THE TRANSCRIPT TEXT",
|
| 12 |
+
"subset": "test_clean",
|
| 13 |
+
"task_type": "understanding"
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
Output example:
|
| 17 |
+
{
|
| 18 |
+
"audio": ".../xxx.flac",
|
| 19 |
+
"text": "language English<asr_text>THE TRANSCRIPT TEXT",
|
| 20 |
+
"prompt": ""
|
| 21 |
+
}
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def normalize_text(text: str, case: str) -> str:
|
| 30 |
+
text = " ".join(str(text).strip().split())
|
| 31 |
+
|
| 32 |
+
if case == "lower":
|
| 33 |
+
return text.lower()
|
| 34 |
+
if case == "upper":
|
| 35 |
+
return text.upper()
|
| 36 |
+
if case == "none":
|
| 37 |
+
return text
|
| 38 |
+
|
| 39 |
+
raise ValueError(f"Unknown case mode: {case}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def convert_one(item: dict, text_case: str = "none", language: str = "English") -> dict:
|
| 43 |
+
audio = item.get("audio_path") or item.get("audio")
|
| 44 |
+
answer = item.get("answer") or item.get("text")
|
| 45 |
+
|
| 46 |
+
if not audio:
|
| 47 |
+
raise ValueError(f"Missing audio_path/audio in item: {item}")
|
| 48 |
+
if answer is None:
|
| 49 |
+
raise ValueError(f"Missing answer/text in item: {item}")
|
| 50 |
+
|
| 51 |
+
answer = normalize_text(answer, text_case)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"audio": audio,
|
| 55 |
+
"text": f"language {language}<asr_text>{answer}",
|
| 56 |
+
"prompt": "",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
parser = argparse.ArgumentParser()
|
| 62 |
+
parser.add_argument("--input_jsonl", type=str, required=True)
|
| 63 |
+
parser.add_argument("--output_jsonl", type=str, required=True)
|
| 64 |
+
parser.add_argument("--language", type=str, default="English")
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--text_case",
|
| 67 |
+
type=str,
|
| 68 |
+
default="none",
|
| 69 |
+
choices=["none", "lower", "upper"],
|
| 70 |
+
help="LibriSpeech transcripts are usually uppercase. Use none to preserve original text.",
|
| 71 |
+
)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
input_path = Path(args.input_jsonl)
|
| 75 |
+
output_path = Path(args.output_jsonl)
|
| 76 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
count = 0
|
| 79 |
+
with input_path.open("r", encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
|
| 80 |
+
for line in fin:
|
| 81 |
+
line = line.strip()
|
| 82 |
+
if not line:
|
| 83 |
+
continue
|
| 84 |
+
item = json.loads(line)
|
| 85 |
+
out = convert_one(item, text_case=args.text_case, language=args.language)
|
| 86 |
+
fout.write(json.dumps(out, ensure_ascii=False) + "\n")
|
| 87 |
+
count += 1
|
| 88 |
+
|
| 89 |
+
print(f"[done] converted {count} samples")
|
| 90 |
+
print(f"[output] {output_path}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
vendor/MegaASR/data/download_librispeech.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python download_librispeech.py \
|
| 2 |
+
--output_dir data/LibriSpeech_test \
|
| 3 |
+
--subsets test-clean,test-other
|
vendor/MegaASR/eval/cn_tn.py
ADDED
|
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# copied from https://github.com/speechio/chinese_text_normalization/blob/master/python/cn_tn.py
|
| 4 |
+
# Authors:
|
| 5 |
+
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
|
| 6 |
+
# 2019.9 - 2022 Jiayu DU
|
| 7 |
+
#
|
| 8 |
+
# requirements:
|
| 9 |
+
# - python 3.X
|
| 10 |
+
# notes: python 2.X WILL fail or produce misleading results
|
| 11 |
+
|
| 12 |
+
import sys, os, argparse
|
| 13 |
+
import string, re
|
| 14 |
+
import csv
|
| 15 |
+
|
| 16 |
+
# ================================================================================ #
|
| 17 |
+
# basic constant
|
| 18 |
+
# ================================================================================ #
|
| 19 |
+
CHINESE_DIGIS = u'零一二三四五六七八九'
|
| 20 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
|
| 21 |
+
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
|
| 22 |
+
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
|
| 23 |
+
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
|
| 24 |
+
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
|
| 25 |
+
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
|
| 26 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
|
| 27 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
|
| 28 |
+
|
| 29 |
+
ZERO_ALT = u'〇'
|
| 30 |
+
ONE_ALT = u'幺'
|
| 31 |
+
TWO_ALTS = [u'两', u'兩']
|
| 32 |
+
|
| 33 |
+
POSITIVE = [u'正', u'正']
|
| 34 |
+
NEGATIVE = [u'负', u'負']
|
| 35 |
+
POINT = [u'点', u'點']
|
| 36 |
+
# PLUS = [u'加', u'加']
|
| 37 |
+
# SIL = [u'杠', u'槓']
|
| 38 |
+
|
| 39 |
+
FILLER_CHARS = ['呃', '啊']
|
| 40 |
+
|
| 41 |
+
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
|
| 42 |
+
'胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
|
| 43 |
+
'儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
|
| 44 |
+
'佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
|
| 45 |
+
ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
|
| 46 |
+
|
| 47 |
+
# 中文数字系统类型
|
| 48 |
+
NUMBERING_TYPES = ['low', 'mid', 'high']
|
| 49 |
+
|
| 50 |
+
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
|
| 51 |
+
'里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
|
| 52 |
+
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
| 53 |
+
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
|
| 54 |
+
'砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
|
| 55 |
+
'针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
|
| 56 |
+
'毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
|
| 57 |
+
'盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
|
| 58 |
+
'纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
|
| 62 |
+
CN_PUNCS_STOP = '!?。。'
|
| 63 |
+
CN_PUNCS_NONSTOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-'
|
| 64 |
+
CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
|
| 65 |
+
|
| 66 |
+
PUNCS = CN_PUNCS + string.punctuation
|
| 67 |
+
PUNCS_TRANSFORM = str.maketrans(PUNCS, ' ' * len(PUNCS), '') # replace puncs with space
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# https://zh.wikipedia.org/wiki/全行和半行
|
| 71 |
+
QJ2BJ = {
|
| 72 |
+
' ': ' ',
|
| 73 |
+
'!': '!',
|
| 74 |
+
'"': '"',
|
| 75 |
+
'#': '#',
|
| 76 |
+
'$': '$',
|
| 77 |
+
'%': '%',
|
| 78 |
+
'&': '&',
|
| 79 |
+
''': "'",
|
| 80 |
+
'(': '(',
|
| 81 |
+
')': ')',
|
| 82 |
+
'*': '*',
|
| 83 |
+
'+': '+',
|
| 84 |
+
',': ',',
|
| 85 |
+
'-': '-',
|
| 86 |
+
'.': '.',
|
| 87 |
+
'/': '/',
|
| 88 |
+
'0': '0',
|
| 89 |
+
'1': '1',
|
| 90 |
+
'2': '2',
|
| 91 |
+
'3': '3',
|
| 92 |
+
'4': '4',
|
| 93 |
+
'5': '5',
|
| 94 |
+
'6': '6',
|
| 95 |
+
'7': '7',
|
| 96 |
+
'8': '8',
|
| 97 |
+
'9': '9',
|
| 98 |
+
':': ':',
|
| 99 |
+
';': ';',
|
| 100 |
+
'<': '<',
|
| 101 |
+
'=': '=',
|
| 102 |
+
'>': '>',
|
| 103 |
+
'?': '?',
|
| 104 |
+
'@': '@',
|
| 105 |
+
'A': 'A',
|
| 106 |
+
'B': 'B',
|
| 107 |
+
'C': 'C',
|
| 108 |
+
'D': 'D',
|
| 109 |
+
'E': 'E',
|
| 110 |
+
'F': 'F',
|
| 111 |
+
'G': 'G',
|
| 112 |
+
'H': 'H',
|
| 113 |
+
'I': 'I',
|
| 114 |
+
'J': 'J',
|
| 115 |
+
'K': 'K',
|
| 116 |
+
'L': 'L',
|
| 117 |
+
'M': 'M',
|
| 118 |
+
'N': 'N',
|
| 119 |
+
'O': 'O',
|
| 120 |
+
'P': 'P',
|
| 121 |
+
'Q': 'Q',
|
| 122 |
+
'R': 'R',
|
| 123 |
+
'S': 'S',
|
| 124 |
+
'T': 'T',
|
| 125 |
+
'U': 'U',
|
| 126 |
+
'V': 'V',
|
| 127 |
+
'W': 'W',
|
| 128 |
+
'X': 'X',
|
| 129 |
+
'Y': 'Y',
|
| 130 |
+
'Z': 'Z',
|
| 131 |
+
'[': '[',
|
| 132 |
+
'\': '\\',
|
| 133 |
+
']': ']',
|
| 134 |
+
'^': '^',
|
| 135 |
+
'_': '_',
|
| 136 |
+
'`': '`',
|
| 137 |
+
'a': 'a',
|
| 138 |
+
'b': 'b',
|
| 139 |
+
'c': 'c',
|
| 140 |
+
'd': 'd',
|
| 141 |
+
'e': 'e',
|
| 142 |
+
'f': 'f',
|
| 143 |
+
'g': 'g',
|
| 144 |
+
'h': 'h',
|
| 145 |
+
'i': 'i',
|
| 146 |
+
'j': 'j',
|
| 147 |
+
'k': 'k',
|
| 148 |
+
'l': 'l',
|
| 149 |
+
'm': 'm',
|
| 150 |
+
'n': 'n',
|
| 151 |
+
'o': 'o',
|
| 152 |
+
'p': 'p',
|
| 153 |
+
'q': 'q',
|
| 154 |
+
'r': 'r',
|
| 155 |
+
's': 's',
|
| 156 |
+
't': 't',
|
| 157 |
+
'u': 'u',
|
| 158 |
+
'v': 'v',
|
| 159 |
+
'w': 'w',
|
| 160 |
+
'x': 'x',
|
| 161 |
+
'y': 'y',
|
| 162 |
+
'z': 'z',
|
| 163 |
+
'{': '{',
|
| 164 |
+
'|': '|',
|
| 165 |
+
'}': '}',
|
| 166 |
+
'~': '~',
|
| 167 |
+
}
|
| 168 |
+
QJ2BJ_TRANSFORM = str.maketrans(''.join(QJ2BJ.keys()), ''.join(QJ2BJ.values()), '')
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources:
|
| 172 |
+
# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
|
| 173 |
+
CN_CHARS_COMMON = (
|
| 174 |
+
'一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举'
|
| 175 |
+
'乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互'
|
| 176 |
+
'亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从'
|
| 177 |
+
'仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优'
|
| 178 |
+
'伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚'
|
| 179 |
+
'佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣'
|
| 180 |
+
'侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯'
|
| 181 |
+
'俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌'
|
| 182 |
+
'偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚'
|
| 183 |
+
'僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六'
|
| 184 |
+
'兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况'
|
| 185 |
+
'冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈'
|
| 186 |
+
'刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐'
|
| 187 |
+
'剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼'
|
| 188 |
+
'劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹'
|
| 189 |
+
'区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵'
|
| 190 |
+
'卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔'
|
| 191 |
+
'叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊'
|
| 192 |
+
'同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆'
|
| 193 |
+
'呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎'
|
| 194 |
+
'咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌'
|
| 195 |
+
'响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛'
|
| 196 |
+
'唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴'
|
| 197 |
+
'啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌'
|
| 198 |
+
'嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡'
|
| 199 |
+
'嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓'
|
| 200 |
+
'嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢'
|
| 201 |
+
'圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥'
|
| 202 |
+
'坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩'
|
| 203 |
+
'垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基'
|
| 204 |
+
'埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填'
|
| 205 |
+
'塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复'
|
| 206 |
+
'夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖'
|
| 207 |
+
'套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮'
|
| 208 |
+
'妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈'
|
| 209 |
+
'娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻'
|
| 210 |
+
'婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱'
|
| 211 |
+
'嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽'
|
| 212 |
+
'宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾'
|
| 213 |
+
'宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝'
|
| 214 |
+
'尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山'
|
| 215 |
+
'屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃'
|
| 216 |
+
'峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧'
|
| 217 |
+
'崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉'
|
| 218 |
+
'巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡'
|
| 219 |
+
'带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐'
|
| 220 |
+
'庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷'
|
| 221 |
+
'建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖'
|
| 222 |
+
'彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循'
|
| 223 |
+
'徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀'
|
| 224 |
+
'态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓'
|
| 225 |
+
'恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟'
|
| 226 |
+
'悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭'
|
| 227 |
+
'惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥'
|
| 228 |
+
'慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我'
|
| 229 |
+
'戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔'
|
| 230 |
+
'托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡'
|
| 231 |
+
'抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥'
|
| 232 |
+
'拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫'
|
| 233 |
+
'振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎'
|
| 234 |
+
'掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭'
|
| 235 |
+
'揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘'
|
| 236 |
+
'摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢'
|
| 237 |
+
'擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整'
|
| 238 |
+
'敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗'
|
| 239 |
+
'旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝'
|
| 240 |
+
'星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡'
|
| 241 |
+
'晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜'
|
| 242 |
+
'曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽'
|
| 243 |
+
'杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅'
|
| 244 |
+
'枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔'
|
| 245 |
+
'柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩'
|
| 246 |
+
'株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯'
|
| 247 |
+
'桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘'
|
| 248 |
+
'棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂'
|
| 249 |
+
'楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱'
|
| 250 |
+
'榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞'
|
| 251 |
+
'橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正'
|
| 252 |
+
'此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒'
|
| 253 |
+
'毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮'
|
| 254 |
+
'氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽'
|
| 255 |
+
'汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽'
|
| 256 |
+
'沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼'
|
| 257 |
+
'泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿'
|
| 258 |
+
'流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸'
|
| 259 |
+
'浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅'
|
| 260 |
+
'淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟'
|
| 261 |
+
'渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆'
|
| 262 |
+
'溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕'
|
| 263 |
+
'滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶'
|
| 264 |
+
'漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧'
|
| 265 |
+
'澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵'
|
| 266 |
+
'灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈'
|
| 267 |
+
'烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯'
|
| 268 |
+
'焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵'
|
| 269 |
+
'熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍'
|
| 270 |
+
'牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷'
|
| 271 |
+
'犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎'
|
| 272 |
+
'猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎'
|
| 273 |
+
'玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊'
|
| 274 |
+
'珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊'
|
| 275 |
+
'琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙'
|
| 276 |
+
'瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱'
|
| 277 |
+
'璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯'
|
| 278 |
+
'田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐'
|
| 279 |
+
'疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒'
|
| 280 |
+
'痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩'
|
| 281 |
+
'瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙'
|
| 282 |
+
'皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾'
|
| 283 |
+
'省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦'
|
| 284 |
+
'睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知'
|
| 285 |
+
'矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮'
|
| 286 |
+
'砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍'
|
| 287 |
+
'碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷'
|
| 288 |
+
'磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭'
|
| 289 |
+
'祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣'
|
| 290 |
+
'秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗'
|
| 291 |
+
'穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立'
|
| 292 |
+
'竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯'
|
| 293 |
+
'笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅'
|
| 294 |
+
'箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾'
|
| 295 |
+
'簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮'
|
| 296 |
+
'粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢'
|
| 297 |
+
'縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁'
|
| 298 |
+
'绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩'
|
| 299 |
+
'绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕'
|
| 300 |
+
'编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐'
|
| 301 |
+
'网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲'
|
| 302 |
+
'羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者'
|
| 303 |
+
'耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋'
|
| 304 |
+
'职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸'
|
| 305 |
+
'肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳'
|
| 306 |
+
'胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒'
|
| 307 |
+
'腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻'
|
| 308 |
+
'臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般'
|
| 309 |
+
'舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊'
|
| 310 |
+
'芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈'
|
| 311 |
+
'苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆'
|
| 312 |
+
'茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐'
|
| 313 |
+
'荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛'
|
| 314 |
+
'莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥'
|
| 315 |
+
'菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著'
|
| 316 |
+
'葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺'
|
| 317 |
+
'蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷'
|
| 318 |
+
'蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸'
|
| 319 |
+
'薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱'
|
| 320 |
+
'虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆'
|
| 321 |
+
'蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗'
|
| 322 |
+
'蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃'
|
| 323 |
+
'螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡'
|
| 324 |
+
'蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒'
|
| 325 |
+
'袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂'
|
| 326 |
+
'褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉'
|
| 327 |
+
'觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认'
|
| 328 |
+
'讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词'
|
| 329 |
+
'诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请'
|
| 330 |
+
'诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡'
|
| 331 |
+
'谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹'
|
| 332 |
+
'豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼'
|
| 333 |
+
'贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤'
|
| 334 |
+
'赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑'
|
| 335 |
+
'跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮'
|
| 336 |
+
'踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇'
|
| 337 |
+
'躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较'
|
| 338 |
+
'辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄'
|
| 339 |
+
'迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆'
|
| 340 |
+
'选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒'
|
| 341 |
+
'道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱'
|
| 342 |
+
'邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴'
|
| 343 |
+
'郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝'
|
| 344 |
+
'酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭'
|
| 345 |
+
'醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒'
|
| 346 |
+
'钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼'
|
| 347 |
+
'钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨'
|
| 348 |
+
'铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐'
|
| 349 |
+
'锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹'
|
| 350 |
+
'锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣'
|
| 351 |
+
'镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼'
|
| 352 |
+
'闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶'
|
| 353 |
+
'阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈'
|
| 354 |
+
'隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳'
|
| 355 |
+
'零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰'
|
| 356 |
+
'靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵'
|
| 357 |
+
'韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额'
|
| 358 |
+
'颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰'
|
| 359 |
+
'饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥'
|
| 360 |
+
'馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑'
|
| 361 |
+
'骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高'
|
| 362 |
+
'髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾'
|
| 363 |
+
'鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨'
|
| 364 |
+
'鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓'
|
| 365 |
+
'鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶'
|
| 366 |
+
'鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟'
|
| 367 |
+
'鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄'
|
| 368 |
+
'黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷'
|
| 369 |
+
'鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃'
|
| 370 |
+
'㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡'
|
| 371 |
+
'䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽'
|
| 372 |
+
'𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯'
|
| 373 |
+
'𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟'
|
| 374 |
+
'𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟'
|
| 375 |
+
'𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟'
|
| 376 |
+
'𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓'
|
| 377 |
+
)
|
| 378 |
+
CN_CHARS_EXT = '吶诶屌囧飚屄'
|
| 379 |
+
|
| 380 |
+
CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
|
| 381 |
+
IN_CH_CHARS = { c : True for c in CN_CHARS }
|
| 382 |
+
|
| 383 |
+
EN_CHARS = string.ascii_letters + string.digits
|
| 384 |
+
IN_EN_CHARS = { c : True for c in EN_CHARS }
|
| 385 |
+
|
| 386 |
+
VALID_CHARS = CN_CHARS + EN_CHARS + ' '
|
| 387 |
+
IN_VALID_CHARS = { c : True for c in VALID_CHARS }
|
| 388 |
+
|
| 389 |
+
# ================================================================================ #
|
| 390 |
+
# basic class
|
| 391 |
+
# ================================================================================ #
|
| 392 |
+
class ChineseChar(object):
|
| 393 |
+
"""
|
| 394 |
+
中文字符
|
| 395 |
+
每个字符对应简体和繁体,
|
| 396 |
+
e.g. 简体 = '负', 繁体 = '負'
|
| 397 |
+
转换时可转换为简体或繁体
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
def __init__(self, simplified, traditional):
|
| 401 |
+
self.simplified = simplified
|
| 402 |
+
self.traditional = traditional
|
| 403 |
+
#self.__repr__ = self.__str__
|
| 404 |
+
|
| 405 |
+
def __str__(self):
|
| 406 |
+
return self.simplified or self.traditional or None
|
| 407 |
+
|
| 408 |
+
def __repr__(self):
|
| 409 |
+
return self.__str__()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class ChineseNumberUnit(ChineseChar):
|
| 413 |
+
"""
|
| 414 |
+
中文数字/数位字符
|
| 415 |
+
每个字符除繁简体外还有一个额外的大写字符
|
| 416 |
+
e.g. '陆' 和 '陸'
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
def __init__(self, power, simplified, traditional, big_s, big_t):
|
| 420 |
+
super(ChineseNumberUnit, self).__init__(simplified, traditional)
|
| 421 |
+
self.power = power
|
| 422 |
+
self.big_s = big_s
|
| 423 |
+
self.big_t = big_t
|
| 424 |
+
|
| 425 |
+
def __str__(self):
|
| 426 |
+
return '10^{}'.format(self.power)
|
| 427 |
+
|
| 428 |
+
@classmethod
|
| 429 |
+
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
| 430 |
+
|
| 431 |
+
if small_unit:
|
| 432 |
+
return ChineseNumberUnit(power=index + 1,
|
| 433 |
+
simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
|
| 434 |
+
elif numbering_type == NUMBERING_TYPES[0]:
|
| 435 |
+
return ChineseNumberUnit(power=index + 8,
|
| 436 |
+
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
| 437 |
+
elif numbering_type == NUMBERING_TYPES[1]:
|
| 438 |
+
return ChineseNumberUnit(power=(index + 2) * 4,
|
| 439 |
+
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
| 440 |
+
elif numbering_type == NUMBERING_TYPES[2]:
|
| 441 |
+
return ChineseNumberUnit(power=pow(2, index + 3),
|
| 442 |
+
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
| 443 |
+
else:
|
| 444 |
+
raise ValueError(
|
| 445 |
+
'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class ChineseNumberDigit(ChineseChar):
|
| 449 |
+
"""
|
| 450 |
+
中文数字字符
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
|
| 454 |
+
super(ChineseNumberDigit, self).__init__(simplified, traditional)
|
| 455 |
+
self.value = value
|
| 456 |
+
self.big_s = big_s
|
| 457 |
+
self.big_t = big_t
|
| 458 |
+
self.alt_s = alt_s
|
| 459 |
+
self.alt_t = alt_t
|
| 460 |
+
|
| 461 |
+
def __str__(self):
|
| 462 |
+
return str(self.value)
|
| 463 |
+
|
| 464 |
+
@classmethod
|
| 465 |
+
def create(cls, i, v):
|
| 466 |
+
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class ChineseMath(ChineseChar):
|
| 470 |
+
"""
|
| 471 |
+
中文数位字符
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
def __init__(self, simplified, traditional, symbol, expression=None):
|
| 475 |
+
super(ChineseMath, self).__init__(simplified, traditional)
|
| 476 |
+
self.symbol = symbol
|
| 477 |
+
self.expression = expression
|
| 478 |
+
self.big_s = simplified
|
| 479 |
+
self.big_t = traditional
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class NumberSystem(object):
|
| 486 |
+
"""
|
| 487 |
+
中文数字系统
|
| 488 |
+
"""
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class MathSymbol(object):
|
| 493 |
+
"""
|
| 494 |
+
用于中文数字系统的数学符号 (繁/简体), e.g.
|
| 495 |
+
positive = ['正', '正']
|
| 496 |
+
negative = ['负', '負']
|
| 497 |
+
point = ['点', '點']
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
def __init__(self, positive, negative, point):
|
| 501 |
+
self.positive = positive
|
| 502 |
+
self.negative = negative
|
| 503 |
+
self.point = point
|
| 504 |
+
|
| 505 |
+
def __iter__(self):
|
| 506 |
+
for v in self.__dict__.values():
|
| 507 |
+
yield v
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# class OtherSymbol(object):
|
| 511 |
+
# """
|
| 512 |
+
# 其他符号
|
| 513 |
+
# """
|
| 514 |
+
#
|
| 515 |
+
# def __init__(self, sil):
|
| 516 |
+
# self.sil = sil
|
| 517 |
+
#
|
| 518 |
+
# def __iter__(self):
|
| 519 |
+
# for v in self.__dict__.values():
|
| 520 |
+
# yield v
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# ================================================================================ #
|
| 524 |
+
# basic utils
|
| 525 |
+
# ================================================================================ #
|
| 526 |
+
def create_system(numbering_type=NUMBERING_TYPES[1]):
|
| 527 |
+
"""
|
| 528 |
+
根据数字系统类型返回创建相应的数字系统,默认为 mid
|
| 529 |
+
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
|
| 530 |
+
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
|
| 531 |
+
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
|
| 532 |
+
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
|
| 533 |
+
返回对应的数字系统
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
# chinese number units of '亿' and larger
|
| 537 |
+
all_larger_units = zip(
|
| 538 |
+
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
| 539 |
+
larger_units = [CNU.create(i, v, numbering_type, False)
|
| 540 |
+
for i, v in enumerate(all_larger_units)]
|
| 541 |
+
# chinese number units of '十, 百, 千, 万'
|
| 542 |
+
all_smaller_units = zip(
|
| 543 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
| 544 |
+
smaller_units = [CNU.create(i, v, small_unit=True)
|
| 545 |
+
for i, v in enumerate(all_smaller_units)]
|
| 546 |
+
# digis
|
| 547 |
+
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
|
| 548 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
|
| 549 |
+
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
|
| 550 |
+
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
|
| 551 |
+
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
|
| 552 |
+
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
|
| 553 |
+
|
| 554 |
+
# symbols
|
| 555 |
+
positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
|
| 556 |
+
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
|
| 557 |
+
point_cn = CM(POINT[0], POINT[1], '.', lambda x,
|
| 558 |
+
y: float(str(x) + '.' + str(y)))
|
| 559 |
+
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
|
| 560 |
+
system = NumberSystem()
|
| 561 |
+
system.units = smaller_units + larger_units
|
| 562 |
+
system.digits = digits
|
| 563 |
+
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
|
| 564 |
+
# system.symbols = OtherSymbol(sil_cn)
|
| 565 |
+
return system
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
| 569 |
+
|
| 570 |
+
def get_symbol(char, system):
|
| 571 |
+
for u in system.units:
|
| 572 |
+
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
|
| 573 |
+
return u
|
| 574 |
+
for d in system.digits:
|
| 575 |
+
if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
|
| 576 |
+
return d
|
| 577 |
+
for m in system.math:
|
| 578 |
+
if char in [m.traditional, m.simplified]:
|
| 579 |
+
return m
|
| 580 |
+
|
| 581 |
+
def string2symbols(chinese_string, system):
|
| 582 |
+
int_string, dec_string = chinese_string, ''
|
| 583 |
+
for p in [system.math.point.simplified, system.math.point.traditional]:
|
| 584 |
+
if p in chinese_string:
|
| 585 |
+
int_string, dec_string = chinese_string.split(p)
|
| 586 |
+
break
|
| 587 |
+
return [get_symbol(c, system) for c in int_string], \
|
| 588 |
+
[get_symbol(c, system) for c in dec_string]
|
| 589 |
+
|
| 590 |
+
def correct_symbols(integer_symbols, system):
|
| 591 |
+
"""
|
| 592 |
+
一百八 to 一百八十
|
| 593 |
+
一亿一千三百万 to 一亿 一千万 三百万
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
if integer_symbols and isinstance(integer_symbols[0], CNU):
|
| 597 |
+
if integer_symbols[0].power == 1:
|
| 598 |
+
integer_symbols = [system.digits[1]] + integer_symbols
|
| 599 |
+
|
| 600 |
+
if len(integer_symbols) > 1:
|
| 601 |
+
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
|
| 602 |
+
integer_symbols.append(
|
| 603 |
+
CNU(integer_symbols[-2].power - 1, None, None, None, None))
|
| 604 |
+
|
| 605 |
+
result = []
|
| 606 |
+
unit_count = 0
|
| 607 |
+
for s in integer_symbols:
|
| 608 |
+
if isinstance(s, CND):
|
| 609 |
+
result.append(s)
|
| 610 |
+
unit_count = 0
|
| 611 |
+
elif isinstance(s, CNU):
|
| 612 |
+
current_unit = CNU(s.power, None, None, None, None)
|
| 613 |
+
unit_count += 1
|
| 614 |
+
|
| 615 |
+
if unit_count == 1:
|
| 616 |
+
result.append(current_unit)
|
| 617 |
+
elif unit_count > 1:
|
| 618 |
+
for i in range(len(result)):
|
| 619 |
+
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
|
| 620 |
+
result[-i - 1] = CNU(result[-i - 1].power +
|
| 621 |
+
current_unit.power, None, None, None, None)
|
| 622 |
+
return result
|
| 623 |
+
|
| 624 |
+
def compute_value(integer_symbols):
|
| 625 |
+
"""
|
| 626 |
+
Compute the value.
|
| 627 |
+
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
|
| 628 |
+
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
|
| 629 |
+
"""
|
| 630 |
+
value = [0]
|
| 631 |
+
last_power = 0
|
| 632 |
+
for s in integer_symbols:
|
| 633 |
+
if isinstance(s, CND):
|
| 634 |
+
value[-1] = s.value
|
| 635 |
+
elif isinstance(s, CNU):
|
| 636 |
+
value[-1] *= pow(10, s.power)
|
| 637 |
+
if s.power > last_power:
|
| 638 |
+
value[:-1] = list(map(lambda v: v *
|
| 639 |
+
pow(10, s.power), value[:-1]))
|
| 640 |
+
last_power = s.power
|
| 641 |
+
value.append(0)
|
| 642 |
+
return sum(value)
|
| 643 |
+
|
| 644 |
+
system = create_system(numbering_type)
|
| 645 |
+
int_part, dec_part = string2symbols(chinese_string, system)
|
| 646 |
+
int_part = correct_symbols(int_part, system)
|
| 647 |
+
int_str = str(compute_value(int_part))
|
| 648 |
+
dec_str = ''.join([str(d.value) for d in dec_part])
|
| 649 |
+
if dec_part:
|
| 650 |
+
return '{0}.{1}'.format(int_str, dec_str)
|
| 651 |
+
else:
|
| 652 |
+
return int_str
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
|
| 656 |
+
traditional=False, alt_zero=False, alt_one=False, alt_two=True,
|
| 657 |
+
use_zeros=True, use_units=True):
|
| 658 |
+
|
| 659 |
+
def get_value(value_string, use_zeros=True):
|
| 660 |
+
|
| 661 |
+
striped_string = value_string.lstrip('0')
|
| 662 |
+
|
| 663 |
+
# record nothing if all zeros
|
| 664 |
+
if not striped_string:
|
| 665 |
+
return []
|
| 666 |
+
|
| 667 |
+
# record one digits
|
| 668 |
+
elif len(striped_string) == 1:
|
| 669 |
+
if use_zeros and len(value_string) != len(striped_string):
|
| 670 |
+
return [system.digits[0], system.digits[int(striped_string)]]
|
| 671 |
+
else:
|
| 672 |
+
return [system.digits[int(striped_string)]]
|
| 673 |
+
|
| 674 |
+
# recursively record multiple digits
|
| 675 |
+
else:
|
| 676 |
+
result_unit = next(u for u in reversed(
|
| 677 |
+
system.units) if u.power < len(striped_string))
|
| 678 |
+
result_string = value_string[:-result_unit.power]
|
| 679 |
+
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
|
| 680 |
+
|
| 681 |
+
system = create_system(numbering_type)
|
| 682 |
+
|
| 683 |
+
int_dec = number_string.split('.')
|
| 684 |
+
if len(int_dec) == 1:
|
| 685 |
+
int_string = int_dec[0]
|
| 686 |
+
dec_string = ""
|
| 687 |
+
elif len(int_dec) == 2:
|
| 688 |
+
int_string = int_dec[0]
|
| 689 |
+
dec_string = int_dec[1]
|
| 690 |
+
else:
|
| 691 |
+
raise ValueError(
|
| 692 |
+
"invalid input num string with more than one dot: {}".format(number_string))
|
| 693 |
+
|
| 694 |
+
if use_units and len(int_string) > 1:
|
| 695 |
+
result_symbols = get_value(int_string)
|
| 696 |
+
else:
|
| 697 |
+
result_symbols = [system.digits[int(c)] for c in int_string]
|
| 698 |
+
dec_symbols = [system.digits[int(c)] for c in dec_string]
|
| 699 |
+
if dec_string:
|
| 700 |
+
result_symbols += [system.math.point] + dec_symbols
|
| 701 |
+
|
| 702 |
+
if alt_two:
|
| 703 |
+
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
|
| 704 |
+
system.digits[2].big_s, system.digits[2].big_t)
|
| 705 |
+
for i, v in enumerate(result_symbols):
|
| 706 |
+
if isinstance(v, CND) and v.value == 2:
|
| 707 |
+
next_symbol = result_symbols[i +
|
| 708 |
+
1] if i < len(result_symbols) - 1 else None
|
| 709 |
+
previous_symbol = result_symbols[i - 1] if i > 0 else None
|
| 710 |
+
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
|
| 711 |
+
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
|
| 712 |
+
result_symbols[i] = liang
|
| 713 |
+
|
| 714 |
+
# if big is True, '两' will not be used and `alt_two` has no impact on output
|
| 715 |
+
if big:
|
| 716 |
+
attr_name = 'big_'
|
| 717 |
+
if traditional:
|
| 718 |
+
attr_name += 't'
|
| 719 |
+
else:
|
| 720 |
+
attr_name += 's'
|
| 721 |
+
else:
|
| 722 |
+
if traditional:
|
| 723 |
+
attr_name = 'traditional'
|
| 724 |
+
else:
|
| 725 |
+
attr_name = 'simplified'
|
| 726 |
+
|
| 727 |
+
result = ''.join([getattr(s, attr_name) for s in result_symbols])
|
| 728 |
+
|
| 729 |
+
# if not use_zeros:
|
| 730 |
+
# result = result.strip(getattr(system.digits[0], attr_name))
|
| 731 |
+
|
| 732 |
+
if alt_zero:
|
| 733 |
+
result = result.replace(
|
| 734 |
+
getattr(system.digits[0], attr_name), system.digits[0].alt_s)
|
| 735 |
+
|
| 736 |
+
if alt_one:
|
| 737 |
+
result = result.replace(
|
| 738 |
+
getattr(system.digits[1], attr_name), system.digits[1].alt_s)
|
| 739 |
+
|
| 740 |
+
for i, p in enumerate(POINT):
|
| 741 |
+
if result.startswith(p):
|
| 742 |
+
return CHINESE_DIGIS[0] + result
|
| 743 |
+
|
| 744 |
+
# ^10, 11, .., 19
|
| 745 |
+
if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
|
| 746 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
|
| 747 |
+
result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
|
| 748 |
+
result = result[1:]
|
| 749 |
+
|
| 750 |
+
return result
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# ================================================================================ #
|
| 754 |
+
# different types of rewriters
|
| 755 |
+
# ================================================================================ #
|
| 756 |
+
class Cardinal:
|
| 757 |
+
"""
|
| 758 |
+
CARDINAL类
|
| 759 |
+
"""
|
| 760 |
+
|
| 761 |
+
def __init__(self, cardinal=None, chntext=None):
|
| 762 |
+
self.cardinal = cardinal
|
| 763 |
+
self.chntext = chntext
|
| 764 |
+
|
| 765 |
+
def chntext2cardinal(self):
|
| 766 |
+
return chn2num(self.chntext)
|
| 767 |
+
|
| 768 |
+
def cardinal2chntext(self):
|
| 769 |
+
return num2chn(self.cardinal)
|
| 770 |
+
|
| 771 |
+
class Digit:
|
| 772 |
+
"""
|
| 773 |
+
DIGIT类
|
| 774 |
+
"""
|
| 775 |
+
|
| 776 |
+
def __init__(self, digit=None, chntext=None):
|
| 777 |
+
self.digit = digit
|
| 778 |
+
self.chntext = chntext
|
| 779 |
+
|
| 780 |
+
# def chntext2digit(self):
|
| 781 |
+
# return chn2num(self.chntext)
|
| 782 |
+
|
| 783 |
+
def digit2chntext(self):
|
| 784 |
+
return num2chn(self.digit, alt_two=False, use_units=False)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
class TelePhone:
|
| 788 |
+
"""
|
| 789 |
+
TELEPHONE类
|
| 790 |
+
"""
|
| 791 |
+
|
| 792 |
+
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
|
| 793 |
+
self.telephone = telephone
|
| 794 |
+
self.raw_chntext = raw_chntext
|
| 795 |
+
self.chntext = chntext
|
| 796 |
+
|
| 797 |
+
# def chntext2telephone(self):
|
| 798 |
+
# sil_parts = self.raw_chntext.split('<SIL>')
|
| 799 |
+
# self.telephone = '-'.join([
|
| 800 |
+
# str(chn2num(p)) for p in sil_parts
|
| 801 |
+
# ])
|
| 802 |
+
# return self.telephone
|
| 803 |
+
|
| 804 |
+
def telephone2chntext(self, fixed=False):
|
| 805 |
+
|
| 806 |
+
if fixed:
|
| 807 |
+
sil_parts = self.telephone.split('-')
|
| 808 |
+
self.raw_chntext = '<SIL>'.join([
|
| 809 |
+
num2chn(part, alt_two=False, use_units=False) for part in sil_parts
|
| 810 |
+
])
|
| 811 |
+
self.chntext = self.raw_chntext.replace('<SIL>', '')
|
| 812 |
+
else:
|
| 813 |
+
sp_parts = self.telephone.strip('+').split()
|
| 814 |
+
self.raw_chntext = '<SP>'.join([
|
| 815 |
+
num2chn(part, alt_two=False, use_units=False) for part in sp_parts
|
| 816 |
+
])
|
| 817 |
+
self.chntext = self.raw_chntext.replace('<SP>', '')
|
| 818 |
+
return self.chntext
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class Fraction:
|
| 822 |
+
"""
|
| 823 |
+
FRACTION类
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
def __init__(self, fraction=None, chntext=None):
|
| 827 |
+
self.fraction = fraction
|
| 828 |
+
self.chntext = chntext
|
| 829 |
+
|
| 830 |
+
def chntext2fraction(self):
|
| 831 |
+
denominator, numerator = self.chntext.split('分之')
|
| 832 |
+
return chn2num(numerator) + '/' + chn2num(denominator)
|
| 833 |
+
|
| 834 |
+
def fraction2chntext(self):
|
| 835 |
+
numerator, denominator = self.fraction.split('/')
|
| 836 |
+
return num2chn(denominator) + '分之' + num2chn(numerator)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class Date:
|
| 840 |
+
"""
|
| 841 |
+
DATE类
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
def __init__(self, date=None, chntext=None):
|
| 845 |
+
self.date = date
|
| 846 |
+
self.chntext = chntext
|
| 847 |
+
|
| 848 |
+
# def chntext2date(self):
|
| 849 |
+
# chntext = self.chntext
|
| 850 |
+
# try:
|
| 851 |
+
# year, other = chntext.strip().split('年', maxsplit=1)
|
| 852 |
+
# year = Digit(chntext=year).digit2chntext() + '年'
|
| 853 |
+
# except ValueError:
|
| 854 |
+
# other = chntext
|
| 855 |
+
# year = ''
|
| 856 |
+
# if other:
|
| 857 |
+
# try:
|
| 858 |
+
# month, day = other.strip().split('月', maxsplit=1)
|
| 859 |
+
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
|
| 860 |
+
# except ValueError:
|
| 861 |
+
# day = chntext
|
| 862 |
+
# month = ''
|
| 863 |
+
# if day:
|
| 864 |
+
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
|
| 865 |
+
# else:
|
| 866 |
+
# month = ''
|
| 867 |
+
# day = ''
|
| 868 |
+
# date = year + month + day
|
| 869 |
+
# self.date = date
|
| 870 |
+
# return self.date
|
| 871 |
+
|
| 872 |
+
def date2chntext(self):
|
| 873 |
+
date = self.date
|
| 874 |
+
try:
|
| 875 |
+
year, other = date.strip().split('年', 1)
|
| 876 |
+
year = Digit(digit=year).digit2chntext() + '年'
|
| 877 |
+
except ValueError:
|
| 878 |
+
other = date
|
| 879 |
+
year = ''
|
| 880 |
+
if other:
|
| 881 |
+
try:
|
| 882 |
+
month, day = other.strip().split('月', 1)
|
| 883 |
+
month = Cardinal(cardinal=month).cardinal2chntext() + '月'
|
| 884 |
+
except ValueError:
|
| 885 |
+
day = date
|
| 886 |
+
month = ''
|
| 887 |
+
if day:
|
| 888 |
+
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
|
| 889 |
+
else:
|
| 890 |
+
month = ''
|
| 891 |
+
day = ''
|
| 892 |
+
chntext = year + month + day
|
| 893 |
+
self.chntext = chntext
|
| 894 |
+
return self.chntext
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
class Money:
|
| 898 |
+
"""
|
| 899 |
+
MONEY类
|
| 900 |
+
"""
|
| 901 |
+
|
| 902 |
+
def __init__(self, money=None, chntext=None):
|
| 903 |
+
self.money = money
|
| 904 |
+
self.chntext = chntext
|
| 905 |
+
|
| 906 |
+
# def chntext2money(self):
|
| 907 |
+
# return self.money
|
| 908 |
+
|
| 909 |
+
def money2chntext(self):
|
| 910 |
+
money = self.money
|
| 911 |
+
pattern = re.compile(r'(\d+(\.\d+)?)')
|
| 912 |
+
matchers = pattern.findall(money)
|
| 913 |
+
if matchers:
|
| 914 |
+
for matcher in matchers:
|
| 915 |
+
money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
|
| 916 |
+
self.chntext = money
|
| 917 |
+
return self.chntext
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
class Percentage:
|
| 921 |
+
"""
|
| 922 |
+
PERCENTAGE类
|
| 923 |
+
"""
|
| 924 |
+
|
| 925 |
+
def __init__(self, percentage=None, chntext=None):
|
| 926 |
+
self.percentage = percentage
|
| 927 |
+
self.chntext = chntext
|
| 928 |
+
|
| 929 |
+
def chntext2percentage(self):
|
| 930 |
+
return chn2num(self.chntext.strip().strip('百分之')) + '%'
|
| 931 |
+
|
| 932 |
+
def percentage2chntext(self):
|
| 933 |
+
return '百分之' + num2chn(self.percentage.strip().strip('%'))
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def normalize_nsw(raw_text):
|
| 937 |
+
text = '^' + raw_text + '$'
|
| 938 |
+
|
| 939 |
+
# 规范化日期
|
| 940 |
+
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
|
| 941 |
+
matchers = pattern.findall(text)
|
| 942 |
+
if matchers:
|
| 943 |
+
#print('date')
|
| 944 |
+
for matcher in matchers:
|
| 945 |
+
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
| 946 |
+
|
| 947 |
+
# 规范化金钱
|
| 948 |
+
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
|
| 949 |
+
matchers = pattern.findall(text)
|
| 950 |
+
if matchers:
|
| 951 |
+
#print('money')
|
| 952 |
+
for matcher in matchers:
|
| 953 |
+
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
|
| 954 |
+
|
| 955 |
+
# 规范化固话/手机号码
|
| 956 |
+
# 手机
|
| 957 |
+
# http://www.jihaoba.com/news/show/13680
|
| 958 |
+
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
| 959 |
+
# 联通:130、131、132、156、155、186、185、176
|
| 960 |
+
# 电信:133、153、189、180、181、177
|
| 961 |
+
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
| 962 |
+
matchers = pattern.findall(text)
|
| 963 |
+
if matchers:
|
| 964 |
+
#print('telephone')
|
| 965 |
+
for matcher in matchers:
|
| 966 |
+
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
|
| 967 |
+
# 固话
|
| 968 |
+
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
| 969 |
+
matchers = pattern.findall(text)
|
| 970 |
+
if matchers:
|
| 971 |
+
# print('fixed telephone')
|
| 972 |
+
for matcher in matchers:
|
| 973 |
+
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
|
| 974 |
+
|
| 975 |
+
# 规范化分数
|
| 976 |
+
pattern = re.compile(r"(\d+/\d+)")
|
| 977 |
+
matchers = pattern.findall(text)
|
| 978 |
+
if matchers:
|
| 979 |
+
#print('fraction')
|
| 980 |
+
for matcher in matchers:
|
| 981 |
+
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
|
| 982 |
+
|
| 983 |
+
# 规范化百分数
|
| 984 |
+
text = text.replace('%', '%')
|
| 985 |
+
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
| 986 |
+
matchers = pattern.findall(text)
|
| 987 |
+
if matchers:
|
| 988 |
+
#print('percentage')
|
| 989 |
+
for matcher in matchers:
|
| 990 |
+
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
|
| 991 |
+
|
| 992 |
+
# 规范化纯数+量词
|
| 993 |
+
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
| 994 |
+
matchers = pattern.findall(text)
|
| 995 |
+
if matchers:
|
| 996 |
+
#print('cardinal+quantifier')
|
| 997 |
+
for matcher in matchers:
|
| 998 |
+
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
| 999 |
+
|
| 1000 |
+
# 规范化数字编号
|
| 1001 |
+
pattern = re.compile(r"(\d{4,32})")
|
| 1002 |
+
matchers = pattern.findall(text)
|
| 1003 |
+
if matchers:
|
| 1004 |
+
#print('digit')
|
| 1005 |
+
for matcher in matchers:
|
| 1006 |
+
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
| 1007 |
+
|
| 1008 |
+
# 规范化纯数
|
| 1009 |
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
| 1010 |
+
matchers = pattern.findall(text)
|
| 1011 |
+
if matchers:
|
| 1012 |
+
#print('cardinal')
|
| 1013 |
+
for matcher in matchers:
|
| 1014 |
+
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
# restore P2P, O2O, B2C, B2B etc
|
| 1018 |
+
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
|
| 1019 |
+
matchers = pattern.findall(text)
|
| 1020 |
+
if matchers:
|
| 1021 |
+
# print('particular')
|
| 1022 |
+
for matcher in matchers:
|
| 1023 |
+
text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
|
| 1024 |
+
|
| 1025 |
+
return text.lstrip('^').rstrip('$')
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
def remove_erhua(text):
|
| 1029 |
+
"""
|
| 1030 |
+
去除儿化音词中的儿:
|
| 1031 |
+
他女儿在那边儿 -> 他女儿在那边
|
| 1032 |
+
"""
|
| 1033 |
+
|
| 1034 |
+
new_str=''
|
| 1035 |
+
while re.search('儿',text):
|
| 1036 |
+
a = re.search('儿',text).span()
|
| 1037 |
+
remove_er_flag = 0
|
| 1038 |
+
|
| 1039 |
+
if ER_WHITELIST_PATTERN.search(text):
|
| 1040 |
+
b = ER_WHITELIST_PATTERN.search(text).span()
|
| 1041 |
+
if b[0] <= a[0]:
|
| 1042 |
+
remove_er_flag = 1
|
| 1043 |
+
|
| 1044 |
+
if remove_er_flag == 0 :
|
| 1045 |
+
new_str = new_str + text[0:a[0]]
|
| 1046 |
+
text = text[a[1]:]
|
| 1047 |
+
else:
|
| 1048 |
+
new_str = new_str + text[0:b[1]]
|
| 1049 |
+
text = text[b[1]:]
|
| 1050 |
+
|
| 1051 |
+
text = new_str + text
|
| 1052 |
+
return text
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
def remove_space(text):
|
| 1056 |
+
tokens = text.split()
|
| 1057 |
+
new = []
|
| 1058 |
+
for k,t in enumerate(tokens):
|
| 1059 |
+
if k != 0:
|
| 1060 |
+
if IN_EN_CHARS.get(tokens[k-1][-1]) and IN_EN_CHARS.get(t[0]):
|
| 1061 |
+
new.append(' ')
|
| 1062 |
+
new.append(t)
|
| 1063 |
+
return ''.join(new)
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
class TextNorm:
|
| 1067 |
+
def __init__(self,
|
| 1068 |
+
to_banjiao:bool = False,
|
| 1069 |
+
to_upper:bool = False,
|
| 1070 |
+
to_lower:bool = False,
|
| 1071 |
+
remove_fillers:bool = False,
|
| 1072 |
+
remove_erhua:bool = False,
|
| 1073 |
+
check_chars:bool = False,
|
| 1074 |
+
remove_space:bool = False,
|
| 1075 |
+
cc_mode:str = '',
|
| 1076 |
+
) :
|
| 1077 |
+
self.to_banjiao = to_banjiao
|
| 1078 |
+
self.to_upper = to_upper
|
| 1079 |
+
self.to_lower = to_lower
|
| 1080 |
+
self.remove_fillers = remove_fillers
|
| 1081 |
+
self.remove_erhua = remove_erhua
|
| 1082 |
+
self.check_chars = check_chars
|
| 1083 |
+
self.remove_space = remove_space
|
| 1084 |
+
|
| 1085 |
+
self.cc = None
|
| 1086 |
+
if cc_mode:
|
| 1087 |
+
from opencc import OpenCC # Open Chinese Convert: pip install opencc
|
| 1088 |
+
self.cc = OpenCC(cc_mode)
|
| 1089 |
+
|
| 1090 |
+
def __call__(self, text):
|
| 1091 |
+
if self.cc:
|
| 1092 |
+
text = self.cc.convert(text)
|
| 1093 |
+
|
| 1094 |
+
if self.to_banjiao:
|
| 1095 |
+
text = text.translate(QJ2BJ_TRANSFORM)
|
| 1096 |
+
|
| 1097 |
+
if self.to_upper:
|
| 1098 |
+
text = text.upper()
|
| 1099 |
+
|
| 1100 |
+
if self.to_lower:
|
| 1101 |
+
text = text.lower()
|
| 1102 |
+
|
| 1103 |
+
if self.remove_fillers:
|
| 1104 |
+
for c in FILLER_CHARS:
|
| 1105 |
+
text = text.replace(c, '')
|
| 1106 |
+
|
| 1107 |
+
if self.remove_erhua:
|
| 1108 |
+
text = remove_erhua(text)
|
| 1109 |
+
|
| 1110 |
+
text = normalize_nsw(text)
|
| 1111 |
+
|
| 1112 |
+
text = text.translate(PUNCS_TRANSFORM)
|
| 1113 |
+
|
| 1114 |
+
if self.check_chars:
|
| 1115 |
+
for c in text:
|
| 1116 |
+
if not IN_VALID_CHARS.get(c):
|
| 1117 |
+
print(f'WARNING: illegal char {c} in: {text}', file=sys.stderr)
|
| 1118 |
+
return ''
|
| 1119 |
+
|
| 1120 |
+
if self.remove_space:
|
| 1121 |
+
text = remove_space(text)
|
| 1122 |
+
|
| 1123 |
+
return text
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
if __name__ == '__main__':
|
| 1127 |
+
p = argparse.ArgumentParser()
|
| 1128 |
+
|
| 1129 |
+
# normalizer options
|
| 1130 |
+
p.add_argument('--to_banjiao', action='store_true', help='convert quanjiao chars to banjiao')
|
| 1131 |
+
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
|
| 1132 |
+
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
|
| 1133 |
+
p.add_argument('--remove_fillers', action='store_true', help='remove filler chars such as "呃, 啊"')
|
| 1134 |
+
p.add_argument('--remove_erhua', action='store_true', help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"')
|
| 1135 |
+
p.add_argument('--check_chars', action='store_true' , help='skip sentences containing illegal chars')
|
| 1136 |
+
p.add_argument('--remove_space', action='store_true' , help='remove whitespace')
|
| 1137 |
+
p.add_argument('--cc_mode', choices=['', 't2s', 's2t'], default='', help='convert between traditional to simplified')
|
| 1138 |
+
|
| 1139 |
+
# I/O options
|
| 1140 |
+
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
|
| 1141 |
+
p.add_argument('--has_key', action='store_true', help="will be deprecated, set --format ark instead")
|
| 1142 |
+
p.add_argument('--format', type=str, choices=['txt', 'ark', 'tsv'], default='txt', help='input format')
|
| 1143 |
+
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
|
| 1144 |
+
p.add_argument('ofile', help='output filename')
|
| 1145 |
+
|
| 1146 |
+
args = p.parse_args()
|
| 1147 |
+
|
| 1148 |
+
if args.has_key:
|
| 1149 |
+
args.format = 'ark'
|
| 1150 |
+
|
| 1151 |
+
normalizer = TextNorm(
|
| 1152 |
+
to_banjiao = args.to_banjiao,
|
| 1153 |
+
to_upper = args.to_upper,
|
| 1154 |
+
to_lower = args.to_lower,
|
| 1155 |
+
remove_fillers = args.remove_fillers,
|
| 1156 |
+
remove_erhua = args.remove_erhua,
|
| 1157 |
+
check_chars = args.check_chars,
|
| 1158 |
+
remove_space = args.remove_space,
|
| 1159 |
+
cc_mode = args.cc_mode,
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
ndone = 0
|
| 1163 |
+
with open(args.ifile, 'r', encoding = 'utf8') as istream, open(args.ofile, 'w+', encoding = 'utf8') as ostream:
|
| 1164 |
+
if args.format == 'tsv':
|
| 1165 |
+
reader = csv.DictReader(istream, delimiter = '\t')
|
| 1166 |
+
assert('TEXT' in reader.fieldnames)
|
| 1167 |
+
print('\t'.join(reader.fieldnames), file=ostream)
|
| 1168 |
+
|
| 1169 |
+
for item in reader:
|
| 1170 |
+
text = item['TEXT']
|
| 1171 |
+
|
| 1172 |
+
if text:
|
| 1173 |
+
text = normalizer(text)
|
| 1174 |
+
|
| 1175 |
+
if text:
|
| 1176 |
+
item['TEXT'] = text
|
| 1177 |
+
print('\t'.join([ item[f] for f in reader.fieldnames ]), file = ostream)
|
| 1178 |
+
|
| 1179 |
+
ndone += 1
|
| 1180 |
+
if ndone % args.log_interval == 0:
|
| 1181 |
+
print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True)
|
| 1182 |
+
else:
|
| 1183 |
+
for l in istream:
|
| 1184 |
+
key, text = '', ''
|
| 1185 |
+
if args.format == 'ark': # KALDI archive, line format: "key text"
|
| 1186 |
+
cols = l.strip().split(maxsplit=1)
|
| 1187 |
+
key, text = cols[0], cols[1] if len(cols) == 2 else ''
|
| 1188 |
+
else:
|
| 1189 |
+
text = l.strip()
|
| 1190 |
+
|
| 1191 |
+
if text:
|
| 1192 |
+
text = normalizer(text)
|
| 1193 |
+
|
| 1194 |
+
if text:
|
| 1195 |
+
if args.format == 'ark':
|
| 1196 |
+
print(key + '\t' + text, file = ostream)
|
| 1197 |
+
else:
|
| 1198 |
+
print(text, file = ostream)
|
| 1199 |
+
|
| 1200 |
+
ndone += 1
|
| 1201 |
+
if ndone % args.log_interval == 0:
|
| 1202 |
+
print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True)
|
| 1203 |
+
print(f'text norm: {ndone} lines done in total.', file = sys.stderr, flush = True)
|
vendor/MegaASR/eval/evaluate_wer.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import unicodedata
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
BATCH_SIZE = 8
|
| 11 |
+
MAX_NEW_TOKENS = 256
|
| 12 |
+
ROOT_DIR = Path(__file__).resolve().parents[3]
|
| 13 |
+
sys.path.append(str(ROOT_DIR / "src"))
|
| 14 |
+
sys.path.append(str(Path(__file__).resolve().parent))
|
| 15 |
+
|
| 16 |
+
from cn_tn import TextNorm
|
| 17 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 18 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 19 |
+
|
| 20 |
+
ENGLISH_NORMALIZER = EnglishTextNormalizer()
|
| 21 |
+
CHINESE_NORMALIZER = TextNorm(
|
| 22 |
+
to_banjiao=False,
|
| 23 |
+
to_upper=False,
|
| 24 |
+
to_lower=False,
|
| 25 |
+
remove_fillers=False,
|
| 26 |
+
remove_erhua=False,
|
| 27 |
+
check_chars=False,
|
| 28 |
+
remove_space=False,
|
| 29 |
+
cc_mode="",
|
| 30 |
+
)
|
| 31 |
+
BASIC_NORMALIZER = BasicTextNormalizer()
|
| 32 |
+
|
| 33 |
+
def detect_language(ref, pred):
|
| 34 |
+
return "zh" if any("\u4e00" <= ch <= "\u9fff" for ch in ref + pred) else "en"
|
| 35 |
+
|
| 36 |
+
def unwrap_prediction(pred):
|
| 37 |
+
if isinstance(pred, list):
|
| 38 |
+
return " ".join(str(x) for x in pred)
|
| 39 |
+
return str(pred)
|
| 40 |
+
|
| 41 |
+
class EvaluationTokenizer:
|
| 42 |
+
SPACE = chr(32)
|
| 43 |
+
SPACE_ESCAPE = chr(9601)
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
tokenizer_type: str = "zh",
|
| 48 |
+
lowercase: bool = True,
|
| 49 |
+
punctuation_removal: bool = True,
|
| 50 |
+
character_tokenization: bool = False,
|
| 51 |
+
):
|
| 52 |
+
from sacrebleu.tokenizers import TOKENIZERS
|
| 53 |
+
|
| 54 |
+
self.tokenizer = TOKENIZERS[tokenizer_type]
|
| 55 |
+
self.lowercase = lowercase
|
| 56 |
+
self.punctuation_removal = punctuation_removal
|
| 57 |
+
self.character_tokenization = character_tokenization
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def remove_punctuation(cls, sent: str):
|
| 61 |
+
return cls.SPACE.join(
|
| 62 |
+
token
|
| 63 |
+
for token in sent.split(cls.SPACE)
|
| 64 |
+
if not all(unicodedata.category(char)[0] == "P" for char in token)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def tokenize(self, sent: str):
|
| 68 |
+
tokenized = self.tokenizer()(sent)
|
| 69 |
+
if self.punctuation_removal:
|
| 70 |
+
tokenized = self.remove_punctuation(tokenized)
|
| 71 |
+
if self.character_tokenization:
|
| 72 |
+
tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
|
| 73 |
+
if self.lowercase:
|
| 74 |
+
tokenized = tokenized.lower()
|
| 75 |
+
return tokenized
|
| 76 |
+
|
| 77 |
+
def compute_one_error(ref, pred, language):
|
| 78 |
+
import editdistance as ed
|
| 79 |
+
import zhconv
|
| 80 |
+
|
| 81 |
+
if language == "yue":
|
| 82 |
+
ref = zhconv.convert(ref, "zh-cn")
|
| 83 |
+
pred = zhconv.convert(pred, "zh-cn")
|
| 84 |
+
if language == "en":
|
| 85 |
+
ref = ENGLISH_NORMALIZER(ref)
|
| 86 |
+
pred = ENGLISH_NORMALIZER(pred)
|
| 87 |
+
elif language == "zh":
|
| 88 |
+
ref = CHINESE_NORMALIZER(ref)
|
| 89 |
+
pred = CHINESE_NORMALIZER(pred)
|
| 90 |
+
else:
|
| 91 |
+
ref = BASIC_NORMALIZER(ref)
|
| 92 |
+
pred = BASIC_NORMALIZER(pred)
|
| 93 |
+
|
| 94 |
+
tokenizer = EvaluationTokenizer()
|
| 95 |
+
ref_items = tokenizer.tokenize(ref).split()
|
| 96 |
+
pred_items = tokenizer.tokenize(pred).split()
|
| 97 |
+
edits = ed.eval(ref_items, pred_items)
|
| 98 |
+
ref_len = len(ref_items)
|
| 99 |
+
return (edits / ref_len if ref_len else 0.0), {"err": int(edits), "nref": int(ref_len)}
|
| 100 |
+
|
| 101 |
+
def resolve_audio(path, jsonl_path):
|
| 102 |
+
path = Path(path)
|
| 103 |
+
if path.is_absolute():
|
| 104 |
+
return str(path)
|
| 105 |
+
jsonl_dir_path = Path(jsonl_path).resolve().parent / path
|
| 106 |
+
return str(jsonl_dir_path if jsonl_dir_path.exists() else Path.cwd() / path)
|
| 107 |
+
|
| 108 |
+
def get_audio_field(item):
|
| 109 |
+
if "audio" in item:
|
| 110 |
+
return item["audio"]
|
| 111 |
+
if "audio_path" in item:
|
| 112 |
+
return item["audio_path"]
|
| 113 |
+
raise KeyError("Input JSONL item must contain `audio` or `audio_path`.")
|
| 114 |
+
|
| 115 |
+
def main():
|
| 116 |
+
parser = argparse.ArgumentParser("Run Mega-ASR inference and compute WER/CER.")
|
| 117 |
+
parser.add_argument("--ckpt_dir", required=True, help="Mega-ASR ckpt root")
|
| 118 |
+
parser.add_argument("--input_jsonl", required=True)
|
| 119 |
+
parser.add_argument("--output_jsonl", required=True)
|
| 120 |
+
parser.add_argument("--routing", action=argparse.BooleanOptionalAction, default=True)
|
| 121 |
+
parser.add_argument("--threshold", type=float, default=0.5)
|
| 122 |
+
parser.add_argument("--device_map", default=None)
|
| 123 |
+
parser.add_argument("--gpu", default=None)
|
| 124 |
+
parser.add_argument("--keep_delta_on_gpu", action=argparse.BooleanOptionalAction, default=True)
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
|
| 127 |
+
if args.gpu is not None:
|
| 128 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 129 |
+
|
| 130 |
+
from MegaASR.model.megaASR import MegaASR
|
| 131 |
+
|
| 132 |
+
ckpt_dir = Path(args.ckpt_dir).expanduser()
|
| 133 |
+
|
| 134 |
+
model = MegaASR(
|
| 135 |
+
model_path=ckpt_dir / "Qwen3-ASR-1.7B",
|
| 136 |
+
lora_dir=ckpt_dir / "mega-asr-merged",
|
| 137 |
+
router_checkpoint=ckpt_dir / "audio_quality_router/best_acc_model.safetensors",
|
| 138 |
+
routing_enabled=args.routing,
|
| 139 |
+
quality_threshold=args.threshold,
|
| 140 |
+
device_map=args.device_map,
|
| 141 |
+
max_inference_batch_size=BATCH_SIZE,
|
| 142 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 143 |
+
keep_delta_on_gpu=args.keep_delta_on_gpu,
|
| 144 |
+
)
|
| 145 |
+
with open(args.input_jsonl, "r", encoding="utf-8") as f:
|
| 146 |
+
data = [json.loads(line) for line in f if line.strip()]
|
| 147 |
+
outputs, total_edits, total_ref_len = [], 0, 0
|
| 148 |
+
|
| 149 |
+
for i in tqdm(range(0, len(data), BATCH_SIZE), desc="evaluating"):
|
| 150 |
+
batch = data[i:i + BATCH_SIZE]
|
| 151 |
+
audio_paths = [resolve_audio(get_audio_field(x), args.input_jsonl) for x in batch]
|
| 152 |
+
results = model.batch_infer(audio_paths)
|
| 153 |
+
for item, pred in zip(batch, results):
|
| 154 |
+
pred = unwrap_prediction(pred).strip()
|
| 155 |
+
language = item.get("language") or detect_language(item["answer"], pred)
|
| 156 |
+
score, detail = compute_one_error(item["answer"], pred, language)
|
| 157 |
+
edits = detail["err"]
|
| 158 |
+
ref_len = detail["nref"]
|
| 159 |
+
metric = "cer" if language in {"zh", "yue"} else "wer"
|
| 160 |
+
item["prediction"] = pred
|
| 161 |
+
item["metric"] = metric
|
| 162 |
+
item["wer"] = round(float(score), 6)
|
| 163 |
+
item["num_edits"] = int(edits)
|
| 164 |
+
item["ref_len"] = int(ref_len)
|
| 165 |
+
total_edits += edits
|
| 166 |
+
total_ref_len += ref_len
|
| 167 |
+
outputs.append(item)
|
| 168 |
+
|
| 169 |
+
out_path = Path(args.output_jsonl)
|
| 170 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 172 |
+
for item in outputs:
|
| 173 |
+
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 174 |
+
print(f"samples: {len(outputs)}")
|
| 175 |
+
print(f"overall_error: {total_edits / total_ref_len if total_ref_len else 0.0:.6f}")
|
| 176 |
+
print(f"saved: {out_path}")
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
main()
|
vendor/MegaASR/eval/evaluate_wer.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python eval/evaluate_wer.py \
|
| 2 |
+
--model_path Qwen3-ASR-1.7B \
|
| 3 |
+
--input_jsonl examples.jsonl \
|
| 4 |
+
--output_jsonl output_with_wer.jsonl
|
vendor/MegaASR/eval/readme.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## ASR Evaluation
|
| 2 |
+
|
| 3 |
+
We provide a simple evaluation script for running Mega-ASR inference and computing WER/CER.
|
| 4 |
+
The input file should be a JSONL file. Each line only needs two required fields:
|
| 5 |
+
|
| 6 |
+
```json
|
| 7 |
+
{"audio": "examples/audio/noise.wav", "answer": "I usually take the quieter road home because the main street gets crowded after work."}
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
The script will keep all original fields and append the following fields to the output JSONL:
|
| 12 |
+
|
| 13 |
+
```text
|
| 14 |
+
prediction # model transcription
|
| 15 |
+
metric # "wer" for English samples, "cer" for Chinese samples
|
| 16 |
+
wer # WER/CER score value; CER is also stored in this field for compatibility
|
| 17 |
+
num_edits # edit distance between prediction and ground truth
|
| 18 |
+
ref_len # number of reference words or characters
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
The script reuses the Mega-ASR inference wrapper, so it loads the base Qwen3-ASR model,
|
| 22 |
+
the Mega-ASR LoRA, and the router from the checkpoint directory:
|
| 23 |
+
|
| 24 |
+
```text
|
| 25 |
+
ckpt/Mega-ASR/
|
| 26 |
+
├── Qwen3-ASR-1.7B
|
| 27 |
+
├── mega-asr-merged
|
| 28 |
+
└── audio_quality_router/best_acc_model.safetensors
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Run Evaluation
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python src/MegaASR/eval/evaluate_wer.py \
|
| 35 |
+
--ckpt_dir ckpt/Mega-ASR \
|
| 36 |
+
--input_jsonl examples/test.jsonl \
|
| 37 |
+
--output_jsonl outputs/pred_with_wer.jsonl
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Disable routing if you want to always use the Mega-ASR LoRA:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python src/MegaASR/eval/evaluate_wer.py \
|
| 44 |
+
--ckpt_dir ckpt/Mega-ASR \
|
| 45 |
+
--input_jsonl examples/test.jsonl \
|
| 46 |
+
--output_jsonl outputs/pred_with_wer.jsonl \
|
| 47 |
+
--no-routing
|
| 48 |
+
```
|
vendor/MegaASR/model/Qwen3_ASR.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Qwen3ASR:
|
| 9 |
+
NAME = "Qwen3-ASR-1.7B"
|
| 10 |
+
HF_REPO_ID = "Qwen/Qwen3-ASR-1.7B"
|
| 11 |
+
DEFAULT_MODEL_DIR = "ckpt/Mega-ASR/Qwen3-ASR-1.7B"
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model_path: str | os.PathLike[str] | None = None,
|
| 16 |
+
*,
|
| 17 |
+
repo_id: str | None = None,
|
| 18 |
+
device_map: str | None = None,
|
| 19 |
+
dtype: Any | None = None,
|
| 20 |
+
max_inference_batch_size: int = 32,
|
| 21 |
+
max_new_tokens: int = 2048,
|
| 22 |
+
download_kwargs: dict[str, Any] | None = None,
|
| 23 |
+
**model_kwargs: Any,
|
| 24 |
+
) -> None:
|
| 25 |
+
import torch
|
| 26 |
+
from qwen_asr import Qwen3ASRModel
|
| 27 |
+
|
| 28 |
+
repo_id = repo_id or self.HF_REPO_ID
|
| 29 |
+
self.model_path = str(Path(model_path or self.DEFAULT_MODEL_DIR).expanduser())
|
| 30 |
+
if not self._has_local_model(self.model_path):
|
| 31 |
+
self.model_path = self.download_model(
|
| 32 |
+
self.model_path,
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
**(download_kwargs or {}),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if device_map is None:
|
| 38 |
+
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
if dtype is None:
|
| 40 |
+
dtype = torch.bfloat16 if device_map != "cpu" else torch.float32
|
| 41 |
+
|
| 42 |
+
self.model = Qwen3ASRModel.from_pretrained(
|
| 43 |
+
self.model_path,
|
| 44 |
+
dtype=dtype,
|
| 45 |
+
device_map=device_map,
|
| 46 |
+
max_inference_batch_size=max_inference_batch_size,
|
| 47 |
+
max_new_tokens=max_new_tokens,
|
| 48 |
+
**model_kwargs,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def _has_local_model(model_path: str | os.PathLike[str]) -> bool:
|
| 53 |
+
path = Path(model_path).expanduser()
|
| 54 |
+
return path.is_dir() and (path / "config.json").is_file()
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def download_model(
|
| 58 |
+
model_path: str | os.PathLike[str],
|
| 59 |
+
*,
|
| 60 |
+
repo_id: str,
|
| 61 |
+
**snapshot_kwargs: Any,
|
| 62 |
+
) -> str:
|
| 63 |
+
from huggingface_hub import snapshot_download
|
| 64 |
+
|
| 65 |
+
local_dir = Path(model_path).expanduser()
|
| 66 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
return snapshot_download(
|
| 69 |
+
repo_id=repo_id,
|
| 70 |
+
local_dir=str(local_dir),
|
| 71 |
+
local_dir_use_symlinks=False,
|
| 72 |
+
**snapshot_kwargs,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def infer(
|
| 76 |
+
self,
|
| 77 |
+
audio: Any,
|
| 78 |
+
*,
|
| 79 |
+
language: str | None = None,
|
| 80 |
+
return_objects: bool = False,
|
| 81 |
+
**transcribe_kwargs: Any,
|
| 82 |
+
) -> str | list[str] | Any:
|
| 83 |
+
if isinstance(audio, os.PathLike):
|
| 84 |
+
audio = str(audio)
|
| 85 |
+
elif isinstance(audio, (list, tuple)):
|
| 86 |
+
audio = [str(item) if isinstance(item, os.PathLike) else item for item in audio]
|
| 87 |
+
|
| 88 |
+
results = self.model.transcribe(
|
| 89 |
+
audio=audio,
|
| 90 |
+
language=language,
|
| 91 |
+
**transcribe_kwargs,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if return_objects:
|
| 95 |
+
return results
|
| 96 |
+
|
| 97 |
+
if isinstance(results, list):
|
| 98 |
+
return [str(getattr(result, "text", result)).strip() for result in results]
|
| 99 |
+
|
| 100 |
+
return str(getattr(results, "text", results)).strip()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_mega_asr(*args: Any, **kwargs: Any) -> Qwen3ASR:
|
| 104 |
+
return Qwen3ASR(*args, **kwargs)
|
vendor/MegaASR/model/__pycache__/Qwen3_ASR.cpython-311.pyc
ADDED
|
Binary file (5.27 kB). View file
|
|
|
vendor/MegaASR/model/__pycache__/megaASR.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
vendor/MegaASR/model/__pycache__/router.cpython-311.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|
vendor/MegaASR/model/megaASR.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import warnings
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from safetensors.torch import load_file as safe_load_file
|
| 12 |
+
|
| 13 |
+
from .Qwen3_ASR import Qwen3ASR
|
| 14 |
+
from .router import AudioQualityRouter
|
| 15 |
+
from .utils.lora_switch import LoRADeltaSwitch
|
| 16 |
+
|
| 17 |
+
class MegaASR:
|
| 18 |
+
NAME = "Mega-ASR"
|
| 19 |
+
DEFAULT_MODEL_DIR = Qwen3ASR.DEFAULT_MODEL_DIR
|
| 20 |
+
DEFAULT_LORA_DIR = "ckpt/Mega-ASR/mega-asr-merged"
|
| 21 |
+
DEFAULT_ROUTER_CHECKPOINT = AudioQualityRouter.DEFAULT_CHECKPOINT
|
| 22 |
+
DOWNLOAD_URLS = {
|
| 23 |
+
"lora": None,
|
| 24 |
+
"router": None,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
model_path: str | os.PathLike[str] | None = None,
|
| 30 |
+
*,
|
| 31 |
+
lora_dir: str | os.PathLike[str] | None = None,
|
| 32 |
+
router_checkpoint: str | os.PathLike[str] | None = None,
|
| 33 |
+
routing_enabled: bool = True,
|
| 34 |
+
quality_threshold: float = 0.5,
|
| 35 |
+
device_map: str | None = None,
|
| 36 |
+
quality_device: str | None = None,
|
| 37 |
+
max_inference_batch_size: int = 32,
|
| 38 |
+
max_new_tokens: int = 256,
|
| 39 |
+
keep_delta_on_gpu: bool = True,
|
| 40 |
+
**model_kwargs: Any,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.model_path = str(Path(model_path or self.DEFAULT_MODEL_DIR).expanduser())
|
| 43 |
+
self.lora_dir = str(Path(lora_dir or self.DEFAULT_LORA_DIR).expanduser())
|
| 44 |
+
self.router_checkpoint = str(
|
| 45 |
+
Path(router_checkpoint or self.DEFAULT_ROUTER_CHECKPOINT).expanduser()
|
| 46 |
+
)
|
| 47 |
+
self.routing_enabled = routing_enabled
|
| 48 |
+
|
| 49 |
+
self.stats = {"total": 0, "use_base": 0, "use_lora": 0}
|
| 50 |
+
self.switch_times: list[dict[str, float | str]] = []
|
| 51 |
+
|
| 52 |
+
self.router = None
|
| 53 |
+
if self.routing_enabled:
|
| 54 |
+
self.router = AudioQualityRouter(
|
| 55 |
+
checkpoint_path=self.router_checkpoint,
|
| 56 |
+
device=quality_device,
|
| 57 |
+
threshold=quality_threshold,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.asr = Qwen3ASR(
|
| 61 |
+
model_path=self.model_path,
|
| 62 |
+
device_map=device_map,
|
| 63 |
+
max_inference_batch_size=max_inference_batch_size,
|
| 64 |
+
max_new_tokens=max_new_tokens,
|
| 65 |
+
**model_kwargs,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.lora_switch = LoRADeltaSwitch(keep_delta_on_gpu=keep_delta_on_gpu)
|
| 69 |
+
self._load_loras()
|
| 70 |
+
self._set_lora(True)
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def download(cls, name: str, target_dir: str | os.PathLike[str]) -> str:
|
| 74 |
+
url = cls.DOWNLOAD_URLS.get(name)
|
| 75 |
+
if not url:
|
| 76 |
+
raise NotImplementedError(f"Download URL for {name} is not set yet.")
|
| 77 |
+
|
| 78 |
+
from huggingface_hub import snapshot_download
|
| 79 |
+
|
| 80 |
+
return snapshot_download(
|
| 81 |
+
repo_id=url,
|
| 82 |
+
local_dir=str(Path(target_dir).expanduser()),
|
| 83 |
+
local_dir_use_symlinks=False,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _load_loras(self) -> None:
|
| 87 |
+
self.lora_switch.add_adapter(
|
| 88 |
+
parent_module=self.asr.model.model,
|
| 89 |
+
adapter_dir=self.lora_dir,
|
| 90 |
+
name="mega_asr_merged_adapter",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _set_lora(self, active: bool) -> None:
|
| 94 |
+
elapsed = self.lora_switch.set_active(active)
|
| 95 |
+
if elapsed > 0:
|
| 96 |
+
direction = "base_to_lora" if active else "lora_to_base"
|
| 97 |
+
self.switch_times.append({"direction": direction, "time": elapsed})
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def _unwrap_audio(audio: Any) -> Any:
|
| 101 |
+
if isinstance(audio, (list, tuple)) and len(audio) == 1:
|
| 102 |
+
return audio[0]
|
| 103 |
+
return audio
|
| 104 |
+
|
| 105 |
+
def _route(self, audio: Any) -> tuple[bool, float | None, str]:
|
| 106 |
+
if self.routing_enabled and self.router is not None:
|
| 107 |
+
is_degraded, degraded_prob = self.router.predict(audio)
|
| 108 |
+
return is_degraded, degraded_prob, "router"
|
| 109 |
+
|
| 110 |
+
return True, None, "default"
|
| 111 |
+
|
| 112 |
+
def infer(
|
| 113 |
+
self,
|
| 114 |
+
audio: Any,
|
| 115 |
+
*,
|
| 116 |
+
language: str | None = None,
|
| 117 |
+
return_objects: bool = False,
|
| 118 |
+
return_route: bool = False,
|
| 119 |
+
**transcribe_kwargs: Any,
|
| 120 |
+
) -> Any:
|
| 121 |
+
audio = self._unwrap_audio(audio)
|
| 122 |
+
use_lora, degraded_prob, route_source = self._route(audio)
|
| 123 |
+
|
| 124 |
+
self._set_lora(use_lora)
|
| 125 |
+
result = self.asr.infer(
|
| 126 |
+
audio,
|
| 127 |
+
language=language,
|
| 128 |
+
return_objects=return_objects,
|
| 129 |
+
**transcribe_kwargs,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.stats["total"] += 1
|
| 133 |
+
if use_lora:
|
| 134 |
+
self.stats["use_lora"] += 1
|
| 135 |
+
else:
|
| 136 |
+
self.stats["use_base"] += 1
|
| 137 |
+
|
| 138 |
+
if return_route:
|
| 139 |
+
return {
|
| 140 |
+
"text": result,
|
| 141 |
+
"use_lora": use_lora,
|
| 142 |
+
"degraded_prob": degraded_prob,
|
| 143 |
+
"route_source": route_source,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
def infer_with_lora(self, audio: Any, **kwargs: Any) -> Any:
|
| 149 |
+
self._set_lora(True)
|
| 150 |
+
return self.asr.infer(self._unwrap_audio(audio), **kwargs)
|
| 151 |
+
|
| 152 |
+
def infer_without_lora(self, audio: Any, **kwargs: Any) -> Any:
|
| 153 |
+
self._set_lora(False)
|
| 154 |
+
return self.asr.infer(self._unwrap_audio(audio), **kwargs)
|
| 155 |
+
|
| 156 |
+
@torch.no_grad()
|
| 157 |
+
def batch_infer(self, audios: list[Any], **kwargs: Any) -> list[Any]:
|
| 158 |
+
audio_paths = [self._unwrap_audio(audio) for audio in audios]
|
| 159 |
+
routes = [self._route(audio) for audio in audio_paths]
|
| 160 |
+
|
| 161 |
+
base_indices = [idx for idx, route in enumerate(routes) if not route[0]]
|
| 162 |
+
lora_indices = [idx for idx, route in enumerate(routes) if route[0]]
|
| 163 |
+
|
| 164 |
+
results: list[Any] = [None] * len(audio_paths)
|
| 165 |
+
groups = [("lora", lora_indices), ("base", base_indices)]
|
| 166 |
+
if not self.lora_switch.active:
|
| 167 |
+
groups = [("base", base_indices), ("lora", lora_indices)]
|
| 168 |
+
|
| 169 |
+
for mode, indices in groups:
|
| 170 |
+
if not indices:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
use_lora = mode == "lora"
|
| 174 |
+
self._set_lora(use_lora)
|
| 175 |
+
|
| 176 |
+
for idx in indices:
|
| 177 |
+
results[idx] = self.asr.infer(audio_paths[idx], **kwargs)
|
| 178 |
+
if use_lora:
|
| 179 |
+
self.stats["use_lora"] += 1
|
| 180 |
+
else:
|
| 181 |
+
self.stats["use_base"] += 1
|
| 182 |
+
|
| 183 |
+
self.stats["total"] += len(audio_paths)
|
| 184 |
+
return results
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_mega_asr(*args: Any, **kwargs: Any) -> MegaASR:
|
| 188 |
+
return MegaASR(*args, **kwargs)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_Mega_ASR(*args: Any, **kwargs: Any) -> MegaASR:
|
| 192 |
+
return get_mega_asr(*args, **kwargs)
|
vendor/MegaASR/model/router.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from safetensors.torch import load_file as safe_load_file
|
| 14 |
+
from safetensors.torch import safe_open
|
| 15 |
+
from scipy.signal import resample_poly
|
| 16 |
+
|
| 17 |
+
from .utils.audio_quality import LogMelSpectrogram, create_audio_quality_model
|
| 18 |
+
|
| 19 |
+
class AudioQualityRouter:
|
| 20 |
+
DEFAULT_CHECKPOINT = "ckpt/Mega-ASR/audio_quality_router/best_acc_model.safetensors"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
checkpoint_path: str | os.PathLike[str] | None = None,
|
| 25 |
+
*,
|
| 26 |
+
device: str | None = None,
|
| 27 |
+
threshold: float = 0.5,
|
| 28 |
+
sample_rate: int = 16000,
|
| 29 |
+
) -> None:
|
| 30 |
+
self.checkpoint_path = str(
|
| 31 |
+
Path(checkpoint_path or self.DEFAULT_CHECKPOINT).expanduser()
|
| 32 |
+
)
|
| 33 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
self.threshold = threshold
|
| 35 |
+
self.sample_rate = sample_rate
|
| 36 |
+
|
| 37 |
+
self.model, self.mel_extractor = self._load_model()
|
| 38 |
+
|
| 39 |
+
def _load_model(self) -> tuple[torch.nn.Module, torch.nn.Module]:
|
| 40 |
+
checkpoint_path = Path(self.checkpoint_path)
|
| 41 |
+
if checkpoint_path.suffix == ".safetensors":
|
| 42 |
+
with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f:
|
| 43 |
+
metadata = f.metadata()
|
| 44 |
+
checkpoint_config = json.loads(metadata.get("config", "{}"))
|
| 45 |
+
config = checkpoint_config.get("model", {})
|
| 46 |
+
state_dict = safe_load_file(str(checkpoint_path), device=self.device)
|
| 47 |
+
else:
|
| 48 |
+
checkpoint = torch.load(
|
| 49 |
+
self.checkpoint_path,
|
| 50 |
+
map_location=self.device,
|
| 51 |
+
weights_only=False,
|
| 52 |
+
)
|
| 53 |
+
config = checkpoint.get("config", {}).get("model", {})
|
| 54 |
+
state_dict = checkpoint["model_state_dict"]
|
| 55 |
+
|
| 56 |
+
model = create_audio_quality_model(config)
|
| 57 |
+
model.load_state_dict(state_dict)
|
| 58 |
+
model.to(self.device)
|
| 59 |
+
model.eval()
|
| 60 |
+
|
| 61 |
+
mel_extractor = LogMelSpectrogram(
|
| 62 |
+
sample_rate=self.sample_rate,
|
| 63 |
+
n_mels=config.get("n_mels", 80),
|
| 64 |
+
).to(self.device)
|
| 65 |
+
mel_extractor.eval()
|
| 66 |
+
|
| 67 |
+
return model, mel_extractor
|
| 68 |
+
|
| 69 |
+
def _load_audio(self, audio_path: str | os.PathLike[str]) -> torch.Tensor:
|
| 70 |
+
audio_np, sr = sf.read(str(audio_path), always_2d=True)
|
| 71 |
+
audio_np = audio_np.mean(axis=1)
|
| 72 |
+
|
| 73 |
+
if sr != self.sample_rate:
|
| 74 |
+
gcd = math.gcd(sr, self.sample_rate)
|
| 75 |
+
audio_np = resample_poly(
|
| 76 |
+
audio_np,
|
| 77 |
+
self.sample_rate // gcd,
|
| 78 |
+
sr // gcd,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
waveform = torch.from_numpy(audio_np).float().unsqueeze(0)
|
| 82 |
+
|
| 83 |
+
return waveform.to(self.device)
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def infer(self, audio_path: str | os.PathLike[str]) -> dict[str, Any]:
|
| 87 |
+
waveform = self._load_audio(audio_path)
|
| 88 |
+
mel = self.mel_extractor(waveform)
|
| 89 |
+
mel = mel.squeeze(0).transpose(0, 1).unsqueeze(0)
|
| 90 |
+
|
| 91 |
+
logits = self.model(mel, mask=None)
|
| 92 |
+
probs = torch.softmax(logits, dim=-1)
|
| 93 |
+
degraded_prob = float(probs[0, 1].item())
|
| 94 |
+
is_degraded = degraded_prob >= self.threshold
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"is_degraded": is_degraded,
|
| 98 |
+
"degraded_prob": degraded_prob,
|
| 99 |
+
"label": int(is_degraded),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def predict(self, audio_path: str | os.PathLike[str]) -> tuple[bool, float]:
|
| 103 |
+
result = self.infer(audio_path)
|
| 104 |
+
return result["is_degraded"], result["degraded_prob"]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_router(*args: Any, **kwargs: Any) -> AudioQualityRouter:
|
| 108 |
+
return AudioQualityRouter(*args, **kwargs)
|
vendor/MegaASR/model/utils/__init__.py
ADDED
|
File without changes
|
vendor/MegaASR/model/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
vendor/MegaASR/model/utils/__pycache__/audio_quality.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
vendor/MegaASR/model/utils/__pycache__/lora_switch.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
vendor/MegaASR/model/utils/audio_quality.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchaudio
|
| 8 |
+
|
| 9 |
+
class LogMelSpectrogram(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
sample_rate: int = 16000,
|
| 13 |
+
n_mels: int = 80,
|
| 14 |
+
n_fft: int = 400,
|
| 15 |
+
hop_length: int = 160,
|
| 16 |
+
win_length: int = 400,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 20 |
+
sample_rate=sample_rate,
|
| 21 |
+
n_fft=n_fft,
|
| 22 |
+
hop_length=hop_length,
|
| 23 |
+
win_length=win_length,
|
| 24 |
+
n_mels=n_mels,
|
| 25 |
+
norm="slaney",
|
| 26 |
+
mel_scale="slaney",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
mel = self.mel_transform(waveform)
|
| 31 |
+
log_mel = torch.clamp(mel, min=1e-10).log10()
|
| 32 |
+
return (log_mel + 4.0) / 4.0
|
| 33 |
+
|
| 34 |
+
class PositionalEncoding(nn.Module):
|
| 35 |
+
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 38 |
+
|
| 39 |
+
pe = torch.zeros(max_len, d_model)
|
| 40 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 41 |
+
div_term = torch.exp(
|
| 42 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 46 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 47 |
+
|
| 48 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
x = x + self.pe[:, : x.size(1)]
|
| 52 |
+
return self.dropout(x)
|
| 53 |
+
|
| 54 |
+
class AttentionPooling(nn.Module):
|
| 55 |
+
def __init__(self, d_model: int) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.query = nn.Linear(d_model, 1)
|
| 58 |
+
|
| 59 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 60 |
+
weights = self.query(x).squeeze(-1)
|
| 61 |
+
|
| 62 |
+
if mask is not None:
|
| 63 |
+
weights = weights.masked_fill(~mask, float("-inf"))
|
| 64 |
+
|
| 65 |
+
weights = F.softmax(weights, dim=-1)
|
| 66 |
+
return torch.bmm(weights.unsqueeze(1), x).squeeze(1)
|
| 67 |
+
|
| 68 |
+
class ConvFrontend(nn.Module):
|
| 69 |
+
def __init__(self, n_mels: int, d_model: int, dropout: float = 0.1) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.conv = nn.Sequential(
|
| 73 |
+
nn.Conv1d(n_mels, d_model // 2, kernel_size=3, stride=2, padding=1),
|
| 74 |
+
nn.BatchNorm1d(d_model // 2),
|
| 75 |
+
nn.GELU(),
|
| 76 |
+
nn.Dropout(dropout),
|
| 77 |
+
nn.Conv1d(d_model // 2, d_model, kernel_size=3, stride=2, padding=1),
|
| 78 |
+
nn.BatchNorm1d(d_model),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.Dropout(dropout),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
x = x.transpose(1, 2)
|
| 85 |
+
x = self.conv(x)
|
| 86 |
+
return x.transpose(1, 2)
|
| 87 |
+
|
| 88 |
+
class AudioQualityClassifier(nn.Module):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
n_mels: int = 80,
|
| 92 |
+
d_model: int = 192,
|
| 93 |
+
nhead: int = 4,
|
| 94 |
+
dim_feedforward: int = 512,
|
| 95 |
+
dropout: float = 0.1,
|
| 96 |
+
max_len: int = 3000,
|
| 97 |
+
num_classes: int = 2,
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
self.downsample_rate = 4
|
| 102 |
+
self.frontend = ConvFrontend(n_mels, d_model, dropout)
|
| 103 |
+
self.pos_encoder = PositionalEncoding(d_model, max_len // 4 + 100, dropout)
|
| 104 |
+
|
| 105 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 106 |
+
d_model=d_model,
|
| 107 |
+
nhead=nhead,
|
| 108 |
+
dim_feedforward=dim_feedforward,
|
| 109 |
+
dropout=dropout,
|
| 110 |
+
activation="gelu",
|
| 111 |
+
batch_first=True,
|
| 112 |
+
norm_first=True,
|
| 113 |
+
)
|
| 114 |
+
self.transformer = nn.TransformerEncoder(
|
| 115 |
+
encoder_layer,
|
| 116 |
+
num_layers=1,
|
| 117 |
+
norm=nn.LayerNorm(d_model),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.pooling = AttentionPooling(d_model)
|
| 121 |
+
self.classifier = nn.Sequential(
|
| 122 |
+
nn.Linear(d_model, d_model // 2),
|
| 123 |
+
nn.GELU(),
|
| 124 |
+
nn.Dropout(dropout),
|
| 125 |
+
nn.Linear(d_model // 2, num_classes),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self._init_weights()
|
| 129 |
+
|
| 130 |
+
def _init_weights(self) -> None:
|
| 131 |
+
for module in self.modules():
|
| 132 |
+
if isinstance(module, nn.Linear):
|
| 133 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 134 |
+
if module.bias is not None:
|
| 135 |
+
nn.init.zeros_(module.bias)
|
| 136 |
+
elif isinstance(module, nn.Conv1d):
|
| 137 |
+
nn.init.kaiming_normal_(
|
| 138 |
+
module.weight,
|
| 139 |
+
mode="fan_out",
|
| 140 |
+
nonlinearity="relu",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
mels: torch.Tensor,
|
| 146 |
+
mask: torch.Tensor | None = None,
|
| 147 |
+
) -> torch.Tensor:
|
| 148 |
+
x = self.frontend(mels)
|
| 149 |
+
time_steps = x.shape[1]
|
| 150 |
+
|
| 151 |
+
if mask is not None:
|
| 152 |
+
mask = mask[:, :: self.downsample_rate]
|
| 153 |
+
if mask.shape[1] > time_steps:
|
| 154 |
+
mask = mask[:, :time_steps]
|
| 155 |
+
elif mask.shape[1] < time_steps:
|
| 156 |
+
pad = torch.ones(
|
| 157 |
+
mask.shape[0],
|
| 158 |
+
time_steps - mask.shape[1],
|
| 159 |
+
device=mask.device,
|
| 160 |
+
dtype=mask.dtype,
|
| 161 |
+
)
|
| 162 |
+
mask = torch.cat([mask, pad], dim=1)
|
| 163 |
+
|
| 164 |
+
x = self.pos_encoder(x)
|
| 165 |
+
src_key_padding_mask = ~mask if mask is not None else None
|
| 166 |
+
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
|
| 167 |
+
x = self.pooling(x, mask)
|
| 168 |
+
return self.classifier(x)
|
| 169 |
+
|
| 170 |
+
def create_audio_quality_model(config: dict) -> nn.Module:
|
| 171 |
+
return AudioQualityClassifier(
|
| 172 |
+
n_mels=config.get("n_mels", 80),
|
| 173 |
+
d_model=config.get("d_model", 192),
|
| 174 |
+
nhead=config.get("nhead", 4),
|
| 175 |
+
dim_feedforward=config.get("dim_feedforward", 512),
|
| 176 |
+
dropout=config.get("dropout", 0.1),
|
| 177 |
+
max_len=config.get("max_len", 3000),
|
| 178 |
+
num_classes=config.get("num_classes", 2),
|
| 179 |
+
)
|
vendor/MegaASR/model/utils/lora_switch.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file as safe_load_file
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LoRADeltaSwitch:
|
| 14 |
+
def __init__(self, keep_delta_on_gpu: bool = True) -> None:
|
| 15 |
+
self.keep_delta_on_gpu = keep_delta_on_gpu
|
| 16 |
+
self.items: list[dict[str, Any]] = []
|
| 17 |
+
self.active = False
|
| 18 |
+
|
| 19 |
+
def _load_adapter_state(self, adapter_dir: str | os.PathLike[str]) -> dict[str, torch.Tensor]:
|
| 20 |
+
adapter_dir = str(adapter_dir)
|
| 21 |
+
safetensors_path = os.path.join(adapter_dir, "adapter_model.safetensors")
|
| 22 |
+
bin_path = os.path.join(adapter_dir, "adapter_model.bin")
|
| 23 |
+
|
| 24 |
+
if os.path.exists(safetensors_path):
|
| 25 |
+
return safe_load_file(safetensors_path)
|
| 26 |
+
return torch.load(bin_path, map_location="cpu")
|
| 27 |
+
|
| 28 |
+
def _load_adapter_config(self, adapter_dir: str | os.PathLike[str]) -> dict[str, Any]:
|
| 29 |
+
config_path = os.path.join(str(adapter_dir), "adapter_config.json")
|
| 30 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 31 |
+
return json.load(f)
|
| 32 |
+
|
| 33 |
+
def _load_adapter_blocks(self, adapter_dir: str | os.PathLike[str]) -> dict[str, Any]:
|
| 34 |
+
blocks_path = os.path.join(str(adapter_dir), "mega_lora_blocks.json")
|
| 35 |
+
if not os.path.exists(blocks_path):
|
| 36 |
+
return {}
|
| 37 |
+
|
| 38 |
+
with open(blocks_path, "r", encoding="utf-8") as f:
|
| 39 |
+
return json.load(f)
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def _normalize_module_name(name: str) -> str:
|
| 43 |
+
for prefix in ("base_model.model.",):
|
| 44 |
+
if name.startswith(prefix):
|
| 45 |
+
name = name[len(prefix) :]
|
| 46 |
+
|
| 47 |
+
if name.startswith("thinker.layers."):
|
| 48 |
+
name = name.replace("thinker.layers.", "thinker.model.layers.", 1)
|
| 49 |
+
|
| 50 |
+
return name
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def _module_name_candidates(name: str) -> list[str]:
|
| 54 |
+
candidates = [name]
|
| 55 |
+
|
| 56 |
+
if name.startswith("model."):
|
| 57 |
+
candidates.append(name[len("model.") :])
|
| 58 |
+
|
| 59 |
+
if name.startswith("thinker.layers."):
|
| 60 |
+
candidates.append(name.replace("thinker.layers.", "thinker.model.layers.", 1))
|
| 61 |
+
|
| 62 |
+
if name.startswith("thinker.model."):
|
| 63 |
+
candidates.append(name.replace("thinker.model.", "thinker.", 1))
|
| 64 |
+
|
| 65 |
+
return list(dict.fromkeys(candidates))
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _raw_module_name(key: str, marker: str) -> str:
|
| 69 |
+
name = key.split(marker)[0]
|
| 70 |
+
for prefix in ("base_model.model.", "model."):
|
| 71 |
+
if name.startswith(prefix):
|
| 72 |
+
return name[len(prefix) :]
|
| 73 |
+
return name
|
| 74 |
+
|
| 75 |
+
def _split_lora_key(self, key: str) -> tuple[str | None, str | None, str | None]:
|
| 76 |
+
raw_key = key
|
| 77 |
+
key = self._normalize_module_name(key)
|
| 78 |
+
|
| 79 |
+
for marker in (".lora_A.", ".lora_B."):
|
| 80 |
+
if marker in key:
|
| 81 |
+
module_name = key.split(marker)[0]
|
| 82 |
+
raw_module_name = self._raw_module_name(raw_key, marker)
|
| 83 |
+
kind = "A" if marker == ".lora_A." else "B"
|
| 84 |
+
return module_name, raw_module_name, kind
|
| 85 |
+
|
| 86 |
+
return None, None, None
|
| 87 |
+
|
| 88 |
+
def add_adapter(
|
| 89 |
+
self,
|
| 90 |
+
parent_module: torch.nn.Module,
|
| 91 |
+
adapter_dir: str | os.PathLike[str],
|
| 92 |
+
name: str,
|
| 93 |
+
strip_prefixes: list[str] | None = None,
|
| 94 |
+
) -> None:
|
| 95 |
+
config = self._load_adapter_config(adapter_dir)
|
| 96 |
+
state = self._load_adapter_state(adapter_dir)
|
| 97 |
+
blocks = self._load_adapter_blocks(adapter_dir)
|
| 98 |
+
|
| 99 |
+
lora_alpha = config.get("lora_alpha", 1)
|
| 100 |
+
rank = config.get("r")
|
| 101 |
+
alpha_pattern = config.get("alpha_pattern") or {}
|
| 102 |
+
rank_pattern = config.get("rank_pattern") or {}
|
| 103 |
+
fan_in_fan_out = bool(config.get("fan_in_fan_out", False))
|
| 104 |
+
|
| 105 |
+
module_dict = dict(parent_module.named_modules())
|
| 106 |
+
grouped: dict[str, dict[str, torch.Tensor]] = {}
|
| 107 |
+
|
| 108 |
+
for key, tensor in state.items():
|
| 109 |
+
module_name, raw_module_name, kind = self._split_lora_key(key)
|
| 110 |
+
if module_name is None or raw_module_name is None or kind is None:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
if strip_prefixes:
|
| 114 |
+
for prefix in strip_prefixes:
|
| 115 |
+
if module_name.startswith(prefix):
|
| 116 |
+
module_name = module_name[len(prefix) :]
|
| 117 |
+
if raw_module_name.startswith(prefix):
|
| 118 |
+
raw_module_name = raw_module_name[len(prefix) :]
|
| 119 |
+
|
| 120 |
+
matched_name = None
|
| 121 |
+
for candidate in self._module_name_candidates(module_name):
|
| 122 |
+
if candidate in module_dict:
|
| 123 |
+
matched_name = candidate
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
target_name = matched_name or module_name
|
| 127 |
+
group_key = f"{target_name}\0{raw_module_name}"
|
| 128 |
+
item = grouped.setdefault(
|
| 129 |
+
group_key,
|
| 130 |
+
{
|
| 131 |
+
"target_module_name": target_name,
|
| 132 |
+
"raw_module_name": raw_module_name,
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
item[kind] = tensor.cpu()
|
| 136 |
+
|
| 137 |
+
loaded = 0
|
| 138 |
+
missing = []
|
| 139 |
+
|
| 140 |
+
for pair in grouped.values():
|
| 141 |
+
if "A" not in pair or "B" not in pair:
|
| 142 |
+
continue
|
| 143 |
+
module_name = pair["target_module_name"]
|
| 144 |
+
raw_module_name = pair["raw_module_name"]
|
| 145 |
+
if module_name not in module_dict:
|
| 146 |
+
missing.append(module_name)
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
module = module_dict[module_name]
|
| 150 |
+
if not hasattr(module, "weight"):
|
| 151 |
+
missing.append(module_name)
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
weight = module.weight
|
| 155 |
+
a_matrix = pair["A"].to(device=weight.device, dtype=torch.float32)
|
| 156 |
+
b_matrix = pair["B"].to(device=weight.device, dtype=torch.float32)
|
| 157 |
+
module_blocks = blocks.get(raw_module_name) or blocks.get(module_name)
|
| 158 |
+
|
| 159 |
+
if module_blocks:
|
| 160 |
+
deltas = []
|
| 161 |
+
for block in module_blocks:
|
| 162 |
+
start = int(block["start"])
|
| 163 |
+
end = int(block["end"])
|
| 164 |
+
block_rank = int(block.get("rank", end - start))
|
| 165 |
+
block_alpha = int(block.get("alpha", block_rank))
|
| 166 |
+
delta = torch.matmul(b_matrix[:, start:end], a_matrix[start:end])
|
| 167 |
+
delta = delta * (float(block_alpha) / float(block_rank))
|
| 168 |
+
if fan_in_fan_out:
|
| 169 |
+
delta = delta.T
|
| 170 |
+
deltas.append(delta)
|
| 171 |
+
else:
|
| 172 |
+
adapter_rank = rank_pattern.get(raw_module_name, rank_pattern.get(module_name, rank))
|
| 173 |
+
if adapter_rank is None:
|
| 174 |
+
adapter_rank = a_matrix.shape[0]
|
| 175 |
+
adapter_alpha = alpha_pattern.get(
|
| 176 |
+
raw_module_name,
|
| 177 |
+
alpha_pattern.get(module_name, lora_alpha),
|
| 178 |
+
)
|
| 179 |
+
scaling = float(adapter_alpha) / float(adapter_rank)
|
| 180 |
+
delta = torch.matmul(b_matrix, a_matrix) * scaling
|
| 181 |
+
if fan_in_fan_out:
|
| 182 |
+
delta = delta.T
|
| 183 |
+
deltas = [delta]
|
| 184 |
+
|
| 185 |
+
for delta in deltas:
|
| 186 |
+
if delta.shape != weight.shape:
|
| 187 |
+
try:
|
| 188 |
+
delta = delta.reshape(weight.shape)
|
| 189 |
+
except Exception:
|
| 190 |
+
missing.append(
|
| 191 |
+
f"{module_name}: delta shape {tuple(delta.shape)} != "
|
| 192 |
+
f"weight shape {tuple(weight.shape)}"
|
| 193 |
+
)
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
delta = delta.to(dtype=weight.dtype)
|
| 197 |
+
if self.keep_delta_on_gpu:
|
| 198 |
+
delta = delta.to(device=weight.device)
|
| 199 |
+
else:
|
| 200 |
+
delta = delta.cpu()
|
| 201 |
+
|
| 202 |
+
self.items.append(
|
| 203 |
+
{
|
| 204 |
+
"name": name,
|
| 205 |
+
"module_name": module_name,
|
| 206 |
+
"weight": weight,
|
| 207 |
+
"delta": delta,
|
| 208 |
+
}
|
| 209 |
+
)
|
| 210 |
+
loaded += 1
|
| 211 |
+
|
| 212 |
+
if missing:
|
| 213 |
+
warnings.warn(
|
| 214 |
+
f"LoRA adapter {name} loaded {loaded} deltas, "
|
| 215 |
+
f"missing {len(missing)} modules. Examples: {missing[:5]}",
|
| 216 |
+
stacklevel=2,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def set_active(self, active: bool) -> float:
|
| 221 |
+
if self.active == active:
|
| 222 |
+
return 0.0
|
| 223 |
+
|
| 224 |
+
start = time.perf_counter()
|
| 225 |
+
sign = 1.0 if active else -1.0
|
| 226 |
+
|
| 227 |
+
for item in self.items:
|
| 228 |
+
weight = item["weight"]
|
| 229 |
+
delta = item["delta"]
|
| 230 |
+
if delta.device != weight.device:
|
| 231 |
+
delta = delta.to(device=weight.device)
|
| 232 |
+
weight.data.add_(delta, alpha=sign)
|
| 233 |
+
|
| 234 |
+
self.active = active
|
| 235 |
+
return time.perf_counter() - start
|