Spaces:
Sleeping
Sleeping
File size: 38,645 Bytes
745f62a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 | """
MedScribe v2 — Synthetic Training Data Generation
Generates paired (Hindi ASHA conversation, structured JSON extraction) training data
via GPT-4o Mini API. Each sample produces:
1. A realistic Hindi ASHA home visit conversation transcript
2. The correct structured JSON extraction (form fields + danger signs)
Adapted from MedScribe v1's proven generation pipeline with:
- Budget cap + cost tracking
- Per-sample validation (schema + clinical consistency)
- Checkpoint/resume support
- Batch quality monitoring with abort threshold
- Dry-run mode
- Negative examples (no danger signs, many null fields)
Usage:
python scripts/generate_training_data.py --dry-run # First 5 samples
python scripts/generate_training_data.py # Full generation
python scripts/generate_training_data.py --resume # Resume from checkpoint
python scripts/generate_training_data.py --count 500 # Generate N samples
"""
import argparse
import json
import os
import random
import re
import sys
import time
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
# ============================================================
# CONFIG
# ============================================================
OUTPUT_DIR = "data/processed"
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "training_data_raw.jsonl")
CHECKPOINT_FILE = os.path.join(OUTPUT_DIR, "generation_checkpoint.json")
MODEL = "gpt-4o-mini"
TEMPERATURE = 0.7 # Higher for diversity in conversations
JSON_TEMPERATURE = 0.1 # Lower for accurate structured extraction
MAX_TOKENS_TRANSCRIPT = 2500
MAX_TOKENS_EXTRACTION = 2000
BATCH_SIZE = 10
MAX_BUDGET = 10.0
FAILURE_RATE_ABORT = 0.30
MIN_BATCH_FOR_CHECK = 20
# ============================================================
# VISIT SCENARIO DEFINITIONS
# ============================================================
# Each scenario defines the visit type, clinical profile, and
# whether danger signs should be present. ~40% of samples should
# be NEGATIVE (no danger signs) to teach the model restraint.
SCENARIOS = [
# ── ANC VISITS (40% of data) ──
{
"type": "anc_visit",
"label": "ANC — Normal pregnancy, routine visit",
"has_danger_signs": False,
"weight": 8,
"clinical_profile": (
"Healthy primigravida, 24 weeks, all vitals normal, no complaints. "
"BP 110/70, Hb 11.5, weight gain normal. Routine ANC visit. "
"ASHA counsels on diet, IFA tablets, birth preparedness."
),
},
{
"type": "anc_visit",
"label": "ANC — Mild anemia, otherwise normal",
"has_danger_signs": False,
"weight": 5,
"clinical_profile": (
"Second pregnancy, 28 weeks, mild anemia (Hb 9.8). No other complaints. "
"BP normal, fetal movements good. ASHA gives IFA, advises diet."
),
},
{
"type": "anc_visit",
"label": "ANC — Preeclampsia signs (HIGH RISK)",
"has_danger_signs": True,
"weight": 5,
"clinical_profile": (
"Third trimester (32-36 weeks), elevated BP (140-160/90-110), "
"headache, blurred vision, facial/hand swelling, excessive weight gain. "
"Possible preeclampsia. Needs immediate referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Severe anemia (HIGH RISK)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Second trimester, severe anemia (Hb <7), extreme weakness, breathlessness, "
"pallor, dizziness. Needs urgent referral for IV iron/transfusion."
),
},
{
"type": "anc_visit",
"label": "ANC — Reduced fetal movement",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Late third trimester, patient reports baby moving less than usual "
"for 2 days. Other vitals may be normal. Needs monitoring/referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Vaginal bleeding",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Any trimester, reports vaginal bleeding (spotting to heavy). "
"Immediate danger sign requiring emergency referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Gestational diabetes",
"has_danger_signs": False,
"weight": 2,
"clinical_profile": (
"28-32 weeks, elevated blood sugar found on testing. No acute danger signs "
"but needs dietary counseling and follow-up. BP normal."
),
},
{
"type": "anc_visit",
"label": "ANC — Young primigravida, many null fields",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"First visit, early pregnancy (8-10 weeks). Minimal information available — "
"no labs done yet, no vitals taken at this visit (home visit, no equipment beyond BP). "
"Many form fields should be null. Just registration and counseling."
),
},
# ── PNC / HBNC VISITS (25% of data) ──
{
"type": "pnc_visit",
"label": "PNC — Normal postpartum + healthy newborn",
"has_danger_signs": False,
"weight": 6,
"clinical_profile": (
"Day 7 HBNC visit. Mother recovering well, no fever, bleeding light. "
"Baby weight 3.0 kg (birth weight 2.8), breastfeeding well, "
"cord clean and dry, active, no jaundice."
),
},
{
"type": "pnc_visit",
"label": "PNC — Low birth weight newborn, needs monitoring",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"Day 14 visit. Baby 2.1 kg (birth weight 1.9), gaining slowly. "
"Breastfeeding adequate but not frequent enough. Mother well. "
"No acute danger signs but close monitoring needed."
),
},
{
"type": "pnc_visit",
"label": "PNC — Newborn not feeding well (DANGER)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Day 3 visit. Newborn lethargic, not latching, weak cry. "
"May have fever or hypothermia. Mother reports baby sleeping too much. "
"IMNCI danger signs present — urgent referral."
),
},
{
"type": "pnc_visit",
"label": "PNC — Postpartum hemorrhage signs (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Day 1-3 visit. Mother reports heavy bleeding, soaking through cloth. "
"Feeling dizzy and weak. Possible postpartum hemorrhage. "
"Immediate referral needed."
),
},
{
"type": "pnc_visit",
"label": "PNC — Newborn jaundice",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Day 3-5 visit. Newborn has yellow skin, possibly yellow palms/soles. "
"Feeding okay or slightly reduced. Needs assessment for severity — "
"jaundice within 24h of birth or palms/soles = severe."
),
},
# ── DELIVERY (10% of data) ──
{
"type": "delivery",
"label": "Delivery — Normal institutional delivery",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"Full-term normal delivery at PHC/district hospital. "
"Healthy baby boy/girl, cried immediately, breastfed within 1 hour. "
"Birth weight 2.8-3.5 kg. Mother stable. Vaccines given at birth."
),
},
{
"type": "delivery",
"label": "Delivery — Home delivery (partial info)",
"has_danger_signs": False,
"weight": 2,
"clinical_profile": (
"Home delivery attended by dai/family. ASHA visiting after the fact. "
"Limited info on delivery details. Baby seems okay. "
"Need to check birth weight, breastfeeding, vaccines."
),
},
{
"type": "delivery",
"label": "Delivery — Preterm with complications (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Preterm delivery (34-36 weeks). Low birth weight (<2.5 kg). "
"Baby may have breathing difficulty. Mother may have had complications. "
"Needs close monitoring or referral."
),
},
# ── CHILD HEALTH / HBYC (25% of data) ──
{
"type": "child_health",
"label": "HBYC — Healthy 6-month-old, routine visit",
"has_danger_signs": False,
"weight": 5,
"clinical_profile": (
"6-month HBYC visit. Good weight gain, breastfeeding + complementary food started. "
"Immunizations up to date. Milestones appropriate. No illness."
),
},
{
"type": "child_health",
"label": "HBYC — Underweight child, no acute danger",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"9-month visit. Child slightly underweight, pallor present (mild anemia). "
"Eating poorly. No acute illness. ASHA counsels on feeding. "
"Deworming due. Vitamin A due."
),
},
{
"type": "child_health",
"label": "HBYC — Diarrhea + dehydration (DANGER)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"12-month visit. Child has watery diarrhea for 3 days, "
"not drinking well, sunken eyes, lethargic. Possible severe dehydration. "
"IMNCI general danger signs may be present."
),
},
{
"type": "child_health",
"label": "HBYC — Pneumonia signs (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"9-month visit. Child has cough for 5 days, fast breathing, "
"chest indrawing visible. Fever present. Possible pneumonia. "
"Needs urgent referral per IMNCI."
),
},
{
"type": "child_health",
"label": "HBYC — Severe malnutrition (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"15-month visit. Visible severe wasting, very low weight for age, "
"child not active. Possible edema of feet. Severe acute malnutrition "
"requiring NRC referral."
),
},
]
# ============================================================
# PROMPTS
# ============================================================
SYSTEM_PROMPT_TRANSCRIPT = """You generate realistic Hindi conversations between ASHA health workers and patients during home visits in rural India. These will train a medical AI — clinical accuracy and natural dialogue are both critical.
OUTPUT FORMAT (strict):
- ONLY dialogue lines. Each line: "ASHA:", "Patient:", or "Mother:" followed by spoken words.
- ABSOLUTELY NO narration, stage directions, action descriptions, or parentheticals.
WRONG: (BP चेक करते हुए) / [measures weight] / *takes temperature* / (बच्चे को देख कर)
RIGHT: "ASHA: चलिए, BP देख लेती हूँ... 150/95 आ रहा है, ये तो ज़्यादा है।"
- ALL text MUST be in Devanagari script. Do NOT use Romanized Hindi (no "Aapka BP", no "theek hai").
Only English words allowed: medical terms (BP, Hb, TT, IFA, ORS, kg, mg, PHC, CHC).
CLINICAL REQUIREMENTS:
- ASHA must verbally state EVERY measurement with its value in Devanagari: "आपका BP 140/90 आ रहा है", "बच्चे का वज़न 3.1 kg है", "Hb 9.5 आया है"
- Include at least 6 distinct clinical data points spoken naturally in dialogue (vitals, history, medications, plans).
- Visit-specific data the ASHA should cover:
ANC: gestational age, BP, weight, Hb, IFA compliance, TT status, fetal movement, birth preparedness (transport, facility, money, blood donor).
PNC/Newborn: mother's bleeding/fever/pain, baby weight, feeding pattern (frequency, latch), cord condition, jaundice check, vaccination status.
Child health: age, weight, feeding/diet, immunization status, milestones, illness symptoms, deworming/Vitamin A.
DIALOGUE STYLE:
- Vary openings — sometimes ASHA calls from the door, sometimes patient greets first, sometimes mid-activity. Do NOT always start with "नमस्ते, कैसे हैं आप?"
- ASHA is warm but efficient — covers clinical ground without sounding like a form.
- Patient speaks colloquially: approximate dates ("लगभग 6 महीना"), local terms, sometimes vague or tangential.
- Patient may volunteer info, ask questions, express worry, or dismiss concerns.
- 20-30 dialogue turns with substantive content (not single-word responses).
- End with ASHA's concrete plan: next visit date, any referral, medications given, counseling summary."""
SYSTEM_PROMPT_FORM_EXTRACTION = """You are a clinical data extraction system. Extract structured medical data from an ASHA home visit conversation transcript into the provided JSON schema.
RULES:
1. Extract ONLY information EXPLICITLY stated or clearly implied in the conversation.
2. Use null for anything not mentioned — never guess or fill in "expected" values.
3. Numbers must match exactly as stated in conversation (BP, weight, Hb, temperature, age, etc.).
4. For array fields (symptoms_reported, counseling_provided), extract all relevant items mentioned.
5. If ASHA states a measurement value, record the exact number, not just "normal".
6. Patient's approximate statements: convert to best numeric estimate ("लगभग 6 महीना" → gestational_weeks: 24).
7. Return valid JSON matching the schema. No markdown formatting."""
SYSTEM_PROMPT_DANGER_EXTRACTION = """You are a clinical danger sign extraction system for Indian ASHA worker home visits. Extract danger signs from conversation transcripts with high precision.
CORE RULES:
1. ONLY flag danger signs with DIRECT, EXPLICIT evidence in the conversation text.
2. Each flag MUST include utterance_evidence — the exact Hindi quote that triggered it.
3. If NO danger signs exist in the conversation, return an empty danger_signs array. This is correct and expected for normal visits.
4. NEVER invent or hallucinate danger signs. When in doubt, do not flag.
CLASSIFICATION GUIDANCE:
- immediate_referral: life-threatening — heavy uncontrolled bleeding, convulsions, unconsciousness, BP ≥160/110, newborn not breathing
- urgent_care: serious — elevated BP with symptoms (headache/vision/swelling), Hb <7 with symptoms, signs of severe dehydration (sunken eyes + lethargic + not drinking), fast breathing with chest indrawing
- monitor_closely: borderline — isolated mild findings, low-grade fever (99-100°F), mild swelling alone
ANTI-HALLUCINATION:
- Do NOT flag normal values as danger signs. BP 110/70 is normal. Temperature 98.6°F is normal.
- Do NOT flag a sign just because the scenario suggests it — only flag what the CONVERSATION actually says.
- If a value is borderline (e.g., temp 99°F), classify as monitor_closely at most, NOT urgent_care.
REFERRAL LOGIC:
- Any immediate_referral sign → refer_immediately (district_hospital/FRU)
- Only urgent_care signs → refer_within_24h (PHC/CHC)
- Only monitor_closely → continue_monitoring
- No signs → routine_followup
Fill the relevant checklist (maternal or newborn): "detected" if found, "not_detected" if assessed as normal, "not_assessed" if not discussed.
Return valid JSON only."""
def build_transcript_prompt(scenario: dict) -> str:
danger_instruction = (
"Yes — include clear, unambiguous danger signs in the conversation"
if scenario['has_danger_signs']
else "No — this is a normal visit with no danger signs"
)
symptom_instruction = (
"Patient describes symptoms matching danger signs in natural colloquial Hindi (not medical jargon)."
if scenario['has_danger_signs']
else "Patient has no concerning symptoms. Routine, healthy visit."
)
return f"""Generate an ASHA home visit conversation in Hindi (Devanagari only):
VISIT: {scenario['type']} — {scenario['label']}
CLINICAL DETAILS: {scenario['clinical_profile']}
DANGER SIGNS: {danger_instruction}
Rules:
- Pure dialogue, zero narration/parentheticals. ASHA speaks measurements aloud in Devanagari.
- {symptom_instruction}
- At least 6 clinical data points woven naturally into conversation.
- 20-30 turns. Feel like a real village home visit, not a medical interview."""
def build_extraction_prompt(transcript: str, visit_type: str, schema: dict) -> str:
return f"""Extract structured data from this ASHA home visit conversation into the provided JSON schema.
CONVERSATION TRANSCRIPT:
{transcript}
OUTPUT JSON SCHEMA:
{json.dumps(schema, ensure_ascii=False, indent=2)}
Extract the data now. Return ONLY valid JSON matching the schema. Use null for fields not mentioned in the conversation."""
def build_danger_signs_prompt(transcript: str, visit_type: str, schema: dict) -> str:
return f"""Analyze this ASHA home visit conversation for danger signs.
CONVERSATION TRANSCRIPT:
{transcript}
VISIT TYPE: {visit_type}
OUTPUT JSON SCHEMA:
{json.dumps(schema, ensure_ascii=False, indent=2)}
CRITICAL RULES:
- ONLY flag danger signs that have DIRECT evidence in the conversation
- Each danger sign MUST include utterance_evidence — the exact Hindi quote that triggered it
- If NO danger signs are present, return an empty danger_signs array
- Referral decision must be based ONLY on detected danger signs
Return ONLY valid JSON matching the schema."""
# ============================================================
# SCHEMA MAPPING
# ============================================================
def load_schema(name: str) -> dict:
path = Path(f"configs/schemas/{name}.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
VISIT_TYPE_TO_SCHEMA = {
"anc_visit": "anc_visit",
"pnc_visit": "pnc_visit",
"delivery": "delivery",
"child_health": "child_health",
}
# ============================================================
# VALIDATION
# ============================================================
def validate_transcript(text: str) -> dict:
"""Validate generated transcript quality."""
result = {"valid": True, "issues": []}
if not text or len(text) < 200:
return {"valid": False, "issues": ["Transcript too short"]}
# Must contain Hindi characters
hindi_chars = sum(1 for c in text if '\u0900' <= c <= '\u097F')
if hindi_chars < 50:
result["valid"] = False
result["issues"].append(f"Too few Hindi characters ({hindi_chars})")
# Must have multiple dialogue turns
turns = text.count("ASHA:") + text.count("Patient:") + text.count("Mother:")
if turns < 8:
result["valid"] = False
result["issues"].append(f"Too few dialogue turns ({turns})")
# Should not contain English narration
narration_markers = ["[", "]", "(walks", "(checks", "(measures", "Scene:", "Setting:"]
for marker in narration_markers:
if marker.lower() in text.lower():
result["issues"].append(f"Contains narration marker: {marker}")
return result
def validate_extraction(data: dict, visit_type: str, has_danger_signs: bool) -> dict:
"""Validate extracted JSON quality and clinical consistency."""
result = {"valid": True, "issues": []}
if not isinstance(data, dict):
return {"valid": False, "issues": ["Not a dict"]}
# Check it's not empty
non_null_count = _count_non_null(data)
if non_null_count < 3:
result["valid"] = False
result["issues"].append(f"Almost empty extraction ({non_null_count} non-null fields)")
return result
def validate_danger_signs(data: dict, has_danger_signs: bool) -> dict:
"""Validate danger sign extraction — the most critical validation."""
result = {"valid": True, "issues": []}
if not isinstance(data, dict):
return {"valid": False, "issues": ["Not a dict"]}
signs = data.get("danger_signs", [])
# If we expect danger signs, there should be some
if has_danger_signs and len(signs) == 0:
result["valid"] = False
result["issues"].append("Expected danger signs but got none")
# If we DON'T expect danger signs, there should be none
if not has_danger_signs and len(signs) > 0:
result["valid"] = False
result["issues"].append(f"Expected no danger signs but got {len(signs)} — hallucination in training data")
# Every danger sign must have utterance_evidence
for i, sign in enumerate(signs):
if not sign.get("utterance_evidence"):
result["valid"] = False
result["issues"].append(f"Danger sign [{i}] '{sign.get('sign')}' missing utterance_evidence")
# Referral decision consistency
referral = data.get("referral_decision", {})
decision = referral.get("decision", "")
if has_danger_signs and decision in ("routine_followup", "continue_monitoring") and len(signs) > 0:
# Has flags but says routine — inconsistent
severity = [s.get("category") for s in signs]
if "immediate_referral" in severity:
result["valid"] = False
result["issues"].append("Has immediate_referral signs but decision is not refer_immediately")
if not has_danger_signs and decision in ("refer_immediately", "refer_within_24h"):
result["valid"] = False
result["issues"].append("No danger signs but referral decision is urgent — hallucination")
return result
def _count_non_null(d, count=0):
if isinstance(d, dict):
for v in d.values():
count = _count_non_null(v, count)
elif isinstance(d, list):
count += len(d)
elif d is not None:
count += 1
return count
# ============================================================
# CHECKPOINT
# ============================================================
def load_checkpoint() -> dict:
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r") as f:
return json.load(f)
return {
"generated": 0, "valid": 0, "invalid": 0, "failed": 0,
"total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0},
}
def save_checkpoint(cp: dict):
with open(CHECKPOINT_FILE, "w") as f:
json.dump(cp, f, indent=2)
# ============================================================
# COST TRACKING
# ============================================================
# GPT-4o Mini pricing (as of 2026)
INPUT_COST_PER_M = 0.15 # $0.15 per 1M input tokens
OUTPUT_COST_PER_M = 0.60 # $0.60 per 1M output tokens
def estimate_cost(input_text: str, output_text: str) -> float:
input_tokens = len(input_text) // 4
output_tokens = len(output_text) // 4
return (input_tokens * INPUT_COST_PER_M / 1_000_000) + (output_tokens * OUTPUT_COST_PER_M / 1_000_000)
# ============================================================
# WEIGHTED SCENARIO SAMPLING
# ============================================================
def sample_scenario() -> dict:
"""Sample a scenario weighted by the 'weight' field."""
weights = [s["weight"] for s in SCENARIOS]
return random.choices(SCENARIOS, weights=weights, k=1)[0]
# ============================================================
# GENERATION
# ============================================================
def generate_one_sample(client: OpenAI, scenario: dict, schemas: dict) -> dict | None:
"""
Generate one complete training sample:
1. Generate Hindi conversation transcript
2. Extract structured form data
3. Extract danger signs
4. Validate all three
Returns the sample dict or None if validation fails.
"""
visit_type = scenario["type"]
form_schema_name = VISIT_TYPE_TO_SCHEMA[visit_type]
form_schema = schemas[form_schema_name]
danger_schema = schemas["danger_signs"]
total_cost = 0.0
# ── Step 1: Generate transcript ──
try:
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_TRANSCRIPT},
{"role": "user", "content": build_transcript_prompt(scenario)},
],
max_tokens=MAX_TOKENS_TRANSCRIPT,
temperature=TEMPERATURE,
)
transcript = resp.choices[0].message.content.strip()
total_cost += estimate_cost(SYSTEM_PROMPT_TRANSCRIPT + build_transcript_prompt(scenario), transcript)
except Exception as e:
return {"error": f"Transcript generation failed: {e}", "cost": total_cost}
# Validate transcript
tv = validate_transcript(transcript)
if not tv["valid"]:
return {"error": f"Invalid transcript: {tv['issues']}", "cost": total_cost}
# ── Step 2: Extract form data ──
try:
extraction_prompt = build_extraction_prompt(transcript, visit_type, form_schema)
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_FORM_EXTRACTION},
{"role": "user", "content": extraction_prompt},
],
max_tokens=MAX_TOKENS_EXTRACTION,
temperature=JSON_TEMPERATURE,
response_format={"type": "json_object"},
)
form_text = resp.choices[0].message.content.strip()
form_data = json.loads(form_text)
total_cost += estimate_cost(SYSTEM_PROMPT_FORM_EXTRACTION + extraction_prompt, form_text)
except json.JSONDecodeError as e:
return {"error": f"Form extraction not valid JSON: {e}", "cost": total_cost}
except Exception as e:
return {"error": f"Form extraction failed: {e}", "cost": total_cost}
# Validate form extraction
ev = validate_extraction(form_data, visit_type, scenario["has_danger_signs"])
if not ev["valid"]:
return {"error": f"Invalid extraction: {ev['issues']}", "cost": total_cost}
# ── Step 3: Extract danger signs ──
try:
danger_prompt = build_danger_signs_prompt(transcript, visit_type, danger_schema)
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_DANGER_EXTRACTION},
{"role": "user", "content": danger_prompt},
],
max_tokens=MAX_TOKENS_EXTRACTION,
temperature=JSON_TEMPERATURE,
response_format={"type": "json_object"},
)
danger_text = resp.choices[0].message.content.strip()
danger_data = json.loads(danger_text)
total_cost += estimate_cost(SYSTEM_PROMPT_DANGER_EXTRACTION + danger_prompt, danger_text)
except json.JSONDecodeError as e:
return {"error": f"Danger signs not valid JSON: {e}", "cost": total_cost}
except Exception as e:
return {"error": f"Danger signs extraction failed: {e}", "cost": total_cost}
# Validate danger signs — this is the most critical validation
dv = validate_danger_signs(danger_data, scenario["has_danger_signs"])
if not dv["valid"]:
return {"error": f"Invalid danger signs: {dv['issues']}", "cost": total_cost}
# ── Success — build training sample ──
return {
"sample": {
"id": None, # assigned later
"visit_type": visit_type,
"scenario_label": scenario["label"],
"has_danger_signs": scenario["has_danger_signs"],
"transcript": transcript,
"form_extraction": form_data,
"danger_signs_extraction": danger_data,
"form_schema": form_schema_name,
},
"validation": {
"transcript": tv,
"extraction": ev,
"danger_signs": dv,
},
"cost": total_cost,
}
# ============================================================
# MAIN
# ============================================================
def main():
parser = argparse.ArgumentParser(description="MedScribe v2 — Training Data Generation")
parser.add_argument("--dry-run", action="store_true", help="Generate 5 samples only")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
parser.add_argument("--count", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
os.makedirs(OUTPUT_DIR, exist_ok=True)
target = 5 if args.dry_run else args.count
print("=" * 60)
print(f"Training Data Generation — {MODEL}" + (" [DRY RUN]" if args.dry_run else ""))
print(f"Target: {target} samples")
print("=" * 60)
# ── Gate: API key ──
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("\nABORT: OPENAI_API_KEY not found. Set in .env file.")
sys.exit(1)
client = OpenAI()
# Quick API test
try:
client.chat.completions.create(
model=MODEL,
messages=[{"role": "user", "content": "Say OK"}],
max_tokens=5,
)
print(f"API connection verified: {MODEL}")
except Exception as e:
print(f"\nABORT: API test failed: {e}")
sys.exit(1)
# ── Load schemas ──
schemas = {}
for name in ["anc_visit", "pnc_visit", "delivery", "child_health", "danger_signs"]:
schemas[name] = load_schema(name)
print(f"Loaded {len(schemas)} schemas")
# ── Load checkpoint ──
cp = load_checkpoint()
start_idx = cp["generated"] if args.resume else 0
if not args.resume:
cp = {
"generated": 0, "valid": 0, "invalid": 0, "failed": 0,
"total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0},
"quality": {
"narration_leaks": 0, "low_hindi": 0, "low_turns": 0,
"polarity_errors": 0, "avg_clinical_density": 0.0,
"total_clinical_density": 0, "total_scored": 0,
},
}
# Open output file (append if resuming)
mode = "a" if args.resume and os.path.exists(OUTPUT_FILE) else "w"
outfile = open(OUTPUT_FILE, mode, encoding="utf-8")
print(f"\nStarting from sample {start_idx + 1}...")
if args.dry_run:
print("DRY RUN: Generating 5 samples. Review output before full run.\n")
batch_count = 0
for i in range(start_idx, target):
scenario = sample_scenario()
vtype = scenario["type"]
result = generate_one_sample(client, scenario, schemas)
# Track cost regardless
cp["total_cost"] += result.get("cost", 0)
if "error" in result:
cp["failed"] += 1
cp["generated"] += 1
status = f"FAIL: {result['error'][:80]}"
elif "sample" in result:
sample = result["sample"]
sample["id"] = f"ms2_{i:04d}"
# Write to JSONL
outfile.write(json.dumps(sample, ensure_ascii=False) + "\n")
outfile.flush()
cp["valid"] += 1
cp["generated"] += 1
cp["by_type"][vtype] = cp["by_type"].get(vtype, 0) + 1
if scenario["has_danger_signs"]:
cp["by_danger"]["positive"] += 1
else:
cp["by_danger"]["negative"] += 1
n_signs = len(sample["danger_signs_extraction"].get("danger_signs", []))
status = f"OK [{vtype}] signs={n_signs}"
# ── Real-time quality monitoring ──
q = cp.setdefault("quality", {
"narration_leaks": 0, "low_hindi": 0, "low_turns": 0,
"polarity_errors": 0, "avg_clinical_density": 0.0,
"total_clinical_density": 0, "total_scored": 0,
})
t = sample["transcript"]
# Check narration (exclude single English words in parens like "(latch)")
paren_matches = re.findall(r'\([^)]{5,}\)', t) # only flag parens with 5+ chars
bracket_matches = re.findall(r'\[[^\]]+\]', t)
if paren_matches or bracket_matches or '*' in t:
q["narration_leaks"] += 1
# Check Hindi density
hindi_chars = sum(1 for c in t if '\u0900' <= c <= '\u097F')
if hindi_chars / max(len(t), 1) < 0.5:
q["low_hindi"] += 1
# Check turns
turns = t.count("ASHA:") + t.count("Patient:") + t.count("Mother:")
if turns < 15:
q["low_turns"] += 1
# Clinical density (count numbers near medical terms)
clin_kws = ['bp', 'weight', 'वजन', 'kg', 'hb', 'हीमोग्लोबिन', 'तापमान',
'हफ्ता', 'महीना', 'tablet', 'ग्राम', 'डिग्री', 'किलो']
nums = re.findall(r'\d+\.?\d*', t)
clin_count = 0
for n in nums:
pos = t.find(n)
ctx = t[max(0, pos-40):pos+40].lower()
if any(kw in ctx for kw in clin_kws):
clin_count += 1
q["total_clinical_density"] += clin_count
q["total_scored"] += 1
# Polarity check
if scenario["has_danger_signs"] and n_signs == 0:
q["polarity_errors"] += 1
elif not scenario["has_danger_signs"] and n_signs > 0:
q["polarity_errors"] += 1
else:
cp["failed"] += 1
cp["generated"] += 1
status = "UNKNOWN"
batch_count += 1
pct = cp["generated"] / target * 100
print(f" [{cp['generated']}/{target}] ({pct:.0f}%) ${cp['total_cost']:.4f} | {scenario['label'][:50]} | {status}")
# ── Safety checks ──
if cp["total_cost"] > MAX_BUDGET:
print(f"\nABORT: Budget exceeded (${cp['total_cost']:.2f} > ${MAX_BUDGET})")
break
if cp["generated"] >= MIN_BATCH_FOR_CHECK:
total_attempted = cp["valid"] + cp["failed"]
if total_attempted > 0 and cp["failed"] / total_attempted > FAILURE_RATE_ABORT:
print(f"\nABORT: Failure rate {cp['failed']}/{total_attempted} exceeds threshold")
break
# Checkpoint + quality report
if batch_count >= BATCH_SIZE:
save_checkpoint(cp)
batch_count = 0
# Periodic quality report every 50 samples
q = cp.get("quality", {})
scored = q.get("total_scored", 0)
if scored > 0 and scored % 50 < BATCH_SIZE:
avg_clin = q["total_clinical_density"] / scored
print(f"\n ┌── QUALITY REPORT (n={scored}) ──")
print(f" │ Narration leaks: {q['narration_leaks']} ({q['narration_leaks']/scored*100:.0f}%)")
print(f" │ Low Hindi: {q['low_hindi']} ({q['low_hindi']/scored*100:.0f}%)")
print(f" │ Low turns (<15): {q['low_turns']} ({q['low_turns']/scored*100:.0f}%)")
print(f" │ Polarity errors: {q['polarity_errors']} ({q['polarity_errors']/scored*100:.0f}%)")
print(f" │ Avg clinical #s: {avg_clin:.1f} per sample")
print(f" └{'─' * 35}")
# ABORT on polarity errors > 5%
if scored >= 30 and q["polarity_errors"] / scored > 0.05:
print(f"\n ABORT: Polarity error rate {q['polarity_errors']}/{scored} > 5%")
print(f" This means the model is hallucinating danger signs or missing real ones.")
print(f" Fix prompts before continuing.")
break
# Rate limit
time.sleep(0.3)
outfile.close()
save_checkpoint(cp)
# ── Summary ──
print(f"\n{'=' * 60}")
print("GENERATION SUMMARY")
print("=" * 60)
print(f" Total generated: {cp['generated']}")
print(f" Valid samples: {cp['valid']}")
print(f" Failed: {cp['failed']}")
print(f" Total cost: ${cp['total_cost']:.4f}")
print(f"\n By visit type:")
for vt, count in sorted(cp["by_type"].items()):
print(f" {vt}: {count}")
print(f"\n Danger sign balance:")
print(f" Positive (has danger signs): {cp['by_danger']['positive']}")
print(f" Negative (no danger signs): {cp['by_danger']['negative']}")
total_with_labels = cp['by_danger']['positive'] + cp['by_danger']['negative']
if total_with_labels > 0:
neg_pct = cp['by_danger']['negative'] / total_with_labels * 100
print(f" Negative ratio: {neg_pct:.0f}% (target: ~40%)")
# Quality summary
q = cp.get("quality", {})
scored = q.get("total_scored", 0)
if scored > 0:
avg_clin = q["total_clinical_density"] / scored
print(f"\n Quality metrics:")
print(f" Narration leaks: {q['narration_leaks']}/{scored} ({q['narration_leaks']/scored*100:.0f}%)")
print(f" Low Hindi: {q['low_hindi']}/{scored} ({q['low_hindi']/scored*100:.0f}%)")
print(f" Low turns (<15): {q['low_turns']}/{scored} ({q['low_turns']/scored*100:.0f}%)")
print(f" Polarity errors: {q['polarity_errors']}/{scored} ({q['polarity_errors']/scored*100:.0f}%)")
print(f" Avg clinical #s: {avg_clin:.1f} per sample")
if q['polarity_errors'] > 0:
print(f"\n WARNING: {q['polarity_errors']} polarity errors detected!")
print(f" Review these samples before training.")
print(f"\n Output: {OUTPUT_FILE}")
if args.dry_run:
print(f"\nDRY RUN complete. Review the output file.")
print(f"If quality looks good, run: python scripts/generate_training_data.py --count {args.count}")
if cp["valid"] >= 200:
print(f"\nREADY for next step: python scripts/04_prepare_training.py")
elif cp["valid"] > 0:
print(f"\n{cp['valid']} samples generated. May need more for good fine-tune results.")
print("=" * 60)
if __name__ == "__main__":
main()
|