Chyd19's picture
Update app.py
c2b4d4f verified
# ==============================
# SECTION 1 β€” INSTALL + IMPORTS
# ==============================
import torch
import gradio as gr
from PIL import Image
from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
import lpips
import clip
from bert_score import score
import torchvision.transforms as T
from sentence_transformers import SentenceTransformer
from rouge_score import rouge_scorer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
device = "cuda" if torch.cuda.is_available() else "cpu"
def free_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ==============================
# SECTION 2 β€” LOAD LIGHTWEIGHT MODELS
# ==============================
blip_large_captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large",
device=0 if device=="cuda" else -1
)
vit_gpt2_captioner = pipeline(
"image-to-text",
model="nlpconnect/vit-gpt2-image-captioning",
device=0 if device=="cuda" else -1
)
# --- NLP Pipelines ---
sentiment_model = pipeline("sentiment-analysis")
ner_model = pipeline("ner", aggregation_strategy="simple")
topic_model = pipeline("zero-shot-classification",
model="facebook/bart-large-mnli")
# --- Metrics ---
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
lpips_model = lpips.LPIPS(net='alex').to(device)
lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
# ==============================
# SECTION 2b β€” LAZY LOAD HEAVY MODELS
# ==============================
blip2_captioner = None
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
vqa_model = None
def get_blip2():
global blip2_captioner
if blip2_captioner is None:
blip2_captioner = pipeline(
"image-to-text",
model="Salesforce/blip2-opt-2.7b",
device=0 if device=="cuda" else -1
)
return blip2_captioner
def get_vqa_model():
global vqa_model
if vqa_model is None:
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
return vqa_model
# ==============================
# SECTION 3 β€” FUNCTIONS
# ==============================
def make_captions(img):
captions = []
try: captions.append(blip_large_captioner(img)[0]["generated_text"])
except: captions.append("BLIP-large failed.")
try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
except: captions.append("ViT-GPT2 failed.")
try:
blip2 = get_blip2()
captions.append(blip2(img)[0]["generated_text"])
except: captions.append("BLIP2-opt failed.")
return captions
# ---------------- Metrics Computation ---------------------
def compute_metrics_button(images, captions, idx1, idx2):
# CLIP similarity
img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
with torch.no_grad():
feat1 = clip_model.encode_image(img1_clip)
feat2 = clip_model.encode_image(img2_clip)
clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
# LPIPS
img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
with torch.no_grad():
lpips_score = float(lpips_model(img1_lp, img2_lp).item())
# BERTScore
_, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
bert_f1 = float(F1.mean().item())
# Cosine similarity of embeddings
emb1 = sentence_model.encode([captions[idx1]])
emb2 = sentence_model.encode([captions[idx2]])
cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
# Jaccard similarity
tokens1 = set(captions[idx1].lower().split())
tokens2 = set(captions[idx2].lower().split())
jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
# ROUGE
scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
rouge_scores = scorer.score(captions[idx1], captions[idx2])
return f"""
- CLIP: {clip_sim:.4f}
- LPIPS: {lpips_score:.4f}
- BERT-F1: {bert_f1:.4f}
- Cosine: {cosine_sim:.4f}
- Jaccard: {jaccard_sim:.4f}
- ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
- ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
"""
# ---- NLP ----
def nlp_bundle(caption):
try:
sentiment = sentiment_model(caption)
sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
except: sentiment = "Sentiment failed."
try:
ents_list = ner_model(caption)
ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
except: ents = "NER failed."
try:
topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
except: topics = "Topics failed."
return sentiment, ents, topics
# ---------------- VQA ----------------
def answer_vqa(question, image):
if image is None or question.strip() == "":
return "Upload an image and enter a question."
model = get_vqa_model()
inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(**inputs)
answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
free_gpu_cache()
return answer
# Convert a PIL.Image to PNG byte stream
def to_bytes(img):
import io
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# ==============================
# SECTION 4 β€” UI (GRADIO)
# ==============================
def build_ui():
with gr.Blocks(title="Multimodal AI Image Studio") as demo:
gr.HTML("""
<style>
.heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
.orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
.teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
.loading-line {
height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
background-size:200% 100%; animation: loading 1s linear infinite;
}
@keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
.circular-img img {
border-radius: 21%;
object-fit: cover;
width: 400px;
height: 200px;
box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
5px 5px 15px rgba(0,0,0,0.3);
border: 2px solid rgba(255,255,255,0.6);
}
.metrics-row {
display: flex;
flex-direction: row;
gap: 20px;
}
.metrics-row > div {
flex: 1;
}
</style>
""")
gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
images_state = gr.State([])
captions_state = gr.State([])
# ---------------- Image Input ----------------
gr.Markdown("### Select Image Source", elem_classes="heading-orange")
with gr.Tabs():
with gr.Tab("πŸ“ Upload Image"):
upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img")
upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
with gr.Tab("πŸ“· Webcam"):
webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img")
webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
with gr.Tab("πŸ”— From URL"):
url_input = gr.Textbox(label="Paste Image URL")
url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
# ---------------- Previews ----------------
with gr.Row():
with gr.Column(scale=1, min_width=200):
preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
blip_caption_box = gr.Markdown()
with gr.Column(scale=1, min_width=200):
preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
vit_caption_box = gr.Markdown()
with gr.Column(scale=1, min_width=200):
preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
blip2_caption_box = gr.Markdown()
# ---------------- Generate Captions ----------------
def generate_all(img, images_state, captions_state):
if img is None:
return (None, None, None, "No image.", "No image.", "No image.", [], [])
captions = make_captions(img)
return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
def load_from_url(url, images_state, captions_state):
import requests
from io import BytesIO
try:
img = Image.open(BytesIO(requests.get(url).content))
except:
return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
return generate_all(img, images_state, captions_state)
url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
# ---------------- Metrics ----------------
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
with gr.Row(elem_classes="metrics-row"):
metrics_A = gr.Markdown()
metrics_B = gr.Markdown()
metrics_C = gr.Markdown()
def compute_metrics_all_pairs_ui(images, captions):
# 3 spinners
yield (
"<div class='loading-line'></div>",
"<div class='loading-line'></div>",
"<div class='loading-line'></div>"
)
if len(images) < 1 or len(captions) < 3:
msg = "<b>Upload 1 image and generate all 3 captions.</b>"
yield (msg, msg, msg)
return
imgs = images * 3
A = compute_metrics_button(imgs, captions, 0, 1)
B = compute_metrics_button(imgs, captions, 0, 2)
C = compute_metrics_button(imgs, captions, 1, 2)
yield (
f"### BLIP-large ↔ ViT-GPT2\n{A}",
f"### BLIP-large ↔ BLIP2\n{B}",
f"### ViT-GPT2 ↔ BLIP2\n{C}"
)
metrics_btn.click(
compute_metrics_all_pairs_ui,
inputs=[images_state, captions_state],
outputs=[metrics_A, metrics_B, metrics_C]
)
# ---------------- NLP ----------------
gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
with gr.Row(elem_classes="metrics-row"): # reuse metrics-row for flex layout
nlp_A = gr.Markdown()
nlp_B = gr.Markdown()
nlp_C = gr.Markdown()
def do_nlp_all(captions):
# 3 spinners like metrics
yield (
"<div class='loading-line'></div>",
"<div class='loading-line'></div>",
"<div class='loading-line'></div>"
)
if len(captions) < 3:
msg = "<b>All 3 captions required.</b>"
yield (msg, msg, msg)
return
labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
results = []
for label, cap in zip(labels, captions):
s, e, t = nlp_bundle(cap)
block = f"""
<h3><u>{label}</u></h3>
<b>Sentiment</b><br>{s}<br><br>
<b>Entities</b><br>{e}<br><br>
<b>Topics</b><br>{t}
"""
results.append(block)
yield (results[0], results[1], results[2])
nlp_btn.click(do_nlp_all, inputs=[captions_state], outputs=[nlp_A, nlp_B, nlp_C])
"""
# ---------------- NLP ---------------- COMMented out NLP
gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
nlp_out = gr.HTML()
def do_nlp(captions):
yield "<div class='loading-line'></div>"
if len(captions) < 3:
yield "<b>All captions required.</b>"
return
labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
blocks = []
for label, cap in zip(labels, captions):
s, e, t = nlp_bundle(cap)
block = f""
<div style='flex:1;padding:10px;min-width:240px;'>
<h3><u>{label}</u></h3>
<b>Sentiment</b><br>{s}<br><br>
<b>Entities</b><br>{e}<br><br>
<b>Topics</b><br>{t}
</div>
""
blocks.append(block)
yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])"""
# ---------------- VQA ----------------
gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
with gr.Row():
vqa_input = gr.Textbox(label="Ask about the image")
vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
vqa_out = gr.Markdown()
def vqa_ui(question, image):
yield "<div class='loading-line'></div>"
yield answer_vqa(question, image)
vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
return demo
# ==============================
# LAUNCH
# ==============================
demo = build_ui()
demo.launch(share=True, debug=False)