techfreakworm commited on
Commit
93f7cf1
·
unverified ·
1 Parent(s): 2d745c3

fix(progress): run sync adapter.generate in a thread so SSE events stream during generation

Browse files
Files changed (2) hide show
  1. server/dialog.py +3 -1
  2. server/main.py +3 -2
server/dialog.py CHANGED
@@ -5,6 +5,7 @@ Generator is in this same file but added in Task 12.
5
  """
6
  from __future__ import annotations
7
 
 
8
  import re
9
  from dataclasses import dataclass
10
 
@@ -112,7 +113,8 @@ async def generate_dialog(
112
  for i, turn in enumerate(turns):
113
  # Re-apply the same seed before each turn so the run is reproducible.
114
  apply_seed(seed_used)
115
- wav_bytes, sr, adapter_seed_used = adapter.generate(
 
116
  turn.text, paths[turn.speaker], language, params_for_call,
117
  )
118
  arr, _ = _decode_wav_to_mono_float(wav_bytes)
 
5
  """
6
  from __future__ import annotations
7
 
8
+ import asyncio
9
  import re
10
  from dataclasses import dataclass
11
 
 
113
  for i, turn in enumerate(turns):
114
  # Re-apply the same seed before each turn so the run is reproducible.
115
  apply_seed(seed_used)
116
+ wav_bytes, sr, adapter_seed_used = await asyncio.to_thread(
117
+ adapter.generate,
118
  turn.text, paths[turn.speaker], language, params_for_call,
119
  )
120
  arr, _ = _decode_wav_to_mono_float(wav_bytes)
server/main.py CHANGED
@@ -1,6 +1,7 @@
1
  """FastAPI application factory."""
2
  from __future__ import annotations
3
 
 
4
  import json
5
  import os
6
  import tempfile
@@ -172,8 +173,8 @@ def build_app() -> FastAPI:
172
  bus = get_bus()
173
  try:
174
  async with bus.session("single", total_turns=1) as sess:
175
- wav_bytes, _sr, seed_used = gen_fn(
176
- text, ref_path, language, json.loads(params or "{}")
177
  )
178
  sess.set_seed(seed_used)
179
  except Exception as exc:
 
1
  """FastAPI application factory."""
2
  from __future__ import annotations
3
 
4
+ import asyncio
5
  import json
6
  import os
7
  import tempfile
 
173
  bus = get_bus()
174
  try:
175
  async with bus.session("single", total_turns=1) as sess:
176
+ wav_bytes, _sr, seed_used = await asyncio.to_thread(
177
+ gen_fn, text, ref_path, language, json.loads(params or "{}"),
178
  )
179
  sess.set_seed(seed_used)
180
  except Exception as exc: