| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def get_next_click3D_torch_no_gt(prev_seg, img3D, threshold=170): |
| """Selects prompt clicks from thresholded image (img3D) based on the previous segmentation (prev_seg). |
| |
| Args: |
| prev_seg (torch.tensor): segmentation masks from previous iteration |
| img3D (torch.tensor): input images |
| threshold (int, optional): threshold value to apply to image for selecting point click. Defaults to 170. |
| |
| Returns: |
| batch_points (list of torch.tensor): list of points to click |
| batch_labels (list of torch.tensor): list of labels corresponding to the points |
| NOTE: In this case, the labels are based on the thresholded image and not the ground truth. |
| """ |
|
|
| mask_threshold = 0.5 |
| batch_points = [] |
| batch_labels = [] |
|
|
| pred_masks = prev_seg > mask_threshold |
| likely_masks = img3D > threshold |
| fn_masks = torch.logical_and(likely_masks, torch.logical_not(pred_masks)) |
| |
| |
|
|
| for i in range(prev_seg.shape[0]): |
|
|
| fn_points = torch.argwhere(fn_masks[i]) |
| point = None |
| if len(fn_points) > 0: |
| point = fn_points[np.random.randint(len(fn_points))] |
| is_positive = True |
| |
| if point is None: |
| point = torch.Tensor( |
| [np.random.randint(sz) for sz in fn_masks[i].size()] |
| ).to(torch.int64) |
| is_positive = False |
| bp = point[1:].clone().detach().reshape(1, 1, -1).to(pred_masks.device) |
| bl = ( |
| torch.tensor( |
| [ |
| int(is_positive), |
| ] |
| ) |
| .reshape(1, 1) |
| .to(pred_masks.device) |
| ) |
|
|
| batch_points.append(bp) |
| batch_labels.append(bl) |
|
|
| return batch_points, batch_labels |
|
|
|
|
| def get_next_click3D_torch_no_gt_naive(prev_seg): |
| """Selects prompt clicks from the area outside predicted masks based on previous segmentation (prev_seg). |
| |
| Args: |
| prev_seg (torch.tensor): segmentation masks from previous iteration |
| |
| Returns: |
| batch_points (list of torch.tensor): list of points to click |
| batch_labels (list of torch.tensor): list of labels corresponding to the points |
| NOTE: In this case, the labels are based on the predicted masks and not the ground truth. |
| """ |
| mask_threshold = 0.5 |
|
|
| batch_points = [] |
| batch_labels = [] |
|
|
| pred_masks = prev_seg > mask_threshold |
| uncertain_masks = torch.logical_xor( |
| pred_masks, pred_masks |
| ) |
|
|
| for i in range(prev_seg.shape[0]): |
| uncertain_region = torch.logical_or(uncertain_masks[i, 0], pred_masks[i, 0]) |
| points = torch.argwhere(uncertain_region) |
|
|
| if len(points) > 0: |
| point = points[np.random.randint(len(points))] |
| is_positive = pred_masks[i, 0, point[1], point[2], point[3]] |
|
|
| bp = point[1:].clone().detach().reshape(1, 1, 3) |
| bl = torch.tensor([int(is_positive)], dtype=torch.long).reshape(1, 1) |
| batch_points.append(bp) |
| batch_labels.append(bl) |
| else: |
| point = torch.Tensor( |
| [np.random.randint(sz) for sz in pred_masks[i, 0].size()] |
| ).to(torch.int64) |
| is_positive = pred_masks[i, 0, point[1], point[2], point[3]] |
|
|
| bp = point[1:].clone().detach().reshape(1, 1, 3) |
| bl = torch.tensor([int(is_positive)], dtype=torch.long).reshape(1, 1) |
| batch_points.append(bp) |
| batch_labels.append(bl) |
|
|
| return batch_points, batch_labels |
|
|
|
|
| def get_next_click3D_torch(prev_seg, gt_semantic_seg): |
|
|
| mask_threshold = 0.5 |
|
|
| batch_points = [] |
| batch_labels = [] |
| |
|
|
| pred_masks = prev_seg > mask_threshold |
| true_masks = gt_semantic_seg > 0 |
| fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks)) |
| fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks) |
|
|
| for i in range(gt_semantic_seg.shape[0]): |
|
|
| fn_points = torch.argwhere(fn_masks[i]) |
| fp_points = torch.argwhere(fp_masks[i]) |
| point = None |
| if len(fn_points) > 0 and len(fp_points) > 0: |
| if np.random.random() > 0.5: |
| point = fn_points[np.random.randint(len(fn_points))] |
| is_positive = True |
| else: |
| point = fp_points[np.random.randint(len(fp_points))] |
| is_positive = False |
| elif len(fn_points) > 0: |
| point = fn_points[np.random.randint(len(fn_points))] |
| is_positive = True |
| elif len(fp_points) > 0: |
| point = fp_points[np.random.randint(len(fp_points))] |
| is_positive = False |
| |
| if point is None: |
| point = torch.Tensor( |
| [np.random.randint(sz) for sz in fn_masks[i].size()] |
| ).to(torch.int64) |
| is_positive = False |
| bp = point[1:].clone().detach().reshape(1, 1, -1).to(pred_masks.device) |
| bl = ( |
| torch.tensor( |
| [ |
| int(is_positive), |
| ] |
| ) |
| .reshape(1, 1) |
| .to(pred_masks.device) |
| ) |
|
|
| batch_points.append(bp) |
| batch_labels.append(bl) |
|
|
| return batch_points, batch_labels |
|
|
|
|
| import edt |
|
|
|
|
| def get_next_click3D_torch_ritm(prev_seg, gt_semantic_seg): |
| mask_threshold = 0.5 |
|
|
| batch_points = [] |
| batch_labels = [] |
| |
|
|
| pred_masks = prev_seg > mask_threshold |
| true_masks = gt_semantic_seg > 0 |
| fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks)) |
| fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks) |
|
|
| fn_mask_single = F.pad(fn_masks, (1, 1, 1, 1, 1, 1), "constant", value=0).to( |
| torch.uint8 |
| )[0, 0] |
| fp_mask_single = F.pad(fp_masks, (1, 1, 1, 1, 1, 1), "constant", value=0).to( |
| torch.uint8 |
| )[0, 0] |
| fn_mask_dt = torch.tensor( |
| edt.edt(fn_mask_single.cpu().numpy(), black_border=True, parallel=4) |
| )[1:-1, 1:-1, 1:-1] |
| fp_mask_dt = torch.tensor( |
| edt.edt(fp_mask_single.cpu().numpy(), black_border=True, parallel=4) |
| )[1:-1, 1:-1, 1:-1] |
| fn_max_dist = torch.max(fn_mask_dt) |
| fp_max_dist = torch.max(fp_mask_dt) |
| is_positive = ( |
| fn_max_dist > fp_max_dist |
| ) |
| dt = fn_mask_dt if is_positive else fp_mask_dt |
| to_point_mask = dt > (max(fn_max_dist, fp_max_dist) / 2.0) |
| to_point_mask = to_point_mask[None, None] |
| |
|
|
| for i in range(gt_semantic_seg.shape[0]): |
| points = torch.argwhere(to_point_mask[i]) |
| point = points[np.random.randint(len(points))] |
| if fn_masks[i, 0, point[1], point[2], point[3]]: |
| is_positive = True |
| else: |
| is_positive = False |
|
|
| bp = point[1:].clone().detach().reshape(1, 1, 3) |
| bl = torch.tensor( |
| [ |
| int(is_positive), |
| ] |
| ).reshape(1, 1) |
| batch_points.append(bp) |
| batch_labels.append(bl) |
|
|
| return batch_points, batch_labels |
|
|
|
|
| def get_next_click3D_torch_2(prev_seg, gt_semantic_seg): |
|
|
| mask_threshold = 0.5 |
|
|
| batch_points = [] |
| batch_labels = [] |
| |
|
|
| pred_masks = prev_seg > mask_threshold |
| true_masks = gt_semantic_seg > 0 |
| fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks)) |
| fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks) |
|
|
| to_point_mask = torch.logical_or(fn_masks, fp_masks) |
|
|
| for i in range(gt_semantic_seg.shape[0]): |
|
|
| points = torch.argwhere(to_point_mask[i]) |
| point = points[np.random.randint(len(points))] |
| |
| if fn_masks[i, 0, point[1], point[2], point[3]]: |
| is_positive = True |
| else: |
| is_positive = False |
|
|
| bp = point[1:].clone().detach().reshape(1, 1, 3) |
| bl = torch.tensor( |
| [ |
| int(is_positive), |
| ] |
| ).reshape(1, 1) |
| batch_points.append(bp) |
| batch_labels.append(bl) |
|
|
| return batch_points, batch_labels |
|
|
|
|
| def get_next_click3D_torch_with_dice(prev_seg, gt_semantic_seg): |
|
|
| def compute_dice(mask_pred, mask_gt): |
| mask_threshold = 0.5 |
|
|
| mask_pred = mask_pred > mask_threshold |
| |
| mask_gt = mask_gt > 0 |
|
|
| volume_sum = mask_gt.sum() + mask_pred.sum() |
| if volume_sum == 0: |
| return np.NaN |
| volume_intersect = (mask_gt & mask_pred).sum() |
| return 2 * volume_intersect / volume_sum |
|
|
| mask_threshold = 0.5 |
|
|
| batch_points = [] |
| batch_labels = [] |
| dice_list = [] |
|
|
| pred_masks = prev_seg > mask_threshold |
| true_masks = gt_semantic_seg > 0 |
| fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks)) |
| fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks) |
|
|
| for i in range(gt_semantic_seg.shape[0]): |
|
|
| fn_points = torch.argwhere(fn_masks[i]) |
| fp_points = torch.argwhere(fp_masks[i]) |
| if len(fn_points) > 0 and len(fp_points) > 0: |
| if np.random.random() > 0.5: |
| point = fn_points[np.random.randint(len(fn_points))] |
| is_positive = True |
| else: |
| point = fp_points[np.random.randint(len(fp_points))] |
| is_positive = False |
| elif len(fn_points) > 0: |
| point = fn_points[np.random.randint(len(fn_points))] |
| is_positive = True |
| elif len(fp_points) > 0: |
| point = fp_points[np.random.randint(len(fp_points))] |
| is_positive = False |
| |
| bp = point[1:].clone().detach().reshape(1, 1, 3) |
| bl = torch.tensor( |
| [ |
| int(is_positive), |
| ] |
| ).reshape(1, 1) |
| batch_points.append(bp) |
| batch_labels.append(bl) |
| dice_list.append(compute_dice(pred_masks[i], true_masks[i])) |
|
|
| return batch_points, batch_labels, (sum(dice_list) / len(dice_list)).item() |
|
|
|
|
| def show_mask(mask, ax, random_color=False): |
| if random_color: |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| else: |
| color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) |
| h, w = mask.shape[-2:] |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| ax.imshow(mask_image) |
|
|
|
|
| def show_point(point, label, ax): |
| if label == 0: |
| ax.add_patch(plt.Circle((point[1], point[0]), 1, color="red")) |
| else: |
| ax.add_patch(plt.Circle((point[1], point[0]), 1, color="green")) |
| |
|
|
|
|
| if __name__ == "__main__": |
| gt2D = torch.randn((2, 1, 256, 256)).cuda() |
| prev_masks = torch.zeros_like(gt2D).to(gt2D.device) |
| batch_points, batch_labels = get_next_click3D_torch( |
| prev_masks.to(gt2D.device), gt2D |
| ) |
| print(batch_points) |
|
|