Spaces:
Sleeping
Sleeping
| """ | |
| MedScribe v2 — Synthetic Training Data Generation | |
| Generates paired (Hindi ASHA conversation, structured JSON extraction) training data | |
| via GPT-4o Mini API. Each sample produces: | |
| 1. A realistic Hindi ASHA home visit conversation transcript | |
| 2. The correct structured JSON extraction (form fields + danger signs) | |
| Adapted from MedScribe v1's proven generation pipeline with: | |
| - Budget cap + cost tracking | |
| - Per-sample validation (schema + clinical consistency) | |
| - Checkpoint/resume support | |
| - Batch quality monitoring with abort threshold | |
| - Dry-run mode | |
| - Negative examples (no danger signs, many null fields) | |
| Usage: | |
| python scripts/generate_training_data.py --dry-run # First 5 samples | |
| python scripts/generate_training_data.py # Full generation | |
| python scripts/generate_training_data.py --resume # Resume from checkpoint | |
| python scripts/generate_training_data.py --count 500 # Generate N samples | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| load_dotenv() | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| OUTPUT_DIR = "data/processed" | |
| OUTPUT_FILE = os.path.join(OUTPUT_DIR, "training_data_raw.jsonl") | |
| CHECKPOINT_FILE = os.path.join(OUTPUT_DIR, "generation_checkpoint.json") | |
| MODEL = "gpt-4o-mini" | |
| TEMPERATURE = 0.7 # Higher for diversity in conversations | |
| JSON_TEMPERATURE = 0.1 # Lower for accurate structured extraction | |
| MAX_TOKENS_TRANSCRIPT = 2500 | |
| MAX_TOKENS_EXTRACTION = 2000 | |
| BATCH_SIZE = 10 | |
| MAX_BUDGET = 10.0 | |
| FAILURE_RATE_ABORT = 0.30 | |
| MIN_BATCH_FOR_CHECK = 20 | |
| # ============================================================ | |
| # VISIT SCENARIO DEFINITIONS | |
| # ============================================================ | |
| # Each scenario defines the visit type, clinical profile, and | |
| # whether danger signs should be present. ~40% of samples should | |
| # be NEGATIVE (no danger signs) to teach the model restraint. | |
| SCENARIOS = [ | |
| # ── ANC VISITS (40% of data) ── | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Normal pregnancy, routine visit", | |
| "has_danger_signs": False, | |
| "weight": 8, | |
| "clinical_profile": ( | |
| "Healthy primigravida, 24 weeks, all vitals normal, no complaints. " | |
| "BP 110/70, Hb 11.5, weight gain normal. Routine ANC visit. " | |
| "ASHA counsels on diet, IFA tablets, birth preparedness." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Mild anemia, otherwise normal", | |
| "has_danger_signs": False, | |
| "weight": 5, | |
| "clinical_profile": ( | |
| "Second pregnancy, 28 weeks, mild anemia (Hb 9.8). No other complaints. " | |
| "BP normal, fetal movements good. ASHA gives IFA, advises diet." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Preeclampsia signs (HIGH RISK)", | |
| "has_danger_signs": True, | |
| "weight": 5, | |
| "clinical_profile": ( | |
| "Third trimester (32-36 weeks), elevated BP (140-160/90-110), " | |
| "headache, blurred vision, facial/hand swelling, excessive weight gain. " | |
| "Possible preeclampsia. Needs immediate referral." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Severe anemia (HIGH RISK)", | |
| "has_danger_signs": True, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "Second trimester, severe anemia (Hb <7), extreme weakness, breathlessness, " | |
| "pallor, dizziness. Needs urgent referral for IV iron/transfusion." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Reduced fetal movement", | |
| "has_danger_signs": True, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "Late third trimester, patient reports baby moving less than usual " | |
| "for 2 days. Other vitals may be normal. Needs monitoring/referral." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Vaginal bleeding", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "Any trimester, reports vaginal bleeding (spotting to heavy). " | |
| "Immediate danger sign requiring emergency referral." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Gestational diabetes", | |
| "has_danger_signs": False, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "28-32 weeks, elevated blood sugar found on testing. No acute danger signs " | |
| "but needs dietary counseling and follow-up. BP normal." | |
| ), | |
| }, | |
| { | |
| "type": "anc_visit", | |
| "label": "ANC — Young primigravida, many null fields", | |
| "has_danger_signs": False, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "First visit, early pregnancy (8-10 weeks). Minimal information available — " | |
| "no labs done yet, no vitals taken at this visit (home visit, no equipment beyond BP). " | |
| "Many form fields should be null. Just registration and counseling." | |
| ), | |
| }, | |
| # ── PNC / HBNC VISITS (25% of data) ── | |
| { | |
| "type": "pnc_visit", | |
| "label": "PNC — Normal postpartum + healthy newborn", | |
| "has_danger_signs": False, | |
| "weight": 6, | |
| "clinical_profile": ( | |
| "Day 7 HBNC visit. Mother recovering well, no fever, bleeding light. " | |
| "Baby weight 3.0 kg (birth weight 2.8), breastfeeding well, " | |
| "cord clean and dry, active, no jaundice." | |
| ), | |
| }, | |
| { | |
| "type": "pnc_visit", | |
| "label": "PNC — Low birth weight newborn, needs monitoring", | |
| "has_danger_signs": False, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "Day 14 visit. Baby 2.1 kg (birth weight 1.9), gaining slowly. " | |
| "Breastfeeding adequate but not frequent enough. Mother well. " | |
| "No acute danger signs but close monitoring needed." | |
| ), | |
| }, | |
| { | |
| "type": "pnc_visit", | |
| "label": "PNC — Newborn not feeding well (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "Day 3 visit. Newborn lethargic, not latching, weak cry. " | |
| "May have fever or hypothermia. Mother reports baby sleeping too much. " | |
| "IMNCI danger signs present — urgent referral." | |
| ), | |
| }, | |
| { | |
| "type": "pnc_visit", | |
| "label": "PNC — Postpartum hemorrhage signs (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "Day 1-3 visit. Mother reports heavy bleeding, soaking through cloth. " | |
| "Feeling dizzy and weak. Possible postpartum hemorrhage. " | |
| "Immediate referral needed." | |
| ), | |
| }, | |
| { | |
| "type": "pnc_visit", | |
| "label": "PNC — Newborn jaundice", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "Day 3-5 visit. Newborn has yellow skin, possibly yellow palms/soles. " | |
| "Feeding okay or slightly reduced. Needs assessment for severity — " | |
| "jaundice within 24h of birth or palms/soles = severe." | |
| ), | |
| }, | |
| # ── DELIVERY (10% of data) ── | |
| { | |
| "type": "delivery", | |
| "label": "Delivery — Normal institutional delivery", | |
| "has_danger_signs": False, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "Full-term normal delivery at PHC/district hospital. " | |
| "Healthy baby boy/girl, cried immediately, breastfed within 1 hour. " | |
| "Birth weight 2.8-3.5 kg. Mother stable. Vaccines given at birth." | |
| ), | |
| }, | |
| { | |
| "type": "delivery", | |
| "label": "Delivery — Home delivery (partial info)", | |
| "has_danger_signs": False, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "Home delivery attended by dai/family. ASHA visiting after the fact. " | |
| "Limited info on delivery details. Baby seems okay. " | |
| "Need to check birth weight, breastfeeding, vaccines." | |
| ), | |
| }, | |
| { | |
| "type": "delivery", | |
| "label": "Delivery — Preterm with complications (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "Preterm delivery (34-36 weeks). Low birth weight (<2.5 kg). " | |
| "Baby may have breathing difficulty. Mother may have had complications. " | |
| "Needs close monitoring or referral." | |
| ), | |
| }, | |
| # ── CHILD HEALTH / HBYC (25% of data) ── | |
| { | |
| "type": "child_health", | |
| "label": "HBYC — Healthy 6-month-old, routine visit", | |
| "has_danger_signs": False, | |
| "weight": 5, | |
| "clinical_profile": ( | |
| "6-month HBYC visit. Good weight gain, breastfeeding + complementary food started. " | |
| "Immunizations up to date. Milestones appropriate. No illness." | |
| ), | |
| }, | |
| { | |
| "type": "child_health", | |
| "label": "HBYC — Underweight child, no acute danger", | |
| "has_danger_signs": False, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "9-month visit. Child slightly underweight, pallor present (mild anemia). " | |
| "Eating poorly. No acute illness. ASHA counsels on feeding. " | |
| "Deworming due. Vitamin A due." | |
| ), | |
| }, | |
| { | |
| "type": "child_health", | |
| "label": "HBYC — Diarrhea + dehydration (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 3, | |
| "clinical_profile": ( | |
| "12-month visit. Child has watery diarrhea for 3 days, " | |
| "not drinking well, sunken eyes, lethargic. Possible severe dehydration. " | |
| "IMNCI general danger signs may be present." | |
| ), | |
| }, | |
| { | |
| "type": "child_health", | |
| "label": "HBYC — Pneumonia signs (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "9-month visit. Child has cough for 5 days, fast breathing, " | |
| "chest indrawing visible. Fever present. Possible pneumonia. " | |
| "Needs urgent referral per IMNCI." | |
| ), | |
| }, | |
| { | |
| "type": "child_health", | |
| "label": "HBYC — Severe malnutrition (DANGER)", | |
| "has_danger_signs": True, | |
| "weight": 2, | |
| "clinical_profile": ( | |
| "15-month visit. Visible severe wasting, very low weight for age, " | |
| "child not active. Possible edema of feet. Severe acute malnutrition " | |
| "requiring NRC referral." | |
| ), | |
| }, | |
| ] | |
| # ============================================================ | |
| # PROMPTS | |
| # ============================================================ | |
| SYSTEM_PROMPT_TRANSCRIPT = """You generate realistic Hindi conversations between ASHA health workers and patients during home visits in rural India. These will train a medical AI — clinical accuracy and natural dialogue are both critical. | |
| OUTPUT FORMAT (strict): | |
| - ONLY dialogue lines. Each line: "ASHA:", "Patient:", or "Mother:" followed by spoken words. | |
| - ABSOLUTELY NO narration, stage directions, action descriptions, or parentheticals. | |
| WRONG: (BP चेक करते हुए) / [measures weight] / *takes temperature* / (बच्चे को देख कर) | |
| RIGHT: "ASHA: चलिए, BP देख लेती हूँ... 150/95 आ रहा है, ये तो ज़्यादा है।" | |
| - ALL text MUST be in Devanagari script. Do NOT use Romanized Hindi (no "Aapka BP", no "theek hai"). | |
| Only English words allowed: medical terms (BP, Hb, TT, IFA, ORS, kg, mg, PHC, CHC). | |
| CLINICAL REQUIREMENTS: | |
| - ASHA must verbally state EVERY measurement with its value in Devanagari: "आपका BP 140/90 आ रहा है", "बच्चे का वज़न 3.1 kg है", "Hb 9.5 आया है" | |
| - Include at least 6 distinct clinical data points spoken naturally in dialogue (vitals, history, medications, plans). | |
| - Visit-specific data the ASHA should cover: | |
| ANC: gestational age, BP, weight, Hb, IFA compliance, TT status, fetal movement, birth preparedness (transport, facility, money, blood donor). | |
| PNC/Newborn: mother's bleeding/fever/pain, baby weight, feeding pattern (frequency, latch), cord condition, jaundice check, vaccination status. | |
| Child health: age, weight, feeding/diet, immunization status, milestones, illness symptoms, deworming/Vitamin A. | |
| DIALOGUE STYLE: | |
| - Vary openings — sometimes ASHA calls from the door, sometimes patient greets first, sometimes mid-activity. Do NOT always start with "नमस्ते, कैसे हैं आप?" | |
| - ASHA is warm but efficient — covers clinical ground without sounding like a form. | |
| - Patient speaks colloquially: approximate dates ("लगभग 6 महीना"), local terms, sometimes vague or tangential. | |
| - Patient may volunteer info, ask questions, express worry, or dismiss concerns. | |
| - 20-30 dialogue turns with substantive content (not single-word responses). | |
| - End with ASHA's concrete plan: next visit date, any referral, medications given, counseling summary.""" | |
| SYSTEM_PROMPT_FORM_EXTRACTION = """You are a clinical data extraction system. Extract structured medical data from an ASHA home visit conversation transcript into the provided JSON schema. | |
| RULES: | |
| 1. Extract ONLY information EXPLICITLY stated or clearly implied in the conversation. | |
| 2. Use null for anything not mentioned — never guess or fill in "expected" values. | |
| 3. Numbers must match exactly as stated in conversation (BP, weight, Hb, temperature, age, etc.). | |
| 4. For array fields (symptoms_reported, counseling_provided), extract all relevant items mentioned. | |
| 5. If ASHA states a measurement value, record the exact number, not just "normal". | |
| 6. Patient's approximate statements: convert to best numeric estimate ("लगभग 6 महीना" → gestational_weeks: 24). | |
| 7. Return valid JSON matching the schema. No markdown formatting.""" | |
| SYSTEM_PROMPT_DANGER_EXTRACTION = """You are a clinical danger sign extraction system for Indian ASHA worker home visits. Extract danger signs from conversation transcripts with high precision. | |
| CORE RULES: | |
| 1. ONLY flag danger signs with DIRECT, EXPLICIT evidence in the conversation text. | |
| 2. Each flag MUST include utterance_evidence — the exact Hindi quote that triggered it. | |
| 3. If NO danger signs exist in the conversation, return an empty danger_signs array. This is correct and expected for normal visits. | |
| 4. NEVER invent or hallucinate danger signs. When in doubt, do not flag. | |
| CLASSIFICATION GUIDANCE: | |
| - immediate_referral: life-threatening — heavy uncontrolled bleeding, convulsions, unconsciousness, BP ≥160/110, newborn not breathing | |
| - urgent_care: serious — elevated BP with symptoms (headache/vision/swelling), Hb <7 with symptoms, signs of severe dehydration (sunken eyes + lethargic + not drinking), fast breathing with chest indrawing | |
| - monitor_closely: borderline — isolated mild findings, low-grade fever (99-100°F), mild swelling alone | |
| ANTI-HALLUCINATION: | |
| - Do NOT flag normal values as danger signs. BP 110/70 is normal. Temperature 98.6°F is normal. | |
| - Do NOT flag a sign just because the scenario suggests it — only flag what the CONVERSATION actually says. | |
| - If a value is borderline (e.g., temp 99°F), classify as monitor_closely at most, NOT urgent_care. | |
| REFERRAL LOGIC: | |
| - Any immediate_referral sign → refer_immediately (district_hospital/FRU) | |
| - Only urgent_care signs → refer_within_24h (PHC/CHC) | |
| - Only monitor_closely → continue_monitoring | |
| - No signs → routine_followup | |
| Fill the relevant checklist (maternal or newborn): "detected" if found, "not_detected" if assessed as normal, "not_assessed" if not discussed. | |
| Return valid JSON only.""" | |
| def build_transcript_prompt(scenario: dict) -> str: | |
| danger_instruction = ( | |
| "Yes — include clear, unambiguous danger signs in the conversation" | |
| if scenario['has_danger_signs'] | |
| else "No — this is a normal visit with no danger signs" | |
| ) | |
| symptom_instruction = ( | |
| "Patient describes symptoms matching danger signs in natural colloquial Hindi (not medical jargon)." | |
| if scenario['has_danger_signs'] | |
| else "Patient has no concerning symptoms. Routine, healthy visit." | |
| ) | |
| return f"""Generate an ASHA home visit conversation in Hindi (Devanagari only): | |
| VISIT: {scenario['type']} — {scenario['label']} | |
| CLINICAL DETAILS: {scenario['clinical_profile']} | |
| DANGER SIGNS: {danger_instruction} | |
| Rules: | |
| - Pure dialogue, zero narration/parentheticals. ASHA speaks measurements aloud in Devanagari. | |
| - {symptom_instruction} | |
| - At least 6 clinical data points woven naturally into conversation. | |
| - 20-30 turns. Feel like a real village home visit, not a medical interview.""" | |
| def build_extraction_prompt(transcript: str, visit_type: str, schema: dict) -> str: | |
| return f"""Extract structured data from this ASHA home visit conversation into the provided JSON schema. | |
| CONVERSATION TRANSCRIPT: | |
| {transcript} | |
| OUTPUT JSON SCHEMA: | |
| {json.dumps(schema, ensure_ascii=False, indent=2)} | |
| Extract the data now. Return ONLY valid JSON matching the schema. Use null for fields not mentioned in the conversation.""" | |
| def build_danger_signs_prompt(transcript: str, visit_type: str, schema: dict) -> str: | |
| return f"""Analyze this ASHA home visit conversation for danger signs. | |
| CONVERSATION TRANSCRIPT: | |
| {transcript} | |
| VISIT TYPE: {visit_type} | |
| OUTPUT JSON SCHEMA: | |
| {json.dumps(schema, ensure_ascii=False, indent=2)} | |
| CRITICAL RULES: | |
| - ONLY flag danger signs that have DIRECT evidence in the conversation | |
| - Each danger sign MUST include utterance_evidence — the exact Hindi quote that triggered it | |
| - If NO danger signs are present, return an empty danger_signs array | |
| - Referral decision must be based ONLY on detected danger signs | |
| Return ONLY valid JSON matching the schema.""" | |
| # ============================================================ | |
| # SCHEMA MAPPING | |
| # ============================================================ | |
| def load_schema(name: str) -> dict: | |
| path = Path(f"configs/schemas/{name}.json") | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| VISIT_TYPE_TO_SCHEMA = { | |
| "anc_visit": "anc_visit", | |
| "pnc_visit": "pnc_visit", | |
| "delivery": "delivery", | |
| "child_health": "child_health", | |
| } | |
| # ============================================================ | |
| # VALIDATION | |
| # ============================================================ | |
| def validate_transcript(text: str) -> dict: | |
| """Validate generated transcript quality.""" | |
| result = {"valid": True, "issues": []} | |
| if not text or len(text) < 200: | |
| return {"valid": False, "issues": ["Transcript too short"]} | |
| # Must contain Hindi characters | |
| hindi_chars = sum(1 for c in text if '\u0900' <= c <= '\u097F') | |
| if hindi_chars < 50: | |
| result["valid"] = False | |
| result["issues"].append(f"Too few Hindi characters ({hindi_chars})") | |
| # Must have multiple dialogue turns | |
| turns = text.count("ASHA:") + text.count("Patient:") + text.count("Mother:") | |
| if turns < 8: | |
| result["valid"] = False | |
| result["issues"].append(f"Too few dialogue turns ({turns})") | |
| # Should not contain English narration | |
| narration_markers = ["[", "]", "(walks", "(checks", "(measures", "Scene:", "Setting:"] | |
| for marker in narration_markers: | |
| if marker.lower() in text.lower(): | |
| result["issues"].append(f"Contains narration marker: {marker}") | |
| return result | |
| def validate_extraction(data: dict, visit_type: str, has_danger_signs: bool) -> dict: | |
| """Validate extracted JSON quality and clinical consistency.""" | |
| result = {"valid": True, "issues": []} | |
| if not isinstance(data, dict): | |
| return {"valid": False, "issues": ["Not a dict"]} | |
| # Check it's not empty | |
| non_null_count = _count_non_null(data) | |
| if non_null_count < 3: | |
| result["valid"] = False | |
| result["issues"].append(f"Almost empty extraction ({non_null_count} non-null fields)") | |
| return result | |
| def validate_danger_signs(data: dict, has_danger_signs: bool) -> dict: | |
| """Validate danger sign extraction — the most critical validation.""" | |
| result = {"valid": True, "issues": []} | |
| if not isinstance(data, dict): | |
| return {"valid": False, "issues": ["Not a dict"]} | |
| signs = data.get("danger_signs", []) | |
| # If we expect danger signs, there should be some | |
| if has_danger_signs and len(signs) == 0: | |
| result["valid"] = False | |
| result["issues"].append("Expected danger signs but got none") | |
| # If we DON'T expect danger signs, there should be none | |
| if not has_danger_signs and len(signs) > 0: | |
| result["valid"] = False | |
| result["issues"].append(f"Expected no danger signs but got {len(signs)} — hallucination in training data") | |
| # Every danger sign must have utterance_evidence | |
| for i, sign in enumerate(signs): | |
| if not sign.get("utterance_evidence"): | |
| result["valid"] = False | |
| result["issues"].append(f"Danger sign [{i}] '{sign.get('sign')}' missing utterance_evidence") | |
| # Referral decision consistency | |
| referral = data.get("referral_decision", {}) | |
| decision = referral.get("decision", "") | |
| if has_danger_signs and decision in ("routine_followup", "continue_monitoring") and len(signs) > 0: | |
| # Has flags but says routine — inconsistent | |
| severity = [s.get("category") for s in signs] | |
| if "immediate_referral" in severity: | |
| result["valid"] = False | |
| result["issues"].append("Has immediate_referral signs but decision is not refer_immediately") | |
| if not has_danger_signs and decision in ("refer_immediately", "refer_within_24h"): | |
| result["valid"] = False | |
| result["issues"].append("No danger signs but referral decision is urgent — hallucination") | |
| return result | |
| def _count_non_null(d, count=0): | |
| if isinstance(d, dict): | |
| for v in d.values(): | |
| count = _count_non_null(v, count) | |
| elif isinstance(d, list): | |
| count += len(d) | |
| elif d is not None: | |
| count += 1 | |
| return count | |
| # ============================================================ | |
| # CHECKPOINT | |
| # ============================================================ | |
| def load_checkpoint() -> dict: | |
| if os.path.exists(CHECKPOINT_FILE): | |
| with open(CHECKPOINT_FILE, "r") as f: | |
| return json.load(f) | |
| return { | |
| "generated": 0, "valid": 0, "invalid": 0, "failed": 0, | |
| "total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0}, | |
| } | |
| def save_checkpoint(cp: dict): | |
| with open(CHECKPOINT_FILE, "w") as f: | |
| json.dump(cp, f, indent=2) | |
| # ============================================================ | |
| # COST TRACKING | |
| # ============================================================ | |
| # GPT-4o Mini pricing (as of 2026) | |
| INPUT_COST_PER_M = 0.15 # $0.15 per 1M input tokens | |
| OUTPUT_COST_PER_M = 0.60 # $0.60 per 1M output tokens | |
| def estimate_cost(input_text: str, output_text: str) -> float: | |
| input_tokens = len(input_text) // 4 | |
| output_tokens = len(output_text) // 4 | |
| return (input_tokens * INPUT_COST_PER_M / 1_000_000) + (output_tokens * OUTPUT_COST_PER_M / 1_000_000) | |
| # ============================================================ | |
| # WEIGHTED SCENARIO SAMPLING | |
| # ============================================================ | |
| def sample_scenario() -> dict: | |
| """Sample a scenario weighted by the 'weight' field.""" | |
| weights = [s["weight"] for s in SCENARIOS] | |
| return random.choices(SCENARIOS, weights=weights, k=1)[0] | |
| # ============================================================ | |
| # GENERATION | |
| # ============================================================ | |
| def generate_one_sample(client: OpenAI, scenario: dict, schemas: dict) -> dict | None: | |
| """ | |
| Generate one complete training sample: | |
| 1. Generate Hindi conversation transcript | |
| 2. Extract structured form data | |
| 3. Extract danger signs | |
| 4. Validate all three | |
| Returns the sample dict or None if validation fails. | |
| """ | |
| visit_type = scenario["type"] | |
| form_schema_name = VISIT_TYPE_TO_SCHEMA[visit_type] | |
| form_schema = schemas[form_schema_name] | |
| danger_schema = schemas["danger_signs"] | |
| total_cost = 0.0 | |
| # ── Step 1: Generate transcript ── | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT_TRANSCRIPT}, | |
| {"role": "user", "content": build_transcript_prompt(scenario)}, | |
| ], | |
| max_tokens=MAX_TOKENS_TRANSCRIPT, | |
| temperature=TEMPERATURE, | |
| ) | |
| transcript = resp.choices[0].message.content.strip() | |
| total_cost += estimate_cost(SYSTEM_PROMPT_TRANSCRIPT + build_transcript_prompt(scenario), transcript) | |
| except Exception as e: | |
| return {"error": f"Transcript generation failed: {e}", "cost": total_cost} | |
| # Validate transcript | |
| tv = validate_transcript(transcript) | |
| if not tv["valid"]: | |
| return {"error": f"Invalid transcript: {tv['issues']}", "cost": total_cost} | |
| # ── Step 2: Extract form data ── | |
| try: | |
| extraction_prompt = build_extraction_prompt(transcript, visit_type, form_schema) | |
| resp = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT_FORM_EXTRACTION}, | |
| {"role": "user", "content": extraction_prompt}, | |
| ], | |
| max_tokens=MAX_TOKENS_EXTRACTION, | |
| temperature=JSON_TEMPERATURE, | |
| response_format={"type": "json_object"}, | |
| ) | |
| form_text = resp.choices[0].message.content.strip() | |
| form_data = json.loads(form_text) | |
| total_cost += estimate_cost(SYSTEM_PROMPT_FORM_EXTRACTION + extraction_prompt, form_text) | |
| except json.JSONDecodeError as e: | |
| return {"error": f"Form extraction not valid JSON: {e}", "cost": total_cost} | |
| except Exception as e: | |
| return {"error": f"Form extraction failed: {e}", "cost": total_cost} | |
| # Validate form extraction | |
| ev = validate_extraction(form_data, visit_type, scenario["has_danger_signs"]) | |
| if not ev["valid"]: | |
| return {"error": f"Invalid extraction: {ev['issues']}", "cost": total_cost} | |
| # ── Step 3: Extract danger signs ── | |
| try: | |
| danger_prompt = build_danger_signs_prompt(transcript, visit_type, danger_schema) | |
| resp = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT_DANGER_EXTRACTION}, | |
| {"role": "user", "content": danger_prompt}, | |
| ], | |
| max_tokens=MAX_TOKENS_EXTRACTION, | |
| temperature=JSON_TEMPERATURE, | |
| response_format={"type": "json_object"}, | |
| ) | |
| danger_text = resp.choices[0].message.content.strip() | |
| danger_data = json.loads(danger_text) | |
| total_cost += estimate_cost(SYSTEM_PROMPT_DANGER_EXTRACTION + danger_prompt, danger_text) | |
| except json.JSONDecodeError as e: | |
| return {"error": f"Danger signs not valid JSON: {e}", "cost": total_cost} | |
| except Exception as e: | |
| return {"error": f"Danger signs extraction failed: {e}", "cost": total_cost} | |
| # Validate danger signs — this is the most critical validation | |
| dv = validate_danger_signs(danger_data, scenario["has_danger_signs"]) | |
| if not dv["valid"]: | |
| return {"error": f"Invalid danger signs: {dv['issues']}", "cost": total_cost} | |
| # ── Success — build training sample ── | |
| return { | |
| "sample": { | |
| "id": None, # assigned later | |
| "visit_type": visit_type, | |
| "scenario_label": scenario["label"], | |
| "has_danger_signs": scenario["has_danger_signs"], | |
| "transcript": transcript, | |
| "form_extraction": form_data, | |
| "danger_signs_extraction": danger_data, | |
| "form_schema": form_schema_name, | |
| }, | |
| "validation": { | |
| "transcript": tv, | |
| "extraction": ev, | |
| "danger_signs": dv, | |
| }, | |
| "cost": total_cost, | |
| } | |
| # ============================================================ | |
| # MAIN | |
| # ============================================================ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="MedScribe v2 — Training Data Generation") | |
| parser.add_argument("--dry-run", action="store_true", help="Generate 5 samples only") | |
| parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") | |
| parser.add_argument("--count", type=int, default=500, help="Number of samples to generate") | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| target = 5 if args.dry_run else args.count | |
| print("=" * 60) | |
| print(f"Training Data Generation — {MODEL}" + (" [DRY RUN]" if args.dry_run else "")) | |
| print(f"Target: {target} samples") | |
| print("=" * 60) | |
| # ── Gate: API key ── | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| print("\nABORT: OPENAI_API_KEY not found. Set in .env file.") | |
| sys.exit(1) | |
| client = OpenAI() | |
| # Quick API test | |
| try: | |
| client.chat.completions.create( | |
| model=MODEL, | |
| messages=[{"role": "user", "content": "Say OK"}], | |
| max_tokens=5, | |
| ) | |
| print(f"API connection verified: {MODEL}") | |
| except Exception as e: | |
| print(f"\nABORT: API test failed: {e}") | |
| sys.exit(1) | |
| # ── Load schemas ── | |
| schemas = {} | |
| for name in ["anc_visit", "pnc_visit", "delivery", "child_health", "danger_signs"]: | |
| schemas[name] = load_schema(name) | |
| print(f"Loaded {len(schemas)} schemas") | |
| # ── Load checkpoint ── | |
| cp = load_checkpoint() | |
| start_idx = cp["generated"] if args.resume else 0 | |
| if not args.resume: | |
| cp = { | |
| "generated": 0, "valid": 0, "invalid": 0, "failed": 0, | |
| "total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0}, | |
| "quality": { | |
| "narration_leaks": 0, "low_hindi": 0, "low_turns": 0, | |
| "polarity_errors": 0, "avg_clinical_density": 0.0, | |
| "total_clinical_density": 0, "total_scored": 0, | |
| }, | |
| } | |
| # Open output file (append if resuming) | |
| mode = "a" if args.resume and os.path.exists(OUTPUT_FILE) else "w" | |
| outfile = open(OUTPUT_FILE, mode, encoding="utf-8") | |
| print(f"\nStarting from sample {start_idx + 1}...") | |
| if args.dry_run: | |
| print("DRY RUN: Generating 5 samples. Review output before full run.\n") | |
| batch_count = 0 | |
| for i in range(start_idx, target): | |
| scenario = sample_scenario() | |
| vtype = scenario["type"] | |
| result = generate_one_sample(client, scenario, schemas) | |
| # Track cost regardless | |
| cp["total_cost"] += result.get("cost", 0) | |
| if "error" in result: | |
| cp["failed"] += 1 | |
| cp["generated"] += 1 | |
| status = f"FAIL: {result['error'][:80]}" | |
| elif "sample" in result: | |
| sample = result["sample"] | |
| sample["id"] = f"ms2_{i:04d}" | |
| # Write to JSONL | |
| outfile.write(json.dumps(sample, ensure_ascii=False) + "\n") | |
| outfile.flush() | |
| cp["valid"] += 1 | |
| cp["generated"] += 1 | |
| cp["by_type"][vtype] = cp["by_type"].get(vtype, 0) + 1 | |
| if scenario["has_danger_signs"]: | |
| cp["by_danger"]["positive"] += 1 | |
| else: | |
| cp["by_danger"]["negative"] += 1 | |
| n_signs = len(sample["danger_signs_extraction"].get("danger_signs", [])) | |
| status = f"OK [{vtype}] signs={n_signs}" | |
| # ── Real-time quality monitoring ── | |
| q = cp.setdefault("quality", { | |
| "narration_leaks": 0, "low_hindi": 0, "low_turns": 0, | |
| "polarity_errors": 0, "avg_clinical_density": 0.0, | |
| "total_clinical_density": 0, "total_scored": 0, | |
| }) | |
| t = sample["transcript"] | |
| # Check narration (exclude single English words in parens like "(latch)") | |
| paren_matches = re.findall(r'\([^)]{5,}\)', t) # only flag parens with 5+ chars | |
| bracket_matches = re.findall(r'\[[^\]]+\]', t) | |
| if paren_matches or bracket_matches or '*' in t: | |
| q["narration_leaks"] += 1 | |
| # Check Hindi density | |
| hindi_chars = sum(1 for c in t if '\u0900' <= c <= '\u097F') | |
| if hindi_chars / max(len(t), 1) < 0.5: | |
| q["low_hindi"] += 1 | |
| # Check turns | |
| turns = t.count("ASHA:") + t.count("Patient:") + t.count("Mother:") | |
| if turns < 15: | |
| q["low_turns"] += 1 | |
| # Clinical density (count numbers near medical terms) | |
| clin_kws = ['bp', 'weight', 'वजन', 'kg', 'hb', 'हीमोग्लोबिन', 'तापमान', | |
| 'हफ्ता', 'महीना', 'tablet', 'ग्राम', 'डिग्री', 'किलो'] | |
| nums = re.findall(r'\d+\.?\d*', t) | |
| clin_count = 0 | |
| for n in nums: | |
| pos = t.find(n) | |
| ctx = t[max(0, pos-40):pos+40].lower() | |
| if any(kw in ctx for kw in clin_kws): | |
| clin_count += 1 | |
| q["total_clinical_density"] += clin_count | |
| q["total_scored"] += 1 | |
| # Polarity check | |
| if scenario["has_danger_signs"] and n_signs == 0: | |
| q["polarity_errors"] += 1 | |
| elif not scenario["has_danger_signs"] and n_signs > 0: | |
| q["polarity_errors"] += 1 | |
| else: | |
| cp["failed"] += 1 | |
| cp["generated"] += 1 | |
| status = "UNKNOWN" | |
| batch_count += 1 | |
| pct = cp["generated"] / target * 100 | |
| print(f" [{cp['generated']}/{target}] ({pct:.0f}%) ${cp['total_cost']:.4f} | {scenario['label'][:50]} | {status}") | |
| # ── Safety checks ── | |
| if cp["total_cost"] > MAX_BUDGET: | |
| print(f"\nABORT: Budget exceeded (${cp['total_cost']:.2f} > ${MAX_BUDGET})") | |
| break | |
| if cp["generated"] >= MIN_BATCH_FOR_CHECK: | |
| total_attempted = cp["valid"] + cp["failed"] | |
| if total_attempted > 0 and cp["failed"] / total_attempted > FAILURE_RATE_ABORT: | |
| print(f"\nABORT: Failure rate {cp['failed']}/{total_attempted} exceeds threshold") | |
| break | |
| # Checkpoint + quality report | |
| if batch_count >= BATCH_SIZE: | |
| save_checkpoint(cp) | |
| batch_count = 0 | |
| # Periodic quality report every 50 samples | |
| q = cp.get("quality", {}) | |
| scored = q.get("total_scored", 0) | |
| if scored > 0 and scored % 50 < BATCH_SIZE: | |
| avg_clin = q["total_clinical_density"] / scored | |
| print(f"\n ┌── QUALITY REPORT (n={scored}) ──") | |
| print(f" │ Narration leaks: {q['narration_leaks']} ({q['narration_leaks']/scored*100:.0f}%)") | |
| print(f" │ Low Hindi: {q['low_hindi']} ({q['low_hindi']/scored*100:.0f}%)") | |
| print(f" │ Low turns (<15): {q['low_turns']} ({q['low_turns']/scored*100:.0f}%)") | |
| print(f" │ Polarity errors: {q['polarity_errors']} ({q['polarity_errors']/scored*100:.0f}%)") | |
| print(f" │ Avg clinical #s: {avg_clin:.1f} per sample") | |
| print(f" └{'─' * 35}") | |
| # ABORT on polarity errors > 5% | |
| if scored >= 30 and q["polarity_errors"] / scored > 0.05: | |
| print(f"\n ABORT: Polarity error rate {q['polarity_errors']}/{scored} > 5%") | |
| print(f" This means the model is hallucinating danger signs or missing real ones.") | |
| print(f" Fix prompts before continuing.") | |
| break | |
| # Rate limit | |
| time.sleep(0.3) | |
| outfile.close() | |
| save_checkpoint(cp) | |
| # ── Summary ── | |
| print(f"\n{'=' * 60}") | |
| print("GENERATION SUMMARY") | |
| print("=" * 60) | |
| print(f" Total generated: {cp['generated']}") | |
| print(f" Valid samples: {cp['valid']}") | |
| print(f" Failed: {cp['failed']}") | |
| print(f" Total cost: ${cp['total_cost']:.4f}") | |
| print(f"\n By visit type:") | |
| for vt, count in sorted(cp["by_type"].items()): | |
| print(f" {vt}: {count}") | |
| print(f"\n Danger sign balance:") | |
| print(f" Positive (has danger signs): {cp['by_danger']['positive']}") | |
| print(f" Negative (no danger signs): {cp['by_danger']['negative']}") | |
| total_with_labels = cp['by_danger']['positive'] + cp['by_danger']['negative'] | |
| if total_with_labels > 0: | |
| neg_pct = cp['by_danger']['negative'] / total_with_labels * 100 | |
| print(f" Negative ratio: {neg_pct:.0f}% (target: ~40%)") | |
| # Quality summary | |
| q = cp.get("quality", {}) | |
| scored = q.get("total_scored", 0) | |
| if scored > 0: | |
| avg_clin = q["total_clinical_density"] / scored | |
| print(f"\n Quality metrics:") | |
| print(f" Narration leaks: {q['narration_leaks']}/{scored} ({q['narration_leaks']/scored*100:.0f}%)") | |
| print(f" Low Hindi: {q['low_hindi']}/{scored} ({q['low_hindi']/scored*100:.0f}%)") | |
| print(f" Low turns (<15): {q['low_turns']}/{scored} ({q['low_turns']/scored*100:.0f}%)") | |
| print(f" Polarity errors: {q['polarity_errors']}/{scored} ({q['polarity_errors']/scored*100:.0f}%)") | |
| print(f" Avg clinical #s: {avg_clin:.1f} per sample") | |
| if q['polarity_errors'] > 0: | |
| print(f"\n WARNING: {q['polarity_errors']} polarity errors detected!") | |
| print(f" Review these samples before training.") | |
| print(f"\n Output: {OUTPUT_FILE}") | |
| if args.dry_run: | |
| print(f"\nDRY RUN complete. Review the output file.") | |
| print(f"If quality looks good, run: python scripts/generate_training_data.py --count {args.count}") | |
| if cp["valid"] >= 200: | |
| print(f"\nREADY for next step: python scripts/04_prepare_training.py") | |
| elif cp["valid"] > 0: | |
| print(f"\n{cp['valid']} samples generated. May need more for good fine-tune results.") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |