Spaces:
Sleeping
Sleeping
| """ | |
| MedScribe v2 — Function Calling Test (Gate 2) | |
| Tests Gemma 4 E4B's native function calling with MCTS extraction schemas. | |
| Tests: | |
| 1. Schema loading and validation | |
| 2. Function calling with a sample Hindi transcript | |
| 3. Output validation against JSON schema | |
| 4. Danger sign evidence grounding check | |
| 5. Ollama function calling (text-only path) | |
| 6. Transformers function calling (for audio pipeline) | |
| Usage: | |
| python scripts/02_test_function_calling.py --mode ollama | |
| python scripts/02_test_function_calling.py --mode transformers | |
| python scripts/02_test_function_calling.py --validate-only # schema validation only | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # ── Sample Test Data ──────────────────────────────────────────────────────── | |
| SAMPLE_TRANSCRIPT_HINDI = """ | |
| ASHA: नमस्ते बहन जी, कैसी तबीयत है आपकी? | |
| Patient: नमस्ते दीदी, ठीक हूँ, बस थोड़ा चक्कर आता है कभी-कभी। | |
| ASHA: पिछली बार जब आई थी तब आपने बताया था कि पैर सूज रहे हैं, अभी कैसे हैं? | |
| Patient: हाँ दीदी, अभी भी सूजे हुए हैं, खासकर शाम को ज्यादा सूज जाते हैं। हाथ भी थोड़े सूजे लग रहे हैं। | |
| ASHA: अच्छा, चलिए बी.पी. चेक करते हैं... बहन जी आपका बी.पी. 145/95 आ रहा है, ये थोड़ा ज़्यादा है। पिछली बार 130/85 था। | |
| Patient: अरे, ये तो बढ़ गया दीदी! क्या करना चाहिए? | |
| ASHA: और सिर में दर्द तो नहीं हो रहा? | |
| Patient: हाँ, कल रात से सिर में दर्द हो रहा है, और आँखों के सामने थोड़ा धुंधला भी दिखा। | |
| ASHA: ये सुनकर मुझे थोड़ी चिंता हो रही है। आपकी प्रेगनेंसी का कितना महीना चल रहा है? | |
| Patient: 8वां महीना है दीदी, करीब 33-34 हफ्ते हो गए। | |
| ASHA: वज़न चेक करते हैं... 62 किलो है। पिछली बार 59 किलो था, 3 किलो बढ़ा है दो हफ्ते में। बच्चा हिल रहा है ठीक से? | |
| Patient: हाँ, हिल तो रहा है, लेकिन पहले से कम लग रहा है। | |
| ASHA: आयरन की गोली खा रही हैं? | |
| Patient: हाँ दीदी, रोज़ खा रही हूँ। TT का दूसरा टीका भी लगवा लिया है। | |
| ASHA: अच्छा, अस्पताल जाने का इंतज़ाम किया है? कौन ले जाएगा? | |
| Patient: हाँ, पति जी का ऑटो है, वो ले जाएंगे। ज़िला अस्पताल जाएंगे। | |
| """ | |
| SAMPLE_TRANSCRIPT_NORMAL = """ | |
| ASHA: नमस्ते बहन जी, कैसी हैं? बच्चे की तबीयत कैसी है? | |
| Patient: नमस्ते दीदी, मैं ठीक हूँ, बच्चा भी ठीक है। | |
| ASHA: बच्चे को दूध पिला रही हैं? | |
| Patient: हाँ दीदी, सिर्फ अपना दूध दे रही हूँ, ऊपर का कुछ नहीं दिया। | |
| ASHA: बहुत अच्छा! बच्चे का वज़न देखते हैं... 3.2 किलो है, जन्म के समय 2.8 था। अच्छा बढ़ रहा है। | |
| Patient: हाँ, ठीक से पी रहा है, हर 2-3 घंटे में। | |
| ASHA: नाभि कैसी है? | |
| Patient: सूख गई है दीदी, साफ है। | |
| ASHA: बच्चे की BCG और OPV की पहली खुराक लगवा ली थी ना? | |
| Patient: हाँ, अस्पताल में ही लगा दी थी जन्म के समय। | |
| ASHA: अगला टीका 6 हफ्ते पर लगेगा, याद रखना। और आप आयरन की गोली खा रही हैं? | |
| Patient: हाँ दीदी, रोज़ खा रही हूँ। | |
| """ | |
| # ── Schema Loading ───────────────────────────────────────────────────────── | |
| def load_schemas() -> dict: | |
| """Load all MCTS extraction schemas from configs/schemas/.""" | |
| schema_dir = Path("configs/schemas") | |
| schemas = {} | |
| for f in schema_dir.glob("*.json"): | |
| with open(f, "r", encoding="utf-8") as fh: | |
| schemas[f.stem] = json.load(fh) | |
| print(f" Loaded {len(schemas)} schemas: {list(schemas.keys())}") | |
| return schemas | |
| def validate_schema(schema: dict) -> bool: | |
| """Validate that a schema is well-formed JSON Schema.""" | |
| try: | |
| import jsonschema | |
| jsonschema.Draft7Validator.check_schema(schema) | |
| return True | |
| except Exception as e: | |
| print(f" Schema validation error: {e}") | |
| return False | |
| # ── Tool Definitions for Gemma 4 ────────────────────────────────────────── | |
| def build_tool_definitions(schemas: dict) -> list: | |
| """ | |
| Convert JSON schemas into Gemma 4 function calling tool definitions. | |
| Format follows the HuggingFace apply_chat_template(tools=...) pattern. | |
| """ | |
| tools = [] | |
| for name, schema in schemas.items(): | |
| tool = { | |
| "type": "function", | |
| "function": { | |
| "name": f"extract_{name}", | |
| "description": schema.get("description", f"Extract {name} data from conversation"), | |
| "parameters": schema, | |
| } | |
| } | |
| tools.append(tool) | |
| return tools | |
| # ── Ollama Function Calling Test ────────────────────────────────────────── | |
| def test_ollama_function_calling(transcript: str, schemas: dict): | |
| """Test function calling via Ollama (text-only path).""" | |
| try: | |
| import ollama | |
| except ImportError: | |
| print(" ollama package not installed. Skipping.") | |
| return None | |
| print(f"\n=== Ollama Function Calling Test ===") | |
| tools = build_tool_definitions(schemas) | |
| # Use ANC + danger signs schemas for this test | |
| test_tools = [t for t in tools if t["function"]["name"] in ("extract_anc_visit", "extract_danger_signs")] | |
| system_prompt = ( | |
| "You are a clinical data extraction system for India's ASHA health worker program. " | |
| "Extract structured data from the Hindi/Hinglish conversation transcript. " | |
| "ONLY extract information explicitly stated in the conversation. " | |
| "Use null for any field not mentioned. " | |
| "For danger signs, you MUST provide the exact utterance from the conversation as evidence. " | |
| "If no danger signs are present, return an empty danger_signs array." | |
| ) | |
| t0 = time.time() | |
| try: | |
| response = ollama.chat( | |
| model="gemma4:e4b-it-q4_K_M", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Extract clinical data from this ASHA home visit conversation:\n\n{transcript}"}, | |
| ], | |
| tools=test_tools, | |
| ) | |
| elapsed = time.time() - t0 | |
| print(f" Response time: {elapsed:.1f}s") | |
| # Parse tool calls from response | |
| if hasattr(response, "message") and hasattr(response.message, "tool_calls"): | |
| tool_calls = response.message.tool_calls | |
| print(f" Tool calls: {len(tool_calls)}") | |
| results = [] | |
| for tc in tool_calls: | |
| print(f"\n Function: {tc.function.name}") | |
| args = tc.function.arguments | |
| if isinstance(args, str): | |
| args = json.loads(args) | |
| print(f" Output:\n{json.dumps(args, ensure_ascii=False, indent=2)[:2000]}") | |
| results.append({"function": tc.function.name, "arguments": args}) | |
| return results | |
| else: | |
| # Model responded with text instead of tool call | |
| text = response.message.content if hasattr(response, "message") else str(response) | |
| print(f" Model returned text (no tool call):\n {text[:500]}") | |
| return None | |
| except Exception as e: | |
| print(f" Error: {e}") | |
| return None | |
| # ── Transformers Function Calling Test ───────────────────────────────────── | |
| def test_transformers_function_calling(transcript: str, schemas: dict, device: str = "cuda"): | |
| """Test function calling via HuggingFace Transformers (needed for audio pipeline).""" | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| print(f"\n=== Transformers Function Calling Test ===") | |
| model_id = "google/gemma-4-E4B-it" | |
| print(f" Loading {model_id}...") | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print(f" Model loaded in {time.time() - t0:.1f}s") | |
| tools = build_tool_definitions(schemas) | |
| test_tools = [t for t in tools if t["function"]["name"] in ("extract_anc_visit", "extract_danger_signs")] | |
| messages = [ | |
| {"role": "system", "content": ( | |
| "You are a clinical data extraction system for India's ASHA health worker program. " | |
| "Extract structured data from the Hindi/Hinglish conversation transcript. " | |
| "ONLY extract information explicitly stated in the conversation. " | |
| "Use null for any field not mentioned. " | |
| "For danger signs, you MUST provide the exact utterance as evidence." | |
| )}, | |
| {"role": "user", "content": f"Extract clinical data:\n\n{transcript}"}, | |
| ] | |
| # Apply chat template with tools | |
| t0 = time.time() | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tools=test_tools, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ).to(device) | |
| print(f" Input tokens: {inputs['input_ids'].shape[-1]}") | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| do_sample=False, | |
| ) | |
| response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| elapsed = time.time() - t0 | |
| print(f" Inference time: {elapsed:.1f}s") | |
| print(f" Raw response:\n{response[:2000]}") | |
| # Try to parse structured output | |
| try: | |
| # Gemma 4 function calls use special format | |
| parsed = json.loads(response) | |
| print(f"\n Parsed JSON successfully") | |
| return parsed | |
| except json.JSONDecodeError: | |
| print(f"\n Response is not pure JSON — may need parsing of tool call format") | |
| return response | |
| # ── Validation ───────────────────────────────────────────────────────────── | |
| def validate_extraction(result: dict, schema: dict, schema_name: str) -> dict: | |
| """Validate extraction result against schema. Returns validation report.""" | |
| import jsonschema | |
| report = {"schema": schema_name, "valid": True, "errors": [], "warnings": []} | |
| try: | |
| jsonschema.validate(result, schema) | |
| except jsonschema.ValidationError as e: | |
| report["valid"] = False | |
| report["errors"].append(str(e.message)) | |
| # Custom validation: danger sign evidence check | |
| if schema_name == "danger_signs": | |
| danger_signs = result.get("danger_signs", []) | |
| for ds in danger_signs: | |
| if not ds.get("utterance_evidence"): | |
| report["valid"] = False | |
| report["errors"].append( | |
| f"Danger sign '{ds.get('sign')}' has no utterance_evidence — HALLUCINATION" | |
| ) | |
| # Check referral decision has evidence | |
| referral = result.get("referral_decision", {}) | |
| if referral.get("decision") in ("refer_immediately", "refer_within_24h"): | |
| if not referral.get("evidence_utterances"): | |
| report["valid"] = False | |
| report["errors"].append("Referral decision has no evidence utterances") | |
| # Null field check: count how many fields are non-null | |
| non_null = 0 | |
| total = 0 | |
| for key, val in _flatten(result).items(): | |
| total += 1 | |
| if val is not None: | |
| non_null += 1 | |
| if total > 0: | |
| fill_rate = non_null / total * 100 | |
| report["fill_rate"] = f"{fill_rate:.0f}%" | |
| if fill_rate < 10: | |
| report["warnings"].append(f"Very low fill rate ({fill_rate:.0f}%) — model may not be extracting") | |
| return report | |
| def _flatten(d, parent_key="", sep="."): | |
| """Flatten nested dict for field counting.""" | |
| items = [] | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(_flatten(v, new_key, sep).items()) | |
| elif isinstance(v, list): | |
| items.append((new_key, v if v else None)) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| # ── Main ─────────────────────────────────────────────────────────────────── | |
| def main(): | |
| parser = argparse.ArgumentParser(description="MedScribe v2 — Function Calling Test") | |
| parser.add_argument("--mode", choices=["ollama", "transformers"], default="ollama") | |
| parser.add_argument("--validate-only", action="store_true", help="Only validate schemas") | |
| parser.add_argument("--transcript", type=str, help="Custom transcript file (UTF-8)") | |
| parser.add_argument("--normal", action="store_true", help="Use normal (no danger signs) transcript") | |
| parser.add_argument("--device", type=str, default="cuda") | |
| args = parser.parse_args() | |
| # Load and validate schemas | |
| print("=== Schema Validation ===") | |
| schemas = load_schemas() | |
| all_valid = True | |
| for name, schema in schemas.items(): | |
| valid = validate_schema(schema) | |
| status = "\033[92m[VALID]\033[0m" if valid else "\033[91m[INVALID]\033[0m" | |
| print(f" {status} {name}") | |
| if not valid: | |
| all_valid = False | |
| if not all_valid: | |
| print("\nSchema validation failed. Fix schemas before testing model.") | |
| sys.exit(1) | |
| if args.validate_only: | |
| print("\nAll schemas valid.") | |
| return | |
| # Select transcript | |
| if args.transcript: | |
| with open(args.transcript, "r", encoding="utf-8") as f: | |
| transcript = f.read() | |
| elif args.normal: | |
| transcript = SAMPLE_TRANSCRIPT_NORMAL | |
| print("\n Using NORMAL transcript (expect no danger signs)") | |
| else: | |
| transcript = SAMPLE_TRANSCRIPT_HINDI | |
| print("\n Using HIGH-RISK transcript (expect danger signs)") | |
| # Run test | |
| if args.mode == "ollama": | |
| results = test_ollama_function_calling(transcript, schemas) | |
| else: | |
| results = test_transformers_function_calling(transcript, schemas, args.device) | |
| if results: | |
| print("\n=== Extraction Results ===") | |
| # Save results | |
| output_path = "data/temp/function_calling_test_result.json" | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| print(f" Saved to {output_path}") | |
| else: | |
| print("\n No structured results returned. Model may need fine-tuning for this task.") | |
| if __name__ == "__main__": | |
| main() | |