gap-clip / evaluation /utils /model_loader.py
Leacb4's picture
Upload evaluation/utils/model_loader.py with huggingface_hub
afbd922 verified
"""
Shared model loading and embedding extraction utilities.
All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline,
or the specialized color model should import from here instead of duplicating
the loading logic.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Tuple
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel as CLIPModelTransformers
from transformers import CLIPProcessor
# Make project root importable when running evaluation scripts directly.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
# ---------------------------------------------------------------------------
# GAP-CLIP (main model)
# ---------------------------------------------------------------------------
def load_gap_clip(
model_path: str,
device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
"""Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor.
Args:
model_path: Path to the `gap_clip.pth` checkpoint.
device: Target device.
Returns:
(model, processor) ready for inference.
"""
model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
return model, processor
# ---------------------------------------------------------------------------
# Fashion-CLIP baseline
# ---------------------------------------------------------------------------
def load_baseline_fashion_clip(
device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
"""Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip).
Returns:
(model, processor) ready for inference.
"""
model_name = "patrickjohncyh/fashion-clip"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModelTransformers.from_pretrained(model_name).to(device)
model.eval()
return model, processor
# ---------------------------------------------------------------------------
# Specialized 16D color model
# ---------------------------------------------------------------------------
def load_color_model(
color_model_path: str,
device: torch.device,
):
"""Load the specialized 16D color model (CLIP-backbone).
Returns:
(color_model, None) -- second element kept for API compatibility
"""
from training.color_model import ColorCLIP # type: ignore
print("Loading ColorCLIP (CLIP-backbone, 16D) ...")
color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
print("Color model loaded successfully")
return color_model, None
def load_hierarchy_model(
hierarchy_model_path: str,
device: torch.device,
):
"""Load the hierarchy model (CLIP-backbone).
Returns:
hierarchy_model ready for inference.
"""
from training.hierarchy_model import HierarchyModel # type: ignore
print("Loading HierarchyModel (CLIP-backbone, 64D) ...")
model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
print("Hierarchy model loaded successfully")
return model
# ---------------------------------------------------------------------------
# Core encoding helpers (same as notebook)
# ---------------------------------------------------------------------------
def encode_text(model, processor, text_queries, device):
"""Encode text queries into embeddings (unnormalized)."""
if isinstance(text_queries, str):
text_queries = [text_queries]
inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
text_features = model.get_text_features(**inputs)
return text_features
def encode_image(model, processor, images, device):
"""Encode images into embeddings (unnormalized)."""
if not isinstance(images, list):
images = [images]
inputs = processor(images=images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
# ---------------------------------------------------------------------------
# Normalized wrappers (preserve old call signatures used across eval scripts)
# ---------------------------------------------------------------------------
def get_text_embedding(model, processor, device, text):
"""Single normalized text embedding (shape: [512])."""
return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)
def get_text_embeddings_batch(model, processor, device, texts):
"""Normalized text embeddings for a batch (shape: [N, 512])."""
return F.normalize(encode_text(model, processor, texts, device), dim=-1)
def get_image_embedding_from_pil(model, processor, device, pil_image):
"""Normalized image embedding from a PIL image (shape: [512])."""
return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0)