Spaces:
Sleeping
Sleeping
File size: 13,000 Bytes
745f62a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | """
Rigorous ASR + normalization test suite for Sakhi.
Tests:
1. Hindi number parser — edge cases, Whisper misspellings, compound numbers
2. Medical term normalization — all abbreviations
3. Full normalize_transcript — real Whisper output patterns
4. Live ASR — transcribe test_audio/ files, verify medical values extracted
5. Round-trip — ASR output → normalization → check expected values
Usage:
python scripts/test_asr.py
python scripts/test_asr.py --skip-gpu # skip live ASR tests (no GPU needed)
"""
import sys
import os
import argparse
import time
os.environ["PYTHONIOENCODING"] = "utf-8"
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.hindi_normalize import (
parse_hindi_number, convert_numbers, normalize_transcript,
WORD_TO_NUM, MEDICAL_TERMS
)
PASS = 0
FAIL = 0
def safe_print(s):
"""Print with fallback for Windows cp1252 encoding."""
try:
print(s)
except UnicodeEncodeError:
print(s.encode('ascii', errors='replace').decode('ascii'))
def check(name, got, expected, exact=True):
global PASS, FAIL
if exact:
ok = got == expected
else:
# substring check
ok = expected in str(got)
if ok:
PASS += 1
else:
FAIL += 1
safe_print(f" FAIL: {name}")
safe_print(f" got: {got!r}")
safe_print(f" expected: {expected!r}")
def test_number_parser():
"""Test parse_hindi_number on individual and compound numbers."""
print("\n=== 1. Number Parser ===")
# Singles (0-99)
singles = [
("शून्य", 0), ("एक", 1), ("दो", 2), ("तीन", 3), ("चार", 4),
("पांच", 5), ("पाँच", 5), ("छह", 6), ("सात", 7), ("आठ", 8),
("नौ", 9), ("दस", 10), ("ग्यारह", 11), ("बारह", 12), ("तेरह", 13),
("चौदह", 14), ("पंद्रह", 15), ("सोलह", 16), ("सत्रह", 17),
("अठारह", 18), ("उन्नीस", 19), ("बीस", 20), ("पच्चीस", 25),
("तीस", 30), ("पैंतीस", 35), ("चालीस", 40), ("पैंतालीस", 45),
("पचास", 50), ("पचपन", 55), ("अट्ठावन", 58), ("साठ", 60),
("पैंसठ", 65), ("सत्तर", 70), ("पचहत्तर", 75), ("अस्सी", 80),
("पचासी", 85), ("नब्बे", 90), ("पंचानवे", 95), ("निन्यानवे", 99),
]
for word, expected in singles:
check(f"parse({word})", parse_hindi_number(word), expected)
# Whisper misspellings
misspellings = [
("गयारह", 11), ("बारा", 12), ("पंद्रा", 15), ("इक्किस", 21),
("बाइस", 22), ("पचीस", 25), ("अठ्ठाईस", 28), ("बतीस", 32),
("चालिस", 40), ("अठावन", 58), ("सतर", 70), ("अठहतर", 78),
("उनासी", 79), ("अस्सि", 80), ("पचानवे", 95),
]
for word, expected in misspellings:
check(f"misspell({word})", parse_hindi_number(word), expected)
# Compounds (100+)
compounds = [
("सौ", 100), ("सो", 100),
("एक सौ", 100), ("एक सो", 100),
("एक सौ दस", 110), ("एक सौ बीस", 120),
("एक सौ पंद्रह", 115), ("एक सौ पच्चीस", 125),
("एक सौ तीस", 130), ("एक सौ पैंतीस", 135),
("एक सौ चालीस", 140), ("एक सौ पैंतालीस", 145),
("एक सौ पचास", 150), ("एक सौ पचपन", 155),
("एक सौ साठ", 160), ("एक सौ पैंसठ", 165),
("एक सौ सत्तर", 170), ("एक सौ अस्सी", 180),
("एक सौ नब्बे", 190),
("दो सौ", 200), ("तीन सौ", 300),
]
for text, expected in compounds:
check(f"compound({text})", parse_hindi_number(text), expected)
# Common BP values
bp_values = [
("एक सौ दस", 110), ("एक सौ बीस", 120), ("एक सौ तीस", 130),
("एक सौ चालीस", 140), ("एक सौ पचास", 150), ("एक सौ साठ", 160),
]
for text, expected in bp_values:
check(f"bp({text})", parse_hindi_number(text), expected)
def test_compound_splits():
"""Test that merged words like 'एकसो' are handled."""
print("\n=== 2. Compound Word Splits ===")
cases = [
("एकसो दस", "110"),
("एकसो बीस", "120"),
("एकसो पचपन", "155"),
("दोसो", "200"),
]
for inp, expected in cases:
result = convert_numbers(inp)
check(f"split({inp})", expected in result, True)
def test_medical_terms():
"""Test all medical abbreviation conversions."""
print("\n=== 3. Medical Terms ===")
cases = [
("बीपी", "BP"), ("भीपी", "BP"), ("बी पी", "BP"), ("बी.पी.", "BP"),
("एचबी", "Hb"), ("हीमोग्लोबिन", "Hb"), ("एच बी", "Hb"),
("आईएफए", "IFA"), ("टीटी", "TT"), ("टी टी", "TT"),
("पीएचसी", "PHC"), ("पी एच सी", "PHC"),
("सीएचसी", "CHC"), ("बीसीजी", "BCG"), ("ओपीवी", "OPV"),
("किलो", "kg"), ("किलोग्राम", "kg"),
("बटा", "/"), ("दशमलव", "."),
]
for hindi, expected in cases:
result = normalize_transcript(f"test {hindi} test")
check(f"med({hindi}→{expected})", expected in result, True)
def test_full_normalization():
"""Test normalize_transcript on realistic Whisper output patterns."""
print("\n=== 4. Full Normalization (realistic patterns) ===")
cases = [
# BP values
("आपका बीपी एक सौ दस बटा सत्तर है", "110/70"),
("बीपी एकसो पचपन बटा सौ", "155/100"),
("बी पी एक सौ बीस बटा अस्सी है", "120/80"),
# Weight
("वजन अट्ठावन किलो है", "58 kg"),
("वजन पचास किलोग्राम", "50 kg"),
# Hemoglobin with decimal
("एचबी ग्यारह दशमलव पांच है", "Hb 11.5"),
("हीमोग्लोबिन बारह दशमलव दो", "Hb 12.2"),
# Gestational weeks
("लगभग चौबीस हफ्ते", "24 हफ्ते"),
("बत्तीस हफ्ते की", "32 हफ्ते"),
# Temperature
("तापमान सौ दशमलव पांच डिग्री", "100.5"),
# Mixed digits and words
("बीपी 110 बटा सत्तर", "110/70"),
("बीपी 110/सतर", "110/70"),
# Whisper repetition fix
("हाँ हाँ हाँ हाँ हाँ ठीक है", "हाँ ठीक है"),
# Sentence boundaries
("पहला। दूसरा। तीसरा", "पहला।\nदूसरा।\nतीसरा"),
# Already-digit pass-through
("BP 120/80 है, weight 55 kg", "120/80"),
]
for inp, expect_substr in cases:
result = normalize_transcript(inp)
check(f"norm({inp[:40]}...→{expect_substr})", expect_substr in result, True)
def test_edge_cases():
"""Edge cases that could break the parser."""
print("\n=== 5. Edge Cases ===")
# Empty input
check("empty", normalize_transcript(""), "")
# Only numbers
check("only_num", "25" in convert_numbers("पच्चीस"), True)
# Number at start of text
check("num_start", normalize_transcript("पच्चीस हफ्ते").startswith("25"), True)
# Number at end of text
check("num_end", normalize_transcript("हफ्ते पच्चीस").endswith("25"), True)
# Adjacent numbers separated by non-number word
result = normalize_transcript("वजन पचास और उम्र तीस")
check("adjacent_nums", "50" in result and "30" in result, True)
# Don't convert number words inside other Hindi words
# "एकतरफ" should NOT become "1तरफ"
result = convert_numbers("एकतरफ")
check("no_partial", "1" not in result, True)
# Very large numbers (should still work)
check("nine_hundred", parse_hindi_number("नौ सौ निन्यानवे"), 999)
def test_real_whisper_transcripts():
"""Test on actual saved Whisper transcripts from previous sessions."""
print("\n=== 6. Real Whisper Transcripts ===")
test_files = {
"postprocess_test.txt": ["110", "70"], # BP should be found after norm
"postprocess_test2.txt": ["110", "70"], # Has "110/सतर"
"pp_test.txt": ["110/70", "58 kg", "11.5"], # Already clean
}
for fname, expected_values in test_files.items():
path = os.path.join(os.path.dirname(os.path.dirname(__file__)), fname)
if not os.path.exists(path):
print(f" SKIP: {fname} not found")
continue
raw = open(path, encoding="utf-8").read().strip()
result = normalize_transcript(raw)
for val in expected_values:
check(f"file({fname}→{val})", val in result, True)
def test_live_asr():
"""Test actual ASR transcription on test audio files."""
print("\n=== 7. Live ASR (GPU required) ===")
audio_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_audio")
if not os.path.exists(audio_dir):
print(" SKIP: test_audio/ not found")
return
audio_files = [f for f in os.listdir(audio_dir) if f.endswith((".mp3", ".wav"))]
if not audio_files:
print(" SKIP: no audio files found")
return
# Expected values per file
expectations = {
"anc_normal.mp3": {
"values": ["110", "70", "58", "11.5", "24"],
"terms": ["BP", "kg", "Hb"],
},
"anc_danger.mp3": {
"values": ["155", "100"],
"terms": ["BP", "PHC"],
},
}
from faster_whisper import WhisperModel
ct2_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models", "whisper-hindi-ct2")
if os.path.exists(ct2_path):
print(f" Loading CT2 model from {ct2_path}...")
model = WhisperModel(ct2_path, device="cuda", compute_type="float16")
else:
print(" Loading collabora/whisper-large-v2-hindi from HuggingFace...")
model = WhisperModel("collabora/whisper-large-v2-hindi", device="cuda", compute_type="float16")
for fname in audio_files:
fpath = os.path.join(audio_dir, fname)
expect = expectations.get(fname, {"values": [], "terms": []})
print(f"\n --- {fname} ---")
t0 = time.time()
segments, info = model.transcribe(fpath, language="hi", task="transcribe")
raw_text = " ".join(seg.text.strip() for seg in segments)
asr_time = time.time() - t0
t0 = time.time()
normalized = normalize_transcript(raw_text)
norm_time = time.time() - t0
safe_print(f" ASR time: {asr_time:.1f}s")
safe_print(f" Norm time: {norm_time*1000:.0f}ms")
safe_print(f" Raw: {raw_text[:150]}...")
safe_print(f" Norm: {normalized[:150]}...")
for val in expect["values"]:
found_raw = val in raw_text
found_norm = val in normalized
status = "RAW+NORM" if found_raw else ("NORM" if found_norm else "MISS")
check(f"asr({fname}→{val})", found_norm, True)
safe_print(f" {val}: {status}")
for term in expect["terms"]:
check(f"asr({fname}→{term})", term in normalized, True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--skip-gpu", action="store_true", help="Skip live ASR tests")
args = parser.parse_args()
print("=" * 60)
print("SAKHI ASR + NORMALIZATION TEST SUITE")
print("=" * 60)
test_number_parser()
test_compound_splits()
test_medical_terms()
test_full_normalization()
test_edge_cases()
test_real_whisper_transcripts()
if not args.skip_gpu:
test_live_asr()
else:
print("\n=== 7. Live ASR — SKIPPED (--skip-gpu) ===")
print("\n" + "=" * 60)
print(f"RESULTS: {PASS} passed, {FAIL} failed")
print("=" * 60)
return 0 if FAIL == 0 else 1
if __name__ == "__main__":
sys.exit(main())
|