| """
|
| # Copyright (c) 2022, salesforce.com, inc.
|
| # All rights reserved.
|
| # SPDX-License-Identifier: BSD-3-Clause
|
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import random
|
| from collections import OrderedDict
|
| from functools import reduce
|
| from tkinter import N
|
|
|
| import streamlit as st
|
| from lavis.common.registry import registry
|
| from lavis.datasets.builders import dataset_zoo, load_dataset
|
| from lavis.datasets.builders.base_dataset_builder import load_dataset_config
|
| from PIL import Image
|
|
|
| IMAGE_LAYOUT = 3, 4
|
| VIDEO_LAYOUT = 1, 2
|
|
|
| PREV_STR = "Prev"
|
| NEXT_STR = "Next"
|
|
|
|
|
| def sample_dataset(dataset, indices):
|
| samples = [dataset.displ_item(idx) for idx in indices]
|
|
|
| return samples
|
|
|
|
|
| def get_concat_v(im1, im2):
|
| margin = 5
|
|
|
| canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
|
| canvas = Image.new("RGB", canvas_size, "White")
|
| canvas.paste(im1, (0, 0))
|
| canvas.paste(im2, (im1.width + margin, 0))
|
|
|
| return canvas
|
|
|
|
|
| def resize_img_w(raw_img, new_w=224):
|
| if isinstance(raw_img, list):
|
| resized_imgs = [resize_img_w(img, 196) for img in raw_img]
|
|
|
| resized_image = reduce(get_concat_v, resized_imgs)
|
| else:
|
| w, h = raw_img.size
|
| scaling_factor = new_w / w
|
| resized_image = raw_img.resize(
|
| (int(w * scaling_factor), int(h * scaling_factor))
|
| )
|
|
|
| return resized_image
|
|
|
|
|
| def get_visual_key(dataset):
|
| if "image" in dataset[0]:
|
| return "image"
|
| elif "image0" in dataset[0]:
|
| return "image"
|
| elif "video" in dataset[0]:
|
| return "video"
|
| else:
|
| raise ValueError("Visual key not found.")
|
|
|
|
|
| def gather_items(samples, exclude=[]):
|
| gathered = []
|
|
|
| for s in samples:
|
| ns = OrderedDict()
|
| for k in s.keys():
|
| if k not in exclude:
|
| ns[k] = s[k]
|
|
|
| gathered.append(ns)
|
|
|
| return gathered
|
|
|
|
|
| @st.cache(allow_output_mutation=True)
|
| def load_dataset_cache(name):
|
| return load_dataset(name)
|
|
|
|
|
| def format_text(text):
|
| md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
|
|
|
| return md
|
|
|
|
|
| def show_samples(dataset, offset=0, is_next=False):
|
| visual_key = get_visual_key(dataset)
|
|
|
| num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
|
| n_samples = num_rows * num_cols
|
|
|
| if not shuffle:
|
| if is_next:
|
| start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
|
| else:
|
| start = max(0, int(start_idx) + offset - n_samples)
|
|
|
| st.session_state.last_start = start
|
| end = min(start + n_samples, len(dataset))
|
|
|
| indices = list(range(start, end))
|
| else:
|
| indices = random.sample(range(len(dataset)), n_samples)
|
| samples = sample_dataset(dataset, indices)
|
|
|
| visual_info = (
|
| iter([resize_img_w(s[visual_key]) for s in samples])
|
| if visual_key == "image"
|
|
|
| else iter([s["file"] for s in samples])
|
| )
|
| text_info = gather_items(samples, exclude=["image", "video"])
|
| text_info = iter([format_text(s) for s in text_info])
|
|
|
| st.markdown(
|
| """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
|
| unsafe_allow_html=True,
|
| )
|
| for _ in range(num_rows):
|
| with st.container():
|
| for col in st.columns(num_cols):
|
|
|
|
|
| try:
|
| col.markdown(next(text_info))
|
| if visual_key == "image":
|
| col.image(next(visual_info), use_column_width=True, clamp=True)
|
| elif visual_key == "video":
|
| col.markdown(
|
| ""
|
| )
|
| except StopIteration:
|
| break
|
|
|
| st.markdown(
|
| """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
|
| unsafe_allow_html=True,
|
| )
|
|
|
| st.session_state.n_display = n_samples
|
|
|
|
|
| if __name__ == "__main__":
|
| st.set_page_config(
|
| page_title="LAVIS Dataset Explorer",
|
|
|
| initial_sidebar_state="expanded",
|
| )
|
|
|
| dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
|
|
|
| function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
|
|
|
| if function == "Browser":
|
| shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
|
|
|
| dataset = load_dataset_cache(dataset_name)
|
| split = st.sidebar.selectbox("Split:", dataset.keys())
|
|
|
| dataset_len = len(dataset[split])
|
| st.success(
|
| f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
|
| )
|
|
|
| if "last_dataset" not in st.session_state:
|
| st.session_state.last_dataset = dataset_name
|
| st.session_state.last_split = split
|
|
|
| if "last_start" not in st.session_state:
|
| st.session_state.last_start = 0
|
|
|
| if "start_idx" not in st.session_state:
|
| st.session_state.start_idx = 0
|
|
|
| if "shuffle" not in st.session_state:
|
| st.session_state.shuffle = shuffle
|
|
|
| if "first_run" not in st.session_state:
|
| st.session_state.first_run = True
|
| elif (
|
| st.session_state.last_dataset != dataset_name
|
| or st.session_state.last_split != split
|
| ):
|
| st.session_state.first_run = True
|
|
|
| st.session_state.last_dataset = dataset_name
|
| st.session_state.last_split = split
|
| elif st.session_state.shuffle != shuffle:
|
| st.session_state.shuffle = shuffle
|
| st.session_state.first_run = True
|
|
|
| if not shuffle:
|
| n_col, p_col = st.columns([0.05, 1])
|
|
|
| prev_button = n_col.button(PREV_STR)
|
| next_button = p_col.button(NEXT_STR)
|
|
|
| else:
|
| next_button = st.button(NEXT_STR)
|
|
|
| if not shuffle:
|
| start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
|
|
|
| if not start_idx.isdigit():
|
| st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
|
| else:
|
| if int(start_idx) != st.session_state.start_idx:
|
| st.session_state.start_idx = int(start_idx)
|
| st.session_state.last_start = int(start_idx)
|
|
|
| if prev_button:
|
| show_samples(
|
| dataset[split],
|
| offset=st.session_state.last_start - st.session_state.start_idx,
|
| is_next=False,
|
| )
|
|
|
| if next_button:
|
| show_samples(
|
| dataset[split],
|
| offset=st.session_state.last_start - st.session_state.start_idx,
|
| is_next=True,
|
| )
|
|
|
| if st.session_state.first_run:
|
| st.session_state.first_run = False
|
|
|
| show_samples(
|
| dataset[split],
|
| offset=st.session_state.last_start - st.session_state.start_idx,
|
| is_next=True,
|
| )
|
|
|