File size: 7,085 Bytes
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
End-to-end pipeline test suite for Sakhi.

Runs 15 synthetic Hindi audio files through the FULL pipeline:
  Audio β†’ Whisper ASR β†’ Hindi normalization β†’ Form extraction β†’ Danger sign detection

Validates against known ground truth values from manifest.json.
Tests: value accuracy, hallucination, danger sign detection, referral decisions.

Usage:
    python scripts/test_pipeline_e2e.py
"""
import json
import os
import sys
import time

os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["PYTHONIOENCODING"] = "utf-8"
sys.stdout.reconfigure(encoding="utf-8")
# Disable buffering for real-time output on Windows
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)

# Ensure project root is on sys.path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from app import (
    transcribe_audio,
    extract_all,
    detect_visit_type,
    init_schemas,
)

AUDIO_DIR = "test_audio/synthetic"
MANIFEST = os.path.join(AUDIO_DIR, "manifest.json")

SAFE_REFERRALS = {"routine_followup", "continue_monitoring"}
URGENT_REFERRALS = {"refer_immediately", "refer_within_24h"}


def get_nested(d, path):
    """Get value from dict using dotted path like 'vitals.bp_systolic'."""
    parts = path.split(".")
    for p in parts:
        if not isinstance(d, dict):
            return None
        d = d.get(p)
    return d


def check_value(got, expected):
    """Check if extracted value matches expected, with tolerance for numbers."""
    if got is None:
        return False
    if isinstance(expected, bool):
        return got == expected
    if isinstance(expected, (int, float)):
        try:
            return abs(float(got) - float(expected)) <= 1.0
        except (ValueError, TypeError):
            return False
    if isinstance(expected, str):
        got_lower = str(got).lower().strip()
        exp_lower = expected.lower().strip()
        return exp_lower in got_lower or got_lower in exp_lower
    return str(got) == str(expected)


def run_test(test_case, test_num, total):
    """Run a single end-to-end test. Returns (pass, issues, timing)."""
    audio_file = os.path.join(AUDIO_DIR, test_case["file"])
    expected = test_case["expected"]
    name = test_case["file"].replace(".mp3", "")

    issues = []
    timing = {}

    # ── Step 1: ASR ──
    t0 = time.time()
    transcript = transcribe_audio(audio_file)
    timing["asr"] = round(time.time() - t0, 1)

    if not transcript or not transcript.strip():
        issues.append("ASR_EMPTY")
        print(f"  [{test_num}/{total}] FAIL [{name}] β€” ASR returned empty")
        return False, issues, timing

    # ── Step 2: Visit type detection ──
    detected_type = detect_visit_type(transcript)
    expected_type = expected["visit_type"]
    if detected_type != expected_type:
        issues.append(f"VISIT_TYPE: detected={detected_type} expected={expected_type}")

    # Use expected type for extraction (test extraction quality, not detection)
    visit_type = expected_type

    # ── Step 3: Unified extraction (function calling) ──
    t0 = time.time()
    result = extract_all(transcript, visit_type)
    extract_time = round(time.time() - t0, 1)
    timing["form"] = extract_time
    timing["danger"] = 0.0  # included in single call

    form = result.get("form")
    danger = result.get("danger")
    n_tool_calls = len(result.get("tool_calls") or [])

    if form is None:
        issues.append("FORM_PARSE_FAIL")
    else:
        # Check expected values
        for path, exp_val in expected.get("checks", {}).items():
            got = get_nested(form, path)
            if got is None:
                issues.append(f"MISSING {path} (expected {exp_val})")
            elif not check_value(got, exp_val):
                issues.append(f"WRONG {path}: got={got} expected={exp_val}")

        # Check hallucination traps
        for path in expected.get("must_be_null", []):
            val = get_nested(form, path)
            if val is not None and str(val).lower() not in ("null", "none", "", "β€”"):
                issues.append(f"HALLUC {path}={val}")

    if danger is None:
        issues.append("DANGER_PARSE_FAIL")
    else:
        signs = danger.get("danger_signs", [])
        n_signs = len(signs) if isinstance(signs, list) else 0
        d_min, d_max = expected.get("danger_count", [0, 0])
        if n_signs < d_min:
            issues.append(f"FALSE_NEG: {n_signs} danger signs < {d_min} expected")
        if n_signs > d_max:
            issues.append(f"FALSE_POS: {n_signs} danger signs > {d_max} expected")

        # Check referral decision
        ref = danger.get("referral_decision", {})
        ref_decision = ref.get("decision", "")
        exp_ref = expected.get("referral", "")
        if exp_ref:
            exp_group = "safe" if exp_ref in SAFE_REFERRALS else "urgent"
            got_group = "safe" if ref_decision in SAFE_REFERRALS else "urgent"
            if exp_group != got_group:
                issues.append(f"REFERRAL: got={ref_decision} expected={exp_ref}")

    timing["total"] = round(sum(timing.values()), 1)
    passed = len(issues) == 0
    status = "PASS" if passed else "FAIL"
    detail = "all checks OK" if passed else "; ".join(issues)
    tc_info = f"tools={n_tool_calls}" if n_tool_calls else "no-fc"
    print(f"  [{test_num}/{total}] {status} [{name}] ({timing['total']:.1f}s, {tc_info}) {detail}")

    return passed, issues, timing


def main():
    if not os.path.exists(MANIFEST):
        print(f"ERROR: {MANIFEST} not found. Run scripts/generate_test_audio.py first.")
        sys.exit(1)

    with open(MANIFEST, encoding="utf-8") as f:
        test_cases = json.load(f)

    print("Initializing schemas...")
    init_schemas()

    print(f"\n{'=' * 74}")
    print(f" Sakhi E2E Pipeline Test β€” {len(test_cases)} audio samples")
    print(f"{'=' * 74}")

    total_pass = 0
    total_fail = 0
    all_timings = []
    failures = []

    for i, tc in enumerate(test_cases, 1):
        passed, issues, timing = run_test(tc, i, len(test_cases))
        if passed:
            total_pass += 1
        else:
            total_fail += 1
            failures.append((tc["file"], issues))
        all_timings.append(timing)

    # ── Summary ──
    total = total_pass + total_fail
    pct = total_pass / total * 100 if total else 0
    avg_total = sum(t.get("total", 0) for t in all_timings) / len(all_timings)
    avg_asr = sum(t.get("asr", 0) for t in all_timings) / len(all_timings)
    avg_extract = sum(t.get("form", 0) for t in all_timings) / len(all_timings)

    print(f"\n{'=' * 74}")
    print(f" RESULTS: {total_pass}/{total} ({pct:.0f}%)")
    print(f" Avg timing: ASR {avg_asr:.1f}s | Extract {avg_extract:.1f}s | Total {avg_total:.1f}s")
    print(f"{'=' * 74}")

    if failures:
        print(f"\n FAILURES:")
        for fname, issues in failures:
            print(f"  {fname}: {'; '.join(issues)}")

    # Exit code for CI
    sys.exit(0 if total_fail == 0 else 1)


if __name__ == "__main__":
    main()