Leacb4 commited on
Commit
4807234
·
verified ·
1 Parent(s): 48572da

Upload evaluation/utils/embeddings.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/utils/embeddings.py +200 -0
evaluation/utils/embeddings.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared embedding extraction utilities for GAP-CLIP evaluation scripts.
3
+
4
+ Consolidates the batch embedding extraction logic that was duplicated across
5
+ sec51, sec52, sec533, and sec536 into two reusable functions:
6
+
7
+ - extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP)
8
+ - extract_color_model_embeddings() — for the specialized 16D ColorCLIP model
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import List, Tuple, Union
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import DataLoader
19
+ from torchvision import transforms
20
+ from tqdm import tqdm
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Helpers
25
+ # ---------------------------------------------------------------------------
26
+
27
+ def _batch_tensors_to_pil(images: torch.Tensor) -> list:
28
+ """Convert a batch of ImageNet-normalised tensors back to PIL images.
29
+
30
+ This is the shared denormalization logic that was duplicated in every
31
+ evaluator's image-embedding extraction method.
32
+ """
33
+ pil_images = []
34
+ for i in range(images.shape[0]):
35
+ t = images[i]
36
+ if t.min() < 0 or t.max() > 1:
37
+ mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1)
38
+ std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1)
39
+ t = torch.clamp(t * std + mean, 0, 1)
40
+ pil_images.append(transforms.ToPILImage()(t.cpu()))
41
+ return pil_images
42
+
43
+
44
+ def _normalize_label(value: object, default: str = "unknown") -> str:
45
+ """Convert label-like values to consistent non-empty strings."""
46
+ if value is None:
47
+ return default
48
+
49
+ # Handle pandas/NumPy missing values without importing pandas here.
50
+ try:
51
+ if bool(np.isnan(value)): # type: ignore[arg-type]
52
+ return default
53
+ except Exception:
54
+ pass
55
+
56
+ label = str(value).strip().lower()
57
+ if not label or label in {"none", "nan"}:
58
+ return default
59
+ return label.replace("grey", "gray")
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP)
64
+ # ---------------------------------------------------------------------------
65
+
66
+ def extract_clip_embeddings(
67
+ model,
68
+ processor,
69
+ dataloader: DataLoader,
70
+ device: torch.device,
71
+ embedding_type: str = "text",
72
+ max_samples: int = 10_000,
73
+ desc: str | None = None,
74
+ ) -> Tuple[np.ndarray, List[str], List[str]]:
75
+ """Extract L2-normalised embeddings from any CLIP-based model.
76
+
77
+ Works with both 3-element batches ``(image, text, color)`` and 4-element
78
+ batches ``(image, text, color, hierarchy)``. Always returns three lists
79
+ (embeddings, colors, hierarchies); when the batch has no hierarchy column
80
+ the third list is filled with ``"unknown"``.
81
+
82
+ Args:
83
+ model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.).
84
+ processor: Matching ``CLIPProcessor``.
85
+ dataloader: PyTorch DataLoader yielding 3- or 4-element tuples.
86
+ device: Target torch device.
87
+ embedding_type: ``"text"`` or ``"image"``.
88
+ max_samples: Stop after collecting this many samples.
89
+ desc: Optional tqdm description override.
90
+
91
+ Returns:
92
+ ``(embeddings, colors, hierarchies)`` where *embeddings* is an
93
+ ``(N, D)`` numpy array and the other two are lists of strings.
94
+ """
95
+ if desc is None:
96
+ desc = f"Extracting {embedding_type} embeddings"
97
+
98
+ all_embeddings: list[np.ndarray] = []
99
+ all_colors: list[str] = []
100
+ all_hierarchies: list[str] = []
101
+ sample_count = 0
102
+
103
+ with torch.no_grad():
104
+ for batch in tqdm(dataloader, desc=desc):
105
+ if sample_count >= max_samples:
106
+ break
107
+
108
+ # Support both 3-element and 4-element batch tuples
109
+ if len(batch) == 4:
110
+ images, texts, colors, hierarchies = batch
111
+ else:
112
+ images, texts, colors = batch
113
+ hierarchies = ["unknown"] * len(colors)
114
+
115
+ images = images.to(device).expand(-1, 3, -1, -1)
116
+
117
+ if embedding_type == "image":
118
+ pil_images = _batch_tensors_to_pil(images)
119
+ inputs = processor(images=pil_images, return_tensors="pt")
120
+ inputs = {k: v.to(device) for k, v in inputs.items()}
121
+ emb = model.get_image_features(**inputs)
122
+ else:
123
+ inputs = processor(
124
+ text=list(texts),
125
+ return_tensors="pt",
126
+ padding=True,
127
+ truncation=True,
128
+ max_length=77,
129
+ )
130
+ inputs = {k: v.to(device) for k, v in inputs.items()}
131
+ emb = model.get_text_features(**inputs)
132
+
133
+ emb = F.normalize(emb, dim=-1)
134
+
135
+ all_embeddings.append(emb.cpu().numpy())
136
+ all_colors.extend(_normalize_label(c) for c in colors)
137
+ all_hierarchies.extend(_normalize_label(h) for h in hierarchies)
138
+ sample_count += len(images)
139
+
140
+ del images, emb
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+
144
+ return np.vstack(all_embeddings), all_colors, all_hierarchies
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Specialized ColorCLIP embedding extraction
149
+ # ---------------------------------------------------------------------------
150
+
151
+ def extract_color_model_embeddings(
152
+ color_model,
153
+ dataloader: DataLoader,
154
+ device: torch.device,
155
+ embedding_type: str = "text",
156
+ max_samples: int = 10_000,
157
+ desc: str | None = None,
158
+ ) -> Tuple[np.ndarray, List[str]]:
159
+ """Extract L2-normalised embeddings from the 16D ColorCLIP model.
160
+
161
+ Args:
162
+ color_model: A ``ColorCLIP`` instance.
163
+ dataloader: DataLoader yielding at least ``(image, text, color, ...)``.
164
+ device: Target torch device.
165
+ embedding_type: ``"text"`` or ``"image"``.
166
+ max_samples: Stop after collecting this many samples.
167
+ desc: Optional tqdm description override.
168
+
169
+ Returns:
170
+ ``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array.
171
+ """
172
+ if desc is None:
173
+ desc = f"Extracting {embedding_type} color-model embeddings"
174
+
175
+ all_embeddings: list[np.ndarray] = []
176
+ all_colors: list[str] = []
177
+ sample_count = 0
178
+
179
+ with torch.no_grad():
180
+ for batch in tqdm(dataloader, desc=desc):
181
+ if sample_count >= max_samples:
182
+ break
183
+
184
+ images, texts, colors = batch[0], batch[1], batch[2]
185
+ images = images.to(device).expand(-1, 3, -1, -1)
186
+
187
+ if embedding_type == "text":
188
+ emb = color_model.get_text_embeddings(list(texts))
189
+ else:
190
+ emb = color_model.get_image_embeddings(images)
191
+ emb = F.normalize(emb, dim=-1)
192
+
193
+ all_embeddings.append(emb.cpu().numpy())
194
+ normalized_colors = [
195
+ str(c).lower().strip().replace("grey", "gray") for c in colors
196
+ ]
197
+ all_colors.extend(normalized_colors)
198
+ sample_count += len(images)
199
+
200
+ return np.vstack(all_embeddings), all_colors