techfreakworm commited on
Commit
edf3bf7
·
unverified ·
1 Parent(s): 5d81907

feat(dialog): /api/generate/dialog endpoint + per-turn dispatcher with seed reuse

Browse files
Files changed (3) hide show
  1. server/dialog.py +87 -0
  2. server/main.py +85 -0
  3. tests/test_dialog_endpoint.py +113 -0
server/dialog.py CHANGED
@@ -38,3 +38,90 @@ def parse_dialog(text: str) -> list[DialogTurn]:
38
  if not turns:
39
  raise DialogParseError("No non-empty speaker turns found.")
40
  return turns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if not turns:
39
  raise DialogParseError("No non-empty speaker turns found.")
40
  return turns
41
+
42
+
43
+ import io as _io
44
+ import tempfile as _tempfile
45
+ from typing import Optional
46
+
47
+ import numpy as _np
48
+ import soundfile as _sf
49
+
50
+ from server.audio import AudioValidationError, validate_reference_clip, write_wav_bytes
51
+ from server.registry import Registry
52
+ from server.seed import apply_seed
53
+
54
+
55
+ SILENCE_GAP_MS = 250
56
+
57
+
58
+ class DialogReferenceError(ValueError):
59
+ """Raised when a turn references a speaker without an uploaded clip."""
60
+
61
+
62
+ def _decode_wav_to_mono_float(wav_bytes: bytes) -> tuple[_np.ndarray, int]:
63
+ arr, sr = _sf.read(_io.BytesIO(wav_bytes), dtype="float32", always_2d=False)
64
+ if arr.ndim == 2:
65
+ arr = arr.mean(axis=1)
66
+ return arr.astype(_np.float32), int(sr)
67
+
68
+
69
+ def _save_temp_wav(data: bytes) -> str:
70
+ tmp = _tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
71
+ tmp.write(data)
72
+ tmp.flush()
73
+ tmp.close()
74
+ return tmp.name
75
+
76
+
77
+ async def generate_dialog(
78
+ *,
79
+ registry: Registry,
80
+ engine_id: str,
81
+ text: str,
82
+ language: Optional[str],
83
+ params: dict,
84
+ speaker_clips: dict[str, bytes], # letter -> raw upload bytes (already validated)
85
+ silence_ms: int = SILENCE_GAP_MS,
86
+ ) -> tuple[bytes, int, int]:
87
+ turns = parse_dialog(text)
88
+
89
+ # Verify every referenced speaker has a clip.
90
+ referenced = {t.speaker for t in turns}
91
+ missing = referenced - set(speaker_clips.keys())
92
+ if missing:
93
+ raise DialogReferenceError(
94
+ f"missing reference for speaker {sorted(missing)[0]}"
95
+ )
96
+
97
+ # Persist each clip to a tempfile path once (the adapter expects a path).
98
+ paths: dict[str, str] = {
99
+ letter: _save_temp_wav(blob) for letter, blob in speaker_clips.items()
100
+ }
101
+
102
+ adapter = await registry.get_or_load(engine_id)
103
+
104
+ # Resolve and re-apply one seed for the whole dialog.
105
+ seed_used = apply_seed(params.get("seed"))
106
+ params_for_call = {**params, "seed": seed_used}
107
+
108
+ sr_out: int | None = None
109
+ adapter_seed_used: int = seed_used
110
+ chunks: list[_np.ndarray] = []
111
+ for turn in turns:
112
+ # Re-apply the same seed before each turn so the run is reproducible.
113
+ apply_seed(seed_used)
114
+ wav_bytes, sr, adapter_seed_used = adapter.generate(
115
+ turn.text, paths[turn.speaker], language, params_for_call,
116
+ )
117
+ arr, _ = _decode_wav_to_mono_float(wav_bytes)
118
+ chunks.append(arr)
119
+ if sr_out is None:
120
+ sr_out = sr
121
+ if silence_ms > 0:
122
+ chunks.append(_np.zeros(int(silence_ms * sr / 1000), dtype=_np.float32))
123
+
124
+ assert sr_out is not None
125
+ full = _np.concatenate(chunks) if chunks else _np.zeros(0, dtype=_np.float32)
126
+ out = write_wav_bytes(full, sr_out)
127
+ return out, sr_out, adapter_seed_used
server/main.py CHANGED
@@ -16,6 +16,11 @@ from sse_starlette.sse import EventSourceResponse
16
 
