import pickle from operator import itemgetter import cv2 import gradio as gr import kornia.filters import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import zipfile from torchvision import transforms, models from skimage.transform import resize from get_models import Resnet_with_skip def create_retrieval_figure(res, num_results): fig = plt.figure(figsize=[10 * 2, 10 * 2]) cols = min(5, num_results) # Limit to 5 columns per row rows = (num_results // 5) + (num_results % 5 > 0) ax_query = fig.add_subplot(rows, 1, 1) plt.axis('off') ax_query.set_title(f'Top {num_results} most similar items', fontsize=40) names = "" # Convert the dictionary to a sorted list of tuples (item_name, distance) sorted_res = sorted(res.items(), key=itemgetter(1)) for i, (image, _) in enumerate(sorted_res[:num_results]): # Limit results current_image_path = "dataset/" + image.split("/")[3] + "/" + image.split("/")[4] archive = zipfile.ZipFile('dataset.zip', 'r') try: imgfile = archive.read(current_image_path) image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1) except Exception: image = np.ones((224, 224, 3), dtype=np.uint8) * 255 cv2.putText(image, "File not found", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) ax = fig.add_subplot(rows, cols, i + 1) plt.axis('off') plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) item_uuid = current_image_path.split("/")[2].split("_photoUUID")[0].split("itemUUID_")[1] ax.set_title(f'Top {i+1}', fontsize=20) names += f"Top {i+1} item UUID: {item_uuid}\n" return fig, names def knn_calc(image_name, query_feature, features): current_image_feature = features[image_name] criterion = torch.nn.CosineSimilarity(dim=1) dist = criterion(query_feature, current_image_feature).mean() return -dist.item() checkpoint_path = "multi_label.pth.tar" resnet = models.resnet101() num_ftrs = resnet.fc.in_features resnet.fc = nn.Linear(num_ftrs, 13) model = Resnet_with_skip(resnet) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint) model.eval() embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), ]) with open('query_images_paths.pkl', 'rb') as fp: query_images_paths = pickle.load(fp) with open('features.pkl', 'rb') as fp: features = pickle.load(fp) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) def predict(inp, use_retrieval, use_drawing, num_results): image_tensor = transform(inp) feature = embedding_model_test(image_tensor.unsqueeze(0)) if use_retrieval else None results = {} if use_drawing: with torch.no_grad(): classification, reconstruction = model(image_tensor.unsqueeze(0)) recon_tensor = reconstruction[0].repeat(3, 1, 1) recon_tensor = invTrans(kornia.enhance.invert(recon_tensor)) plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy() w, h = inp.size plot_recon = resize(plot_recon, (h, w)) results['Drawing'] = plot_recon if use_retrieval: dists = {image_name: knn_calc(image_name, feature, features) for image_name in query_images_paths} res = dict(sorted(dists.items(), key=itemgetter(1))) fig, names = create_retrieval_figure(res, num_results) results['Retrieval'] = (fig, names) retrieval_fig, retrieval_text = results.get('Retrieval', (None, "")) return retrieval_fig, results.get('Drawing', None), retrieval_text gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Image", height=300, width=300), gr.Checkbox(label="Use Retrieval"), gr.Checkbox(label="Use Drawing"), gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Retrieved Images") # Added slider ], outputs=[ gr.Plot(label="Retrieval Results"), gr.Image(label="Drawing", height=300, width=300), gr.Textbox(label="Item UUIDs") # Display item UUIDs ] ).launch(share=True)