| """
|
| # Copyright (c) 2022, salesforce.com, inc.
|
| # All rights reserved.
|
| # SPDX-License-Identifier: BSD-3-Clause
|
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import numpy as np
|
| import streamlit as st
|
| import torch
|
| from lavis.models import BlipBase, load_model
|
| from matplotlib import pyplot as plt
|
| from PIL import Image
|
| from scipy.ndimage import filters
|
| from skimage import transform as skimage_transform
|
|
|
|
|
| def resize_img(raw_img):
|
| w, h = raw_img.size
|
| scaling_factor = 240 / w
|
| resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
| return resized_image
|
|
|
|
|
| def read_img(filepath):
|
| raw_image = Image.open(filepath).convert("RGB")
|
|
|
| return raw_image
|
|
|
|
|
| @st.cache(
|
| hash_funcs={
|
| torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
| .cpu()
|
| .numpy()
|
| },
|
| allow_output_mutation=True,
|
| )
|
| def load_model_cache(name, model_type, is_eval, device):
|
| return load_model(name, model_type, is_eval, device)
|
|
|
|
|
| @st.cache(allow_output_mutation=True)
|
| def init_bert_tokenizer():
|
| tokenizer = BlipBase.init_tokenizer()
|
| return tokenizer
|
|
|
|
|
| def getAttMap(img, attMap, blur=True, overlap=True):
|
| attMap -= attMap.min()
|
| if attMap.max() > 0:
|
| attMap /= attMap.max()
|
| attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
| if blur:
|
| attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
| attMap -= attMap.min()
|
| attMap /= attMap.max()
|
| cmap = plt.get_cmap("jet")
|
| attMapV = cmap(attMap)
|
| attMapV = np.delete(attMapV, 3, 2)
|
| if overlap:
|
| attMap = (
|
| 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
| + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
| )
|
| return attMap
|
|
|
|
|
| @st.cache(
|
| hash_funcs={
|
| torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
| .cpu()
|
| .numpy()
|
| },
|
| allow_output_mutation=True,
|
| )
|
| def load_blip_itm_model(device, model_type="base"):
|
| model = load_model(
|
| "blip_image_text_matching", model_type, is_eval=True, device=device
|
| )
|
| return model
|
|
|