Leacb4 commited on
Commit
4477181
·
verified ·
1 Parent(s): afbd922

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +105 -135
example_usage.py CHANGED
@@ -1,238 +1,208 @@
1
  #!/usr/bin/env python3
2
  """
3
- Example usage of models from Hugging Face.
4
- This file provides example code for loading and using the models (color, hierarchy, main)
5
- from the Hugging Face Hub. It shows how to load models, extract embeddings,
6
- and perform searches or similarity comparisons.
 
7
  """
8
 
 
 
9
  import torch
10
  import torch.nn.functional as F
11
  from PIL import Image
12
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
13
  from huggingface_hub import hf_hub_download
14
- import json
15
- import os
16
 
17
- # Import local models (to adapt to your structure)
18
- from training.color_model import ColorCLIP, Tokenizer
19
- from training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
20
  import config
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
23
  """
24
- Load models from Hugging Face
25
-
26
  Args:
27
  repo_id: ID of the Hugging Face repository
28
  cache_dir: Local cache directory
29
  """
30
-
31
  os.makedirs(cache_dir, exist_ok=True)
32
  device = config.device
33
-
34
- print(f"📥 Loading models from '{repo_id}'...")
35
-
36
  # 1. Loading color model
37
- print(" 📦 Loading color model...")
38
  color_model_path = hf_hub_download(
39
  repo_id=repo_id,
40
  filename="models/color_model.pt",
41
- cache_dir=cache_dir
42
- )
43
-
44
- # Loading vocabulary
45
- vocab_path = hf_hub_download(
46
- repo_id=repo_id,
47
- filename="tokenizer_vocab.json",
48
- cache_dir=cache_dir
49
  )
50
-
51
- with open(vocab_path, 'r') as f:
52
- vocab_dict = json.load(f)
53
-
54
- tokenizer = Tokenizer()
55
- tokenizer.load_vocab(vocab_dict)
56
-
57
- checkpoint = torch.load(color_model_path, map_location=device)
58
- vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
59
- color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(device)
60
- color_model.tokenizer = tokenizer
61
- color_model.load_state_dict(checkpoint)
62
- color_model.eval()
63
- print(" ✅ Color model loaded")
64
-
65
  # 2. Loading hierarchy model
66
- print(" 📦 Loading hierarchy model...")
67
  hierarchy_model_path = hf_hub_download(
68
  repo_id=repo_id,
69
  filename="models/hierarchy_model.pth",
70
- cache_dir=cache_dir
71
  )
72
-
73
- hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=device)
74
- hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
75
-
76
- hierarchy_model = HierarchyModel(
77
- num_hierarchy_classes=len(hierarchy_classes),
78
- embed_dim=config.hierarchy_emb_dim
79
- ).to(device)
80
- hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
81
-
82
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
83
- hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
84
- hierarchy_model.eval()
85
- print(" ✅ Hierarchy model loaded")
86
-
87
  # 3. Loading main CLIP model
88
- print(" 📦 Loading main CLIP model...")
89
  main_model_path = hf_hub_download(
90
  repo_id=repo_id,
91
  filename="models/gap_clip.pth",
92
- cache_dir=cache_dir
93
  )
94
-
95
  clip_model = CLIPModel_transformers.from_pretrained(
96
  'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
97
  )
98
  checkpoint = torch.load(main_model_path, map_location=device)
99
-
100
  # Handle different checkpoint structures
101
- if isinstance(checkpoint, dict):
102
- if 'model_state_dict' in checkpoint:
103
- clip_model.load_state_dict(checkpoint['model_state_dict'])
104
- else:
105
- # If the checkpoint is directly the state_dict
106
- clip_model.load_state_dict(checkpoint)
107
  else:
108
  clip_model.load_state_dict(checkpoint)
109
-
110
  clip_model = clip_model.to(device)
111
  clip_model.eval()
