feat(api): adapter generate returns seed_used; expose X-Seed-Used header
Browse files- server/main.py +8 -2
- server/models/base.py +1 -1
- server/models/chatterbox_en.py +4 -2
- server/models/chatterbox_mtl.py +4 -2
- server/models/chatterbox_turbo.py +4 -2
- tests/conftest.py +4 -1
- tests/test_main_generate.py +1 -0
server/main.py
CHANGED
|
@@ -151,13 +151,19 @@ def build_app() -> FastAPI:
|
|
| 151 |
|
| 152 |
gen_fn = decorate(adapter.generate)
|
| 153 |
try:
|
| 154 |
-
wav_bytes, _sr = gen_fn(
|
|
|
|
|
|
|
| 155 |
except Exception as exc:
|
| 156 |
return JSONResponse(
|
| 157 |
status_code=500,
|
| 158 |
content={"error": {"code": "generation_failed", "message": str(exc)}},
|
| 159 |
)
|
| 160 |
-
return Response(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
@app.exception_handler(HTTPException)
|
| 163 |
async def _http_exc(request, exc: HTTPException):
|
|
|
|
| 151 |
|
| 152 |
gen_fn = decorate(adapter.generate)
|
| 153 |
try:
|
| 154 |
+
wav_bytes, _sr, seed_used = gen_fn(
|
| 155 |
+
text, ref_path, language, json.loads(params or "{}")
|
| 156 |
+
)
|
| 157 |
except Exception as exc:
|
| 158 |
return JSONResponse(
|
| 159 |
status_code=500,
|
| 160 |
content={"error": {"code": "generation_failed", "message": str(exc)}},
|
| 161 |
)
|
| 162 |
+
return Response(
|
| 163 |
+
content=wav_bytes,
|
| 164 |
+
media_type="audio/wav",
|
| 165 |
+
headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"},
|
| 166 |
+
)
|
| 167 |
|
| 168 |
@app.exception_handler(HTTPException)
|
| 169 |
async def _http_exc(request, exc: HTTPException):
|
server/models/base.py
CHANGED
|
@@ -29,7 +29,7 @@ class ModelAdapter(Protocol):
|
|
| 29 |
reference_wav_path: str | None,
|
| 30 |
language: str | None,
|
| 31 |
params: dict[str, Any],
|
| 32 |
-
) -> tuple[bytes, int]: ...
|
| 33 |
|
| 34 |
|
| 35 |
def is_valid_adapter(cls: type) -> bool:
|
|
|
|
| 29 |
reference_wav_path: str | None,
|
| 30 |
language: str | None,
|
| 31 |
params: dict[str, Any],
|
| 32 |
+
) -> tuple[bytes, int, int]: ... # (wav_bytes, sample_rate, seed_used)
|
| 33 |
|
| 34 |
|
| 35 |
def is_valid_adapter(cls: type) -> bool:
|
server/models/chatterbox_en.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class Adapter:
|
|
@@ -52,9 +53,10 @@ class Adapter:
|
|
| 52 |
reference_wav_path: str | None,
|
| 53 |
language: str | None,
|
| 54 |
params: dict[str, Any],
|
| 55 |
-
) -> tuple[bytes, int]:
|
| 56 |
if self._model is None:
|
| 57 |
raise RuntimeError("model not loaded")
|
|
|
|
| 58 |
wav = self._model.generate(
|
| 59 |
text,
|
| 60 |
audio_prompt_path=reference_wav_path,
|
|
@@ -73,4 +75,4 @@ class Adapter:
|
|
| 73 |
sr = getattr(self._model, "sr", 24000)
|
| 74 |
buf = io.BytesIO()
|
| 75 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 76 |
-
return buf.getvalue(), sr
|
|
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
| 10 |
+
from server.seed import apply_seed
|
| 11 |
|
| 12 |
|
| 13 |
class Adapter:
|
|
|
|
| 53 |
reference_wav_path: str | None,
|
| 54 |
language: str | None,
|
| 55 |
params: dict[str, Any],
|
| 56 |
+
) -> tuple[bytes, int, int]:
|
| 57 |
if self._model is None:
|
| 58 |
raise RuntimeError("model not loaded")
|
| 59 |
+
seed_used = apply_seed(params.get("seed"))
|
| 60 |
wav = self._model.generate(
|
| 61 |
text,
|
| 62 |
audio_prompt_path=reference_wav_path,
|
|
|
|
| 75 |
sr = getattr(self._model, "sr", 24000)
|
| 76 |
buf = io.BytesIO()
|
| 77 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 78 |
+
return buf.getvalue(), sr, seed_used
|
server/models/chatterbox_mtl.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
_MTL_LANGS: list[Lang] = [
|
|
@@ -70,11 +71,12 @@ class Adapter:
|
|
| 70 |
reference_wav_path: str | None,
|
| 71 |
language: str | None,
|
| 72 |
params: dict[str, Any],
|
| 73 |
-
) -> tuple[bytes, int]:
|
| 74 |
if self._model is None:
|
| 75 |
raise RuntimeError("model not loaded")
|
| 76 |
if not language:
|
| 77 |
raise ValueError("language is required for chatterbox-mtl")
|
|
|
|
| 78 |
wav = self._model.generate(
|
| 79 |
text,
|
| 80 |
language_id=language,
|
|
@@ -93,4 +95,4 @@ class Adapter:
|
|
| 93 |
sr = getattr(self._model, "sr", 24000)
|
| 94 |
buf = io.BytesIO()
|
| 95 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 96 |
-
return buf.getvalue(), sr
|
|
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
| 10 |
+
from server.seed import apply_seed
|
| 11 |
|
| 12 |
|
| 13 |
_MTL_LANGS: list[Lang] = [
|
|
|
|
| 71 |
reference_wav_path: str | None,
|
| 72 |
language: str | None,
|
| 73 |
params: dict[str, Any],
|
| 74 |
+
) -> tuple[bytes, int, int]:
|
| 75 |
if self._model is None:
|
| 76 |
raise RuntimeError("model not loaded")
|
| 77 |
if not language:
|
| 78 |
raise ValueError("language is required for chatterbox-mtl")
|
| 79 |
+
seed_used = apply_seed(params.get("seed"))
|
| 80 |
wav = self._model.generate(
|
| 81 |
text,
|
| 82 |
language_id=language,
|
|
|
|
| 95 |
sr = getattr(self._model, "sr", 24000)
|
| 96 |
buf = io.BytesIO()
|
| 97 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 98 |
+
return buf.getvalue(), sr, seed_used
|
server/models/chatterbox_turbo.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, ClassVar
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class Adapter:
|
|
@@ -54,9 +55,10 @@ class Adapter:
|
|
| 54 |
reference_wav_path: str | None,
|
| 55 |
language: str | None,
|
| 56 |
params: dict[str, Any],
|
| 57 |
-
) -> tuple[bytes, int]:
|
| 58 |
if self._model is None:
|
| 59 |
raise RuntimeError("model not loaded")
|
|
|
|
| 60 |
wav = self._model.generate(
|
| 61 |
text,
|
| 62 |
audio_prompt_path=reference_wav_path,
|
|
@@ -74,4 +76,4 @@ class Adapter:
|
|
| 74 |
sr = getattr(self._model, "sr", 24000)
|
| 75 |
buf = io.BytesIO()
|
| 76 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 77 |
-
return buf.getvalue(), sr
|
|
|
|
| 7 |
import soundfile as sf
|
| 8 |
|
| 9 |
from server.schemas import Lang, ParamSpec
|
| 10 |
+
from server.seed import apply_seed
|
| 11 |
|
| 12 |
|
| 13 |
class Adapter:
|
|
|
|
| 55 |
reference_wav_path: str | None,
|
| 56 |
language: str | None,
|
| 57 |
params: dict[str, Any],
|
| 58 |
+
) -> tuple[bytes, int, int]:
|
| 59 |
if self._model is None:
|
| 60 |
raise RuntimeError("model not loaded")
|
| 61 |
+
seed_used = apply_seed(params.get("seed"))
|
| 62 |
wav = self._model.generate(
|
| 63 |
text,
|
| 64 |
audio_prompt_path=reference_wav_path,
|
|
|
|
| 76 |
sr = getattr(self._model, "sr", 24000)
|
| 77 |
buf = io.BytesIO()
|
| 78 |
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
|
| 79 |
+
return buf.getvalue(), sr, seed_used
|
tests/conftest.py
CHANGED
|
@@ -37,7 +37,10 @@ class FakeAdapter:
|
|
| 37 |
self.loaded = False
|
| 38 |
|
| 39 |
def generate(self, text, reference_wav_path, language, params):
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class FakeAdapterB(FakeAdapter):
|
|
|
|
| 37 |
self.loaded = False
|
| 38 |
|
| 39 |
def generate(self, text, reference_wav_path, language, params):
|
| 40 |
+
# FakeAdapter never actually applies a seed; report the input or 0.
|
| 41 |
+
seed_in = params.get("seed", 0) if isinstance(params, dict) else 0
|
| 42 |
+
seed_used = 0 if seed_in is None or seed_in < 0 else int(seed_in)
|
| 43 |
+
return (b"FAKEWAV", 24000, seed_used)
|
| 44 |
|
| 45 |
|
| 46 |
class FakeAdapterB(FakeAdapter):
|
tests/test_main_generate.py
CHANGED
|
@@ -25,6 +25,7 @@ async def test_generate_returns_wav_bytes(monkeypatch, fake_classes):
|
|
| 25 |
assert r.status_code == 200
|
| 26 |
assert r.headers["content-type"].startswith("audio/wav")
|
| 27 |
assert r.content == b"FAKEWAV"
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
async def test_generate_unknown_model_404(monkeypatch, fake_classes):
|
|
|
|
| 25 |
assert r.status_code == 200
|
| 26 |
assert r.headers["content-type"].startswith("audio/wav")
|
| 27 |
assert r.content == b"FAKEWAV"
|
| 28 |
+
assert r.headers["x-seed-used"] == "0"
|
| 29 |
|
| 30 |
|
| 31 |
async def test_generate_unknown_model_404(monkeypatch, fake_classes):
|