| import os
|
| from PIL import Image
|
| import gradio as gr
|
| import cv2
|
| import numpy as np
|
| from skimage.metrics import structural_similarity as ssim
|
| from skimage.metrics import peak_signal_noise_ratio as psnr
|
|
|
|
|
| def calculate_metrics(original, restored):
|
| """
|
| 计算原始图像与修复图像的PSNR和SSIM
|
| :param original: 原始图像(PIL Image)
|
| :param restored: 修复图像(PIL Image)
|
| :return: PSNR值(float)、SSIM值(float)
|
| """
|
| if original is None or restored is None:
|
| return 0.0, 0.0
|
|
|
| original = original.convert("RGB")
|
| restored = restored.convert("RGB")
|
|
|
| original_np = np.array(original)
|
| restored_np = np.array(restored)
|
|
|
| if original_np.shape != restored_np.shape:
|
| restored_np = cv2.resize(restored_np, (original_np.shape[1], original_np.shape[0]))
|
|
|
| psnr_val = psnr(original_np, restored_np, data_range=255)
|
|
|
| ssim_val = ssim(
|
| original_np,
|
| restored_np,
|
| data_range=255,
|
| multichannel=True,
|
| channel_axis=-1
|
| )
|
|
|
| return round(psnr_val, 2), round(ssim_val, 4)
|
|
|
| def save_mask(input_image):
|
| if input_image is None:
|
| return []
|
| if isinstance(input_image, dict) and "composite" in input_image:
|
| composite = input_image["composite"]
|
| img_pil = composite.convert("RGB")
|
| else:
|
| img_pil = input_image
|
|
|
| img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
|
| lower_red1 = np.array([0, 100, 100])
|
| upper_red1 = np.array([10, 255, 255])
|
| lower_red2 = np.array([160, 100, 100])
|
| upper_red2 = np.array([180, 255, 255])
|
|
|
| hsv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2HSV)
|
|
|
| mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
| mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
| red_mask = mask1 + mask2
|
|
|
|
|
| kernel = np.ones((5, 5), np.uint8)
|
| red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_CLOSE, kernel)
|
| red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_OPEN, kernel)
|
|
|
| mask_image = Image.fromarray(red_mask).convert('L')
|
| return [(mask_image,"mask")]
|
|
|
|
|
|
|
| def create_interactive_generative_inpainting(runner):
|
| def run_inpainting(sketch, mask):
|
| inpainting_result = runner.run_generative_inpainting(sketch, mask)
|
| original_img = None
|
| if isinstance(sketch, dict) and "background" in sketch:
|
| original_img = sketch["background"].convert("RGB")
|
| if isinstance(inpainting_result, list) and len(inpainting_result) > 0:
|
| inpainting_result = inpainting_result[0]
|
| psnr_val, ssim_val = 0.0, 0.0
|
| if original_img is not None and inpainting_result is not None:
|
| psnr_val, ssim_val = calculate_metrics(original_img, inpainting_result)
|
| return [(inpainting_result,"inpainting")],original_img,psnr_val,ssim_val
|
|
|
| with gr.Blocks():
|
| with gr.Row():
|
| gr.Markdown(
|
| '1. 上传输入图片.\n'
|
| '2. 在上传的图片上绘制掩码,然后点击“Save”保存,或者你可以上传你自己的掩码\n'
|
| '3. 点击`Inpainting`修复图像.'
|
| )
|
| with gr.Row():
|
| with gr.Column():
|
| sketch = gr.ImageEditor(label='Drawing mask',type="pil")
|
| output=gr.Gallery(label='Mask',elem_id='gallery',columns=2,height='auto',preview=True)
|
| btn = gr.Button(value="Save")
|
| btn.click(
|
| fn=save_mask,
|
| inputs=[sketch],
|
| outputs=[output]
|
| )
|
| run_button = gr.Button(value="Inpainting")
|
| with gr.Column():
|
| gr.Markdown('#### 修复图像:\n')
|
| result_gallery = gr.Gallery(label='Inpainting Image',elem_id='gallery',columns=2,height='auto',preview=True)
|
| gr.Markdown('#### 原始图像:\n')
|
| original_preview = gr.Image(label='Original Image', type="pil")
|
| psnr_display = gr.Number(label="PSNR")
|
| ssim_display = gr.Number(label="SSIM")
|
| ips = [sketch,output]
|
| run_button.click(fn=run_inpainting, inputs=ips,outputs=[result_gallery,original_preview,psnr_display,ssim_display])
|
|
|
|
|