File size: 5,182 Bytes
745f62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Sakhi — Augment training data with unlabeled (no speaker labels) transcript variants.

Simulates raw Whisper ASR output by stripping "ASHA:" / "Patient:" / "Mother:" labels
from existing training samples. This teaches the model to handle audio transcripts
where no speaker diarization is available.

Ground truth adjustments:
  - patient.name → null if the name was only a generic address ("दीदी", "बहन")
  - All other fields stay the same (vitals, danger signs, etc.)

Usage:
    python scripts/augment_unlabeled.py
    python scripts/augment_unlabeled.py --ratio 0.3  # 30% of samples become unlabeled
"""

import argparse
import copy
import json
import os
import random
import re
import sys

INPUT_FILE = "data/processed/training_data_raw.jsonl"
OUTPUT_FILE = "data/processed/training_data_raw_augmented.jsonl"

# Speaker label patterns to strip
SPEAKER_LABELS = re.compile(
    r'^(ASHA|Patient|Mother|Father|Husband|Doctor|ANM|Nurse|CHW|दाई)\s*:\s*',
    re.MULTILINE | re.IGNORECASE
)

# Generic address terms that are NOT real names
GENERIC_ADDRESSES = {
    "दीदी", "बहन", "बहनजी", "भाई", "भैया", "जी", "अम्मा", "माँ", "माताजी",
    "patient", "didi", "bahen", "amma",
}


def strip_speaker_labels(transcript):
    """Remove speaker labels like 'ASHA:', 'Patient:' from transcript."""
    # Remove speaker labels
    result = SPEAKER_LABELS.sub('', transcript)
    # Collapse multiple newlines
    result = re.sub(r'\n{2,}', '\n', result)
    # Remove leading/trailing whitespace per line
    result = '\n'.join(line.strip() for line in result.split('\n') if line.strip())
    return result


def fix_ground_truth(form_extraction):
    """Null out fields that would be hallucinated on unlabeled transcripts."""
    form = copy.deepcopy(form_extraction)

    # Walk through possible patient name locations
    for path in [
        ("patient", "name"),
        ("patient", "patient_name"),
        ("patient_details", "name"),
        ("mother_assessment", "patient_name"),
        ("visit_info", "patient_name"),
    ]:
        obj = form
        for key in path[:-1]:
            obj = obj.get(key, {}) if isinstance(obj, dict) else {}
        if isinstance(obj, dict) and path[-1] in obj:
            name = obj[path[-1]]
            if name and str(name).strip().lower() in GENERIC_ADDRESSES:
                obj[path[-1]] = None

    return form


def main():
    parser = argparse.ArgumentParser(description="Augment training data with unlabeled variants")
    parser.add_argument("--ratio", type=float, default=0.3,
                        help="Fraction of samples to create unlabeled variants for (default: 0.3)")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    random.seed(args.seed)

    if not os.path.exists(INPUT_FILE):
        print(f"ABORT: {INPUT_FILE} not found. Run generate_training_data.py first.")
        sys.exit(1)

    # Load raw data
    raw_samples = []
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                raw_samples.append(json.loads(line))

    print(f"Loaded {len(raw_samples)} raw samples")

    # Select samples to augment
    n_augment = int(len(raw_samples) * args.ratio)
    augment_indices = set(random.sample(range(len(raw_samples)), n_augment))
    print(f"Creating {n_augment} unlabeled variants ({args.ratio:.0%} of total)")

    # Create augmented dataset
    augmented = []
    n_labels_stripped = 0

    for i, sample in enumerate(raw_samples):
        # Always include original
        augmented.append(sample)

        if i in augment_indices:
            # Create unlabeled variant
            variant = copy.deepcopy(sample)
            original_transcript = variant["transcript"]
            variant["transcript"] = strip_speaker_labels(original_transcript)
            variant["id"] = f"{variant['id']}_unlabeled"
            variant["form_extraction"] = fix_ground_truth(variant["form_extraction"])

            if variant["transcript"] != original_transcript:
                n_labels_stripped += 1

            augmented.append(variant)

    # Write augmented dataset
    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        for sample in augmented:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")

    print(f"\nResults:")
    print(f"  Original samples: {len(raw_samples)}")
    print(f"  Unlabeled variants added: {n_augment}")
    print(f"  Labels actually stripped: {n_labels_stripped}")
    print(f"  Total samples: {len(augmented)}")
    print(f"  Written to: {OUTPUT_FILE}")
    print(f"\nNext: Run prepare_training.py with --input {OUTPUT_FILE}")

    # Show a before/after example
    if augment_indices:
        idx = min(augment_indices)
        orig = raw_samples[idx]["transcript"]
        stripped = strip_speaker_labels(orig)
        print(f"\n--- Example (sample {idx}) ---")
        print(f"BEFORE (first 200 chars):\n  {orig[:200]}")
        print(f"\nAFTER (first 200 chars):\n  {stripped[:200]}")


if __name__ == "__main__":
    main()