File size: 13,934 Bytes
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
745f62a
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
 
745f62a
 
 
d3595cb
 
 
 
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
 
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
 
 
745f62a
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3595cb
 
 
 
 
 
 
745f62a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
"""
Sakhi API β€” FastAPI backend for React frontend.

Endpoints:
  POST /api/process-audio   β€” Upload audio file β†’ transcript + form + danger signs
  POST /api/process-text    β€” Submit transcript text β†’ form + danger signs
  GET  /api/health          β€” Health check
  GET  /api/examples        β€” List example transcripts

Runs on port 8000. React frontend runs on port 3000.
"""
import os
import json
import time
import tempfile

os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"

from fastapi import FastAPI, UploadFile, File, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Optional

# Import pipeline functions from app.py
from app import (
    transcribe_audio,
    extract_form,
    extract_danger_signs,
    extract_all,
    detect_visit_type,
    init_schemas,
    validate_form_output,
    postprocess_transcript,
    translate_to_english,
    warm_whisper,
    WHISPER_MODEL,
)

app = FastAPI(title="Sakhi API", version="1.0.0")

# CORS for React dev server
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Startup: load schemas + pre-warm Whisper so the Space only reports ready
# when the audio path is hot. Whisper load is wrapped in try/except β€” if the
# eager load fails (no GPU, network blip), fall back to lazy loading on
# first audio request instead of blocking the whole boot.
@app.on_event("startup")
def startup():
    init_schemas()
    try:
        warm_whisper()
    except Exception as e:
        print(f"[startup] WARN: Whisper pre-warm failed ({e!r}); falling back to lazy load")


# ── Models ──
class PatientMetadata(BaseModel):
    """ASHA-entered patient identifier fields. All optional β€” pipeline still runs without them.
    When provided, override LLM-extracted name/age/sex in the form (see apply_metadata in app.py)."""
    patient_name: Optional[str] = None
    patient_age: Optional[int] = None
    age_unit: Optional[str] = None        # "years" | "months"
    patient_sex: Optional[str] = None     # "male" | "female"
    patient_mobile: Optional[str] = None
    asha_id: Optional[str] = None
    visit_date: Optional[str] = None      # ISO date string


class TextRequest(BaseModel):
    transcript: str
    visit_type: Optional[str] = "auto"
    metadata: Optional[PatientMetadata] = None


class TranslateRequest(BaseModel):
    text: str


class ExtractionResult(BaseModel):
    visit_type: str
    form: Optional[dict] = None
    danger: Optional[dict] = None
    metadata: Optional[dict] = None
    transcript: Optional[str] = None
    timing: dict = {}
    tool_calls: Optional[list] = None
    error: Optional[str] = None


def _metadata_dict(meta):
    """Coerce a PatientMetadata or None into a dict (or None if empty)."""
    if meta is None:
        return None
    d = meta.dict() if hasattr(meta, "dict") else dict(meta)
    # Drop all-None entries so apply_metadata short-circuits cleanly
    return {k: v for k, v in d.items() if v is not None and v != ""} or None


# ── Endpoints ──
@app.get("/api/health")
def health():
    return {
        "status": "ok",
        "model": os.environ.get("OLLAMA_MODEL", "gemma4:e4b-it-q4_K_M"),
        "whisper": WHISPER_MODEL,
    }


@app.get("/api/examples")
def examples():
    from app import EXAMPLE_TRANSCRIPTS
    return [
        {"label": ex[0], "transcript": ex[1], "default": i == 1}
        for i, ex in enumerate(EXAMPLE_TRANSCRIPTS)
    ]
    # index 1 = "ANC Visit β€” Preeclampsia (DANGER)" β€” best for demo (has danger signs)


@app.post("/api/translate")
def translate(req: TranslateRequest):
    """Hindi / Hinglish β†’ English. Uses the same Gemma model already in VRAM,
    so the cost is one extra ~3-5s LLM call. Reviewer-facing convenience;
    never invoked from the main extraction path."""
    t0 = time.time()
    english = translate_to_english(req.text)
    return {"english": english, "time_s": round(time.time() - t0, 2)}


_DEMO_AUDIO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "demo_audio")


