| import onnxruntime as ort |
| import numpy |
| import gradio as gr |
| from PIL import Image |
|
|
| ort_sess = ort.InferenceSession('tiny_doodle_embedding.onnx') |
|
|
| |
|
|
| def get_bounds(img): |
| |
| |
| left = img.shape[1] |
| right = 0 |
| top = img.shape[0] |
| bottom = 0 |
| min_color = numpy.min(img) |
| max_color = numpy.max(img) |
| mean_color = 0.5*(min_color+max_color) |
| |
| for y in range(0, img.shape[0]): |
| for x in range(0, img.shape[1]): |
| if img[y,x] > mean_color: |
| left = min(left, x) |
| right = max(right, x) |
| top = min(top, y) |
| bottom = max(bottom, y) |
| return (top, bottom, left, right) |
|
|
| def resize_maxpool(img, out_width: int, out_height: int): |
| out = numpy.zeros((out_height, out_width), dtype=img.dtype) |
| scale_factor_y = img.shape[0] // out_height |
| scale_factor_x = img.shape[1] // out_width |
| for y in range(0, out.shape[0]): |
| for x in range(0, out.shape[1]): |
| out[y,x] = numpy.max(img[y*scale_factor_y:(y+1)*scale_factor_y, x*scale_factor_x:(x+1)*scale_factor_x]) |
| return out |
|
|
| def process_input(input_msg): |
| img = input_msg["composite"] |
| |
| img_mean = 0.5 * (numpy.max(img) + numpy.min(img)) |
| img = 1.0 * (img < img_mean) |
| crop_area = get_bounds(img) |
| img = img[crop_area[0]:crop_area[1], crop_area[2]:crop_area[3]] |
| img = resize_maxpool(img, 32, 32) |
| |
| img = numpy.expand_dims(img, axis=0) |
| return img |
| |
|
|
| def compare(input_img_a, input_img_b): |
| text_out = "" |
|
|
| img_a = process_input(input_img_a) |
| img_b = process_input(input_img_b) |
|
|
| |
| a_embedding = ort_sess.run(None, {'input': img_a.astype(numpy.float32)})[0] |
| b_embedding = ort_sess.run(None, {'input': img_b.astype(numpy.float32)})[0] |
| a_mag = 1.0 |
| b_mag = 1.0 |
| a_embedding /= a_mag |
| b_embedding /= b_mag |
| text_out += f"img_a_embedding: {a_embedding}\n" |
| text_out += f"img_b_embedding: {b_embedding}\n" |
| sim = numpy.dot(a_embedding , b_embedding.T) |
| print(sim) |
| print(text_out) |
| return Image.fromarray(numpy.clip((numpy.hstack([img_a[0], img_b[0]]) * 254), 0, 255).astype(numpy.uint8)), sim[0][0], text_out |
| |
|
|
|
|
| demo = gr.Interface( |
| fn=compare, |
| inputs=[ |
| gr.Sketchpad(image_mode='L', type='numpy'), |
| gr.Sketchpad(image_mode='L', type='numpy'), |
| |
| |
| |
| |
| |
| |
| |
| ], |
| outputs=["image", "number", "text"], |
| ) |
|
|
| demo.launch(share=True) |
|
|
|
|