Spaces:
Sleeping
Sleeping
| import asyncio | |
| import pytest | |
| from server.registry import Registry | |
| pytestmark = pytest.mark.asyncio | |
| async def test_get_or_load_loads_first_time(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| a = await reg.get_or_load("fake") | |
| assert a.loaded is True | |
| assert reg.status()["status"] == "loaded" | |
| assert reg.status()["id"] == "fake" | |
| async def test_get_or_load_reuses_active(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| a1 = await reg.get_or_load("fake") | |
| a2 = await reg.get_or_load("fake") | |
| assert a1 is a2 | |
| async def test_get_or_load_swaps_to_different(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| a = await reg.get_or_load("fake") | |
| b = await reg.get_or_load("fake-b") | |
| assert b.loaded is True | |
| assert a.unload_called is True | |
| assert reg.status()["id"] == "fake-b" | |
| async def test_get_or_load_unknown_id_raises(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| with pytest.raises(KeyError): | |
| await reg.get_or_load("nope") | |
| async def test_load_failure_sets_error_status(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| fake_classes["fake"].instances.clear() | |
| orig_init = fake_classes["fake"].__init__ | |
| def patched_init(self, device): | |
| orig_init(self, device) | |
| self.load_should_fail = True | |
| fake_classes["fake"].__init__ = patched_init | |
| try: | |
| with pytest.raises(RuntimeError): | |
| await reg.get_or_load("fake") | |
| s = reg.status() | |
| assert s["status"] == "error" | |
| assert "simulated" in s["last_error"] | |
| finally: | |
| fake_classes["fake"].__init__ = orig_init | |
| async def test_concurrent_activations_serialize(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| # All three resolve without error; the final state is the last requested model. | |
| await asyncio.gather( | |
| reg.get_or_load("fake"), | |
| reg.get_or_load("fake-b"), | |
| reg.get_or_load("fake"), | |
| ) | |
| s = reg.status() | |
| assert s["status"] == "loaded" | |
| assert s["id"] == "fake" | |
| async def test_emits_loading_then_loaded_events(fake_classes): | |
| reg = Registry(adapter_classes=fake_classes, device="cpu") | |
| seen: list[dict] = [] | |
| async def collect(): | |
| async for evt in reg.stream_events(): | |
| seen.append(evt) | |
| if evt["status"] == "loaded": | |
| return | |
| consumer = asyncio.create_task(collect()) | |
| await asyncio.sleep(0) | |
| await reg.get_or_load("fake") | |
| await asyncio.wait_for(consumer, timeout=2) | |
| statuses = [e["status"] for e in seen] | |
| assert "loading" in statuses | |
| assert "loaded" in statuses | |