techfreakworm commited on
Commit
422829d
·
unverified ·
1 Parent(s): e859fb0

feat(progress): ProgressBus with sessions, ticks, and turn-complete events

Browse files
Files changed (3) hide show
  1. server/progress.py +174 -0
  2. tests/conftest.py +11 -0
  3. tests/test_progress.py +121 -0
server/progress.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Progress event bus for in-flight generations.
2
+
3
+ Endpoints (`/api/generate`, `/api/generate/dialog`) wrap their work in
4
+ `bus.session(...)` which emits `start` and `done`/`error` events plus a
5
+ 0.5s `tick` heartbeat. Dialog mode also emits `turn_complete` between
6
+ adapter calls. Subscribers receive events via `subscribe()` (used by
7
+ the SSE endpoint).
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import time
13
+ from contextlib import asynccontextmanager
14
+ from dataclasses import dataclass, field
15
+ from typing import AsyncIterator, Literal
16
+
17
+
18
+ EventType = Literal["start", "tick", "turn_complete", "done", "error"]
19
+
20
+
21
+ @dataclass
22
+ class ProgressEvent:
23
+ type: EventType
24
+ elapsed_s: float
25
+ payload: dict = field(default_factory=dict)
26
+
27
+ def to_dict(self) -> dict:
28
+ return {"type": self.type, "elapsed_s": round(self.elapsed_s, 2), **self.payload}
29
+
30
+
31
+ class ProgressBus:
32
+ def __init__(self) -> None:
33
+ self._subscribers: list[asyncio.Queue[ProgressEvent]] = []
34
+ self._lock = asyncio.Lock()
35
+ self._current_session: "_Session | None" = None
36
+
37
+ async def publish(self, event: ProgressEvent) -> None:
38
+ async with self._lock:
39
+ subs = list(self._subscribers)
40
+ for q in subs:
41
+ await q.put(event)
42
+
43
+ @asynccontextmanager
44
+ async def subscribe(self) -> AsyncIterator[asyncio.Queue[ProgressEvent]]:
45
+ q: asyncio.Queue[ProgressEvent] = asyncio.Queue()
46
+ async with self._lock:
47
+ self._subscribers.append(q)
48
+ if self._current_session is not None:
49
+ snapshot = self._current_session.snapshot_event()
50
+ if snapshot is not None:
51
+ await q.put(snapshot)
52
+ try:
53
+ yield q
54
+ finally:
55
+ async with self._lock:
56
+ if q in self._subscribers:
57
+ self._subscribers.remove(q)
58
+
59
+ @asynccontextmanager
60
+ async def session(
61
+ self, kind: Literal["single", "dialog"], total_turns: int = 1,
62
+ ) -> AsyncIterator["_Session"]:
63
+ session = _Session(bus=self, kind=kind, total_turns=total_turns)
64
+ async with self._lock:
65
+ self._current_session = session
66
+ await self.publish(
67
+ ProgressEvent(
68
+ type="start",
69
+ elapsed_s=0.0,
70
+ payload={"kind": kind, "total_turns": total_turns, "turn": 0},
71
+ ),
72
+ )
73
+ ticker = asyncio.create_task(session._tick_loop())
74
+ try:
75
+ yield session
76
+ await self.publish(
77
+ ProgressEvent(
78
+ type="done",
79
+ elapsed_s=session.elapsed(),
80
+ payload={
81
+ "kind": kind,
82
+ "seed_used": session.seed_used,
83
+ "turn": session.turn,
84
+ "total_turns": total_turns,
85
+ },
86
+ ),
87
+ )
88
+ except Exception as exc:
89
+ await self.publish(
90
+ ProgressEvent(
91
+ type="error",
92
+ elapsed_s=session.elapsed(),
93
+ payload={"message": str(exc)},
94
+ ),
95
+ )
96
+ raise
97
+ finally:
98
+ ticker.cancel()
99
+ try:
100
+ await ticker
101
+ except asyncio.CancelledError:
102
+ pass
103
+ async with self._lock:
104
+ if self._current_session is session:
105
+ self._current_session = None
106
+
107
+
108
+ @dataclass
109
+ class _Session:
110
+ bus: ProgressBus
111
+ kind: Literal["single", "dialog"]
112
+ total_turns: int
113
+ started_at: float = field(default_factory=time.monotonic)
114
+ turn: int = 0
115
+ seed_used: int | None = None
116
+
117
+ def elapsed(self) -> float:
118
+ return time.monotonic() - self.started_at
119
+
120
+ def set_seed(self, seed: int) -> None:
121
+ self.seed_used = seed
122
+
123
+ async def turn_complete(self, turn_index: int) -> None:
124
+ self.turn = turn_index
125
+ await self.bus.publish(
126
+ ProgressEvent(
127
+ type="turn_complete",
128
+ elapsed_s=self.elapsed(),
129
+ payload={
130
+ "turn": turn_index,
131
+ "total_turns": self.total_turns,
132
+ "kind": self.kind,
133
+ },
134
+ ),
135
+ )
136
+
137
+ async def _tick_loop(self) -> None:
138
+ try:
139
+ while True:
140
+ await asyncio.sleep(0.5)
141
+ await self.bus.publish(
142
+ ProgressEvent(
143
+ type="tick",
144
+ elapsed_s=self.elapsed(),
145
+ payload={
146
+ "kind": self.kind,
147
+ "turn": self.turn,
148
+ "total_turns": self.total_turns,
149
+ },
150
+ ),
151
+ )
152
+ except asyncio.CancelledError:
153
+ pass
154
+
155
+ def snapshot_event(self) -> ProgressEvent | None:
156
+ return ProgressEvent(
157
+ type="tick",
158
+ elapsed_s=self.elapsed(),
159
+ payload={
160
+ "kind": self.kind,
161
+ "turn": self.turn,
162
+ "total_turns": self.total_turns,
163
+ },
164
+ )
165
+
166
+
167
+ _BUS: ProgressBus | None = None
168
+
169
+
170
+ def get_bus() -> ProgressBus:
171
+ global _BUS
172
+ if _BUS is None:
173
+ _BUS = ProgressBus()
174
+ return _BUS
tests/conftest.py CHANGED
@@ -59,3 +59,14 @@ async def lifespan_ctx(app):
59
  """Run an ASGI app's lifespan startup/shutdown around an `httpx.AsyncClient`."""
