File size: 7,077 Bytes
ecf13ab 2d745c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | import asyncio
import json
import httpx
import pytest
from server.main import build_app
pytestmark = pytest.mark.asyncio
async def _run_sse_until_done(app, path="/api/progress", timeout=3.0):
"""Drive the ASGI SSE endpoint manually and collect parsed events until
a 'done'/'error' event arrives or the timeout fires.
Note: httpx ASGITransport buffers the entire response before returning,
so it can't be used to stream a long-lived SSE response. We invoke the
ASGI app directly with a bespoke receive/send pair and parse SSE frames
out of the body chunks as they're emitted. Returns (events, timed_out).
"""
events: list[dict] = []
request_consumed = asyncio.Event()
stop = asyncio.Event()
async def receive():
if request_consumed.is_set():
# Hold here until the test signals the client wants to disconnect.
await stop.wait()
return {"type": "http.disconnect"}
request_consumed.set()
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message):
if message["type"] == "http.response.body":
body = message.get("body", b"")
for line in body.decode("utf-8", errors="replace").splitlines():
line = line.strip()
if not line.startswith("data:"):
continue
payload = line[len("data:") :].strip()
if not payload:
continue
try:
evt = json.loads(payload)
except json.JSONDecodeError:
continue
events.append(evt)
if evt.get("type") in ("done", "error"):
stop.set()
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"headers": [],
"scheme": "http",
"path": path,
"raw_path": path.encode(),
"query_string": b"",
"server": ("test", 80),
"client": ("test", 12345),
"root_path": "",
}
app_task = asyncio.create_task(app(scope, receive, send))
timed_out = False
try:
await asyncio.wait_for(stop.wait(), timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
stop.set()
# Allow disconnect to propagate, then cancel the app task if still alive.
await asyncio.sleep(0.05)
if not app_task.done():
app_task.cancel()
try:
await app_task
except (asyncio.CancelledError, Exception):
pass
return events, timed_out
async def test_single_generate_emits_start_and_done(
monkeypatch, fake_classes, reset_progress_bus,
):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(
transport=transport, base_url="http://t",
) as c:
# Start collecting SSE events from a parallel ASGI invocation.
sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=3.0))
# Give the subscriber a moment to register before generate fires.
await asyncio.sleep(0.05)
gen_resp = await c.post(
"/api/generate",
data={"text": "hi", "model_id": "fake", "params": "{}"},
)
events, timed_out = await sse_task
assert gen_resp.status_code == 200
assert not timed_out, f"SSE timed out before 'done'; got events: {events}"
types = [e["type"] for e in events]
assert types[0] == "start"
assert "done" in types
done = next(e for e in events if e["type"] == "done")
assert done["seed_used"] == 0
assert done["kind"] == "single"
async def test_unknown_engine_does_not_emit_progress(
monkeypatch, fake_classes, reset_progress_bus,
):
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(
transport=transport, base_url="http://t",
) as c:
sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=0.6))
await asyncio.sleep(0.05)
r = await c.post(
"/api/generate",
data={"text": "x", "model_id": "nope", "params": "{}"},
)
events, timed_out = await sse_task
assert r.status_code == 404
# Bus stayed quiet — no start/done fired because the route 404'd before
# entering the session.
assert timed_out
assert events == []
async def test_dialog_emits_per_turn_events(
monkeypatch, fake_classes, reset_progress_bus,
):
import io
import numpy as np
import soundfile as sf
def _silent_wav(seconds: float = 0.2, sr: int = 24000) -> bytes:
samples = np.zeros(int(seconds * sr), dtype=np.float32)
buf = io.BytesIO()
sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
return buf.getvalue()
monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
monkeypatch.setattr("server.main.select_device", lambda: "cpu")
# Have FakeAdapter emit a real silent WAV so the dialog generator can decode it.
monkeypatch.setattr(
fake_classes["fake"],
"generate",
lambda self, text, ref, lang, p: (_silent_wav(0.1), 24000, 0),
)
app = build_app()
from tests.conftest import lifespan_ctx
transport = httpx.ASGITransport(app=app)
async with lifespan_ctx(app), httpx.AsyncClient(
transport=transport, base_url="http://t",
) as c:
sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=5.0))
# Give the subscriber a moment to register before generate fires.
await asyncio.sleep(0.05)
gen_resp = await c.post(
"/api/generate/dialog",
data={
"text": "SPEAKER A: hi\nSPEAKER B: hello",
"engine_id": "fake",
"params": "{}",
},
files={
"reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"),
"reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"),
},
)
events, timed_out = await sse_task
assert gen_resp.status_code == 200
assert not timed_out, f"SSE timed out before 'done'; got events: {events}"
types = [e["type"] for e in events]
assert types[0] == "start"
start = events[0]
assert start["kind"] == "dialog"
assert start["total_turns"] == 2
turn_events = [e for e in events if e["type"] == "turn_complete"]
turn_indices = [e["turn"] for e in turn_events]
assert turn_indices == [1, 2]
assert events[-1]["type"] == "done"
|