#!/usr/bin/env python3 """ Annex 9.4 — Search Engine Demo =============================== Interactive fashion search engine using pre-computed GAP-CLIP text embeddings. Demonstrates real-world retrieval quality by accepting free-text queries and returning the most similar items from the internal dataset, with images and similarity scores displayed in a grid layout. Run directly: python annex94_search_demo.py Paper reference: Section 9.4 (Appendix), Figure 5. """ import torch import numpy as np import pandas as pd from PIL import Image import matplotlib.pyplot as plt from sklearn.metrics.pairwise import cosine_similarity from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers import warnings import os import sys from pathlib import Path from typing import List, Optional # Ensure project root is importable when running this file directly. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) # Import custom models via shared loaders from evaluation.utils.model_loader import load_color_model, load_hierarchy_model import config warnings.filterwarnings("ignore") class FashionSearchEngine: """ Fashion search engine using multi-modal embeddings with category emphasis """ def __init__( self, top_k: int = 10, max_items: int = 10000, use_baseline: bool = False ): """ Initialize the fashion search engine Args: top_k: Number of top results to return max_items: Maximum number of items to process (for faster initialization) use_baseline: If True, use the Fashion-CLIP baseline instead of GAP-CLIP. """ self.device = config.device self.top_k = top_k self.max_items = max_items self.color_dim = config.color_emb_dim self.hierarchy_dim = config.hierarchy_emb_dim self.use_baseline = use_baseline # Load models self._load_models() # Load dataset self._load_dataset() # Pre-compute embeddings for all items self._precompute_embeddings() print("āœ… Fashion Search Engine ready!") def _load_models(self): """Load all required models.""" print("šŸ“¦ Loading models...") # Load color model (optional for search in this script). self.color_model = None color_model_path = getattr(config, "color_model_path", None) if not color_model_path or not Path(color_model_path).exists(): print("āš ļø color model checkpoint not found; continuing without color model.") else: try: self.color_model, _ = load_color_model( color_model_path=config.color_model_path, device=self.device, ) except Exception as e: print(f"āš ļø Failed to load color model: {e}; continuing without it.") # Load hierarchy model self.hierarchy_model = load_hierarchy_model( hierarchy_model_path=config.hierarchy_model_path, device=self.device, ) # Load main CLIP model (baseline or fine-tuned GAP-CLIP) if self.use_baseline: baseline_name = "patrickjohncyh/fashion-clip" print(f"šŸ“¦ Loading baseline Fashion-CLIP model ({baseline_name})...") self.main_model = CLIPModel_transformers.from_pretrained(baseline_name).to( self.device ) self.main_model.eval() self.clip_processor = CLIPProcessor.from_pretrained(baseline_name) else: self.main_model = CLIPModel_transformers.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" ) checkpoint = torch.load(config.main_model_path, map_location=self.device) if "model_state_dict" in checkpoint: self.main_model.load_state_dict(checkpoint["model_state_dict"]) else: self.main_model.load_state_dict(checkpoint) self.main_model.to(self.device) self.main_model.eval() self.clip_processor = CLIPProcessor.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" ) model_label = "Fashion-CLIP baseline" if self.use_baseline else "GAP-CLIP" print( f"āœ… Models loaded ({model_label}) - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D" ) def _load_dataset(self): """Load the fashion dataset. Tries ``config.local_dataset_path`` first. If it doesn't exist, falls back to ``data/data.csv`` (the raw catalogue without ``local_image_path``). """ print("šŸ“Š Loading dataset...") dataset_path = config.local_dataset_path if not Path(dataset_path).exists(): fallback = Path(config.ROOT_DIR) / "data" / "data.csv" if fallback.exists(): print(f"āš ļø {dataset_path} not found, falling back to {fallback}") dataset_path = str(fallback) else: raise FileNotFoundError( f"Neither {config.local_dataset_path} nor {fallback} found." ) self.df = pd.read_csv(dataset_path) # If local_image_path column is missing, create an empty one so the # rest of the pipeline can proceed (text-only search still works). if config.column_local_image_path not in self.df.columns: self.df[config.column_local_image_path] = "" self.df_clean = self.df.dropna(subset=[config.text_column]) print(f"āœ… {len(self.df_clean)} items loaded for search") def _precompute_embeddings(self): """Pre-compute text embeddings using stratified sampling (up to 20 items per color-category).""" print("šŸ”„ Pre-computing embeddings with stratified sampling...") sampled_df = self.df_clean.groupby( [config.color_column, config.hierarchy_column], ).apply(lambda g: g.sample(n=min(20, len(g)), replace=False)) sampled_df = sampled_df.reset_index(drop=True) all_embeddings = [] all_texts = [] all_colors = [] all_hierarchies = [] all_images = [] all_urls = [] batch_size = 32 from tqdm import tqdm total_batches = (len(sampled_df) + batch_size - 1) // batch_size for i in tqdm( range(0, len(sampled_df), batch_size), desc="Computing embeddings", total=total_batches, ): batch = sampled_df.iloc[i : i + batch_size] texts = batch[config.text_column].tolist() all_texts.extend(texts) all_colors.extend(batch[config.color_column].tolist()) all_hierarchies.extend(batch[config.hierarchy_column].tolist()) all_images.extend(batch[config.column_local_image_path].tolist()) all_urls.extend(batch[config.column_url_image].tolist()) with torch.no_grad(): text_inputs = self.clip_processor( text=texts, padding=True, truncation=True, max_length=77, return_tensors="pt", ) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device) outputs = self.main_model(**text_inputs, pixel_values=dummy_images) embeddings = outputs.text_embeds.cpu().numpy() all_embeddings.extend(embeddings) self.all_embeddings = np.array(all_embeddings) self.all_texts = all_texts self.all_colors = all_colors self.all_hierarchies = all_hierarchies self.all_images = all_images self.all_urls = all_urls print(f"āœ… Pre-computed embeddings for {len(self.all_embeddings)} items") def search_by_text( self, query_text: str, filter_category: Optional[str] = None ) -> List[dict]: """Search for clothing items using a text query. Args: query_text: Free-text description (e.g. "red summer dress"). filter_category: Optional category filter (e.g. "dress"). Returns: List of result dicts with keys: rank, image_path, text, color, hierarchy, similarity, index, url. """ print(f"šŸ” Searching for: '{query_text}'") with torch.no_grad(): text_inputs = self.clip_processor( text=[query_text], padding=True, return_tensors="pt" ) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} dummy_image = torch.zeros(1, 3, 224, 224).to(self.device) outputs = self.main_model(**text_inputs, pixel_values=dummy_image) query_embedding = outputs.text_embeds.cpu().numpy() similarities = cosine_similarity(query_embedding, self.all_embeddings)[0] top_indices = np.argsort(similarities)[::-1][: self.top_k * 2] results = [] for idx in top_indices: if similarities[idx] > -0.5: if ( filter_category and filter_category.lower() not in self.all_hierarchies[idx].lower() ): continue results.append( { "rank": len(results) + 1, "image_path": self.all_images[idx], "text": self.all_texts[idx], "color": self.all_colors[idx], "hierarchy": self.all_hierarchies[idx], "similarity": float(similarities[idx]), "index": int(idx), "url": self.all_urls[idx], } ) if len(results) >= self.top_k: break print(f"āœ… Found {len(results)} results") return results @staticmethod def _fetch_image_from_url(url: str, timeout: int = 5): """Try to download an image from *url*; return a PIL Image or None.""" import requests from io import BytesIO try: resp = requests.get(url, timeout=timeout) resp.raise_for_status() return Image.open(BytesIO(resp.content)).convert("RGB") except Exception: return None def display_results( self, results: List[dict], query_info: str = "", save_path: Optional[str] = None ): """Display search results as an image grid with similarity scores. Args: results: List of result dicts from search_by_text(). query_info: Label shown in the plot title. save_path: If given, save the figure to this path instead of plt.show(). """ if not results: print("āŒ No results found") return print(f"\nšŸŽÆ Search Results for: {query_info}") print("=" * 80) n_results = len(results) cols = min(5, n_results) rows = (n_results + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 5 * rows)) if rows == 1: axes = axes.reshape(1, -1) elif cols == 1: axes = axes.reshape(-1, 1) for i, result in enumerate(results): row = i // cols col = i % cols ax = axes[row, col] title = ( f"#{result['rank']} (Sim: {result['similarity']:.3f})\n" f"{result['color']} {result['hierarchy']}" ) # Try local file → URL download → text fallback img = None if result.get("image_path") and Path(result["image_path"]).is_file(): try: img = Image.open(result["image_path"]) except Exception: pass if img is None and result.get("url"): img = self._fetch_image_from_url(result["url"]) if img is not None: ax.imshow(img) else: ax.set_facecolor("#f0f0f0") snippet = result["text"][:80] ax.text( 0.5, 0.5, snippet, ha="center", va="center", transform=ax.transAxes, fontsize=8, wrap=True, ) ax.set_title(title, fontsize=10) ax.axis("off") for i in range(n_results, rows * cols): axes[i // cols, i % cols].axis("off") fig.suptitle(f'Search: "{query_info}"', fontsize=14, fontweight="bold") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") print(f"šŸ“Š Figure saved to {save_path}") else: plt.show() plt.close(fig) print("\nšŸ“‹ Detailed Results:") for result in results: print( f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | " f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | " f"Text: {result['text'][:50]}..." ) print(f" šŸ”— URL: {result['url']}") print() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Annex 9.4 — Fashion Search Engine Demo" ) parser.add_argument( "--baseline", action="store_true", help="Use the Fashion-CLIP baseline instead of GAP-CLIP", ) parser.add_argument( "--queries", nargs="*", default=None, help="Queries to run (e.g. 'red dress' 'blue pants')", ) args = parser.parse_args() label = "Baseline Fashion-CLIP" if args.baseline else "GAP-CLIP" print(f"šŸŽÆ Initializing Fashion Search Engine ({label})") engine = FashionSearchEngine(top_k=10, max_items=10000, use_baseline=args.baseline) print("āœ… Engine initialized (models loaded, embeddings precomputed).") if args.queries: all_results = {} figures_dir = Path("evaluation") figures_dir.mkdir(parents=True, exist_ok=True) (figures_dir / "figures").mkdir(parents=True, exist_ok=True) for query in args.queries: results = engine.search_by_text(query) slug = query.replace(" ", "_") fig_path = ( figures_dir / f"figures/baseline_{slug}.png" if args.baseline else figures_dir / f"figures/gapclip_{slug}.png" ) engine.display_results(results, query_info=query, save_path=str(fig_path)) all_results[query] = results