techfreakworm commited on
Commit
b066638
·
unverified ·
1 Parent(s): 451dece

feat(api): adapter generate returns seed_used; expose X-Seed-Used header

Browse files
server/main.py CHANGED
@@ -151,13 +151,19 @@ def build_app() -> FastAPI:
151
 
152
  gen_fn = decorate(adapter.generate)
153
  try:
154
- wav_bytes, _sr = gen_fn(text, ref_path, language, json.loads(params or "{}"))
 
 
155
  except Exception as exc:
156
  return JSONResponse(
157
  status_code=500,
158
  content={"error": {"code": "generation_failed", "message": str(exc)}},
159
  )
160
- return Response(content=wav_bytes, media_type="audio/wav")
 
 
 
 
161
 
162
  @app.exception_handler(HTTPException)
163
  async def _http_exc(request, exc: HTTPException):
 
151
 
152
  gen_fn = decorate(adapter.generate)
153
  try:
154
+ wav_bytes, _sr, seed_used = gen_fn(
155
+ text, ref_path, language, json.loads(params or "{}")
156
+ )
157
  except Exception as exc:
158
  return JSONResponse(
159
  status_code=500,
160
  content={"error": {"code": "generation_failed", "message": str(exc)}},
161
  )
162
+ return Response(
163
+ content=wav_bytes,
164
+ media_type="audio/wav",
165
+ headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"},
166
+ )
167
 
168
  @app.exception_handler(HTTPException)
169
  async def _http_exc(request, exc: HTTPException):
server/models/base.py CHANGED
@@ -29,7 +29,7 @@ class ModelAdapter(Protocol):
29
  reference_wav_path: str | None,
30
  language: str | None,
31
  params: dict[str, Any],
32
- ) -> tuple[bytes, int]: ...
33
 
34
 
35
  def is_valid_adapter(cls: type) -> bool:
 
29
  reference_wav_path: str | None,
30
  language: str | None,
31
  params: dict[str, Any],
32
+ ) -> tuple[bytes, int, int]: ... # (wav_bytes, sample_rate, seed_used)
33
 
34
 
35
  def is_valid_adapter(cls: type) -> bool:
server/models/chatterbox_en.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
 
10
 
11
 
12
  class Adapter:
@@ -52,9 +53,10 @@ class Adapter:
52
  reference_wav_path: str | None,
53
  language: str | None,
54
  params: dict[str, Any],
55
- ) -> tuple[bytes, int]:
56
  if self._model is None:
57
  raise RuntimeError("model not loaded")
 
58
  wav = self._model.generate(
59
  text,
60
  audio_prompt_path=reference_wav_path,
@@ -73,4 +75,4 @@ class Adapter:
73
  sr = getattr(self._model, "sr", 24000)
74
  buf = io.BytesIO()
75
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
76
- return buf.getvalue(), sr
 
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
10
+ from server.seed import apply_seed
11
 
12
 
13
  class Adapter:
 
53
  reference_wav_path: str | None,
54
  language: str | None,
55
  params: dict[str, Any],
56
+ ) -> tuple[bytes, int, int]:
57
  if self._model is None:
58
  raise RuntimeError("model not loaded")
