| |
| """ |
| Example usage of GAP-CLIP models. |
| |
| This file provides example code for loading and using the models (color, |
| hierarchy, main) from local checkpoints or the Hugging Face Hub. It shows |
| how to load models, extract embeddings, and perform similarity comparisons. |
| """ |
|
|
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
| import requests |
| from PIL import Image |
| from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
| from huggingface_hub import hf_hub_download |
|
|
| from training.color_model import ColorCLIP |
| from training.hierarchy_model import HierarchyModel |
| import config |
|
|
| CLIP_MODEL_NAME = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
| HF_REPO_ID = "Leacb4/gap-clip" |
|
|
|
|
| |
| |
| |
|
|
| def load_gap_clip(repo_id: str = HF_REPO_ID): |
| """ |
| Load the GAP-CLIP model directly from Hugging Face. |
| |
| This is the simplest way to use the model. Returns (model, processor). |
| |
| Example:: |
| |
| model, processor = load_gap_clip() |
| emb = get_image_embedding_from_url( |
| "https://www.gap.com/webcontent/0060/662/817/cn60662817.jpg", |
| model, processor, |
| ) |
| print(emb.shape) # torch.Size([1, 512]) |
| """ |
| model = CLIPModel_transformers.from_pretrained(repo_id) |
| processor = CLIPProcessor.from_pretrained(repo_id) |
| model.eval() |
| return model, processor |
|
|
|
|
| def get_image_embedding_from_url(url: str, model, processor, device=None): |
| """ |
| Download an image from a URL and return its 512D GAP-CLIP embedding. |
| |
| Args: |
| url: Image URL. |
| model: CLIPModel loaded via load_gap_clip() or from_pretrained(). |
| processor: CLIPProcessor matching the model. |
| device: Device to run on (defaults to config.device). |
| |
| Returns: |
| Tensor of shape [1, 512] (L2-normalized). |
| """ |
| device = device or config.device |
| image = Image.open(requests.get(url, stream=True).raw).convert("RGB") |
| inputs = processor(images=image, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| model = model.to(device) |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| return F.normalize(image_features, dim=-1) |
|
|
|
|
| def get_text_embedding(text: str, model, processor, device=None): |
| """ |
| Return a 512D GAP-CLIP embedding for a text query. |
| |
| Args: |
| text: Text query (e.g., "red dress"). |
| model: CLIPModel loaded via load_gap_clip() or from_pretrained(). |
| processor: CLIPProcessor matching the model. |
| device: Device to run on (defaults to config.device). |
| |
| Returns: |
| Tensor of shape [1, 512] (L2-normalized). |
| """ |
| device = device or config.device |
| inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| model = model.to(device) |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| return F.normalize(text_features, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def encode_text(model, processor, text_queries, device): |
| """Encode text queries into embeddings (unnormalized).""" |
| if isinstance(text_queries, str): |
| text_queries = [text_queries] |
| inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| return text_features |
|
|
|
|
| def encode_image(model, processor, images, device): |
| """Encode images into embeddings (unnormalized).""" |
| if not isinstance(images, list): |
| images = [images] |
| inputs = processor(images=images, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| return image_features |
|
|
|
|
| def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"): |
| """ |
| Load models from Hugging Face. |
| |
| Args: |
| repo_id: ID of the Hugging Face repository |
| cache_dir: Local cache directory |
| """ |
| os.makedirs(cache_dir, exist_ok=True) |
| device = config.device |
|
|
| print(f"Loading models from '{repo_id}'...") |
|
|
| |
| print(" Loading color model...") |
| color_model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="models/color_model.pt", |
| cache_dir=cache_dir, |
| ) |
| color_model = ColorCLIP.from_checkpoint(color_model_path, device=device) |
| print(" Color model loaded") |
|
|
| |
| print(" Loading hierarchy model...") |
| hierarchy_model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="models/hierarchy_model.pth", |
| cache_dir=cache_dir, |
| ) |
| hierarchy_model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device) |
| print(" Hierarchy model loaded") |
|
|
| |
| print(" Loading main CLIP model...") |
| main_model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="models/gap_clip.pth", |
| cache_dir=cache_dir, |
| ) |
|
|
| clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME) |
| checkpoint = torch.load(main_model_path, map_location=device, weights_only=False) |
|
|
| |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| clip_model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| clip_model.load_state_dict(checkpoint) |
|
|
| clip_model = clip_model.to(device) |
| clip_model.eval() |
|
|
| processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
| print(" Main CLIP model loaded") |
|
|
| print("\nAll models loaded!") |
|
|
| return { |
| 'color_model': color_model, |
| 'hierarchy_model': hierarchy_model, |
| 'main_model': clip_model, |
| 'processor': processor, |
| 'device': device, |
| } |
|
|
|
|
| def load_models_from_local( |
| color_model_path: str = None, |
| hierarchy_model_path: str = None, |
| main_model_path: str = None, |
| ): |
| """ |
| Load models from local checkpoint files. |
| |
| Args: |
| color_model_path: Path to color_model.pt (defaults to config.color_model_path) |
| hierarchy_model_path: Path to hierarchy_model.pth (defaults to config.hierarchy_model_path) |
| main_model_path: Path to gap_clip.pth (defaults to config.main_model_path) |
| """ |
| device = config.device |
| color_model_path = color_model_path or config.color_model_path |
| hierarchy_model_path = hierarchy_model_path or config.hierarchy_model_path |
| main_model_path = main_model_path or config.main_model_path |
|
|
| print(f"Loading models from local checkpoints (device={device})...") |
|
|
| |
| print(" Loading color model...") |
| color_model = ColorCLIP.from_checkpoint(color_model_path, device=device) |
| print(" Color model loaded") |
|
|
| |
| print(" Loading hierarchy model...") |
| hierarchy_model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device) |
| print(" Hierarchy model loaded") |
|
|
| |
| print(" Loading main CLIP model...") |
| clip_model = CLIPModel_transformers.from_pretrained(CLIP_MODEL_NAME) |
| checkpoint = torch.load(main_model_path, map_location=device, weights_only=False) |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| clip_model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| clip_model.load_state_dict(checkpoint) |
| clip_model.to(device).eval() |
| processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
| print(" Main CLIP model loaded") |
|
|
| print("\nAll models loaded!") |
|
|
| return { |
| 'color_model': color_model, |
| 'hierarchy_model': hierarchy_model, |
| 'main_model': clip_model, |
| 'processor': processor, |
| 'device': device, |
| } |
|
|
|
|
| def example_search(models, image_path: str = None, text_query: str = None): |
| """ |
| Example search with the models. |
| |
| Args: |
| models: Dictionary of loaded models |
| image_path: Path to an image (optional) |
| text_query: Text query (optional) |
| """ |
| color_model = models['color_model'] |
| hierarchy_model = models['hierarchy_model'] |
| main_model = models['main_model'] |
| processor = models['processor'] |
| device = models['device'] |
|
|
| print("\nExample search...") |
|
|
| if text_query: |
| print(f" Text query: '{text_query}'") |
|
|
| |
| color_emb = color_model.get_text_embeddings([text_query]) |
| hierarchy_emb = hierarchy_model.get_text_embeddings([text_query]) |
|
|
| print(f" Color embedding shape: {color_emb.shape}, norm: {color_emb.norm(dim=-1).item():.4f}") |
| print(f" Hierarchy embedding shape: {hierarchy_emb.shape}, norm: {hierarchy_emb.norm(dim=-1).item():.4f}") |
|
|
| |
| text_features = encode_text(main_model, processor, text_query, device) |
| text_features = F.normalize(text_features, dim=-1) |
|
|
| print(f" Main embedding: {text_features.shape}") |
| print(f" First 10 dims of main embedding: {text_features[0, :10]}") |
|
|
| |
| main_color_emb = text_features[:, :config.color_emb_dim] |
| main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim] |
|
|
| print(f"\n Subspace comparison (color model vs main model dims [0:{config.color_emb_dim}]):") |
| print(f" color_model first 5 dims: {color_emb[0, :5].tolist()}") |
| print(f" main_model first 5 dims: {main_color_emb[0, :5].tolist()}") |
| print(f" Subspace comparison (hierarchy model vs main model dims [{config.color_emb_dim}:{config.color_emb_dim + config.hierarchy_emb_dim}]):") |
| print(f" hierarchy_model first 5 dims: {hierarchy_emb[0, :5].tolist()}") |
| print(f" main_model first 5 dims: {main_hierarchy_emb[0, :5].tolist()}") |
|
|
| |
| color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1) |
| print(f"\n Cosine similarity between color embeddings: {color_cosine_sim.item():.4f}") |
|
|
| |
| hierarchy_cosine_sim = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1) |
| print(f" Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}") |
|
|
| if image_path and os.path.exists(image_path): |
| print(f"\n Image: {image_path}") |
| image = Image.open(image_path).convert("RGB") |
|
|
| |
| image_features = encode_image(main_model, processor, image, device) |
| image_features = F.normalize(image_features, dim=-1) |
| print(f" Main image embedding shape: {image_features.shape}") |
|
|
| |
| color_pixel_values = color_model.processor( |
| images=image, return_tensors="pt" |
| )["pixel_values"].to(device) |
| color_img_emb = color_model.get_image_embeddings(color_pixel_values) |
| print(f" Color image embedding shape: {color_img_emb.shape}") |
|
|
| |
| hierarchy_pixel_values = hierarchy_model.processor( |
| images=image, return_tensors="pt" |
| )["pixel_values"].to(device) |
| hierarchy_img_emb = hierarchy_model.get_image_embeddings(hierarchy_pixel_values) |
| print(f" Hierarchy image embedding shape: {hierarchy_img_emb.shape}") |
|
|
| |
| main_color_img = image_features[:, :config.color_emb_dim] |
| main_hierarchy_img = image_features[:, config.color_emb_dim:config.color_emb_dim + config.hierarchy_emb_dim] |
| color_img_sim = F.cosine_similarity(color_img_emb, main_color_img, dim=1) |
| hierarchy_img_sim = F.cosine_similarity(hierarchy_img_emb, main_hierarchy_img, dim=1) |
| print(f" Image color subspace cosine similarity: {color_img_sim.item():.4f}") |
| print(f" Image hierarchy subspace cosine similarity: {hierarchy_img_sim.item():.4f}") |
|
|
|
|
| def example_similarity_search(models, image_paths: list, text_query: str): |
| """ |
| Rank images by similarity to a text query using GAP-CLIP. |
| |
| Shows the key use case: computing text-to-image similarity scores |
| for ranking, combining color, hierarchy, and general CLIP subspaces. |
| |
| Args: |
| models: Dictionary of loaded models |
| image_paths: List of image file paths to rank |
| text_query: Text query to match against |
| """ |
| main_model = models['main_model'] |
| processor = models['processor'] |
| device = models['device'] |
|
|
| print(f"\nSimilarity search: '{text_query}' against {len(image_paths)} images") |
|
|
| |
| text_features = encode_text(main_model, processor, text_query, device) |
| text_features = F.normalize(text_features, dim=-1) |
|
|
| |
| images = [] |
| valid_paths = [] |
| for p in image_paths: |
| if os.path.exists(p): |
| images.append(Image.open(p).convert("RGB")) |
| valid_paths.append(p) |
| else: |
| print(f" Warning: {p} not found, skipping") |
|
|
| if not images: |
| print(" No valid images found.") |
| return |
|
|
| image_features = encode_image(main_model, processor, images, device) |
| image_features = F.normalize(image_features, dim=-1) |
|
|
| |
| full_scores = (text_features @ image_features.T).squeeze(0) |
|
|
| |
| color_dim = config.color_emb_dim |
| hierarchy_end = color_dim + config.hierarchy_emb_dim |
|
|
| color_text = F.normalize(text_features[:, :color_dim], dim=-1) |
| color_imgs = F.normalize(image_features[:, :color_dim], dim=-1) |
| color_scores = (color_text @ color_imgs.T).squeeze(0) |
|
|
| hier_text = F.normalize(text_features[:, color_dim:hierarchy_end], dim=-1) |
| hier_imgs = F.normalize(image_features[:, color_dim:hierarchy_end], dim=-1) |
| hierarchy_scores = (hier_text @ hier_imgs.T).squeeze(0) |
|
|
| |
| ranked_indices = full_scores.argsort(descending=True) |
|
|
| print(f"\n Ranking (by full 512D cosine similarity):") |
| for rank, idx in enumerate(ranked_indices): |
| i = idx.item() |
| print( |
| f" {rank + 1}. {os.path.basename(valid_paths[i]):30s}" |
| f" full={full_scores[i]:.4f}" |
| f" color={color_scores[i]:.4f}" |
| f" hierarchy={hierarchy_scores[i]:.4f}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Example usage of GAP-CLIP models") |
| parser.add_argument( |
| "--repo-id", |
| type=str, |
| default=None, |
| help="Hugging Face repo ID (e.g., Leacb4/gap-clip). If omitted, loads from local paths.", |
| ) |
| parser.add_argument( |
| "--text", |
| type=str, |
| default="red dress", |
| help="Text query for search", |
| ) |
| parser.add_argument( |
| "--image", |
| type=str, |
| default=None, |
| help="Path to a single image for example_search", |
| ) |
| parser.add_argument( |
| "--images", |
| type=str, |
| nargs="+", |
| default=None, |
| help="Paths to multiple images for similarity ranking", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.repo_id: |
| models = load_models_from_hf(args.repo_id) |
| else: |
| models = load_models_from_local() |
|
|
| |
| example_search(models, image_path=args.image, text_query=args.text) |
|
|
| |
| if args.images: |
| example_similarity_search(models, args.images, args.text) |
|
|