@app.get("/api/audio-examples")
def audio_examples():
    """Curated voice clips bundled into the image. Returns playable URLs
    relative to the Space origin, so the frontend can both <audio src=...>
    them and re-POST them to /api/process-audio-stream."""
    manifest_path = os.path.join(_DEMO_AUDIO_DIR, "manifest.json")
    if not os.path.isfile(manifest_path):
        return []
    with open(manifest_path, "r", encoding="utf-8") as f:
        entries = json.load(f)
    for e in entries:
        e["url"] = f"/audio/{e['file']}"
    return entries


@app.post("/api/process-text", response_model=ExtractionResult)
def process_text(req: TextRequest):
    t_total = time.time()

    transcript = req.transcript.strip()
    if not transcript:
        return ExtractionResult(visit_type="unknown", error="Empty transcript")

    # Detect visit type
    if req.visit_type and req.visit_type != "auto":
        visit_type = req.visit_type.lower().replace(" ", "_")
    else:
        visit_type = detect_visit_type(transcript)

    metadata = _metadata_dict(req.metadata)
    result = extract_all(transcript, visit_type, metadata=metadata)

    total = time.time() - t_total
    timing = result.get("timing", {})
    timing["total_s"] = round(total, 1)

    return ExtractionResult(
        visit_type=visit_type,
        form=result["form"],
        danger=result["danger"],
        metadata=result.get("metadata"),
        timing=timing,
        tool_calls=result.get("tool_calls"),
    )


@app.post("/api/process-audio", response_model=ExtractionResult)
async def process_audio(
    audio: UploadFile = File(...),
    visit_type: str = Form("auto"),
    patient_name: Optional[str] = Form(None),
    patient_age: Optional[int] = Form(None),
    age_unit: Optional[str] = Form(None),
    patient_sex: Optional[str] = Form(None),
    patient_mobile: Optional[str] = Form(None),
    asha_id: Optional[str] = Form(None),
    visit_date: Optional[str] = Form(None),
):
    t_total = time.time()

    # Save uploaded audio to temp file
    suffix = os.path.splitext(audio.filename or "audio.wav")[1]
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        content = await audio.read()
        tmp.write(content)
        tmp_path = tmp.name

    try:
        # ASR
        t0 = time.time()
        transcript = transcribe_audio(tmp_path)
        asr_time = time.time() - t0

        if not transcript or not transcript.strip():
            return ExtractionResult(
                visit_type="unknown",
                error="Transcription returned empty",
                timing={"asr_s": round(asr_time, 1)},
            )

        # Detect visit type
        if visit_type and visit_type != "auto":
            vtype = visit_type.lower().replace(" ", "_")
        else:
            vtype = detect_visit_type(transcript)

        metadata = _metadata_dict(PatientMetadata(
            patient_name=patient_name, patient_age=patient_age, age_unit=age_unit,
            patient_sex=patient_sex, patient_mobile=patient_mobile,
            asha_id=asha_id, visit_date=visit_date,
        ))
        result = extract_all(transcript, vtype, metadata=metadata)

        total = time.time() - t_total
        timing = result.get("timing", {})
        timing["asr_s"] = round(asr_time, 1)
        timing["total_s"] = round(total, 1)

        return ExtractionResult(
            visit_type=vtype,
            form=result["form"],
            danger=result["danger"],
            metadata=result.get("metadata"),
            transcript=transcript,
            timing=timing,
            tool_calls=result.get("tool_calls"),
        )
    finally:
        os.unlink(tmp_path)


def _sse_event(data: dict) -> str:
    return f"data: {json.dumps(data)}\n\n"


@app.post("/api/process-text-stream")
async def process_text_stream(req: TextRequest):
    def generate():
        t_total = time.time()
        transcript = req.transcript.strip()
        if not transcript:
            yield _sse_event({"error": "Empty transcript"})
            return

        # Detect visit type
        yield _sse_event({"stage": "detect", "status": "running"})
        if req.visit_type and req.visit_type != "auto":
            visit_type = req.visit_type.lower().replace(" ", "_")
        else:
            visit_type = detect_visit_type(transcript)
        yield _sse_event({"stage": "detect", "status": "done", "visit_type": visit_type})

        metadata = _metadata_dict(req.metadata)

        # Unified extraction (form + danger in one LLM call via function calling)
        yield _sse_event({"stage": "form", "status": "running"})
        t0 = time.time()
        result = extract_all(transcript, visit_type, metadata=metadata)
        extract_time = time.time() - t0
        yield _sse_event({"stage": "form", "status": "done", "time": round(extract_time, 1)})

        # Danger stage is instant (already done in same call)
        yield _sse_event({"stage": "danger", "status": "done", "time": 0.0})

        total = time.time() - t_total
        timing = result.get("timing", {})
        timing["total_s"] = round(total, 1)
        yield _sse_event({
            "stage": "complete",
            "visit_type": visit_type,
            "form": result["form"],
            "danger": result["danger"],
            "metadata": result.get("metadata"),
            "tool_calls": result.get("tool_calls"),
            "timing": timing,
        })

    return StreamingResponse(generate(), media_type="text/event-stream")


