sakhi / scripts /augment_unlabeled.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
Sakhi — Augment training data with unlabeled (no speaker labels) transcript variants.
Simulates raw Whisper ASR output by stripping "ASHA:" / "Patient:" / "Mother:" labels
from existing training samples. This teaches the model to handle audio transcripts
where no speaker diarization is available.
Ground truth adjustments:
- patient.name → null if the name was only a generic address ("दीदी", "बहन")
- All other fields stay the same (vitals, danger signs, etc.)
Usage:
python scripts/augment_unlabeled.py
python scripts/augment_unlabeled.py --ratio 0.3 # 30% of samples become unlabeled
"""
import argparse
import copy
import json
import os
import random
import re
import sys
INPUT_FILE = "data/processed/training_data_raw.jsonl"
OUTPUT_FILE = "data/processed/training_data_raw_augmented.jsonl"
# Speaker label patterns to strip
SPEAKER_LABELS = re.compile(
r'^(ASHA|Patient|Mother|Father|Husband|Doctor|ANM|Nurse|CHW|दाई)\s*:\s*',
re.MULTILINE | re.IGNORECASE
)
# Generic address terms that are NOT real names
GENERIC_ADDRESSES = {
"दीदी", "बहन", "बहनजी", "भाई", "भैया", "जी", "अम्मा", "माँ", "माताजी",
"patient", "didi", "bahen", "amma",
}
def strip_speaker_labels(transcript):
"""Remove speaker labels like 'ASHA:', 'Patient:' from transcript."""
# Remove speaker labels
result = SPEAKER_LABELS.sub('', transcript)
# Collapse multiple newlines
result = re.sub(r'\n{2,}', '\n', result)
# Remove leading/trailing whitespace per line
result = '\n'.join(line.strip() for line in result.split('\n') if line.strip())
return result
def fix_ground_truth(form_extraction):
"""Null out fields that would be hallucinated on unlabeled transcripts."""
form = copy.deepcopy(form_extraction)
# Walk through possible patient name locations
for path in [
("patient", "name"),
("patient", "patient_name"),
("patient_details", "name"),
("mother_assessment", "patient_name"),
("visit_info", "patient_name"),
]:
obj = form
for key in path[:-1]:
obj = obj.get(key, {}) if isinstance(obj, dict) else {}
if isinstance(obj, dict) and path[-1] in obj:
name = obj[path[-1]]
if name and str(name).strip().lower() in GENERIC_ADDRESSES:
obj[path[-1]] = None
return form
def main():
parser = argparse.ArgumentParser(description="Augment training data with unlabeled variants")
parser.add_argument("--ratio", type=float, default=0.3,
help="Fraction of samples to create unlabeled variants for (default: 0.3)")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
if not os.path.exists(INPUT_FILE):
print(f"ABORT: {INPUT_FILE} not found. Run generate_training_data.py first.")
sys.exit(1)
# Load raw data
raw_samples = []
with open(INPUT_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
raw_samples.append(json.loads(line))
print(f"Loaded {len(raw_samples)} raw samples")
# Select samples to augment
n_augment = int(len(raw_samples) * args.ratio)
augment_indices = set(random.sample(range(len(raw_samples)), n_augment))
print(f"Creating {n_augment} unlabeled variants ({args.ratio:.0%} of total)")
# Create augmented dataset
augmented = []
n_labels_stripped = 0
for i, sample in enumerate(raw_samples):
# Always include original
augmented.append(sample)
if i in augment_indices:
# Create unlabeled variant
variant = copy.deepcopy(sample)
original_transcript = variant["transcript"]
variant["transcript"] = strip_speaker_labels(original_transcript)
variant["id"] = f"{variant['id']}_unlabeled"
variant["form_extraction"] = fix_ground_truth(variant["form_extraction"])
if variant["transcript"] != original_transcript:
n_labels_stripped += 1
augmented.append(variant)
# Write augmented dataset
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
for sample in augmented:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f"\nResults:")
print(f" Original samples: {len(raw_samples)}")
print(f" Unlabeled variants added: {n_augment}")
print(f" Labels actually stripped: {n_labels_stripped}")
print(f" Total samples: {len(augmented)}")
print(f" Written to: {OUTPUT_FILE}")
print(f"\nNext: Run prepare_training.py with --input {OUTPUT_FILE}")
# Show a before/after example
if augment_indices:
idx = min(augment_indices)
orig = raw_samples[idx]["transcript"]
stripped = strip_speaker_labels(orig)
print(f"\n--- Example (sample {idx}) ---")
print(f"BEFORE (first 200 chars):\n {orig[:200]}")
print(f"\nAFTER (first 200 chars):\n {stripped[:200]}")
if __name__ == "__main__":
main()