Spaces:
Sleeping
Sleeping
| import asyncio | |
| import pytest | |
| from server.progress import ProgressEvent, get_bus | |
| pytestmark = pytest.mark.asyncio | |
| async def test_subscribe_receives_published_events(reset_progress_bus): | |
| bus = get_bus() | |
| async with bus.subscribe() as q: | |
| await bus.publish(ProgressEvent(type="tick", elapsed_s=0.1, payload={"foo": 1})) | |
| evt = await asyncio.wait_for(q.get(), 0.5) | |
| assert evt.type == "tick" | |
| assert evt.payload == {"foo": 1} | |
| async def test_two_subscribers_both_receive_events(reset_progress_bus): | |
| bus = get_bus() | |
| async with bus.subscribe() as q1, bus.subscribe() as q2: | |
| await bus.publish(ProgressEvent(type="tick", elapsed_s=0.0)) | |
| a = await asyncio.wait_for(q1.get(), 0.5) | |
| b = await asyncio.wait_for(q2.get(), 0.5) | |
| assert a.type == "tick" | |
| assert b.type == "tick" | |
| async def test_session_emits_start_and_done(reset_progress_bus): | |
| bus = get_bus() | |
| received: list[ProgressEvent] = [] | |
| async def collect(): | |
| async with bus.subscribe() as q: | |
| while True: | |
| received.append(await q.get()) | |
| if received[-1].type == "done": | |
| return | |
| consumer = asyncio.create_task(collect()) | |
| await asyncio.sleep(0) # let subscriber register | |
| async with bus.session("single", total_turns=1) as sess: | |
| sess.set_seed(42) | |
| await asyncio.wait_for(consumer, 1.0) | |
| types = [e.type for e in received] | |
| assert types[0] == "start" | |
| assert types[-1] == "done" | |
| done_payload = received[-1].payload | |
| assert done_payload["seed_used"] == 42 | |
| async def test_session_emits_error_on_exception_and_reraises(reset_progress_bus): | |
| bus = get_bus() | |
| received: list[ProgressEvent] = [] | |
| async def collect(): | |
| async with bus.subscribe() as q: | |
| while True: | |
| received.append(await q.get()) | |
| if received[-1].type in ("done", "error"): | |
| return | |
| consumer = asyncio.create_task(collect()) | |
| await asyncio.sleep(0) | |
| with pytest.raises(RuntimeError): | |
| async with bus.session("single", total_turns=1): | |
| raise RuntimeError("boom") | |
| await asyncio.wait_for(consumer, 1.0) | |
| types = [e.type for e in received] | |
| assert "error" in types | |
| assert any(e.payload.get("message") == "boom" for e in received) | |
| async def test_turn_complete_event_carries_turn_payload(reset_progress_bus): | |
| bus = get_bus() | |
| received: list[ProgressEvent] = [] | |
| async def collect(): | |
| async with bus.subscribe() as q: | |
| while True: | |
| received.append(await q.get()) | |
| if received[-1].type == "done": | |
| return | |
| consumer = asyncio.create_task(collect()) | |
| await asyncio.sleep(0) | |
| async with bus.session("dialog", total_turns=3) as sess: | |
| await sess.turn_complete(1) | |
| await sess.turn_complete(2) | |
| await sess.turn_complete(3) | |
| await asyncio.wait_for(consumer, 1.0) | |
| turn_events = [e for e in received if e.type == "turn_complete"] | |
| assert [e.payload["turn"] for e in turn_events] == [1, 2, 3] | |
| assert all(e.payload["total_turns"] == 3 for e in turn_events) | |
| async def test_late_subscriber_gets_snapshot(reset_progress_bus): | |
| bus = get_bus() | |
| received: list[ProgressEvent] = [] | |
| async def collect(): | |
| async with bus.subscribe() as q: | |
| received.append(await asyncio.wait_for(q.get(), 1.0)) | |
| async with bus.session("dialog", total_turns=4) as sess: | |
| await sess.turn_complete(2) | |
| # join AFTER the session started | |
| consumer = asyncio.create_task(collect()) | |
| await asyncio.wait_for(consumer, 1.0) | |
| assert received[0].type == "tick" | |
| assert received[0].payload["kind"] == "dialog" | |
| assert received[0].payload["turn"] == 2 | |
| assert received[0].payload["total_turns"] == 4 | |