17
  from server.audio import AudioValidationError, validate_reference_clip
18
  from server.device import select_device
 
 
 
 
 
19
  from server.registry import Registry
20
  from server.zerogpu import decorate
21
 
@@ -165,6 +170,86 @@ def build_app() -> FastAPI:
165
  headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"},
166
  )
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @app.exception_handler(HTTPException)
169
  async def _http_exc(request, exc: HTTPException):
170
  if isinstance(exc.detail, dict) and "error" in exc.detail:
 
16
 
17
  from server.audio import AudioValidationError, validate_reference_clip
18
  from server.device import select_device
19
+ from server.dialog import (
20
+ DialogParseError,
21
+ DialogReferenceError,
22
+ generate_dialog,
23
+ )
24
  from server.registry import Registry
25
  from server.zerogpu import decorate
26
 
 
170
  headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"},
171
  )
172
 
173
+ @app.post("/api/generate/dialog")
174
+ async def generate_dialog_route(
175
+ text: str = Form(...),
176
+ engine_id: str = Form(...),
177
+ params: str = Form("{}"),
178
+ language: str | None = Form(None),
179
+ reference_wav_a: UploadFile | None = File(None),
180
+ reference_wav_b: UploadFile | None = File(None),
181
+ reference_wav_c: UploadFile | None = File(None),
182
+ reference_wav_d: UploadFile | None = File(None),
183
+ ):
184
+ speaker_clips: dict[str, bytes] = {}
185
+ upload_map = {
186
+ "A": reference_wav_a,
187
+ "B": reference_wav_b,
188
+ "C": reference_wav_c,
189
+ "D": reference_wav_d,
190
+ }
191
+ for letter, upload in upload_map.items():
192
+ if upload is None:
193
+ continue
194
+ data = await upload.read()
195
+ try:
196
+ validate_reference_clip(data)
197
+ except AudioValidationError as exc:
198
+ return JSONResponse(
199
+ status_code=400,
200
+ content={
201
+ "error": {
202
+ "code": "reference_invalid",
203
+ "message": f"speaker {letter}: {exc}",
204
+ }
205
+ },
206
+ )
207
+ speaker_clips[letter] = data
208
+
209
+ try:
210
+ wav_bytes, _sr, seed_used = await generate_dialog(
211
+ registry=app.state.registry,
212
+ engine_id=engine_id,
213
+ text=text,
214
+ language=language,
215
+ params=json.loads(params or "{}"),
216
+ speaker_clips=speaker_clips,
217
+ )
218
+ except KeyError:
219
+ raise HTTPException(
220
+ status_code=404,
221
+ detail={"error": {"code": "model_not_found", "message": engine_id}},
222
+ )
223
+ except DialogParseError as exc:
224
+ return JSONResponse(
225
+ status_code=400,
226
+ content={
227
+ "error": {"code": "dialog_format_invalid", "message": str(exc)}
228
+ },
229
+ )
230
+ except DialogReferenceError as exc:
231
+ return JSONResponse(
232
+ status_code=400,
233
+ content={
234
+ "error": {"code": "dialog_missing_reference", "message": str(exc)}
235
+ },
236
+ )
237
+ except Exception as exc:
238
+ return JSONResponse(
239
+ status_code=500,
240
+ content={
241
+ "error": {"code": "generation_failed", "message": str(exc)}
242
+ },
243
+ )
244
+ return Response(
245
+ content=wav_bytes,
246
+ media_type="audio/wav",
247
+ headers={
248
+ "X-Seed-Used": str(seed_used),
249
+ "Access-Control-Expose-Headers": "X-Seed-Used",
250
+ },
251
+ )
252
+
253
  @app.exception_handler(HTTPException)
