| import gradio as gr |
| import os |
| import tempfile |
| import cv2 |
| from test import predict_one |
| from plot import ( |
| autocrop, get_json_corners, extract_points_from_xml, |
| draw_feature_matching, stack_images_side_by_side |
| ) |
|
|
| |
| MODEL_CKPT = "best_model.pth" |
|
|
|
|
| |
| |
| |
| def run_pipeline(flat_img, pers_img, mockup_json, xml_gt): |
| |
| tmpdir = tempfile.mkdtemp() |
| xml_pred_path = os.path.join(tmpdir, "pred.xml") |
| result_path = os.path.join(tmpdir, "result.png") |
|
|
| |
| predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path) |
|
|
| |
| img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB)) |
| img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB)) |
|
|
| json_pts = get_json_corners(mockup_json) |
| gt_pts = extract_points_from_xml(xml_gt) |
| pred_pts = extract_points_from_xml(xml_pred_path) |
| color = (0, 255, 0) |
| color2 = (0, 0, 255) |
| match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True) |
| match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True) |
|
|
| stacked = stack_images_side_by_side(match_json_gt, match_json_pred) |
| |
| h, w, _ = stacked.shape |
| center_x = w // 2 |
| cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4) |
|
|
| |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA) |
| cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA) |
|
|
| |
| cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR)) |
|
|
| return result_path, xml_pred_path |
|
|
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Mesh Key Point Transformer Demo") |
|
|
| with gr.Row(): |
| flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300) |
| pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300) |
|
|
| with gr.Row(): |
| mockup_json_in = gr.File(type="filepath", label="Mockup JSON") |
| xml_gt_in = gr.File(type="filepath", label="Ground Truth XML") |
|
|
| run_btn = gr.Button("Run Prediction + Visualization") |
|
|
| with gr.Row(): |
| out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600) |
| out_xml = gr.File(type="filepath", label="Predicted XML") |
|
|
| run_btn.click( |
| fn=run_pipeline, |
| inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in], |
| outputs=[out_img, out_xml] |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |
|
|