| import numpy as np
|
| import itertools
|
| from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
| """
|
| MaskFormer criterion.
|
| """
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
|
|
| from rscd.losses.loss_util.criterion import SetCriterion
|
| from rscd.losses.loss_util.matcher import HungarianMatcher
|
|
|
| class Mask2formerLoss(nn.Module):
|
| def __init__(self, class_weight=2.0,
|
| dice_weight=5.0,
|
| mask_weight=5.0,
|
| no_object_weight=0.1,
|
| dec_layers = 10,
|
| num_classes = 1,
|
| device="cuda:0"):
|
| super(Mask2formerLoss, self).__init__()
|
| self.device = device
|
| self.class_weight = class_weight
|
| self.dice_weight = dice_weight
|
| self.mask_weight = mask_weight
|
| self.no_object_weight = no_object_weight
|
| self.dec_layers = dec_layers
|
| self.num_classes = num_classes
|
|
|
| def forward(self, preds, target):
|
|
|
| matcher = HungarianMatcher(
|
| cost_class=self.class_weight,
|
| cost_mask=self.mask_weight,
|
| cost_dice=self.dice_weight,
|
| num_points=12544,
|
| )
|
|
|
| weight_dict = {"loss_ce": self.class_weight, "loss_mask": self.mask_weight, "loss_dice": self.dice_weight}
|
| aux_weight_dict = {}
|
| for i in range(self.dec_layers - 1):
|
| aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| weight_dict.update(aux_weight_dict)
|
|
|
| losses = ["labels", "masks"]
|
| criterion = SetCriterion(
|
| num_classes=self.num_classes,
|
| matcher=matcher,
|
| weight_dict=weight_dict,
|
| eos_coef=self.no_object_weight,
|
| losses=losses,
|
| num_points=12544,
|
| oversample_ratio=3.0,
|
| importance_sample_ratio=0.75,
|
| device=torch.device(self.device)
|
| )
|
|
|
| preds["pred_masks"]= F.interpolate(
|
| preds["pred_masks"],
|
| scale_factor=(4, 4),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
|
|
| for v in preds['aux_outputs']:
|
| v['pred_masks'] = F.interpolate(
|
| v["pred_masks"],
|
| scale_factor=(4, 4),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
|
|
| losses = criterion(preds, target)
|
| weight_dict = criterion.weight_dict
|
|
|
| loss_ce = 0.0
|
| loss_dice = 0.0
|
| loss_mask = 0.0
|
| for k in list(losses.keys()):
|
| if k in weight_dict:
|
| losses[k] *= criterion.weight_dict[k]
|
| if '_ce' in k:
|
| loss_ce += losses[k]
|
| elif '_dice' in k:
|
| loss_dice += losses[k]
|
| elif '_mask' in k:
|
| loss_mask += losses[k]
|
| else:
|
|
|
| losses.pop(k)
|
| loss = loss_ce + loss_dice + loss_mask
|
| return loss |