| print("Importing standard...") |
| from abc import ABC, abstractmethod |
|
|
| print("Importing external...") |
| import torch |
| from torch.nn.functional import binary_cross_entropy |
|
|
| |
|
|
| print("Importing internal...") |
| from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou |
|
|
|
|
| |
| def my_lovasz_hinge(logits, gt, downsample=False): |
| if downsample: |
| offset = int(torch.randint(downsample - 1, (1,))) |
| logits, gt = logits[:, offset::downsample], gt[:, offset::downsample] |
| |
| gt = 1.0 * gt |
| areas = gt.sum(dim=1, keepdims=True) |
| |
| signs = 2 * gt - 1 |
| errors = 1 - logits * signs |
| errors_sorted, perm = torch.sort(errors, dim=1, descending=True) |
| gt_sorted = torch.gather(gt, 1, perm) |
| |
| intersection = areas - gt_sorted.cumsum(dim=1) |
| union = areas + (1 - gt_sorted).cumsum(dim=1) |
| jaccard = 1 - intersection / union |
| jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1] |
| loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) |
| return torch.nanmean(loss) |
|
|
|
|
| def focal_loss(scores, targets, alpha=0.25, gamma=2): |
| p = scores |
| ce_loss = binary_cross_entropy(p, targets, reduction="none") |
| p_t = p * targets + (1 - p) * (1 - targets) |
| loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
| if alpha >= 0: |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
| loss = alpha_t * loss |
|
|
| return loss |
|
|
|
|
| |
|
|
|
|
| |
| def get_distances(features, refs, sigma, norm_p, square_distances, H, W): |
| |
| |
| |
| B, M = refs.shape[0], refs.shape[1] |
| distances = torch.norm( |
| features - refs, dim=2, p=norm_p, keepdim=True |
| ) |
| distances = distances**2 if square_distances else distances |
| distances = (distances / (2 * sigma**2)).reshape(B, M, H * W) |
| return distances |
|
|
|
|
| def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction): |
| |
| |
| assert activation in ["sigmoid", "symlog"] |
| if masks is None: |
| B, M = 1, 1 |
| F, N = sorted(features.shape) |
| H, W = [int(N ** (0.5))] * 2 |
| features = features.reshape(1, 1, -1, H * W) |
| else: |
| masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) |
| |
| |
| if use_sigma: |
| sigma = torch.nn.functional.softplus(features)[:, :, -1:] |
| features = features[:, :, :-1] |
| F = features.shape[2] |
| else: |
| sigma = 1 |
| features = symlog(features) if activation == "symlog" else torch.sigmoid(features) |
| if offset_pos: |
| assert F >= 2 |
| row, col = get_row_col(H, W, features.device) |
| row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) |
| col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) |
| positional_features = torch.cat([row, col], dim=2) |
| features[:, :, :2] = features[:, :, :2] + positional_features |
| prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None |
| if masks is None: |
| features = features.reshape(-1, H * W) |
| sigma = sigma.reshape(-1, H * W) if use_sigma else 1 |
| return features, sigma, H, W |
| return features, masks, sigma, prediction, B, M, F, H, W |
|
|
|
|
| class AbstractLoss(ABC): |
| @staticmethod |
| @abstractmethod |
| def loss(features, masks, ret_prediction=False, **kwargs): |
| pass |
|
|
| @staticmethod |
| @abstractmethod |
| def get_mask_from_query(features, sindex, **kwargs): |
| pass |
|
|
|
|
| class IISLoss(AbstractLoss): |
| @staticmethod |
| def loss(features, masks, ret_prediction=False, K=3, logger=None): |
| features, masks, sigma, prediction, B, M, F, H, W = activate( |
| features, masks, "symlog", False, False, ret_prediction |
| ) |
| rindices = torch.randperm(H * W, device=masks.device) |
| |
| sindices = torch.stack( |
| [ |
| torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) |
| for b in range(B) |
| ] |
| ) |
| feats_at_sindices = torch.gather( |
| features.permute(0, 3, 1, 2).expand(B, H * W, K, F), |
| dim=1, |
| index=sindices.reshape(B, M, K, 1).expand(B, M, K, F), |
| ) |
| feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) |
| dists = get_distances( |
| features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W |
| ) |
| score = torch.exp(-dists) |
| targets = ( |
| masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float() |
| ) |
| floss = focal_loss(score, targets).mean() |
| lloss = my_lovasz_hinge( |
| score.view(B * M * K, H * W) * 2 - 1, |
| targets.view(B * M * K, H * W), |
| ) |
| loss = floss + lloss |
| return loss, prediction |
|
|
| @staticmethod |
| def get_mask_from_query(features, sindex): |
| features, _, H, W = activate(features, None, "symlog", False, False, False) |
| F = features.shape[0] |
| query_feat = features[:, sindex] |
| dists = get_distances( |
| features.reshape(1, 1, F, H * W), |
| query_feat.reshape(1, 1, F, 1), |
| 1, |
| 2, |
| True, |
| H, |
| W, |
| ) |
| score = torch.exp(-dists) |
| pred = score > 0.5 |
| return pred |
|
|
|
|
| def iis_iou(features, masks, get_mask_from_query, K=20): |
| masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) |
| |
| |
| rindices = torch.randperm(H * W).to(masks.device) |
| sindices = torch.stack( |
| [ |
| torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) |
| for b in range(B) |
| ] |
| ) |
| cum_iou, n_samples = 0, 0 |
| for b in range(B): |
| for m in range(M): |
| for k in range(K): |
| sindex = sindices[b, m, k] |
| pred = get_mask_from_query(features[b, 0], sindex) |
| iou = calculate_iou(pred, masks[b, m, 0, :]) |
| cum_iou += iou |
| n_samples += 1 |
|
|
| return cum_iou / n_samples |
|
|
|
|
| losses_names = [ |
| "iis", |
| ] |
| |
|
|
|
|
| def get_loss_class(loss_name): |
| if loss_name == "iis": |
| return IISLoss |
| else: |
| raise NotImplementedError |
|
|
|
|
| def get_get_mask_from_query(loss_name): |
| loss_class = get_loss_class(loss_name) |
| return loss_class.get_mask_from_query |
|
|
|
|
| def get_loss(loss_name): |
| loss_class = get_loss_class(loss_name) |
| return loss_class.loss |
|
|