@app.post("/api/process-audio-stream")
async def process_audio_stream(
    audio: UploadFile = File(...),
    visit_type: str = Form("auto"),
    patient_name: Optional[str] = Form(None),
    patient_age: Optional[int] = Form(None),
    age_unit: Optional[str] = Form(None),
    patient_sex: Optional[str] = Form(None),
    patient_mobile: Optional[str] = Form(None),
    asha_id: Optional[str] = Form(None),
    visit_date: Optional[str] = Form(None),
):
    # Save uploaded audio to temp file before streaming
    suffix = os.path.splitext(audio.filename or "audio.wav")[1]
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        content = await audio.read()
        tmp.write(content)
        tmp_path = tmp.name

    metadata = _metadata_dict(PatientMetadata(
        patient_name=patient_name, patient_age=patient_age, age_unit=age_unit,
        patient_sex=patient_sex, patient_mobile=patient_mobile,
        asha_id=asha_id, visit_date=visit_date,
    ))

    def generate():
        t_total = time.time()
        try:
            # ASR
            yield _sse_event({"stage": "asr", "status": "running"})
            t0 = time.time()
            transcript = transcribe_audio(tmp_path)
            asr_time = time.time() - t0
            yield _sse_event({"stage": "asr", "status": "done", "time": round(asr_time, 1)})

            if not transcript or not transcript.strip():
                yield _sse_event({"error": "Transcription returned empty"})
                return

            # Normalize
            yield _sse_event({"stage": "normalize", "status": "running"})
            transcript = postprocess_transcript(transcript)
            yield _sse_event({"stage": "normalize", "status": "done", "transcript": transcript})

            # Detect visit type
            yield _sse_event({"stage": "detect", "status": "running"})
            if visit_type and visit_type != "auto":
                vtype = visit_type.lower().replace(" ", "_")
            else:
                vtype = detect_visit_type(transcript)
            yield _sse_event({"stage": "detect", "status": "done", "visit_type": vtype})

            # Unified extraction (form + danger in one LLM call via function calling)
            yield _sse_event({"stage": "form", "status": "running"})
            t1 = time.time()
            result = extract_all(transcript, vtype, metadata=metadata)
            extract_time = time.time() - t1
            yield _sse_event({"stage": "form", "status": "done", "time": round(extract_time, 1)})

            # Danger stage is instant (already done in same call)
            yield _sse_event({"stage": "danger", "status": "done", "time": 0.0})

            total = time.time() - t_total
            timing = result.get("timing", {})
            timing["asr_s"] = round(asr_time, 1)
            timing["total_s"] = round(total, 1)
            yield _sse_event({
                "stage": "complete",
                "visit_type": vtype,
                "form": result["form"],
                "danger": result["danger"],
                "metadata": result.get("metadata"),
                "transcript": transcript,
                "tool_calls": result.get("tool_calls"),
                "timing": timing,
            })
        finally:
            os.unlink(tmp_path)

    return StreamingResponse(generate(), media_type="text/event-stream")


# Serve curated demo audio under /audio/* so the frontend can <audio src=...>
# them. Must be mounted BEFORE the SPA catch-all below; otherwise the
# StaticFiles for `/` would swallow these paths.
if os.path.isdir(_DEMO_AUDIO_DIR):
    app.mount("/audio", StaticFiles(directory=_DEMO_AUDIO_DIR), name="demo_audio")


# Serve built React frontend at / when dist exists (unified desktop UI for health centers).
# Must be mounted AFTER all /api/* routes so they take priority.
_FRONTEND_DIST = os.path.join(os.path.dirname(os.path.abspath(__file__)), "frontend", "dist")
if os.path.isdir(_FRONTEND_DIST):
    app.mount("/", StaticFiles(directory=_FRONTEND_DIST, html=True), name="frontend")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)