| import io |
| import os |
|
|
| import json |
| import base64 |
| import random |
| import numpy as np |
| import pandas as pd |
| import gradio as gr |
| from pathlib import Path |
| from PIL import Image |
|
|
| from plots import get_pre_define_colors |
| from utils.load_model import load_xclip |
| from utils.predict import xclip_pred |
|
|
|
|
| DEVICE = "cpu" |
| XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE) |
| XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json" |
| XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r")) |
| PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt') |
| IMAGES_FOLDER = "data/images" |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r")) |
| |
| |
| IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST] |
|
|
| ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] |
| ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] |
| COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10']) |
| SACHIT_COLOR = "#ADD8E6" |
| |
| VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r')) |
| VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12)) |
|
|
| |
| def img_to_base64(img): |
| img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img |
| buffered = io.BytesIO() |
| img_pil.save(buffered, format="JPEG") |
| img_str = base64.b64encode(buffered.getvalue()) |
| return img_str.decode() |
|
|
| def create_blank_image(width=500, height=500, color=(255, 255, 255)): |
| """Create a blank image of the given size and color.""" |
| return np.array(Image.new("RGB", (width, height), color)) |
|
|
| |
| def rgb_to_hex(rgb): |
| return f"#{''.join(f'{x:02x}' for x in rgb)}" |
|
|
| def load_part_images(file_name: str) -> dict: |
| part_images = {} |
| |
| for part_name in ORDERED_PARTS: |
| base_name = Path(file_name).stem |
| part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg") |
| if not Path(part_image_path).exists(): |
| continue |
| image = np.array(Image.open(part_image_path)) |
| part_images[part_name] = img_to_base64(image) |
| |
| |
| return part_images |
|
|
| def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))): |
| """ |
| The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name' |
| descriptions: {part_name1: desc_1, part_name2: desc_2, ...} |
| pred_scores: {part_name1: score_1, part_name2: score_2, ...} |
| file_name: str |
| """ |
| |
| descriptions = result_dict['descriptions'] |
| image_name = result_dict['file_name'] |
| part_images = PART_IMAGES_DICT[image_name] |
| MAX_LENGTH = 50 |
| exp_length = 400 |
| fontsize = 15 |
|
|
| |
| svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', |
| "<svg width=\"100%\" height=\"100%\">"] |
|
|
| |
| y_offset = 0 |
| for part in ORDERED_PARTS: |
| if visibility[part] and part_mask[part]: |
| |
| part_score = max(result_dict['pred_scores'][part], 0) |
| bar_length = part_score * exp_length |
|
|
| |
| mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;" |
| mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;" |
|
|
| combined_mouseover = f"javascript: {mouseover_action1};" |
| combined_mouseout = f"javascript: {mouseout_action1};" |
|
|
| |
| num_lines = len(descriptions[part]) // MAX_LENGTH + 1 |
| for line in range(num_lines): |
| desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH] |
| y_offset += fontsize |
| svg_parts.append(f""" |
| <text x="0" y="{y_offset}" font-size="{fontsize}" |
| onmouseover="{combined_mouseover}" |
| onmouseout="{combined_mouseout}"> |
| {desc_line} |
| </text> |
| """) |
|
|
| |
| svg_parts.append(f""" |
| <rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}" |
| onmouseover="{combined_mouseover}" |
| onmouseout="{combined_mouseout}"> |
| </rect> |
| """) |
| |
| svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>') |
|
|
| y_offset += fontsize + 3 |
| svg_parts.extend(("</svg>", "</div>")) |
| |
| html = "".join(svg_parts) |
|
|
|
|
| return html |
|
|
|
|
|
|
| def generate_sachit_explanations(result_dict:dict): |
| descriptions = result_dict['descriptions'] |
| scores = result_dict['scores'] |
| MAX_LENGTH = 50 |
| exp_length = 400 |
| fontsize = 15 |
|
|
| descriptions = zip(scores, descriptions) |
| descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True) |
|
|
| |
| svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', |
| "<svg width=\"100%\" height=\"100%\">"] |
|
|
| |
| y_offset = 0 |
| for score, desc in descriptions: |
|
|
| |
| part_score = max(score, 0) |
| bar_length = part_score * exp_length |
|
|
| |
| num_lines = len(desc) // MAX_LENGTH + 1 |
| for line in range(num_lines): |
| desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH] |
| y_offset += fontsize |
| svg_parts.append(f""" |
| <text x="0" y="{y_offset}" font-size="{fontsize}" fill="black"> |
| {desc_line} |
| </text> |
| """) |
|
|
| |
| svg_parts.append(f""" |
| <rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}"> |
| </rect> |
| """) |
|
|
| |
| svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') |
|
|
| y_offset += fontsize + 3 |
|
|
|
|
| svg_parts.extend(("</svg>", "</div>")) |
| |
| html = "".join(svg_parts) |
|
|
|
|
| return html |
|
|
| |
| BLANK_OVERLAY = img_to_base64(create_blank_image()) |
| PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)} |
| blank_image = np.array(Image.open('data/images/final.png').convert('RGB')) |
| PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST} |
|
|
| |
| def update_selected_image(event: gr.SelectData): |
| image_height = 400 |
| index = event.index |
|
|
| image_name = IMAGE_FILE_LIST[index] |
| current_image.state = image_name |
| org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB') |
| img_base64 = f""" |
| <div style="position: relative; height: {image_height}px; display: inline-block;"> |
| <img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;"> |
| <img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;"> |
| </div> |
| """ |
| gt_label = XCLIP_RESULTS[image_name]['ground_truth'] |
| gt_class.state = gt_label |
|
|
| |
| out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state) |
| xclip_label = out_dict['pred_class'] |
| clip_pred_scores = out_dict['pred_score'] |
| xclip_part_scores = out_dict['pred_desc_scores'] |
| result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} |
| xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12))) |
| |
| |
| xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red" |
| xclip_pred_markdown = f""" |
| ### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {clip_pred_scores:.4f}</span> |
| """ |
|
|
| gt_label = f""" |
| ## {gt_label} |
| """ |
| current_predicted_class.state = xclip_label |
| |
| |
| custom_class_name = "class name: custom" |
| descs = XCLIP_DESC[xclip_label] |
| descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} |
| descs = {k: descs[k] for k in ORDERED_PARTS} |
| custom_text = [custom_class_name] + list(descs.values()) |
| descriptions = ";\n".join(custom_text) |
| textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) |
| |
| return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox |
|
|
| def on_edit_button_click_xclip(): |
| empty_exp = gr.HTML.update(visible=False) |
|
|
| |
| descs = XCLIP_DESC[current_predicted_class.state] |
| descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} |
| descs = {k: descs[k] for k in ORDERED_PARTS} |
| custom_text = ["class name: custom"] + list(descs.values()) |
| descriptions = ";\n".join(custom_text) |
| textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) |
| |
| return textbox, empty_exp |
|
|
| def convert_input_text_to_xclip_format(textbox_input: str): |
|
|
| |
| descriptions_list = textbox_input.split(";\n") |
| |
| class_name_line = descriptions_list[0] |
| new_class_name = class_name_line.split(":")[1].strip() |
| |
| descriptions_list = descriptions_list[1:] |
| |
| |
| descriptions_dict = {} |
| for desc in descriptions_list: |
| if desc.strip() == "": |
| continue |
| part_name, _ = desc.split(":") |
| descriptions_dict[part_name.strip()] = desc |
| |
| part_mask = {} |
| for part in ORDERED_PARTS: |
| if part not in descriptions_dict: |
| descriptions_dict[part] = "" |
| part_mask[part] = 0 |
| else: |
| part_mask[part] = 1 |
| return descriptions_dict, part_mask, new_class_name |
|
|
| def on_predict_button_click_xclip(textbox_input: str): |
| descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input) |
| |
| |
| out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state) |
| xclip_label = out_dict['pred_class'] |
| xclip_pred_score = out_dict['pred_score'] |
| xclip_part_scores = out_dict['pred_desc_scores'] |
| custom_label = out_dict['modified_class'] |
| custom_pred_score = out_dict['modified_score'] |
| custom_part_scores = out_dict['modified_desc_scores'] |
|
|
| |
| result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} |
| xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask) |
| modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state} |
| modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask) |
|
|
| xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red" |
| xclip_pred_markdown = f""" |
| ### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {xclip_pred_score:.4f}</span> |
| """ |
| custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red" |
| custom_pred_markdown = f""" |
| ### <span style='color:{custom_color}'>XCLIP: {custom_label} {custom_pred_score:.4f}</span> |
| """ |
| textbox = gr.Textbox.update(visible=False) |
| |
| |
| modified_exp = gr.HTML().update(value=modified_explanation, visible=True) |
| return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp |
|
|
|
|
| custom_css = """ |
| html, body { |
| margin: 0; |
| padding: 0; |
| } |
| |
| #container { |
| position: relative; |
| width: 400px; |
| height: 400px; |
| border: 1px solid #000; |
| margin: 0 auto; /* This will center the container horizontally */ |
| } |
| |
| #canvas { |
| position: absolute; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| object-fit: cover; |
| } |
| |
| """ |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo: |
| current_image = gr.State("") |
| current_predicted_class = gr.State("") |
| gt_class = gr.State("") |
| |
| with gr.Column(): |
| title_text = gr.Markdown("# PEEB - demo") |
| gr.Markdown( |
| "- In this demo, you can edit the descriptions of a class and see how to model react to it." |
| ) |
|
|
| |
| with gr.Column(): |
| |
| gr.Markdown("## Select an image to start!") |
| image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250) |
| gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_label = gr.Markdown("### Class Name") |
| org_image = gr.HTML() |
| |
| with gr.Column(): |
| with gr.Row(): |
| |
| xclip_predict_button = gr.Button(value="Predict") |
| xclip_pred_label = gr.Markdown("### XCLIP:") |
| xclip_explanation = gr.HTML() |
|
|
| with gr.Column(): |
| |
| xclip_edit_button = gr.Button(value="Reset Descriptions") |
| custom_pred_label = gr.Markdown( |
| "### Custom Descritpions:" |
| ) |
| xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False) |
| |
| custom_explanation = gr.HTML() |
|
|
| gr.HTML("<br>") |
|
|
| image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox]) |
| xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation]) |
| xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation]) |
|
|
| demo.launch(server_port=5000, share=True) |