Leacb4 commited on
Commit
38783e1
·
verified ·
1 Parent(s): 3af73e2

Add simple API: load_gap_clip, get_image_embedding_from_url, get_text_embedding + fix bugs

Browse files
Files changed (2) hide show
  1. __init__.py +9 -2
  2. example_usage.py +257 -26
__init__.py CHANGED
@@ -26,15 +26,22 @@ __email__ = "lea.attia@gmail.com"
26
  # Import main components for easy access
27
  try:
28
  from .training.color_model import ColorCLIP
29
- from .training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
30
- from .example_usage import load_models_from_hf, example_search
 
 
 
31
  from . import config
32
 
33
  __all__ = [
34
  'ColorCLIP',
35
  'HierarchyModel',
36
  'HierarchyExtractor',
 
 
 
37
  'load_models_from_hf',
 
38
  'example_search',
39
  'config',
40
  '__version__',
 
26
  # Import main components for easy access
27
  try:
28
  from .training.color_model import ColorCLIP
29
+ from .training.hierarchy_model import HierarchyModel, HierarchyExtractor
30
+ from .example_usage import (
31
+ load_gap_clip, get_image_embedding_from_url, get_text_embedding,
32
+ load_models_from_hf, load_models_from_local, example_search,
33
+ )
34
  from . import config
35
 
36
  __all__ = [
37
  'ColorCLIP',
38
  'HierarchyModel',
39
  'HierarchyExtractor',
40
+ 'load_gap_clip',
41
+ 'get_image_embedding_from_url',
42
+ 'get_text_embedding',
43
  'load_models_from_hf',
44
+ 'load_models_from_local',
45
  'example_search',
46
  'config',
47
  '__version__',
example_usage.py CHANGED
@@ -11,6 +11,7 @@ import os
11
 
12
  import torch
13
  import torch.nn.functional as F
 
14
  from PIL import Image
15
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
16
  from huggingface_hub import hf_hub_download
@@ -19,6 +20,83 @@ from training.color_model import ColorCLIP
19
  from training.hierarchy_model import HierarchyModel
20
  import config
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def encode_text(model, processor, text_queries, device):
24
  """Encode text queries into embeddings (unnormalized)."""
@@ -83,10 +161,8 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
83
  cache_dir=cache_dir,
84
  )
85
 
86
- clip_model = CLIPModel_transformers.from_pretrained(
87
- 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
88
- )
89
- checkpoint = torch.load(main_model_path, map_location=device)
90
 
91
  # Handle different checkpoint structures
92
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
@@ -97,7 +173,60 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
97
  clip_model = clip_model.to(device)
98
  clip_model.eval()
99
 
100
- processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  print(" Main CLIP model loaded")
102
 
103
  print("\nAll models loaded!")
@@ -135,27 +264,26 @@ def example_search(models, image_path: str = None, text_query: str = None):
135
  color_emb = color_model.get_text_embeddings([text_query])
136
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
137
 
138
- print(f" Color embedding: {color_emb.shape}")
139
- print(f" color_emb: {color_emb}")
140
- print(f" Hierarchy embedding: {hierarchy_emb.shape}")
141
- print(f" hierarchy_emb: {hierarchy_emb}")
142
 
143
  # Get main model embeddings
144
  text_features = encode_text(main_model, processor, text_query, device)
145
  text_features = F.normalize(text_features, dim=-1)
146
 
147
  print(f" Main embedding: {text_features.shape}")
148
- print(f" First logits of main embedding: {text_features[0:10]}")
149
 
150
  # Extract color and hierarchy embeddings from main embedding
151
  main_color_emb = text_features[:, :config.color_emb_dim]
152
  main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
153
 
154
- print(f"\n Comparison:")
155
- print(f" Color embedding from color model: {color_emb[0]}")
156
- print(f" Color embedding from main model (first {config.color_emb_dim} dims): {main_color_emb[0]}")
157
- print(f" Hierarchy embedding from hierarchy model: {hierarchy_emb[0]}")
158
- print(f" Hierarchy embedding from main model (dims {config.color_emb_dim}-{config.color_emb_dim + config.hierarchy_emb_dim}): {main_hierarchy_emb[0]}")
 
159
 
160
  # Calculate cosine similarity between color embeddings
161
  color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
@@ -166,25 +294,114 @@ def example_search(models, image_path: str = None, text_query: str = None):
166
  print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
167
 
168
  if image_path and os.path.exists(image_path):
169
- print(f" Image: {image_path}")
170
  image = Image.open(image_path).convert("RGB")
171
 
172
- # Get image embeddings
173
  image_features = encode_image(main_model, processor, image, device)
174
  image_features = F.normalize(image_features, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- print(f" Image embedding: {image_features.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  if __name__ == "__main__":
180
  import argparse
181
 
182
- parser = argparse.ArgumentParser(description="Example usage of models")
183
  parser.add_argument(
184
  "--repo-id",
185
  type=str,
186
- required=True,
187
- help="ID of the Hugging Face repository",
188
  )
189
  parser.add_argument(
190
  "--text",
@@ -195,14 +412,28 @@ if __name__ == "__main__":
195
  parser.add_argument(
196
  "--image",
197
  type=str,
198
- default="red_dress.png",
199
- help="Path to an image",
 
 
 
 
 
 
 
200
  )
201
 
202
  args = parser.parse_args()
203
 
204
- # Load models
205
- models = load_models_from_hf(args.repo_id)
 
 
 
206
 
207
- # Example search
208
  example_search(models, image_path=args.image, text_query=args.text)
 
 
 
 
 
11
 
12
  import torch
13
  import torch.nn.functional as F
14
+ import requests
15
  from PIL import Image
16
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
17
  from huggingface_hub import hf_hub_download
 
20
  from training.hierarchy_model import HierarchyModel
21
  import config
22
 
23
+ CLIP_MODEL_NAME = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
24
+ HF_REPO_ID = "Leacb4/gap-clip"
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Simple API — load from HF and get 512D embeddings
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def load_gap_clip(repo_id: str = HF_REPO_ID):
32
+ """
33
+ Load the GAP-CLIP model directly from Hugging Face.
34
+
35
+ This is the simplest way to use the model. Returns (model, processor).
36
+
37
+ Example::
38
+
39
+ model, processor = load_gap_clip()
40
+ emb = get_image_embedding_from_url(
41
+ "https://www.gap.com/webcontent/0060/662/817/cn60662817.jpg",
42
+ model, processor,
43
+ )
44
+ print(emb.shape) # torch.Size([1, 512])
45
+ """
46
+ model = CLIPModel_transformers.from_pretrained(repo_id)
47
+ processor = CLIPProcessor.from_pretrained(repo_id)
48
+ model.eval()
49
+ return model, processor
50
+
51
+
52
+ def get_image_embedding_from_url(url: str, model, processor, device=None):
53
+ """
54
+ Download an image from a URL and return its 512D GAP-CLIP embedding.
55
+
56
+ Args:
57
+ url: Image URL.
58
+ model: CLIPModel loaded via load_gap_clip() or from_pretrained().
59
+ processor: CLIPProcessor matching the model.
60
+ device: Device to run on (defaults to config.device).
61
+
62
+ Returns:
63
+ Tensor of shape [1, 512] (L2-normalized).
64
+ """
65
+ device = device or config.device
66
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
67
+ inputs = processor(images=image, return_tensors="pt")
68
+ inputs = {k: v.to(device) for k, v in inputs.items()}
69
+ model = model.to(device)
70
+ with torch.no_grad():
71
+ image_features = model.get_image_features(**inputs)
72
+ return F.normalize(image_features, dim=-1)
73
+
74
+
75
+ def get_text_embedding(text: str, model, processor, device=None):
76
+ """
77
+ Return a 512D GAP-CLIP embedding for a text query.
78
+
79
+ Args:
80
+ text: Text query (e.g., "red dress").
81
+ model: CLIPModel loaded via load_gap_clip() or from_pretrained().
82
+ processor: CLIPProcessor matching the model.
83
+ device: Device to run on (defaults to config.device).
84
+
85
+ Returns:
86
+ Tensor of shape [1, 512] (L2-normalized).
87
+ """
88
+ device = device or config.device
89
+ inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
90
+ inputs = {k: v.to(device) for k, v in inputs.items()}
91
+ model = model.to(device)
92
+ with torch.no_grad():
93
+ text_features = model.get_text_features(**inputs)
94
+ return F.normalize(text_features, dim=-1)
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Internal helpers for encode_text / encode_image (used by advanced examples)
99
+ # ---------------------------------------------------------------------------
100
 
101
  def encode_text(model, processor, text_queries, device):
102
  """Encode text queries into embeddings (unnormalized)."""
 
161
  cache_dir=cache_dir,
162
  )
163
 
164
+ clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME)
165
+ checkpoint = torch.load(main_model_path, map_location=device, weights_only=False)
 
 
166
 
167
  # Handle different checkpoint structures
168
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
 
173
  clip_model = clip_model.to(device)
174
  clip_model.eval()
175
 
176
+ processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
177
+ print(" Main CLIP model loaded")
178
+
179
+ print("\nAll models loaded!")
180
+
181
+ return {
182
+ 'color_model': color_model,
183
+ 'hierarchy_model': hierarchy_model,
184
+ 'main_model': clip_model,
185
+ 'processor': processor,
186
+ 'device': device,
187
+ }
188
+
189
+
190
+ def load_models_from_local(
191
+ color_model_path: str = None,
192
+ hierarchy_model_path: str = None,
193
+ main_model_path: str = None,
194
+ ):
195
+ """
196
+ Load models from local checkpoint files.
197
+
198
+ Args:
199
+ color_model_path: Path to color_model.pt (defaults to config.color_model_path)
200
+ hierarchy_model_path: Path to hierarchy_model.pth (defaults to config.hierarchy_model_path)
201
+ main_model_path: Path to gap_clip.pth (defaults to config.main_model_path)
202
+ """
203
+ device = config.device
204
+ color_model_path = color_model_path or config.color_model_path
205
+ hierarchy_model_path = hierarchy_model_path or config.hierarchy_model_path
206
+ main_model_path = main_model_path or config.main_model_path
207
+
208
+ print(f"Loading models from local checkpoints (device={device})...")
209
+
210
+ # 1. Color model
211
+ print(" Loading color model...")
212
+ color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
213
+ print(" Color model loaded")
214
+
215
+ # 2. Hierarchy model
216
+ print(" Loading hierarchy model...")
217
+ hierarchy_model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
218
+ print(" Hierarchy model loaded")
219
+
220
+ # 3. Main CLIP model
221
+ print(" Loading main CLIP model...")
222
+ clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME)
223
+ checkpoint = torch.load(main_model_path, map_location=device, weights_only=False)
224
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
225
+ clip_model.load_state_dict(checkpoint['model_state_dict'])
226
+ else:
227
+ clip_model.load_state_dict(checkpoint)
228
+ clip_model.to(device).eval()
229
+ processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
230
  print(" Main CLIP model loaded")
231
 
232
  print("\nAll models loaded!")
 
264
  color_emb = color_model.get_text_embeddings([text_query])
265
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
266
 
267
+ print(f" Color embedding shape: {color_emb.shape}, norm: {color_emb.norm(dim=-1).item():.4f}")
268
+ print(f" Hierarchy embedding shape: {hierarchy_emb.shape}, norm: {hierarchy_emb.norm(dim=-1).item():.4f}")
 
 
269
 
270
  # Get main model embeddings
271
  text_features = encode_text(main_model, processor, text_query, device)
272
  text_features = F.normalize(text_features, dim=-1)
273
 
274
  print(f" Main embedding: {text_features.shape}")
275
+ print(f" First 10 dims of main embedding: {text_features[0, :10]}")
276
 
277
  # Extract color and hierarchy embeddings from main embedding
278
  main_color_emb = text_features[:, :config.color_emb_dim]
279
  main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
280
 
281
+ print(f"\n Subspace comparison (color model vs main model dims [0:{config.color_emb_dim}]):")
282
+ print(f" color_model first 5 dims: {color_emb[0, :5].tolist()}")
283
+ print(f" main_model first 5 dims: {main_color_emb[0, :5].tolist()}")
284
+ print(f" Subspace comparison (hierarchy model vs main model dims [{config.color_emb_dim}:{config.color_emb_dim + config.hierarchy_emb_dim}]):")
285
+ print(f" hierarchy_model first 5 dims: {hierarchy_emb[0, :5].tolist()}")
286
+ print(f" main_model first 5 dims: {main_hierarchy_emb[0, :5].tolist()}")
287
 
288
  # Calculate cosine similarity between color embeddings
289
  color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
 
294
  print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
295
 
296
  if image_path and os.path.exists(image_path):
297
+ print(f"\n Image: {image_path}")
298
  image = Image.open(image_path).convert("RGB")
299
 
300
+ # Main model image embedding
301
  image_features = encode_image(main_model, processor, image, device)
302
  image_features = F.normalize(image_features, dim=-1)
303
+ print(f" Main image embedding shape: {image_features.shape}")
304
+
305
+ # Color model image embedding (preprocess through model's own processor)
306
+ color_pixel_values = color_model.processor(
307
+ images=image, return_tensors="pt"
308
+ )["pixel_values"].to(device)
309
+ color_img_emb = color_model.get_image_embeddings(color_pixel_values)
310
+ print(f" Color image embedding shape: {color_img_emb.shape}")
311
+
312
+ # Hierarchy model image embedding
313
+ hierarchy_pixel_values = hierarchy_model.processor(
314
+ images=image, return_tensors="pt"
315
+ )["pixel_values"].to(device)
316
+ hierarchy_img_emb = hierarchy_model.get_image_embeddings(hierarchy_pixel_values)
317
+ print(f" Hierarchy image embedding shape: {hierarchy_img_emb.shape}")
318
+
319
+ # Compare subspace alignment for images
320
+ main_color_img = image_features[:, :config.color_emb_dim]
321
+ main_hierarchy_img = image_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
322
+ color_img_sim = F.cosine_similarity(color_img_emb, main_color_img, dim=1)
323
+ hierarchy_img_sim = F.cosine_similarity(hierarchy_img_emb, main_hierarchy_img, dim=1)
324
+ print(f" Image color subspace cosine similarity: {color_img_sim.item():.4f}")
325
+ print(f" Image hierarchy subspace cosine similarity: {hierarchy_img_sim.item():.4f}")
326
+
327
+
328
+ def example_similarity_search(models, image_paths: list, text_query: str):
329
+ """
330
+ Rank images by similarity to a text query using GAP-CLIP.
331
+
332
+ Shows the key use case: computing text-to-image similarity scores
333
+ for ranking, combining color, hierarchy, and general CLIP subspaces.
334
 
335
+ Args:
336
+ models: Dictionary of loaded models
337
+ image_paths: List of image file paths to rank
338
+ text_query: Text query to match against
339
+ """
340
+ main_model = models['main_model']
341
+ processor = models['processor']
342
+ device = models['device']
343
+
344
+ print(f"\nSimilarity search: '{text_query}' against {len(image_paths)} images")
345
+
346
+ # Encode the text query
347
+ text_features = encode_text(main_model, processor, text_query, device)
348
+ text_features = F.normalize(text_features, dim=-1) # [1, 512]
349
+
350
+ # Encode all images
351
+ images = []
352
+ valid_paths = []
353
+ for p in image_paths:
354
+ if os.path.exists(p):
355
+ images.append(Image.open(p).convert("RGB"))
356
+ valid_paths.append(p)
357
+ else:
358
+ print(f" Warning: {p} not found, skipping")
359
+
360
+ if not images:
361
+ print(" No valid images found.")
362
+ return
363
+
364
+ image_features = encode_image(main_model, processor, images, device)
365
+ image_features = F.normalize(image_features, dim=-1) # [N, 512]
366
+
367
+ # Full 512D similarity
368
+ full_scores = (text_features @ image_features.T).squeeze(0) # [N]
369
+
370
+ # Subspace similarities
371
+ color_dim = config.color_emb_dim
372
+ hierarchy_end = color_dim + config.hierarchy_emb_dim
373
+
374
+ color_text = F.normalize(text_features[:, :color_dim], dim=-1)
375
+ color_imgs = F.normalize(image_features[:, :color_dim], dim=-1)
376
+ color_scores = (color_text @ color_imgs.T).squeeze(0)
377
+
378
+ hier_text = F.normalize(text_features[:, color_dim:hierarchy_end], dim=-1)
379
+ hier_imgs = F.normalize(image_features[:, color_dim:hierarchy_end], dim=-1)
380
+ hierarchy_scores = (hier_text @ hier_imgs.T).squeeze(0)
381
+
382
+ # Rank by full similarity
383
+ ranked_indices = full_scores.argsort(descending=True)
384
+
385
+ print(f"\n Ranking (by full 512D cosine similarity):")
386
+ for rank, idx in enumerate(ranked_indices):
387
+ i = idx.item()
388
+ print(
389
+ f" {rank + 1}. {os.path.basename(valid_paths[i]):30s}"
390
+ f" full={full_scores[i]:.4f}"
391
+ f" color={color_scores[i]:.4f}"
392
+ f" hierarchy={hierarchy_scores[i]:.4f}"
393
+ )
394
 
395
 
396
  if __name__ == "__main__":
397
  import argparse
398
 
399
+ parser = argparse.ArgumentParser(description="Example usage of GAP-CLIP models")
400
  parser.add_argument(
401
  "--repo-id",
402
  type=str,
403
+ default=None,
404
+ help="Hugging Face repo ID (e.g., Leacb4/gap-clip). If omitted, loads from local paths.",
405
  )
406
  parser.add_argument(
407
  "--text",
 
412
  parser.add_argument(
413
  "--image",
414
  type=str,
415
+ default=None,
416
+ help="Path to a single image for example_search",
417
+ )
418
+ parser.add_argument(
419
+ "--images",
420
+ type=str,
421
+ nargs="+",
422
+ default=None,
423
+ help="Paths to multiple images for similarity ranking",
424
  )
425
 
426
  args = parser.parse_args()
427
 
428
+ # Load models (HF or local)
429
+ if args.repo_id:
430
+ models = load_models_from_hf(args.repo_id)
431
+ else:
432
+ models = load_models_from_local()
433
 
434
+ # Example search (embedding inspection)
435
  example_search(models, image_path=args.image, text_query=args.text)
436
+
437
+ # Similarity ranking (if multiple images provided)
438
+ if args.images:
439
+ example_similarity_search(models, args.images, args.text)