File size: 7,077 Bytes
ecf13ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d745c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import asyncio
import json

import httpx
import pytest

from server.main import build_app


pytestmark = pytest.mark.asyncio


async def _run_sse_until_done(app, path="/api/progress", timeout=3.0):
    """Drive the ASGI SSE endpoint manually and collect parsed events until
    a 'done'/'error' event arrives or the timeout fires.

    Note: httpx ASGITransport buffers the entire response before returning,
    so it can't be used to stream a long-lived SSE response. We invoke the
    ASGI app directly with a bespoke receive/send pair and parse SSE frames
    out of the body chunks as they're emitted. Returns (events, timed_out).
    """
    events: list[dict] = []
    request_consumed = asyncio.Event()
    stop = asyncio.Event()

    async def receive():
        if request_consumed.is_set():
            # Hold here until the test signals the client wants to disconnect.
            await stop.wait()
            return {"type": "http.disconnect"}
        request_consumed.set()
        return {"type": "http.request", "body": b"", "more_body": False}

    async def send(message):
        if message["type"] == "http.response.body":
            body = message.get("body", b"")
            for line in body.decode("utf-8", errors="replace").splitlines():
                line = line.strip()
                if not line.startswith("data:"):
                    continue
                payload = line[len("data:") :].strip()
                if not payload:
                    continue
                try:
                    evt = json.loads(payload)
                except json.JSONDecodeError:
                    continue
                events.append(evt)
                if evt.get("type") in ("done", "error"):
                    stop.set()

    scope = {
        "type": "http",
        "asgi": {"version": "3.0"},
        "http_version": "1.1",
        "method": "GET",
        "headers": [],
        "scheme": "http",
        "path": path,
        "raw_path": path.encode(),
        "query_string": b"",
        "server": ("test", 80),
        "client": ("test", 12345),
        "root_path": "",
    }

    app_task = asyncio.create_task(app(scope, receive, send))
    timed_out = False
    try:
        await asyncio.wait_for(stop.wait(), timeout=timeout)
    except asyncio.TimeoutError:
        timed_out = True
        stop.set()
    # Allow disconnect to propagate, then cancel the app task if still alive.
    await asyncio.sleep(0.05)
    if not app_task.done():
        app_task.cancel()
        try:
            await app_task
        except (asyncio.CancelledError, Exception):
            pass
    return events, timed_out


async def test_single_generate_emits_start_and_done(
    monkeypatch, fake_classes, reset_progress_bus,
):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(
        transport=transport, base_url="http://t",
    ) as c:
        # Start collecting SSE events from a parallel ASGI invocation.
        sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=3.0))
        # Give the subscriber a moment to register before generate fires.
        await asyncio.sleep(0.05)
        gen_resp = await c.post(
            "/api/generate",
            data={"text": "hi", "model_id": "fake", "params": "{}"},
        )
        events, timed_out = await sse_task

    assert gen_resp.status_code == 200
    assert not timed_out, f"SSE timed out before 'done'; got events: {events}"
    types = [e["type"] for e in events]
    assert types[0] == "start"
    assert "done" in types
    done = next(e for e in events if e["type"] == "done")
    assert done["seed_used"] == 0
    assert done["kind"] == "single"


async def test_unknown_engine_does_not_emit_progress(
    monkeypatch, fake_classes, reset_progress_bus,
):
    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(
        transport=transport, base_url="http://t",
    ) as c:
        sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=0.6))
        await asyncio.sleep(0.05)
        r = await c.post(
            "/api/generate",
            data={"text": "x", "model_id": "nope", "params": "{}"},
        )
        events, timed_out = await sse_task

    assert r.status_code == 404
    # Bus stayed quiet — no start/done fired because the route 404'd before
    # entering the session.
    assert timed_out
    assert events == []


async def test_dialog_emits_per_turn_events(
    monkeypatch, fake_classes, reset_progress_bus,
):
    import io

    import numpy as np
    import soundfile as sf

    def _silent_wav(seconds: float = 0.2, sr: int = 24000) -> bytes:
        samples = np.zeros(int(seconds * sr), dtype=np.float32)
        buf = io.BytesIO()
        sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
        return buf.getvalue()

    monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
    monkeypatch.setattr("server.main.select_device", lambda: "cpu")
    # Have FakeAdapter emit a real silent WAV so the dialog generator can decode it.
    monkeypatch.setattr(
        fake_classes["fake"],
        "generate",
        lambda self, text, ref, lang, p: (_silent_wav(0.1), 24000, 0),
    )
    app = build_app()
    from tests.conftest import lifespan_ctx
    transport = httpx.ASGITransport(app=app)
    async with lifespan_ctx(app), httpx.AsyncClient(
        transport=transport, base_url="http://t",
    ) as c:
        sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=5.0))
        # Give the subscriber a moment to register before generate fires.
        await asyncio.sleep(0.05)
        gen_resp = await c.post(
            "/api/generate/dialog",
            data={
                "text": "SPEAKER A: hi\nSPEAKER B: hello",
                "engine_id": "fake",
                "params": "{}",
            },
            files={
                "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
                "reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"),
            },
        )
        events, timed_out = await sse_task

    assert gen_resp.status_code == 200
    assert not timed_out, f"SSE timed out before 'done'; got events: {events}"
    types = [e["type"] for e in events]
    assert types[0] == "start"
    start = events[0]
    assert start["kind"] == "dialog"
    assert start["total_turns"] == 2
    turn_events = [e for e in events if e["type"] == "turn_complete"]
    turn_indices = [e["turn"] for e in turn_events]
    assert turn_indices == [1, 2]
    assert events[-1]["type"] == "done"