Spaces:
Sleeping
Sleeping
| """ | |
| Rigorous ASR + normalization test suite for Sakhi. | |
| Tests: | |
| 1. Hindi number parser — edge cases, Whisper misspellings, compound numbers | |
| 2. Medical term normalization — all abbreviations | |
| 3. Full normalize_transcript — real Whisper output patterns | |
| 4. Live ASR — transcribe test_audio/ files, verify medical values extracted | |
| 5. Round-trip — ASR output → normalization → check expected values | |
| Usage: | |
| python scripts/test_asr.py | |
| python scripts/test_asr.py --skip-gpu # skip live ASR tests (no GPU needed) | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| import time | |
| os.environ["PYTHONIOENCODING"] = "utf-8" | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.hindi_normalize import ( | |
| parse_hindi_number, convert_numbers, normalize_transcript, | |
| WORD_TO_NUM, MEDICAL_TERMS | |
| ) | |
| PASS = 0 | |
| FAIL = 0 | |
| def safe_print(s): | |
| """Print with fallback for Windows cp1252 encoding.""" | |
| try: | |
| print(s) | |
| except UnicodeEncodeError: | |
| print(s.encode('ascii', errors='replace').decode('ascii')) | |
| def check(name, got, expected, exact=True): | |
| global PASS, FAIL | |
| if exact: | |
| ok = got == expected | |
| else: | |
| # substring check | |
| ok = expected in str(got) | |
| if ok: | |
| PASS += 1 | |
| else: | |
| FAIL += 1 | |
| safe_print(f" FAIL: {name}") | |
| safe_print(f" got: {got!r}") | |
| safe_print(f" expected: {expected!r}") | |
| def test_number_parser(): | |
| """Test parse_hindi_number on individual and compound numbers.""" | |
| print("\n=== 1. Number Parser ===") | |
| # Singles (0-99) | |
| singles = [ | |
| ("शून्य", 0), ("एक", 1), ("दो", 2), ("तीन", 3), ("चार", 4), | |
| ("पांच", 5), ("पाँच", 5), ("छह", 6), ("सात", 7), ("आठ", 8), | |
| ("नौ", 9), ("दस", 10), ("ग्यारह", 11), ("बारह", 12), ("तेरह", 13), | |
| ("चौदह", 14), ("पंद्रह", 15), ("सोलह", 16), ("सत्रह", 17), | |
| ("अठारह", 18), ("उन्नीस", 19), ("बीस", 20), ("पच्चीस", 25), | |
| ("तीस", 30), ("पैंतीस", 35), ("चालीस", 40), ("पैंतालीस", 45), | |
| ("पचास", 50), ("पचपन", 55), ("अट्ठावन", 58), ("साठ", 60), | |
| ("पैंसठ", 65), ("सत्तर", 70), ("पचहत्तर", 75), ("अस्सी", 80), | |
| ("पचासी", 85), ("नब्बे", 90), ("पंचानवे", 95), ("निन्यानवे", 99), | |
| ] | |
| for word, expected in singles: | |
| check(f"parse({word})", parse_hindi_number(word), expected) | |
| # Whisper misspellings | |
| misspellings = [ | |
| ("गयारह", 11), ("बारा", 12), ("पंद्रा", 15), ("इक्किस", 21), | |
| ("बाइस", 22), ("पचीस", 25), ("अठ्ठाईस", 28), ("बतीस", 32), | |
| ("चालिस", 40), ("अठावन", 58), ("सतर", 70), ("अठहतर", 78), | |
| ("उनासी", 79), ("अस्सि", 80), ("पचानवे", 95), | |
| ] | |
| for word, expected in misspellings: | |
| check(f"misspell({word})", parse_hindi_number(word), expected) | |
| # Compounds (100+) | |
| compounds = [ | |
| ("सौ", 100), ("सो", 100), | |
| ("एक सौ", 100), ("एक सो", 100), | |
| ("एक सौ दस", 110), ("एक सौ बीस", 120), | |
| ("एक सौ पंद्रह", 115), ("एक सौ पच्चीस", 125), | |
| ("एक सौ तीस", 130), ("एक सौ पैंतीस", 135), | |
| ("एक सौ चालीस", 140), ("एक सौ पैंतालीस", 145), | |
| ("एक सौ पचास", 150), ("एक सौ पचपन", 155), | |
| ("एक सौ साठ", 160), ("एक सौ पैंसठ", 165), | |
| ("एक सौ सत्तर", 170), ("एक सौ अस्सी", 180), | |
| ("एक सौ नब्बे", 190), | |
| ("दो सौ", 200), ("तीन सौ", 300), | |
| ] | |
| for text, expected in compounds: | |
| check(f"compound({text})", parse_hindi_number(text), expected) | |
| # Common BP values | |
| bp_values = [ | |
| ("एक सौ दस", 110), ("एक सौ बीस", 120), ("एक सौ तीस", 130), | |
| ("एक सौ चालीस", 140), ("एक सौ पचास", 150), ("एक सौ साठ", 160), | |
| ] | |
| for text, expected in bp_values: | |
| check(f"bp({text})", parse_hindi_number(text), expected) | |
| def test_compound_splits(): | |
| """Test that merged words like 'एकसो' are handled.""" | |
| print("\n=== 2. Compound Word Splits ===") | |
| cases = [ | |
| ("एकसो दस", "110"), | |
| ("एकसो बीस", "120"), | |
| ("एकसो पचपन", "155"), | |
| ("दोसो", "200"), | |
| ] | |
| for inp, expected in cases: | |
| result = convert_numbers(inp) | |
| check(f"split({inp})", expected in result, True) | |
| def test_medical_terms(): | |
| """Test all medical abbreviation conversions.""" | |
| print("\n=== 3. Medical Terms ===") | |
| cases = [ | |
| ("बीपी", "BP"), ("भीपी", "BP"), ("बी पी", "BP"), ("बी.पी.", "BP"), | |
| ("एचबी", "Hb"), ("हीमोग्लोबिन", "Hb"), ("एच बी", "Hb"), | |
| ("आईएफए", "IFA"), ("टीटी", "TT"), ("टी टी", "TT"), | |
| ("पीएचसी", "PHC"), ("पी एच सी", "PHC"), | |
| ("सीएचसी", "CHC"), ("बीसीजी", "BCG"), ("ओपीवी", "OPV"), | |
| ("किलो", "kg"), ("किलोग्राम", "kg"), | |
| ("बटा", "/"), ("दशमलव", "."), | |
| ] | |
| for hindi, expected in cases: | |
| result = normalize_transcript(f"test {hindi} test") | |
| check(f"med({hindi}→{expected})", expected in result, True) | |
| def test_full_normalization(): | |
| """Test normalize_transcript on realistic Whisper output patterns.""" | |
| print("\n=== 4. Full Normalization (realistic patterns) ===") | |
| cases = [ | |
| # BP values | |
| ("आपका बीपी एक सौ दस बटा सत्तर है", "110/70"), | |
| ("बीपी एकसो पचपन बटा सौ", "155/100"), | |
| ("बी पी एक सौ बीस बटा अस्सी है", "120/80"), | |
| # Weight | |
| ("वजन अट्ठावन किलो है", "58 kg"), | |
| ("वजन पचास किलोग्राम", "50 kg"), | |
| # Hemoglobin with decimal | |
| ("एचबी ग्यारह दशमलव पांच है", "Hb 11.5"), | |
| ("हीमोग्लोबिन बारह दशमलव दो", "Hb 12.2"), | |
| # Gestational weeks | |
| ("लगभग चौबीस हफ्ते", "24 हफ्ते"), | |
| ("बत्तीस हफ्ते की", "32 हफ्ते"), | |
| # Temperature | |
| ("तापमान सौ दशमलव पांच डिग्री", "100.5"), | |
| # Mixed digits and words | |
| ("बीपी 110 बटा सत्तर", "110/70"), | |
| ("बीपी 110/सतर", "110/70"), | |
| # Whisper repetition fix | |
| ("हाँ हाँ हाँ हाँ हाँ ठीक है", "हाँ ठीक है"), | |
| # Sentence boundaries | |
| ("पहला। दूसरा। तीसरा", "पहला।\nदूसरा।\nतीसरा"), | |
| # Already-digit pass-through | |
| ("BP 120/80 है, weight 55 kg", "120/80"), | |
| ] | |
| for inp, expect_substr in cases: | |
| result = normalize_transcript(inp) | |
| check(f"norm({inp[:40]}...→{expect_substr})", expect_substr in result, True) | |
| def test_edge_cases(): | |
| """Edge cases that could break the parser.""" | |
| print("\n=== 5. Edge Cases ===") | |
| # Empty input | |
| check("empty", normalize_transcript(""), "") | |
| # Only numbers | |
| check("only_num", "25" in convert_numbers("पच्चीस"), True) | |
| # Number at start of text | |
| check("num_start", normalize_transcript("पच्चीस हफ्ते").startswith("25"), True) | |
| # Number at end of text | |
| check("num_end", normalize_transcript("हफ्ते पच्चीस").endswith("25"), True) | |
| # Adjacent numbers separated by non-number word | |
| result = normalize_transcript("वजन पचास और उम्र तीस") | |
| check("adjacent_nums", "50" in result and "30" in result, True) | |
| # Don't convert number words inside other Hindi words | |
| # "एकतरफ" should NOT become "1तरफ" | |
| result = convert_numbers("एकतरफ") | |
| check("no_partial", "1" not in result, True) | |
| # Very large numbers (should still work) | |
| check("nine_hundred", parse_hindi_number("नौ सौ निन्यानवे"), 999) | |
| def test_real_whisper_transcripts(): | |
| """Test on actual saved Whisper transcripts from previous sessions.""" | |
| print("\n=== 6. Real Whisper Transcripts ===") | |
| test_files = { | |
| "postprocess_test.txt": ["110", "70"], # BP should be found after norm | |
| "postprocess_test2.txt": ["110", "70"], # Has "110/सतर" | |
| "pp_test.txt": ["110/70", "58 kg", "11.5"], # Already clean | |
| } | |
| for fname, expected_values in test_files.items(): | |
| path = os.path.join(os.path.dirname(os.path.dirname(__file__)), fname) | |
| if not os.path.exists(path): | |
| print(f" SKIP: {fname} not found") | |
| continue | |
| raw = open(path, encoding="utf-8").read().strip() | |
| result = normalize_transcript(raw) | |
| for val in expected_values: | |
| check(f"file({fname}→{val})", val in result, True) | |
| def test_live_asr(): | |
| """Test actual ASR transcription on test audio files.""" | |
| print("\n=== 7. Live ASR (GPU required) ===") | |
| audio_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_audio") | |
| if not os.path.exists(audio_dir): | |
| print(" SKIP: test_audio/ not found") | |
| return | |
| audio_files = [f for f in os.listdir(audio_dir) if f.endswith((".mp3", ".wav"))] | |
| if not audio_files: | |
| print(" SKIP: no audio files found") | |
| return | |
| # Expected values per file | |
| expectations = { | |
| "anc_normal.mp3": { | |
| "values": ["110", "70", "58", "11.5", "24"], | |
| "terms": ["BP", "kg", "Hb"], | |
| }, | |
| "anc_danger.mp3": { | |
| "values": ["155", "100"], | |
| "terms": ["BP", "PHC"], | |
| }, | |
| } | |
| from faster_whisper import WhisperModel | |
| ct2_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models", "whisper-hindi-ct2") | |
| if os.path.exists(ct2_path): | |
| print(f" Loading CT2 model from {ct2_path}...") | |
| model = WhisperModel(ct2_path, device="cuda", compute_type="float16") | |
| else: | |
| print(" Loading collabora/whisper-large-v2-hindi from HuggingFace...") | |
| model = WhisperModel("collabora/whisper-large-v2-hindi", device="cuda", compute_type="float16") | |
| for fname in audio_files: | |
| fpath = os.path.join(audio_dir, fname) | |
| expect = expectations.get(fname, {"values": [], "terms": []}) | |
| print(f"\n --- {fname} ---") | |
| t0 = time.time() | |
| segments, info = model.transcribe(fpath, language="hi", task="transcribe") | |
| raw_text = " ".join(seg.text.strip() for seg in segments) | |
| asr_time = time.time() - t0 | |
| t0 = time.time() | |
| normalized = normalize_transcript(raw_text) | |
| norm_time = time.time() - t0 | |
| safe_print(f" ASR time: {asr_time:.1f}s") | |
| safe_print(f" Norm time: {norm_time*1000:.0f}ms") | |
| safe_print(f" Raw: {raw_text[:150]}...") | |
| safe_print(f" Norm: {normalized[:150]}...") | |
| for val in expect["values"]: | |
| found_raw = val in raw_text | |
| found_norm = val in normalized | |
| status = "RAW+NORM" if found_raw else ("NORM" if found_norm else "MISS") | |
| check(f"asr({fname}→{val})", found_norm, True) | |
| safe_print(f" {val}: {status}") | |
| for term in expect["terms"]: | |
| check(f"asr({fname}→{term})", term in normalized, True) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--skip-gpu", action="store_true", help="Skip live ASR tests") | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("SAKHI ASR + NORMALIZATION TEST SUITE") | |
| print("=" * 60) | |
| test_number_parser() | |
| test_compound_splits() | |
| test_medical_terms() | |
| test_full_normalization() | |
| test_edge_cases() | |
| test_real_whisper_transcripts() | |
| if not args.skip_gpu: | |
| test_live_asr() | |
| else: | |
| print("\n=== 7. Live ASR — SKIPPED (--skip-gpu) ===") | |
| print("\n" + "=" * 60) | |
| print(f"RESULTS: {PASS} passed, {FAIL} failed") | |
| print("=" * 60) | |
| return 0 if FAIL == 0 else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |