| """ |
| CLIP Image Search Application |
| |
| A Gradio-based application for searching similar images using OpenAI's CLIP model. |
| Supports multiple image formats and provides a web interface for uploading and searching images. |
| """ |
|
|
| import gradio as gr |
| from transformers import CLIPProcessor, CLIPModel |
| from PIL import Image |
| import torch |
| import pickle |
| from pathlib import Path |
| import os |
| import spaces |
| from typing import List, Dict, Tuple, Optional, Union |
|
|
| |
| model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
| processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
| model.eval() |
|
|
| DATASET_DIR: Path = Path("dataset") |
| CACHE_FILE: str = "cache.pkl" |
|
|
| |
| IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"] |
|
|
| def get_all_image_files() -> List[Path]: |
| """ |
| Get all image files from the dataset directory. |
| |
| Searches for images with supported extensions in both lowercase and uppercase. |
| |
| Returns: |
| List[Path]: List of Path objects for all found image files |
| """ |
| image_files: List[Path] = [] |
| for ext in IMAGE_EXTENSIONS: |
| image_files.extend(DATASET_DIR.glob(ext)) |
| image_files.extend(DATASET_DIR.glob(ext.upper())) |
| return image_files |
|
|
| def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor: |
| """ |
| Generate CLIP embedding for an image. |
| |
| Args: |
| image (Image.Image): PIL Image object to process |
| device (str, optional): Device to run computation on. Defaults to "cpu". |
| |
| Returns: |
| torch.Tensor: L2-normalized image embedding tensor |
| |
| Raises: |
| RuntimeError: If CUDA is requested but not available |
| """ |
| |
| inputs = processor(images=image, return_tensors="pt").to(device) |
| model_device = model.to(device) |
| with torch.no_grad(): |
| emb: torch.Tensor = model_device.get_image_features(**inputs) |
| |
| emb = emb / emb.norm(p=2, dim=-1, keepdim=True) |
| return emb |
|
|
| @spaces.GPU |
| def get_reference_embeddings() -> Dict[str, torch.Tensor]: |
| """ |
| Load or compute embeddings for all reference images in the dataset. |
| |
| Checks if cached embeddings are up to date with the current dataset. |
| If not, recomputes embeddings for all images and updates the cache. |
| |
| Returns: |
| Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings |
| |
| Raises: |
| FileNotFoundError: If dataset directory doesn't exist |
| PermissionError: If unable to write cache file |
| """ |
| |
| current_image_files: List[Path] = get_all_image_files() |
| current_images: set = set(img_path.name for img_path in current_image_files) |
| |
| |
| cached_embeddings: Dict[str, torch.Tensor] = {} |
| if os.path.exists(CACHE_FILE): |
| with open(CACHE_FILE, "rb") as f: |
| cached_embeddings = pickle.load(f) |
| |
| |
| cached_images: set = set(cached_embeddings.keys()) |
| |
| |
| if current_images != cached_images: |
| print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}") |
| embeddings: Dict[str, torch.Tensor] = {} |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| for img_path in current_image_files: |
| print(f"Processing {img_path.name}...") |
| try: |
| img: Image.Image = Image.open(img_path).convert("RGB") |
| emb: torch.Tensor = get_embedding(img, device=device) |
| embeddings[img_path.name] = emb.cpu() |
| except Exception as e: |
| print(f"Error processing {img_path.name}: {e}") |
| continue |
| |
| |
| with open(CACHE_FILE, "wb") as f: |
| pickle.dump(embeddings, f) |
| print(f"Cache updated with {len(embeddings)} images") |
| return embeddings |
| else: |
| print(f"Using cached embeddings for {len(cached_embeddings)} images") |
| return cached_embeddings |
|
|
| |
| reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings() |
|
|
| @spaces.GPU |
| def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]: |
| """ |
| Find similar images to the query image using CLIP embeddings. |
| |
| Args: |
| query_img (Image.Image): Query image to find similar images for |
| |
| Returns: |
| List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score) |
| Limited to top 5 results above similarity threshold |
| |
| Raises: |
| RuntimeError: If CUDA operations fail |
| """ |
| |
| global reference_embeddings |
| reference_embeddings = get_reference_embeddings() |
| |
| query_emb: torch.Tensor = get_embedding(query_img, device="cuda") |
| results: List[Tuple[str, float]] = [] |
| |
| for name, ref_emb in reference_embeddings.items(): |
| |
| ref_emb_gpu: torch.Tensor = ref_emb.to("cuda") |
| |
| sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() |
| results.append((name, sim)) |
| |
| results.sort(key=lambda x: x[1], reverse=True) |
| |
| |
| SIMILARITY_THRESHOLD: float = 0.2 |
| filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD] |
| |
| if not filtered_results: |
| return [("No similar images found", "No matches above similarity threshold")] |
| |
| |
| return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]] |
|
|
| @spaces.GPU |
| def add_image(name: str, image: Image.Image) -> str: |
| """ |
| Add a new image to the dataset and update embeddings. |
| |
| Args: |
| name (str): Name for the new image (without extension) |
| image (Image.Image): PIL Image object to add to dataset |
| |
| Returns: |
| str: Success message with total image count |
| |
| Raises: |
| ValueError: If name is empty or invalid |
| PermissionError: If unable to save image or update cache |
| RuntimeError: If embedding computation fails |
| """ |
| if not name.strip(): |
| return "Please provide a valid image name." |
| |
| |
| path: Path = DATASET_DIR / f"{name}.png" |
| image.save(path, "PNG") |
| |
| |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| emb: torch.Tensor = get_embedding(image, device=device) |
| |
| |
| reference_embeddings[f"{name}.png"] = emb.cpu() |
| |
| with open(CACHE_FILE, "wb") as f: |
| pickle.dump(reference_embeddings, f) |
| |
| return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}" |
|
|
| |
| search_interface: gr.Interface = gr.Interface( |
| fn=search_similar, |
| inputs=gr.Image(type="pil", label="Query Image"), |
| outputs=gr.Gallery(label="Top Matches", columns=5), |
| allow_flagging="never", |
| title="Image Similarity Search", |
| description="Upload an image to find similar images in the dataset" |
| ) |
|
|
| add_interface: gr.Interface = gr.Interface( |
| fn=add_image, |
| inputs=[ |
| gr.Text(label="Image Name", placeholder="Enter a unique name for your image"), |
| gr.Image(type="pil", label="Product Image") |
| ], |
| outputs="text", |
| allow_flagging="never", |
| title="Add Image to Dataset", |
| description="Add a new image to the searchable dataset" |
| ) |
|
|
| |
| demo: gr.TabbedInterface = gr.TabbedInterface( |
| [search_interface, add_interface], |
| tab_names=["Search", "Add Product"], |
| title="CLIP Image Search System", |
| theme=gr.themes.Soft() |
| ) |
|
|
| if __name__ == "__main__": |
| |
| DATASET_DIR.mkdir(exist_ok=True) |
| demo.launch(share=True, mcp_server=True) |