File size: 4,298 Bytes
edf3bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import io

import httpx
import numpy as np
import pytest
import soundfile as sf

from server.main import build_app


pytestmark = pytest.mark.asyncio


def _silent_wav(seconds: float = 1.0, sr: int = 24000) -> bytes:
    samples = np.zeros(int(seconds * sr), dtype=np.float32)
    buf = io.BytesIO()
    sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
    return buf.getvalue()


async def test_dialog_generates_concatenated_wav(monkeypatch, fake_classes):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    # Have FakeAdapter emit a real silent WAV so the dialog generator can decode it.
    monkeypatch.setattr(
        fake_classes["fake"],
        "generate",
        lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
    )
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
        files = {
            "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
            "reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"),
        }
        r = await c.post(
            "/api/generate/dialog",
            data={
                "text": "SPEAKER A: hi\nSPEAKER B: hello",
                "engine_id": "fake",
                "params": "{}",
            },
            files=files,
        )
    assert r.status_code == 200
    assert r.headers["content-type"].startswith("audio/wav")
    assert r.content[:4] == b"RIFF"
    assert r.headers["x-seed-used"] == "0"


async def test_dialog_format_invalid(monkeypatch, fake_classes):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
        r = await c.post(
            "/api/generate/dialog",
            data={"text": "no speaker tags", "engine_id": "fake", "params": "{}"},
            files={
                "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
            },
        )
    assert r.status_code == 400
    assert r.json()["error"]["code"] == "dialog_format_invalid"


async def test_dialog_missing_reference(monkeypatch, fake_classes):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    monkeypatch.setattr(
        fake_classes["fake"],
        "generate",
        lambda self, text, ref, lang, p: (_silent_wav(0.2), 24000, 0),
    )
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
        r = await c.post(
            "/api/generate/dialog",
            data={
                "text": "SPEAKER A: hi\nSPEAKER B: hello",
                "engine_id": "fake",
                "params": "{}",
            },
            files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
        )
    assert r.status_code == 400
    assert r.json()["error"]["code"] == "dialog_missing_reference"


async def test_dialog_unknown_engine_404(monkeypatch, fake_classes):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
        r = await c.post(
            "/api/generate/dialog",
            data={
                "text": "SPEAKER A: hi",
                "engine_id": "nope",
                "params": "{}",
            },
            files={"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav")},
        )
    assert r.status_code == 404
    assert r.json()["error"]["code"] == "model_not_found"