| import gradio as gr |
| from PIL import Image |
| from vit_mosaic import make_vit_mosaic |
| import tempfile |
| import os |
| import svgwrite |
| import base64 |
| import requests |
| from io import BytesIO |
| from rembg import new_session, remove |
| import argparse |
|
|
| |
| |
| |
|
|
| print("Loading rembg model...") |
| rembg_session = new_session("u2net") |
| print("rembg model loaded") |
|
|
|
|
| |
| |
| |
|
|
| EXAMPLES = { |
| "Dogs": { |
| "url": "https://raw.githubusercontent.com/daidedou/daidedou.github.io/master/images/dogs.jpg", |
| "credit": "Photo by Jackielsy β Pixabay (CC0)" |
| }, |
| "Landscape": { |
| "url": "https://raw.githubusercontent.com/daidedou/daidedou.github.io/master/images/landscape.jpg", |
| "credit": "Photo by brenkee β Pixabay (CC0)" |
| }, |
| "Illustration": { |
| "url": "https://raw.githubusercontent.com/daidedou/daidedou.github.io/master/images/illustration.jpg", |
| "credit": "Illustration by the_iop β Pixabay (CC0)" |
| } |
| } |
|
|
| |
| |
| |
|
|
| def rgb_to_hex(r, g, b): |
| return f"#{r:02X}{g:02X}{b:02X}" |
|
|
|
|
| def update_color_preview(r, g, b): |
| hex_color = rgb_to_hex(r, g, b) |
| return f""" |
| <div style=" |
| width:100%; |
| height:40px; |
| border-radius:8px; |
| border:1px solid #ccc; |
| background:{hex_color}; |
| "></div> |
| """ |
|
|
|
|
| def toggle_clipping(enabled): |
| return gr.update(interactive=enabled) |
|
|
|
|
| def load_image_from_url(url): |
| response = requests.get(url, timeout=10) |
| response.raise_for_status() |
| return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
| def on_gallery_select(evt: gr.SelectData): |
| name = list(EXAMPLES.keys())[evt.index] |
| data = EXAMPLES[name] |
| image = load_image_from_url(data["url"]) |
| credit = f"**Image Credit:** {data['credit']}" |
| return image, credit |
|
|
|
|
| def export_svg(image, path): |
| width, height = image.size |
| dwg = svgwrite.Drawing(path, size=(width, height)) |
|
|
| buffer = BytesIO() |
| image.save(buffer, format="PNG") |
| encoded = base64.b64encode(buffer.getvalue()).decode() |
|
|
| dwg.add( |
| dwg.image( |
| href=f"data:image/png;base64,{encoded}", |
| insert=(0, 0), |
| size=(width, height), |
| ) |
| ) |
|
|
| dwg.save() |
|
|
|
|
| |
| |
| |
|
|
| def generate( |
| image, |
| patches, |
| max_size, |
| spacing, |
| border_thickness, |
| border_r, |
| border_g, |
| border_b, |
| use_padding, |
| pad_r, |
| pad_g, |
| pad_b, |
| supersample, |
| scale_mode, |
| dpi_value, |
| background_mode, |
| rounded, |
| clipping, |
| remove_bg, |
| ): |
| if image is None: |
| return None, None, None |
| |
| if remove_bg: |
| image = remove(image, session=rembg_session) |
|
|
| border_color = rgb_to_hex(border_r, border_g, border_b) |
| |
| padding_color = None |
| if use_padding: |
| |
| if image.mode == "RGBA": |
| padding_color = None |
| else: |
| padding_color = rgb_to_hex(pad_r, pad_g, pad_b) |
|
|
| mosaic, _ = make_vit_mosaic( |
| image, |
| target_total_patches=(patches,), |
| max_long_side=max_size, |
| spacing=spacing, |
| border_thickness=border_thickness, |
| border_color=border_color, |
| padding_color=padding_color, |
| supersample=supersample, |
| output_scale_mode=scale_mode, |
| rounded=rounded, |
| true_clipping=clipping, |
| ) |
|
|
| if background_mode == "White": |
| white_bg = Image.new("RGBA", mosaic.size, (255, 255, 255, 255)) |
| white_bg.paste(mosaic, (0, 0), mosaic) |
| mosaic = white_bg |
|
|
| tmp_dir = tempfile.mkdtemp() |
| png_path = os.path.join(tmp_dir, "vit_mosaic.png") |
| svg_path = os.path.join(tmp_dir, "vit_mosaic.svg") |
|
|
| mosaic.save(png_path, dpi=(dpi_value, dpi_value)) |
| export_svg(mosaic, svg_path) |
|
|
| return mosaic, png_path, svg_path |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks() as demo: |
|
|
| gr.Markdown("# π§© ViT Patch Mosaic Generator") |
|
|
| with gr.Row(): |
|
|
| |
| with gr.Column(scale=1): |
|
|
| gr.Markdown("### β¨ Example Images") |
|
|
| gallery = gr.Gallery( |
| value=[v["url"] for v in EXAMPLES.values()], |
| columns=3, |
| height=250, |
| allow_preview=False, |
| object_fit="contain", |
| ) |
|
|
| gr.Markdown("### β Parameters") |
|
|
| patches = gr.Radio([12, 16], value=16, label="Number of patches") |
|
|
| max_size = gr.Slider( |
| 128, 1024, |
| value=512, |
| step=64, |
| label="Max long side" |
| ) |
|
|
| spacing = gr.Slider(0, 40, value=12, label="Spacing") |
|
|
| border_thickness = gr.Slider( |
| 0, 50, |
| value=14, |
| label="Border thickness" |
| ) |
|
|
| gr.Markdown("### π¨ Border Color") |
|
|
| border_r = gr.Slider(0, 255, value=0, label="R") |
| border_g = gr.Slider(0, 255, value=255, label="G") |
| border_b = gr.Slider(0, 255, value=255, label="B") |
|
|
| color_preview = gr.HTML(update_color_preview(0, 255, 255)) |
|
|
| with gr.Accordion("π§± Padding Settings", open=False): |
|
|
| use_padding = gr.Checkbox(value=False, label="Enable padding color") |
| pad_r = gr.Slider(0, 255, value=255, label="Pad R") |
| pad_g = gr.Slider(0, 255, value=255, label="Pad G") |
| pad_b = gr.Slider(0, 255, value=255, label="Pad B") |
|
|
| with gr.Accordion("β Advanced Settings", open=False): |
|
|
| rounded = gr.Checkbox(value=True, label="Enable rounded corners") |
| clipping = gr.Checkbox(value=True, label="True rounded clipping") |
| remove_bg = gr.Checkbox(value=False, label="Remove Background (rembg)") |
|
|
| supersample = gr.Slider(1, 4, value=2, step=1) |
| scale_mode = gr.Radio(["keep", "downscale"], value="keep") |
| dpi_value = gr.Slider(72, 600, value=300, step=1) |
| background_mode = gr.Radio(["Transparent", "White"], value="Transparent") |
|
|
| generate_btn = gr.Button("Generate Mosaic") |
|
|
| |
| with gr.Column(scale=1): |
|
|
| gr.Markdown("### π₯ Selected Image") |
|
|
| input_image = gr.Image(type="pil", image_mode="RGBA", height=250) |
| credit_display = gr.Markdown("") |
|
|
| gr.Markdown("### πΌ Mosaic Preview") |
|
|
| output_image = gr.Image(type="pil", height=350) |
|
|
| download_png = gr.File(label="Download PNG") |
| download_svg = gr.File(label="Download SVG") |
|
|
| |
| |
| |
|
|
| gallery.select( |
| fn=on_gallery_select, |
| outputs=[input_image, credit_display], |
| ) |
|
|
| border_r.change(update_color_preview, [border_r, border_g, border_b], color_preview) |
| border_g.change(update_color_preview, [border_r, border_g, border_b], color_preview) |
| border_b.change(update_color_preview, [border_r, border_g, border_b], color_preview) |
|
|
| rounded.change(toggle_clipping, rounded, clipping) |
|
|
| generate_btn.click( |
| fn=generate, |
| inputs=[ |
| input_image, |
| patches, |
| max_size, |
| spacing, |
| border_thickness, |
| border_r, |
| border_g, |
| border_b, |
| use_padding, |
| pad_r, |
| pad_g, |
| pad_b, |
| supersample, |
| scale_mode, |
| dpi_value, |
| background_mode, |
| rounded, |
| clipping, |
| remove_bg, |
| ], |
| outputs=[output_image, download_png, download_svg], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Launch the gradio demo") |
| parser.add_argument('--share', action="store_true") |
| args = parser.parse_args() |
| demo.launch(share=args.share) |