| import torch |
| import torch.nn as nn |
| from rscd.losses.loss_func import CELoss, FocalLoss, dice_loss, BCEDICE_loss, LOVASZ |
| from rscd.losses.mask2formerLoss import Mask2formerLoss |
| from rscd.losses.RSMambaLoss import FCCDN_loss_without_seg |
|
|
| class myLoss(nn.Module): |
| def __init__(self, param, loss_name=['CELoss'], loss_weight=[1.0], **kwargs): |
| super(myLoss, self).__init__() |
| self.loss_weight = loss_weight |
| self.loss = list() |
| for _loss in loss_name: |
| self.loss.append(eval(_loss)(**param[_loss],**kwargs)) |
| |
| def forward(self, preds, target): |
| loss = 0 |
| for i in range(0, len(self.loss)): |
| loss += self.loss[i](preds, target) * self.loss_weight[i] |
| return loss |
| |
| def build_loss(cfg): |
| loss_type = cfg.pop('type') |
| obj_cls = eval(loss_type) |
| obj = obj_cls(**cfg) |
| return obj |
|
|