Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, pipeline | |
| from sentence_transformers import SentenceTransformer, util | |
| import pickle | |
| import numpy as np | |
| from PIL import Image, ImageEnhance | |
| import os | |
| import io | |
| import concurrent.futures | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class CLIPModelHandler: | |
| def __init__(self, model_name): | |
| self.model_name = model_name | |
| self.img_names, self.img_emb = self.load_precomputed_embeddings() | |
| def load_precomputed_embeddings(self): | |
| emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
| with open(emb_filename, 'rb') as fIn: | |
| img_names, img_emb = pickle.load(fIn) | |
| return img_names, img_emb | |
| def search_text(self, query, top_k=1): | |
| model = CLIPModel.from_pretrained(self.model_name) | |
| processor = CLIPProcessor.from_pretrained(self.model_name) | |
| tokenizer = CLIPTokenizer.from_pretrained(self.model_name) | |
| inputs = tokenizer([query], padding=True, return_tensors="pt") | |
| query_emb = model.get_text_features(**inputs) | |
| hits = util.semantic_search(query_emb, self.img_emb, top_k=top_k)[0] | |
| images = [Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) for hit in hits] | |
| return images | |
| def search_image(self, image_path, top_k=1): | |
| model = CLIPModel.from_pretrained(self.model_name) | |
| processor = CLIPProcessor.from_pretrained(self.model_name) | |
| # Load and preprocess the image | |
| image = Image.open(image_path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Get the image features | |
| outputs = model(**inputs) | |
| image_emb = outputs.logits_per_image | |
| # Perform semantic search | |
| hits = util.semantic_search(image_emb, self.img_emb, top_k=top_k)[0] | |
| # Retrieve and return the relevant images | |
| result_images = [] | |
| for hit in hits: | |
| img = Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) | |
| result_images.append(img) | |
| return result_images | |
| class BLIPImageCaptioning: | |
| def __init__(self, blip_model_name): | |
| self.blip_model_name = blip_model_name | |
| def preprocess_image(self, image): | |
| if isinstance(image, str): | |
| return Image.open(image).convert('RGB') | |
| elif isinstance(image, np.ndarray): | |
| return Image.fromarray(np.uint8(image)).convert('RGB') | |
| else: | |
| raise ValueError("Invalid input type for image. Supported types: str (file path) or np.ndarray.") | |
| def generate_caption(self, image): | |
| try: | |
| model = BlipForConditionalGeneration.from_pretrained(self.blip_model_name) | |
| processor = BlipProcessor.from_pretrained(self.blip_model_name) | |
| raw_image = self.preprocess_image(image) | |
| inputs = processor(raw_image, return_tensors="pt") | |
| out = model.generate(**inputs) | |
| unconditional_caption = processor.decode(out[0], skip_special_tokens=True) | |
| return unconditional_caption | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def generate_captions_parallel(self, images): | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| results = list(executor.map(self.generate_caption, images)) | |
| return results | |
| # Initialize the CLIP model handler | |
| clip_handler = CLIPModelHandler("openai/clip-vit-base-patch32") | |
| # Initialize the zero-shot image classification pipeline | |
| clip_classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32") | |
| # Load BLIP model directly | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model_name = "Salesforce/blip-image-captioning-base" | |
| # Function for text-to-image search | |
| def text_to_image_interface(query, top_k): | |
| try: | |
| # Perform text-to-image search | |
| result_images = clip_handler.search_text(query, top_k) | |
| # Resize images before displaying | |
| result_images_resized = [image.resize((224, 224)) for image in result_images] | |
| # Display more information about the results | |
| result_info = [{"Image Name": os.path.basename(img_path)} for img_path in clip_handler.img_names] | |
| return result_images_resized, result_info | |
| except Exception as e: | |
| return gr.Error(f"Error in text-to-image search: {str(e)}") | |
| # Gradio Interface function for zero-shot classification | |
| def zero_shot_classification(image, labels_text): | |
| try: | |
| # Convert image to PIL format | |
| PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') | |
| # Split labels_text into a list of labels | |
| labels = labels_text.split(",") | |
| # Perform zero-shot classification | |
| res = clip_classifier(images=PIL_image, candidate_labels=labels, hypothesis_template="This is a photo of a {}") | |
| # Format the result as a dictionary | |
| formatted_results = {dic["label"]: dic["score"] for dic in res} | |
| return formatted_results | |
| except Exception as e: | |
| return gr.Error(f"Error in zero-shot classification: {str(e)}") | |
| def preprocessing_interface(original_image, brightness_slider, contrast_slider, saturation_slider, sharpness_slider, rotation_slider): | |
| try: | |
| # Convert NumPy array to PIL Image | |
| PIL_image = Image.fromarray(np.uint8(original_image)).convert('RGB') | |
| # Normalize slider values to be in the range [0, 1] | |
| brightness_normalized = brightness_slider / 100.0 | |
| contrast_normalized = contrast_slider / 100.0 | |
| saturation_normalized = saturation_slider / 100.0 | |
| sharpness_normalized = sharpness_slider / 100.0 | |
| # Apply preprocessing based on user input | |
| PIL_image = PIL_image.convert("RGB") | |
| PIL_image = PIL_image.rotate(rotation_slider) | |
| # Adjust brightness | |
| enhancer = ImageEnhance.Brightness(PIL_image) | |
| PIL_image = enhancer.enhance(brightness_normalized) | |
| # Adjust contrast | |
| enhancer = ImageEnhance.Contrast(PIL_image) | |
| PIL_image = enhancer.enhance(contrast_normalized) | |
| # Adjust saturation | |
| enhancer = ImageEnhance.Color(PIL_image) | |
| PIL_image = enhancer.enhance(saturation_normalized) | |
| # Adjust sharpness | |
| enhancer = ImageEnhance.Sharpness(PIL_image) | |
| PIL_image = enhancer.enhance(sharpness_normalized) | |
| # Return the processed image | |
| return PIL_image | |
| except Exception as e: | |
| return gr.Error(f"Error in preprocessing: {str(e)}") | |
| def generate_captions(images): | |
| blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name) | |
| blip_processor = BlipProcessor.from_pretrained(blip_model_name) | |
| return [blip_model_instance.generate_caption(image) for image in images] | |
| # Gradio Interfaces | |
| zero_shot_classification_interface = gr.Interface( | |
| fn=zero_shot_classification, | |
| inputs=[ | |
| gr.Image(label="Image Query", elem_id="image_input"), | |
| gr.Textbox(label="Labels (comma-separated)", elem_id="labels_input"), | |
| ], | |
| outputs=gr.Label(elem_id="label_image"), | |
| ) | |
| text_to_image_interface = gr.Interface( | |
| fn=text_to_image_interface, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=2, | |
| label="Text Query", | |
| placeholder="Enter text here...", | |
| ), | |
| gr.Slider(0, 5, step=1, label="Top K Results"), | |
| ], | |
| outputs=[ | |
| gr.Gallery( | |
| label="Text-to-Image Search Results", | |
| elem_id="gallery_text", | |
| grid_cols=2, | |
| height="auto", | |
| ), | |
| gr.Text(label="Result Information", elem_id="text_info"), | |
| ], | |
| ) | |
| blip_model = BLIPImageCaptioning(blip_model_name) # Instantiate the object | |
| blip_captioning_interface = gr.Interface( | |
| fn=blip_model.generate_caption, # Correct the method name | |
| inputs=gr.Image(label="Image for Captioning", elem_id="blip_caption_image"), | |
| outputs=gr.Textbox(label="Generated Captions", elem_id="blip_generated_captions", default=""), | |
| ) | |
| preprocessing_interface = gr.Interface( | |
| fn=blip_model.preprocess_image, # Correct the method name | |
| inputs=[ | |
| gr.Image(label="Original Image", elem_id="original_image"), | |
| ], | |
| outputs=[ | |
| gr.Image(label="Processed Image", elem_id="processed_image"), | |
| ], | |
| ) | |
| # Tabbed Interface | |
| app = gr.TabbedInterface( | |
| interface_list=[text_to_image_interface, zero_shot_classification_interface, blip_captioning_interface], | |
| tab_names=["Text-to-Image Search", "Zero-Shot Classification", "BLIP Image Captioning"], | |
| ) | |
| # Launch the Gradio interface | |
| app.launch(debug=True, share="true") |