| |
| |
|
|
| import numpy as np |
| import pydensecrf.densecrf as dcrf |
| import pydensecrf.utils as utils |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms.functional as VF |
|
|
| MAX_ITER = 10 |
| POS_W = 7 |
| POS_XY_STD = 3 |
| Bi_W = 10 |
| Bi_XY_STD = 50 |
| Bi_RGB_STD = 5 |
|
|
| def densecrf(image, mask): |
| h, w = mask.shape |
| mask = mask.reshape(1, h, w) |
| fg = mask.astype(float) |
| bg = 1 - fg |
| output_logits = torch.from_numpy(np.concatenate((bg,fg), axis=0)) |
|
|
| H, W = image.shape[:2] |
| image = np.ascontiguousarray(image) |
| |
| output_logits = F.interpolate(output_logits.unsqueeze(0), size=(H, W), mode="bilinear").squeeze() |
| output_probs = F.softmax(output_logits, dim=0).cpu().numpy() |
|
|
| c = output_probs.shape[0] |
| h = output_probs.shape[1] |
| w = output_probs.shape[2] |
|
|
| U = utils.unary_from_softmax(output_probs) |
| U = np.ascontiguousarray(U) |
|
|
| d = dcrf.DenseCRF2D(w, h, c) |
| d.setUnaryEnergy(U) |
| d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) |
| d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W) |
|
|
| Q = d.inference(MAX_ITER) |
| Q = np.array(Q).reshape((c, h, w)) |
| MAP = np.argmax(Q, axis=0).reshape((h,w)).astype(np.float32) |
| return MAP |
|
|