| from typing import Any |
|
|
| from pytorch_toolbelt.losses import BinaryFocalLoss |
| from torch import nn |
| from torch.nn.modules.loss import BCEWithLogitsLoss |
|
|
|
|
| class WeightedLosses(nn.Module): |
| def __init__(self, losses, weights): |
| super().__init__() |
| self.losses = losses |
| self.weights = weights |
|
|
| def forward(self, *input: Any, **kwargs: Any): |
| cum_loss = 0 |
| for loss, w in zip(self.losses, self.weights): |
| cum_loss += w * loss.forward(*input, **kwargs) |
| return cum_loss |
|
|
|
|
| class BinaryCrossentropy(BCEWithLogitsLoss): |
| pass |
|
|
|
|
| class FocalLoss(BinaryFocalLoss): |
| def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False, |
| reduced_threshold=None): |
| super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold) |