"""Model loading and caching utilities.""" import os from typing import Optional, Any class ModelLoader: """Singleton-like model loader that lazily loads and caches models.""" def __init__(self, config: dict = None): self.config = config or {} self._cache = {} self._device = self._get_device() def _get_device(self) -> str: import torch if torch.cuda.is_available(): return "cuda" # Check for MPS (Apple Silicon) if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" def get_clip_model(self): """Load CLIP model and processor.""" if "clip" in self._cache: return self._cache["clip"] model_id = self.config.get("models", {}).get("clip", {}).get("model_id", "openai/clip-vit-base-patch32") try: import torch from transformers import CLIPModel, CLIPProcessor print(f"[ModelLoader] Loading CLIP: {model_id}...") model = CLIPModel.from_pretrained(model_id) processor = CLIPProcessor.from_pretrained(model_id) model = model.to(self._device) if self._device == "cpu": model = model.float() # Use full precision on CPU self._cache["clip"] = (model, processor) print(f"[ModelLoader] CLIP loaded on {self._device}") return self._cache["clip"] except Exception as e: print(f"[ModelLoader] CLIP load FAILED: {e}") self._cache["clip"] = None return None def get_device(self) -> str: return self._device def warmup(self): """Preload critical models.""" self.get_clip_model() def clear(self): """Clear model cache to free memory.""" import torch import gc for key in list(self._cache.keys()): del self._cache[key] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() self._cache.clear()