| """ |
| Shared model loading and embedding extraction utilities. |
| |
| All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline, |
| or the specialized color model should import from here instead of duplicating |
| the loading logic. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from transformers import CLIPModel as CLIPModelTransformers |
| from transformers import CLIPProcessor |
|
|
| |
| _PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
|
|
| |
| |
| |
|
|
| def load_gap_clip( |
| model_path: str, |
| device: torch.device, |
| ) -> Tuple[CLIPModelTransformers, CLIPProcessor]: |
| """Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor. |
| |
| Args: |
| model_path: Path to the `gap_clip.pth` checkpoint. |
| device: Target device. |
| |
| Returns: |
| (model, processor) ready for inference. |
| """ |
| model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
|
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| else: |
| model.load_state_dict(checkpoint) |
|
|
| model = model.to(device) |
| model.eval() |
| processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") |
| return model, processor |
|
|
|
|
| |
| |
| |
|
|
| def load_baseline_fashion_clip( |
| device: torch.device, |
| ) -> Tuple[CLIPModelTransformers, CLIPProcessor]: |
| """Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip). |
| |
| Returns: |
| (model, processor) ready for inference. |
| """ |
| model_name = "patrickjohncyh/fashion-clip" |
| processor = CLIPProcessor.from_pretrained(model_name) |
| model = CLIPModelTransformers.from_pretrained(model_name).to(device) |
| model.eval() |
| return model, processor |
|
|
|
|
| |
| |
| |
|
|
| def load_color_model( |
| color_model_path: str, |
| device: torch.device, |
| ): |
| """Load the specialized 16D color model (CLIP-backbone). |
| |
| Returns: |
| (color_model, None) -- second element kept for API compatibility |
| """ |
| from training.color_model import ColorCLIP |
|
|
| print("Loading ColorCLIP (CLIP-backbone, 16D) ...") |
| color_model = ColorCLIP.from_checkpoint(color_model_path, device=device) |
| print("Color model loaded successfully") |
| return color_model, None |
|
|
|
|
| def load_hierarchy_model( |
| hierarchy_model_path: str, |
| device: torch.device, |
| ): |
| """Load the hierarchy model (CLIP-backbone). |
| |
| Returns: |
| hierarchy_model ready for inference. |
| """ |
| from training.hierarchy_model import HierarchyModel |
|
|
| print("Loading HierarchyModel (CLIP-backbone, 64D) ...") |
| model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device) |
| print("Hierarchy model loaded successfully") |
| return model |
|
|
|
|
|
|
| |
| |
| |
|
|
| def encode_text(model, processor, text_queries, device): |
| """Encode text queries into embeddings (unnormalized).""" |
| if isinstance(text_queries, str): |
| text_queries = [text_queries] |
| inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| return text_features |
|
|
|
|
| def encode_image(model, processor, images, device): |
| """Encode images into embeddings (unnormalized).""" |
| if not isinstance(images, list): |
| images = [images] |
| inputs = processor(images=images, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| return image_features |
|
|
|
|
| |
| |
| |
|
|
| def get_text_embedding(model, processor, device, text): |
| """Single normalized text embedding (shape: [512]).""" |
| return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0) |
|
|
|
|
| def get_text_embeddings_batch(model, processor, device, texts): |
| """Normalized text embeddings for a batch (shape: [N, 512]).""" |
| return F.normalize(encode_text(model, processor, texts, device), dim=-1) |
|
|
|
|
| def get_image_embedding_from_pil(model, processor, device, pil_image): |
| """Normalized image embedding from a PIL image (shape: [512]).""" |
| return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0) |
|
|