""" Shared embedding extraction utilities for GAP-CLIP evaluation scripts. Consolidates the batch embedding extraction logic that was duplicated across sec51, sec52, sec533, and sec536 into two reusable functions: - extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP) - extract_color_model_embeddings() — for the specialized 16D ColorCLIP model """ from __future__ import annotations from typing import List, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _batch_tensors_to_pil(images: torch.Tensor) -> list: """Convert a batch of ImageNet-normalised tensors back to PIL images. This is the shared denormalization logic that was duplicated in every evaluator's image-embedding extraction method. """ pil_images = [] for i in range(images.shape[0]): t = images[i] if t.min() < 0 or t.max() > 1: mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1) t = torch.clamp(t * std + mean, 0, 1) pil_images.append(transforms.ToPILImage()(t.cpu())) return pil_images def _normalize_label(value: object, default: str = "unknown") -> str: """Convert label-like values to consistent non-empty strings.""" if value is None: return default # Handle pandas/NumPy missing values without importing pandas here. try: if bool(np.isnan(value)): # type: ignore[arg-type] return default except Exception: pass label = str(value).strip().lower() if not label or label in {"none", "nan"}: return default return label.replace("grey", "gray") # --------------------------------------------------------------------------- # CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP) # --------------------------------------------------------------------------- def extract_clip_embeddings( model, processor, dataloader: DataLoader, device: torch.device, embedding_type: str = "text", max_samples: int = 10_000, desc: str | None = None, ) -> Tuple[np.ndarray, List[str], List[str]]: """Extract L2-normalised embeddings from any CLIP-based model. Works with both 3-element batches ``(image, text, color)`` and 4-element batches ``(image, text, color, hierarchy)``. Always returns three lists (embeddings, colors, hierarchies); when the batch has no hierarchy column the third list is filled with ``"unknown"``. Args: model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.). processor: Matching ``CLIPProcessor``. dataloader: PyTorch DataLoader yielding 3- or 4-element tuples. device: Target torch device. embedding_type: ``"text"`` or ``"image"``. max_samples: Stop after collecting this many samples. desc: Optional tqdm description override. Returns: ``(embeddings, colors, hierarchies)`` where *embeddings* is an ``(N, D)`` numpy array and the other two are lists of strings. """ if desc is None: desc = f"Extracting {embedding_type} embeddings" all_embeddings: list[np.ndarray] = [] all_colors: list[str] = [] all_hierarchies: list[str] = [] sample_count = 0 with torch.no_grad(): for batch in tqdm(dataloader, desc=desc): if sample_count >= max_samples: break # Support both 3-element and 4-element batch tuples if len(batch) == 4: images, texts, colors, hierarchies = batch else: images, texts, colors = batch hierarchies = ["unknown"] * len(colors) images = images.to(device).expand(-1, 3, -1, -1) if embedding_type == "image": pil_images = _batch_tensors_to_pil(images) inputs = processor(images=pil_images, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} emb = model.get_image_features(**inputs) else: inputs = processor( text=list(texts), return_tensors="pt", padding=True, truncation=True, max_length=77, ) inputs = {k: v.to(device) for k, v in inputs.items()} emb = model.get_text_features(**inputs) emb = F.normalize(emb, dim=-1) all_embeddings.append(emb.cpu().numpy()) all_colors.extend(_normalize_label(c) for c in colors) all_hierarchies.extend(_normalize_label(h) for h in hierarchies) sample_count += len(images) del images, emb if torch.cuda.is_available(): torch.cuda.empty_cache() return np.vstack(all_embeddings), all_colors, all_hierarchies # --------------------------------------------------------------------------- # Specialized ColorCLIP embedding extraction # --------------------------------------------------------------------------- def extract_color_model_embeddings( color_model, dataloader: DataLoader, device: torch.device, embedding_type: str = "text", max_samples: int = 10_000, desc: str | None = None, ) -> Tuple[np.ndarray, List[str]]: """Extract L2-normalised embeddings from the 16D ColorCLIP model. Args: color_model: A ``ColorCLIP`` instance. dataloader: DataLoader yielding at least ``(image, text, color, ...)``. device: Target torch device. embedding_type: ``"text"`` or ``"image"``. max_samples: Stop after collecting this many samples. desc: Optional tqdm description override. Returns: ``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array. """ if desc is None: desc = f"Extracting {embedding_type} color-model embeddings" all_embeddings: list[np.ndarray] = [] all_colors: list[str] = [] sample_count = 0 with torch.no_grad(): for batch in tqdm(dataloader, desc=desc): if sample_count >= max_samples: break images, texts, colors = batch[0], batch[1], batch[2] images = images.to(device).expand(-1, 3, -1, -1) if embedding_type == "text": emb = color_model.get_text_embeddings(list(texts)) else: emb = color_model.get_image_embeddings(images) emb = F.normalize(emb, dim=-1) all_embeddings.append(emb.cpu().numpy()) normalized_colors = [ str(c).lower().strip().replace("grey", "gray") for c in colors ] all_colors.extend(normalized_colors) sample_count += len(images) return np.vstack(all_embeddings), all_colors