feat(registry): active-model swap with async lock and SSE event bus
Browse files- server/registry.py +98 -0
- tests/conftest.py +51 -0
- tests/test_registry.py +90 -0
server/registry.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Active-model registry with async swap lock and SSE event bus."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import gc
|
| 6 |
+
from typing import AsyncIterator
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from server.models.base import ModelAdapter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Registry:
|
| 14 |
+
def __init__(self, adapter_classes: dict[str, type], device: str):
|
| 15 |
+
self._classes = adapter_classes
|
| 16 |
+
self._device = device
|
| 17 |
+
self._active: ModelAdapter | None = None
|
| 18 |
+
self._active_id: str | None = None
|
| 19 |
+
self._status: str = "idle"
|
| 20 |
+
self._last_error: str | None = None
|
| 21 |
+
self._lock = asyncio.Lock()
|
| 22 |
+
self._subscribers: list[asyncio.Queue[dict]] = []
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def device(self) -> str:
|
| 26 |
+
return self._device
|
| 27 |
+
|
| 28 |
+
def status(self) -> dict:
|
| 29 |
+
return {"id": self._active_id, "status": self._status, "last_error": self._last_error}
|
| 30 |
+
|
| 31 |
+
def list_models(self) -> list[dict]:
|
| 32 |
+
return [
|
| 33 |
+
{
|
| 34 |
+
"id": cls.id,
|
| 35 |
+
"label": cls.label,
|
| 36 |
+
"description": cls.description,
|
| 37 |
+
"languages": [l.model_dump() for l in cls.languages],
|
| 38 |
+
"paralinguistic_tags": cls.paralinguistic_tags,
|
| 39 |
+
"supports_voice_clone": cls.supports_voice_clone,
|
| 40 |
+
"params": [p.model_dump() for p in cls.params],
|
| 41 |
+
}
|
| 42 |
+
for cls in self._classes.values()
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
async def get_or_load(self, model_id: str) -> ModelAdapter:
|
| 46 |
+
if model_id not in self._classes:
|
| 47 |
+
raise KeyError(model_id)
|
| 48 |
+
async with self._lock:
|
| 49 |
+
if self._active_id == model_id and self._active is not None:
|
| 50 |
+
return self._active
|
| 51 |
+
await self._publish({"id": model_id, "status": "loading"})
|
| 52 |
+
self._status = "loading"
|
| 53 |
+
self._last_error = None
|
| 54 |
+
if self._active is not None:
|
| 55 |
+
try:
|
| 56 |
+
self._active.unload()
|
| 57 |
+
finally:
|
| 58 |
+
self._active = None
|
| 59 |
+
self._free_caches()
|
| 60 |
+
try:
|
| 61 |
+
instance = self._classes[model_id](self._device)
|
| 62 |
+
instance.load()
|
| 63 |
+
except Exception as exc:
|
| 64 |
+
self._status = "error"
|
| 65 |
+
self._last_error = str(exc)
|
| 66 |
+
self._active_id = None
|
| 67 |
+
await self._publish({"id": model_id, "status": "error", "error": str(exc)})
|
| 68 |
+
raise
|
| 69 |
+
self._active = instance
|
| 70 |
+
self._active_id = model_id
|
| 71 |
+
self._status = "loaded"
|
| 72 |
+
await self._publish({"id": model_id, "status": "loaded"})
|
| 73 |
+
return instance
|
| 74 |
+
|
| 75 |
+
async def stream_events(self) -> AsyncIterator[dict]:
|
| 76 |
+
q: asyncio.Queue[dict] = asyncio.Queue()
|
| 77 |
+
self._subscribers.append(q)
|
| 78 |
+
try:
|
| 79 |
+
await q.put({"id": self._active_id, "status": self._status})
|
| 80 |
+
while True:
|
| 81 |
+
yield await q.get()
|
| 82 |
+
finally:
|
| 83 |
+
self._subscribers.remove(q)
|
| 84 |
+
|
| 85 |
+
async def _publish(self, event: dict) -> None:
|
| 86 |
+
for q in list(self._subscribers):
|
| 87 |
+
await q.put(event)
|
| 88 |
+
|
| 89 |
+
def _free_caches(self) -> None:
|
| 90 |
+
gc.collect()
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
mps = getattr(torch.backends, "mps", None)
|
| 94 |
+
if mps and mps.is_available():
|
| 95 |
+
try:
|
| 96 |
+
torch.mps.empty_cache() # type: ignore[attr-defined]
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared test fixtures."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from server.schemas import Lang, ParamSpec
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FakeAdapter:
|
| 12 |
+
"""Minimal in-memory adapter for unit tests."""
|
| 13 |
+
id = "fake"
|
| 14 |
+
label = "Fake"
|
| 15 |
+
description = "Test fake"
|
| 16 |
+
languages = [Lang(code="en", label="English")]
|
| 17 |
+
paralinguistic_tags: list[str] = ["[laugh]"]
|
| 18 |
+
supports_voice_clone = True
|
| 19 |
+
params = [ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0)]
|
| 20 |
+
|
| 21 |
+
instances: list["FakeAdapter"] = []
|
| 22 |
+
|
| 23 |
+
def __init__(self, device: str):
|
| 24 |
+
self.device = device
|
| 25 |
+
self.loaded = False
|
| 26 |
+
self.unload_called = False
|
| 27 |
+
self.load_should_fail = False
|
| 28 |
+
FakeAdapter.instances.append(self)
|
| 29 |
+
|
| 30 |
+
def load(self) -> None:
|
| 31 |
+
if self.load_should_fail:
|
| 32 |
+
raise RuntimeError("simulated load failure")
|
| 33 |
+
self.loaded = True
|
| 34 |
+
|
| 35 |
+
def unload(self) -> None:
|
| 36 |
+
self.unload_called = True
|
| 37 |
+
self.loaded = False
|
| 38 |
+
|
| 39 |
+
def generate(self, text, reference_wav_path, language, params):
|
| 40 |
+
return (b"FAKEWAV", 24000)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class FakeAdapterB(FakeAdapter):
|
| 44 |
+
id = "fake-b"
|
| 45 |
+
label = "Fake B"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@pytest.fixture
|
| 49 |
+
def fake_classes():
|
| 50 |
+
FakeAdapter.instances.clear()
|
| 51 |
+
return {FakeAdapter.id: FakeAdapter, FakeAdapterB.id: FakeAdapterB}
|
tests/test_registry.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from server.registry import Registry
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
pytestmark = pytest.mark.asyncio
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def test_get_or_load_loads_first_time(fake_classes):
|
| 12 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 13 |
+
a = await reg.get_or_load("fake")
|
| 14 |
+
assert a.loaded is True
|
| 15 |
+
assert reg.status()["status"] == "loaded"
|
| 16 |
+
assert reg.status()["id"] == "fake"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def test_get_or_load_reuses_active(fake_classes):
|
| 20 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 21 |
+
a1 = await reg.get_or_load("fake")
|
| 22 |
+
a2 = await reg.get_or_load("fake")
|
| 23 |
+
assert a1 is a2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def test_get_or_load_swaps_to_different(fake_classes):
|
| 27 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 28 |
+
a = await reg.get_or_load("fake")
|
| 29 |
+
b = await reg.get_or_load("fake-b")
|
| 30 |
+
assert b.loaded is True
|
| 31 |
+
assert a.unload_called is True
|
| 32 |
+
assert reg.status()["id"] == "fake-b"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def test_get_or_load_unknown_id_raises(fake_classes):
|
| 36 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 37 |
+
with pytest.raises(KeyError):
|
| 38 |
+
await reg.get_or_load("nope")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def test_load_failure_sets_error_status(fake_classes):
|
| 42 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 43 |
+
fake_classes["fake"].instances.clear()
|
| 44 |
+
orig_init = fake_classes["fake"].__init__
|
| 45 |
+
|
| 46 |
+
def patched_init(self, device):
|
| 47 |
+
orig_init(self, device)
|
| 48 |
+
self.load_should_fail = True
|
| 49 |
+
|
| 50 |
+
fake_classes["fake"].__init__ = patched_init
|
| 51 |
+
try:
|
| 52 |
+
with pytest.raises(RuntimeError):
|
| 53 |
+
await reg.get_or_load("fake")
|
| 54 |
+
s = reg.status()
|
| 55 |
+
assert s["status"] == "error"
|
| 56 |
+
assert "simulated" in s["last_error"]
|
| 57 |
+
finally:
|
| 58 |
+
fake_classes["fake"].__init__ = orig_init
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def test_concurrent_activations_serialize(fake_classes):
|
| 62 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 63 |
+
# All three resolve without error; the final state is the last requested model.
|
| 64 |
+
await asyncio.gather(
|
| 65 |
+
reg.get_or_load("fake"),
|
| 66 |
+
reg.get_or_load("fake-b"),
|
| 67 |
+
reg.get_or_load("fake"),
|
| 68 |
+
)
|
| 69 |
+
s = reg.status()
|
| 70 |
+
assert s["status"] == "loaded"
|
| 71 |
+
assert s["id"] == "fake"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
async def test_emits_loading_then_loaded_events(fake_classes):
|
| 75 |
+
reg = Registry(adapter_classes=fake_classes, device="cpu")
|
| 76 |
+
seen: list[dict] = []
|
| 77 |
+
|
| 78 |
+
async def collect():
|
| 79 |
+
async for evt in reg.stream_events():
|
| 80 |
+
seen.append(evt)
|
| 81 |
+
if evt["status"] == "loaded":
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
consumer = asyncio.create_task(collect())
|
| 85 |
+
await asyncio.sleep(0)
|
| 86 |
+
await reg.get_or_load("fake")
|
| 87 |
+
await asyncio.wait_for(consumer, timeout=2)
|
| 88 |
+
statuses = [e["status"] for e in seen]
|
| 89 |
+
assert "loading" in statuses
|
| 90 |
+
assert "loaded" in statuses
|