| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from typing import List |
|
|
| class DiceLoss(nn.Module): |
| def __init__(self, epsilon: float = 1e-6): |
| super().__init__() |
| self.epsilon = epsilon |
| def forward(self, logits: List[torch.Tensor], gts: List[torch.Tensor]): |
| if len(logits) == 0: |
| dev = gts[0].device if len(gts) else "cpu" |
| return torch.tensor(0.0, device=dev) |
| total = 0.0 |
| for pred, gt in zip(logits, gts): |
| p = pred.flatten().sigmoid() |
| g = gt.flatten().to(p.device, dtype=torch.float) |
| inter = (p * g).sum() |
| denom = p.sum() + g.sum() |
| dice = (2 * inter + self.epsilon) / (denom + self.epsilon) |
| total += (1 - dice) |
| return total / len(logits) |
|
|
| class BCELoss(nn.Module): |
| def __init__(self): |
| super().__init__() |
| def forward(self, logits: List[torch.Tensor], gts: List[torch.Tensor]): |
| if len(logits) == 0: |
| dev = gts[0].device if len(gts) else "cpu" |
| return torch.tensor(0.0, device=dev) |
| total = 0.0 |
| for pred, gt in zip(logits, gts): |
| total += F.binary_cross_entropy_with_logits( |
| pred.flatten().float(), |
| gt.flatten().to(pred.device).float(), |
| ) |
| return total / len(logits) |
|
|
| class MaskLoss(nn.Module): |
| def __init__(self, dice_weight=1.0, bce_weight=0.1, epsilon=1e-6): |
| super().__init__() |
| self.dice = DiceLoss(epsilon) |
| self.bce = BCELoss() |
| self.dw = dice_weight |
| self.bw = bce_weight |
| def forward(self, logits, gts): |
| return self.dw * self.dice(logits, gts) + self.bw * self.bce(logits, gts) |