| import torch |
| import torch.nn as nn |
|
|
|
|
| class dice_loss(nn.Module): |
| def __init__(self, batch=True): |
| super(dice_loss, self).__init__() |
| |
| |
| self.batch = batch |
|
|
| def soft_dice_coeff(self, y_pred, y_true): |
| smooth = 0.00001 |
| if self.batch: |
| i = torch.sum(y_true) |
| j = torch.sum(y_pred) |
| intersection = torch.sum(y_true * y_pred) |
| else: |
| i = y_true.sum(1).sum(1).sum(1) |
| j = y_pred.sum(1).sum(1).sum(1) |
| intersection = (y_true * y_pred).sum(1).sum(1).sum(1) |
|
|
| score = (2. * intersection + smooth) / (i + j + smooth) |
| return score.mean() |
|
|
| def soft_dice_loss(self, y_pred, y_true): |
| loss = 1 - self.soft_dice_coeff(y_pred, y_true) |
| return loss |
|
|
| def __call__(self, y_pred, y_true): |
| return self.soft_dice_loss(y_pred.to(dtype=torch.float32), y_true) |
|
|
|
|
| class dice_focal_loss(nn.Module): |
|
|
| def __init__(self): |
| super(dice_focal_loss, self).__init__() |
| self.focal_loss = nn.BCEWithLogitsLoss() |
| self.binnary_dice = dice_loss() |
|
|
| def __call__(self, scores, labels): |
| diceloss = self.binnary_dice(torch.sigmoid(scores.clone()), labels) |
| foclaloss = self.focal_loss(scores.clone(), labels) |
|
|
| return diceloss, foclaloss |
|
|
| class FCCDN_loss_without_seg(nn.Module): |
| def __init__(self): |
| super(FCCDN_loss_without_seg, self).__init__() |
| |
| def __call__(self, scores, labels): |
| |
| |
| scores = scores.squeeze(1) if len(scores.shape) > 3 else scores |
| labels = labels.squeeze(1) if len(labels.shape) > 3 else labels |
| |
| |
| |
| |
| """ for binary change detection task""" |
| criterion_change = dice_focal_loss() |
|
|
| |
| diceloss, foclaloss = criterion_change(scores, labels.float()) |
|
|
| loss_change = diceloss + foclaloss |
|
|
| return loss_change.mean() |
|
|