112
-
113
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
114
- print(" Main CLIP model loaded")
115
-
116
- print("\n✅ All models loaded!")
117
-
118
  return {
119
  'color_model': color_model,
120
  'hierarchy_model': hierarchy_model,
121
  'main_model': clip_model,
122
  'processor': processor,
123
- 'device': device
124
  }
125
 
126
 
127
  def example_search(models, image_path: str = None, text_query: str = None):
128
  """
129
- Example search with the models
130
-
131
  Args:
132
  models: Dictionary of loaded models
133
  image_path: Path to an image (optional)
134
  text_query: Text query (optional)
135
  """
136
-
137
  color_model = models['color_model']
138
  hierarchy_model = models['hierarchy_model']
139
  main_model = models['main_model']
140
  processor = models['processor']
141
  device = models['device']
142
-
143
- print("\n🔍 Example search...")
144
-
145
  if text_query:
146
- print(f" 📝 Text query: '{text_query}'")
147
-
148
  # Get color and hierarchy embeddings
149
  color_emb = color_model.get_text_embeddings([text_query])
150
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
151
-
152
- print(f" 🎨 Color embedding: {color_emb.shape}")
153
- print(f"color_emb: {color_emb}")
154
- print(f" 📂 Hierarchy embedding: {hierarchy_emb.shape}")
155
- print(f"hierarchy_emb: {hierarchy_emb}")
156
-
157
  # Get main model embeddings
158
- text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
159
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
160
-
161
- with torch.no_grad():
162
- # Use text_model directly for text-only processing
163
- text_outputs = main_model.text_model(**text_inputs)
164
- text_features = main_model.text_projection(text_outputs.pooler_output)
165
- text_features = F.normalize(text_features, dim=-1)
166
-
167
- print(f" 🎯 Main embedding: {text_features.shape}")
168
- print(f" 🎯 First logits of main embedding: {text_features[0:10]}")
169
-
170
  # Extract color and hierarchy embeddings from main embedding
171
- main_color_emb = text_features[:, :config.color_emb_dim]
172
- main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
173
-
174
- print(f"\n 📊 Comparison:")
175
- print(f" 🎨 Color embedding from color model: {color_emb[0]}")
176
- print(f" 🎨 Color embedding from main model (first {config.color_emb_dim} dims): {main_color_emb[0]}")
177
- print(f" 📂 Hierarchy embedding from hierarchy model: {hierarchy_emb[0]}")
178
- print(f" 📂 Hierarchy embedding from main model (dims {config.color_emb_dim}-{config.color_emb_dim+config.hierarchy_emb_dim}): {main_hierarchy_emb[0]}")
179
-
180
  # Calculate cosine similarity between color embeddings
181
  color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
182
- print(f"\n 🔍 Cosine similarity between color embeddings: {color_cosine_sim.item():.4f}")
183
-
184
  # Calculate cosine similarity between hierarchy embeddings
185
  hierarchy_cosine_sim = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
