Leacb4 commited on
Commit
afbd922
·
verified ·
1 Parent(s): b3bf2b7

Upload evaluation/utils/model_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/utils/model_loader.py +53 -116
evaluation/utils/model_loader.py CHANGED
@@ -8,8 +8,6 @@ the loading logic.
8
 
9
  from __future__ import annotations
10
 
11
- import json
12
- import os
13
  import sys
14
  from pathlib import Path
15
  from typing import Tuple
@@ -44,7 +42,7 @@ def load_gap_clip(
44
  (model, processor) ready for inference.
45
  """
46
  model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
47
- checkpoint = torch.load(model_path, map_location=device)
48
 
49
  if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
50
  model.load_state_dict(checkpoint["model_state_dict"])
@@ -82,140 +80,79 @@ def load_baseline_fashion_clip(
82
 
83
  def load_color_model(
84
  color_model_path: str,
85
- tokenizer_path: str,
86
- color_emb_dim: int,
87
  device: torch.device,
88
- repo_id: str = "Leacb4/gap-clip",
89
- cache_dir: str = "./models_cache",
90
  ):
91
- """Load the specialized 16D color model (ColorCLIP) and its tokenizer.
92
 
93
- Falls back to Hugging Face Hub if local files are not found.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  Returns:
96
- (color_model, color_tokenizer)
97
  """
98
- from training.color_model import ColorCLIP, Tokenizer # type: ignore
99
 
100
- local_model_exists = os.path.exists(color_model_path)
101
- local_tokenizer_exists = os.path.exists(tokenizer_path)
 
 
102
 
103
- if local_model_exists and local_tokenizer_exists:
104
- print("Loading specialized color model (16D) from local files...")
105
- state_dict = torch.load(color_model_path, map_location=device)
106
- with open(tokenizer_path, "r") as f:
107
- vocab = json.load(f)
108
- else:
109
- from huggingface_hub import hf_hub_download # type: ignore
110
-
111
- print(f"Local color model/tokenizer not found. Loading from Hugging Face ({repo_id})...")
112
- hf_model_path = hf_hub_download(
113
- repo_id=repo_id, filename="color_model.pt", cache_dir=cache_dir
114
- )
115
- hf_vocab_path = hf_hub_download(
116
- repo_id=repo_id, filename="tokenizer_vocab.json", cache_dir=cache_dir
117
- )
118
- state_dict = torch.load(hf_model_path, map_location=device)
119
- with open(hf_vocab_path, "r") as f:
120
- vocab = json.load(f)
121
-
122
- vocab_size = state_dict["text_encoder.embedding.weight"].shape[0]
123
- print(f" Detected vocab size from checkpoint: {vocab_size}")
124
-
125
- tokenizer = Tokenizer()
126
- tokenizer.load_vocab(vocab)
127
-
128
- color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=color_emb_dim)
129
- color_model.load_state_dict(state_dict)
130
- color_model.to(device)
131
- color_model.eval()
132
- print("Color model loaded successfully")
133
- return color_model, tokenizer
134
 
135
 
136
  # ---------------------------------------------------------------------------
137
- # Embedding extraction helpers
138
  # ---------------------------------------------------------------------------
139
 
140
- def get_text_embedding(
141
- model: CLIPModelTransformers,
142
- processor: CLIPProcessor,
143
- device: torch.device,
144
- text: str,
145
- ) -> torch.Tensor:
146
- """Extract a single normalized text embedding (shape: [512])."""
147
- text_inputs = processor(text=[text], padding=True, return_tensors="pt")
148
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
149
-
150
  with torch.no_grad():
151
- text_outputs = model.text_model(**text_inputs)
152
- text_features = model.text_projection(text_outputs.pooler_output)
153
- text_features = F.normalize(text_features, dim=-1)
154
-
155
- return text_features.squeeze(0)
156
 
157
 
158
- def get_text_embeddings_batch(
159
- model: CLIPModelTransformers,
160
- processor: CLIPProcessor,
161
- device: torch.device,
162
- texts: list[str],
163
- ) -> torch.Tensor:
164
- """Extract normalized text embeddings for a batch of strings (shape: [N, 512])."""
165
- text_inputs = processor(text=texts, padding=True, return_tensors="pt", truncation=True, max_length=77)
166
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
167
-
168
  with torch.no_grad():
169
- text_outputs = model.text_model(**text_inputs)
170
- text_features = model.text_projection(text_outputs.pooler_output)
171
- text_features = F.normalize(text_features, dim=-1)
172
 
173
- return text_features
174
 
 
 
 
175
 
176
- def get_image_embedding(
177
- model: CLIPModelTransformers,
178
- image: torch.Tensor,
179
- device: torch.device,
180
- ) -> torch.Tensor:
181
- """Extract a normalized image embedding from a preprocessed tensor.
182
 
183
- Args:
184
- model: GAP-CLIP model.
185
- image: Tensor of shape (C, H, W) or (1, C, H, W) or (N, C, H, W).
186
- device: Target device.
187
 
188
- Returns:
189
- Normalized embedding tensor of shape (1, 512) or (N, 512).
190
- """
191
- model.eval()
192
- with torch.no_grad():
193
- if image.dim() == 3 and image.size(0) == 1:
194
- image = image.expand(3, -1, -1)
195
- elif image.dim() == 4 and image.size(1) == 1:
196
- image = image.expand(-1, 3, -1, -1)
197
- if image.dim() == 3:
198
- image = image.unsqueeze(0)
199
-
200
- image = image.to(device)
201
- vision_outputs = model.vision_model(pixel_values=image)
202
- image_features = model.visual_projection(vision_outputs.pooler_output)
203
- return F.normalize(image_features, dim=-1)
204
-
205
-
206
- def get_image_embedding_from_pil(
207
- model: CLIPModelTransformers,
208
- processor: CLIPProcessor,
209
- device: torch.device,
210
- pil_image: Image.Image,
211
- ) -> torch.Tensor:
212
- """Extract a normalized image embedding from a PIL image (shape: [512])."""
213
- inputs = processor(images=pil_image, return_tensors="pt")
214
- inputs = {k: v.to(device) for k, v in inputs.items()}
215
 
216
- with torch.no_grad():
217
- vision_outputs = model.vision_model(**inputs)
218
- image_features = model.visual_projection(vision_outputs.pooler_output)
219
- image_features = F.normalize(image_features, dim=-1)
220
 
221
- return image_features.squeeze(0)
 
 
 
8
 
9
  from __future__ import annotations
10
 
 
 
11
  import sys
12
  from pathlib import Path
13
  from typing import Tuple
 
42
  (model, processor) ready for inference.
43
  """
44
  model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
45
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
46
 
47
  if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
48
  model.load_state_dict(checkpoint["model_state_dict"])
 
80
 
81
  def load_color_model(
82
  color_model_path: str,
 
 
83
  device: torch.device,
 
 
84
  ):
85
+ """Load the specialized 16D color model (CLIP-backbone).
86
 
87
+ Returns:
88
+ (color_model, None) -- second element kept for API compatibility
89
+ """
90
+ from training.color_model import ColorCLIP # type: ignore
91
+
92
+ print("Loading ColorCLIP (CLIP-backbone, 16D) ...")
93
+ color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
94
+ print("Color model loaded successfully")
95
+ return color_model, None
96
+
97
+
98
+ def load_hierarchy_model(
99
+ hierarchy_model_path: str,
100
+ device: torch.device,
101
+ ):
102
+ """Load the hierarchy model (CLIP-backbone).
103
 
104
  Returns:
105
+ hierarchy_model ready for inference.
106
  """
107
+ from training.hierarchy_model import HierarchyModel # type: ignore
108
 
109
+ print("Loading HierarchyModel (CLIP-backbone, 64D) ...")
110
+ model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
111
+ print("Hierarchy model loaded successfully")
112
+ return model
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  # ---------------------------------------------------------------------------
117
+ # Core encoding helpers (same as notebook)
118
  # ---------------------------------------------------------------------------
119
 
120
+ def encode_text(model, processor, text_queries, device):
121
+ """Encode text queries into embeddings (unnormalized)."""
122
+ if isinstance(text_queries, str):
123
+ text_queries = [text_queries]
124
+ inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
125
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
126
  with torch.no_grad():
127
+ text_features = model.get_text_features(**inputs)
128
+ return text_features
 
 
 
129
 
130
 
131
+ def encode_image(model, processor, images, device):
132
+ """Encode images into embeddings (unnormalized)."""
133
+ if not isinstance(images, list):
134
+ images = [images]
135
+ inputs = processor(images=images, return_tensors="pt")
136
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
137
  with torch.no_grad():
138
+ image_features = model.get_image_features(**inputs)
139
+ return image_features
 
140
 
 
141
 
142
+ # ---------------------------------------------------------------------------
143
+ # Normalized wrappers (preserve old call signatures used across eval scripts)
144
+ # ---------------------------------------------------------------------------
145
 
146
+ def get_text_embedding(model, processor, device, text):
147
+ """Single normalized text embedding (shape: [512])."""
148
+ return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)
 
 
 
149
 
 
 
 
 
150
 
151
+ def get_text_embeddings_batch(model, processor, device, texts):
152
+ """Normalized text embeddings for a batch (shape: [N, 512])."""
153
+ return F.normalize(encode_text(model, processor, texts, device), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
155
 
156
+ def get_image_embedding_from_pil(model, processor, device, pil_image):
157
+ """Normalized image embedding from a PIL image (shape: [512])."""
158
+ return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0)