fix(progress): run sync adapter.generate in a thread so SSE events stream during generation
Browse files- server/dialog.py +3 -1
- server/main.py +3 -2
server/dialog.py
CHANGED
|
@@ -5,6 +5,7 @@ Generator is in this same file but added in Task 12.
|
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
|
|
| 8 |
import re
|
| 9 |
from dataclasses import dataclass
|
| 10 |
|
|
@@ -112,7 +113,8 @@ async def generate_dialog(
|
|
| 112 |
for i, turn in enumerate(turns):
|
| 113 |
# Re-apply the same seed before each turn so the run is reproducible.
|
| 114 |
apply_seed(seed_used)
|
| 115 |
-
wav_bytes, sr, adapter_seed_used =
|
|
|
|
| 116 |
turn.text, paths[turn.speaker], language, params_for_call,
|
| 117 |
)
|
| 118 |
arr, _ = _decode_wav_to_mono_float(wav_bytes)
|
|
|
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
+
import asyncio
|
| 9 |
import re
|
| 10 |
from dataclasses import dataclass
|
| 11 |
|
|
|
|
| 113 |
for i, turn in enumerate(turns):
|
| 114 |
# Re-apply the same seed before each turn so the run is reproducible.
|
| 115 |
apply_seed(seed_used)
|
| 116 |
+
wav_bytes, sr, adapter_seed_used = await asyncio.to_thread(
|
| 117 |
+
adapter.generate,
|
| 118 |
turn.text, paths[turn.speaker], language, params_for_call,
|
| 119 |
)
|
| 120 |
arr, _ = _decode_wav_to_mono_float(wav_bytes)
|
server/main.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""FastAPI application factory."""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import tempfile
|
|
@@ -172,8 +173,8 @@ def build_app() -> FastAPI:
|
|
| 172 |
bus = get_bus()
|
| 173 |
try:
|
| 174 |
async with bus.session("single", total_turns=1) as sess:
|
| 175 |
-
wav_bytes, _sr, seed_used =
|
| 176 |
-
text, ref_path, language, json.loads(params or "{}")
|
| 177 |
)
|
| 178 |
sess.set_seed(seed_used)
|
| 179 |
except Exception as exc:
|
|
|
|
| 1 |
"""FastAPI application factory."""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import asyncio
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import tempfile
|
|
|
|
| 173 |
bus = get_bus()
|
| 174 |
try:
|
| 175 |
async with bus.session("single", total_turns=1) as sess:
|
| 176 |
+
wav_bytes, _sr, seed_used = await asyncio.to_thread(
|
| 177 |
+
gen_fn, text, ref_path, language, json.loads(params or "{}"),
|
| 178 |
)
|
| 179 |
sess.set_seed(seed_used)
|
| 180 |
except Exception as exc:
|