""" Shared model loading and embedding extraction utilities. All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline, or the specialized color model should import from here instead of duplicating the loading logic. """ from __future__ import annotations import sys from pathlib import Path from typing import Tuple import torch import torch.nn.functional as F from PIL import Image from transformers import CLIPModel as CLIPModelTransformers from transformers import CLIPProcessor # Make project root importable when running evaluation scripts directly. _PROJECT_ROOT = Path(__file__).resolve().parents[2] if str(_PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(_PROJECT_ROOT)) # --------------------------------------------------------------------------- # GAP-CLIP (main model) # --------------------------------------------------------------------------- def load_gap_clip( model_path: str, device: torch.device, ) -> Tuple[CLIPModelTransformers, CLIPProcessor]: """Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor. Args: model_path: Path to the `gap_clip.pth` checkpoint. device: Target device. Returns: (model, processor) ready for inference. """ model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") checkpoint = torch.load(model_path, map_location=device, weights_only=False) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") return model, processor # --------------------------------------------------------------------------- # Fashion-CLIP baseline # --------------------------------------------------------------------------- def load_baseline_fashion_clip( device: torch.device, ) -> Tuple[CLIPModelTransformers, CLIPProcessor]: """Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip). Returns: (model, processor) ready for inference. """ model_name = "patrickjohncyh/fashion-clip" processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModelTransformers.from_pretrained(model_name).to(device) model.eval() return model, processor # --------------------------------------------------------------------------- # Specialized 16D color model # --------------------------------------------------------------------------- def load_color_model( color_model_path: str, device: torch.device, ): """Load the specialized 16D color model (CLIP-backbone). Returns: (color_model, None) -- second element kept for API compatibility """ from training.color_model import ColorCLIP # type: ignore print("Loading ColorCLIP (CLIP-backbone, 16D) ...") color_model = ColorCLIP.from_checkpoint(color_model_path, device=device) print("Color model loaded successfully") return color_model, None def load_hierarchy_model( hierarchy_model_path: str, device: torch.device, ): """Load the hierarchy model (CLIP-backbone). Returns: hierarchy_model ready for inference. """ from training.hierarchy_model import HierarchyModel # type: ignore print("Loading HierarchyModel (CLIP-backbone, 64D) ...") model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device) print("Hierarchy model loaded successfully") return model # --------------------------------------------------------------------------- # Core encoding helpers (same as notebook) # --------------------------------------------------------------------------- def encode_text(model, processor, text_queries, device): """Encode text queries into embeddings (unnormalized).""" if isinstance(text_queries, str): text_queries = [text_queries] inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): text_features = model.get_text_features(**inputs) return text_features def encode_image(model, processor, images, device): """Encode images into embeddings (unnormalized).""" if not isinstance(images, list): images = [images] inputs = processor(images=images, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): image_features = model.get_image_features(**inputs) return image_features # --------------------------------------------------------------------------- # Normalized wrappers (preserve old call signatures used across eval scripts) # --------------------------------------------------------------------------- def get_text_embedding(model, processor, device, text): """Single normalized text embedding (shape: [512]).""" return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0) def get_text_embeddings_batch(model, processor, device, texts): """Normalized text embeddings for a batch (shape: [N, 512]).""" return F.normalize(encode_text(model, processor, texts, device), dim=-1) def get_image_embedding_from_pil(model, processor, device, pil_image): """Normalized image embedding from a PIL image (shape: [512]).""" return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0)