SentinelAI / app /models /model_registry.py
sajith-0701's picture
initial deployment for HF Spaces
71c1ad2
# app/models/model_registry.py
# Singleton model registry β€” loads and manages all ML models
from app.models.text_model import TextToxicityModel
from app.models.image_model import ImageClassificationModel
from app.models.clip_model import CLIPModel
from app.observability.logging import get_logger
logger = get_logger(__name__)
class ModelRegistry:
"""
Central registry for all ML models.
Provides lazy-loading and lifecycle management.
Models are loaded once and reused across requests.
"""
def __init__(self):
self._text_model: TextToxicityModel | None = None
self._image_model: ImageClassificationModel | None = None
self._clip_model: CLIPModel | None = None
async def load_all(self) -> dict[str, bool]:
"""
Load all models. Called during application startup.
Returns:
Dict of model name β†’ loaded status.
"""
results = {}
# Text model (required)
logger.info("registry_loading", model="text_toxicity")
try:
self._text_model = TextToxicityModel()
self._text_model.load()
results["text"] = True
except Exception as e:
logger.error("text_model_load_failed", error=str(e))
results["text"] = False
# Image model (required)
logger.info("registry_loading", model="image_classifier")
try:
self._image_model = ImageClassificationModel()
self._image_model.load()
results["image"] = True
except Exception as e:
logger.error("image_model_load_failed", error=str(e))
results["image"] = False
# CLIP model (optional β€” only for deep analysis)
logger.info("registry_loading", model="clip")
try:
self._clip_model = CLIPModel()
self._clip_model.load()
results["clip"] = self._clip_model.is_loaded
except Exception as e:
logger.warning("clip_model_load_failed", error=str(e))
results["clip"] = False
logger.info("registry_loaded", results=results)
return results
@property
def text_model(self) -> TextToxicityModel:
if self._text_model is None or not self._text_model.is_loaded:
raise RuntimeError("Text model not available")
return self._text_model
@property
def image_model(self) -> ImageClassificationModel:
if self._image_model is None or not self._image_model.is_loaded:
raise RuntimeError("Image model not available")
return self._image_model
@property
def clip_model(self) -> CLIPModel:
if self._clip_model is None:
raise RuntimeError("CLIP model not available")
return self._clip_model
@property
def clip_available(self) -> bool:
return self._clip_model is not None and self._clip_model.is_loaded
def get_status(self) -> dict:
"""Get health status of all models."""
return {
"text_model": self._text_model.is_loaded if self._text_model else False,
"image_model": self._image_model.is_loaded if self._image_model else False,
"clip_model": self._clip_model.is_loaded if self._clip_model else False,
}
# Global singleton
model_registry = ModelRegistry()