186
- print(f" 🔍 Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
187
-
188
  if image_path and os.path.exists(image_path):
189
- print(f" 🖼️ Image: {image_path}")
190
  image = Image.open(image_path).convert("RGB")
191
-
192
  # Get image embeddings
193
- image_inputs = processor(images=[image], return_tensors="pt")
194
- image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
195
-
196
- with torch.no_grad():
197
- # Use vision_model directly for image-only processing
198
- vision_outputs = main_model.vision_model(**image_inputs)
199
- image_features = main_model.visual_projection(vision_outputs.pooler_output)
200
- image_features = F.normalize(image_features, dim=-1)
201
-
202
- print(f" 🎯 Image embedding: {image_features.shape}")
203
 
204
 
205
  if __name__ == "__main__":
206
  import argparse
207
-
208
  parser = argparse.ArgumentParser(description="Example usage of models")
209
  parser.add_argument(
210
  "--repo-id",
211
  type=str,
212
  required=True,
213
- help="ID of the Hugging Face repository"
214
  )
215
  parser.add_argument(
216
  "--text",
217
  type=str,
218
  default="red dress",
219
- help="Text query for search"
220
  )
221
  parser.add_argument(
222
  "--image",
223
  type=str,
224
  default="red_dress.png",
225
- help="Path to an image"
226
  )
227
-
228
  args = parser.parse_args()
229
-
230
  # Load models
231
  models = load_models_from_hf(args.repo_id)
232
-
233
  # Example search
234
  example_search(models, image_path=args.image, text_query=args.text)
235
-
236
-
237
-
238
-
 
1
  #!/usr/bin/env python3
2
  """
3
+ Example usage of GAP-CLIP models.
4
+
5
+ This file provides example code for loading and using the models (color,
6
+ hierarchy, main) from local checkpoints or the Hugging Face Hub. It shows
7
+ how to load models, extract embeddings, and perform similarity comparisons.
8
  """
9
 
10
+ import os
11
+
12
  import torch
13
  import torch.nn.functional as F
14
  from PIL import Image
15
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
16
  from huggingface_hub import hf_hub_download
 
 
17
 
18
+ from training.color_model import ColorCLIP
19
+ from training.hierarchy_model import HierarchyModel
 
20
  import config
21
 
22
+
23
+ def encode_text(model, processor, text_queries, device):
24
+ """Encode text queries into embeddings (unnormalized)."""
25
+ if isinstance(text_queries, str):
26
+ text_queries = [text_queries]
27
+ inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
28
+ inputs = {k: v.to(device) for k, v in inputs.items()}
29
+ with torch.no_grad():
30
+ text_features = model.get_text_features(**inputs)
31
+ return text_features
32
+
33
+
34
+ def encode_image(model, processor, images, device):
35
+ """Encode images into embeddings (unnormalized)."""
36
+ if not isinstance(images, list):
37
+ images = [images]
38
+ inputs = processor(images=images, return_tensors="pt")
39
+ inputs = {k: v.to(device) for k, v in inputs.items()}
40
+ with torch.no_grad():
41
+ image_features = model.get_image_features(**inputs)
42
+ return image_features
43
+
44
+
45
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
46
  """
47
+ Load models from Hugging Face.
48
+
49
  Args:
50
  repo_id: ID of the Hugging Face repository
51
  cache_dir: Local cache directory
52
  """
 
53
  os.makedirs(cache_dir, exist_ok=True)
54
  device = config.device
55
+
56
+ print(f"Loading models from '{repo_id}'...")
57
+
58
  # 1. Loading color model
59
+ print(" Loading color model...")
60
  color_model_path = hf_hub_download(
61
  repo_id=repo_id,
62
  filename="models/color_model.pt",
63
+ cache_dir=cache_dir,
 
 
 
 
 
 
 
64
  )
65
+ color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
66
+ print(" Color model loaded")
67
+
 
 
 
 
 
 
 
 
 
 
 
 
68
  # 2. Loading hierarchy model
69
+ print(" Loading hierarchy model...")
70
  hierarchy_model_path = hf_hub_download(
71
  repo_id=repo_id,
72
  filename="models/hierarchy_model.pth",
73
+ cache_dir=cache_dir,
74
  )
75
+ hierarchy_model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
76
+ print(" Hierarchy model loaded")
77
+
 
 
 
 
 
 
 
 
 
 
 
 
78
  # 3. Loading main CLIP model
79
+ print(" Loading main CLIP model...")
80
  main_model_path = hf_hub_download(
81
  repo_id=repo_id,
82
  filename="models/gap_clip.pth",
83
+ cache_dir=cache_dir,
84
  )
85
+
86
  clip_model = CLIPModel_transformers.from_pretrained(
87
  'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
88
  )
89
  checkpoint = torch.load(main_model_path, map_location=device)
90
+
91
  # Handle different checkpoint structures
92
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
93
+ clip_model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
94
  else:
95
  clip_model.load_state_dict(checkpoint)
96
+
97
  clip_model = clip_model.to(device)
98
  clip_model.eval()
99
+
100
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
101
+ print(" Main CLIP model loaded")
102
+
103
+ print("\nAll models loaded!")
104
+
105
  return {
106
  'color_model': color_model,
107
  'hierarchy_model': hierarchy_model,
108
  'main_model': clip_model,
109
  'processor': processor,
110
+ 'device': device,
111
  }
112
 
113
 
114
  def example_search(models, image_path: str = None, text_query: str = None):
115
  """
116
+ Example search with the models.
117
+
118
  Args:
119
  models: Dictionary of loaded models
120
  image_path: Path to an image (optional)
121
  text_query: Text query (optional)
122
  """
 
123
  color_model = models['color_model']
124
  hierarchy_model = models['hierarchy_model']
125
  main_model = models['main_model']
126
  processor = models['processor']
127
  device = models['device']
128
+
129
+ print("\nExample search...")
130
+
131
  if text_query:
132
+ print(f" Text query: '{text_query}'")
133
+
134
  # Get color and hierarchy embeddings
135
  color_emb = color_model.get_text_embeddings([text_query])
136
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
137
+
138
+ print(f" Color embedding: {color_emb.shape}")
139
+ print(f" color_emb: {color_emb}")
140
+ print(f" Hierarchy embedding: {hierarchy_emb.shape}")
141
+ print(f" hierarchy_emb: {hierarchy_emb}")
142
+
143
  # Get main model embeddings
144
+ text_features = encode_text(main_model, processor, text_query, device)
145
+ text_features = F.normalize(text_features, dim=-1)
146
+
147
+ print(f" Main embedding: {text_features.shape}")
148
+ print(f" First logits of main embedding: {text_features[0:10]}")
149
+
 
 
 
 
 
 
150
  # Extract color and hierarchy embeddings from main embedding
151
+ main_color_emb = text_features[:, :config.color_emb_dim]
152
+ main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
153
+
154
+ print(f"\n Comparison:")
155
+ print(f" Color embedding from color model: {color_emb[0]}")
156
+ print(f" Color embedding from main model (first {config.color_emb_dim} dims): {main_color_emb[0]}")
157
+ print(f" Hierarchy embedding from hierarchy model: {hierarchy_emb[0]}")
158
+ print(f" Hierarchy embedding from main model (dims {config.color_emb_dim}-{config.color_emb_dim + config.hierarchy_emb_dim}): {main_hierarchy_emb[0]}")
159
+
160
  # Calculate cosine similarity between color embeddings
161
  color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
162
+ print(f"\n Cosine similarity between color embeddings: {color_cosine_sim.item():.4f}")
163
+
164
  # Calculate cosine similarity between hierarchy embeddings
165
  hierarchy_cosine_sim = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
166
+ print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
167
+
168
  if image_path and os.path.exists(image_path):
169
+ print(f" Image: {image_path}")
170
  image = Image.open(image_path).convert("RGB")
171
+
172
  # Get image embeddings
173
+ image_features = encode_image(main_model, processor, image, device)
174
+ image_features = F.normalize(image_features, dim=-1)
175
+
176
+ print(f" Image embedding: {image_features.shape}")
 
 
 
 
 
 
177
 
178
 
179
  if __name__ == "__main__":
180
  import argparse
181
+
182
  parser = argparse.ArgumentParser(description="Example usage of models")
183
  parser.add_argument(
184
  "--repo-id",
185
  type=str,
186
  required=True,
187
+ help="ID of the Hugging Face repository",
188
  )
189
  parser.add_argument(
190
  "--text",
191
  type=str,
192
  default="red dress",
193
+ help="Text query for search",
194
  )
195
  parser.add_argument(
196
  "--image",
197
  type=str,
198
  default="red_dress.png",
199
+ help="Path to an image",
200
  )
201
+
202
  args = parser.parse_args()
203
+
204
  # Load models
205
  models = load_models_from_hf(args.repo_id)
206
+
207
  # Example search
208
  example_search(models, image_path=args.image, text_query=args.text)