| import asyncio |
|
|
| import pytest |
|
|
| from server.registry import Registry |
|
|
|
|
| pytestmark = pytest.mark.asyncio |
|
|
|
|
| async def test_get_or_load_loads_first_time(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| a = await reg.get_or_load("fake") |
| assert a.loaded is True |
| assert reg.status()["status"] == "loaded" |
| assert reg.status()["id"] == "fake" |
|
|
|
|
| async def test_get_or_load_reuses_active(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| a1 = await reg.get_or_load("fake") |
| a2 = await reg.get_or_load("fake") |
| assert a1 is a2 |
|
|
|
|
| async def test_get_or_load_swaps_to_different(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| a = await reg.get_or_load("fake") |
| b = await reg.get_or_load("fake-b") |
| assert b.loaded is True |
| assert a.unload_called is True |
| assert reg.status()["id"] == "fake-b" |
|
|
|
|
| async def test_get_or_load_unknown_id_raises(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| with pytest.raises(KeyError): |
| await reg.get_or_load("nope") |
|
|
|
|
| async def test_load_failure_sets_error_status(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| fake_classes["fake"].instances.clear() |
| orig_init = fake_classes["fake"].__init__ |
|
|
| def patched_init(self, device): |
| orig_init(self, device) |
| self.load_should_fail = True |
|
|
| fake_classes["fake"].__init__ = patched_init |
| try: |
| with pytest.raises(RuntimeError): |
| await reg.get_or_load("fake") |
| s = reg.status() |
| assert s["status"] == "error" |
| assert "simulated" in s["last_error"] |
| finally: |
| fake_classes["fake"].__init__ = orig_init |
|
|
|
|
| async def test_concurrent_activations_serialize(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| |
| await asyncio.gather( |
| reg.get_or_load("fake"), |
| reg.get_or_load("fake-b"), |
| reg.get_or_load("fake"), |
| ) |
| s = reg.status() |
| assert s["status"] == "loaded" |
| assert s["id"] == "fake" |
|
|
|
|
| async def test_emits_loading_then_loaded_events(fake_classes): |
| reg = Registry(adapter_classes=fake_classes, device="cpu") |
| seen: list[dict] = [] |
|
|
| async def collect(): |
| async for evt in reg.stream_events(): |
| seen.append(evt) |
| if evt["status"] == "loaded": |
| return |
|
|
| consumer = asyncio.create_task(collect()) |
| await asyncio.sleep(0) |
| await reg.get_or_load("fake") |
| await asyncio.wait_for(consumer, timeout=2) |
| statuses = [e["status"] for e in seen] |
| assert "loading" in statuses |
| assert "loaded" in statuses |
|
|