60
  async with app.router.lifespan_context(app):
61
  yield
 
 
 
 
 
 
 
 
 
 
 
 
59
  """Run an ASGI app's lifespan startup/shutdown around an `httpx.AsyncClient`."""
60
  async with app.router.lifespan_context(app):
61
  yield
62
+
63
+
64
+ @pytest.fixture(autouse=False)
65
+ def reset_progress_bus():
66
+ """Reset server.progress._BUS so each test gets a fresh bus."""
67
+ import server.progress as p
68
+ p._BUS = None
69
+ try:
70
+ yield
71
+ finally:
72
+ p._BUS = None
tests/test_progress.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+ from server.progress import ProgressEvent, get_bus
6
+
7
+
8
+ pytestmark = pytest.mark.asyncio
9
+
10
+
11
+ async def test_subscribe_receives_published_events(reset_progress_bus):
12
+ bus = get_bus()
13
+ async with bus.subscribe() as q:
14
+ await bus.publish(ProgressEvent(type="tick", elapsed_s=0.1, payload={"foo": 1}))
15
+ evt = await asyncio.wait_for(q.get(), 0.5)
16
+ assert evt.type == "tick"
17
+ assert evt.payload == {"foo": 1}
18
+
19
+
20
+ async def test_two_subscribers_both_receive_events(reset_progress_bus):
21
+ bus = get_bus()
22
+ async with bus.subscribe() as q1, bus.subscribe() as q2:
23
+ await bus.publish(ProgressEvent(type="tick", elapsed_s=0.0))
24
+ a = await asyncio.wait_for(q1.get(), 0.5)
25
+ b = await asyncio.wait_for(q2.get(), 0.5)
26
+ assert a.type == "tick"
27
+ assert b.type == "tick"
28
+
29
+
30
+ async def test_session_emits_start_and_done(reset_progress_bus):
31
+ bus = get_bus()
32
+ received: list[ProgressEvent] = []
33
+
34
+ async def collect():
35
+ async with bus.subscribe() as q:
36
+ while True:
37
+ received.append(await q.get())
38
+ if received[-1].type == "done":
39
+ return
40
+
41
+ consumer = asyncio.create_task(collect())
42
+ await asyncio.sleep(0) # let subscriber register
43
+
44
+ async with bus.session("single", total_turns=1) as sess:
45
+ sess.set_seed(42)
46
+
47
+ await asyncio.wait_for(consumer, 1.0)
48
+ types = [e.type for e in received]
49
+ assert types[0] == "start"
50
+ assert types[-1] == "done"
51
+ done_payload = received[-1].payload
52
+ assert done_payload["seed_used"] == 42
53
+
54
+
55
+ async def test_session_emits_error_on_exception_and_reraises(reset_progress_bus):
56
+ bus = get_bus()
57
+ received: list[ProgressEvent] = []
58
+
59
+ async def collect():
60
+ async with bus.subscribe() as q:
61
+ while True:
62
+ received.append(await q.get())
63
+ if received[-1].type in ("done", "error"):
64
+ return
65
+
66
+ consumer = asyncio.create_task(collect())
67
+ await asyncio.sleep(0)
68
+
69
+ with pytest.raises(RuntimeError):
70
+ async with bus.session("single", total_turns=1):
71
+ raise RuntimeError("boom")
72
+
73
+ await asyncio.wait_for(consumer, 1.0)
74
+ types = [e.type for e in received]
75
+ assert "error" in types
76
+ assert any(e.payload.get("message") == "boom" for e in received)
77
+
78
+
79
+ async def test_turn_complete_event_carries_turn_payload(reset_progress_bus):
80
+ bus = get_bus()
81
+ received: list[ProgressEvent] = []
82
+
83
+ async def collect():
84
+ async with bus.subscribe() as q:
85
+ while True:
86
+ received.append(await q.get())
87
+ if received[-1].type == "done":
88
+ return
89
+
90
+ consumer = asyncio.create_task(collect())
91
+ await asyncio.sleep(0)
92
+
93
+ async with bus.session("dialog", total_turns=3) as sess:
94
+ await sess.turn_complete(1)
95
+ await sess.turn_complete(2)
96
+ await sess.turn_complete(3)
97
+
98
+ await asyncio.wait_for(consumer, 1.0)
99
+ turn_events = [e for e in received if e.type == "turn_complete"]
100
+ assert [e.payload["turn"] for e in turn_events] == [1, 2, 3]
101
+ assert all(e.payload["total_turns"] == 3 for e in turn_events)
102
+
103
+
104
+ async def test_late_subscriber_gets_snapshot(reset_progress_bus):
105
+ bus = get_bus()
106
+ received: list[ProgressEvent] = []
107
+
108
+ async def collect():
109
+ async with bus.subscribe() as q:
110
+ received.append(await asyncio.wait_for(q.get(), 1.0))
111
+
112
+ async with bus.session("dialog", total_turns=4) as sess:
113
+ await sess.turn_complete(2)
114
+ # join AFTER the session started
115
+ consumer = asyncio.create_task(collect())
116
+ await asyncio.wait_for(consumer, 1.0)
117
+
118
+ assert received[0].type == "tick"
119
+ assert received[0].payload["kind"] == "dialog"
120
+ assert received[0].payload["turn"] == 2
121
+ assert received[0].payload["total_turns"] == 4