Chyd19 commited on
Commit
c2b4d4f
·
verified ·
1 Parent(s): 262865d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -373
app.py CHANGED
@@ -386,376 +386,3 @@ def build_ui():
386
  demo = build_ui()
387
  demo.launch(share=True, debug=False)
388
 
389
- """
390
- # ==============================
391
- # SECTION 1 — INSTALL + IMPORTS
392
- # ==============================
393
-
394
- import torch
395
- import gradio as gr
396
- from PIL import Image
397
- from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
398
- import lpips
399
- import clip
400
- from bert_score import score
401
- import torchvision.transforms as T
402
- from sentence_transformers import SentenceTransformer
403
- from rouge_score import rouge_scorer
404
- import numpy as np
405
- from sklearn.metrics.pairwise import cosine_similarity
406
-
407
- device = "cuda" if torch.cuda.is_available() else "cpu"
408
-
409
- def free_gpu_cache():
410
- if torch.cuda.is_available():
411
- torch.cuda.empty_cache()
412
-
413
- # ==============================
414
- # SECTION 2 — LOAD LIGHTWEIGHT MODELS
415
- # ==============================
416
- blip_large_captioner = pipeline(
417
- "image-to-text",
418
- model="Salesforce/blip-image-captioning-large",
419
- device=0 if device=="cuda" else -1
420
- )
421
-
422
- vit_gpt2_captioner = pipeline(
423
- "image-to-text",
424
- model="nlpconnect/vit-gpt2-image-captioning",
425
- device=0 if device=="cuda" else -1
426
- )
427
-
428
- # --- NLP Pipelines ---
429
- sentiment_model = pipeline("sentiment-analysis")
430
- ner_model = pipeline("ner", aggregation_strategy="simple")
431
- topic_model = pipeline("zero-shot-classification",
432
- model="facebook/bart-large-mnli")
433
-
434
- # --- Metrics ---
435
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
436
- lpips_model = lpips.LPIPS(net='alex').to(device)
437
- lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
438
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
439
-
440
- # ==============================
441
- # SECTION 2b — LAZY LOAD HEAVY MODELS
442
- # ==============================
443
- blip2_captioner = None
444
- vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
445
- vqa_model = None
446
-
447
- def get_blip2():
448
- global blip2_captioner
449
- if blip2_captioner is None:
450
- blip2_captioner = pipeline(
451
- "image-to-text",
452
- model="Salesforce/blip2-opt-2.7b",
453
- device=0 if device=="cuda" else -1
454
- )
455
- return blip2_captioner
456
-
457
- def get_vqa_model():
458
- global vqa_model
459
- if vqa_model is None:
460
- vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
461
- return vqa_model
462
-
463
- # ==============================
464
- # SECTION 3 — FUNCTIONS
465
- # ==============================
466
- def make_captions(img):
467
- captions = []
468
- try: captions.append(blip_large_captioner(img)[0]["generated_text"])
469
- except: captions.append("BLIP-large failed.")
470
- try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
471
- except: captions.append("ViT-GPT2 failed.")
472
- try:
473
- blip2 = get_blip2()
474
- captions.append(blip2(img)[0]["generated_text"])
475
- except: captions.append("BLIP2-opt failed.")
476
- return captions
477
-
478
- # ---------------- Metrics Computation ---------------------
479
- def compute_metrics_button(images, captions, idx1, idx2):
480
- # CLIP similarity
481
- img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
482
- img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
483
- with torch.no_grad():
484
- feat1 = clip_model.encode_image(img1_clip)
485
- feat2 = clip_model.encode_image(img2_clip)
486
- clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
487
-
488
- # LPIPS
489
- img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
490
- img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
491
- with torch.no_grad():
492
- lpips_score = float(lpips_model(img1_lp, img2_lp).item())
493
-
494
- # BERTScore
495
- _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
496
- bert_f1 = float(F1.mean().item())
497
-
498
- # Cosine similarity of embeddings
499
- emb1 = sentence_model.encode([captions[idx1]])
500
- emb2 = sentence_model.encode([captions[idx2]])
501
- cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
502
-
503
- # Jaccard similarity
504
- tokens1 = set(captions[idx1].lower().split())
505
- tokens2 = set(captions[idx2].lower().split())
506
- jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
507
-
508
- # ROUGE
509
- scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
510
- rouge_scores = scorer.score(captions[idx1], captions[idx2])
511
-
512
- return f""
513
- **Metrics Comparison**
514
- - CLIP Similarity: {clip_sim:.4f}
515
- - LPIPS Score: {lpips_score:.4f}
516
- - BERTScore F1: {bert_f1:.4f}
517
- - Cosine Similarity: {cosine_sim:.4f}
518
- - Jaccard Similarity: {jaccard_sim:.4f}
519
- - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
520
- - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
521
- ""
522
-
523
- # ---- NLP ----
524
- def nlp_bundle(caption):
525
- try:
526
- sentiment = sentiment_model(caption)
527
- sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
528
- except: sentiment = "Sentiment failed."
529
-
530
- try:
531
- ents_list = ner_model(caption)
532
- ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
533
- except: ents = "NER failed."
534
-
535
- try:
536
- topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
537
- topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
538
- except: topics = "Topics failed."
539
-
540
- return sentiment, ents, topics
541
-
542
- # ---------------- VQA ----------------
543
- def answer_vqa(question, image):
544
- if image is None or question.strip() == "":
545
- return "Upload an image and enter a question."
546
- model = get_vqa_model()
547
- inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
548
- with torch.no_grad():
549
- generated_ids = model.generate(**inputs)
550
- answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
551
- free_gpu_cache()
552
- return answer
553
-
554
- # Convert a PIL.Image to PNG byte stream
555
- def to_bytes(img):
556
- import io
557
- buf = io.BytesIO()
558
- img.save(buf, format="PNG")
559
- return buf.getvalue()
560
-
561
- # ==============================
562
- # SECTION 4 — UI (GRADIO)
563
- # ==============================
564
- def build_ui():
565
- with gr.Blocks(title="Multimodal AI Image Studio") as demo:
566
-
567
- gr.HTML(
568
- <style>
569
- .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
570
- .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
571
- .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
572
- .loading-line {
573
- height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
574
- background-size:200% 100%; animation: loading 1s linear infinite;
575
- }
576
- @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
577
- .circular-img img {
578
- border-radius: 21%;
579
- object-fit: cover;
580
- width: 400px;
581
- height: 200px;
582
- box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
583
- 5px 5px 15px rgba(0,0,0,0.3);
584
- border: 2px solid rgba(255,255,255,0.6);
585
- }
586
- </style>
587
- )
588
-
589
- gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
590
- images_state = gr.State([])
591
- captions_state = gr.State([])
592
-
593
- # ---------------- Image Input ----------------
594
- gr.Markdown("### Select Image Source", elem_classes="heading-orange")
595
- with gr.Tabs():
596
- with gr.Tab("📁 Upload Image"):
597
- upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img")
598
- upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
599
- with gr.Tab("📷 Webcam"):
600
- webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img")
601
- webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
602
- with gr.Tab("🔗 From URL"):
603
- url_input = gr.Textbox(label="Paste Image URL")
604
- url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
605
-
606
- # ---------------- Previews ----------------
607
- with gr.Row():
608
- with gr.Column(scale=1, min_width=200):
609
- preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
610
- blip_caption_box = gr.Markdown()
611
- with gr.Column(scale=1, min_width=200):
612
- preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
613
- vit_caption_box = gr.Markdown()
614
- with gr.Column(scale=1, min_width=200):
615
- preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
616
- blip2_caption_box = gr.Markdown()
617
-
618
- # ---------------- Generate Captions ----------------
619
- def generate_all(img, images_state, captions_state):
620
- if img is None:
621
- return (None, None, None, "No image.", "No image.", "No image.", [], [])
622
- captions = make_captions(img)
623
- return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
624
-
625
- upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
626
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
627
- webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
628
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
629
-
630
- def load_from_url(url, images_state, captions_state):
631
- import requests
632
- from io import BytesIO
633
- try:
634
- img = Image.open(BytesIO(requests.get(url).content))
635
- except:
636
- return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
637
- return generate_all(img, images_state, captions_state)
638
-
639
- url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
640
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
641
-
642
- # ---------------- Metrics ----------------
643
-
644
- ""
645
- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
646
- metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
647
- metrics_A = gr.Markdown()
648
- metrics_B = gr.Markdown()
649
- metrics_C = gr.Markdown()
650
-
651
- def compute_metrics_all_pairs_ui(images, captions):
652
- yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
653
- if len(images) < 1 or len(captions) < 3:
654
- msg = "Upload 1 image and generate all 3 captions."
655
- yield msg, msg, msg
656
- return
657
- imgs = images * 3
658
- A = compute_metrics_button(imgs, captions, 0, 1)
659
- B = compute_metrics_button(imgs, captions, 0, 2)
660
- C = compute_metrics_button(imgs, captions, 1, 2)
661
- yield (f"**BLIP-large ↔ ViT-GPT2**<br>{A}",
662
- f"**BLIP-large ↔ BLIP2**<br>{B}",
663
- f"**ViT-GPT2 ↔ BLIP2**<br>{C}")
664
-
665
- metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
666
- outputs=[metrics_A, metrics_B, metrics_C])""
667
-
668
-
669
- # ---------------- Metrics ----------------
670
- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
671
- metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
672
-
673
- with gr.Row(elem_classes="metrics-row"):
674
- metrics_A = gr.Markdown()
675
- metrics_B = gr.Markdown()
676
- metrics_C = gr.Markdown()
677
-
678
- def compute_metrics_all_pairs_ui(images, captions):
679
-
680
- # 3 spinners (one for each column)
681
- yield (
682
- "<div class='loading-line'></div>",
683
- "<div class='loading-line'></div>",
684
- "<div class='loading-line'></div>"
685
- )
686
-
687
- if len(images) < 1 or len(captions) < 3:
688
- msg = "<b>Upload 1 image and generate all 3 captions.</b>"
689
- yield (msg, msg, msg)
690
- return
691
-
692
- # duplicate image for internal function
693
- imgs = images * 3
694
-
695
- # compute
696
- A = compute_metrics_button(imgs, captions, 0, 1)
697
- B = compute_metrics_button(imgs, captions, 0, 2)
698
- C = compute_metrics_button(imgs, captions, 1, 2)
699
-
700
- # return 3 separate markdown blocks (side-by-side)
701
- yield (
702
- f"### BLIP-large ↔ ViT-GPT2\n{A}",
703
- f"### BLIP-large ↔ BLIP2\n{B}",
704
- f"### ViT-GPT2 ↔ BLIP2\n{C}"
705
- )
706
-
707
- metrics_btn.click(
708
- compute_metrics_all_pairs_ui,
709
- inputs=[images_state, captions_state],
710
- outputs=[metrics_A, metrics_B, metrics_C]
711
- )
712
-
713
- # ---------------- NLP ----------------
714
- gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
715
- nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
716
- nlp_out = gr.HTML()
717
-
718
- def do_nlp(captions):
719
- yield "<div class='loading-line'></div>"
720
- if len(captions) < 3:
721
- yield "<b>All captions required.</b>"
722
- return
723
- labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
724
- blocks = []
725
- for label, cap in zip(labels, captions):
726
- s, e, t = nlp_bundle(cap)
727
- block = f""
728
- <div style='flex:1;padding:10px;min-width:240px;'>
729
- <h3><u>{label}</u></h3>
730
- <b>Sentiment</b><br>{s}<br><br>
731
- <b>Entities</b><br>{e}<br><br>
732
- <b>Topics</b><br>{t}
733
- </div>
734
- ""
735
- blocks.append(block)
736
- yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
737
-
738
- nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])
739
-
740
- # ---------------- VQA ----------------
741
- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
742
- with gr.Row():
743
- vqa_input = gr.Textbox(label="Ask about the image")
744
- vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
745
- vqa_out = gr.Markdown()
746
-
747
- def vqa_ui(question, image):
748
- yield "<div class='loading-line'></div>"
749
- yield answer_vqa(question, image)
750
-
751
- vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
752
-
753
- return demo
754
-
755
- # ==============================
756
- # LAUNCH
757
- # ==============================
758
- demo = build_ui()
759
- demo.launch(share=True, debug=False)
760
-
761
- """
 
386
  demo = build_ui()
387
  demo.launch(share=True, debug=False)
388