| import gradio as gr
|
| import torch
|
| import numpy as np
|
| import cv2
|
| from torchvision import transforms
|
| from PIL import Image
|
| from u2net import U2NET
|
|
|
|
|
| model_path = "u2net.pth"
|
| model = U2NET(3, 1)
|
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| model.eval()
|
|
|
|
|
| def preprocess(image):
|
| transform = transforms.Compose([
|
| transforms.Resize((512, 512)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| ])
|
| return transform(image).unsqueeze(0)
|
|
|
|
|
| def postprocess(output, original_size):
|
| output = output.squeeze().detach().numpy()
|
| output = (output - output.min()) / (output.max() - output.min())
|
| mask = (output * 255).astype(np.uint8)
|
|
|
|
|
| mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
|
|
|
|
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
|
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
|
|
| return mask
|
|
|
|
|
|
|
| def remove_background(image):
|
| original_size = image.size
|
| input_tensor = preprocess(image)
|
| with torch.no_grad():
|
| output = model(input_tensor)[0]
|
| mask = postprocess(output, original_size)
|
|
|
|
|
| transparent_img = np.array(image).astype(np.uint8)
|
| if transparent_img.shape[2] == 3:
|
| transparent_img = cv2.cvtColor(transparent_img, cv2.COLOR_RGB2RGBA)
|
|
|
|
|
| alpha = cv2.GaussianBlur(mask, (15, 15), 0) / 255.0
|
| for c in range(3):
|
| transparent_img[..., c] = transparent_img[..., c] * alpha
|
| transparent_img[..., 3] = (mask > 0).astype(np.uint8) * 255
|
|
|
| return Image.fromarray(transparent_img, "RGBA")
|
|
|
|
|
| iface = gr.Interface(
|
| fn=remove_background,
|
| inputs=gr.Image(type="pil"),
|
| outputs=gr.Image(type="pil"),
|
| title="Background Remover",
|
| description="Upload an image to remove its background using U^2-Net."
|
| )
|
|
|
| if __name__ == "__main__":
|
| iface.launch()
|
|
|