sakhi / scripts /test_asr.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
Rigorous ASR + normalization test suite for Sakhi.
Tests:
1. Hindi number parser — edge cases, Whisper misspellings, compound numbers
2. Medical term normalization — all abbreviations
3. Full normalize_transcript — real Whisper output patterns
4. Live ASR — transcribe test_audio/ files, verify medical values extracted
5. Round-trip — ASR output → normalization → check expected values
Usage:
python scripts/test_asr.py
python scripts/test_asr.py --skip-gpu # skip live ASR tests (no GPU needed)
"""
import sys
import os
import argparse
import time
os.environ["PYTHONIOENCODING"] = "utf-8"
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.hindi_normalize import (
parse_hindi_number, convert_numbers, normalize_transcript,
WORD_TO_NUM, MEDICAL_TERMS
)
PASS = 0
FAIL = 0
def safe_print(s):
"""Print with fallback for Windows cp1252 encoding."""
try:
print(s)
except UnicodeEncodeError:
print(s.encode('ascii', errors='replace').decode('ascii'))
def check(name, got, expected, exact=True):
global PASS, FAIL
if exact:
ok = got == expected
else:
# substring check
ok = expected in str(got)
if ok:
PASS += 1
else:
FAIL += 1
safe_print(f" FAIL: {name}")
safe_print(f" got: {got!r}")
safe_print(f" expected: {expected!r}")
def test_number_parser():
"""Test parse_hindi_number on individual and compound numbers."""
print("\n=== 1. Number Parser ===")
# Singles (0-99)
singles = [
("शून्य", 0), ("एक", 1), ("दो", 2), ("तीन", 3), ("चार", 4),
("पांच", 5), ("पाँच", 5), ("छह", 6), ("सात", 7), ("आठ", 8),
("नौ", 9), ("दस", 10), ("ग्यारह", 11), ("बारह", 12), ("तेरह", 13),
("चौदह", 14), ("पंद्रह", 15), ("सोलह", 16), ("सत्रह", 17),
("अठारह", 18), ("उन्नीस", 19), ("बीस", 20), ("पच्चीस", 25),
("तीस", 30), ("पैंतीस", 35), ("चालीस", 40), ("पैंतालीस", 45),
("पचास", 50), ("पचपन", 55), ("अट्ठावन", 58), ("साठ", 60),
("पैंसठ", 65), ("सत्तर", 70), ("पचहत्तर", 75), ("अस्सी", 80),
("पचासी", 85), ("नब्बे", 90), ("पंचानवे", 95), ("निन्यानवे", 99),
]
for word, expected in singles:
check(f"parse({word})", parse_hindi_number(word), expected)
# Whisper misspellings
misspellings = [
("गयारह", 11), ("बारा", 12), ("पंद्रा", 15), ("इक्किस", 21),
("बाइस", 22), ("पचीस", 25), ("अठ्ठाईस", 28), ("बतीस", 32),
("चालिस", 40), ("अठावन", 58), ("सतर", 70), ("अठहतर", 78),
("उनासी", 79), ("अस्सि", 80), ("पचानवे", 95),
]
for word, expected in misspellings:
check(f"misspell({word})", parse_hindi_number(word), expected)
# Compounds (100+)
compounds = [
("सौ", 100), ("सो", 100),
("एक सौ", 100), ("एक सो", 100),
("एक सौ दस", 110), ("एक सौ बीस", 120),
("एक सौ पंद्रह", 115), ("एक सौ पच्चीस", 125),
("एक सौ तीस", 130), ("एक सौ पैंतीस", 135),
("एक सौ चालीस", 140), ("एक सौ पैंतालीस", 145),
("एक सौ पचास", 150), ("एक सौ पचपन", 155),
("एक सौ साठ", 160), ("एक सौ पैंसठ", 165),
("एक सौ सत्तर", 170), ("एक सौ अस्सी", 180),
("एक सौ नब्बे", 190),
("दो सौ", 200), ("तीन सौ", 300),
]
for text, expected in compounds:
check(f"compound({text})", parse_hindi_number(text), expected)
# Common BP values
bp_values = [
("एक सौ दस", 110), ("एक सौ बीस", 120), ("एक सौ तीस", 130),
("एक सौ चालीस", 140), ("एक सौ पचास", 150), ("एक सौ साठ", 160),
]
for text, expected in bp_values:
check(f"bp({text})", parse_hindi_number(text), expected)
def test_compound_splits():
"""Test that merged words like 'एकसो' are handled."""
print("\n=== 2. Compound Word Splits ===")
cases = [
("एकसो दस", "110"),
("एकसो बीस", "120"),
("एकसो पचपन", "155"),
("दोसो", "200"),
]
for inp, expected in cases:
result = convert_numbers(inp)
check(f"split({inp})", expected in result, True)
def test_medical_terms():
"""Test all medical abbreviation conversions."""
print("\n=== 3. Medical Terms ===")
cases = [
("बीपी", "BP"), ("भीपी", "BP"), ("बी पी", "BP"), ("बी.पी.", "BP"),
("एचबी", "Hb"), ("हीमोग्लोबिन", "Hb"), ("एच बी", "Hb"),
("आईएफए", "IFA"), ("टीटी", "TT"), ("टी टी", "TT"),
("पीएचसी", "PHC"), ("पी एच सी", "PHC"),
("सीएचसी", "CHC"), ("बीसीजी", "BCG"), ("ओपीवी", "OPV"),
("किलो", "kg"), ("किलोग्राम", "kg"),
("बटा", "/"), ("दशमलव", "."),
]
for hindi, expected in cases:
result = normalize_transcript(f"test {hindi} test")
check(f"med({hindi}{expected})", expected in result, True)
def test_full_normalization():
"""Test normalize_transcript on realistic Whisper output patterns."""
print("\n=== 4. Full Normalization (realistic patterns) ===")
cases = [
# BP values
("आपका बीपी एक सौ दस बटा सत्तर है", "110/70"),
("बीपी एकसो पचपन बटा सौ", "155/100"),
("बी पी एक सौ बीस बटा अस्सी है", "120/80"),
# Weight
("वजन अट्ठावन किलो है", "58 kg"),
("वजन पचास किलोग्राम", "50 kg"),
# Hemoglobin with decimal
("एचबी ग्यारह दशमलव पांच है", "Hb 11.5"),
("हीमोग्लोबिन बारह दशमलव दो", "Hb 12.2"),
# Gestational weeks
("लगभग चौबीस हफ्ते", "24 हफ्ते"),
("बत्तीस हफ्ते की", "32 हफ्ते"),
# Temperature
("तापमान सौ दशमलव पांच डिग्री", "100.5"),
# Mixed digits and words
("बीपी 110 बटा सत्तर", "110/70"),
("बीपी 110/सतर", "110/70"),
# Whisper repetition fix
("हाँ हाँ हाँ हाँ हाँ ठीक है", "हाँ ठीक है"),
# Sentence boundaries
("पहला। दूसरा। तीसरा", "पहला।\nदूसरा।\nतीसरा"),
# Already-digit pass-through
("BP 120/80 है, weight 55 kg", "120/80"),
]
for inp, expect_substr in cases:
result = normalize_transcript(inp)
check(f"norm({inp[:40]}...→{expect_substr})", expect_substr in result, True)
def test_edge_cases():
"""Edge cases that could break the parser."""
print("\n=== 5. Edge Cases ===")
# Empty input
check("empty", normalize_transcript(""), "")
# Only numbers
check("only_num", "25" in convert_numbers("पच्चीस"), True)
# Number at start of text
check("num_start", normalize_transcript("पच्चीस हफ्ते").startswith("25"), True)
# Number at end of text
check("num_end", normalize_transcript("हफ्ते पच्चीस").endswith("25"), True)
# Adjacent numbers separated by non-number word
result = normalize_transcript("वजन पचास और उम्र तीस")
check("adjacent_nums", "50" in result and "30" in result, True)
# Don't convert number words inside other Hindi words
# "एकतरफ" should NOT become "1तरफ"
result = convert_numbers("एकतरफ")
check("no_partial", "1" not in result, True)
# Very large numbers (should still work)
check("nine_hundred", parse_hindi_number("नौ सौ निन्यानवे"), 999)
def test_real_whisper_transcripts():
"""Test on actual saved Whisper transcripts from previous sessions."""
print("\n=== 6. Real Whisper Transcripts ===")
test_files = {
"postprocess_test.txt": ["110", "70"], # BP should be found after norm
"postprocess_test2.txt": ["110", "70"], # Has "110/सतर"
"pp_test.txt": ["110/70", "58 kg", "11.5"], # Already clean
}
for fname, expected_values in test_files.items():
path = os.path.join(os.path.dirname(os.path.dirname(__file__)), fname)
if not os.path.exists(path):
print(f" SKIP: {fname} not found")
continue
raw = open(path, encoding="utf-8").read().strip()
result = normalize_transcript(raw)
for val in expected_values:
check(f"file({fname}{val})", val in result, True)
def test_live_asr():
"""Test actual ASR transcription on test audio files."""
print("\n=== 7. Live ASR (GPU required) ===")
audio_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_audio")
if not os.path.exists(audio_dir):
print(" SKIP: test_audio/ not found")
return
audio_files = [f for f in os.listdir(audio_dir) if f.endswith((".mp3", ".wav"))]
if not audio_files:
print(" SKIP: no audio files found")
return
# Expected values per file
expectations = {
"anc_normal.mp3": {
"values": ["110", "70", "58", "11.5", "24"],
"terms": ["BP", "kg", "Hb"],
},
"anc_danger.mp3": {
"values": ["155", "100"],
"terms": ["BP", "PHC"],
},
}
from faster_whisper import WhisperModel
ct2_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models", "whisper-hindi-ct2")
if os.path.exists(ct2_path):
print(f" Loading CT2 model from {ct2_path}...")
model = WhisperModel(ct2_path, device="cuda", compute_type="float16")
else:
print(" Loading collabora/whisper-large-v2-hindi from HuggingFace...")
model = WhisperModel("collabora/whisper-large-v2-hindi", device="cuda", compute_type="float16")
for fname in audio_files:
fpath = os.path.join(audio_dir, fname)
expect = expectations.get(fname, {"values": [], "terms": []})
print(f"\n --- {fname} ---")
t0 = time.time()
segments, info = model.transcribe(fpath, language="hi", task="transcribe")
raw_text = " ".join(seg.text.strip() for seg in segments)
asr_time = time.time() - t0
t0 = time.time()
normalized = normalize_transcript(raw_text)
norm_time = time.time() - t0
safe_print(f" ASR time: {asr_time:.1f}s")
safe_print(f" Norm time: {norm_time*1000:.0f}ms")
safe_print(f" Raw: {raw_text[:150]}...")
safe_print(f" Norm: {normalized[:150]}...")
for val in expect["values"]:
found_raw = val in raw_text
found_norm = val in normalized
status = "RAW+NORM" if found_raw else ("NORM" if found_norm else "MISS")
check(f"asr({fname}{val})", found_norm, True)
safe_print(f" {val}: {status}")
for term in expect["terms"]:
check(f"asr({fname}{term})", term in normalized, True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--skip-gpu", action="store_true", help="Skip live ASR tests")
args = parser.parse_args()
print("=" * 60)
print("SAKHI ASR + NORMALIZATION TEST SUITE")
print("=" * 60)
test_number_parser()
test_compound_splits()
test_medical_terms()
test_full_normalization()
test_edge_cases()
test_real_whisper_transcripts()
if not args.skip_gpu:
test_live_asr()
else:
print("\n=== 7. Live ASR — SKIPPED (--skip-gpu) ===")
print("\n" + "=" * 60)
print(f"RESULTS: {PASS} passed, {FAIL} failed")
print("=" * 60)
return 0 if FAIL == 0 else 1
if __name__ == "__main__":
sys.exit(main())