techfreakworm commited on
Commit
e6b3389
·
unverified ·
1 Parent(s): f9569f4

feat(registry): active-model swap with async lock and SSE event bus

Browse files
Files changed (3) hide show
  1. server/registry.py +98 -0
  2. tests/conftest.py +51 -0
  3. tests/test_registry.py +90 -0
server/registry.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Active-model registry with async swap lock and SSE event bus."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import gc
6
+ from typing import AsyncIterator
7
+
8
+ import torch
9
+
10
+ from server.models.base import ModelAdapter
11
+
12
+
13
+ class Registry:
14
+ def __init__(self, adapter_classes: dict[str, type], device: str):
15
+ self._classes = adapter_classes
16
+ self._device = device
17
+ self._active: ModelAdapter | None = None
18
+ self._active_id: str | None = None
19
+ self._status: str = "idle"
20
+ self._last_error: str | None = None
21
+ self._lock = asyncio.Lock()
22
+ self._subscribers: list[asyncio.Queue[dict]] = []
23
+
24
+ @property
25
+ def device(self) -> str:
26
+ return self._device
27
+
28
+ def status(self) -> dict:
29
+ return {"id": self._active_id, "status": self._status, "last_error": self._last_error}
30
+
31
+ def list_models(self) -> list[dict]:
32
+ return [
33
+ {
34
+ "id": cls.id,
35
+ "label": cls.label,
36
+ "description": cls.description,
37
+ "languages": [l.model_dump() for l in cls.languages],
38
+ "paralinguistic_tags": cls.paralinguistic_tags,
39
+ "supports_voice_clone": cls.supports_voice_clone,
40
+ "params": [p.model_dump() for p in cls.params],
41
+ }
42
+ for cls in self._classes.values()
43
+ ]
44
+
45
+ async def get_or_load(self, model_id: str) -> ModelAdapter:
46
+ if model_id not in self._classes:
47
+ raise KeyError(model_id)
48
+ async with self._lock:
49
+ if self._active_id == model_id and self._active is not None:
50
+ return self._active
51
+ await self._publish({"id": model_id, "status": "loading"})
52
+ self._status = "loading"
53
+ self._last_error = None
54
+ if self._active is not None:
55
+ try:
56
+ self._active.unload()
57
+ finally:
58
+ self._active = None
59
+ self._free_caches()
60
+ try:
61
+ instance = self._classes[model_id](self._device)
62
+ instance.load()
63
+ except Exception as exc:
64
+ self._status = "error"
65
+ self._last_error = str(exc)
66
+ self._active_id = None
67
+ await self._publish({"id": model_id, "status": "error", "error": str(exc)})
68
+ raise
69
+ self._active = instance
70
+ self._active_id = model_id
71
+ self._status = "loaded"
72
+ await self._publish({"id": model_id, "status": "loaded"})
73
+ return instance
74
+
75
+ async def stream_events(self) -> AsyncIterator[dict]:
76
+ q: asyncio.Queue[dict] = asyncio.Queue()
77
+ self._subscribers.append(q)
78
+ try:
79
+ await q.put({"id": self._active_id, "status": self._status})
80
+ while True:
81
+ yield await q.get()
82
+ finally:
83
+ self._subscribers.remove(q)
84
+
85
+ async def _publish(self, event: dict) -> None:
86
+ for q in list(self._subscribers):
87
+ await q.put(event)
88
+
89
+ def _free_caches(self) -> None:
90
+ gc.collect()
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+ mps = getattr(torch.backends, "mps", None)
94
+ if mps and mps.is_available():
95
+ try:
96
+ torch.mps.empty_cache() # type: ignore[attr-defined]
97
+ except Exception:
98
+ pass
tests/conftest.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared test fixtures."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+
6
+ import pytest
7
+
8
+ from server.schemas import Lang, ParamSpec
9
+
10
+
11
+ class FakeAdapter:
12
+ """Minimal in-memory adapter for unit tests."""
13
+ id = "fake"
14
+ label = "Fake"
15
+ description = "Test fake"
16
+ languages = [Lang(code="en", label="English")]
17
+ paralinguistic_tags: list[str] = ["[laugh]"]
18
+ supports_voice_clone = True
19
+ params = [ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0)]
20
+
21
+ instances: list["FakeAdapter"] = []
22
+
23
+ def __init__(self, device: str):
24
+ self.device = device
25
+ self.loaded = False
26
+ self.unload_called = False
27
+ self.load_should_fail = False
28
+ FakeAdapter.instances.append(self)
29
+
30
+ def load(self) -> None:
31
+ if self.load_should_fail:
32
+ raise RuntimeError("simulated load failure")
33
+ self.loaded = True
34
+
35
+ def unload(self) -> None:
36
+ self.unload_called = True
37
+ self.loaded = False
38
+
39
+ def generate(self, text, reference_wav_path, language, params):
40
+ return (b"FAKEWAV", 24000)
41
+
42
+
43
+ class FakeAdapterB(FakeAdapter):
44
+ id = "fake-b"
45
+ label = "Fake B"
46
+
47
+
48
+ @pytest.fixture
49
+ def fake_classes():
50
+ FakeAdapter.instances.clear()
51
+ return {FakeAdapter.id: FakeAdapter, FakeAdapterB.id: FakeAdapterB}
tests/test_registry.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+ from server.registry import Registry
6
+
7
+
8
+ pytestmark = pytest.mark.asyncio
9
+
10
+
11
+ async def test_get_or_load_loads_first_time(fake_classes):
12
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
13
+ a = await reg.get_or_load("fake")
14
+ assert a.loaded is True
15
+ assert reg.status()["status"] == "loaded"
16
+ assert reg.status()["id"] == "fake"
17
+
18
+
19
+ async def test_get_or_load_reuses_active(fake_classes):
20
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
21
+ a1 = await reg.get_or_load("fake")
22
+ a2 = await reg.get_or_load("fake")
23
+ assert a1 is a2
24
+
25
+
26
+ async def test_get_or_load_swaps_to_different(fake_classes):
27
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
28
+ a = await reg.get_or_load("fake")
29
+ b = await reg.get_or_load("fake-b")
30
+ assert b.loaded is True
31
+ assert a.unload_called is True
32
+ assert reg.status()["id"] == "fake-b"
33
+
34
+
35
+ async def test_get_or_load_unknown_id_raises(fake_classes):
36
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
37
+ with pytest.raises(KeyError):
38
+ await reg.get_or_load("nope")
39
+
40
+
41
+ async def test_load_failure_sets_error_status(fake_classes):
42
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
43
+ fake_classes["fake"].instances.clear()
44
+ orig_init = fake_classes["fake"].__init__
45
+
46
+ def patched_init(self, device):
47
+ orig_init(self, device)
48
+ self.load_should_fail = True
49
+
50
+ fake_classes["fake"].__init__ = patched_init
51
+ try:
52
+ with pytest.raises(RuntimeError):
53
+ await reg.get_or_load("fake")
54
+ s = reg.status()
55
+ assert s["status"] == "error"
56
+ assert "simulated" in s["last_error"]
57
+ finally:
58
+ fake_classes["fake"].__init__ = orig_init
59
+
60
+
61
+ async def test_concurrent_activations_serialize(fake_classes):
62
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
63
+ # All three resolve without error; the final state is the last requested model.
64
+ await asyncio.gather(
65
+ reg.get_or_load("fake"),
66
+ reg.get_or_load("fake-b"),
67
+ reg.get_or_load("fake"),
68
+ )
69
+ s = reg.status()
70
+ assert s["status"] == "loaded"
71
+ assert s["id"] == "fake"
72
+
73
+
74
+ async def test_emits_loading_then_loaded_events(fake_classes):
75
+ reg = Registry(adapter_classes=fake_classes, device="cpu")
76
+ seen: list[dict] = []
77
+
78
+ async def collect():
79
+ async for evt in reg.stream_events():
80
+ seen.append(evt)
81
+ if evt["status"] == "loaded":
82
+ return
83
+
84
+ consumer = asyncio.create_task(collect())
85
+ await asyncio.sleep(0)
86
+ await reg.get_or_load("fake")
87
+ await asyncio.wait_for(consumer, timeout=2)
88
+ statuses = [e["status"] for e in seen]
89
+ assert "loading" in statuses
90
+ assert "loaded" in statuses