""" Sakhi — Augment training data with unlabeled (no speaker labels) transcript variants. Simulates raw Whisper ASR output by stripping "ASHA:" / "Patient:" / "Mother:" labels from existing training samples. This teaches the model to handle audio transcripts where no speaker diarization is available. Ground truth adjustments: - patient.name → null if the name was only a generic address ("दीदी", "बहन") - All other fields stay the same (vitals, danger signs, etc.) Usage: python scripts/augment_unlabeled.py python scripts/augment_unlabeled.py --ratio 0.3 # 30% of samples become unlabeled """ import argparse import copy import json import os import random import re import sys INPUT_FILE = "data/processed/training_data_raw.jsonl" OUTPUT_FILE = "data/processed/training_data_raw_augmented.jsonl" # Speaker label patterns to strip SPEAKER_LABELS = re.compile( r'^(ASHA|Patient|Mother|Father|Husband|Doctor|ANM|Nurse|CHW|दाई)\s*:\s*', re.MULTILINE | re.IGNORECASE ) # Generic address terms that are NOT real names GENERIC_ADDRESSES = { "दीदी", "बहन", "बहनजी", "भाई", "भैया", "जी", "अम्मा", "माँ", "माताजी", "patient", "didi", "bahen", "amma", } def strip_speaker_labels(transcript): """Remove speaker labels like 'ASHA:', 'Patient:' from transcript.""" # Remove speaker labels result = SPEAKER_LABELS.sub('', transcript) # Collapse multiple newlines result = re.sub(r'\n{2,}', '\n', result) # Remove leading/trailing whitespace per line result = '\n'.join(line.strip() for line in result.split('\n') if line.strip()) return result def fix_ground_truth(form_extraction): """Null out fields that would be hallucinated on unlabeled transcripts.""" form = copy.deepcopy(form_extraction) # Walk through possible patient name locations for path in [ ("patient", "name"), ("patient", "patient_name"), ("patient_details", "name"), ("mother_assessment", "patient_name"), ("visit_info", "patient_name"), ]: obj = form for key in path[:-1]: obj = obj.get(key, {}) if isinstance(obj, dict) else {} if isinstance(obj, dict) and path[-1] in obj: name = obj[path[-1]] if name and str(name).strip().lower() in GENERIC_ADDRESSES: obj[path[-1]] = None return form def main(): parser = argparse.ArgumentParser(description="Augment training data with unlabeled variants") parser.add_argument("--ratio", type=float, default=0.3, help="Fraction of samples to create unlabeled variants for (default: 0.3)") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() random.seed(args.seed) if not os.path.exists(INPUT_FILE): print(f"ABORT: {INPUT_FILE} not found. Run generate_training_data.py first.") sys.exit(1) # Load raw data raw_samples = [] with open(INPUT_FILE, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: raw_samples.append(json.loads(line)) print(f"Loaded {len(raw_samples)} raw samples") # Select samples to augment n_augment = int(len(raw_samples) * args.ratio) augment_indices = set(random.sample(range(len(raw_samples)), n_augment)) print(f"Creating {n_augment} unlabeled variants ({args.ratio:.0%} of total)") # Create augmented dataset augmented = [] n_labels_stripped = 0 for i, sample in enumerate(raw_samples): # Always include original augmented.append(sample) if i in augment_indices: # Create unlabeled variant variant = copy.deepcopy(sample) original_transcript = variant["transcript"] variant["transcript"] = strip_speaker_labels(original_transcript) variant["id"] = f"{variant['id']}_unlabeled" variant["form_extraction"] = fix_ground_truth(variant["form_extraction"]) if variant["transcript"] != original_transcript: n_labels_stripped += 1 augmented.append(variant) # Write augmented dataset with open(OUTPUT_FILE, "w", encoding="utf-8") as f: for sample in augmented: f.write(json.dumps(sample, ensure_ascii=False) + "\n") print(f"\nResults:") print(f" Original samples: {len(raw_samples)}") print(f" Unlabeled variants added: {n_augment}") print(f" Labels actually stripped: {n_labels_stripped}") print(f" Total samples: {len(augmented)}") print(f" Written to: {OUTPUT_FILE}") print(f"\nNext: Run prepare_training.py with --input {OUTPUT_FILE}") # Show a before/after example if augment_indices: idx = min(augment_indices) orig = raw_samples[idx]["transcript"] stripped = strip_speaker_labels(orig) print(f"\n--- Example (sample {idx}) ---") print(f"BEFORE (first 200 chars):\n {orig[:200]}") print(f"\nAFTER (first 200 chars):\n {stripped[:200]}") if __name__ == "__main__": main()