| import streamlit as st |
| import streamlit_analytics |
|
|
| import torch |
| import torchvision.transforms as transforms |
| from transformers import ViTModel, ViTConfig |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import io |
|
|
| streamlit_analytics.start_tracking() |
|
|
| |
| st.set_page_config(page_title="ViewViz", layout="wide") |
|
|
| |
| st.markdown(""" |
| <style> |
| .stApp { |
| background-color: #2b3d4f; |
| color: #ffffff; |
| } |
| .stButton>button { |
| color: #2b3d4f; |
| background-color: #4fd1c5; |
| border-radius: 5px; |
| } |
| .stSlider>div>div>div>div { |
| background-color: #4fd1c5; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| |
| USE_GPU = False |
| device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu') |
|
|
| |
| COLOR_SCHEMES = { |
| 'Plasma': plt.cm.plasma, |
| 'Viridis': plt.cm.viridis, |
| 'Magma': plt.cm.magma, |
| 'Inferno': plt.cm.inferno, |
| 'Cividis': plt.cm.cividis, |
| 'Spectral': plt.cm.Spectral, |
| 'Coolwarm': plt.cm.coolwarm |
| } |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model_name = 'google/vit-base-patch16-384' |
| config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager") |
| model = ViTModel.from_pretrained(model_name, config=config) |
| model.eval() |
| return model.to(device) |
|
|
| model = load_model() |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((384, 384)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def get_attention_map(img): |
| |
| input_tensor = preprocess(img).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(input_tensor, output_attentions=True) |
| |
| |
| att_mat = torch.stack(outputs.attentions).squeeze(1) |
| att_mat = torch.mean(att_mat, dim=1) |
|
|
| |
| residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device) |
| aug_att_mat = att_mat + residual_att |
| aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) |
|
|
| |
| joint_attentions = torch.zeros(aug_att_mat.size()).to(device) |
| joint_attentions[0] = aug_att_mat[0] |
| for n in range(1, aug_att_mat.size(0)): |
| joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) |
|
|
| |
| v = joint_attentions[-1] |
| grid_size = int(np.sqrt(aug_att_mat.size(-1))) |
| mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy() |
| |
| return mask |
|
|
| def overlay_attention_map(image, attention_map, overlay_strength, color_scheme): |
| |
| attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC) |
| attention_map = np.array(attention_map) |
| |
| |
| attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) |
| |
| |
| attention_map_color = color_scheme(attention_map) |
| |
| |
| image_rgba = image.convert("RGBA") |
| image_array = np.array(image_rgba) / 255.0 |
| |
| |
| overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength |
| |
| return Image.fromarray((overlayed_image * 255).astype(np.uint8)) |
|
|
| st.title("ViewViz") |
|
|
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file).convert('RGB') |
| |
| st.success("Starting Prediction Process...") |
| attention_map = get_attention_map(image) |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0 |
| |
| with col2: |
| color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys())) |
| |
| color_scheme = COLOR_SCHEMES[color_scheme_name] |
| |
| overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme) |
| |
| st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True) |
| |
| |
| buf = io.BytesIO() |
| overlayed_image.save(buf, format="PNG") |
| btn = st.download_button( |
| label="Download Image with Attention Map", |
| data=buf.getvalue(), |
| file_name="attention_map_overlay.png", |
| mime="image/png" |
| ) |
|
|
| streamlit_analytics.stop_tracking() |