techfreakworm commited on
Commit
ecf13ab
·
unverified ·
1 Parent(s): 5043eea

feat(progress): /api/progress SSE endpoint; wrap /api/generate in session

Browse files
Files changed (2) hide show
  1. server/main.py +19 -3
  2. tests/test_main_progress_sse.py +139 -0
server/main.py CHANGED
@@ -21,6 +21,7 @@ from server.dialog import (
21
  DialogReferenceError,
22
  generate_dialog,
23
  )
 
24
  from server.registry import Registry
25
  from server.zerogpu import decorate
26
 
@@ -106,6 +107,18 @@ def build_app() -> FastAPI:
106
 
107
  return EventSourceResponse(gen())
108
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @app.post("/api/generate")
110
  async def generate(
111
  text: str = Form(...),
@@ -155,10 +168,13 @@ def build_app() -> FastAPI:
155
  ref_path = tmp.name
156
 
157
  gen_fn = decorate(adapter.generate)
 
158
  try:
159
- wav_bytes, _sr, seed_used = gen_fn(
160
- text, ref_path, language, json.loads(params or "{}")
161
- )
 
 
162
  except Exception as exc:
163
  return JSONResponse(
164
  status_code=500,
 
21
  DialogReferenceError,
22
  generate_dialog,
23
  )
24
+ from server.progress import get_bus
25
  from server.registry import Registry
26
  from server.zerogpu import decorate
27
 
 
107
 
108
  return EventSourceResponse(gen())
109
 
110
+ @app.get("/api/progress")
111
+ async def progress_events():
112
+ bus = get_bus()
113
+
114
+ async def gen():
115
+ async with bus.subscribe() as q:
116
+ while True:
117
+ evt = await q.get()
118
+ yield {"data": json.dumps(evt.to_dict())}
119
+
120
+ return EventSourceResponse(gen())
121
+
122
  @app.post("/api/generate")
123
  async def generate(
124
  text: str = Form(...),
 
168
  ref_path = tmp.name
169
 
170
  gen_fn = decorate(adapter.generate)
171
+ bus = get_bus()
172
  try:
173
+ async with bus.session("single", total_turns=1) as sess:
174
+ wav_bytes, _sr, seed_used = gen_fn(
175
+ text, ref_path, language, json.loads(params or "{}")
176
+ )
177
+ sess.set_seed(seed_used)
178
  except Exception as exc:
179
  return JSONResponse(
180
  status_code=500,
tests/test_main_progress_sse.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+
4
+ import httpx
5
+ import pytest
6
+
7
+ from server.main import build_app
8
+
9
+
10
+ pytestmark = pytest.mark.asyncio
11
+
12
+
13
+ async def _run_sse_until_done(app, path="/api/progress", timeout=3.0):
14
+ """Drive the ASGI SSE endpoint manually and collect parsed events until
15
+ a 'done'/'error' event arrives or the timeout fires.
16
+
17
+ Note: httpx ASGITransport buffers the entire response before returning,
18
+ so it can't be used to stream a long-lived SSE response. We invoke the
19
+ ASGI app directly with a bespoke receive/send pair and parse SSE frames
20
+ out of the body chunks as they're emitted. Returns (events, timed_out).
21
+ """
22
+ events: list[dict] = []
23
+ request_consumed = asyncio.Event()
24
+ stop = asyncio.Event()
25
+
26
+ async def receive():
27
+ if request_consumed.is_set():
28
+ # Hold here until the test signals the client wants to disconnect.
29
+ await stop.wait()
30
+ return {"type": "http.disconnect"}
31
+ request_consumed.set()
32
+ return {"type": "http.request", "body": b"", "more_body": False}
33
+
34
+ async def send(message):
35
+ if message["type"] == "http.response.body":
36
+ body = message.get("body", b"")
37
+ for line in body.decode("utf-8", errors="replace").splitlines():
38
+ line = line.strip()
39
+ if not line.startswith("data:"):
40
+ continue
41
+ payload = line[len("data:") :].strip()
42
+ if not payload:
43
+ continue
44
+ try:
45
+ evt = json.loads(payload)
46
+ except json.JSONDecodeError:
47
+ continue
48
+ events.append(evt)
49
+ if evt.get("type") in ("done", "error"):
50
+ stop.set()
51
+
52
+ scope = {
53
+ "type": "http",
54
+ "asgi": {"version": "3.0"},
55
+ "http_version": "1.1",
56
+ "method": "GET",
57
+ "headers": [],
58
+ "scheme": "http",
59
+ "path": path,
60
+ "raw_path": path.encode(),
61
+ "query_string": b"",
62
+ "server": ("test", 80),
63
+ "client": ("test", 12345),
64
+ "root_path": "",
65
+ }
66
+
67
+ app_task = asyncio.create_task(app(scope, receive, send))
68
+ timed_out = False
69
+ try:
70
+ await asyncio.wait_for(stop.wait(), timeout=timeout)
71
+ except asyncio.TimeoutError:
72
+ timed_out = True
73
+ stop.set()
74
+ # Allow disconnect to propagate, then cancel the app task if still alive.
75
+ await asyncio.sleep(0.05)
76
+ if not app_task.done():
77
+ app_task.cancel()
78
+ try:
79
+ await app_task
80
+ except (asyncio.CancelledError, Exception):
81
+ pass
82
+ return events, timed_out
83
+
84
+
85
+ async def test_single_generate_emits_start_and_done(
86
+ monkeypatch, fake_classes, reset_progress_bus,
87
+ ):
88
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
89
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
90
+ app = build_app()
91
+ from tests.conftest import lifespan_ctx
92
+ transport = httpx.ASGITransport(app=app)
93
+ async with lifespan_ctx(app), httpx.AsyncClient(
94
+ transport=transport, base_url="http://t",
95
+ ) as c:
96
+ # Start collecting SSE events from a parallel ASGI invocation.
97
+ sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=3.0))
98
+ # Give the subscriber a moment to register before generate fires.
99
+ await asyncio.sleep(0.05)
100
+ gen_resp = await c.post(
101
+ "/api/generate",
102
+ data={"text": "hi", "model_id": "fake", "params": "{}"},
103
+ )
104
+ events, timed_out = await sse_task
105
+
106
+ assert gen_resp.status_code == 200
107
+ assert not timed_out, f"SSE timed out before 'done'; got events: {events}"
108
+ types = [e["type"] for e in events]
109
+ assert types[0] == "start"
110
+ assert "done" in types
111
+ done = next(e for e in events if e["type"] == "done")
112
+ assert done["seed_used"] == 0
113
+ assert done["kind"] == "single"
114
+
115
+
116
+ async def test_unknown_engine_does_not_emit_progress(
117
+ monkeypatch, fake_classes, reset_progress_bus,
118
+ ):
119
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
120
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
121
+ app = build_app()
122
+ from tests.conftest import lifespan_ctx
123
+ transport = httpx.ASGITransport(app=app)
124
+ async with lifespan_ctx(app), httpx.AsyncClient(
125
+ transport=transport, base_url="http://t",
126
+ ) as c:
127
+ sse_task = asyncio.create_task(_run_sse_until_done(app, timeout=0.6))
128
+ await asyncio.sleep(0.05)
129
+ r = await c.post(
130
+ "/api/generate",
131
+ data={"text": "x", "model_id": "nope", "params": "{}"},
132
+ )
133
+ events, timed_out = await sse_task
134
+
135
+ assert r.status_code == 404
136
+ # Bus stayed quiet — no start/done fired because the route 404'd before
137
+ # entering the session.
138
+ assert timed_out
139
+ assert events == []