File size: 4,116 Bytes
08ec965 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
download pretrained weights to ./weights
wget https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth
wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth
"""
import sys
sys.path.append("maskcut")
import numpy as np
import PIL.Image as Image
import torch
from scipy import ndimage
from colormap import random_color
import dino
from third_party.token_cut.unsupervised_saliency_detection import metric
from crf import densecrf
from maskcut import maskcut
from cog import BasePredictor, Input, Path
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
# DINO pre-trained model
vit_features = "k"
self.patch_size = 8
# adapted dino.ViTFeat to load from local pretrained_path
self.backbone_base = dino.ViTFeat(
"weights/dino_vitbase8_pretrain.pth",
768,
"base",
vit_features,
self.patch_size,
)
self.backbone_small = dino.ViTFeat(
"weights/dino_deitsmall8_300ep_pretrain.pth",
384,
"small",
vit_features,
self.patch_size,
)
self.backbone_base.eval()
self.backbone_base.cuda()
self.backbone_small.eval()
self.backbone_small.cuda()
def predict(
self,
image: Path = Input(
description="Input image",
),
model: str = Input(
description="Choose the model architecture",
default="base",
choices=["small", "base"]
),
n_pseudo_masks: int = Input(
description="The maximum number of pseudo-masks per image",
default=3,
),
tau: float = Input(
description="Threshold used for producing binary graph",
default=0.15,
),
) -> Path:
"""Run a single prediction on the model"""
backbone = self.backbone_base if model == "base" else self.backbone_small
# MaskCut hyperparameters
fixed_size = 480
# get pesudo-masks with MaskCut
bipartitions, _, I_new = maskcut(
str(image),
backbone,
self.patch_size,
tau,
N=n_pseudo_masks,
fixed_size=fixed_size,
cpu=False,
)
I = Image.open(str(image)).convert("RGB")
width, height = I.size
pseudo_mask_list = []
for idx, bipartition in enumerate(bipartitions):
# post-process pesudo-masks with CRF
pseudo_mask = densecrf(np.array(I_new), bipartition)
pseudo_mask = ndimage.binary_fill_holes(pseudo_mask >= 0.5)
# filter out the mask that have a very different pseudo-mask after the CRF
mask1 = torch.from_numpy(bipartition).cuda()
mask2 = torch.from_numpy(pseudo_mask).cuda()
if metric.IoU(mask1, mask2) < 0.5:
pseudo_mask = pseudo_mask * -1
# construct binary pseudo-masks
pseudo_mask[pseudo_mask < 0] = 0
pseudo_mask = Image.fromarray(np.uint8(pseudo_mask * 255))
pseudo_mask = np.asarray(pseudo_mask.resize((width, height)))
pseudo_mask = pseudo_mask.astype(np.uint8)
upper = np.max(pseudo_mask)
lower = np.min(pseudo_mask)
thresh = upper / 2.0
pseudo_mask[pseudo_mask > thresh] = upper
pseudo_mask[pseudo_mask <= thresh] = lower
pseudo_mask_list.append(pseudo_mask)
out = np.array(I)
for pseudo_mask in pseudo_mask_list:
out = vis_mask(out, pseudo_mask, random_color(rgb=True))
output_path = f"/tmp/out.png"
out.save(str(output_path))
return Path(output_path)
def vis_mask(input, mask, mask_color):
fg = mask > 0.5
rgb = np.copy(input)
rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
return Image.fromarray(rgb)
|