chatterbox-voice-studio / tests /test_dialog_endpoint.py
techfreakworm's picture
feat(dialog): /api/generate/dialog endpoint + per-turn dispatcher with seed reuse
edf3bf7 unverified
import io
import httpx
import numpy as np
import pytest
import soundfile as sf
from server.main import build_app
pytestmark = pytest.mark.asyncio
def _silent_wav(seconds: float = 1.0, sr: int = 24000) -> bytes:
samples = np.zeros(int(seconds * sr), dtype=np.float32)
buf = io.BytesIO()
sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
return buf.getvalue()
async def test_dialog_generates_concatenated_wav(monkeypatch, fake_classes):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
# Have FakeAdapter emit a real silent WAV so the dialog generator can decode it.
monkeypatch.setattr(
fake_classes["fake"],
"generate",
lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
)
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
files = {
"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
"reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"),
}
r = await c.post(
"/api/generate/dialog",
data={
"text": "SPEAKER A: hi\nSPEAKER B: hello",
"engine_id": "fake",
"params": "{}",
},
files=files,
)
assert r.status_code == 200
assert r.headers["content-type"].startswith("audio/wav")
assert r.content[:4] == b"RIFF"
assert r.headers["x-seed-used"] == "0"
async def test_dialog_format_invalid(monkeypatch, fake_classes):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
r = await c.post(
"/api/generate/dialog",
data={"text": "no speaker tags", "engine_id": "fake", "params": "{}"},
files={
"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
},
)
assert r.status_code == 400
assert r.json()["error"]["code"] == "dialog_format_invalid"
async def test_dialog_missing_reference(monkeypatch, fake_classes):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
monkeypatch.setattr(
fake_classes["fake"],
"generate",
lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
)
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
r = await c.post(
"/api/generate/dialog",
data={
"text": "SPEAKER A: hi\nSPEAKER B: hello",
"engine_id": "fake",
"params": "{}",
},
files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
)
assert r.status_code == 400
assert r.json()["error"]["code"] == "dialog_missing_reference"
async def test_dialog_unknown_engine_404(monkeypatch, fake_classes):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
r = await c.post(
"/api/generate/dialog",
data={
"text": "SPEAKER A: hi",
"engine_id": "nope",
"params": "{}",
},
files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
)
assert r.status_code == 404
assert r.json()["error"]["code"] == "model_not_found"