Spaces:
Sleeping
Sleeping
| """ | |
| Sakhi (सखी) — ASHA Health Worker AI Companion | |
| ================================================ | |
| Hindi voice → structured MCTS/HMIS forms + danger sign detection | |
| powered by Gemma 4 E4B (fine-tuned via Unsloth). | |
| This module is the pipeline library (ASR + extraction + validation). The | |
| React UI is served by api.py; this file is not run directly. | |
| """ | |
| import os | |
| import re | |
| import json | |
| import time | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| # ============================================================ | |
| # CONFIGURATION | |
| # ============================================================ | |
| MODEL_PATH = "./models/checkpoints/final" | |
| MAX_SEQ_LENGTH = 4096 | |
| # Ollama config — set OLLAMA_MODEL to use Ollama instead of Unsloth | |
| # Use "sakhi" once fine-tuned GGUF is registered, or base model for now | |
| OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma4:e4b-it-q4_K_M") | |
| USE_OLLAMA = os.environ.get("USE_OLLAMA", "1") == "1" | |
| USE_FUNCTION_CALLING = os.environ.get("USE_FUNCTION_CALLING", "1") == "1" | |
| # Whisper config. Default = the CTranslate2-converted mirror of collabora's | |
| # Hindi fine-tune of whisper-large-v2 (selected after session 19's real-voice | |
| # validation pass). faster-whisper requires CT2 format; the original | |
| # collabora/ repo is transformers format and won't load directly. | |
| # Override with WHISPER_MODEL for evals against other variants. Local dev | |
| # with a pre-converted CT2 directory at models/whisper-hindi-ct2/ takes | |
| # precedence over this env var — see warm_whisper(). | |
| WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "Tushar9802/whisper-large-v2-hindi-ct2") | |
| # System prompts (same as training) | |
| FORM_SYSTEM_PROMPT = ( | |
| "You are a clinical data extraction system for India's ASHA health worker program. " | |
| "Extract structured data from the Hindi/Hinglish home visit conversation into the requested JSON schema. " | |
| "ONLY extract information explicitly stated in the conversation. Use null for any field not mentioned.\n\n" | |
| "STRICT RULES:\n" | |
| "1. Do NOT invent names, dates, phone numbers, or addresses. If the patient is only called 'दीदी' or 'बहन', set name to null.\n" | |
| "2. If age is not explicitly stated as a number, set age to null. Do NOT guess from context.\n" | |
| "3. If blood group, HIV status, or other lab tests are not discussed, they MUST be null — never assume 'negative' or a default group.\n" | |
| "4. If the conversation has no speaker labels (ASHA/Patient), still extract data but be extra strict about nulls.\n" | |
| "5. Numbers may appear as Hindi words (e.g., 'एक सो दस बटा सत्तर' = 110/70). Convert them to digits.\n" | |
| "Return valid JSON only." | |
| ) | |
| DANGER_SYSTEM_PROMPT = ( | |
| "You are a clinical danger sign detection system for India's ASHA health worker program. " | |
| "Analyze the Hindi/Hinglish home visit conversation for NHM-defined danger signs.\n\n" | |
| "STRICT RULES:\n" | |
| "1. ONLY flag a danger sign if the EXACT words proving it appear in the conversation.\n" | |
| "2. utterance_evidence MUST be a verbatim copy-paste from the conversation — do NOT paraphrase or fabricate.\n" | |
| "3. If a vital sign is NORMAL (e.g., BP 110/70, temperature 37°C), that is NOT a danger sign.\n" | |
| "4. Most routine visits have ZERO danger signs. Return an empty danger_signs array when none exist.\n" | |
| "5. When in doubt, do NOT flag — a missed flag is better than a false alarm.\n" | |
| "Return valid JSON only." | |
| ) | |
| # ============================================================ | |
| # EXAMPLE TRANSCRIPTS (for demo) | |
| # ============================================================ | |
| EXAMPLE_TRANSCRIPTS = [ | |
| [ | |
| "ANC Visit — Normal", | |
| ( | |
| "ASHA: नमस्ते, कैसे हैं आप?\n" | |
| "Patient: नमस्ते दीदी, मैं ठीक हूँ।\n" | |
| "ASHA: अच्छा है। मैं आपका चेकअप करने आई हूँ। चलिए, पहले आपका BP चेक कर लेती हूँ।\n" | |
| "Patient: ठीक है।\n" | |
| "ASHA: आपका BP 110/70 है, बिल्कुल ठीक है। अब वजन देखती हूँ... 58 kg है। पिछली बार 56 था, तो अच्छा बढ़ रहा है।\n" | |
| "Patient: हाँ, मैं अच्छा खा रही हूँ।\n" | |
| "ASHA: बहुत अच्छा! Hb कितना आया था पिछली बार?\n" | |
| "Patient: डॉक्टर ने कहा था 11.5 है।\n" | |
| "ASHA: ये तो बहुत अच्छा है। IFA की गोलियाँ ले रही हैं?\n" | |
| "Patient: हाँ, रोज़ लेती हूँ।\n" | |
| "ASHA: TT का टीका लगा?\n" | |
| "Patient: हाँ, पहला लग गया है।\n" | |
| "ASHA: बच्चे की हलचल कैसी है?\n" | |
| "Patient: बहुत हिलता-डुलता है, ठीक है।\n" | |
| "ASHA: बहुत अच्छा। आप लगभग 24 हफ्ते की हैं। डिलीवरी के लिए कहाँ जाएँगी?\n" | |
| "Patient: PHC में।\n" | |
| "ASHA: गाड़ी का इंतज़ाम है?\n" | |
| "Patient: हाँ, पति की गाड़ी है।\n" | |
| "ASHA: ठीक है। अगली बार 2 हफ्ते बाद आऊँगी। कोई तकलीफ़ हो तो फ़ोन कर दीजिए।\n" | |
| "Patient: ठीक है दीदी, धन्यवाद।" | |
| ), | |
| ], | |
| [ | |
| "ANC Visit — Preeclampsia (DANGER)", | |
| ( | |
| "ASHA: नमस्ते दीदी, कैसे हैं?\n" | |
| "Patient: दीदी, मुझे बहुत सिरदर्द हो रहा है कल से।\n" | |
| "ASHA: अच्छा, और कोई तकलीफ़?\n" | |
| "Patient: हाँ, आँखों के सामने धुंधला दिखता है कभी-कभी। और चेहरे पर सूजन भी आ गई है।\n" | |
| "ASHA: ये तो ठीक नहीं है। मैं BP चेक करती हूँ... आपका BP 155/100 आ रहा है। ये बहुत ज़्यादा है।\n" | |
| "Patient: क्या करें दीदी?\n" | |
| "ASHA: आपको तुरंत PHC जाना होगा। ये गंभीर हो सकता है। आप कितने महीने की हैं?\n" | |
| "Patient: लगभग 8 महीने।\n" | |
| "ASHA: पैरों में सूजन है?\n" | |
| "Patient: हाँ, काफी सूजन है।\n" | |
| "ASHA: मैं अभी गाड़ी का इंतज़ाम करती हूँ। आपको आज ही PHC ले चलती हूँ।" | |
| ), | |
| ], | |
| [ | |
| "PNC — Newborn not feeding (DANGER)", | |
| ( | |
| "ASHA: नमस्ते, कैसे हैं? बच्चा कैसा है?\n" | |
| "Mother: दीदी, बच्चा बहुत सोता रहता है। दूध भी ठीक से नहीं पीता।\n" | |
| "ASHA: कब से ऐसा है?\n" | |
| "Mother: कल से। पहले ठीक था, अब लगभग 12 घंटे से दूध नहीं पिया।\n" | |
| "ASHA: बच्चे का रोना कैसा है?\n" | |
| "Mother: बहुत कमज़ोर आवाज़ में रोता है।\n" | |
| "ASHA: तापमान चेक करती हूँ... 100.5 डिग्री है। बुखार है। और बच्चा सुस्त लग रहा है।\n" | |
| "Mother: क्या करें?\n" | |
| "ASHA: ये IMNCI के danger signs हैं। बच्चे को तुरंत PHC ले जाना होगा। मैं गाड़ी बुलाती हूँ।" | |
| ), | |
| ], | |
| [ | |
| "Child Health — Routine visit", | |
| ( | |
| "ASHA: नमस्ते, बच्चा कैसा है?\n" | |
| "Mother: बिल्कुल ठीक है दीदी। खूब खाता है, खेलता है।\n" | |
| "ASHA: बहुत अच्छा! वजन देखती हूँ... 8.5 kg है। 9 महीने के लिए अच्छा है।\n" | |
| "Mother: हाँ, दाल-चावल, केला सब खाता है अब।\n" | |
| "ASHA: Vitamin A की दवाई दी थी पिछली बार?\n" | |
| "Mother: हाँ, 6 महीने में दी थी।\n" | |
| "ASHA: अच्छा। अब deworming भी देनी है। और टीके सब लगे हैं?\n" | |
| "Mother: हाँ, सब समय पर लगे हैं।\n" | |
| "ASHA: बहुत अच्छा। बच्चा बैठता है, घुटनों पर चलता है?\n" | |
| "Mother: हाँ, सब करता है। बोलने भी लगा है थोड़ा।\n" | |
| "ASHA: बढ़िया है। अगली बार 3 महीने बाद आऊँगी।" | |
| ), | |
| ], | |
| ] | |
| # ============================================================ | |
| # SCHEMA LOADING | |
| # ============================================================ | |
| def load_schema(name): | |
| with open(f"configs/schemas/{name}.json", "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| SCHEMAS = {} | |
| VISIT_TYPE_MAP = { | |
| "anc_visit": "anc_visit", | |
| "pnc_visit": "pnc_visit", | |
| "delivery": "delivery", | |
| "child_health": "child_health", | |
| } | |
| def init_schemas(): | |
| global SCHEMAS | |
| for name in ["anc_visit", "pnc_visit", "delivery", "child_health", "danger_signs"]: | |
| SCHEMAS[name] = load_schema(name) | |
| # ============================================================ | |
| # MODEL LOADING | |
| # ============================================================ | |
| _model = None | |
| _tokenizer = None | |
| def load_model(): | |
| global _model, _tokenizer | |
| if _model is not None: | |
| return _model, _tokenizer | |
| import torch | |
| torch._dynamo.config.suppress_errors = True | |
| from unsloth import FastLanguageModel | |
| print("[MODEL] Loading Gemma 4 E4B fine-tuned model...") | |
| _model, _tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_PATH, | |
| max_seq_length=MAX_SEQ_LENGTH, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(_model) | |
| print("[MODEL] Model loaded.") | |
| return _model, _tokenizer | |
| # ============================================================ | |
| # TRANSCRIPT POST-PROCESSING (delegated to src/hindi_normalize) | |
| # ============================================================ | |
| from src.hindi_normalize import normalize_transcript as postprocess_transcript | |
| _whisper_model = None | |
| def warm_whisper(): | |
| """Eagerly load the Whisper model into VRAM. Idempotent — safe to call | |
| multiple times; subsequent calls return the cached singleton. Called from | |
| FastAPI's startup hook so the first user audio request lands hot.""" | |
| global _whisper_model | |
| if _whisper_model is not None: | |
| return _whisper_model | |
| from faster_whisper import WhisperModel | |
| ct2_path = os.path.join(os.path.dirname(__file__), "models", "whisper-hindi-ct2") | |
| if os.path.exists(ct2_path): | |
| print(f"[ASR] Loading CTranslate2 model from {ct2_path}...") | |
| _whisper_model = WhisperModel(ct2_path, device="cuda", compute_type="float16") | |
| else: | |
| print(f"[ASR] Loading {WHISPER_MODEL} from HuggingFace Hub...") | |
| _whisper_model = WhisperModel(WHISPER_MODEL, device="cuda", compute_type="float16") | |
| print("[ASR] Whisper loaded.") | |
| return _whisper_model | |
| def transcribe_audio(audio_path): | |
| """Transcribe audio using the configured Whisper model via faster-whisper (CTranslate2).""" | |
| warm_whisper() | |
| print("[ASR] Transcribing...") | |
| segments, info = _whisper_model.transcribe( | |
| audio_path, | |
| language="hi", | |
| task="transcribe", | |
| vad_filter=True, | |
| beam_size=1, | |
| temperature=0.0, | |
| condition_on_previous_text=False, | |
| ) | |
| transcript = " ".join(seg.text.strip() for seg in segments) | |
| transcript = postprocess_transcript(transcript) | |
| print(f"[ASR] Transcript ({len(transcript)} chars)") | |
| return transcript | |
| def run_inference(system_prompt, user_prompt): | |
| """Run model inference via Ollama or Unsloth, return parsed JSON or raw text.""" | |
| if USE_OLLAMA: | |
| return _run_inference_ollama(system_prompt, user_prompt) | |
| return _run_inference_unsloth(system_prompt, user_prompt) | |
| def translate_to_english(hindi_text): | |
| """Translate Hindi / Hinglish home-visit text to English via the same | |
| Gemma model already loaded in VRAM. On-demand only — never on the | |
| main extraction path. Returns plain English text (not JSON).""" | |
| import ollama | |
| text = (hindi_text or "").strip() | |
| if not text: | |
| return "" | |
| t0 = time.time() | |
| resp = ollama.chat( | |
| model=OLLAMA_MODEL, | |
| messages=[ | |
| {"role": "system", "content": ( | |
| "Translate the following Hindi or Hinglish conversation into clear, natural English. " | |
| "Preserve speaker labels (ASHA / Patient / Mother) and clinical numbers exactly. " | |
| "Do not add commentary or explanations — output ONLY the translation." | |
| )}, | |
| {"role": "user", "content": text}, | |
| ], | |
| options={"temperature": 0.1, "num_ctx": 4096, "num_gpu": 999}, | |
| keep_alive=os.environ.get("OLLAMA_KEEP_ALIVE", "10m"), | |
| ) | |
| elapsed = time.time() - t0 | |
| out = resp.message.content.strip() | |
| tok_s = resp.eval_count / (resp.eval_duration / 1e9) if resp.eval_duration else 0 | |
| print(f"[LLM] Translate: {elapsed:.1f}s ({resp.eval_count} tok, {tok_s:.0f} tok/s)") | |
| return out | |
| def _run_inference_ollama(system_prompt, user_prompt): | |
| """Run inference via Ollama API — fast GGUF on GPU with JSON mode.""" | |
| import ollama | |
| t0 = time.time() | |
| resp = ollama.chat( | |
| model=OLLAMA_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| format="json", | |
| options={"temperature": 0.1, "num_ctx": 4096, "num_gpu": 999}, | |
| keep_alive=os.environ.get("OLLAMA_KEEP_ALIVE", "10m"), | |
| ) | |
| elapsed = time.time() - t0 | |
| response = resp.message.content | |
| tok_s = resp.eval_count / (resp.eval_duration / 1e9) if resp.eval_duration else 0 | |
| print(f"[LLM] Ollama: {elapsed:.1f}s ({resp.eval_count} tok, {tok_s:.0f} tok/s)") | |
| # format="json" guarantees valid JSON — parse directly | |
| try: | |
| parsed = json.loads(response) | |
| except json.JSONDecodeError: | |
| print(f"[WARN] Ollama JSON mode parse failed, falling back to heuristic parser") | |
| parsed = _parse_json_response(response) | |
| return {"raw": response, "parsed": parsed, "time_s": elapsed} | |
| # ============================================================ | |
| # FUNCTION CALLING — Gemma 4 native tool use | |
| # ============================================================ | |
| def _build_form_tool(visit_type): | |
| """Build extract_form tool definition from the visit's JSON schema.""" | |
| schema_key = VISIT_TYPE_MAP.get(visit_type, "anc_visit") | |
| schema = SCHEMAS.get(schema_key, SCHEMAS["anc_visit"]) | |
| return { | |
| "type": "function", | |
| "function": { | |
| "name": "extract_form", | |
| "description": ( | |
| f"Extract structured {schema_key.replace('_', ' ')} form data from the " | |
| "ASHA home visit conversation. ONLY extract information explicitly stated. " | |
| "Use null for any field not mentioned." | |
| ), | |
| "parameters": schema, | |
| }, | |
| } | |
| TOOL_FLAG_DANGER_SIGN = { | |
| "type": "function", | |
| "function": { | |
| "name": "flag_danger_sign", | |
| "description": ( | |
| "Flag a single danger sign detected in the patient conversation. " | |
| "Call once per danger sign found. Do NOT call if no danger signs exist. " | |
| "The evidence field MUST be an exact verbatim quote from the conversation." | |
| ), | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "sign": { | |
| "type": "string", | |
| "description": "Standard NHM danger sign name (e.g., severe_preeclampsia, severe_anemia)", | |
| }, | |
| "category": { | |
| "type": "string", | |
| "enum": ["immediate_referral", "urgent_care", "monitor_closely"], | |
| }, | |
| "clinical_value": { | |
| "type": ["string", "null"], | |
| "description": "Measured value if applicable (e.g., '145/95', '38.5C')", | |
| }, | |
| "utterance_evidence": { | |
| "type": "string", | |
| "description": "REQUIRED: exact verbatim quote from conversation proving this sign", | |
| }, | |
| }, | |
| "required": ["sign", "category", "utterance_evidence"], | |
| }, | |
| }, | |
| } | |
| TOOL_ISSUE_REFERRAL = { | |
| "type": "function", | |
| "function": { | |
| "name": "issue_referral", | |
| "description": ( | |
| "Issue a referral decision based on detected danger signs. " | |
| "Only call if danger signs warrant referral. Do NOT call for routine visits." | |
| ), | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "urgency": { | |
| "type": "string", | |
| "enum": ["immediate", "within_24h", "routine"], | |
| }, | |
| "facility": { | |
| "type": ["string", "null"], | |
| "enum": ["PHC", "CHC", "district_hospital", "FRU", None], | |
| }, | |
| "reason": { | |
| "type": "string", | |
| "description": "Brief clinical reasoning for referral", | |
| }, | |
| }, | |
| "required": ["urgency", "facility", "reason"], | |
| }, | |
| }, | |
| } | |
| DANGER_FC_SYSTEM_PROMPT = ( | |
| "You are a clinical danger sign detection system for India's ASHA health worker program.\n\n" | |
| "Analyze the conversation and use the provided tools:\n" | |
| "1. flag_danger_sign — call ONCE per danger sign found. Evidence MUST be a verbatim quote from the conversation. " | |
| "If NO danger signs exist, do NOT call any tool.\n" | |
| "2. issue_referral — call only if danger signs warrant referral to a facility.\n\n" | |
| "STRICT RULES:\n" | |
| "- ONLY flag a danger sign if the EXACT words proving it appear in the conversation.\n" | |
| "- utterance_evidence MUST be a verbatim copy-paste from the conversation — do NOT paraphrase.\n" | |
| "- If a vital sign is NORMAL (e.g., BP 110/70, temperature 37°C), that is NOT a danger sign.\n" | |
| "- Most routine visits have ZERO danger signs. Do NOT call any tools for normal visits.\n" | |
| "- When in doubt, do NOT flag — a missed flag is better than a false alarm." | |
| ) | |
| def _run_danger_fc(transcript, visit_type): | |
| """Run danger sign detection via function calling (flag_danger_sign + issue_referral tools).""" | |
| import ollama | |
| tools = [TOOL_FLAG_DANGER_SIGN, TOOL_ISSUE_REFERRAL] | |
| t0 = time.time() | |
| resp = ollama.chat( | |
| model=OLLAMA_MODEL, | |
| messages=[ | |
| {"role": "system", "content": DANGER_FC_SYSTEM_PROMPT}, | |
| {"role": "user", "content": ( | |
| f"Analyze this ASHA home visit conversation for danger signs.\n\n" | |
| f"Visit type: {visit_type}\n\n" | |
| f"{transcript}" | |
| )}, | |
| ], | |
| tools=tools, | |
| options={"temperature": 0.1, "num_ctx": 4096, "num_gpu": 999}, | |
| keep_alive=os.environ.get("OLLAMA_KEEP_ALIVE", "10m"), | |
| ) | |
| elapsed = time.time() - t0 | |
| tok_s = resp.eval_count / (resp.eval_duration / 1e9) if resp.eval_duration else 0 | |
| print(f"[LLM] Danger FC: {elapsed:.1f}s ({resp.eval_count} tok, {tok_s:.0f} tok/s)") | |
| danger_signs = [] | |
| referral = None | |
| tool_calls_raw = [] | |
| if resp.message.tool_calls: | |
| for tc in resp.message.tool_calls: | |
| fname = tc.function.name | |
| args = tc.function.arguments | |
| tool_calls_raw.append({"function": fname, "arguments": args}) | |
| if fname == "flag_danger_sign": | |
| danger_signs.append(args) | |
| elif fname == "issue_referral": | |
| referral = args | |
| print(f"[LLM] Tool calls: {len(resp.message.tool_calls)} " | |
| f"(danger_signs={len(danger_signs)}, " | |
| f"referral={'yes' if referral else 'no'})") | |
| else: | |
| print(f"[LLM] No tool calls — no danger signs detected") | |
| return { | |
| "danger_signs": danger_signs, | |
| "referral": referral, | |
| "tool_calls": tool_calls_raw, | |
| "time_s": elapsed, | |
| } | |
| def _normalize_fc_form(raw, visit_type): | |
| """Normalize function calling form output to match the expected schema structure. | |
| The model sometimes uses free-form keys (blood_pressure: "110/70") instead | |
| of schema keys (bp_systolic: 110, bp_diastolic: 70), or nests data | |
| differently. This flattens and remaps to the canonical form. | |
| """ | |
| if not raw or not isinstance(raw, dict): | |
| return raw | |
| # Recursively collect all key-value pairs from the raw output | |
| def _collect(d, prefix=""): | |
| items = {} | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| key = f"{prefix}.{k}" if prefix else k | |
| if isinstance(v, dict): | |
| items.update(_collect(v, key)) | |
| else: | |
| items[key] = v | |
| # Also store under the leaf key for simple matching | |
| items[k] = v | |
| return items | |
| flat = _collect(raw) | |
| # Build a clean output matching schema structure | |
| schema_key = VISIT_TYPE_MAP.get(visit_type, "anc_visit") | |
| schema = SCHEMAS.get(schema_key, SCHEMAS.get("anc_visit", {})) | |
| result = {} | |
| # Walk schema top-level sections and fill from flat values | |
| for section_name, section_def in schema.get("properties", {}).items(): | |
| if section_def.get("type") == "object": | |
| section_data = {} | |
| for field_name in section_def.get("properties", {}).keys(): | |
| # Try exact match first, then look through flat keys | |
| val = flat.get(f"{section_name}.{field_name}") or flat.get(field_name) | |
| if val is not None: | |
| section_data[field_name] = val | |
| if section_data: | |
| result[section_name] = section_data | |
| elif section_def.get("type") == "array": | |
| val = flat.get(section_name) | |
| if isinstance(val, list): | |
| result[section_name] = val | |
| else: | |
| result[section_name] = [] | |
| else: | |
| val = flat.get(section_name) | |
| if val is not None: | |
| result[section_name] = val | |
| # ── BP splitting: "110/70" → bp_systolic=110, bp_diastolic=70 ── | |
| vitals = result.get("vitals", {}) | |
| bp_raw = flat.get("blood_pressure") or flat.get("bp") or flat.get("vitals.blood_pressure") | |
| if bp_raw and isinstance(bp_raw, str) and "/" in bp_raw: | |
| parts = bp_raw.split("/") | |
| try: | |
| if "bp_systolic" not in vitals or vitals.get("bp_systolic") is None: | |
| vitals["bp_systolic"] = int(parts[0].strip()) | |
| if "bp_diastolic" not in vitals or vitals.get("bp_diastolic") is None: | |
| vitals["bp_diastolic"] = int(parts[1].strip()) | |
| except (ValueError, IndexError): | |
| pass | |
| # ── Infant/child weight normalization (before vitals, to avoid misplacement) ── | |
| # PNC: infant_assessment.weight_kg, Delivery: infant.birth_weight_kg | |
| for iw_section, iw_field, iw_keys in [ | |
| ("infant_assessment", "weight_kg", [ | |
| "infant_assessment.weight_kg", "infant_assessment.weight", | |
| ]), | |
| ("infant", "birth_weight_kg", [ | |
| "infant.birth_weight_kg", "infant.birth_weight", "infant.weight", | |
| ]), | |
| ("child", "weight_kg", [ | |
| "child.weight_kg", "child.weight", | |
| ]), | |
| ("growth_assessment", "weight_kg", [ | |
| "growth_assessment.weight_kg", "growth_assessment.weight", | |
| ]), | |
| ]: | |
| for iw_key in iw_keys: | |
| iw_val = flat.get(iw_key) | |
| if iw_val is not None: | |
| section = result.get(iw_section, {}) | |
| if isinstance(section, dict) and (iw_field not in section or section.get(iw_field) is None): | |
| try: | |
| num = float(str(iw_val).replace("kg", "").replace("KG", "").strip()) | |
| section[iw_field] = num | |
| result[iw_section] = section | |
| except (ValueError, TypeError): | |
| pass | |
| break | |
| # ── Vitals weight normalization: "55 kg" → 55.0 ── | |
| # Only use vitals-specific keys to avoid grabbing infant weight | |
| for wkey in ("vitals.weight", "vitals.weight_kg"): | |
| wval = flat.get(wkey) | |
| if wval is not None: | |
| try: | |
| num = float(str(wval).replace("kg", "").replace("KG", "").strip()) | |
| if "weight_kg" not in vitals or vitals.get("weight_kg") is None: | |
| vitals["weight_kg"] = num | |
| except (ValueError, TypeError): | |
| pass | |
| break | |
| # ── Hemoglobin normalization ── | |
| for hkey in ("hemoglobin", "hemoglobin_gm_percent", "hb", "lab_results.hemoglobin"): | |
| hval = flat.get(hkey) | |
| if hval is not None: | |
| try: | |
| num = float(str(hval).replace("g/dl", "").replace("gm", "").strip()) | |
| if "hemoglobin_gm_percent" not in vitals or vitals.get("hemoglobin_gm_percent") is None: | |
| vitals["hemoglobin_gm_percent"] = num | |
| except (ValueError, TypeError): | |
| pass | |
| break | |
| if vitals: | |
| result["vitals"] = vitals | |
| # ── Gestational weeks normalization ── | |
| pregnancy = result.get("pregnancy", {}) | |
| if "gestational_weeks" not in pregnancy or pregnancy.get("gestational_weeks") is None: | |
| for gkey in ("gestational_weeks", "gestational_age", "pregnancy.gestational_age", | |
| "pregnancy.gestational_weeks", "gestation_weeks"): | |
| gval = flat.get(gkey) | |
| if gval is not None: | |
| try: | |
| num = int(re.search(r'(\d+)', str(gval)).group(1)) | |
| pregnancy["gestational_weeks"] = num | |
| except (ValueError, TypeError, AttributeError): | |
| pass | |
| break | |
| if pregnancy: | |
| result["pregnancy"] = pregnancy | |
| # ── Child age normalization ── | |
| for akey in ("age_months", "child.age_months", "age"): | |
| aval = flat.get(akey) | |
| if aval is not None: | |
| child = result.get("child", {}) | |
| if isinstance(child, dict) and ("age_months" not in child or child.get("age_months") is None): | |
| try: | |
| num = int(re.search(r'(\d+)', str(aval)).group(1)) | |
| child["age_months"] = num | |
| result["child"] = child | |
| except (ValueError, TypeError, AttributeError): | |
| pass | |
| break | |
| return result | |
| def _run_inference_unsloth(system_prompt, user_prompt): | |
| """Run inference via Unsloth/transformers — slower but works without Ollama.""" | |
| import torch | |
| model, tokenizer = load_model() | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(text=[text], return_tensors="pt").to("cuda") | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, max_new_tokens=768, do_sample=False) | |
| elapsed = time.time() - t0 | |
| response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| parsed = _parse_json_response(response) | |
| return {"raw": response, "parsed": parsed, "time_s": elapsed} | |
| def _parse_json_response(response): | |
| """Parse JSON from model response, handling markdown fences and quirks.""" | |
| print(f"[DEBUG] raw response repr (first 80): {repr(response[:80])}") | |
| # Strip markdown fences — handle variations: ```json, ``` json, whitespace, BOM | |
| clean = response.strip().lstrip('\ufeff') | |
| clean = re.sub(r'^`{3,}\s*(?:json)?\s*[\r\n]*', '', clean, flags=re.IGNORECASE) | |
| clean = re.sub(r'[\r\n]*`{3,}\s*$', '', clean) | |
| clean = clean.strip() | |
| # Fix common model quirks | |
| if clean and clean[0] == '"' and not clean.startswith('{"') and not clean.startswith('["'): | |
| clean = "{" + clean | |
| if clean and clean[0] not in ('{', '['): | |
| first_brace = min( | |
| (clean.find("{") if clean.find("{") >= 0 else len(clean)), | |
| (clean.find("[") if clean.find("[") >= 0 else len(clean)), | |
| ) | |
| if first_brace < len(clean): | |
| print(f"[DEBUG] skipped leading junk: {repr(clean[:first_brace])}") | |
| clean = clean[first_brace:] | |
| clean = re.sub(r'"{2,}([^"]+)"{2,}', r'"\1"', clean) | |
| clean = re.sub(r'(?<=: )"{2,}', '"', clean) | |
| clean = re.sub(r'"{2,}(?=\s*[,\}\]])', '"', clean) | |
| clean = re.sub(r',\s*([}\]])', r'\1', clean) | |
| print(f"[DEBUG] cleaned JSON (first 120): {repr(clean[:120])}") | |
| try: | |
| return json.loads(clean) | |
| except json.JSONDecodeError as e: | |
| print(f"[DEBUG] JSON parse failed: {e}") | |
| for end_pos in range(len(clean), max(0, len(clean) - 200), -1): | |
| if clean[end_pos - 1] in ('}', ']'): | |
| try: | |
| parsed = json.loads(clean[:end_pos]) | |
| print(f"[DEBUG] recovered JSON by truncating at pos {end_pos}") | |
| return parsed | |
| except json.JSONDecodeError: | |
| continue | |
| print(f"[DEBUG] FULL raw response ({len(response)} chars):\n{response}\n---END---") | |
| return None | |
| # ============================================================ | |
| # EXTRACTION PIPELINE | |
| # ============================================================ | |
| def detect_visit_type(transcript): | |
| """Heuristic visit type detection from transcript content.""" | |
| t = transcript.lower() | |
| # Delivery — check first, most specific keywords | |
| if any(kw in t for kw in ["डिलीवरी हो गई", "डिलीवरी हुई", "delivery हुई", | |
| "डिलीवरी कब हुई", "delivery कब", | |
| "जन्म हुआ", "पैदा हुआ", "प्रसव हुआ", | |
| "लड़का हुआ", "लड़की हुई", "लड़की हुआ", | |
| "घर पर ही हो गया", "घर पर हुई", "घर पर हुआ", | |
| "ऑपरेशन से हुई", "caesarean", "सिजेरियन", | |
| "जन्म का वजन", "birth weight", "birth_weight", | |
| "जन्म के समय", "normal delivery", "दाई ने"]): | |
| return "delivery" | |
| # ANC — check before PNC/child (broad keywords like टीका overlap) | |
| if any(kw in t for kw in ["गर्भ", "प्रेग्नेंसी", "pregnancy", "anc", "पेट में बच्चा", | |
| "गर्भवती", "हफ्ते की", "हफ्ते हो", "महीने की", | |
| "lmp", "edd", "bp चेक", "hb ", "ifa", "tt का टीका", | |
| "बच्चे की हलचल", "fetal", "डिलीवरी कहाँ", "डिलीवरी के लिए", | |
| "जन्म के लिए तैयारी", "birth preparedness"]): | |
| return "anc_visit" | |
| # PNC — postpartum mother/newborn care | |
| if any(kw in t for kw in ["नवजात", "newborn", "दूध पीना", "दूध नहीं पीता", "दूध पीता", | |
| "दूध पी रहा", "दूध नहीं पी", "दूध पिला", | |
| "नाभि", "cord", "नाल", "स्तनपान", | |
| "breastfeed", "imnci", "hbnc", "डिलीवरी के बाद", | |
| "डिलीवरी को", "delivery को", "pnc", | |
| "खून बहना", "खून आ रहा", "pad ", "पैड "]): | |
| return "pnc_visit" | |
| # Child health — older infants/children | |
| # Note: dropped "बच्चे को" — fires falsely on ANC danger-talk like | |
| # "तुम्हारा और बच्चे को खतरा" (preeclampsia warning to mother). | |
| # "child" also dropped — too generic, can appear in delivery/PNC counseling. | |
| if any(kw in t for kw in ["बच्चा कैसा", "बच्चा कैसी", "बच्चे का वजन", "बच्ची का वजन", | |
| "टीका लग", "vaccine", "deworming", "vitamin a", "hbyc", | |
| "महीने का", "महीने है", "दस्त", "diarrhea", | |
| "खाता है", "खेलता है", "आँखें धँसी", | |
| "सुस्त है", "सुस्त हो", "बहुत सुस्त"]): | |
| return "child_health" | |
| return "anc_visit" | |
| def build_trimmed_danger_schema(): | |
| """Danger sign schema without checklists — much smaller output.""" | |
| return { | |
| "type": "object", | |
| "properties": { | |
| "visit_type": { | |
| "type": "string", | |
| "enum": ["antenatal", "postnatal_mother", "newborn", "child_under5"], | |
| }, | |
| "danger_signs": { | |
| "type": "array", | |
| "description": "Detected danger signs. Empty array [] if none found.", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "sign": {"type": "string"}, | |
| "category": {"type": "string", "enum": ["immediate_referral", "urgent_care", "monitor_closely"]}, | |
| "clinical_value": {"type": ["string", "null"]}, | |
| "utterance_evidence": {"type": "string", "description": "REQUIRED: exact verbatim quote"}, | |
| }, | |
| "required": ["sign", "category", "utterance_evidence"], | |
| }, | |
| }, | |
| "referral_decision": { | |
| "type": "object", | |
| "properties": { | |
| "decision": {"type": "string", "enum": ["refer_immediately", "refer_within_24h", "continue_monitoring", "routine_followup"]}, | |
| "reason": {"type": "string"}, | |
| }, | |
| "required": ["decision", "reason"], | |
| }, | |
| }, | |
| "required": ["visit_type", "danger_signs", "referral_decision"], | |
| } | |
| # Maternal danger sign names that map to checklist fields | |
| MATERNAL_CHECKLIST_SIGNS = { | |
| "severe_vaginal_bleeding": ["vaginal bleeding", "severe bleeding", "रक्तस्राव", "खून"], | |
| "convulsions": ["convulsion", "seizure", "दौरा", "अकड़न"], | |
| # preeclampsia is the diagnostic name the LLM may emit instead of the symptom triad — | |
| # treat its presence as an explicit detection of severe headache + blurred vision | |
| "severe_headache_blurred_vision": [ | |
| "headache", "blurred vision", "सिरदर्द", "धुंधला", | |
| "preeclampsia", "pre-eclampsia", "प्रीक्लिम्सिया", "प्री-एक्लेम्पसिया", | |
| ], | |
| "high_fever": ["high fever", "fever", "बुखार", "तेज़ बुखार"], | |
| "severe_abdominal_pain": ["abdominal pain", "पेट दर्द", "पेट में दर्द"], | |
| "fast_difficult_breathing": ["breathing", "साँस", "सांस"], | |
| # "सूज" matches the verb-stem (पैर सूज रहे हैं) which "सूजन" does not | |
| "swelling_face_hands": ["swelling", "edema", "सूजन", "सूज"], | |
| "reduced_fetal_movement": ["fetal movement", "reduced movement", "हलचल कम", "हिलता नहीं"], | |
| "water_break_prom": ["water break", "पानी टूट", "झिल्ली"], | |
| "foul_vaginal_discharge": ["discharge", "बदबूदार", "स्राव"], | |
| } | |
| NEWBORN_CHECKLIST_SIGNS = { | |
| "not_feeding_well": ["not feeding", "feeding", "दूध नहीं", "दूध पीना"], | |
| "convulsions": ["convulsion", "seizure", "दौरा"], | |
| "fast_breathing_gte60": ["fast breathing", "breathing", "साँस तेज़"], | |
| "severe_chest_indrawing": ["chest indrawing", "छाती धँसना"], | |
| "high_temperature": ["high temperature", "fever", "बुखार", "तापमान"], | |
| "low_temperature": ["low temperature", "ठंडा", "हाइपोथर्मिया"], | |
| "no_movement": ["no movement", "सुस्त", "हिलता नहीं"], | |
| "jaundice": ["jaundice", "पीलिया"], | |
| "umbilicus_red_pus": ["umbilicus", "नाभि", "cord"], | |
| } | |
| def derive_checklists(danger_signs, visit_type): | |
| """Derive maternal/newborn checklists from the danger_signs array.""" | |
| maternal_ck = {k: "not_assessed" for k in MATERNAL_CHECKLIST_SIGNS} | |
| newborn_ck = {k: "not_assessed" for k in NEWBORN_CHECKLIST_SIGNS} | |
| if not danger_signs: | |
| return maternal_ck, newborn_ck | |
| # Check each detected sign against checklist keywords | |
| detected_signs_text = " ".join( | |
| f"{s.get('sign', '')} {s.get('utterance_evidence', '')}".lower() | |
| for s in danger_signs | |
| ) | |
| for field, keywords in MATERNAL_CHECKLIST_SIGNS.items(): | |
| if any(kw.lower() in detected_signs_text for kw in keywords): | |
| maternal_ck[field] = "detected" | |
| else: | |
| maternal_ck[field] = "not_detected" | |
| for field, keywords in NEWBORN_CHECKLIST_SIGNS.items(): | |
| if any(kw.lower() in detected_signs_text for kw in keywords): | |
| newborn_ck[field] = "detected" | |
| else: | |
| newborn_ck[field] = "not_detected" | |
| return maternal_ck, newborn_ck | |
| def validate_form_output(parsed, transcript): | |
| """Post-extraction validation: strip hallucinated fields, apply range checks. | |
| Common hallucination patterns on audio transcripts: | |
| - patient.name = "दीदी" / "बहन" / "Patient" (generic address, not a name) | |
| - patient.age = 30 (model's default guess) | |
| - lab_results.blood_group / hiv_status invented when not discussed | |
| """ | |
| if not isinstance(parsed, dict): | |
| return parsed | |
| t_lower = transcript.lower() if transcript else "" | |
| # -- Name hallucination: generic Hindi address terms -- | |
| FAKE_NAMES = {"दीदी", "बहन", "बहनजी", "patient", "दी दी", "didi", "bahen"} | |
| patient = parsed.get("patient") or {} | |
| name = patient.get("name") or patient.get("patient_name") | |
| if name and name.strip().lower() in FAKE_NAMES: | |
| if "patient" in parsed and isinstance(parsed["patient"], dict): | |
| for key in ("name", "patient_name"): | |
| if key in parsed["patient"]: | |
| parsed["patient"][key] = None | |
| print(f"[VALIDATE] Stripped hallucinated name: {name}") | |
| # -- Age hallucination: exactly 30 when not mentioned -- | |
| age = patient.get("age") or patient.get("patient_age") | |
| if age == 30: | |
| # Check if "30" or "तीस" actually appears in transcript | |
| if "30" not in transcript and "तीस" not in transcript: | |
| if "patient" in parsed and isinstance(parsed["patient"], dict): | |
| for key in ("age", "patient_age"): | |
| if key in parsed["patient"]: | |
| parsed["patient"][key] = None | |
| print(f"[VALIDATE] Stripped hallucinated age: 30") | |
| # -- Lab results hallucination: blood_group, HIV when not discussed -- | |
| lab = parsed.get("lab_results") or {} | |
| BLOOD_GROUPS = {"a+", "a-", "b+", "b-", "ab+", "ab-", "o+", "o-"} | |
| bg = lab.get("blood_group") | |
| if bg and str(bg).strip().lower() in BLOOD_GROUPS: | |
| bg_mentioned = any(kw in t_lower for kw in ["blood group", "ब्लड ग्रुप", "खून का ग्रुप", "रक्त समूह"]) | |
| if not bg_mentioned: | |
| parsed.setdefault("lab_results", {})["blood_group"] = None | |
| print(f"[VALIDATE] Stripped hallucinated blood_group: {bg}") | |
| hiv = lab.get("hiv_status") or lab.get("hiv") | |
| if hiv and str(hiv).strip().lower() in ("negative", "positive", "नेगेटिव", "पॉजिटिव"): | |
| hiv_mentioned = any(kw in t_lower for kw in ["hiv", "एचआईवी", "एड्स"]) | |
| if not hiv_mentioned: | |
| for key in ("hiv_status", "hiv"): | |
| if key in parsed.get("lab_results", {}): | |
| parsed["lab_results"][key] = None | |
| print(f"[VALIDATE] Stripped hallucinated HIV: {hiv}") | |
| # -- Range checks on vital signs -- | |
| RANGES = { | |
| "bp_systolic": (60, 250), "bp_diastolic": (30, 150), | |
| "weight_kg": (1, 200), "hemoglobin_gm_percent": (3, 20), | |
| "gestational_weeks": (1, 45), "temperature_f": (90, 110), | |
| } | |
| for section in [parsed, parsed.get("vitals", {}), parsed.get("pregnancy", {}), | |
| parsed.get("anc_details", {}), parsed.get("newborn", {})]: | |
| if not isinstance(section, dict): | |
| continue | |
| for field, (lo, hi) in RANGES.items(): | |
| val = section.get(field) | |
| if val is not None: | |
| try: | |
| num = float(val) | |
| if num < lo or num > hi: | |
| section[field] = None | |
| print(f"[VALIDATE] Out-of-range {field}={val} (valid: {lo}-{hi})") | |
| except (ValueError, TypeError): | |
| pass | |
| return parsed | |
| def extract_form(transcript, visit_type): | |
| """Extract structured form data from transcript.""" | |
| schema = SCHEMAS.get(VISIT_TYPE_MAP.get(visit_type, "anc_visit"), SCHEMAS["anc_visit"]) | |
| user_prompt = ( | |
| f"Extract structured data from this ASHA home visit conversation:\n\n" | |
| f"{transcript}\n\n" | |
| f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}" | |
| ) | |
| result = run_inference(FORM_SYSTEM_PROMPT, user_prompt) | |
| if result.get("parsed") and isinstance(result["parsed"], dict): | |
| result["parsed"] = validate_form_output(result["parsed"], transcript) | |
| return result | |
| def extract_danger_signs(transcript, visit_type): | |
| """Extract danger signs using trimmed schema (no checklists) + post-validation.""" | |
| schema = build_trimmed_danger_schema() | |
| user_prompt = ( | |
| f"Analyze this ASHA home visit conversation for danger signs.\n\n" | |
| f"Visit type: {visit_type}\n\n" | |
| f"{transcript}\n\n" | |
| f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}" | |
| ) | |
| result = run_inference(DANGER_SYSTEM_PROMPT, user_prompt) | |
| # Post-validation: drop danger signs whose evidence isn't in the transcript | |
| # or whose evidence is a generic ASHA phrase (not actual symptom description) | |
| GENERIC_PHRASES = [ | |
| "कोई तकलीफ़ हो तो फ़ोन कर दीजिए", | |
| "कोई तकलीफ हो तो फोन कर दीजिए", | |
| "कोई समस्या हो तो तुरंत बताइए", | |
| "कोई समस्या हो तो फोन करें", | |
| "कोई दिक्कत हो तो", | |
| "अगली बार आऊँगी", | |
| "अगली विज़िट", | |
| "ठीक है दीदी, धन्यवाद", | |
| "ठीक है दीदी", | |
| ] | |
| # Normal vital sign readings that should NOT be flagged as danger signs | |
| NORMAL_INDICATORS = [ | |
| "110/70", "120/80", "110/80", "118/76", "108/72", # normal BP | |
| "बिल्कुल ठीक", "सामान्य", "नॉर्मल", "अच्छा है", "ठीक है", | |
| "बिल्कुल सामान्य", | |
| ] | |
| if result["parsed"] and "danger_signs" in result["parsed"]: | |
| validated_signs = [] | |
| norm_transcript = re.sub(r'\s+', ' ', transcript.strip()) | |
| for sign in result["parsed"]["danger_signs"]: | |
| evidence = sign.get("utterance_evidence", "") | |
| if not evidence or len(evidence) < 10: | |
| print(f"[DEBUG] dropped sign '{sign.get('sign','')}': evidence too short ({len(evidence)} chars)") | |
| continue | |
| norm_evidence = re.sub(r'\s+', ' ', evidence.strip()) | |
| # Check against generic phrase blocklist | |
| is_generic = any(phrase in norm_evidence for phrase in GENERIC_PHRASES) | |
| if is_generic: | |
| print(f"[DEBUG] dropped sign '{sign.get('sign','')}': evidence is generic ASHA phrase") | |
| continue | |
| # Check if evidence describes a normal reading, not a danger sign | |
| is_normal = any(indicator in norm_evidence for indicator in NORMAL_INDICATORS) | |
| if is_normal: | |
| print(f"[DEBUG] dropped sign '{sign.get('sign','')}': evidence contains normal vital indicator") | |
| continue | |
| found = False | |
| if norm_evidence in norm_transcript: | |
| found = True | |
| elif len(norm_evidence) >= 20: | |
| min_chunk = min(30, len(norm_evidence)) | |
| for i in range(0, len(norm_evidence) - min_chunk + 1): | |
| chunk = norm_evidence[i:i + min_chunk] | |
| if chunk in norm_transcript: | |
| found = True | |
| break | |
| if found: | |
| validated_signs.append(sign) | |
| else: | |
| print(f"[DEBUG] dropped sign '{sign.get('sign','')}': evidence not found in transcript") | |
| print(f"[DEBUG] evidence: {repr(norm_evidence[:80])}") | |
| # If all remaining signs cite the same evidence, it's likely generic — drop all | |
| if len(validated_signs) > 1: | |
| evidences = set(s.get("utterance_evidence", "").strip() for s in validated_signs) | |
| if len(evidences) == 1: | |
| print(f"[DEBUG] dropped all {len(validated_signs)} signs: all cite same evidence (likely generic)") | |
| validated_signs = [] | |
| dropped = len(result["parsed"]["danger_signs"]) - len(validated_signs) | |
| if dropped: | |
| print(f"[DEBUG] post-validation dropped {dropped}/{dropped + len(validated_signs)} danger signs") | |
| result["parsed"]["danger_signs"] = validated_signs | |
| if not validated_signs: | |
| result["parsed"]["referral_decision"] = { | |
| "decision": "routine_followup", | |
| "reason": "No danger signs detected in conversation", | |
| } | |
| # Derive checklists programmatically (instead of model generating them) | |
| if result["parsed"]: | |
| signs = result["parsed"].get("danger_signs", []) | |
| maternal_ck, newborn_ck = derive_checklists(signs, visit_type) | |
| result["parsed"]["maternal_danger_signs_checklist"] = maternal_ck | |
| result["parsed"]["newborn_danger_signs_checklist"] = newborn_ck | |
| return result | |
| def _validate_fc_danger_signs(danger_signs, transcript): | |
| """Post-validate danger signs from function calling — same logic as extract_danger_signs.""" | |
| GENERIC_PHRASES = [ | |
| "कोई तकलीफ़ हो तो फ़ोन कर दीजिए", | |
| "कोई तकलीफ हो तो फोन कर दीजिए", | |
| "कोई समस्या हो तो तुरंत बताइए", | |
| "कोई समस्या हो तो फोन करें", | |
| "कोई दिक्कत हो तो", | |
| "अगली बार आऊँगी", | |
| "अगली विज़िट", | |
| "ठीक है दीदी, धन्यवाद", | |
| "ठीक है दीदी", | |
| ] | |
| NORMAL_INDICATORS = [ | |
| "110/70", "120/80", "110/80", "118/76", "108/72", | |
| "बिल्कुल ठीक", "सामान्य", "नॉर्मल", "अच्छा है", "ठीक है", | |
| "बिल्कुल सामान्य", | |
| ] | |
| validated = [] | |
| norm_transcript = re.sub(r'\s+', ' ', transcript.strip()) | |
| for sign in danger_signs: | |
| evidence = sign.get("utterance_evidence") or sign.get("evidence", "") | |
| if not evidence or len(evidence) < 10: | |
| print(f"[DEBUG] FC dropped sign '{sign.get('sign','')}': evidence too short") | |
| continue | |
| norm_evidence = re.sub(r'\s+', ' ', evidence.strip()) | |
| if any(phrase in norm_evidence for phrase in GENERIC_PHRASES): | |
| print(f"[DEBUG] FC dropped sign '{sign.get('sign','')}': generic phrase") | |
| continue | |
| if any(indicator in norm_evidence for indicator in NORMAL_INDICATORS): | |
| print(f"[DEBUG] FC dropped sign '{sign.get('sign','')}': normal vital") | |
| continue | |
| # Check evidence exists in transcript | |
| found = False | |
| if norm_evidence in norm_transcript: | |
| found = True | |
| elif len(norm_evidence) >= 20: | |
| min_chunk = min(30, len(norm_evidence)) | |
| for i in range(0, len(norm_evidence) - min_chunk + 1): | |
| if norm_evidence[i:i + min_chunk] in norm_transcript: | |
| found = True | |
| break | |
| if found: | |
| validated.append(sign) | |
| else: | |
| print(f"[DEBUG] FC dropped sign '{sign.get('sign','')}': evidence not in transcript") | |
| # Same-evidence dedup | |
| if len(validated) > 1: | |
| evidences = set((s.get("utterance_evidence") or s.get("evidence", "")).strip() for s in validated) | |
| if len(evidences) == 1: | |
| print(f"[DEBUG] FC dropped all {len(validated)} signs: same evidence") | |
| validated = [] | |
| dropped = len(danger_signs) - len(validated) | |
| if dropped: | |
| print(f"[DEBUG] FC post-validation dropped {dropped}/{len(danger_signs)} danger signs") | |
| return validated | |
| def apply_metadata(form, visit_type, metadata): | |
| """Merge ASHA-entered patient identifier metadata into the LLM-extracted form. | |
| Metadata keys are schema-agnostic (patient_name, patient_age, age_unit, patient_sex, | |
| asha_id, visit_date, patient_mobile). This function overrides whichever schema-specific | |
| fields make sense for the visit type — leaving other LLM output untouched. | |
| PNC and delivery schemas have no patient block, so the metadata is preserved only | |
| in the envelope returned alongside the form (see extract_all). | |
| """ | |
| if not form or not isinstance(form, dict) or not metadata: | |
| return form | |
| name = metadata.get("patient_name") or None | |
| age = metadata.get("patient_age") | |
| age_unit = (metadata.get("age_unit") or "").lower() | |
| sex = (metadata.get("patient_sex") or "").lower() or None | |
| mobile = metadata.get("patient_mobile") or None | |
| if visit_type == "anc_visit": | |
| patient = form.setdefault("patient", {}) if isinstance(form.get("patient"), dict) else None | |
| if patient is not None: | |
| if name: patient["name"] = name | |
| if age is not None and age_unit in ("", "years"): | |
| patient["age"] = age | |
| if mobile: patient["mobile"] = mobile | |
| elif visit_type == "child_health": | |
| child = form.setdefault("child", {}) if isinstance(form.get("child"), dict) else None | |
| if child is not None: | |
| if name: child["name"] = name | |
| if age is not None: | |
| # Convert to months for child_health schema | |
| if age_unit == "years": | |
| child["age_months"] = int(age) * 12 | |
| elif age_unit in ("", "months"): | |
| child["age_months"] = int(age) | |
| if sex in ("male", "female"): | |
| child["sex"] = sex | |
| # pnc_visit and delivery — no schema-level patient block; envelope-only. | |
| return form | |
| def extract_all(transcript, visit_type, metadata=None): | |
| """Hybrid extraction: format="json" for form (precise), function calling for danger+referral. | |
| Falls back to two format="json" calls if function calling is off. | |
| Optional `metadata` dict (patient identifier fields entered by ASHA before recording) | |
| is merged into the form and returned in the envelope. See apply_metadata(). | |
| """ | |
| if not (USE_OLLAMA and USE_FUNCTION_CALLING): | |
| # Fallback: two separate json-mode calls | |
| form_result = extract_form(transcript, visit_type) | |
| danger_result = extract_danger_signs(transcript, visit_type) | |
| form_data = apply_metadata(form_result.get("parsed"), visit_type, metadata) | |
| return { | |
| "form": form_data, | |
| "danger": danger_result.get("parsed"), | |
| "metadata": metadata or None, | |
| "tool_calls": [], | |
| "timing": { | |
| "form_s": round(form_result.get("time_s", 0), 1), | |
| "danger_s": round(danger_result.get("time_s", 0), 1), | |
| }, | |
| } | |
| # ── Step 1: Form extraction via format="json" (proven precision) ── | |
| t0 = time.time() | |
| form_result = extract_form(transcript, visit_type) | |
| form_time = time.time() - t0 | |
| form_data = form_result.get("parsed") | |
| # ── Step 2: Danger signs + referral via function calling ── | |
| fc_result = _run_danger_fc(transcript, visit_type) | |
| # Post-process danger signs | |
| raw_signs = fc_result["danger_signs"] | |
| validated_signs = _validate_fc_danger_signs(raw_signs, transcript) | |
| # Build referral decision | |
| referral_raw = fc_result["referral"] | |
| if validated_signs: | |
| urgency_map = { | |
| "immediate": "refer_immediately", | |
| "within_24h": "refer_within_24h", | |
| "routine": "continue_monitoring", | |
| } | |
| if referral_raw: | |
| referral_decision = { | |
| "decision": urgency_map.get(referral_raw.get("urgency"), "continue_monitoring"), | |
| "reason": referral_raw.get("reason", ""), | |
| "evidence_utterances": [s.get("utterance_evidence") or s.get("evidence", "") for s in validated_signs], | |
| "recommended_facility": referral_raw.get("facility"), | |
| } | |
| else: | |
| referral_decision = { | |
| "decision": "continue_monitoring", | |
| "reason": "Danger signs detected but no explicit referral issued", | |
| "evidence_utterances": [s.get("utterance_evidence") or s.get("evidence", "") for s in validated_signs], | |
| } | |
| else: | |
| referral_decision = { | |
| "decision": "routine_followup", | |
| "reason": "No danger signs detected in conversation", | |
| "evidence_utterances": [], | |
| } | |
| # Normalize danger sign format to match existing schema | |
| normalized_signs = [] | |
| for s in validated_signs: | |
| normalized_signs.append({ | |
| "sign": s.get("sign", ""), | |
| "category": s.get("category", "monitor_closely"), | |
| "clinical_value": s.get("clinical_value"), | |
| "utterance_evidence": s.get("utterance_evidence") or s.get("evidence", ""), | |
| }) | |
| # Derive checklists | |
| maternal_ck, newborn_ck = derive_checklists(normalized_signs, visit_type) | |
| danger_data = { | |
| "visit_type": visit_type, | |
| "danger_signs": normalized_signs, | |
| "referral_decision": referral_decision, | |
| "maternal_danger_signs_checklist": maternal_ck, | |
| "newborn_danger_signs_checklist": newborn_ck, | |
| } | |
| form_data = apply_metadata(form_data, visit_type, metadata) | |
| return { | |
| "form": form_data, | |
| "danger": danger_data, | |
| "metadata": metadata or None, | |
| "tool_calls": fc_result["tool_calls"], | |
| "timing": { | |
| "form_s": round(form_time, 1), | |
| "danger_s": round(fc_result["time_s"], 1), | |
| }, | |
| } | |