59
+ seed_used = apply_seed(params.get("seed"))
60
  wav = self._model.generate(
61
  text,
62
  audio_prompt_path=reference_wav_path,
 
75
  sr = getattr(self._model, "sr", 24000)
76
  buf = io.BytesIO()
77
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
78
+ return buf.getvalue(), sr, seed_used
server/models/chatterbox_mtl.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
 
10
 
11
 
12
  _MTL_LANGS: list[Lang] = [
@@ -70,11 +71,12 @@ class Adapter:
70
  reference_wav_path: str | None,
71
  language: str | None,
72
  params: dict[str, Any],
73
- ) -> tuple[bytes, int]:
74
  if self._model is None:
75
  raise RuntimeError("model not loaded")
76
  if not language:
77
  raise ValueError("language is required for chatterbox-mtl")
 
78
  wav = self._model.generate(
79
  text,
80
  language_id=language,
@@ -93,4 +95,4 @@ class Adapter:
93
  sr = getattr(self._model, "sr", 24000)
94
  buf = io.BytesIO()
95
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
96
- return buf.getvalue(), sr
 
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
10
+ from server.seed import apply_seed
11
 
12
 
13
  _MTL_LANGS: list[Lang] = [
 
71
  reference_wav_path: str | None,
72
  language: str | None,
73
  params: dict[str, Any],
74
+ ) -> tuple[bytes, int, int]:
75
  if self._model is None:
76
  raise RuntimeError("model not loaded")
77
  if not language:
78
  raise ValueError("language is required for chatterbox-mtl")
79
+ seed_used = apply_seed(params.get("seed"))
80
  wav = self._model.generate(
81
  text,
82
  language_id=language,
 
95
  sr = getattr(self._model, "sr", 24000)
96
  buf = io.BytesIO()
97
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
98
+ return buf.getvalue(), sr, seed_used
server/models/chatterbox_turbo.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
 
10
 
11
 
12
  class Adapter:
@@ -54,9 +55,10 @@ class Adapter:
54
  reference_wav_path: str | None,
55
  language: str | None,
56
  params: dict[str, Any],
57
- ) -> tuple[bytes, int]:
58
  if self._model is None:
59
  raise RuntimeError("model not loaded")
 
60
  wav = self._model.generate(
61
  text,
62
  audio_prompt_path=reference_wav_path,
@@ -74,4 +76,4 @@ class Adapter:
74
  sr = getattr(self._model, "sr", 24000)
75
  buf = io.BytesIO()
76
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
77
- return buf.getvalue(), sr
 
7
  import soundfile as sf
8
 
9
  from server.schemas import Lang, ParamSpec
10
+ from server.seed import apply_seed
11
 
12
 
13
  class Adapter:
 
55
  reference_wav_path: str | None,
56
  language: str | None,
57
  params: dict[str, Any],
58
+ ) -> tuple[bytes, int, int]:
59
  if self._model is None:
60
  raise RuntimeError("model not loaded")
61
+ seed_used = apply_seed(params.get("seed"))
62
  wav = self._model.generate(
63
  text,
64
  audio_prompt_path=reference_wav_path,
 
76
  sr = getattr(self._model, "sr", 24000)
77
  buf = io.BytesIO()
78
  sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
79
+ return buf.getvalue(), sr, seed_used
tests/conftest.py CHANGED
@@ -37,7 +37,10 @@ class FakeAdapter:
37
  self.loaded = False
38
 
39
  def generate(self, text, reference_wav_path, language, params):
40
- return (b"FAKEWAV", 24000)
 
 
 
41
 
42
 
43
  class FakeAdapterB(FakeAdapter):
 
37
  self.loaded = False
38
 
39
  def generate(self, text, reference_wav_path, language, params):
40
+ # FakeAdapter never actually applies a seed; report the input or 0.
41
+ seed_in = params.get("seed", 0) if isinstance(params, dict) else 0
42
+ seed_used = 0 if seed_in is None or seed_in < 0 else int(seed_in)
43
+ return (b"FAKEWAV", 24000, seed_used)
44
 
45
 
46
  class FakeAdapterB(FakeAdapter):
tests/test_main_generate.py CHANGED
@@ -25,6 +25,7 @@ async def test_generate_returns_wav_bytes(monkeypatch, fake_classes):
25
  assert r.status_code == 200
26
  assert r.headers["content-type"].startswith("audio/wav")
27
  assert r.content == b"FAKEWAV"
 
28
 
29
 
30
  async def test_generate_unknown_model_404(monkeypatch, fake_classes):
 
25
  assert r.status_code == 200
26
  assert r.headers["content-type"].startswith("audio/wav")
27
  assert r.content == b"FAKEWAV"
28
+ assert r.headers["x-seed-used"] == "0"
29
 
30
 
31
  async def test_generate_unknown_model_404(monkeypatch, fake_classes):