254
  async def _http_exc(request, exc: HTTPException):
255
  if isinstance(exc.detail, dict) and "error" in exc.detail:
tests/test_dialog_endpoint.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import httpx
4
+ import numpy as np
5
+ import pytest
6
+ import soundfile as sf
7
+
8
+ from server.main import build_app
9
+
10
+
11
+ pytestmark = pytest.mark.asyncio
12
+
13
+
14
+ def _silent_wav(seconds: float = 1.0, sr: int = 24000) -> bytes:
15
+ samples = np.zeros(int(seconds * sr), dtype=np.float32)
16
+ buf = io.BytesIO()
17
+ sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
18
+ return buf.getvalue()
19
+
20
+
21
+ async def test_dialog_generates_concatenated_wav(monkeypatch, fake_classes):
22
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
23
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
24
+ # Have FakeAdapter emit a real silent WAV so the dialog generator can decode it.
25
+ monkeypatch.setattr(
26
+ fake_classes["fake"],
27
+ "generate",
28
+ lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
29
+ )
30
+ app = build_app()
31
+ from tests.conftest import lifespan_ctx
32
+ transport = httpx.ASGITransport(app=app)
33
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
34
+ files = {
35
+ "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
36
+ "reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"),
37
+ }
38
+ r = await c.post(
39
+ "/api/generate/dialog",
40
+ data={
41
+ "text": "SPEAKER A: hi\nSPEAKER B: hello",
42
+ "engine_id": "fake",
43
+ "params": "{}",
44
+ },
45
+ files=files,
46
+ )
47
+ assert r.status_code == 200
48
+ assert r.headers["content-type"].startswith("audio/wav")
49
+ assert r.content[:4] == b"RIFF"
50
+ assert r.headers["x-seed-used"] == "0"
51
+
52
+
53
+ async def test_dialog_format_invalid(monkeypatch, fake_classes):
54
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
55
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
56
+ app = build_app()
57
+ from tests.conftest import lifespan_ctx
58
+ transport = httpx.ASGITransport(app=app)
59
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
60
+ r = await c.post(
61
+ "/api/generate/dialog",
62
+ data={"text": "no speaker tags", "engine_id": "fake", "params": "{}"},
63
+ files={
64
+ "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
65
+ },
66
+ )
67
+ assert r.status_code == 400
68
+ assert r.json()["error"]["code"] == "dialog_format_invalid"
69
+
70
+
71
+ async def test_dialog_missing_reference(monkeypatch, fake_classes):
72
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
73
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
74
+ monkeypatch.setattr(
75
+ fake_classes["fake"],
76
+ "generate",
77
+ lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
78
+ )
79
+ app = build_app()
80
+ from tests.conftest import lifespan_ctx
81
+ transport = httpx.ASGITransport(app=app)
82
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
83
+ r = await c.post(
84
+ "/api/generate/dialog",
85
+ data={
86
+ "text": "SPEAKER A: hi\nSPEAKER B: hello",
87
+ "engine_id": "fake",
88
+ "params": "{}",
89
+ },
90
+ files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
91
+ )
92
+ assert r.status_code == 400
93
+ assert r.json()["error"]["code"] == "dialog_missing_reference"
94
+
95
+
96
+ async def test_dialog_unknown_engine_404(monkeypatch, fake_classes):
97
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
98
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
99
+ app = build_app()
100
+ from tests.conftest import lifespan_ctx
101
+ transport = httpx.ASGITransport(app=app)
102
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
103
+ r = await c.post(
104
+ "/api/generate/dialog",
105
+ data={
106
+ "text": "SPEAKER A: hi",
107
+ "engine_id": "nope",
108
+ "params": "{}",
109
+ },
110
+ files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
111
+ )
112
+ assert r.status_code == 404
113
+ assert r.json()["error"]["code"] == "model_not_found"