| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image, ImageEnhance, ImageDraw |
| import torch |
| import streamlit as st |
| from model.inference_cpu import inference_case |
|
|
| initial_rectangle = { |
| "version": "4.4.0", |
| 'objects': [ |
| { |
| "type": "rect", |
| "version": "4.4.0", |
| "originX": "left", |
| "originY": "top", |
| "left": 50, |
| "top": 50, |
| "width": 100, |
| "height": 100, |
| 'fill': 'rgba(255, 165, 0, 0.3)', |
| 'stroke': '#2909F1', |
| 'strokeWidth': 3, |
| 'strokeDashArray': None, |
| 'strokeLineCap': 'butt', |
| 'strokeDashOffset': 0, |
| 'strokeLineJoin': 'miter', |
| 'strokeUniform': True, |
| 'strokeMiterLimit': 4, |
| 'scaleX': 1, |
| 'scaleY': 1, |
| 'angle': 0, |
| 'flipX': False, |
| 'flipY': False, |
| 'opacity': 1, |
| 'shadow': None, |
| 'visible': True, |
| 'backgroundColor': '', |
| 'fillRule': |
| 'nonzero', |
| 'paintFirst': |
| 'fill', |
| 'globalCompositeOperation': 'source-over', |
| 'skewX': 0, |
| 'skewY': 0, |
| 'rx': 0, |
| 'ry': 0 |
| } |
| ] |
| } |
|
|
| def run(): |
| image = st.session_state.data_item["image"].float() |
| image_zoom_out = st.session_state.data_item["zoom_out_image"].float() |
| text_prompt = None |
| point_prompt = None |
| box_prompt = None |
| if st.session_state.use_text_prompt: |
| text_prompt = st.session_state.text_prompt |
| if st.session_state.use_point_prompt and len(st.session_state.points) > 0: |
| point_prompt = reflect_points_into_model(st.session_state.points) |
| if st.session_state.use_box_prompt: |
| box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox) |
| inference_case.clear() |
| st.write("text_prompt: {}".format(text_prompt)) |
| st.write("box_prompt: {}".format(box_prompt)) |
| st.write("point_prompt: {}".format(point_prompt)) |
| st.write("image shape: {}".format(image.shape)) |
| st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out, |
| text_prompt=text_prompt, |
| _point_prompt=point_prompt, |
| _box_prompt=box_prompt) |
|
|
| def reflect_box_into_model(box_3d): |
| z1, y1, x1, z2, y2, x2 = box_3d |
| x1_prompt = int(x1 * 256.0 / 325.0) |
| y1_prompt = int(y1 * 256.0 / 325.0) |
| z1_prompt = int(z1 * 32.0 / 325.0) |
| x2_prompt = int(x2 * 256.0 / 325.0) |
| y2_prompt = int(y2 * 256.0 / 325.0) |
| z2_prompt = int(z2 * 32.0 / 325.0) |
| return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) |
|
|
| def reflect_json_data_to_3D_box(json_data, view): |
| if view == 'xy': |
| st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top'] |
| st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left'] |
| st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY'] |
| st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX'] |
| print(st.session_state.rectangle_3Dbox) |
|
|
| def reflect_points_into_model(points): |
| points_prompt_list = [] |
| for point in points: |
| z, y, x = point |
| x_prompt = int(x * 256.0 / 325.0) |
| y_prompt = int(y * 256.0 / 325.0) |
| z_prompt = int(z * 32.0 / 325.0) |
| points_prompt_list.append([z_prompt, y_prompt, x_prompt]) |
| points_prompt = np.array(points_prompt_list) |
| points_label = np.ones(points_prompt.shape[0]) |
| print(points_prompt, points_label) |
| return (torch.tensor(points_prompt), torch.tensor(points_label)) |
|
|
| def show_points(points_ax, points_label, ax): |
| color = 'red' if points_label == 0 else 'blue' |
| ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200) |
|
|
| def make_fig(image, preds, point_axs=None, current_idx=None, view=None): |
| |
| image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB") |
| enhancer = ImageEnhance.Contrast(image) |
| image = enhancer.enhance(2.0) |
|
|
| |
| if preds is not None: |
| mask = np.where(preds == 1, 255, 0).astype(np.uint8) |
| mask = Image.merge("RGB", |
| (Image.fromarray(mask), |
| Image.fromarray(mask), |
| Image.fromarray(np.zeros_like(mask, dtype=np.uint8)))) |
|
|
| |
| image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency) |
| |
| if point_axs is not None: |
| draw = ImageDraw.Draw(image) |
| radius = 5 |
| for point in point_axs: |
| z, y, x = point |
| if view == 'xy' and z == current_idx: |
| draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue") |
| elif view == 'xz'and y == current_idx: |
| draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue") |
| return image |