chatterbox-voice-studio / tests /test_registry.py
techfreakworm's picture
feat(registry): active-model swap with async lock and SSE event bus
e6b3389 unverified
import asyncio
import pytest
from server.registry import Registry
pytestmark = pytest.mark.asyncio
async def test_get_or_load_loads_first_time(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
a = await reg.get_or_load("fake")
assert a.loaded is True
assert reg.status()["status"] == "loaded"
assert reg.status()["id"] == "fake"
async def test_get_or_load_reuses_active(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
a1 = await reg.get_or_load("fake")
a2 = await reg.get_or_load("fake")
assert a1 is a2
async def test_get_or_load_swaps_to_different(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
a = await reg.get_or_load("fake")
b = await reg.get_or_load("fake-b")
assert b.loaded is True
assert a.unload_called is True
assert reg.status()["id"] == "fake-b"
async def test_get_or_load_unknown_id_raises(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
with pytest.raises(KeyError):
await reg.get_or_load("nope")
async def test_load_failure_sets_error_status(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
fake_classes["fake"].instances.clear()
orig_init = fake_classes["fake"].__init__
def patched_init(self, device):
orig_init(self, device)
self.load_should_fail = True
fake_classes["fake"].__init__ = patched_init
try:
with pytest.raises(RuntimeError):
await reg.get_or_load("fake")
s = reg.status()
assert s["status"] == "error"
assert "simulated" in s["last_error"]
finally:
fake_classes["fake"].__init__ = orig_init
async def test_concurrent_activations_serialize(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
# All three resolve without error; the final state is the last requested model.
await asyncio.gather(
reg.get_or_load("fake"),
reg.get_or_load("fake-b"),
reg.get_or_load("fake"),
)
s = reg.status()
assert s["status"] == "loaded"
assert s["id"] == "fake"
async def test_emits_loading_then_loaded_events(fake_classes):
reg = Registry(adapter_classes=fake_classes, device="cpu")
seen: list[dict] = []
async def collect():
async for evt in reg.stream_events():
seen.append(evt)
if evt["status"] == "loaded":
return
consumer = asyncio.create_task(collect())
await asyncio.sleep(0)
await reg.get_or_load("fake")
await asyncio.wait_for(consumer, timeout=2)
statuses = [e["status"] for e in seen]
assert "loading" in statuses
assert "loaded" in statuses