Spaces:
Sleeping
Sleeping
File size: 3,350 Bytes
71c1ad2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | # app/models/model_registry.py
# Singleton model registry — loads and manages all ML models
from app.models.text_model import TextToxicityModel
from app.models.image_model import ImageClassificationModel
from app.models.clip_model import CLIPModel
from app.observability.logging import get_logger
logger = get_logger(__name__)
class ModelRegistry:
"""
Central registry for all ML models.
Provides lazy-loading and lifecycle management.
Models are loaded once and reused across requests.
"""
def __init__(self):
self._text_model: TextToxicityModel | None = None
self._image_model: ImageClassificationModel | None = None
self._clip_model: CLIPModel | None = None
async def load_all(self) -> dict[str, bool]:
"""
Load all models. Called during application startup.
Returns:
Dict of model name → loaded status.
"""
results = {}
# Text model (required)
logger.info("registry_loading", model="text_toxicity")
try:
self._text_model = TextToxicityModel()
self._text_model.load()
results["text"] = True
except Exception as e:
logger.error("text_model_load_failed", error=str(e))
results["text"] = False
# Image model (required)
logger.info("registry_loading", model="image_classifier")
try:
self._image_model = ImageClassificationModel()
self._image_model.load()
results["image"] = True
except Exception as e:
logger.error("image_model_load_failed", error=str(e))
results["image"] = False
# CLIP model (optional — only for deep analysis)
logger.info("registry_loading", model="clip")
try:
self._clip_model = CLIPModel()
self._clip_model.load()
results["clip"] = self._clip_model.is_loaded
except Exception as e:
logger.warning("clip_model_load_failed", error=str(e))
results["clip"] = False
logger.info("registry_loaded", results=results)
return results
@property
def text_model(self) -> TextToxicityModel:
if self._text_model is None or not self._text_model.is_loaded:
raise RuntimeError("Text model not available")
return self._text_model
@property
def image_model(self) -> ImageClassificationModel:
if self._image_model is None or not self._image_model.is_loaded:
raise RuntimeError("Image model not available")
return self._image_model
@property
def clip_model(self) -> CLIPModel:
if self._clip_model is None:
raise RuntimeError("CLIP model not available")
return self._clip_model
@property
def clip_available(self) -> bool:
return self._clip_model is not None and self._clip_model.is_loaded
def get_status(self) -> dict:
"""Get health status of all models."""
return {
"text_model": self._text_model.is_loaded if self._text_model else False,
"image_model": self._image_model.is_loaded if self._image_model else False,
"clip_model": self._clip_model.is_loaded if self._clip_model else False,
}
# Global singleton
model_registry = ModelRegistry()
|