File size: 4,081 Bytes
5d81907
 
 
 
 
 
 
93f7cf1
5d81907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf3bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d745c3
edf3bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d745c3
edf3bf7
 
93f7cf1
 
edf3bf7
 
 
 
 
 
 
 
2d745c3
 
edf3bf7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Dialog mode: parse SPEAKER X: scripts into ordered turns and stitch
per-turn outputs into a single concatenated WAV.

Generator is in this same file but added in Task 12.
"""
from __future__ import annotations

import asyncio
import re
from dataclasses import dataclass


_SPEAKER_RE = re.compile(r"^\s*SPEAKER\s+([A-D])\s*:\s*", re.MULTILINE)


@dataclass(frozen=True)
class DialogTurn:
    speaker: str   # "A" | "B" | "C" | "D"
    text: str


class DialogParseError(ValueError):
    """Raised when a dialog script can't be parsed into turns."""


def parse_dialog(text: str) -> list[DialogTurn]:
    matches = list(_SPEAKER_RE.finditer(text))
    if not matches:
        raise DialogParseError(
            "Use SPEAKER A: ... / SPEAKER B: ... lines to define turns."
        )
    turns: list[DialogTurn] = []
    for i, m in enumerate(matches):
        start = m.end()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
        block = text[start:end].strip()
        if block:
            turns.append(DialogTurn(speaker=m.group(1), text=block))
    if not turns:
        raise DialogParseError("No non-empty speaker turns found.")
    return turns


import io as _io
import tempfile as _tempfile
from typing import Optional

import numpy as _np
import soundfile as _sf

from server.audio import AudioValidationError, validate_reference_clip, write_wav_bytes
from server.registry import Registry
from server.seed import apply_seed


SILENCE_GAP_MS = 250


class DialogReferenceError(ValueError):
    """Raised when a turn references a speaker without an uploaded clip."""


def _decode_wav_to_mono_float(wav_bytes: bytes) -> tuple[_np.ndarray, int]:
    arr, sr = _sf.read(_io.BytesIO(wav_bytes), dtype="float32", always_2d=False)
    if arr.ndim == 2:
        arr = arr.mean(axis=1)
    return arr.astype(_np.float32), int(sr)


def _save_temp_wav(data: bytes) -> str:
    tmp = _tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
    tmp.write(data)
    tmp.flush()
    tmp.close()
    return tmp.name


async def generate_dialog(
    *,
    registry: Registry,
    engine_id: str,
    text: str,
    language: Optional[str],
    params: dict,
    speaker_clips: dict[str, bytes],   # letter -> raw upload bytes (already validated)
    silence_ms: int = SILENCE_GAP_MS,
    session: "object | None" = None,   # _Session from server.progress, or None
) -> tuple[bytes, int, int]:
    turns = parse_dialog(text)

    # Verify every referenced speaker has a clip.
    referenced = {t.speaker for t in turns}
    missing = referenced - set(speaker_clips.keys())
    if missing:
        raise DialogReferenceError(
            f"missing reference for speaker {sorted(missing)[0]}"
        )

    # Persist each clip to a tempfile path once (the adapter expects a path).
    paths: dict[str, str] = {
        letter: _save_temp_wav(blob) for letter, blob in speaker_clips.items()
    }

    adapter = await registry.get_or_load(engine_id)

    # Resolve and re-apply one seed for the whole dialog.
    seed_used = apply_seed(params.get("seed"))
    params_for_call = {**params, "seed": seed_used}

    sr_out: int | None = None
    adapter_seed_used: int = seed_used
    chunks: list[_np.ndarray] = []
    for i, turn in enumerate(turns):
        # Re-apply the same seed before each turn so the run is reproducible.
        apply_seed(seed_used)
        wav_bytes, sr, adapter_seed_used = await asyncio.to_thread(
            adapter.generate,
            turn.text, paths[turn.speaker], language, params_for_call,
        )
        arr, _ = _decode_wav_to_mono_float(wav_bytes)
        chunks.append(arr)
        if sr_out is None:
            sr_out = sr
        if silence_ms > 0:
            chunks.append(_np.zeros(int(silence_ms * sr / 1000), dtype=_np.float32))
        if session is not None:
            await session.turn_complete(i + 1)

    assert sr_out is not None
    full = _np.concatenate(chunks) if chunks else _np.zeros(0, dtype=_np.float32)
    out = write_wav_bytes(full, sr_out)
    return out, sr_out, adapter_seed_used