Spaces:
Sleeping
Sleeping
| # ============================== | |
| # SECTION 1 β INSTALL + IMPORTS | |
| # ============================== | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| from sentence_transformers import SentenceTransformer | |
| from rouge_score import rouge_scorer | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ============================== | |
| # SECTION 2 β LOAD LIGHTWEIGHT MODELS | |
| # ============================== | |
| blip_large_captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| vit_gpt2_captioner = pipeline( | |
| "image-to-text", | |
| model="nlpconnect/vit-gpt2-image-captioning", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| # --- NLP Pipelines --- | |
| sentiment_model = pipeline("sentiment-analysis") | |
| ner_model = pipeline("ner", aggregation_strategy="simple") | |
| topic_model = pipeline("zero-shot-classification", | |
| model="facebook/bart-large-mnli") | |
| # --- Metrics --- | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))]) | |
| sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity | |
| # ============================== | |
| # SECTION 2b β LAZY LOAD HEAVY MODELS | |
| # ============================== | |
| blip2_captioner = None | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = None | |
| def get_blip2(): | |
| global blip2_captioner | |
| if blip2_captioner is None: | |
| blip2_captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip2-opt-2.7b", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| return blip2_captioner | |
| def get_vqa_model(): | |
| global vqa_model | |
| if vqa_model is None: | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) | |
| return vqa_model | |
| # ============================== | |
| # SECTION 3 β FUNCTIONS | |
| # ============================== | |
| def make_captions(img): | |
| captions = [] | |
| try: captions.append(blip_large_captioner(img)[0]["generated_text"]) | |
| except: captions.append("BLIP-large failed.") | |
| try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"]) | |
| except: captions.append("ViT-GPT2 failed.") | |
| try: | |
| blip2 = get_blip2() | |
| captions.append(blip2(img)[0]["generated_text"]) | |
| except: captions.append("BLIP2-opt failed.") | |
| return captions | |
| # ---------------- Metrics Computation --------------------- | |
| def compute_metrics_button(images, captions, idx1, idx2): | |
| # CLIP similarity | |
| img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device) | |
| img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| feat1 = clip_model.encode_image(img1_clip) | |
| feat2 = clip_model.encode_image(img2_clip) | |
| clip_sim = float(torch.cosine_similarity(feat1, feat2).item()) | |
| # LPIPS | |
| img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1 | |
| img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1 | |
| with torch.no_grad(): | |
| lpips_score = float(lpips_model(img1_lp, img2_lp).item()) | |
| # BERTScore | |
| _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False) | |
| bert_f1 = float(F1.mean().item()) | |
| # Cosine similarity of embeddings | |
| emb1 = sentence_model.encode([captions[idx1]]) | |
| emb2 = sentence_model.encode([captions[idx2]]) | |
| cosine_sim = float(cosine_similarity(emb1, emb2)[0][0]) | |
| # Jaccard similarity | |
| tokens1 = set(captions[idx1].lower().split()) | |
| tokens2 = set(captions[idx2].lower().split()) | |
| jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2)) | |
| # ROUGE | |
| scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True) | |
| rouge_scores = scorer.score(captions[idx1], captions[idx2]) | |
| return f""" | |
| - CLIP: {clip_sim:.4f} | |
| - LPIPS: {lpips_score:.4f} | |
| - BERT-F1: {bert_f1:.4f} | |
| - Cosine: {cosine_sim:.4f} | |
| - Jaccard: {jaccard_sim:.4f} | |
| - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f} | |
| - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f} | |
| """ | |
| # ---- NLP ---- | |
| def nlp_bundle(caption): | |
| try: | |
| sentiment = sentiment_model(caption) | |
| sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment]) | |
| except: sentiment = "Sentiment failed." | |
| try: | |
| ents_list = ner_model(caption) | |
| ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None" | |
| except: ents = "NER failed." | |
| try: | |
| topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"]) | |
| topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])]) | |
| except: topics = "Topics failed." | |
| return sentiment, ents, topics | |
| # ---------------- VQA ---------------- | |
| def answer_vqa(question, image): | |
| if image is None or question.strip() == "": | |
| return "Upload an image and enter a question." | |
| model = get_vqa_model() | |
| inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs) | |
| answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True) | |
| free_gpu_cache() | |
| return answer | |
| # Convert a PIL.Image to PNG byte stream | |
| def to_bytes(img): | |
| import io | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| # ============================== | |
| # SECTION 4 β UI (GRADIO) | |
| # ============================== | |
| def build_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| gr.HTML(""" | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; } | |
| .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; } | |
| .loading-line { | |
| height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); | |
| background-size:200% 100%; animation: loading 1s linear infinite; | |
| } | |
| @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} } | |
| .circular-img img { | |
| border-radius: 21%; | |
| object-fit: cover; | |
| width: 400px; | |
| height: 200px; | |
| box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3), | |
| 5px 5px 15px rgba(0,0,0,0.3); | |
| border: 2px solid rgba(255,255,255,0.6); | |
| } | |
| .metrics-row { | |
| display: flex; | |
| flex-direction: row; | |
| gap: 20px; | |
| } | |
| .metrics-row > div { | |
| flex: 1; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange") | |
| images_state = gr.State([]) | |
| captions_state = gr.State([]) | |
| # ---------------- Image Input ---------------- | |
| gr.Markdown("### Select Image Source", elem_classes="heading-orange") | |
| with gr.Tabs(): | |
| with gr.Tab("π Upload Image"): | |
| upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img") | |
| upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn") | |
| with gr.Tab("π· Webcam"): | |
| webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img") | |
| webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn") | |
| with gr.Tab("π From URL"): | |
| url_input = gr.Textbox(label="Paste Image URL") | |
| url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn") | |
| # ---------------- Previews ---------------- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=200): | |
| preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230) | |
| blip_caption_box = gr.Markdown() | |
| with gr.Column(scale=1, min_width=200): | |
| preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230) | |
| vit_caption_box = gr.Markdown() | |
| with gr.Column(scale=1, min_width=200): | |
| preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230) | |
| blip2_caption_box = gr.Markdown() | |
| # ---------------- Generate Captions ---------------- | |
| def generate_all(img, images_state, captions_state): | |
| if img is None: | |
| return (None, None, None, "No image.", "No image.", "No image.", [], []) | |
| captions = make_captions(img) | |
| return (img, img, img, captions[0], captions[1], captions[2], [img], captions) | |
| upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state], | |
| outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state]) | |
| webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state], | |
| outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state]) | |
| def load_from_url(url, images_state, captions_state): | |
| import requests | |
| from io import BytesIO | |
| try: | |
| img = Image.open(BytesIO(requests.get(url).content)) | |
| except: | |
| return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], []) | |
| return generate_all(img, images_state, captions_state) | |
| url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state], | |
| outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state]) | |
| # ---------------- Metrics ---------------- | |
| gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") | |
| with gr.Row(elem_classes="metrics-row"): | |
| metrics_A = gr.Markdown() | |
| metrics_B = gr.Markdown() | |
| metrics_C = gr.Markdown() | |
| def compute_metrics_all_pairs_ui(images, captions): | |
| # 3 spinners | |
| yield ( | |
| "<div class='loading-line'></div>", | |
| "<div class='loading-line'></div>", | |
| "<div class='loading-line'></div>" | |
| ) | |
| if len(images) < 1 or len(captions) < 3: | |
| msg = "<b>Upload 1 image and generate all 3 captions.</b>" | |
| yield (msg, msg, msg) | |
| return | |
| imgs = images * 3 | |
| A = compute_metrics_button(imgs, captions, 0, 1) | |
| B = compute_metrics_button(imgs, captions, 0, 2) | |
| C = compute_metrics_button(imgs, captions, 1, 2) | |
| yield ( | |
| f"### BLIP-large β ViT-GPT2\n{A}", | |
| f"### BLIP-large β BLIP2\n{B}", | |
| f"### ViT-GPT2 β BLIP2\n{C}" | |
| ) | |
| metrics_btn.click( | |
| compute_metrics_all_pairs_ui, | |
| inputs=[images_state, captions_state], | |
| outputs=[metrics_A, metrics_B, metrics_C] | |
| ) | |
| # ---------------- NLP ---------------- | |
| gr.Markdown("### NLP Analysis", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| with gr.Row(elem_classes="metrics-row"): # reuse metrics-row for flex layout | |
| nlp_A = gr.Markdown() | |
| nlp_B = gr.Markdown() | |
| nlp_C = gr.Markdown() | |
| def do_nlp_all(captions): | |
| # 3 spinners like metrics | |
| yield ( | |
| "<div class='loading-line'></div>", | |
| "<div class='loading-line'></div>", | |
| "<div class='loading-line'></div>" | |
| ) | |
| if len(captions) < 3: | |
| msg = "<b>All 3 captions required.</b>" | |
| yield (msg, msg, msg) | |
| return | |
| labels = ["BLIP-large", "ViT-GPT2", "BLIP2"] | |
| results = [] | |
| for label, cap in zip(labels, captions): | |
| s, e, t = nlp_bundle(cap) | |
| block = f""" | |
| <h3><u>{label}</u></h3> | |
| <b>Sentiment</b><br>{s}<br><br> | |
| <b>Entities</b><br>{e}<br><br> | |
| <b>Topics</b><br>{t} | |
| """ | |
| results.append(block) | |
| yield (results[0], results[1], results[2]) | |
| nlp_btn.click(do_nlp_all, inputs=[captions_state], outputs=[nlp_A, nlp_B, nlp_C]) | |
| """ | |
| # ---------------- NLP ---------------- COMMented out NLP | |
| gr.Markdown("### NLP Analysis", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| nlp_out = gr.HTML() | |
| def do_nlp(captions): | |
| yield "<div class='loading-line'></div>" | |
| if len(captions) < 3: | |
| yield "<b>All captions required.</b>" | |
| return | |
| labels = ["BLIP-large", "ViT-GPT2", "BLIP2"] | |
| blocks = [] | |
| for label, cap in zip(labels, captions): | |
| s, e, t = nlp_bundle(cap) | |
| block = f"" | |
| <div style='flex:1;padding:10px;min-width:240px;'> | |
| <h3><u>{label}</u></h3> | |
| <b>Sentiment</b><br>{s}<br><br> | |
| <b>Entities</b><br>{e}<br><br> | |
| <b>Topics</b><br>{t} | |
| </div> | |
| "" | |
| blocks.append(block) | |
| yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>" | |
| nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])""" | |
| # ---------------- VQA ---------------- | |
| gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| vqa_input = gr.Textbox(label="Ask about the image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| vqa_out = gr.Markdown() | |
| def vqa_ui(question, image): | |
| yield "<div class='loading-line'></div>" | |
| yield answer_vqa(question, image) | |
| vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out]) | |
| return demo | |
| # ============================== | |
| # LAUNCH | |
| # ============================== | |
| demo = build_ui() | |
| demo.launch(share=True, debug=False) | |