Spaces:
Build error
Build error
File size: 3,255 Bytes
da50507 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import cv2
import numpy as np
import torch
import threading
from torchvision import transforms
from clip.clipseg import CLIPDensePredT
import numpy as np
from roop.typing import Frame
THREAD_LOCK_CLIP = threading.Lock()
class Mask_Clip2Seg:
plugin_options: dict = None
model_clip = None
processorname = "clip2seg"
type = "mask"
def Initialize(self, plugin_options: dict):
if self.plugin_options is not None:
if self.plugin_options["devicename"] != plugin_options["devicename"]:
self.Release()
self.plugin_options = plugin_options
if self.model_clip is None:
self.model_clip = CLIPDensePredT(
version="ViT-B/16", reduce_dim=64, complex_trans_conv=True
)
self.model_clip.eval()
self.model_clip.load_state_dict(
torch.load(
"models/CLIP/rd64-uni-refined.pth", map_location=torch.device("cpu")
),
strict=False,
)
device = torch.device(self.plugin_options["devicename"])
self.model_clip.to(device)
def Run(self, img1, keywords: str) -> Frame:
if keywords is None or len(keywords) < 1 or img1 is None:
return img1
source_image_small = cv2.resize(img1, (256, 256))
img_mask = np.full(
(source_image_small.shape[0], source_image_small.shape[1]),
0,
dtype=np.float32,
)
mask_border = 1
l = 0
t = 0
r = 1
b = 1
mask_blur = 5
clip_blur = 5
img_mask = cv2.rectangle(
img_mask,
(mask_border + int(l), mask_border + int(t)),
(256 - mask_border - int(r), 256 - mask_border - int(b)),
(255, 255, 255),
-1,
)
img_mask = cv2.GaussianBlur(img_mask, (mask_blur * 2 + 1, mask_blur * 2 + 1), 0)
img_mask /= 255
input_image = source_image_small
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.Resize((256, 256)),
]
)
img = transform(input_image).unsqueeze(0)
thresh = 0.5
prompts = keywords.split(",")
with THREAD_LOCK_CLIP:
with torch.no_grad():
preds = self.model_clip(img.repeat(len(prompts), 1, 1, 1), prompts)[0]
clip_mask = torch.sigmoid(preds[0][0])
for i in range(len(prompts) - 1):
clip_mask += torch.sigmoid(preds[i + 1][0])
clip_mask = clip_mask.data.cpu().numpy()
np.clip(clip_mask, 0, 1)
clip_mask[clip_mask > thresh] = 1.0
clip_mask[clip_mask <= thresh] = 0.0
kernel = np.ones((5, 5), np.float32)
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
clip_mask = cv2.GaussianBlur(
clip_mask, (clip_blur * 2 + 1, clip_blur * 2 + 1), 0
)
img_mask *= clip_mask
img_mask[img_mask < 0.0] = 0.0
return img_mask
def Release(self):
self.model_clip = None
|