| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import numpy as np |
| import os |
| import requests |
| from model.u2net import U2NET |
|
|
| MODEL_DIR = "saved_models/u2net" |
| MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth") |
| MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth" |
|
|
| def download_model(): |
| if not os.path.exists(MODEL_PATH): |
| os.makedirs(MODEL_DIR, exist_ok=True) |
| print("Downloading model...") |
| r = requests.get(MODEL_URL, stream=True) |
| with open(MODEL_PATH, "wb") as f: |
| for chunk in r.iter_content(chunk_size=8192): |
| f.write(chunk) |
| print("Model downloaded.") |
|
|
| download_model() |
|
|
| def load_model(): |
| net = U2NET(3, 1) |
| net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) |
| net.eval() |
| return net |
|
|
| model = load_model() |
|
|
| def preprocess(image): |
| transform = transforms.Compose([ |
| transforms.Resize((320, 320)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| return transform(image).unsqueeze(0) |
|
|
| def postprocess(mask, original_size): |
| mask = mask.squeeze().cpu().data.numpy() |
| mask = (mask - mask.min()) / (mask.max() - mask.min()) |
| mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR) |
| return mask |
|
|
| def remove_background(image): |
| input_tensor = preprocess(image) |
| with torch.no_grad(): |
| d1, *_ = model(input_tensor) |
| mask = postprocess(d1, image.size) |
|
|
| image = image.convert("RGBA") |
| datas = image.getdata() |
| masks = mask.getdata() |
|
|
| new_data = [] |
| for item, m in zip(datas, masks): |
| new_data.append((item[0], item[1], item[2], m)) |
|
|
| image.putdata(new_data) |
| return image |
|
|