| import torch |
| import cv2 |
| import os |
| import gradio as gr |
| import numpy as np |
| import random |
| from pathlib import Path |
| import json |
| import spaces |
|
|
|
|
| |
| _TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images' |
| MAX_SEED = 1000 |
|
|
| os.system('mkdir -p models') |
| os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v2-train-sa1b.pt -O models/scalelsd-vitbase-v2-train-sa1b.pt') |
| os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v1-train-sa1b.pt -O models/scalelsd-vitbase-v1-train-sa1b.pt') |
| os.system('pip install -e .') |
|
|
|
|
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
| """random seed""" |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| return seed |
|
|
| def stop_run(): |
| """stop run""" |
| return ( |
| gr.update(value="Run", variant="primary", visible=True), |
| gr.update(visible=False), |
| ) |
|
|
| |
| @spaces.GPU |
| def process_image( |
| input_image, |
| model_name='scalelsd-vitbase-v2-train-sa1b.pt', |
| save_name='temp_output', |
| threshold=10, |
| junction_threshold_hm=0.008, |
| num_junctions_inference=512, |
| width=512, |
| height=512, |
| line_width=2, |
| juncs_size=4, |
| whitebg=0.0, |
| draw_junctions_only=False, |
| use_lsd=False, |
| use_nms=False, |
| edge_color='orange', |
| vertex_color='Cyan', |
| output_format='png', |
| seed=0, |
| randomize_seed=False |
| ): |
| use_lsd = False |
| from scalelsd.ssl.models.detector import ScaleLSD |
| from scalelsd.base import show, WireframeGraph |
| from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model |
| """core processing function for image inference""" |
| |
| seed = int(randomize_seed_fn(seed, randomize_seed)) |
| fix_seeds(seed) |
| |
| |
| ckpt = "models/" + model_name |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = load_scalelsd_model(ckpt, device) |
|
|
| |
| model.junction_threshold_hm = junction_threshold_hm |
| model.num_junctions_inference = num_junctions_inference |
|
|
| |
| if isinstance(input_image, np.ndarray): |
| image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY) |
| else: |
| image = cv2.imread(input_image, 0) |
| |
| |
| ori_shape = image.shape[:2] |
| image_resized = cv2.resize(image.copy(), (width, height)) |
| image_tensor = torch.from_numpy(image_resized).float() / 255.0 |
| image_tensor = image_tensor[None, None].to('cuda') |
| |
| |
| meta = { |
| 'width': ori_shape[1], |
| 'height': ori_shape[0], |
| 'filename': '', |
| 'use_lsd': use_lsd, |
| 'use_nms': use_nms, |
| } |
| |
| |
| with torch.no_grad(): |
| outputs, _ = model(image_tensor, meta) |
| outputs = outputs[0] |
| |
| |
| painter = show.painters.HAWPainter() |
| painter.confidence_threshold = threshold |
| painter.line_width = line_width |
| painter.marker_size = juncs_size |
| if whitebg > 0.0: |
| show.Canvas.white_overlay = whitebg |
| |
| temp_folder = "temp_output" |
| os.makedirs(temp_folder, exist_ok=True) |
| fig_file = f"{temp_folder}/{save_name}.png" |
| with show.image_canvas(input_image, fig_file=fig_file) as ax: |
| if draw_junctions_only: |
| painter.draw_junctions(ax, outputs) |
| else: |
| painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) |
| |
| result_image = cv2.imread(fig_file) |
|
|
| if output_format != 'png': |
| fig_file = f"{temp_folder}/{save_name}.{output_format}" |
| with show.image_canvas(input_image, fig_file=fig_file) as ax: |
| if draw_junctions_only: |
| painter.draw_junctions(ax, outputs) |
| else: |
| painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) |
|
|
| json_file = f"{temp_folder}/{save_name}.json" |
| indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) |
| wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) |
| with open(json_file, 'w') as f: |
| json.dump(wireframe.jsonize(),f) |
|
|
|
|
| return result_image[:, :, ::-1], json_file, fig_file |
| |
| def run_demo(): |
| """create the Gradio demo interface""" |
| css = """ |
| #col-container { |
| margin: 0 auto; |
| max-width: 800px; |
| } |
| """ |
| |
| with gr.Blocks(css=css, title=_TITLE) as demo: |
| with gr.Column(elem_id="col-container"): |
| gr.Markdown(f'# {_TITLE}') |
| gr.Markdown("Detect wireframe structures in images using ScaleLSD model") |
| |
| pid = gr.State() |
| figs_root = "assets/figs" |
| example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)] |
| |
| with gr.Row(): |
| input_image = gr.Image(example_images[0], label="Input Image", type="numpy") |
| output_image = gr.Image(label="Detection Result") |
| |
| with gr.Row(): |
| run_btn = gr.Button(value="Run", variant="primary") |
| stop_btn = gr.Button(value="Stop", variant="stop", visible=False) |
| |
| with gr.Row(): |
| json_file = gr.File(label="Download JSON Output", type="filepath") |
| image_file = gr.File(label="Download Image Output", type="filepath") |
| |
| with gr.Accordion("Advanced Settings", open=True): |
| with gr.Row(): |
| model_name = gr.Dropdown( |
| [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], |
| value='scalelsd-vitbase-v2-train-sa1b.pt', |
| label="Model Selection" |
| ) |
|
|
| with gr.Row(): |
| save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| threshold = gr.Number(10, label="Line Threshold") |
| junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") |
| num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") |
| width = gr.Number(512, label="Input Width") |
| height = gr.Number(512, label="Input Height") |
| |
| with gr.Column(): |
| draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") |
| use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") |
| use_nms = gr.Checkbox(True, label="Use NMS") |
| output_format = gr.Dropdown( |
| ['png', 'jpg', 'pdf'], |
| value='png', |
| label="Output Format" |
| ) |
| whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity") |
| line_width = gr.Number(2, label="Line Width") |
| juncs_size = gr.Number(8, label="Junctions Size") |
| |
| with gr.Row(): |
| edge_color = gr.Dropdown( |
| ['orange', 'midnightblue', 'red', 'green'], |
| value='orange', |
| label="Edge Color" |
| ) |
| vertex_color = gr.Dropdown( |
| ['Cyan', 'deeppink', 'yellow', 'purple'], |
| value='Cyan', |
| label="Vertex Color" |
| ) |
| |
| with gr.Row(): |
| randomize_seed = gr.Checkbox(False, label="Randomize Seed") |
| seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") |
| |
| gr.Examples( |
| examples=example_images, |
| inputs=input_image, |
| ) |
| |
| |
| run_event = run_btn.click( |
| fn=process_image, |
| inputs=[ |
| input_image, |
| model_name, |
| save_name, |
| threshold, |
| junction_threshold_hm, |
| num_junctions_inference, |
| width, |
| height, |
| line_width, |
| juncs_size, |
| whitebg, |
| draw_junctions_only, |
| use_lsd, |
| use_nms, |
| edge_color, |
| vertex_color, |
| output_format, |
| seed, |
| randomize_seed |
| ], |
| outputs=[output_image, json_file, image_file], |
| ) |
| |
| |
| stop_btn.click( |
| fn=stop_run, |
| outputs=[run_btn, stop_btn], |
| cancels=[run_event], |
| queue=False, |
| ) |
|
|
| |
| return demo |
|
|
| run_demo().launch() |
|
|