Exploration_Platform / webui /tab_generative_inpainting.py
HZSDU's picture
Add files using upload-large-folder tool
dfb6163 verified
import os
from PIL import Image
import gradio as gr
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
def calculate_metrics(original, restored):
"""
计算原始图像与修复图像的PSNR和SSIM
:param original: 原始图像(PIL Image)
:param restored: 修复图像(PIL Image)
:return: PSNR值(float)、SSIM值(float)
"""
if original is None or restored is None:
return 0.0, 0.0 # 若图像为空,返回默认值
original = original.convert("RGB")
restored = restored.convert("RGB")
original_np = np.array(original)
restored_np = np.array(restored)
if original_np.shape != restored_np.shape:
restored_np = cv2.resize(restored_np, (original_np.shape[1], original_np.shape[0]))
psnr_val = psnr(original_np, restored_np, data_range=255)
ssim_val = ssim(
original_np,
restored_np,
data_range=255,
multichannel=True,
channel_axis=-1
)
return round(psnr_val, 2), round(ssim_val, 4)
def save_mask(input_image):
if input_image is None:
return []
if isinstance(input_image, dict) and "composite" in input_image:
composite = input_image["composite"]
img_pil = composite.convert("RGB") # 确保转换为RGB格式
else:
img_pil = input_image # 普通Image组件的情况
img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# 红色范围阈值
lower_red1 = np.array([0, 100, 100])
upper_red1 = np.array([10, 255, 255])
lower_red2 = np.array([160, 100, 100])
upper_red2 = np.array([180, 255, 255])
hsv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2HSV)
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
red_mask = mask1 + mask2
# 形态学操作
kernel = np.ones((5, 5), np.uint8)
red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_CLOSE, kernel)
red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_OPEN, kernel)
mask_image = Image.fromarray(red_mask).convert('L')
return [(mask_image,"mask")]
def create_interactive_generative_inpainting(runner):
def run_inpainting(sketch, mask):
inpainting_result = runner.run_generative_inpainting(sketch, mask)
original_img = None
if isinstance(sketch, dict) and "background" in sketch:
original_img = sketch["background"].convert("RGB")
if isinstance(inpainting_result, list) and len(inpainting_result) > 0:
inpainting_result = inpainting_result[0]
psnr_val, ssim_val = 0.0, 0.0
if original_img is not None and inpainting_result is not None:
psnr_val, ssim_val = calculate_metrics(original_img, inpainting_result)
return [(inpainting_result,"inpainting")],original_img,psnr_val,ssim_val
with gr.Blocks():
with gr.Row():
gr.Markdown(
'1. 上传输入图片.\n'
'2. 在上传的图片上绘制掩码,然后点击“Save”保存,或者你可以上传你自己的掩码\n'
'3. 点击`Inpainting`修复图像.'
)
with gr.Row():
with gr.Column():
sketch = gr.ImageEditor(label='Drawing mask',type="pil")
output=gr.Gallery(label='Mask',elem_id='gallery',columns=2,height='auto',preview=True)
btn = gr.Button(value="Save")
btn.click(
fn=save_mask,
inputs=[sketch],
outputs=[output]
)
run_button = gr.Button(value="Inpainting")
with gr.Column():
gr.Markdown('#### 修复图像:\n')
result_gallery = gr.Gallery(label='Inpainting Image',elem_id='gallery',columns=2,height='auto',preview=True)
gr.Markdown('#### 原始图像:\n')
original_preview = gr.Image(label='Original Image', type="pil")
psnr_display = gr.Number(label="PSNR")
ssim_display = gr.Number(label="SSIM")
ips = [sketch,output]
run_button.click(fn=run_inpainting, inputs=ips,outputs=[result_gallery,original_preview,psnr_display,ssim_display])