| import json |
| import os |
| import pickle |
| import random |
| import time |
| from collections import Counter |
| from datetime import datetime |
| from glob import glob |
|
|
| import gdown |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| import streamlit as st |
| from PIL import Image |
|
|
| import SessionState |
| from download_utils import * |
| from image_utils import * |
|
|
| random.seed(datetime.now()) |
| np.random.seed(int(time.time())) |
|
|
| NUMBER_OF_TRIALS = 20 |
| CLASSIFIER_TAG = "" |
| selected_xai_tool = None |
|
|
| |
| folder_to_name = {} |
| |
| classifier_predictions = {} |
| selected_dataset = "CUB-iNAt-Unified" |
|
|
| root_visualization_dir = "./visualizations/" |
| viz_url = "https://drive.google.com/uc?id=1ifNYT2Jfnj61U1t144ZRE3c1PI_K89YA" |
| viz_archivefile = "CUB-Final.zip" |
|
|
|
|
| demonstration_url = "https://drive.google.com/uc?id=1YbAT-M_ZdjLlPuPPAaXbAOqYOqhAIJWs" |
| demonst_zipfile = "demonstrations.zip" |
|
|
| picklefile_url = "https://drive.google.com/uc?id=175CO7kfqedrmqO-3wXCN9gFFjnVhXFn8" |
| prediction_root = "./predictions/" |
| prediction_pickle = f"{prediction_root}predictions.pickle" |
|
|
|
|
| |
| download_files( |
| root_visualization_dir, |
| viz_url, |
| viz_archivefile, |
| demonstration_url, |
| demonst_zipfile, |
| picklefile_url, |
| prediction_root, |
| prediction_pickle, |
| ) |
| |
| |
| app_mode = "" |
|
|
| |
| birds_list = list( |
| sorted([x.replace(".jpg", "") for x in os.listdir("./CUB-Demonstrations")]) |
| ) |
| id_to_bird = {i: x for i, x in enumerate(birds_list)} |
| folder_to_name = {x: x for x in birds_list} |
| |
|
|
| with open(prediction_pickle, "rb") as f: |
| classifier_predictions = pickle.load(f) |
|
|
| |
| session_state = SessionState.get( |
| page=1, |
| first_run=1, |
| user_feedback={}, |
| queries=[], |
| is_classifier_correct={}, |
| XAI_tool="Unselected", |
| ) |
| |
|
|
|
|
| def resmaple_queries(): |
| if session_state.first_run == 1: |
|
|
| |
| |
| |
| |
| |
| |
|
|
| Corret_predictions_idx = [ |
| k |
| for k, v in classifier_predictions.items() |
| if v[f"{CLASSIFIER_TAG}-Output"] |
| ] |
| Wrong_predictions_idx = [ |
| k |
| for k, v in classifier_predictions.items() |
| if not v[f"{CLASSIFIER_TAG}-Output"] |
| ] |
|
|
| correct_classified_plots = [ |
| f"{root_visualization_dir}{selected_dataset}/cub-inat-{x}.jpeg" |
| for x in Corret_predictions_idx |
| ] |
| wrong_classified_plots = [ |
| f"{root_visualization_dir}{selected_dataset}/cub-inat-{x}.jpeg" |
| for x in Wrong_predictions_idx |
| ] |
|
|
| correct_samples = list( |
| np.random.choice( |
| a=correct_classified_plots, size=NUMBER_OF_TRIALS // 2, replace=False |
| ) |
| ) |
| wrong_samples = list( |
| np.random.choice( |
| a=wrong_classified_plots, size=NUMBER_OF_TRIALS // 2, replace=False |
| ) |
| ) |
|
|
| all_images = correct_samples + wrong_samples |
| random.shuffle(all_images) |
|
|
| session_state.queries = all_images |
| session_state.first_run = -1 |
| |
| session_state.user_feedback = {} |
| session_state.is_classifier_correct = {} |
|
|
|
|
| def render_experiment(query): |
| current_query = session_state.queries[query] |
| query_id = int(os.path.basename(current_query).split("-")[2].split(".")[0]) |
|
|
| predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"] |
| prediction_confidence = classifier_predictions[query_id][ |
| f"{CLASSIFIER_TAG}-confidence" |
| ] |
| prediction_label = folder_to_name[predicted_wnid] |
| |
|
|
| session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][ |
| f"{CLASSIFIER_TAG.upper()}-Output" |
| ] |
|
|
| |
|
|
| col1, col2 = st.columns(2) |
| with col1: |
| st.image(load_query(current_query), caption=f"Query ID: {query_id}") |
| with col2: |
| |
| with st.expander("Show Class Description"): |
| st.write(f"**Name**: {prediction_label}") |
| st.write("**Class Definition**:") |
| |
| st.image( |
| Image.open(f"CUB-Demonstrations/{predicted_wnid}.jpg"), |
| caption=f"Class Explanation", |
| use_column_width=True, |
| ) |
|
|
| default_value = 0 |
| if query_id in session_state.user_feedback.keys(): |
| if session_state.user_feedback[query_id] == "Correct": |
| default_value = 1 |
| elif session_state.user_feedback[query_id] == "Wrong": |
| default_value = 2 |
|
|
| session_state.user_feedback[query_id] = st.radio( |
| "What do you think about model's prediction?", |
| ("-", "Correct", "Wrong"), |
| key=query_id, |
| index=default_value, |
| ) |
| st.write(f"**Model Prediction**: {prediction_label}") |
| st.write(f"**Model Confidence**: {prediction_confidence}") |
|
|
| |
| if selected_xai_tool is not None: |
| st.image( |
| selected_xai_tool(current_query), |
| caption=f"Explaination", |
| use_column_width=True, |
| ) |
|
|
| |
|
|
| if st.button("Debug: Show Everything"): |
| st.image(Image.open(current_query)) |
|
|
|
|
| def render_results(): |
| user_correct_guess = 0 |
| |
| |
| for q in session_state.user_feedback.keys(): |
| if session_state.user_feedback[q] != "-": |
| uf = True if session_state.user_feedback[q] == "Correct" else False |
| if session_state.is_classifier_correct[q] == uf: |
| user_correct_guess += 1 |
|
|
| st.write( |
| f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct" |
| ) |
| st.markdown("## User Performance Breakdown") |
|
|
| categories = [ |
| "Correct", |
| "Wrong", |
| ] |
| breakdown_stats_correct = {c: 0 for c in categories} |
| breakdown_stats_wrong = {c: 0 for c in categories} |
|
|
| experiment_summary = [] |
|
|
| for q in session_state.user_feedback.keys(): |
| category = "Correct" if session_state.is_classifier_correct[q] else "Wrong" |
| is_user_correct = category == session_state.user_feedback[q] |
|
|
| if is_user_correct: |
| breakdown_stats_correct[category] += 1 |
| else: |
| breakdown_stats_wrong[category] += 1 |
|
|
| experiment_summary.append( |
| [ |
| q, |
| classifier_predictions[q]["gt_wnid"], |
| folder_to_name[ |
| classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"] |
| ], |
| category, |
| session_state.user_feedback[q], |
| is_user_correct, |
| ] |
| ) |
| |
| experiment_summary_df = pd.DataFrame.from_records( |
| experiment_summary, |
| columns=[ |
| "Query", |
| "GT Labels", |
| f"{CLASSIFIER_TAG} Prediction", |
| "Category", |
| "User Prediction", |
| "Is User Prediction Correct", |
| ], |
| ) |
| st.write("Summary", experiment_summary_df) |
|
|
| csv = convert_df(experiment_summary_df) |
| st.download_button( |
| "Press to Download", csv, "summary.csv", "text/csv", key="download-records" |
| ) |
| |
| user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg( |
| {"Is User Prediction Correct": ["count", "sum", "mean"]} |
| ) |
| |
| user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0) |
| user_pf_by_model_pred.columns = [ |
| "Count", |
| "Correct User Guess", |
| "Mean User Performance", |
| ] |
| user_pf_by_model_pred.index.name = "Model Prediction" |
| st.write("User performance break down by Model prediction:", user_pf_by_model_pred) |
| csv = convert_df(user_pf_by_model_pred) |
| st.download_button( |
| "Press to Download", |
| csv, |
| "user-performance-by-model-prediction.csv", |
| "text/csv", |
| key="download-performance-by-model-prediction", |
| ) |
| |
|
|
| confusion_matrix = pd.crosstab( |
| experiment_summary_df["Category"], |
| experiment_summary_df["User Prediction"], |
| rownames=["Actual"], |
| colnames=["Predicted"], |
| ) |
| st.write("Confusion Matrix", confusion_matrix) |
| csv = convert_df(confusion_matrix) |
| st.download_button( |
| "Press to Download", |
| csv, |
| "confusion-matrix.csv", |
| "text/csv", |
| key="download-confusiion-matrix", |
| ) |
|
|
|
|
| def render_menu(): |
| |
| readme_text = st.markdown( |
| """ |
| # Instructions |
| ``` |
| When testing this study, you should first see the class definition, then hide the expander and see the query. |
| ``` |
| """ |
| ) |
|
|
| app_mode = st.selectbox( |
| "Choose the page to show:", |
| ["Experiment Instruction", "Start Experiment", "See the Results"], |
| ) |
|
|
| if app_mode == "Experiment Instruction": |
| st.success("To continue select an option in the dropdown menu.") |
| elif app_mode == "Start Experiment": |
| |
| readme_text.empty() |
|
|
| page_id = session_state.page |
| col1, col4, col2, col3 = st.columns(4) |
| prev_page = col1.button("Previous Image") |
|
|
| if prev_page: |
| page_id -= 1 |
| if page_id < 1: |
| page_id = 1 |
|
|
| next_page = col2.button("Next Image") |
|
|
| if next_page: |
| page_id += 1 |
| if page_id > NUMBER_OF_TRIALS: |
| page_id = NUMBER_OF_TRIALS |
|
|
| if page_id == NUMBER_OF_TRIALS: |
| st.success( |
| 'You have reached the last image. Please go to the "Results" page to see your performance.' |
| ) |
| if st.button("View"): |
| app_mode = "See the Results" |
|
|
| if col3.button("Resample"): |
| st.write("Restarting ...") |
| page_id = 1 |
| session_state.first_run = 1 |
| resmaple_queries() |
|
|
| session_state.page = page_id |
| st.write(f"Render Experiment: {session_state.page}") |
| render_experiment(session_state.page - 1) |
| elif app_mode == "See the Results": |
| readme_text.empty() |
| st.write("Results Summary") |
| render_results() |
|
|
|
|
| def main(): |
| global app_mode |
| global session_state |
| global selected_xai_tool |
| global CLASSIFIER_TAG |
|
|
| |
| st.set_page_config(layout="wide") |
| st.title("Visual CorrespondenceHuman Study - CUB") |
|
|
| options = [ |
| "Unselected", |
| "NOXAI", |
| "KNN", |
| "EMD-Corr Nearest Neighbors", |
| "EMD-Corr Correspondence", |
| "CHM-Corr Nearest Neighbors", |
| "CHM-Corr Correspondence", |
| ] |
|
|
| st.markdown( |
| """ <style> |
| div[role="radiogroup"] > :first-child{ |
| display: none !important; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| if session_state.XAI_tool == "Unselected": |
| default = options.index(session_state.XAI_tool) |
| session_state.XAI_tool = st.radio( |
| "What explaination tool do you want to evaluate?", |
| options, |
| key="which_xai", |
| index=default, |
| ) |
|
|
| if session_state.XAI_tool != "Unselected": |
| st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``") |
|
|
| if session_state.XAI_tool == "NOXAI": |
| CLASSIFIER_TAG = "KNN" |
| selected_xai_tool = None |
| elif session_state.XAI_tool == "KNN": |
| selected_xai_tool = load_knn_nns |
| CLASSIFIER_TAG = "KNN" |
| elif session_state.XAI_tool == "CHM-Corr Nearest Neighbors": |
| selected_xai_tool = load_chm_nns |
| CLASSIFIER_TAG = "CHM" |
| elif session_state.XAI_tool == "CHM-Corr Correspondence": |
| selected_xai_tool = load_chm_corrs |
| CLASSIFIER_TAG = "CHM" |
| elif session_state.XAI_tool == "EMD-Corr Nearest Neighbors": |
| selected_xai_tool = load_emd_nns |
| CLASSIFIER_TAG = "EMD" |
| elif session_state.XAI_tool == "EMD-Corr Correspondence": |
| selected_xai_tool = load_emd_corrs |
| CLASSIFIER_TAG = "EMD" |
|
|
| resmaple_queries() |
| render_menu() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|