viral-images / models /loader.py
Babajaan's picture
Full Viral Images v1.0 implementation - all modules and configs
6ceaa94 verified
"""Model loading and caching utilities."""
import os
from typing import Optional, Any
class ModelLoader:
"""Singleton-like model loader that lazily loads and caches models."""
def __init__(self, config: dict = None):
self.config = config or {}
self._cache = {}
self._device = self._get_device()
def _get_device(self) -> str:
import torch
if torch.cuda.is_available():
return "cuda"
# Check for MPS (Apple Silicon)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_clip_model(self):
"""Load CLIP model and processor."""
if "clip" in self._cache:
return self._cache["clip"]
model_id = self.config.get("models", {}).get("clip", {}).get("model_id", "openai/clip-vit-base-patch32")
try:
import torch
from transformers import CLIPModel, CLIPProcessor
print(f"[ModelLoader] Loading CLIP: {model_id}...")
model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
model = model.to(self._device)
if self._device == "cpu":
model = model.float() # Use full precision on CPU
self._cache["clip"] = (model, processor)
print(f"[ModelLoader] CLIP loaded on {self._device}")
return self._cache["clip"]
except Exception as e:
print(f"[ModelLoader] CLIP load FAILED: {e}")
self._cache["clip"] = None
return None
def get_device(self) -> str:
return self._device
def warmup(self):
"""Preload critical models."""
self.get_clip_model()
def clear(self):
"""Clear model cache to free memory."""
import torch
import gc
for key in list(self._cache.keys()):
del self._cache[key]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._cache.clear()