| import gradio as gr |
| from gradio.components.image_editor import EditorValue |
| from gradio_imageslider import ImageSlider |
| from PIL import Image |
| from typing import cast |
| import numpy as np |
| from simple_lama_inpainting import SimpleLama |
|
|
|
|
| simple_lama = SimpleLama() |
|
|
| def HWC3(x): |
| if x.ndim == 2: |
| x = x[:, :, None] |
| H, W, C = x.shape |
| if C == 3: |
| return x |
| if C == 1: |
| return np.concatenate([x, x, x], axis=2) |
| if C == 4: |
| color = x[:, :, 0:3].astype(np.float32) |
| alpha = x[:, :, 3:4].astype(np.float32) / 255.0 |
| y = color * alpha + 255.0 * (1.0 - alpha) |
| y = y.clip(0, 255).astype(np.uint8) |
| return y |
|
|
| def process_image( |
| image: Image.Image | str | None, |
| mask: Image.Image | str | None, |
| progress: gr.Progress = gr.Progress(), |
| ) -> Image.Image | None: |
| progress(0, desc="Preparing inputs...") |
| if image is None or mask is None: |
| return None |
| |
| if isinstance(mask, str): |
| mask = Image.open(mask) |
| if isinstance(image, str): |
| image = Image.open(image) |
| image = np.array(image) |
| image = HWC3(image) |
|
|
| result = simple_lama(image, mask) |
| result.save("inpainted.png") |
| return result |
|
|
| def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: |
| if img.width <= min_side_length and img.height <= min_side_length: |
| return img |
|
|
| aspect_ratio = img.width / img.height |
| if img.width < img.height: |
| new_height = int(min_side_length / aspect_ratio) |
| return img.resize((min_side_length, new_height)) |
|
|
| new_width = int(min_side_length * aspect_ratio) |
| return img.resize((new_width, min_side_length)) |
|
|
|
|
| async def process( |
| image_and_mask: EditorValue | None, |
| progress: gr.Progress = gr.Progress(), |
| ) -> tuple[Image.Image, Image.Image] | None: |
| if not image_and_mask: |
| gr.Info("Please upload an image and draw a mask") |
| return None |
|
|
|
|
| image_np = image_and_mask["background"] |
| image_np = cast(np.ndarray, image_np) |
|
|
| if np.sum(image_np) == 0: |
| gr.Info("Please upload an image") |
| return None |
|
|
| alpha_channel = image_and_mask["layers"][0] |
| alpha_channel = cast(np.ndarray, alpha_channel) |
| mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) |
|
|
| if np.sum(mask_np) == 0: |
| gr.Info("Please mark the areas you want to remove") |
| return None |
|
|
| mask = Image.fromarray(mask_np) |
| mask = resize_image(mask) |
|
|
| image = Image.fromarray(image_np) |
| image = resize_image(image) |
|
|
| output = process_image( |
| image, |
| mask, |
| progress, |
| ) |
|
|
| if output is None: |
| gr.Info("Processing failed") |
| return None |
| progress(100, desc="Processing completed") |
| return image, output |
|
|
|
|
| with gr.Blocks() as demo: |
| with gr.Row(): |
| with gr.Column(): |
| image_and_mask = gr.ImageMask( |
| label="Upload Image and Draw Mask", |
| layers=False, |
| show_fullscreen_button=False, |
| sources=["upload"], |
| show_download_button=False, |
| interactive=True, |
| height="full", |
| width="full", |
| brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), |
| transforms=[], |
| ) |
|
|
|
|
| with gr.Column(): |
| image_slider = ImageSlider( |
| label="Result", |
| interactive=False, |
| ) |
|
|
| process_btn = gr.ClearButton( |
| value="Run", |
| variant="primary", |
| size="lg", |
| components=[image_slider], |
| ) |
|
|
| process_btn.click( |
| fn=lambda _: gr.update(interactive=False, value="Processing..."), |
| inputs=[], |
| outputs=[process_btn], |
| api_name=False, |
| ).then( |
| fn=process, |
| inputs=[ |
| image_and_mask, |
| ], |
| outputs=[image_slider], |
| api_name=False, |
| ).then( |
| fn=lambda _: gr.update(interactive=True, value="Run"), |
| inputs=[], |
| outputs=[process_btn], |
| api_name=False, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| debug=False, |
| share=False, |
| show_api=False, |
| ) |