File size: 3,430 Bytes
e6b3389 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | """Active-model registry with async swap lock and SSE event bus."""
from __future__ import annotations
import asyncio
import gc
from typing import AsyncIterator
import torch
from server.models.base import ModelAdapter
class Registry:
def __init__(self, adapter_classes: dict[str, type], device: str):
self._classes = adapter_classes
self._device = device
self._active: ModelAdapter | None = None
self._active_id: str | None = None
self._status: str = "idle"
self._last_error: str | None = None
self._lock = asyncio.Lock()
self._subscribers: list[asyncio.Queue[dict]] = []
@property
def device(self) -> str:
return self._device
def status(self) -> dict:
return {"id": self._active_id, "status": self._status, "last_error": self._last_error}
def list_models(self) -> list[dict]:
return [
{
"id": cls.id,
"label": cls.label,
"description": cls.description,
"languages": [l.model_dump() for l in cls.languages],
"paralinguistic_tags": cls.paralinguistic_tags,
"supports_voice_clone": cls.supports_voice_clone,
"params": [p.model_dump() for p in cls.params],
}
for cls in self._classes.values()
]
async def get_or_load(self, model_id: str) -> ModelAdapter:
if model_id not in self._classes:
raise KeyError(model_id)
async with self._lock:
if self._active_id == model_id and self._active is not None:
return self._active
await self._publish({"id": model_id, "status": "loading"})
self._status = "loading"
self._last_error = None
if self._active is not None:
try:
self._active.unload()
finally:
self._active = None
self._free_caches()
try:
instance = self._classes[model_id](self._device)
instance.load()
except Exception as exc:
self._status = "error"
self._last_error = str(exc)
self._active_id = None
await self._publish({"id": model_id, "status": "error", "error": str(exc)})
raise
self._active = instance
self._active_id = model_id
self._status = "loaded"
await self._publish({"id": model_id, "status": "loaded"})
return instance
async def stream_events(self) -> AsyncIterator[dict]:
q: asyncio.Queue[dict] = asyncio.Queue()
self._subscribers.append(q)
try:
await q.put({"id": self._active_id, "status": self._status})
while True:
yield await q.get()
finally:
self._subscribers.remove(q)
async def _publish(self, event: dict) -> None:
for q in list(self._subscribers):
await q.put(event)
def _free_caches(self) -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
mps = getattr(torch.backends, "mps", None)
if mps and mps.is_available():
try:
torch.mps.empty_cache() # type: ignore[attr-defined]
except Exception:
pass
|