sakhi / scripts /test_function_calling.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
MedScribe v2 — Function Calling Test (Gate 2)
Tests Gemma 4 E4B's native function calling with MCTS extraction schemas.
Tests:
1. Schema loading and validation
2. Function calling with a sample Hindi transcript
3. Output validation against JSON schema
4. Danger sign evidence grounding check
5. Ollama function calling (text-only path)
6. Transformers function calling (for audio pipeline)
Usage:
python scripts/02_test_function_calling.py --mode ollama
python scripts/02_test_function_calling.py --mode transformers
python scripts/02_test_function_calling.py --validate-only # schema validation only
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
# ── Sample Test Data ────────────────────────────────────────────────────────
SAMPLE_TRANSCRIPT_HINDI = """
ASHA: नमस्ते बहन जी, कैसी तबीयत है आपकी?
Patient: नमस्ते दीदी, ठीक हूँ, बस थोड़ा चक्कर आता है कभी-कभी।
ASHA: पिछली बार जब आई थी तब आपने बताया था कि पैर सूज रहे हैं, अभी कैसे हैं?
Patient: हाँ दीदी, अभी भी सूजे हुए हैं, खासकर शाम को ज्यादा सूज जाते हैं। हाथ भी थोड़े सूजे लग रहे हैं।
ASHA: अच्छा, चलिए बी.पी. चेक करते हैं... बहन जी आपका बी.पी. 145/95 आ रहा है, ये थोड़ा ज़्यादा है। पिछली बार 130/85 था।
Patient: अरे, ये तो बढ़ गया दीदी! क्या करना चाहिए?
ASHA: और सिर में दर्द तो नहीं हो रहा?
Patient: हाँ, कल रात से सिर में दर्द हो रहा है, और आँखों के सामने थोड़ा धुंधला भी दिखा।
ASHA: ये सुनकर मुझे थोड़ी चिंता हो रही है। आपकी प्रेगनेंसी का कितना महीना चल रहा है?
Patient: 8वां महीना है दीदी, करीब 33-34 हफ्ते हो गए।
ASHA: वज़न चेक करते हैं... 62 किलो है। पिछली बार 59 किलो था, 3 किलो बढ़ा है दो हफ्ते में। बच्चा हिल रहा है ठीक से?
Patient: हाँ, हिल तो रहा है, लेकिन पहले से कम लग रहा है।
ASHA: आयरन की गोली खा रही हैं?
Patient: हाँ दीदी, रोज़ खा रही हूँ। TT का दूसरा टीका भी लगवा लिया है।
ASHA: अच्छा, अस्पताल जाने का इंतज़ाम किया है? कौन ले जाएगा?
Patient: हाँ, पति जी का ऑटो है, वो ले जाएंगे। ज़िला अस्पताल जाएंगे।
"""
SAMPLE_TRANSCRIPT_NORMAL = """
ASHA: नमस्ते बहन जी, कैसी हैं? बच्चे की तबीयत कैसी है?
Patient: नमस्ते दीदी, मैं ठीक हूँ, बच्चा भी ठीक है।
ASHA: बच्चे को दूध पिला रही हैं?
Patient: हाँ दीदी, सिर्फ अपना दूध दे रही हूँ, ऊपर का कुछ नहीं दिया।
ASHA: बहुत अच्छा! बच्चे का वज़न देखते हैं... 3.2 किलो है, जन्म के समय 2.8 था। अच्छा बढ़ रहा है।
Patient: हाँ, ठीक से पी रहा है, हर 2-3 घंटे में।
ASHA: नाभि कैसी है?
Patient: सूख गई है दीदी, साफ है।
ASHA: बच्चे की BCG और OPV की पहली खुराक लगवा ली थी ना?
Patient: हाँ, अस्पताल में ही लगा दी थी जन्म के समय।
ASHA: अगला टीका 6 हफ्ते पर लगेगा, याद रखना। और आप आयरन की गोली खा रही हैं?
Patient: हाँ दीदी, रोज़ खा रही हूँ।
"""
# ── Schema Loading ─────────────────────────────────────────────────────────
def load_schemas() -> dict:
"""Load all MCTS extraction schemas from configs/schemas/."""
schema_dir = Path("configs/schemas")
schemas = {}
for f in schema_dir.glob("*.json"):
with open(f, "r", encoding="utf-8") as fh:
schemas[f.stem] = json.load(fh)
print(f" Loaded {len(schemas)} schemas: {list(schemas.keys())}")
return schemas
def validate_schema(schema: dict) -> bool:
"""Validate that a schema is well-formed JSON Schema."""
try:
import jsonschema
jsonschema.Draft7Validator.check_schema(schema)
return True
except Exception as e:
print(f" Schema validation error: {e}")
return False
# ── Tool Definitions for Gemma 4 ──────────────────────────────────────────
def build_tool_definitions(schemas: dict) -> list:
"""
Convert JSON schemas into Gemma 4 function calling tool definitions.
Format follows the HuggingFace apply_chat_template(tools=...) pattern.
"""
tools = []
for name, schema in schemas.items():
tool = {
"type": "function",
"function": {
"name": f"extract_{name}",
"description": schema.get("description", f"Extract {name} data from conversation"),
"parameters": schema,
}
}
tools.append(tool)
return tools
# ── Ollama Function Calling Test ──────────────────────────────────────────
def test_ollama_function_calling(transcript: str, schemas: dict):
"""Test function calling via Ollama (text-only path)."""
try:
import ollama
except ImportError:
print(" ollama package not installed. Skipping.")
return None
print(f"\n=== Ollama Function Calling Test ===")
tools = build_tool_definitions(schemas)
# Use ANC + danger signs schemas for this test
test_tools = [t for t in tools if t["function"]["name"] in ("extract_anc_visit", "extract_danger_signs")]
system_prompt = (
"You are a clinical data extraction system for India's ASHA health worker program. "
"Extract structured data from the Hindi/Hinglish conversation transcript. "
"ONLY extract information explicitly stated in the conversation. "
"Use null for any field not mentioned. "
"For danger signs, you MUST provide the exact utterance from the conversation as evidence. "
"If no danger signs are present, return an empty danger_signs array."
)
t0 = time.time()
try:
response = ollama.chat(
model="gemma4:e4b-it-q4_K_M",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Extract clinical data from this ASHA home visit conversation:\n\n{transcript}"},
],
tools=test_tools,
)
elapsed = time.time() - t0
print(f" Response time: {elapsed:.1f}s")
# Parse tool calls from response
if hasattr(response, "message") and hasattr(response.message, "tool_calls"):
tool_calls = response.message.tool_calls
print(f" Tool calls: {len(tool_calls)}")
results = []
for tc in tool_calls:
print(f"\n Function: {tc.function.name}")
args = tc.function.arguments
if isinstance(args, str):
args = json.loads(args)
print(f" Output:\n{json.dumps(args, ensure_ascii=False, indent=2)[:2000]}")
results.append({"function": tc.function.name, "arguments": args})
return results
else:
# Model responded with text instead of tool call
text = response.message.content if hasattr(response, "message") else str(response)
print(f" Model returned text (no tool call):\n {text[:500]}")
return None
except Exception as e:
print(f" Error: {e}")
return None
# ── Transformers Function Calling Test ─────────────────────────────────────
def test_transformers_function_calling(transcript: str, schemas: dict, device: str = "cuda"):
"""Test function calling via HuggingFace Transformers (needed for audio pipeline)."""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
print(f"\n=== Transformers Function Calling Test ===")
model_id = "google/gemma-4-E4B-it"
print(f" Loading {model_id}...")
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(f" Model loaded in {time.time() - t0:.1f}s")
tools = build_tool_definitions(schemas)
test_tools = [t for t in tools if t["function"]["name"] in ("extract_anc_visit", "extract_danger_signs")]
messages = [
{"role": "system", "content": (
"You are a clinical data extraction system for India's ASHA health worker program. "
"Extract structured data from the Hindi/Hinglish conversation transcript. "
"ONLY extract information explicitly stated in the conversation. "
"Use null for any field not mentioned. "
"For danger signs, you MUST provide the exact utterance as evidence."
)},
{"role": "user", "content": f"Extract clinical data:\n\n{transcript}"},
]
# Apply chat template with tools
t0 = time.time()
inputs = tokenizer.apply_chat_template(
messages,
tools=test_tools,
return_tensors="pt",
return_dict=True,
).to(device)
print(f" Input tokens: {inputs['input_ids'].shape[-1]}")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False,
)
response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
elapsed = time.time() - t0
print(f" Inference time: {elapsed:.1f}s")
print(f" Raw response:\n{response[:2000]}")
# Try to parse structured output
try:
# Gemma 4 function calls use special format
parsed = json.loads(response)
print(f"\n Parsed JSON successfully")
return parsed
except json.JSONDecodeError:
print(f"\n Response is not pure JSON — may need parsing of tool call format")
return response
# ── Validation ─────────────────────────────────────────────────────────────
def validate_extraction(result: dict, schema: dict, schema_name: str) -> dict:
"""Validate extraction result against schema. Returns validation report."""
import jsonschema
report = {"schema": schema_name, "valid": True, "errors": [], "warnings": []}
try:
jsonschema.validate(result, schema)
except jsonschema.ValidationError as e:
report["valid"] = False
report["errors"].append(str(e.message))
# Custom validation: danger sign evidence check
if schema_name == "danger_signs":
danger_signs = result.get("danger_signs", [])
for ds in danger_signs:
if not ds.get("utterance_evidence"):
report["valid"] = False
report["errors"].append(
f"Danger sign '{ds.get('sign')}' has no utterance_evidence — HALLUCINATION"
)
# Check referral decision has evidence
referral = result.get("referral_decision", {})
if referral.get("decision") in ("refer_immediately", "refer_within_24h"):
if not referral.get("evidence_utterances"):
report["valid"] = False
report["errors"].append("Referral decision has no evidence utterances")
# Null field check: count how many fields are non-null
non_null = 0
total = 0
for key, val in _flatten(result).items():
total += 1
if val is not None:
non_null += 1
if total > 0:
fill_rate = non_null / total * 100
report["fill_rate"] = f"{fill_rate:.0f}%"
if fill_rate < 10:
report["warnings"].append(f"Very low fill rate ({fill_rate:.0f}%) — model may not be extracting")
return report
def _flatten(d, parent_key="", sep="."):
"""Flatten nested dict for field counting."""
items = []
if isinstance(d, dict):
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(_flatten(v, new_key, sep).items())
elif isinstance(v, list):
items.append((new_key, v if v else None))
else:
items.append((new_key, v))
return dict(items)
# ── Main ───────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="MedScribe v2 — Function Calling Test")
parser.add_argument("--mode", choices=["ollama", "transformers"], default="ollama")
parser.add_argument("--validate-only", action="store_true", help="Only validate schemas")
parser.add_argument("--transcript", type=str, help="Custom transcript file (UTF-8)")
parser.add_argument("--normal", action="store_true", help="Use normal (no danger signs) transcript")
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
# Load and validate schemas
print("=== Schema Validation ===")
schemas = load_schemas()
all_valid = True
for name, schema in schemas.items():
valid = validate_schema(schema)
status = "\033[92m[VALID]\033[0m" if valid else "\033[91m[INVALID]\033[0m"
print(f" {status} {name}")
if not valid:
all_valid = False
if not all_valid:
print("\nSchema validation failed. Fix schemas before testing model.")
sys.exit(1)
if args.validate_only:
print("\nAll schemas valid.")
return
# Select transcript
if args.transcript:
with open(args.transcript, "r", encoding="utf-8") as f:
transcript = f.read()
elif args.normal:
transcript = SAMPLE_TRANSCRIPT_NORMAL
print("\n Using NORMAL transcript (expect no danger signs)")
else:
transcript = SAMPLE_TRANSCRIPT_HINDI
print("\n Using HIGH-RISK transcript (expect danger signs)")
# Run test
if args.mode == "ollama":
results = test_ollama_function_calling(transcript, schemas)
else:
results = test_transformers_function_calling(transcript, schemas, args.device)
if results:
print("\n=== Extraction Results ===")
# Save results
output_path = "data/temp/function_calling_test_result.json"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f" Saved to {output_path}")
else:
print("\n No structured results returned. Model may need fine-tuning for this task.")
if __name__ == "__main__":
main()