Spaces:
Sleeping
Sleeping
File size: 4,895 Bytes
71c1ad2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | # app/models/clip_model.py
# CLIP model for multimodal text-image alignment (deep analysis only)
from PIL import Image
import numpy as np
from app.config import get_settings
from app.observability.logging import get_logger
logger = get_logger(__name__)
class CLIPModel:
"""
CLIP (Contrastive Language-Image Pre-Training) model.
Used in the deep analysis path to compute semantic alignment
between text descriptions and image content. This helps detect
subtle multimodal threats (e.g., threatening text overlaid on images).
"""
def __init__(self):
self.settings = get_settings()
self.model = None
self.preprocess = None
self.tokenizer = None
self._loaded = False
self.device = None
def load(self) -> None:
"""Load the CLIP model and preprocessor."""
import torch
try:
import open_clip
model_name = self.settings.clip_model_name
cache_dir = self.settings.model_cache_path / "clip"
cache_dir.mkdir(parents=True, exist_ok=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("loading_clip_model", model=model_name)
# Use OpenCLIP for flexibility
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-32",
pretrained="laion2b_s34b_b79k",
)
self.model = self.model.to(self.device)
self.model.eval()
self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
self._loaded = True
logger.info("clip_model_loaded")
except ImportError:
logger.warning("clip_not_available", reason="open_clip not installed")
self._loaded = False
except Exception as e:
logger.error("clip_load_failed", error=str(e))
self._loaded = False
def compute_similarity(self, image: Image.Image, texts: list[str]) -> dict:
"""
Compute cosine similarity between an image and a list of text descriptions.
Args:
image: PIL Image.
texts: List of text descriptions to compare against.
Returns:
Dict with similarities, best_match, and best_score.
"""
if not self._loaded:
return {"error": "CLIP model not loaded", "similarities": []}
import torch
# Preprocess image
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# Tokenize texts
text_tokens = self.tokenizer(texts).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
text_features = self.model.encode_text(text_tokens)
# Normalize
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Cosine similarity
similarities = (image_features @ text_features.T).squeeze(0).cpu().numpy()
sim_list = similarities.tolist()
best_idx = int(np.argmax(sim_list))
return {
"similarities": dict(zip(texts, sim_list)),
"best_match": texts[best_idx],
"best_score": sim_list[best_idx],
}
def align_content(self, image: Image.Image, context_text: str | None = None) -> dict:
"""
Analyze image alignment with harmful content categories.
Args:
image: Image to analyze.
context_text: Optional surrounding text context.
Returns:
Dict with category alignment scores.
"""
harmful_descriptions = [
"a photo containing violence, fighting, or physical harm",
"a photo containing nudity or sexual content",
"a photo containing self-harm or suicide imagery",
"a photo containing hate symbols or extremist content",
"a photo containing drugs or substance abuse",
"a safe and appropriate photo for children",
]
result = self.compute_similarity(image, harmful_descriptions)
if "error" in result:
return result
# Also check text-image alignment if context provided
text_alignment = None
if context_text:
text_result = self.compute_similarity(image, [context_text, "unrelated content"])
text_alignment = text_result["similarities"].get(context_text, 0.0)
return {
"category_scores": result["similarities"],
"most_aligned": result["best_match"],
"alignment_score": result["best_score"],
"text_image_alignment": text_alignment,
}
@property
def is_loaded(self) -> bool:
return self._loaded
|