File size: 2,155 Bytes
6ceaa94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Model loading and caching utilities."""

import os
from typing import Optional, Any


class ModelLoader:
    """Singleton-like model loader that lazily loads and caches models."""
    
    def __init__(self, config: dict = None):
        self.config = config or {}
        self._cache = {}
        self._device = self._get_device()
    
    def _get_device(self) -> str:
        import torch
        if torch.cuda.is_available():
            return "cuda"
        # Check for MPS (Apple Silicon)
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            return "mps"
        return "cpu"
    
    def get_clip_model(self):
        """Load CLIP model and processor."""
        if "clip" in self._cache:
            return self._cache["clip"]
        
        model_id = self.config.get("models", {}).get("clip", {}).get("model_id", "openai/clip-vit-base-patch32")
        
        try:
            import torch
            from transformers import CLIPModel, CLIPProcessor
            
            print(f"[ModelLoader] Loading CLIP: {model_id}...")
            model = CLIPModel.from_pretrained(model_id)
            processor = CLIPProcessor.from_pretrained(model_id)
            model = model.to(self._device)
            if self._device == "cpu":
                model = model.float()  # Use full precision on CPU
            
            self._cache["clip"] = (model, processor)
            print(f"[ModelLoader] CLIP loaded on {self._device}")
            return self._cache["clip"]
        except Exception as e:
            print(f"[ModelLoader] CLIP load FAILED: {e}")
            self._cache["clip"] = None
            return None
    
    def get_device(self) -> str:
        return self._device
    
    def warmup(self):
        """Preload critical models."""
        self.get_clip_model()
    
    def clear(self):
        """Clear model cache to free memory."""
        import torch
        import gc
        for key in list(self._cache.keys()):
            del self._cache[key]
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        self._cache.clear()