| import asyncio |
| import json |
|
|
| import httpx |
| import pytest |
|
|
| from server.main import build_app |
|
|
|
|
| pytestmark = pytest.mark.asyncio |
|
|
|
|
| async def _run_sse_until_done(app, path="/api/progress", timeout=3.0): |
| """Drive the ASGI SSE endpoint manually and collect parsed events until |
| a 'done'/'error' event arrives or the timeout fires. |
| |
| Note: httpx ASGITransport buffers the entire response before returning, |
| so it can't be used to stream a long-lived SSE response. We invoke the |
| ASGI app directly with a bespoke receive/send pair and parse SSE frames |
| out of the body chunks as they're emitted. Returns (events, timed_out). |
| """ |
| events: list[dict] = [] |
| request_consumed = asyncio.Event() |
| stop = asyncio.Event() |
|
|
| async def receive(): |
| if request_consumed.is_set(): |
| |
| await stop.wait() |
| return {"type": "http.disconnect"} |
| request_consumed.set() |
| return {"type": "http.request", "body": b"", "more_body": False} |
|
|
| async def send(message): |
| if message["type"] == "http.response.body": |
| body = message.get("body", b"") |
| for line in body.decode("utf-8", errors="replace").splitlines(): |
| line = line.strip() |
| if not line.startswith("data:"): |
| continue |
| payload = line[len("data:") :].strip() |
| if not payload: |
| continue |
| try: |
| evt = json.loads(payload) |
| except json.JSONDecodeError: |
| continue |
| events.append(evt) |
| if evt.get("type") in ("done", "error"): |
| stop.set() |
|
|
| scope = { |
| "type": "http", |
| "asgi": {"version": "3.0"}, |
| "http_version": "1.1", |
| "method": "GET", |
| "headers": [], |
| "scheme": "http", |
| "path": path, |
| "raw_path": path.encode(), |
| "query_string": b"", |
| "server": ("test", 80), |
| "client": ("test", 12345), |
| "root_path": "", |
| } |
|
|
| app_task = asyncio.create_task(app(scope, receive, send)) |
| timed_out = False |
| try: |
| await asyncio.wait_for(stop.wait(), timeout=timeout) |
| except asyncio.TimeoutError: |
| timed_out = True |
| stop.set() |
| |
| await asyncio.sleep(0.05) |
| if not app_task.done(): |
| app_task.cancel() |
| try: |
| await app_task |
| except (asyncio.CancelledError, Exception): |
| pass |
| return events, timed_out |
|
|
|
|
| async def test_single_generate_emits_start_and_done( |
| monkeypatch, fake_classes, reset_progress_bus, |
| ): |
| monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes) |
| monkeypatch.setattr("server.main.select_device", lambda: "cpu") |
| app = build_app() |
| from tests.conftest import lifespan_ctx |
| transport = httpx.ASGITransport(app=app) |
| async with lifespan_ctx(app), httpx.AsyncClient( |
| transport=transport, base_url="http://t", |
| ) as c: |
| |
| sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=3.0)) |
| |
| await asyncio.sleep(0.05) |
| gen_resp = await c.post( |
| "/api/generate", |
| data={"text": "hi", "model_id": "fake", "params": "{}"}, |
| ) |
| events, timed_out = await sse_task |
|
|
| assert gen_resp.status_code == 200 |
| assert not timed_out, f"SSE timed out before 'done'; got events: {events}" |
| types = [e["type"] for e in events] |
| assert types[0] == "start" |
| assert "done" in types |
| done = next(e for e in events if e["type"] == "done") |
| assert done["seed_used"] == 0 |
| assert done["kind"] == "single" |
|
|
|
|
| async def test_unknown_engine_does_not_emit_progress( |
| monkeypatch, fake_classes, reset_progress_bus, |
| ): |
| monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes) |
| monkeypatch.setattr("server.main.select_device", lambda: "cpu") |
| app = build_app() |
| from tests.conftest import lifespan_ctx |
| transport = httpx.ASGITransport(app=app) |
| async with lifespan_ctx(app), httpx.AsyncClient( |
| transport=transport, base_url="http://t", |
| ) as c: |
| sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=0.6)) |
| await asyncio.sleep(0.05) |
| r = await c.post( |
| "/api/generate", |
| data={"text": "x", "model_id": "nope", "params": "{}"}, |
| ) |
| events, timed_out = await sse_task |
|
|
| assert r.status_code == 404 |
| |
| |
| assert timed_out |
| assert events == [] |
|
|
|
|
| async def test_dialog_emits_per_turn_events( |
| monkeypatch, fake_classes, reset_progress_bus, |
| ): |
| import io |
|
|
| import numpy as np |
| import soundfile as sf |
|
|
| def _silent_wav(seconds: float = 0.2, sr: int = 24000) -> bytes: |
| samples = np.zeros(int(seconds * sr), dtype=np.float32) |
| buf = io.BytesIO() |
| sf.write(buf, samples, sr, format="WAV", subtype="PCM_16") |
| return buf.getvalue() |
|
|
| monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes) |
| monkeypatch.setattr("server.main.select_device", lambda: "cpu") |
| |
| monkeypatch.setattr( |
| fake_classes["fake"], |
| "generate", |
| lambda self, text, ref, lang, p: (_silent_wav(0.1), 24000, 0), |
| ) |
| app = build_app() |
| from tests.conftest import lifespan_ctx |
| transport = httpx.ASGITransport(app=app) |
| async with lifespan_ctx(app), httpx.AsyncClient( |
| transport=transport, base_url="http://t", |
| ) as c: |
| sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=5.0)) |
| |
| await asyncio.sleep(0.05) |
| gen_resp = await c.post( |
| "/api/generate/dialog", |
| data={ |
| "text": "SPEAKER A: hi\nSPEAKER B: hello", |
| "engine_id": "fake", |
| "params": "{}", |
| }, |
| files={ |
| "reference_wav_a": ("a.wav", _silent_wav(1.0), "audio/wav"), |
| "reference_wav_b": ("b.wav", _silent_wav(1.0), "audio/wav"), |
| }, |
| ) |
| events, timed_out = await sse_task |
|
|
| assert gen_resp.status_code == 200 |
| assert not timed_out, f"SSE timed out before 'done'; got events: {events}" |
| types = [e["type"] for e in events] |
| assert types[0] == "start" |
| start = events[0] |
| assert start["kind"] == "dialog" |
| assert start["total_turns"] == 2 |
| turn_events = [e for e in events if e["type"] == "turn_complete"] |
| turn_indices = [e["turn"] for e in turn_events] |
| assert turn_indices == [1, 2] |
| assert events[-1]["type"] == "done" |
|
|