| """ |
| download pretrained weights to ./weights |
| wget https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth |
| wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth |
| """ |
|
|
| import sys |
|
|
| sys.path.append("maskcut") |
| import numpy as np |
| import PIL.Image as Image |
| import torch |
| from scipy import ndimage |
| from colormap import random_color |
|
|
| import dino |
| from third_party.token_cut.unsupervised_saliency_detection import metric |
| from crf import densecrf |
| from maskcut import maskcut |
|
|
| from cog import BasePredictor, Input, Path |
|
|
|
|
| class Predictor(BasePredictor): |
| def setup(self): |
| """Load the model into memory to make running multiple predictions efficient""" |
|
|
| |
| vit_features = "k" |
| self.patch_size = 8 |
| |
| self.backbone_base = dino.ViTFeat( |
| "weights/dino_vitbase8_pretrain.pth", |
| 768, |
| "base", |
| vit_features, |
| self.patch_size, |
| ) |
|
|
| self.backbone_small = dino.ViTFeat( |
| "weights/dino_deitsmall8_300ep_pretrain.pth", |
| 384, |
| "small", |
| vit_features, |
| self.patch_size, |
| ) |
| self.backbone_base.eval() |
| self.backbone_base.cuda() |
| self.backbone_small.eval() |
| self.backbone_small.cuda() |
|
|
| def predict( |
| self, |
| image: Path = Input( |
| description="Input image", |
| ), |
| model: str = Input( |
| description="Choose the model architecture", |
| default="base", |
| choices=["small", "base"] |
| ), |
| n_pseudo_masks: int = Input( |
| description="The maximum number of pseudo-masks per image", |
| default=3, |
| ), |
| tau: float = Input( |
| description="Threshold used for producing binary graph", |
| default=0.15, |
| ), |
| ) -> Path: |
| """Run a single prediction on the model""" |
|
|
| backbone = self.backbone_base if model == "base" else self.backbone_small |
|
|
| |
| fixed_size = 480 |
|
|
| |
| bipartitions, _, I_new = maskcut( |
| str(image), |
| backbone, |
| self.patch_size, |
| tau, |
| N=n_pseudo_masks, |
| fixed_size=fixed_size, |
| cpu=False, |
| ) |
|
|
| I = Image.open(str(image)).convert("RGB") |
| width, height = I.size |
| pseudo_mask_list = [] |
| for idx, bipartition in enumerate(bipartitions): |
| |
| pseudo_mask = densecrf(np.array(I_new), bipartition) |
| pseudo_mask = ndimage.binary_fill_holes(pseudo_mask >= 0.5) |
|
|
| |
| mask1 = torch.from_numpy(bipartition).cuda() |
| mask2 = torch.from_numpy(pseudo_mask).cuda() |
|
|
| if metric.IoU(mask1, mask2) < 0.5: |
| pseudo_mask = pseudo_mask * -1 |
|
|
| |
| pseudo_mask[pseudo_mask < 0] = 0 |
| pseudo_mask = Image.fromarray(np.uint8(pseudo_mask * 255)) |
| pseudo_mask = np.asarray(pseudo_mask.resize((width, height))) |
|
|
| pseudo_mask = pseudo_mask.astype(np.uint8) |
| upper = np.max(pseudo_mask) |
| lower = np.min(pseudo_mask) |
| thresh = upper / 2.0 |
| pseudo_mask[pseudo_mask > thresh] = upper |
| pseudo_mask[pseudo_mask <= thresh] = lower |
| pseudo_mask_list.append(pseudo_mask) |
|
|
| out = np.array(I) |
| for pseudo_mask in pseudo_mask_list: |
|
|
| out = vis_mask(out, pseudo_mask, random_color(rgb=True)) |
|
|
| output_path = f"/tmp/out.png" |
|
|
| out.save(str(output_path)) |
|
|
| return Path(output_path) |
|
|
|
|
| def vis_mask(input, mask, mask_color): |
| fg = mask > 0.5 |
| rgb = np.copy(input) |
| rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8) |
| return Image.fromarray(rgb) |
|
|