| """
|
| # Copyright (c) 2022, salesforce.com, inc.
|
| # All rights reserved.
|
| # SPDX-License-Identifier: BSD-3-Clause
|
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import plotly.graph_objects as go
|
| import requests
|
| import streamlit as st
|
| import torch
|
| from lavis.models import load_model
|
| from lavis.processors import load_processor
|
| from lavis.processors.blip_processors import BlipCaptionProcessor
|
| from PIL import Image
|
|
|
| from app import device, load_demo_image
|
| from app.utils import load_blip_itm_model
|
| from lavis.processors.clip_processors import ClipImageEvalProcessor
|
|
|
|
|
| @st.cache()
|
| def load_demo_image(img_url=None):
|
| if not img_url:
|
| img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg"
|
| raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
| return raw_image
|
|
|
|
|
| @st.cache(
|
| hash_funcs={
|
| torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
| .cpu()
|
| .numpy()
|
| },
|
| allow_output_mutation=True,
|
| )
|
| def load_model_cache(model_type, device):
|
| if model_type == "blip":
|
| model = load_model(
|
| "blip_feature_extractor", model_type="base", is_eval=True, device=device
|
| )
|
| elif model_type == "albef":
|
| model = load_model(
|
| "albef_feature_extractor", model_type="base", is_eval=True, device=device
|
| )
|
| elif model_type == "CLIP_ViT-B-32":
|
| model = load_model(
|
| "clip_feature_extractor", "ViT-B-32", is_eval=True, device=device
|
| )
|
| elif model_type == "CLIP_ViT-B-16":
|
| model = load_model(
|
| "clip_feature_extractor", "ViT-B-16", is_eval=True, device=device
|
| )
|
| elif model_type == "CLIP_ViT-L-14":
|
| model = load_model(
|
| "clip_feature_extractor", "ViT-L-14", is_eval=True, device=device
|
| )
|
|
|
| return model
|
|
|
|
|
| def app():
|
| model_type = st.sidebar.selectbox(
|
| "Model:",
|
| ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
|
| )
|
| score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
|
|
|
|
|
| st.markdown(
|
| "<h1 style='text-align: center;'>Zero-shot Classification</h1>",
|
| unsafe_allow_html=True,
|
| )
|
|
|
| instructions = """Try the provided image or upload your own:"""
|
| file = st.file_uploader(instructions)
|
|
|
| st.header("Image")
|
| if file:
|
| raw_img = Image.open(file).convert("RGB")
|
| else:
|
| raw_img = load_demo_image()
|
|
|
| st.image(raw_img)
|
|
|
| col1, col2 = st.columns(2)
|
|
|
| col1.header("Categories")
|
|
|
| cls_0 = col1.text_input("category 1", value="merlion")
|
| cls_1 = col1.text_input("category 2", value="sky")
|
| cls_2 = col1.text_input("category 3", value="giraffe")
|
| cls_3 = col1.text_input("category 4", value="fountain")
|
| cls_4 = col1.text_input("category 5", value="marina bay")
|
|
|
| cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
|
| cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
|
|
|
| if len(cls_names) != len(set(cls_names)):
|
| st.error("Please provide unique class names")
|
| return
|
|
|
| button = st.button("Submit")
|
|
|
| col2.header("Prediction")
|
|
|
|
|
|
|
| if button:
|
| if model_type.startswith("BLIP"):
|
| text_processor = BlipCaptionProcessor(prompt="A picture of ")
|
| cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
|
|
|
| if score_type == "Cosine":
|
| vis_processor = load_processor("blip_image_eval").build(image_size=224)
|
| img = vis_processor(raw_img).unsqueeze(0).to(device)
|
|
|
| feature_extractor = load_model_cache(model_type="blip", device=device)
|
|
|
| sample = {"image": img, "text_input": cls_prompt}
|
|
|
| with torch.no_grad():
|
| image_features = feature_extractor.extract_features(
|
| sample, mode="image"
|
| ).image_embeds_proj[:, 0]
|
| text_features = feature_extractor.extract_features(
|
| sample, mode="text"
|
| ).text_embeds_proj[:, 0]
|
| sims = (image_features @ text_features.t())[
|
| 0
|
| ] / feature_extractor.temp
|
|
|
| else:
|
| vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
| img = vis_processor(raw_img).unsqueeze(0).to(device)
|
|
|
| model = load_blip_itm_model(device)
|
|
|
| output = model(img, cls_prompt, match_head="itm")
|
| sims = output[:, 1]
|
|
|
| sims = torch.nn.Softmax(dim=0)(sims)
|
| inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
|
|
|
| elif model_type.startswith("ALBEF"):
|
| vis_processor = load_processor("blip_image_eval").build(image_size=224)
|
| img = vis_processor(raw_img).unsqueeze(0).to(device)
|
|
|
| text_processor = BlipCaptionProcessor(prompt="A picture of ")
|
| cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
|
|
|
| feature_extractor = load_model_cache(model_type="albef", device=device)
|
|
|
| sample = {"image": img, "text_input": cls_prompt}
|
|
|
| with torch.no_grad():
|
| image_features = feature_extractor.extract_features(
|
| sample, mode="image"
|
| ).image_embeds_proj[:, 0]
|
| text_features = feature_extractor.extract_features(
|
| sample, mode="text"
|
| ).text_embeds_proj[:, 0]
|
|
|
| st.write(image_features.shape)
|
| st.write(text_features.shape)
|
|
|
| sims = (image_features @ text_features.t())[0] / feature_extractor.temp
|
|
|
| sims = torch.nn.Softmax(dim=0)(sims)
|
| inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
|
|
|
| elif model_type.startswith("CLIP"):
|
| if model_type == "CLIP_ViT-B-32":
|
| model = load_model_cache(model_type="CLIP_ViT-B-32", device=device)
|
| elif model_type == "CLIP_ViT-B-16":
|
| model = load_model_cache(model_type="CLIP_ViT-B-16", device=device)
|
| elif model_type == "CLIP_ViT-L-14":
|
| model = load_model_cache(model_type="CLIP_ViT-L-14", device=device)
|
| else:
|
| raise ValueError(f"Unknown model type {model_type}")
|
|
|
| if score_type == "Cosine":
|
|
|
| image_preprocess = ClipImageEvalProcessor(image_size=224)
|
| img = image_preprocess(raw_img).unsqueeze(0).to(device)
|
|
|
| sample = {"image": img, "text_input": cls_names}
|
|
|
| with torch.no_grad():
|
| clip_features = model.extract_features(sample)
|
|
|
| image_features = clip_features.image_embeds_proj
|
| text_features = clip_features.text_embeds_proj
|
|
|
| sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
|
| inv_sims = sims.tolist()[::-1]
|
| else:
|
| st.warning("CLIP does not support multimodal scoring.")
|
| return
|
|
|
| fig = go.Figure(
|
| go.Bar(
|
| x=inv_sims,
|
| y=cls_names[::-1],
|
| text=["{:.2f}".format(s) for s in inv_sims],
|
| orientation="h",
|
| )
|
| )
|
| fig.update_traces(
|
| textfont_size=12,
|
| textangle=0,
|
| textposition="outside",
|
| cliponaxis=False,
|
| )
|
| col2.plotly_chart(fig, use_container_width=True)
|
|
|