sakhi / scripts /generate_training_data.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
MedScribe v2 — Synthetic Training Data Generation
Generates paired (Hindi ASHA conversation, structured JSON extraction) training data
via GPT-4o Mini API. Each sample produces:
1. A realistic Hindi ASHA home visit conversation transcript
2. The correct structured JSON extraction (form fields + danger signs)
Adapted from MedScribe v1's proven generation pipeline with:
- Budget cap + cost tracking
- Per-sample validation (schema + clinical consistency)
- Checkpoint/resume support
- Batch quality monitoring with abort threshold
- Dry-run mode
- Negative examples (no danger signs, many null fields)
Usage:
python scripts/generate_training_data.py --dry-run # First 5 samples
python scripts/generate_training_data.py # Full generation
python scripts/generate_training_data.py --resume # Resume from checkpoint
python scripts/generate_training_data.py --count 500 # Generate N samples
"""
import argparse
import json
import os
import random
import re
import sys
import time
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
# ============================================================
# CONFIG
# ============================================================
OUTPUT_DIR = "data/processed"
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "training_data_raw.jsonl")
CHECKPOINT_FILE = os.path.join(OUTPUT_DIR, "generation_checkpoint.json")
MODEL = "gpt-4o-mini"
TEMPERATURE = 0.7 # Higher for diversity in conversations
JSON_TEMPERATURE = 0.1 # Lower for accurate structured extraction
MAX_TOKENS_TRANSCRIPT = 2500
MAX_TOKENS_EXTRACTION = 2000
BATCH_SIZE = 10
MAX_BUDGET = 10.0
FAILURE_RATE_ABORT = 0.30
MIN_BATCH_FOR_CHECK = 20
# ============================================================
# VISIT SCENARIO DEFINITIONS
# ============================================================
# Each scenario defines the visit type, clinical profile, and
# whether danger signs should be present. ~40% of samples should
# be NEGATIVE (no danger signs) to teach the model restraint.
SCENARIOS = [
# ── ANC VISITS (40% of data) ──
{
"type": "anc_visit",
"label": "ANC — Normal pregnancy, routine visit",
"has_danger_signs": False,
"weight": 8,
"clinical_profile": (
"Healthy primigravida, 24 weeks, all vitals normal, no complaints. "
"BP 110/70, Hb 11.5, weight gain normal. Routine ANC visit. "
"ASHA counsels on diet, IFA tablets, birth preparedness."
),
},
{
"type": "anc_visit",
"label": "ANC — Mild anemia, otherwise normal",
"has_danger_signs": False,
"weight": 5,
"clinical_profile": (
"Second pregnancy, 28 weeks, mild anemia (Hb 9.8). No other complaints. "
"BP normal, fetal movements good. ASHA gives IFA, advises diet."
),
},
{
"type": "anc_visit",
"label": "ANC — Preeclampsia signs (HIGH RISK)",
"has_danger_signs": True,
"weight": 5,
"clinical_profile": (
"Third trimester (32-36 weeks), elevated BP (140-160/90-110), "
"headache, blurred vision, facial/hand swelling, excessive weight gain. "
"Possible preeclampsia. Needs immediate referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Severe anemia (HIGH RISK)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Second trimester, severe anemia (Hb <7), extreme weakness, breathlessness, "
"pallor, dizziness. Needs urgent referral for IV iron/transfusion."
),
},
{
"type": "anc_visit",
"label": "ANC — Reduced fetal movement",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Late third trimester, patient reports baby moving less than usual "
"for 2 days. Other vitals may be normal. Needs monitoring/referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Vaginal bleeding",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Any trimester, reports vaginal bleeding (spotting to heavy). "
"Immediate danger sign requiring emergency referral."
),
},
{
"type": "anc_visit",
"label": "ANC — Gestational diabetes",
"has_danger_signs": False,
"weight": 2,
"clinical_profile": (
"28-32 weeks, elevated blood sugar found on testing. No acute danger signs "
"but needs dietary counseling and follow-up. BP normal."
),
},
{
"type": "anc_visit",
"label": "ANC — Young primigravida, many null fields",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"First visit, early pregnancy (8-10 weeks). Minimal information available — "
"no labs done yet, no vitals taken at this visit (home visit, no equipment beyond BP). "
"Many form fields should be null. Just registration and counseling."
),
},
# ── PNC / HBNC VISITS (25% of data) ──
{
"type": "pnc_visit",
"label": "PNC — Normal postpartum + healthy newborn",
"has_danger_signs": False,
"weight": 6,
"clinical_profile": (
"Day 7 HBNC visit. Mother recovering well, no fever, bleeding light. "
"Baby weight 3.0 kg (birth weight 2.8), breastfeeding well, "
"cord clean and dry, active, no jaundice."
),
},
{
"type": "pnc_visit",
"label": "PNC — Low birth weight newborn, needs monitoring",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"Day 14 visit. Baby 2.1 kg (birth weight 1.9), gaining slowly. "
"Breastfeeding adequate but not frequent enough. Mother well. "
"No acute danger signs but close monitoring needed."
),
},
{
"type": "pnc_visit",
"label": "PNC — Newborn not feeding well (DANGER)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"Day 3 visit. Newborn lethargic, not latching, weak cry. "
"May have fever or hypothermia. Mother reports baby sleeping too much. "
"IMNCI danger signs present — urgent referral."
),
},
{
"type": "pnc_visit",
"label": "PNC — Postpartum hemorrhage signs (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Day 1-3 visit. Mother reports heavy bleeding, soaking through cloth. "
"Feeling dizzy and weak. Possible postpartum hemorrhage. "
"Immediate referral needed."
),
},
{
"type": "pnc_visit",
"label": "PNC — Newborn jaundice",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Day 3-5 visit. Newborn has yellow skin, possibly yellow palms/soles. "
"Feeding okay or slightly reduced. Needs assessment for severity — "
"jaundice within 24h of birth or palms/soles = severe."
),
},
# ── DELIVERY (10% of data) ──
{
"type": "delivery",
"label": "Delivery — Normal institutional delivery",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"Full-term normal delivery at PHC/district hospital. "
"Healthy baby boy/girl, cried immediately, breastfed within 1 hour. "
"Birth weight 2.8-3.5 kg. Mother stable. Vaccines given at birth."
),
},
{
"type": "delivery",
"label": "Delivery — Home delivery (partial info)",
"has_danger_signs": False,
"weight": 2,
"clinical_profile": (
"Home delivery attended by dai/family. ASHA visiting after the fact. "
"Limited info on delivery details. Baby seems okay. "
"Need to check birth weight, breastfeeding, vaccines."
),
},
{
"type": "delivery",
"label": "Delivery — Preterm with complications (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"Preterm delivery (34-36 weeks). Low birth weight (<2.5 kg). "
"Baby may have breathing difficulty. Mother may have had complications. "
"Needs close monitoring or referral."
),
},
# ── CHILD HEALTH / HBYC (25% of data) ──
{
"type": "child_health",
"label": "HBYC — Healthy 6-month-old, routine visit",
"has_danger_signs": False,
"weight": 5,
"clinical_profile": (
"6-month HBYC visit. Good weight gain, breastfeeding + complementary food started. "
"Immunizations up to date. Milestones appropriate. No illness."
),
},
{
"type": "child_health",
"label": "HBYC — Underweight child, no acute danger",
"has_danger_signs": False,
"weight": 3,
"clinical_profile": (
"9-month visit. Child slightly underweight, pallor present (mild anemia). "
"Eating poorly. No acute illness. ASHA counsels on feeding. "
"Deworming due. Vitamin A due."
),
},
{
"type": "child_health",
"label": "HBYC — Diarrhea + dehydration (DANGER)",
"has_danger_signs": True,
"weight": 3,
"clinical_profile": (
"12-month visit. Child has watery diarrhea for 3 days, "
"not drinking well, sunken eyes, lethargic. Possible severe dehydration. "
"IMNCI general danger signs may be present."
),
},
{
"type": "child_health",
"label": "HBYC — Pneumonia signs (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"9-month visit. Child has cough for 5 days, fast breathing, "
"chest indrawing visible. Fever present. Possible pneumonia. "
"Needs urgent referral per IMNCI."
),
},
{
"type": "child_health",
"label": "HBYC — Severe malnutrition (DANGER)",
"has_danger_signs": True,
"weight": 2,
"clinical_profile": (
"15-month visit. Visible severe wasting, very low weight for age, "
"child not active. Possible edema of feet. Severe acute malnutrition "
"requiring NRC referral."
),
},
]
# ============================================================
# PROMPTS
# ============================================================
SYSTEM_PROMPT_TRANSCRIPT = """You generate realistic Hindi conversations between ASHA health workers and patients during home visits in rural India. These will train a medical AI — clinical accuracy and natural dialogue are both critical.
OUTPUT FORMAT (strict):
- ONLY dialogue lines. Each line: "ASHA:", "Patient:", or "Mother:" followed by spoken words.
- ABSOLUTELY NO narration, stage directions, action descriptions, or parentheticals.
WRONG: (BP चेक करते हुए) / [measures weight] / *takes temperature* / (बच्चे को देख कर)
RIGHT: "ASHA: चलिए, BP देख लेती हूँ... 150/95 आ रहा है, ये तो ज़्यादा है।"
- ALL text MUST be in Devanagari script. Do NOT use Romanized Hindi (no "Aapka BP", no "theek hai").
Only English words allowed: medical terms (BP, Hb, TT, IFA, ORS, kg, mg, PHC, CHC).
CLINICAL REQUIREMENTS:
- ASHA must verbally state EVERY measurement with its value in Devanagari: "आपका BP 140/90 आ रहा है", "बच्चे का वज़न 3.1 kg है", "Hb 9.5 आया है"
- Include at least 6 distinct clinical data points spoken naturally in dialogue (vitals, history, medications, plans).
- Visit-specific data the ASHA should cover:
ANC: gestational age, BP, weight, Hb, IFA compliance, TT status, fetal movement, birth preparedness (transport, facility, money, blood donor).
PNC/Newborn: mother's bleeding/fever/pain, baby weight, feeding pattern (frequency, latch), cord condition, jaundice check, vaccination status.
Child health: age, weight, feeding/diet, immunization status, milestones, illness symptoms, deworming/Vitamin A.
DIALOGUE STYLE:
- Vary openings — sometimes ASHA calls from the door, sometimes patient greets first, sometimes mid-activity. Do NOT always start with "नमस्ते, कैसे हैं आप?"
- ASHA is warm but efficient — covers clinical ground without sounding like a form.
- Patient speaks colloquially: approximate dates ("लगभग 6 महीना"), local terms, sometimes vague or tangential.
- Patient may volunteer info, ask questions, express worry, or dismiss concerns.
- 20-30 dialogue turns with substantive content (not single-word responses).
- End with ASHA's concrete plan: next visit date, any referral, medications given, counseling summary."""
SYSTEM_PROMPT_FORM_EXTRACTION = """You are a clinical data extraction system. Extract structured medical data from an ASHA home visit conversation transcript into the provided JSON schema.
RULES:
1. Extract ONLY information EXPLICITLY stated or clearly implied in the conversation.
2. Use null for anything not mentioned — never guess or fill in "expected" values.
3. Numbers must match exactly as stated in conversation (BP, weight, Hb, temperature, age, etc.).
4. For array fields (symptoms_reported, counseling_provided), extract all relevant items mentioned.
5. If ASHA states a measurement value, record the exact number, not just "normal".
6. Patient's approximate statements: convert to best numeric estimate ("लगभग 6 महीना" → gestational_weeks: 24).
7. Return valid JSON matching the schema. No markdown formatting."""
SYSTEM_PROMPT_DANGER_EXTRACTION = """You are a clinical danger sign extraction system for Indian ASHA worker home visits. Extract danger signs from conversation transcripts with high precision.
CORE RULES:
1. ONLY flag danger signs with DIRECT, EXPLICIT evidence in the conversation text.
2. Each flag MUST include utterance_evidence — the exact Hindi quote that triggered it.
3. If NO danger signs exist in the conversation, return an empty danger_signs array. This is correct and expected for normal visits.
4. NEVER invent or hallucinate danger signs. When in doubt, do not flag.
CLASSIFICATION GUIDANCE:
- immediate_referral: life-threatening — heavy uncontrolled bleeding, convulsions, unconsciousness, BP ≥160/110, newborn not breathing
- urgent_care: serious — elevated BP with symptoms (headache/vision/swelling), Hb <7 with symptoms, signs of severe dehydration (sunken eyes + lethargic + not drinking), fast breathing with chest indrawing
- monitor_closely: borderline — isolated mild findings, low-grade fever (99-100°F), mild swelling alone
ANTI-HALLUCINATION:
- Do NOT flag normal values as danger signs. BP 110/70 is normal. Temperature 98.6°F is normal.
- Do NOT flag a sign just because the scenario suggests it — only flag what the CONVERSATION actually says.
- If a value is borderline (e.g., temp 99°F), classify as monitor_closely at most, NOT urgent_care.
REFERRAL LOGIC:
- Any immediate_referral sign → refer_immediately (district_hospital/FRU)
- Only urgent_care signs → refer_within_24h (PHC/CHC)
- Only monitor_closely → continue_monitoring
- No signs → routine_followup
Fill the relevant checklist (maternal or newborn): "detected" if found, "not_detected" if assessed as normal, "not_assessed" if not discussed.
Return valid JSON only."""
def build_transcript_prompt(scenario: dict) -> str:
danger_instruction = (
"Yes — include clear, unambiguous danger signs in the conversation"
if scenario['has_danger_signs']
else "No — this is a normal visit with no danger signs"
)
symptom_instruction = (
"Patient describes symptoms matching danger signs in natural colloquial Hindi (not medical jargon)."
if scenario['has_danger_signs']
else "Patient has no concerning symptoms. Routine, healthy visit."
)
return f"""Generate an ASHA home visit conversation in Hindi (Devanagari only):
VISIT: {scenario['type']}{scenario['label']}
CLINICAL DETAILS: {scenario['clinical_profile']}
DANGER SIGNS: {danger_instruction}
Rules:
- Pure dialogue, zero narration/parentheticals. ASHA speaks measurements aloud in Devanagari.
- {symptom_instruction}
- At least 6 clinical data points woven naturally into conversation.
- 20-30 turns. Feel like a real village home visit, not a medical interview."""
def build_extraction_prompt(transcript: str, visit_type: str, schema: dict) -> str:
return f"""Extract structured data from this ASHA home visit conversation into the provided JSON schema.
CONVERSATION TRANSCRIPT:
{transcript}
OUTPUT JSON SCHEMA:
{json.dumps(schema, ensure_ascii=False, indent=2)}
Extract the data now. Return ONLY valid JSON matching the schema. Use null for fields not mentioned in the conversation."""
def build_danger_signs_prompt(transcript: str, visit_type: str, schema: dict) -> str:
return f"""Analyze this ASHA home visit conversation for danger signs.
CONVERSATION TRANSCRIPT:
{transcript}
VISIT TYPE: {visit_type}
OUTPUT JSON SCHEMA:
{json.dumps(schema, ensure_ascii=False, indent=2)}
CRITICAL RULES:
- ONLY flag danger signs that have DIRECT evidence in the conversation
- Each danger sign MUST include utterance_evidence — the exact Hindi quote that triggered it
- If NO danger signs are present, return an empty danger_signs array
- Referral decision must be based ONLY on detected danger signs
Return ONLY valid JSON matching the schema."""
# ============================================================
# SCHEMA MAPPING
# ============================================================
def load_schema(name: str) -> dict:
path = Path(f"configs/schemas/{name}.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
VISIT_TYPE_TO_SCHEMA = {
"anc_visit": "anc_visit",
"pnc_visit": "pnc_visit",
"delivery": "delivery",
"child_health": "child_health",
}
# ============================================================
# VALIDATION
# ============================================================
def validate_transcript(text: str) -> dict:
"""Validate generated transcript quality."""
result = {"valid": True, "issues": []}
if not text or len(text) < 200:
return {"valid": False, "issues": ["Transcript too short"]}
# Must contain Hindi characters
hindi_chars = sum(1 for c in text if '\u0900' <= c <= '\u097F')
if hindi_chars < 50:
result["valid"] = False
result["issues"].append(f"Too few Hindi characters ({hindi_chars})")
# Must have multiple dialogue turns
turns = text.count("ASHA:") + text.count("Patient:") + text.count("Mother:")
if turns < 8:
result["valid"] = False
result["issues"].append(f"Too few dialogue turns ({turns})")
# Should not contain English narration
narration_markers = ["[", "]", "(walks", "(checks", "(measures", "Scene:", "Setting:"]
for marker in narration_markers:
if marker.lower() in text.lower():
result["issues"].append(f"Contains narration marker: {marker}")
return result
def validate_extraction(data: dict, visit_type: str, has_danger_signs: bool) -> dict:
"""Validate extracted JSON quality and clinical consistency."""
result = {"valid": True, "issues": []}
if not isinstance(data, dict):
return {"valid": False, "issues": ["Not a dict"]}
# Check it's not empty
non_null_count = _count_non_null(data)
if non_null_count < 3:
result["valid"] = False
result["issues"].append(f"Almost empty extraction ({non_null_count} non-null fields)")
return result
def validate_danger_signs(data: dict, has_danger_signs: bool) -> dict:
"""Validate danger sign extraction — the most critical validation."""
result = {"valid": True, "issues": []}
if not isinstance(data, dict):
return {"valid": False, "issues": ["Not a dict"]}
signs = data.get("danger_signs", [])
# If we expect danger signs, there should be some
if has_danger_signs and len(signs) == 0:
result["valid"] = False
result["issues"].append("Expected danger signs but got none")
# If we DON'T expect danger signs, there should be none
if not has_danger_signs and len(signs) > 0:
result["valid"] = False
result["issues"].append(f"Expected no danger signs but got {len(signs)} — hallucination in training data")
# Every danger sign must have utterance_evidence
for i, sign in enumerate(signs):
if not sign.get("utterance_evidence"):
result["valid"] = False
result["issues"].append(f"Danger sign [{i}] '{sign.get('sign')}' missing utterance_evidence")
# Referral decision consistency
referral = data.get("referral_decision", {})
decision = referral.get("decision", "")
if has_danger_signs and decision in ("routine_followup", "continue_monitoring") and len(signs) > 0:
# Has flags but says routine — inconsistent
severity = [s.get("category") for s in signs]
if "immediate_referral" in severity:
result["valid"] = False
result["issues"].append("Has immediate_referral signs but decision is not refer_immediately")
if not has_danger_signs and decision in ("refer_immediately", "refer_within_24h"):
result["valid"] = False
result["issues"].append("No danger signs but referral decision is urgent — hallucination")
return result
def _count_non_null(d, count=0):
if isinstance(d, dict):
for v in d.values():
count = _count_non_null(v, count)
elif isinstance(d, list):
count += len(d)
elif d is not None:
count += 1
return count
# ============================================================
# CHECKPOINT
# ============================================================
def load_checkpoint() -> dict:
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r") as f:
return json.load(f)
return {
"generated": 0, "valid": 0, "invalid": 0, "failed": 0,
"total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0},
}
def save_checkpoint(cp: dict):
with open(CHECKPOINT_FILE, "w") as f:
json.dump(cp, f, indent=2)
# ============================================================
# COST TRACKING
# ============================================================
# GPT-4o Mini pricing (as of 2026)
INPUT_COST_PER_M = 0.15 # $0.15 per 1M input tokens
OUTPUT_COST_PER_M = 0.60 # $0.60 per 1M output tokens
def estimate_cost(input_text: str, output_text: str) -> float:
input_tokens = len(input_text) // 4
output_tokens = len(output_text) // 4
return (input_tokens * INPUT_COST_PER_M / 1_000_000) + (output_tokens * OUTPUT_COST_PER_M / 1_000_000)
# ============================================================
# WEIGHTED SCENARIO SAMPLING
# ============================================================
def sample_scenario() -> dict:
"""Sample a scenario weighted by the 'weight' field."""
weights = [s["weight"] for s in SCENARIOS]
return random.choices(SCENARIOS, weights=weights, k=1)[0]
# ============================================================
# GENERATION
# ============================================================
def generate_one_sample(client: OpenAI, scenario: dict, schemas: dict) -> dict | None:
"""
Generate one complete training sample:
1. Generate Hindi conversation transcript
2. Extract structured form data
3. Extract danger signs
4. Validate all three
Returns the sample dict or None if validation fails.
"""
visit_type = scenario["type"]
form_schema_name = VISIT_TYPE_TO_SCHEMA[visit_type]
form_schema = schemas[form_schema_name]
danger_schema = schemas["danger_signs"]
total_cost = 0.0
# ── Step 1: Generate transcript ──
try:
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_TRANSCRIPT},
{"role": "user", "content": build_transcript_prompt(scenario)},
],
max_tokens=MAX_TOKENS_TRANSCRIPT,
temperature=TEMPERATURE,
)
transcript = resp.choices[0].message.content.strip()
total_cost += estimate_cost(SYSTEM_PROMPT_TRANSCRIPT + build_transcript_prompt(scenario), transcript)
except Exception as e:
return {"error": f"Transcript generation failed: {e}", "cost": total_cost}
# Validate transcript
tv = validate_transcript(transcript)
if not tv["valid"]:
return {"error": f"Invalid transcript: {tv['issues']}", "cost": total_cost}
# ── Step 2: Extract form data ──
try:
extraction_prompt = build_extraction_prompt(transcript, visit_type, form_schema)
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_FORM_EXTRACTION},
{"role": "user", "content": extraction_prompt},
],
max_tokens=MAX_TOKENS_EXTRACTION,
temperature=JSON_TEMPERATURE,
response_format={"type": "json_object"},
)
form_text = resp.choices[0].message.content.strip()
form_data = json.loads(form_text)
total_cost += estimate_cost(SYSTEM_PROMPT_FORM_EXTRACTION + extraction_prompt, form_text)
except json.JSONDecodeError as e:
return {"error": f"Form extraction not valid JSON: {e}", "cost": total_cost}
except Exception as e:
return {"error": f"Form extraction failed: {e}", "cost": total_cost}
# Validate form extraction
ev = validate_extraction(form_data, visit_type, scenario["has_danger_signs"])
if not ev["valid"]:
return {"error": f"Invalid extraction: {ev['issues']}", "cost": total_cost}
# ── Step 3: Extract danger signs ──
try:
danger_prompt = build_danger_signs_prompt(transcript, visit_type, danger_schema)
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT_DANGER_EXTRACTION},
{"role": "user", "content": danger_prompt},
],
max_tokens=MAX_TOKENS_EXTRACTION,
temperature=JSON_TEMPERATURE,
response_format={"type": "json_object"},
)
danger_text = resp.choices[0].message.content.strip()
danger_data = json.loads(danger_text)
total_cost += estimate_cost(SYSTEM_PROMPT_DANGER_EXTRACTION + danger_prompt, danger_text)
except json.JSONDecodeError as e:
return {"error": f"Danger signs not valid JSON: {e}", "cost": total_cost}
except Exception as e:
return {"error": f"Danger signs extraction failed: {e}", "cost": total_cost}
# Validate danger signs — this is the most critical validation
dv = validate_danger_signs(danger_data, scenario["has_danger_signs"])
if not dv["valid"]:
return {"error": f"Invalid danger signs: {dv['issues']}", "cost": total_cost}
# ── Success — build training sample ──
return {
"sample": {
"id": None, # assigned later
"visit_type": visit_type,
"scenario_label": scenario["label"],
"has_danger_signs": scenario["has_danger_signs"],
"transcript": transcript,
"form_extraction": form_data,
"danger_signs_extraction": danger_data,
"form_schema": form_schema_name,
},
"validation": {
"transcript": tv,
"extraction": ev,
"danger_signs": dv,
},
"cost": total_cost,
}
# ============================================================
# MAIN
# ============================================================
def main():
parser = argparse.ArgumentParser(description="MedScribe v2 — Training Data Generation")
parser.add_argument("--dry-run", action="store_true", help="Generate 5 samples only")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
parser.add_argument("--count", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
os.makedirs(OUTPUT_DIR, exist_ok=True)
target = 5 if args.dry_run else args.count
print("=" * 60)
print(f"Training Data Generation — {MODEL}" + (" [DRY RUN]" if args.dry_run else ""))
print(f"Target: {target} samples")
print("=" * 60)
# ── Gate: API key ──
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("\nABORT: OPENAI_API_KEY not found. Set in .env file.")
sys.exit(1)
client = OpenAI()
# Quick API test
try:
client.chat.completions.create(
model=MODEL,
messages=[{"role": "user", "content": "Say OK"}],
max_tokens=5,
)
print(f"API connection verified: {MODEL}")
except Exception as e:
print(f"\nABORT: API test failed: {e}")
sys.exit(1)
# ── Load schemas ──
schemas = {}
for name in ["anc_visit", "pnc_visit", "delivery", "child_health", "danger_signs"]:
schemas[name] = load_schema(name)
print(f"Loaded {len(schemas)} schemas")
# ── Load checkpoint ──
cp = load_checkpoint()
start_idx = cp["generated"] if args.resume else 0
if not args.resume:
cp = {
"generated": 0, "valid": 0, "invalid": 0, "failed": 0,
"total_cost": 0.0, "by_type": {}, "by_danger": {"positive": 0, "negative": 0},
"quality": {
"narration_leaks": 0, "low_hindi": 0, "low_turns": 0,
"polarity_errors": 0, "avg_clinical_density": 0.0,
"total_clinical_density": 0, "total_scored": 0,
},
}
# Open output file (append if resuming)
mode = "a" if args.resume and os.path.exists(OUTPUT_FILE) else "w"
outfile = open(OUTPUT_FILE, mode, encoding="utf-8")
print(f"\nStarting from sample {start_idx + 1}...")
if args.dry_run:
print("DRY RUN: Generating 5 samples. Review output before full run.\n")
batch_count = 0
for i in range(start_idx, target):
scenario = sample_scenario()
vtype = scenario["type"]
result = generate_one_sample(client, scenario, schemas)
# Track cost regardless
cp["total_cost"] += result.get("cost", 0)
if "error" in result:
cp["failed"] += 1
cp["generated"] += 1
status = f"FAIL: {result['error'][:80]}"
elif "sample" in result:
sample = result["sample"]
sample["id"] = f"ms2_{i:04d}"
# Write to JSONL
outfile.write(json.dumps(sample, ensure_ascii=False) + "\n")
outfile.flush()
cp["valid"] += 1
cp["generated"] += 1
cp["by_type"][vtype] = cp["by_type"].get(vtype, 0) + 1
if scenario["has_danger_signs"]:
cp["by_danger"]["positive"] += 1
else:
cp["by_danger"]["negative"] += 1
n_signs = len(sample["danger_signs_extraction"].get("danger_signs", []))
status = f"OK [{vtype}] signs={n_signs}"
# ── Real-time quality monitoring ──
q = cp.setdefault("quality", {
"narration_leaks": 0, "low_hindi": 0, "low_turns": 0,
"polarity_errors": 0, "avg_clinical_density": 0.0,
"total_clinical_density": 0, "total_scored": 0,
})
t = sample["transcript"]
# Check narration (exclude single English words in parens like "(latch)")
paren_matches = re.findall(r'\([^)]{5,}\)', t) # only flag parens with 5+ chars
bracket_matches = re.findall(r'\[[^\]]+\]', t)
if paren_matches or bracket_matches or '*' in t:
q["narration_leaks"] += 1
# Check Hindi density
hindi_chars = sum(1 for c in t if '\u0900' <= c <= '\u097F')
if hindi_chars / max(len(t), 1) < 0.5:
q["low_hindi"] += 1
# Check turns
turns = t.count("ASHA:") + t.count("Patient:") + t.count("Mother:")
if turns < 15:
q["low_turns"] += 1
# Clinical density (count numbers near medical terms)
clin_kws = ['bp', 'weight', 'वजन', 'kg', 'hb', 'हीमोग्लोबिन', 'तापमान',
'हफ्ता', 'महीना', 'tablet', 'ग्राम', 'डिग्री', 'किलो']
nums = re.findall(r'\d+\.?\d*', t)
clin_count = 0
for n in nums:
pos = t.find(n)
ctx = t[max(0, pos-40):pos+40].lower()
if any(kw in ctx for kw in clin_kws):
clin_count += 1
q["total_clinical_density"] += clin_count
q["total_scored"] += 1
# Polarity check
if scenario["has_danger_signs"] and n_signs == 0:
q["polarity_errors"] += 1
elif not scenario["has_danger_signs"] and n_signs > 0:
q["polarity_errors"] += 1
else:
cp["failed"] += 1
cp["generated"] += 1
status = "UNKNOWN"
batch_count += 1
pct = cp["generated"] / target * 100
print(f" [{cp['generated']}/{target}] ({pct:.0f}%) ${cp['total_cost']:.4f} | {scenario['label'][:50]} | {status}")
# ── Safety checks ──
if cp["total_cost"] > MAX_BUDGET:
print(f"\nABORT: Budget exceeded (${cp['total_cost']:.2f} > ${MAX_BUDGET})")
break
if cp["generated"] >= MIN_BATCH_FOR_CHECK:
total_attempted = cp["valid"] + cp["failed"]
if total_attempted > 0 and cp["failed"] / total_attempted > FAILURE_RATE_ABORT:
print(f"\nABORT: Failure rate {cp['failed']}/{total_attempted} exceeds threshold")
break
# Checkpoint + quality report
if batch_count >= BATCH_SIZE:
save_checkpoint(cp)
batch_count = 0
# Periodic quality report every 50 samples
q = cp.get("quality", {})
scored = q.get("total_scored", 0)
if scored > 0 and scored % 50 < BATCH_SIZE:
avg_clin = q["total_clinical_density"] / scored
print(f"\n ┌── QUALITY REPORT (n={scored}) ──")
print(f" │ Narration leaks: {q['narration_leaks']} ({q['narration_leaks']/scored*100:.0f}%)")
print(f" │ Low Hindi: {q['low_hindi']} ({q['low_hindi']/scored*100:.0f}%)")
print(f" │ Low turns (<15): {q['low_turns']} ({q['low_turns']/scored*100:.0f}%)")
print(f" │ Polarity errors: {q['polarity_errors']} ({q['polarity_errors']/scored*100:.0f}%)")
print(f" │ Avg clinical #s: {avg_clin:.1f} per sample")
print(f" └{'─' * 35}")
# ABORT on polarity errors > 5%
if scored >= 30 and q["polarity_errors"] / scored > 0.05:
print(f"\n ABORT: Polarity error rate {q['polarity_errors']}/{scored} > 5%")
print(f" This means the model is hallucinating danger signs or missing real ones.")
print(f" Fix prompts before continuing.")
break
# Rate limit
time.sleep(0.3)
outfile.close()
save_checkpoint(cp)
# ── Summary ──
print(f"\n{'=' * 60}")
print("GENERATION SUMMARY")
print("=" * 60)
print(f" Total generated: {cp['generated']}")
print(f" Valid samples: {cp['valid']}")
print(f" Failed: {cp['failed']}")
print(f" Total cost: ${cp['total_cost']:.4f}")
print(f"\n By visit type:")
for vt, count in sorted(cp["by_type"].items()):
print(f" {vt}: {count}")
print(f"\n Danger sign balance:")
print(f" Positive (has danger signs): {cp['by_danger']['positive']}")
print(f" Negative (no danger signs): {cp['by_danger']['negative']}")
total_with_labels = cp['by_danger']['positive'] + cp['by_danger']['negative']
if total_with_labels > 0:
neg_pct = cp['by_danger']['negative'] / total_with_labels * 100
print(f" Negative ratio: {neg_pct:.0f}% (target: ~40%)")
# Quality summary
q = cp.get("quality", {})
scored = q.get("total_scored", 0)
if scored > 0:
avg_clin = q["total_clinical_density"] / scored
print(f"\n Quality metrics:")
print(f" Narration leaks: {q['narration_leaks']}/{scored} ({q['narration_leaks']/scored*100:.0f}%)")
print(f" Low Hindi: {q['low_hindi']}/{scored} ({q['low_hindi']/scored*100:.0f}%)")
print(f" Low turns (<15): {q['low_turns']}/{scored} ({q['low_turns']/scored*100:.0f}%)")
print(f" Polarity errors: {q['polarity_errors']}/{scored} ({q['polarity_errors']/scored*100:.0f}%)")
print(f" Avg clinical #s: {avg_clin:.1f} per sample")
if q['polarity_errors'] > 0:
print(f"\n WARNING: {q['polarity_errors']} polarity errors detected!")
print(f" Review these samples before training.")
print(f"\n Output: {OUTPUT_FILE}")
if args.dry_run:
print(f"\nDRY RUN complete. Review the output file.")
print(f"If quality looks good, run: python scripts/generate_training_data.py --count {args.count}")
if cp["valid"] >= 200:
print(f"\nREADY for next step: python scripts/04_prepare_training.py")
elif cp["valid"] > 0:
print(f"\n{cp['valid']} samples generated. May need more for good fine-tune results.")
print("=" * 60)
if __name__ == "__main__":
main()