sakhi / scripts /prepare_training.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
Sakhi — Prepare Training Data for Unsloth
Converts raw generated data into chat-format JSONL for SFTTrainer.
Fixes from v1: strips schema metadata from assistant outputs, uses trimmed
danger schema (matching production), correct system prompts.
Usage:
python scripts/prepare_training.py
"""
import json
import os
import random
import sys
from pathlib import Path
# ============================================================
# CONFIG
# ============================================================
DEFAULT_INPUT_FILE = "data/processed/training_data_raw.jsonl"
AUGMENTED_INPUT_FILE = "data/processed/training_data_raw_augmented.jsonl"
INPUT_FILE = AUGMENTED_INPUT_FILE if os.path.exists(AUGMENTED_INPUT_FILE) else DEFAULT_INPUT_FILE
TRAIN_FILE = "data/processed/train.jsonl"
VAL_FILE = "data/processed/val.jsonl"
STATS_FILE = "data/processed/data_stats.json"
# Match production prompts exactly (from app.py)
FORM_SYSTEM_PROMPT = (
"You are a clinical data extraction system for India's ASHA health worker program. "
"Extract structured data from the Hindi/Hinglish home visit conversation into the requested JSON schema. "
"ONLY extract information explicitly stated in the conversation. Use null for any field not mentioned.\n\n"
"STRICT RULES:\n"
"1. Do NOT invent names, dates, phone numbers, or addresses. If the patient is only called 'दीदी' or 'बहन', set name to null.\n"
"2. If age is not explicitly stated as a number, set age to null. Do NOT guess from context.\n"
"3. If blood group, HIV status, or other lab tests are not discussed, they MUST be null — never assume 'negative' or a default group.\n"
"4. If the conversation has no speaker labels (ASHA/Patient), still extract data but be extra strict about nulls.\n"
"5. Numbers may appear as Hindi words (e.g., 'एक सो दस बटा सत्तर' = 110/70). Convert them to digits.\n"
"Return valid JSON only."
)
DANGER_SYSTEM_PROMPT = (
"You are a clinical danger sign detection system for India's ASHA health worker program. "
"Analyze the Hindi/Hinglish home visit conversation for NHM-defined danger signs.\n\n"
"STRICT RULES:\n"
"1. ONLY flag a danger sign if the EXACT words proving it appear in the conversation.\n"
"2. utterance_evidence MUST be a verbatim copy-paste from the conversation — do NOT paraphrase or fabricate.\n"
"3. If a vital sign is NORMAL (e.g., BP 110/70, temperature 37°C), that is NOT a danger sign.\n"
"4. Most routine visits have ZERO danger signs. Return an empty danger_signs array when none exist.\n"
"5. When in doubt, do NOT flag — a missed flag is better than a false alarm.\n"
"Return valid JSON only."
)
def load_schema(name: str) -> dict:
with open(f"configs/schemas/{name}.json", "r", encoding="utf-8") as f:
return json.load(f)
def build_trimmed_danger_schema():
"""Match production: trimmed danger schema without checklists."""
return {
"type": "object",
"properties": {
"visit_type": {
"type": "string",
"enum": ["antenatal", "postnatal_mother", "newborn", "child_under5"],
},
"danger_signs": {
"type": "array",
"description": "Detected danger signs. Empty array [] if none found.",
"items": {
"type": "object",
"properties": {
"sign": {"type": "string"},
"category": {"type": "string", "enum": ["immediate_referral", "urgent_care", "monitor_closely"]},
"clinical_value": {"type": ["string", "null"]},
"utterance_evidence": {"type": "string", "description": "REQUIRED: exact verbatim quote"},
},
"required": ["sign", "category", "utterance_evidence"],
},
},
"referral_decision": {
"type": "object",
"properties": {
"decision": {"type": "string", "enum": ["refer_immediately", "refer_within_24h", "continue_monitoring", "routine_followup"]},
"reason": {"type": "string"},
},
"required": ["decision", "reason"],
},
},
"required": ["visit_type", "danger_signs", "referral_decision"],
}
def clean_form_output(form_data: dict) -> dict:
"""Strip any schema metadata from form extraction output."""
# Remove JSON Schema metadata keys that GPT-4o sometimes includes
for key in ("$schema", "title", "description", "$id", "$ref"):
form_data.pop(key, None)
return form_data
def clean_danger_output(danger_data: dict) -> dict:
"""Strip schema metadata and checklists — match production trimmed format."""
# Remove schema metadata
for key in ("$schema", "title", "description", "$id", "$ref"):
danger_data.pop(key, None)
# Remove checklists (production derives these programmatically)
danger_data.pop("maternal_danger_signs_checklist", None)
danger_data.pop("newborn_danger_signs_checklist", None)
# Remove evidence_utterances from referral (production builds this from signs)
ref = danger_data.get("referral_decision", {})
ref.pop("evidence_utterances", None)
ref.pop("recommended_facility", None)
# Strip confidence from individual signs (not in trimmed schema)
for sign in danger_data.get("danger_signs", []):
sign.pop("confidence", None)
return danger_data
def build_form_user_message(transcript: str, schema: dict) -> str:
return (
f"Extract structured data from this ASHA home visit conversation:\n\n"
f"{transcript}\n\n"
f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}"
)
def build_danger_user_message(transcript: str, visit_type: str, schema: dict) -> str:
return (
f"Analyze this ASHA home visit conversation for danger signs.\n\n"
f"Visit type: {visit_type}\n\n"
f"{transcript}\n\n"
f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}"
)
def raw_to_training_examples(sample: dict, schemas: dict, danger_schema_trimmed: dict) -> list[dict]:
"""Convert one raw sample into 1-2 training examples (chat format)."""
examples = []
transcript = sample["transcript"]
visit_type = sample["visit_type"]
form_schema_name = sample["form_schema"]
form_schema = schemas[form_schema_name]
# ── Example 1: Form extraction ──
form_output = clean_form_output(dict(sample["form_extraction"]))
examples.append({
"messages": [
{"role": "system", "content": FORM_SYSTEM_PROMPT},
{"role": "user", "content": build_form_user_message(transcript, form_schema)},
{"role": "assistant", "content": json.dumps(form_output, ensure_ascii=False)},
],
"metadata": {
"task": "form_extraction",
"visit_type": visit_type,
"schema": form_schema_name,
"has_danger_signs": sample["has_danger_signs"],
"source_id": sample["id"],
},
})
# ── Example 2: Danger sign detection (trimmed schema, matching production) ──
danger_output = clean_danger_output(dict(sample["danger_signs_extraction"]))
examples.append({
"messages": [
{"role": "system", "content": DANGER_SYSTEM_PROMPT},
{"role": "user", "content": build_danger_user_message(transcript, visit_type, danger_schema_trimmed)},
{"role": "assistant", "content": json.dumps(danger_output, ensure_ascii=False)},
],
"metadata": {
"task": "danger_signs",
"visit_type": visit_type,
"has_danger_signs": sample["has_danger_signs"],
"source_id": sample["id"],
},
})
return examples
def main():
random.seed(42)
if not os.path.exists(INPUT_FILE):
print(f"ABORT: Input not found: {INPUT_FILE}")
sys.exit(1)
# Load schemas
schemas = {}
for name in ["anc_visit", "pnc_visit", "delivery", "child_health"]:
schemas[name] = load_schema(name)
danger_schema_trimmed = build_trimmed_danger_schema()
# Load raw data
raw_samples = []
with open(INPUT_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
raw_samples.append(json.loads(line))
print(f"Loaded {len(raw_samples)} raw samples from {INPUT_FILE}")
# Convert to training examples
all_examples = []
schema_leak_fixed = 0
for sample in raw_samples:
# Count schema leakage fixes
if "$schema" in sample.get("danger_signs_extraction", {}):
schema_leak_fixed += 1
if "$schema" in sample.get("form_extraction", {}):
schema_leak_fixed += 1
examples = raw_to_training_examples(sample, schemas, danger_schema_trimmed)
all_examples.extend(examples)
print(f"Produced {len(all_examples)} training examples")
if schema_leak_fixed:
print(f"Fixed schema leakage in {schema_leak_fixed} examples")
# Verify no leakage remains
leaked = 0
for ex in all_examples:
content = ex["messages"][2]["content"]
if '"$schema"' in content or '"title": "' in content[:100]:
leaked += 1
if leaked:
print(f"WARNING: {leaked} examples still have schema leakage!")
else:
print(f"Schema leakage check: CLEAN")
# ── Oversample positive danger sign examples to ~45% ──
danger_positive = [ex for ex in all_examples
if ex["metadata"]["task"] == "danger_signs" and ex["metadata"]["has_danger_signs"]]
danger_negative = [ex for ex in all_examples
if ex["metadata"]["task"] == "danger_signs" and not ex["metadata"]["has_danger_signs"]]
if danger_positive and danger_negative:
current_ratio = len(danger_positive) / (len(danger_positive) + len(danger_negative))
target_ratio = 0.45
if current_ratio < target_ratio:
extra_needed = int((target_ratio * len(danger_negative)) / (1 - target_ratio)) - len(danger_positive)
extra_needed = max(0, extra_needed)
if extra_needed > 0:
oversampled = random.choices(danger_positive, k=extra_needed)
all_examples.extend(oversampled)
new_pos = len(danger_positive) + extra_needed
new_total = new_pos + len(danger_negative)
print(f"Oversampled: +{extra_needed} positive danger examples "
f"({current_ratio:.0%} -> {new_pos/new_total:.0%})")
random.shuffle(all_examples)
# Split
val_count = max(1, int(len(all_examples) * 0.15))
val_examples = all_examples[:val_count]
train_examples = all_examples[val_count:]
print(f"Split: {len(train_examples)} train / {len(val_examples)} val")
# Write
for path, examples in [(TRAIN_FILE, train_examples), (VAL_FILE, val_examples)]:
with open(path, "w", encoding="utf-8") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"Wrote {path}")
# Stats
stats = {
"raw_samples": len(raw_samples),
"total_examples": len(all_examples),
"train": len(train_examples),
"val": len(val_examples),
"schema_leaks_fixed": schema_leak_fixed,
}
with open(STATS_FILE, "w") as f:
json.dump(stats, f, indent=2)
print(f"\nReady for training: python scripts/train_unsloth.py")
if __name__ == "__main__":
main()