Spaces:
Sleeping
Sleeping
| """ | |
| Sakhi — Prepare Training Data for Unsloth | |
| Converts raw generated data into chat-format JSONL for SFTTrainer. | |
| Fixes from v1: strips schema metadata from assistant outputs, uses trimmed | |
| danger schema (matching production), correct system prompts. | |
| Usage: | |
| python scripts/prepare_training.py | |
| """ | |
| import json | |
| import os | |
| import random | |
| import sys | |
| from pathlib import Path | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| DEFAULT_INPUT_FILE = "data/processed/training_data_raw.jsonl" | |
| AUGMENTED_INPUT_FILE = "data/processed/training_data_raw_augmented.jsonl" | |
| INPUT_FILE = AUGMENTED_INPUT_FILE if os.path.exists(AUGMENTED_INPUT_FILE) else DEFAULT_INPUT_FILE | |
| TRAIN_FILE = "data/processed/train.jsonl" | |
| VAL_FILE = "data/processed/val.jsonl" | |
| STATS_FILE = "data/processed/data_stats.json" | |
| # Match production prompts exactly (from app.py) | |
| FORM_SYSTEM_PROMPT = ( | |
| "You are a clinical data extraction system for India's ASHA health worker program. " | |
| "Extract structured data from the Hindi/Hinglish home visit conversation into the requested JSON schema. " | |
| "ONLY extract information explicitly stated in the conversation. Use null for any field not mentioned.\n\n" | |
| "STRICT RULES:\n" | |
| "1. Do NOT invent names, dates, phone numbers, or addresses. If the patient is only called 'दीदी' or 'बहन', set name to null.\n" | |
| "2. If age is not explicitly stated as a number, set age to null. Do NOT guess from context.\n" | |
| "3. If blood group, HIV status, or other lab tests are not discussed, they MUST be null — never assume 'negative' or a default group.\n" | |
| "4. If the conversation has no speaker labels (ASHA/Patient), still extract data but be extra strict about nulls.\n" | |
| "5. Numbers may appear as Hindi words (e.g., 'एक सो दस बटा सत्तर' = 110/70). Convert them to digits.\n" | |
| "Return valid JSON only." | |
| ) | |
| DANGER_SYSTEM_PROMPT = ( | |
| "You are a clinical danger sign detection system for India's ASHA health worker program. " | |
| "Analyze the Hindi/Hinglish home visit conversation for NHM-defined danger signs.\n\n" | |
| "STRICT RULES:\n" | |
| "1. ONLY flag a danger sign if the EXACT words proving it appear in the conversation.\n" | |
| "2. utterance_evidence MUST be a verbatim copy-paste from the conversation — do NOT paraphrase or fabricate.\n" | |
| "3. If a vital sign is NORMAL (e.g., BP 110/70, temperature 37°C), that is NOT a danger sign.\n" | |
| "4. Most routine visits have ZERO danger signs. Return an empty danger_signs array when none exist.\n" | |
| "5. When in doubt, do NOT flag — a missed flag is better than a false alarm.\n" | |
| "Return valid JSON only." | |
| ) | |
| def load_schema(name: str) -> dict: | |
| with open(f"configs/schemas/{name}.json", "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def build_trimmed_danger_schema(): | |
| """Match production: trimmed danger schema without checklists.""" | |
| return { | |
| "type": "object", | |
| "properties": { | |
| "visit_type": { | |
| "type": "string", | |
| "enum": ["antenatal", "postnatal_mother", "newborn", "child_under5"], | |
| }, | |
| "danger_signs": { | |
| "type": "array", | |
| "description": "Detected danger signs. Empty array [] if none found.", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "sign": {"type": "string"}, | |
| "category": {"type": "string", "enum": ["immediate_referral", "urgent_care", "monitor_closely"]}, | |
| "clinical_value": {"type": ["string", "null"]}, | |
| "utterance_evidence": {"type": "string", "description": "REQUIRED: exact verbatim quote"}, | |
| }, | |
| "required": ["sign", "category", "utterance_evidence"], | |
| }, | |
| }, | |
| "referral_decision": { | |
| "type": "object", | |
| "properties": { | |
| "decision": {"type": "string", "enum": ["refer_immediately", "refer_within_24h", "continue_monitoring", "routine_followup"]}, | |
| "reason": {"type": "string"}, | |
| }, | |
| "required": ["decision", "reason"], | |
| }, | |
| }, | |
| "required": ["visit_type", "danger_signs", "referral_decision"], | |
| } | |
| def clean_form_output(form_data: dict) -> dict: | |
| """Strip any schema metadata from form extraction output.""" | |
| # Remove JSON Schema metadata keys that GPT-4o sometimes includes | |
| for key in ("$schema", "title", "description", "$id", "$ref"): | |
| form_data.pop(key, None) | |
| return form_data | |
| def clean_danger_output(danger_data: dict) -> dict: | |
| """Strip schema metadata and checklists — match production trimmed format.""" | |
| # Remove schema metadata | |
| for key in ("$schema", "title", "description", "$id", "$ref"): | |
| danger_data.pop(key, None) | |
| # Remove checklists (production derives these programmatically) | |
| danger_data.pop("maternal_danger_signs_checklist", None) | |
| danger_data.pop("newborn_danger_signs_checklist", None) | |
| # Remove evidence_utterances from referral (production builds this from signs) | |
| ref = danger_data.get("referral_decision", {}) | |
| ref.pop("evidence_utterances", None) | |
| ref.pop("recommended_facility", None) | |
| # Strip confidence from individual signs (not in trimmed schema) | |
| for sign in danger_data.get("danger_signs", []): | |
| sign.pop("confidence", None) | |
| return danger_data | |
| def build_form_user_message(transcript: str, schema: dict) -> str: | |
| return ( | |
| f"Extract structured data from this ASHA home visit conversation:\n\n" | |
| f"{transcript}\n\n" | |
| f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}" | |
| ) | |
| def build_danger_user_message(transcript: str, visit_type: str, schema: dict) -> str: | |
| return ( | |
| f"Analyze this ASHA home visit conversation for danger signs.\n\n" | |
| f"Visit type: {visit_type}\n\n" | |
| f"{transcript}\n\n" | |
| f"Output JSON schema:\n{json.dumps(schema, ensure_ascii=False)}" | |
| ) | |
| def raw_to_training_examples(sample: dict, schemas: dict, danger_schema_trimmed: dict) -> list[dict]: | |
| """Convert one raw sample into 1-2 training examples (chat format).""" | |
| examples = [] | |
| transcript = sample["transcript"] | |
| visit_type = sample["visit_type"] | |
| form_schema_name = sample["form_schema"] | |
| form_schema = schemas[form_schema_name] | |
| # ── Example 1: Form extraction ── | |
| form_output = clean_form_output(dict(sample["form_extraction"])) | |
| examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": FORM_SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_form_user_message(transcript, form_schema)}, | |
| {"role": "assistant", "content": json.dumps(form_output, ensure_ascii=False)}, | |
| ], | |
| "metadata": { | |
| "task": "form_extraction", | |
| "visit_type": visit_type, | |
| "schema": form_schema_name, | |
| "has_danger_signs": sample["has_danger_signs"], | |
| "source_id": sample["id"], | |
| }, | |
| }) | |
| # ── Example 2: Danger sign detection (trimmed schema, matching production) ── | |
| danger_output = clean_danger_output(dict(sample["danger_signs_extraction"])) | |
| examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": DANGER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_danger_user_message(transcript, visit_type, danger_schema_trimmed)}, | |
| {"role": "assistant", "content": json.dumps(danger_output, ensure_ascii=False)}, | |
| ], | |
| "metadata": { | |
| "task": "danger_signs", | |
| "visit_type": visit_type, | |
| "has_danger_signs": sample["has_danger_signs"], | |
| "source_id": sample["id"], | |
| }, | |
| }) | |
| return examples | |
| def main(): | |
| random.seed(42) | |
| if not os.path.exists(INPUT_FILE): | |
| print(f"ABORT: Input not found: {INPUT_FILE}") | |
| sys.exit(1) | |
| # Load schemas | |
| schemas = {} | |
| for name in ["anc_visit", "pnc_visit", "delivery", "child_health"]: | |
| schemas[name] = load_schema(name) | |
| danger_schema_trimmed = build_trimmed_danger_schema() | |
| # Load raw data | |
| raw_samples = [] | |
| with open(INPUT_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| raw_samples.append(json.loads(line)) | |
| print(f"Loaded {len(raw_samples)} raw samples from {INPUT_FILE}") | |
| # Convert to training examples | |
| all_examples = [] | |
| schema_leak_fixed = 0 | |
| for sample in raw_samples: | |
| # Count schema leakage fixes | |
| if "$schema" in sample.get("danger_signs_extraction", {}): | |
| schema_leak_fixed += 1 | |
| if "$schema" in sample.get("form_extraction", {}): | |
| schema_leak_fixed += 1 | |
| examples = raw_to_training_examples(sample, schemas, danger_schema_trimmed) | |
| all_examples.extend(examples) | |
| print(f"Produced {len(all_examples)} training examples") | |
| if schema_leak_fixed: | |
| print(f"Fixed schema leakage in {schema_leak_fixed} examples") | |
| # Verify no leakage remains | |
| leaked = 0 | |
| for ex in all_examples: | |
| content = ex["messages"][2]["content"] | |
| if '"$schema"' in content or '"title": "' in content[:100]: | |
| leaked += 1 | |
| if leaked: | |
| print(f"WARNING: {leaked} examples still have schema leakage!") | |
| else: | |
| print(f"Schema leakage check: CLEAN") | |
| # ── Oversample positive danger sign examples to ~45% ── | |
| danger_positive = [ex for ex in all_examples | |
| if ex["metadata"]["task"] == "danger_signs" and ex["metadata"]["has_danger_signs"]] | |
| danger_negative = [ex for ex in all_examples | |
| if ex["metadata"]["task"] == "danger_signs" and not ex["metadata"]["has_danger_signs"]] | |
| if danger_positive and danger_negative: | |
| current_ratio = len(danger_positive) / (len(danger_positive) + len(danger_negative)) | |
| target_ratio = 0.45 | |
| if current_ratio < target_ratio: | |
| extra_needed = int((target_ratio * len(danger_negative)) / (1 - target_ratio)) - len(danger_positive) | |
| extra_needed = max(0, extra_needed) | |
| if extra_needed > 0: | |
| oversampled = random.choices(danger_positive, k=extra_needed) | |
| all_examples.extend(oversampled) | |
| new_pos = len(danger_positive) + extra_needed | |
| new_total = new_pos + len(danger_negative) | |
| print(f"Oversampled: +{extra_needed} positive danger examples " | |
| f"({current_ratio:.0%} -> {new_pos/new_total:.0%})") | |
| random.shuffle(all_examples) | |
| # Split | |
| val_count = max(1, int(len(all_examples) * 0.15)) | |
| val_examples = all_examples[:val_count] | |
| train_examples = all_examples[val_count:] | |
| print(f"Split: {len(train_examples)} train / {len(val_examples)} val") | |
| # Write | |
| for path, examples in [(TRAIN_FILE, train_examples), (VAL_FILE, val_examples)]: | |
| with open(path, "w", encoding="utf-8") as f: | |
| for ex in examples: | |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") | |
| print(f"Wrote {path}") | |
| # Stats | |
| stats = { | |
| "raw_samples": len(raw_samples), | |
| "total_examples": len(all_examples), | |
| "train": len(train_examples), | |
| "val": len(val_examples), | |
| "schema_leaks_fixed": schema_leak_fixed, | |
| } | |
| with open(STATS_FILE, "w") as f: | |
| json.dump(stats, f, indent=2) | |
| print(f"\nReady for training: python scripts/train_unsloth.py") | |
| if __name__ == "__main__": | |
| main() | |