Add simple API: load_gap_clip, get_image_embedding_from_url, get_text_embedding + fix bugs
Browse files- __init__.py +9 -2
- example_usage.py +257 -26
__init__.py
CHANGED
|
@@ -26,15 +26,22 @@ __email__ = "lea.attia@gmail.com"
|
|
| 26 |
# Import main components for easy access
|
| 27 |
try:
|
| 28 |
from .training.color_model import ColorCLIP
|
| 29 |
-
from .training.hierarchy_model import
|
| 30 |
-
from .example_usage import
|
|
|
|
|
|
|
|
|
|
| 31 |
from . import config
|
| 32 |
|
| 33 |
__all__ = [
|
| 34 |
'ColorCLIP',
|
| 35 |
'HierarchyModel',
|
| 36 |
'HierarchyExtractor',
|
|
|
|
|
|
|
|
|
|
| 37 |
'load_models_from_hf',
|
|
|
|
| 38 |
'example_search',
|
| 39 |
'config',
|
| 40 |
'__version__',
|
|
|
|
| 26 |
# Import main components for easy access
|
| 27 |
try:
|
| 28 |
from .training.color_model import ColorCLIP
|
| 29 |
+
from .training.hierarchy_model import HierarchyModel, HierarchyExtractor
|
| 30 |
+
from .example_usage import (
|
| 31 |
+
load_gap_clip, get_image_embedding_from_url, get_text_embedding,
|
| 32 |
+
load_models_from_hf, load_models_from_local, example_search,
|
| 33 |
+
)
|
| 34 |
from . import config
|
| 35 |
|
| 36 |
__all__ = [
|
| 37 |
'ColorCLIP',
|
| 38 |
'HierarchyModel',
|
| 39 |
'HierarchyExtractor',
|
| 40 |
+
'load_gap_clip',
|
| 41 |
+
'get_image_embedding_from_url',
|
| 42 |
+
'get_text_embedding',
|
| 43 |
'load_models_from_hf',
|
| 44 |
+
'load_models_from_local',
|
| 45 |
'example_search',
|
| 46 |
'config',
|
| 47 |
'__version__',
|
example_usage.py
CHANGED
|
@@ -11,6 +11,7 @@ import os
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
|
|
|
| 14 |
from PIL import Image
|
| 15 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 16 |
from huggingface_hub import hf_hub_download
|
|
@@ -19,6 +20,83 @@ from training.color_model import ColorCLIP
|
|
| 19 |
from training.hierarchy_model import HierarchyModel
|
| 20 |
import config
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def encode_text(model, processor, text_queries, device):
|
| 24 |
"""Encode text queries into embeddings (unnormalized)."""
|
|
@@ -83,10 +161,8 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 83 |
cache_dir=cache_dir,
|
| 84 |
)
|
| 85 |
|
| 86 |
-
clip_model = CLIPModel_transformers.from_pretrained(
|
| 87 |
-
|
| 88 |
-
)
|
| 89 |
-
checkpoint = torch.load(main_model_path, map_location=device)
|
| 90 |
|
| 91 |
# Handle different checkpoint structures
|
| 92 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
@@ -97,7 +173,60 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 97 |
clip_model = clip_model.to(device)
|
| 98 |
clip_model.eval()
|
| 99 |
|
| 100 |
-
processor = CLIPProcessor.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
print(" Main CLIP model loaded")
|
| 102 |
|
| 103 |
print("\nAll models loaded!")
|
|
@@ -135,27 +264,26 @@ def example_search(models, image_path: str = None, text_query: str = None):
|
|
| 135 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 136 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 137 |
|
| 138 |
-
print(f" Color embedding: {color_emb.shape}")
|
| 139 |
-
print(f"
|
| 140 |
-
print(f" Hierarchy embedding: {hierarchy_emb.shape}")
|
| 141 |
-
print(f" hierarchy_emb: {hierarchy_emb}")
|
| 142 |
|
| 143 |
# Get main model embeddings
|
| 144 |
text_features = encode_text(main_model, processor, text_query, device)
|
| 145 |
text_features = F.normalize(text_features, dim=-1)
|
| 146 |
|
| 147 |
print(f" Main embedding: {text_features.shape}")
|
| 148 |
-
print(f" First
|
| 149 |
|
| 150 |
# Extract color and hierarchy embeddings from main embedding
|
| 151 |
main_color_emb = text_features[:, :config.color_emb_dim]
|
| 152 |
main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
|
| 153 |
|
| 154 |
-
print(f"\n
|
| 155 |
-
print(f"
|
| 156 |
-
print(f"
|
| 157 |
-
print(f"
|
| 158 |
-
print(f"
|
|
|
|
| 159 |
|
| 160 |
# Calculate cosine similarity between color embeddings
|
| 161 |
color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
|
@@ -166,25 +294,114 @@ def example_search(models, image_path: str = None, text_query: str = None):
|
|
| 166 |
print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
|
| 167 |
|
| 168 |
if image_path and os.path.exists(image_path):
|
| 169 |
-
print(f" Image: {image_path}")
|
| 170 |
image = Image.open(image_path).convert("RGB")
|
| 171 |
|
| 172 |
-
#
|
| 173 |
image_features = encode_image(main_model, processor, image, device)
|
| 174 |
image_features = F.normalize(image_features, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
if __name__ == "__main__":
|
| 180 |
import argparse
|
| 181 |
|
| 182 |
-
parser = argparse.ArgumentParser(description="Example usage of models")
|
| 183 |
parser.add_argument(
|
| 184 |
"--repo-id",
|
| 185 |
type=str,
|
| 186 |
-
|
| 187 |
-
help="ID
|
| 188 |
)
|
| 189 |
parser.add_argument(
|
| 190 |
"--text",
|
|
@@ -195,14 +412,28 @@ if __name__ == "__main__":
|
|
| 195 |
parser.add_argument(
|
| 196 |
"--image",
|
| 197 |
type=str,
|
| 198 |
-
default=
|
| 199 |
-
help="Path to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
)
|
| 201 |
|
| 202 |
args = parser.parse_args()
|
| 203 |
|
| 204 |
-
# Load models
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
# Example search
|
| 208 |
example_search(models, image_path=args.image, text_query=args.text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
| 14 |
+
import requests
|
| 15 |
from PIL import Image
|
| 16 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 17 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 20 |
from training.hierarchy_model import HierarchyModel
|
| 21 |
import config
|
| 22 |
|
| 23 |
+
CLIP_MODEL_NAME = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
| 24 |
+
HF_REPO_ID = "Leacb4/gap-clip"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Simple API — load from HF and get 512D embeddings
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def load_gap_clip(repo_id: str = HF_REPO_ID):
|
| 32 |
+
"""
|
| 33 |
+
Load the GAP-CLIP model directly from Hugging Face.
|
| 34 |
+
|
| 35 |
+
This is the simplest way to use the model. Returns (model, processor).
|
| 36 |
+
|
| 37 |
+
Example::
|
| 38 |
+
|
| 39 |
+
model, processor = load_gap_clip()
|
| 40 |
+
emb = get_image_embedding_from_url(
|
| 41 |
+
"https://www.gap.com/webcontent/0060/662/817/cn60662817.jpg",
|
| 42 |
+
model, processor,
|
| 43 |
+
)
|
| 44 |
+
print(emb.shape) # torch.Size([1, 512])
|
| 45 |
+
"""
|
| 46 |
+
model = CLIPModel_transformers.from_pretrained(repo_id)
|
| 47 |
+
processor = CLIPProcessor.from_pretrained(repo_id)
|
| 48 |
+
model.eval()
|
| 49 |
+
return model, processor
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_image_embedding_from_url(url: str, model, processor, device=None):
|
| 53 |
+
"""
|
| 54 |
+
Download an image from a URL and return its 512D GAP-CLIP embedding.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
url: Image URL.
|
| 58 |
+
model: CLIPModel loaded via load_gap_clip() or from_pretrained().
|
| 59 |
+
processor: CLIPProcessor matching the model.
|
| 60 |
+
device: Device to run on (defaults to config.device).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tensor of shape [1, 512] (L2-normalized).
|
| 64 |
+
"""
|
| 65 |
+
device = device or config.device
|
| 66 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 67 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 68 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 69 |
+
model = model.to(device)
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
image_features = model.get_image_features(**inputs)
|
| 72 |
+
return F.normalize(image_features, dim=-1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_text_embedding(text: str, model, processor, device=None):
|
| 76 |
+
"""
|
| 77 |
+
Return a 512D GAP-CLIP embedding for a text query.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
text: Text query (e.g., "red dress").
|
| 81 |
+
model: CLIPModel loaded via load_gap_clip() or from_pretrained().
|
| 82 |
+
processor: CLIPProcessor matching the model.
|
| 83 |
+
device: Device to run on (defaults to config.device).
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Tensor of shape [1, 512] (L2-normalized).
|
| 87 |
+
"""
|
| 88 |
+
device = device or config.device
|
| 89 |
+
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
|
| 90 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 91 |
+
model = model.to(device)
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
text_features = model.get_text_features(**inputs)
|
| 94 |
+
return F.normalize(text_features, dim=-1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Internal helpers for encode_text / encode_image (used by advanced examples)
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
|
| 101 |
def encode_text(model, processor, text_queries, device):
|
| 102 |
"""Encode text queries into embeddings (unnormalized)."""
|
|
|
|
| 161 |
cache_dir=cache_dir,
|
| 162 |
)
|
| 163 |
|
| 164 |
+
clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME)
|
| 165 |
+
checkpoint = torch.load(main_model_path, map_location=device, weights_only=False)
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Handle different checkpoint structures
|
| 168 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
|
|
| 173 |
clip_model = clip_model.to(device)
|
| 174 |
clip_model.eval()
|
| 175 |
|
| 176 |
+
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
|
| 177 |
+
print(" Main CLIP model loaded")
|
| 178 |
+
|
| 179 |
+
print("\nAll models loaded!")
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
'color_model': color_model,
|
| 183 |
+
'hierarchy_model': hierarchy_model,
|
| 184 |
+
'main_model': clip_model,
|
| 185 |
+
'processor': processor,
|
| 186 |
+
'device': device,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def load_models_from_local(
|
| 191 |
+
color_model_path: str = None,
|
| 192 |
+
hierarchy_model_path: str = None,
|
| 193 |
+
main_model_path: str = None,
|
| 194 |
+
):
|
| 195 |
+
"""
|
| 196 |
+
Load models from local checkpoint files.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
color_model_path: Path to color_model.pt (defaults to config.color_model_path)
|
| 200 |
+
hierarchy_model_path: Path to hierarchy_model.pth (defaults to config.hierarchy_model_path)
|
| 201 |
+
main_model_path: Path to gap_clip.pth (defaults to config.main_model_path)
|
| 202 |
+
"""
|
| 203 |
+
device = config.device
|
| 204 |
+
color_model_path = color_model_path or config.color_model_path
|
| 205 |
+
hierarchy_model_path = hierarchy_model_path or config.hierarchy_model_path
|
| 206 |
+
main_model_path = main_model_path or config.main_model_path
|
| 207 |
+
|
| 208 |
+
print(f"Loading models from local checkpoints (device={device})...")
|
| 209 |
+
|
| 210 |
+
# 1. Color model
|
| 211 |
+
print(" Loading color model...")
|
| 212 |
+
color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
|
| 213 |
+
print(" Color model loaded")
|
| 214 |
+
|
| 215 |
+
# 2. Hierarchy model
|
| 216 |
+
print(" Loading hierarchy model...")
|
| 217 |
+
hierarchy_model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
|
| 218 |
+
print(" Hierarchy model loaded")
|
| 219 |
+
|
| 220 |
+
# 3. Main CLIP model
|
| 221 |
+
print(" Loading main CLIP model...")
|
| 222 |
+
clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME)
|
| 223 |
+
checkpoint = torch.load(main_model_path, map_location=device, weights_only=False)
|
| 224 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 225 |
+
clip_model.load_state_dict(checkpoint['model_state_dict'])
|
| 226 |
+
else:
|
| 227 |
+
clip_model.load_state_dict(checkpoint)
|
| 228 |
+
clip_model.to(device).eval()
|
| 229 |
+
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
|
| 230 |
print(" Main CLIP model loaded")
|
| 231 |
|
| 232 |
print("\nAll models loaded!")
|
|
|
|
| 264 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 265 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 266 |
|
| 267 |
+
print(f" Color embedding shape: {color_emb.shape}, norm: {color_emb.norm(dim=-1).item():.4f}")
|
| 268 |
+
print(f" Hierarchy embedding shape: {hierarchy_emb.shape}, norm: {hierarchy_emb.norm(dim=-1).item():.4f}")
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Get main model embeddings
|
| 271 |
text_features = encode_text(main_model, processor, text_query, device)
|
| 272 |
text_features = F.normalize(text_features, dim=-1)
|
| 273 |
|
| 274 |
print(f" Main embedding: {text_features.shape}")
|
| 275 |
+
print(f" First 10 dims of main embedding: {text_features[0, :10]}")
|
| 276 |
|
| 277 |
# Extract color and hierarchy embeddings from main embedding
|
| 278 |
main_color_emb = text_features[:, :config.color_emb_dim]
|
| 279 |
main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
|
| 280 |
|
| 281 |
+
print(f"\n Subspace comparison (color model vs main model dims [0:{config.color_emb_dim}]):")
|
| 282 |
+
print(f" color_model first 5 dims: {color_emb[0, :5].tolist()}")
|
| 283 |
+
print(f" main_model first 5 dims: {main_color_emb[0, :5].tolist()}")
|
| 284 |
+
print(f" Subspace comparison (hierarchy model vs main model dims [{config.color_emb_dim}:{config.color_emb_dim + config.hierarchy_emb_dim}]):")
|
| 285 |
+
print(f" hierarchy_model first 5 dims: {hierarchy_emb[0, :5].tolist()}")
|
| 286 |
+
print(f" main_model first 5 dims: {main_hierarchy_emb[0, :5].tolist()}")
|
| 287 |
|
| 288 |
# Calculate cosine similarity between color embeddings
|
| 289 |
color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
|
|
|
| 294 |
print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
|
| 295 |
|
| 296 |
if image_path and os.path.exists(image_path):
|
| 297 |
+
print(f"\n Image: {image_path}")
|
| 298 |
image = Image.open(image_path).convert("RGB")
|
| 299 |
|
| 300 |
+
# Main model image embedding
|
| 301 |
image_features = encode_image(main_model, processor, image, device)
|
| 302 |
image_features = F.normalize(image_features, dim=-1)
|
| 303 |
+
print(f" Main image embedding shape: {image_features.shape}")
|
| 304 |
+
|
| 305 |
+
# Color model image embedding (preprocess through model's own processor)
|
| 306 |
+
color_pixel_values = color_model.processor(
|
| 307 |
+
images=image, return_tensors="pt"
|
| 308 |
+
)["pixel_values"].to(device)
|
| 309 |
+
color_img_emb = color_model.get_image_embeddings(color_pixel_values)
|
| 310 |
+
print(f" Color image embedding shape: {color_img_emb.shape}")
|
| 311 |
+
|
| 312 |
+
# Hierarchy model image embedding
|
| 313 |
+
hierarchy_pixel_values = hierarchy_model.processor(
|
| 314 |
+
images=image, return_tensors="pt"
|
| 315 |
+
)["pixel_values"].to(device)
|
| 316 |
+
hierarchy_img_emb = hierarchy_model.get_image_embeddings(hierarchy_pixel_values)
|
| 317 |
+
print(f" Hierarchy image embedding shape: {hierarchy_img_emb.shape}")
|
| 318 |
+
|
| 319 |
+
# Compare subspace alignment for images
|
| 320 |
+
main_color_img = image_features[:, :config.color_emb_dim]
|
| 321 |
+
main_hierarchy_img = image_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim]
|
| 322 |
+
color_img_sim = F.cosine_similarity(color_img_emb, main_color_img, dim=1)
|
| 323 |
+
hierarchy_img_sim = F.cosine_similarity(hierarchy_img_emb, main_hierarchy_img, dim=1)
|
| 324 |
+
print(f" Image color subspace cosine similarity: {color_img_sim.item():.4f}")
|
| 325 |
+
print(f" Image hierarchy subspace cosine similarity: {hierarchy_img_sim.item():.4f}")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def example_similarity_search(models, image_paths: list, text_query: str):
|
| 329 |
+
"""
|
| 330 |
+
Rank images by similarity to a text query using GAP-CLIP.
|
| 331 |
+
|
| 332 |
+
Shows the key use case: computing text-to-image similarity scores
|
| 333 |
+
for ranking, combining color, hierarchy, and general CLIP subspaces.
|
| 334 |
|
| 335 |
+
Args:
|
| 336 |
+
models: Dictionary of loaded models
|
| 337 |
+
image_paths: List of image file paths to rank
|
| 338 |
+
text_query: Text query to match against
|
| 339 |
+
"""
|
| 340 |
+
main_model = models['main_model']
|
| 341 |
+
processor = models['processor']
|
| 342 |
+
device = models['device']
|
| 343 |
+
|
| 344 |
+
print(f"\nSimilarity search: '{text_query}' against {len(image_paths)} images")
|
| 345 |
+
|
| 346 |
+
# Encode the text query
|
| 347 |
+
text_features = encode_text(main_model, processor, text_query, device)
|
| 348 |
+
text_features = F.normalize(text_features, dim=-1) # [1, 512]
|
| 349 |
+
|
| 350 |
+
# Encode all images
|
| 351 |
+
images = []
|
| 352 |
+
valid_paths = []
|
| 353 |
+
for p in image_paths:
|
| 354 |
+
if os.path.exists(p):
|
| 355 |
+
images.append(Image.open(p).convert("RGB"))
|
| 356 |
+
valid_paths.append(p)
|
| 357 |
+
else:
|
| 358 |
+
print(f" Warning: {p} not found, skipping")
|
| 359 |
+
|
| 360 |
+
if not images:
|
| 361 |
+
print(" No valid images found.")
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
image_features = encode_image(main_model, processor, images, device)
|
| 365 |
+
image_features = F.normalize(image_features, dim=-1) # [N, 512]
|
| 366 |
+
|
| 367 |
+
# Full 512D similarity
|
| 368 |
+
full_scores = (text_features @ image_features.T).squeeze(0) # [N]
|
| 369 |
+
|
| 370 |
+
# Subspace similarities
|
| 371 |
+
color_dim = config.color_emb_dim
|
| 372 |
+
hierarchy_end = color_dim + config.hierarchy_emb_dim
|
| 373 |
+
|
| 374 |
+
color_text = F.normalize(text_features[:, :color_dim], dim=-1)
|
| 375 |
+
color_imgs = F.normalize(image_features[:, :color_dim], dim=-1)
|
| 376 |
+
color_scores = (color_text @ color_imgs.T).squeeze(0)
|
| 377 |
+
|
| 378 |
+
hier_text = F.normalize(text_features[:, color_dim:hierarchy_end], dim=-1)
|
| 379 |
+
hier_imgs = F.normalize(image_features[:, color_dim:hierarchy_end], dim=-1)
|
| 380 |
+
hierarchy_scores = (hier_text @ hier_imgs.T).squeeze(0)
|
| 381 |
+
|
| 382 |
+
# Rank by full similarity
|
| 383 |
+
ranked_indices = full_scores.argsort(descending=True)
|
| 384 |
+
|
| 385 |
+
print(f"\n Ranking (by full 512D cosine similarity):")
|
| 386 |
+
for rank, idx in enumerate(ranked_indices):
|
| 387 |
+
i = idx.item()
|
| 388 |
+
print(
|
| 389 |
+
f" {rank + 1}. {os.path.basename(valid_paths[i]):30s}"
|
| 390 |
+
f" full={full_scores[i]:.4f}"
|
| 391 |
+
f" color={color_scores[i]:.4f}"
|
| 392 |
+
f" hierarchy={hierarchy_scores[i]:.4f}"
|
| 393 |
+
)
|
| 394 |
|
| 395 |
|
| 396 |
if __name__ == "__main__":
|
| 397 |
import argparse
|
| 398 |
|
| 399 |
+
parser = argparse.ArgumentParser(description="Example usage of GAP-CLIP models")
|
| 400 |
parser.add_argument(
|
| 401 |
"--repo-id",
|
| 402 |
type=str,
|
| 403 |
+
default=None,
|
| 404 |
+
help="Hugging Face repo ID (e.g., Leacb4/gap-clip). If omitted, loads from local paths.",
|
| 405 |
)
|
| 406 |
parser.add_argument(
|
| 407 |
"--text",
|
|
|
|
| 412 |
parser.add_argument(
|
| 413 |
"--image",
|
| 414 |
type=str,
|
| 415 |
+
default=None,
|
| 416 |
+
help="Path to a single image for example_search",
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--images",
|
| 420 |
+
type=str,
|
| 421 |
+
nargs="+",
|
| 422 |
+
default=None,
|
| 423 |
+
help="Paths to multiple images for similarity ranking",
|
| 424 |
)
|
| 425 |
|
| 426 |
args = parser.parse_args()
|
| 427 |
|
| 428 |
+
# Load models (HF or local)
|
| 429 |
+
if args.repo_id:
|
| 430 |
+
models = load_models_from_hf(args.repo_id)
|
| 431 |
+
else:
|
| 432 |
+
models = load_models_from_local()
|
| 433 |
|
| 434 |
+
# Example search (embedding inspection)
|
| 435 |
example_search(models, image_path=args.image, text_query=args.text)
|
| 436 |
+
|
| 437 |
+
# Similarity ranking (if multiple images provided)
|
| 438 |
+
if args.images:
|
| 439 |
+
example_similarity_search(models, args.images, args.text)
|