File size: 16,871 Bytes
0422215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
"""
Step 6: Audio sync β€” match synthesised segment durations to original timestamps.

For each segment:
  - Too long  β†’ speed up using ffmpeg atempo filter
  - Too short β†’ pad with silence at the end
Then stitch all segments into a single final audio track.
"""
import array
import math
import os
import subprocess
import wave
from pathlib import Path


def _get_wav_duration(wav_path: str) -> float:
    with wave.open(wav_path, 'r') as f:
        frames = f.getnframes()
        rate = f.getframerate()
        return frames / float(rate)


def _speedup_audio(input_path: str, output_path: str, factor: float) -> None:
    """Speed up/slow down audio by factor using ffmpeg atempo (supports 0.5–100x via chaining)."""
    # atempo supports 0.5 to 2.0, chain filters for larger factors
    filters = []
    remaining = factor
    while remaining > 2.0:
        filters.append("atempo=2.0")
        remaining /= 2.0
    while remaining < 0.5:
        filters.append("atempo=0.5")
        remaining /= 0.5
    filters.append(f"atempo={remaining:.4f}")
    filter_str = ",".join(filters)

    cmd = [
        "ffmpeg", "-y", "-i", input_path,
        "-filter:a", filter_str,
        output_path,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg atempo failed:\n{result.stderr}")


def _pad_silence(input_path: str, output_path: str, target_duration: float) -> None:
    """Pad audio with silence to reach target_duration seconds."""
    current = _get_wav_duration(input_path)
    pad_seconds = max(0, target_duration - current)

    cmd = [
        "ffmpeg", "-y", "-i", input_path,
        "-af", f"apad=pad_dur={pad_seconds:.4f}",
        "-t", str(target_duration),
        output_path,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg apad failed:\n{result.stderr}")


def _trim_audio(input_path: str, output_path: str, duration: float) -> None:
    """Trim audio to exactly duration seconds."""
    tmp = output_path + ".trim.wav"
    cmd = ["ffmpeg", "-y", "-i", input_path, "-t", str(duration), tmp]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg trim failed:\n{result.stderr}")
    os.replace(tmp, output_path)


def _detect_pauses(words: list[dict], min_pause: float = 0.15) -> list[dict]:
    """Find gaps between consecutive words that exceed min_pause seconds.

    Returns list of {after_word_idx, position, duration} sorted by position.
    """
    pauses = []
    for i in range(len(words) - 1):
        gap = words[i + 1]["start"] - words[i]["end"]
        if gap >= min_pause:
            pauses.append({
                "after_word_idx": i,
                "position": words[i]["end"],
                "duration": gap,
            })
    return pauses


def _find_tts_silences(wav_path: str, threshold_db: float = -35.0,
                       min_dur: float = 0.08) -> list[dict]:
    """Find silence regions in a TTS WAV using RMS energy.

    Returns list of {start, end, duration} for each detected silence region.
    """
    with wave.open(wav_path, "r") as f:
        n_frames = f.getnframes()
        sample_rate = f.getframerate()
        raw = f.readframes(n_frames)

    # Convert raw bytes to 16-bit signed samples
    samples = array.array("h", raw)

    window_size = int(0.02 * sample_rate)  # 20 ms windows
    hop = window_size // 2
    threshold_linear = 10 ** (threshold_db / 20.0) * 32768  # dBFS to linear amplitude

    silences: list[dict] = []
    in_silence = False
    silence_start = 0.0

    for pos in range(0, len(samples) - window_size, hop):
        chunk = samples[pos:pos + window_size]
        rms = math.sqrt(sum(s * s for s in chunk) / window_size)
        t = pos / sample_rate

        if rms < threshold_linear:
            if not in_silence:
                in_silence = True
                silence_start = t
        else:
            if in_silence:
                dur = t - silence_start
                if dur >= min_dur:
                    silences.append({"start": silence_start, "end": t, "duration": dur})
                in_silence = False

    # Close trailing silence
    if in_silence:
        t_end = len(samples) / sample_rate
        dur = t_end - silence_start
        if dur >= min_dur:
            silences.append({"start": silence_start, "end": t_end, "duration": dur})

    return silences


def _read_wav_samples(wav_path: str) -> tuple[array.array, int]:
    """Read a mono 16-bit WAV and return (samples, sample_rate)."""
    with wave.open(wav_path, "r") as f:
        sr = f.getframerate()
        raw = f.readframes(f.getnframes())
    return array.array("h", raw), sr


def _write_wav_samples(samples: array.array, sample_rate: int, output_path: str) -> None:
    """Write 16-bit mono samples to a WAV file."""
    with wave.open(output_path, "w") as f:
        f.setnchannels(1)
        f.setsampwidth(2)
        f.setframerate(sample_rate)
        f.writeframes(samples.tobytes())


def _pause_aware_sync(tts_path: str, synced_path: str, target_duration: float,
                      words: list[dict], max_speed: float,
                      max_overflow: float = 0.0) -> None:
    """Sync TTS audio using pause-aware strategy: compress silences first, then atempo.

    When TTS is too long: shrink detected silence regions before speeding up speech.
    When TTS is too short: distribute extra padding at natural pause points.

    `max_overflow`: extra seconds the synced output may exceed target_duration without
    trimming. The caller borrows this budget from the inter-segment silence that follows,
    so we never silently drop trailing words just to hit `target_duration` exactly.
    """
    tts_duration = _get_wav_duration(tts_path)
    original_pauses = _detect_pauses(words)
    tts_silences = _find_tts_silences(tts_path)

    total_tts_silence = sum(s["duration"] for s in tts_silences)
    hard_cap = target_duration + max_overflow
    overshoot_vs_cap = tts_duration - hard_cap

    if tts_duration > target_duration * 1.02:
        if tts_silences and total_tts_silence > 0:
            if overshoot_vs_cap <= 0:
                # Already within hard_cap once we factor in the borrow budget β€” keep TTS as-is.
                import shutil
                shutil.copy(tts_path, synced_path)
                print(f"[s5]   pause-aware: within +{max_overflow:.2f}s borrow, no compression")
            else:
                removable = min(total_tts_silence * 0.9, overshoot_vs_cap)
                if removable >= overshoot_vs_cap:
                    compression_ratio = 1.0 - (removable / total_tts_silence)
                    _compress_silences(tts_path, synced_path, tts_silences, compression_ratio)
                    print(f"[s5]   pause-aware: compressed silences (ratio {compression_ratio:.2f}, +{max_overflow:.2f}s borrow)")
                else:
                    _compress_silences(tts_path, synced_path, tts_silences, 0.1)  # keep 10%
                    remaining_dur = _get_wav_duration(synced_path)
                    speed_factor = remaining_dur / hard_cap if hard_cap > 0 else max_speed
                    if speed_factor > max_speed:
                        print(f"[s5]   pause-aware: WARNING speed x{speed_factor:.2f} exceeds max, capping at x{max_speed} (will overflow next gap)")
                        speed_factor = max_speed
                    print(f"[s5]   pause-aware: compressed silences + speedup x{speed_factor:.2f} (+{max_overflow:.2f}s borrow)")
                    tmp = synced_path + ".tmp.wav"
                    _speedup_audio(synced_path, tmp, speed_factor)
                    os.replace(tmp, synced_path)
        else:
            # No silences detected β€” uniform speedup, but use hard_cap as the target.
            speed_factor = tts_duration / hard_cap if hard_cap > 0 else max_speed
            if speed_factor > max_speed:
                print(f"[s5]   pause-aware: WARNING speed x{speed_factor:.2f} exceeds max, capping at x{max_speed} (will overflow next gap)")
                speed_factor = max_speed
            print(f"[s5]   pause-aware: uniform speedup x{speed_factor:.2f} (no silences, +{max_overflow:.2f}s borrow)")
            _speedup_audio(tts_path, synced_path, speed_factor)

    elif tts_duration < target_duration * 0.98:
        shortfall = target_duration - tts_duration
        if tts_silences and original_pauses:
            # Distribute padding at detected silence positions
            _distribute_padding(tts_path, synced_path, tts_silences, shortfall, target_duration)
            print(f"[s5]   pause-aware: distributed {shortfall:.2f}s padding across {len(tts_silences)} pause points")
        else:
            # No pause points β€” pad at end
            _pad_silence(tts_path, synced_path, target_duration)
            print(f"[s5]   pause-aware: padded {shortfall:.2f}s at end (no pause points)")
    else:
        import shutil
        shutil.copy(tts_path, synced_path)


def _compress_silences(input_path: str, output_path: str,
                       silences: list[dict], keep_ratio: float) -> None:
    """Rewrite WAV with silence regions compressed to keep_ratio of their original duration."""
    samples, sr = _read_wav_samples(input_path)
    out = array.array("h")

    prev_end_sample = 0
    for sil in silences:
        sil_start = int(sil["start"] * sr)
        sil_end = int(sil["end"] * sr)

        # Copy speech before this silence
        out.extend(samples[prev_end_sample:sil_start])

        # Keep only keep_ratio of the silence
        kept_samples = int((sil_end - sil_start) * keep_ratio)
        if kept_samples > 0:
            out.extend(samples[sil_start:sil_start + kept_samples])

        prev_end_sample = sil_end

    # Copy remaining speech after last silence
    out.extend(samples[prev_end_sample:])
    _write_wav_samples(out, sr, output_path)


def _distribute_padding(input_path: str, output_path: str,
                        tts_silences: list[dict], shortfall: float,
                        target_duration: float) -> None:
    """Insert extra silence distributed across detected pause points."""
    samples, sr = _read_wav_samples(input_path)
    n_points = len(tts_silences)
    pad_per_point = shortfall / n_points

    out = array.array("h")
    prev_end_sample = 0

    for sil in tts_silences:
        sil_end = int(sil["end"] * sr)

        # Copy everything up to end of this silence region
        out.extend(samples[prev_end_sample:sil_end])

        # Insert extra silence
        extra_samples = int(pad_per_point * sr)
        out.extend(array.array("h", [0] * extra_samples))

        prev_end_sample = sil_end

    # Copy remaining audio
    out.extend(samples[prev_end_sample:])

    _write_wav_samples(out, sr, output_path)

    # Trim to exact target if slightly over due to rounding
    actual = len(out) / sr
    if actual > target_duration * 1.02:
        _trim_audio(output_path, output_path, target_duration)


def _generate_silence(output_path: str, duration: float, sample_rate: int = 16000) -> None:
    """Generate a silent WAV file of given duration."""
    num_samples = int(duration * sample_rate)
    with wave.open(output_path, "w") as f:
        f.setnchannels(1)
        f.setsampwidth(2)  # 16-bit
        f.setframerate(sample_rate)
        f.writeframes(b"\x00\x00" * num_samples)


def sync_and_stitch(
    segments: list[dict],
    output_path: str = "tmp/audio/final_audio.wav",
    synced_dir: str = "tmp/audio/tts_synced",
    max_speed: float = 1.8,
) -> str:
    """
    Sync each TTS segment to its original timestamp window and stitch into a single WAV.

    Args:
        segments: List of dicts with {start, end, tts_path}.
        output_path: Where to write the final stitched audio.
        synced_dir: Temp directory for per-segment synced WAVs.
        max_speed: Maximum allowed speedup factor (default 1.8x to preserve naturalness).

    Returns:
        Path to the final stitched audio WAV.
    """
    Path(synced_dir).mkdir(parents=True, exist_ok=True)
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)

    # Detect TTS sample rate from the first segment
    with wave.open(segments[0]["tts_path"], 'r') as f:
        tts_sample_rate = f.getframerate()
    print(f"[s5] TTS sample rate: {tts_sample_rate} Hz")

    concat_list_path = "tmp/concat_list.txt"
    concat_entries = []

    # Track the real wall-clock playback cursor. When a segment overflows its
    # original window, the cursor moves past the segment's nominal end, and the
    # next inter-segment silence shrinks accordingly β€” overflow is absorbed by
    # the following gap instead of being trimmed off the end of the audio.
    playback_cursor = 0.0
    for i, seg in enumerate(segments):
        start = seg["start"]
        end = seg["end"]
        target_duration = end - start
        tts_path = seg["tts_path"]

        # Fill gap before this segment with silence β€” but only as much as the
        # cursor is actually behind. If a prior segment overflowed past `start`,
        # `gap` goes negative and we skip the silence (and start slightly late).
        gap = start - playback_cursor
        if gap > 0.01:
            sil_path = os.path.join(synced_dir, f"silence_{i:04d}.wav")
            _generate_silence(sil_path, gap, sample_rate=tts_sample_rate)
            concat_entries.append(sil_path)
            playback_cursor += gap
        elif gap < -0.05:
            print(f"[s5] Seg {i}: running {-gap:.2f}s behind original timeline (prior overflow absorbed)")

        # Borrow budget: how much we may overflow `target_duration` without
        # trimming. We can use the silence between this segment's `end` and the
        # next segment's `start`. Last segment has no follower β†’ 0.
        if i + 1 < len(segments):
            allowed_overflow = max(segments[i + 1]["start"] - end, 0.0)
        else:
            allowed_overflow = 0.0

        tts_duration = _get_wav_duration(tts_path)
        synced_path = os.path.join(synced_dir, f"synced_{i:04d}.wav")
        hard_cap = target_duration + allowed_overflow

        words = seg.get("words")
        if words and len(words) > 1:
            print(f"[s5] Seg {i}: pause-aware sync ({tts_duration:.2f}s -> {target_duration:.2f}s, +{allowed_overflow:.2f}s borrow)")
            _pause_aware_sync(tts_path, synced_path, target_duration, words, max_speed,
                              max_overflow=allowed_overflow)
        elif tts_duration > target_duration * 1.02:
            # Speed up only as far as needed to land within hard_cap; if the
            # required factor exceeds max_speed, cap it and let it overflow β€”
            # the next gap will shrink to absorb it. Never trim.
            if tts_duration <= hard_cap:
                speed_factor = 1.0
            else:
                speed_factor = tts_duration / hard_cap if hard_cap > 0 else max_speed
            if speed_factor > max_speed:
                print(f"[s5] Seg {i}: WARNING speed x{speed_factor:.2f} exceeds max, capping at x{max_speed} (will overflow next gap)")
                speed_factor = max_speed
            if speed_factor > 1.001:
                print(f"[s5] Seg {i}: speeding up x{speed_factor:.2f} (+{allowed_overflow:.2f}s borrow)")
                _speedup_audio(tts_path, synced_path, speed_factor)
            else:
                import shutil
                shutil.copy(tts_path, synced_path)
                print(f"[s5] Seg {i}: within +{allowed_overflow:.2f}s borrow, no speedup")
        elif tts_duration < target_duration * 0.98:
            print(f"[s5] Seg {i}: padding {target_duration - tts_duration:.2f}s silence")
            _pad_silence(tts_path, synced_path, target_duration)
        else:
            import shutil
            shutil.copy(tts_path, synced_path)

        concat_entries.append(synced_path)
        playback_cursor += _get_wav_duration(synced_path)

    # Write concat list for ffmpeg
    with open(concat_list_path, "w") as f:
        for entry in concat_entries:
            abs_entry = os.path.abspath(entry)
            f.write(f"file '{abs_entry}'\n")

    # Concatenate all segments (re-encode to normalize sample rates)
    cmd = [
        "ffmpeg", "-y",
        "-f", "concat", "-safe", "0",
        "-i", concat_list_path,
        "-ar", str(tts_sample_rate),
        "-ac", "1",
        "-acodec", "pcm_s16le",
        output_path,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg concat failed:\n{result.stderr}")

    print(f"[s5] Audio sync complete β†’ {output_path} βœ“")
    return output_path