| import torch.nn as nn |
| import torch |
| from .general import bbox_iou |
| from .postprocess import build_targets |
| from lib.core.evaluate import SegmentationMetric |
|
|
| class MultiHeadLoss(nn.Module): |
| """ |
| collect all the loss we need |
| """ |
| def __init__(self, losses, cfg, lambdas=None): |
| """ |
| Inputs: |
| - losses: (list)[nn.Module, nn.Module, ...] |
| - cfg: config object |
| - lambdas: (list) + IoU loss, weight for each loss |
| """ |
| super().__init__() |
| |
| if not lambdas: |
| lambdas = [1.0 for _ in range(len(losses) + 3)] |
| assert all(lam >= 0.0 for lam in lambdas) |
|
|
| self.losses = nn.ModuleList(losses) |
| self.lambdas = lambdas |
| self.cfg = cfg |
|
|
| def forward(self, head_fields, head_targets, shapes, model): |
| """ |
| Inputs: |
| - head_fields: (list) output from each task head |
| - head_targets: (list) ground-truth for each task head |
| - model: |
| |
| Returns: |
| - total_loss: sum of all the loss |
| - head_losses: (tuple) contain all loss[loss1, loss2, ...] |
| |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| total_loss, head_losses = self._forward_impl(head_fields, head_targets, shapes, model) |
|
|
| return total_loss, head_losses |
|
|
| def _forward_impl(self, predictions, targets, shapes, model): |
| """ |
| |
| Args: |
| predictions: predicts of [[det_head1, det_head2, det_head3], drive_area_seg_head, lane_line_seg_head] |
| targets: gts [det_targets, segment_targets, lane_targets] |
| model: |
| |
| Returns: |
| total_loss: sum of all the loss |
| head_losses: list containing losses |
| |
| """ |
| cfg = self.cfg |
| device = targets[0].device |
| lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) |
| tcls, tbox, indices, anchors = build_targets(cfg, predictions[0], targets[0], model) |
|
|
| |
| cp, cn = smooth_BCE(eps=0.0) |
|
|
| BCEcls, BCEobj, BCEseg = self.losses |
|
|
| |
| nt = 0 |
| no = len(predictions[0]) |
| balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] |
|
|
| |
| for i, pi in enumerate(predictions[0]): |
| b, a, gj, gi = indices[i] |
| tobj = torch.zeros_like(pi[..., 0], device=device) |
|
|
| n = b.shape[0] |
| if n: |
| nt += n |
| ps = pi[b, a, gj, gi] |
|
|
| |
| pxy = ps[:, :2].sigmoid() * 2. - 0.5 |
| pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] |
| pbox = torch.cat((pxy, pwh), 1).to(device) |
| iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) |
| lbox += (1.0 - iou).mean() |
|
|
| |
| tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) |
|
|
| |
| |
| if model.nc > 1: |
| t = torch.full_like(ps[:, 5:], cn, device=device) |
| t[range(n), tcls[i]] = cp |
| lcls += BCEcls(ps[:, 5:], t) |
| lobj += BCEobj(pi[..., 4], tobj) * balance[i] |
|
|
| drive_area_seg_predicts = predictions[1].view(-1) |
| drive_area_seg_targets = targets[1].view(-1) |
| lseg_da = BCEseg(drive_area_seg_predicts, drive_area_seg_targets) |
|
|
| lane_line_seg_predicts = predictions[2].view(-1) |
| lane_line_seg_targets = targets[2].view(-1) |
| lseg_ll = BCEseg(lane_line_seg_predicts, lane_line_seg_targets) |
|
|
| metric = SegmentationMetric(2) |
| nb, _, height, width = targets[1].shape |
| pad_w, pad_h = shapes[0][1][1] |
| pad_w = int(pad_w) |
| pad_h = int(pad_h) |
| _,lane_line_pred=torch.max(predictions[2], 1) |
| _,lane_line_gt=torch.max(targets[2], 1) |
| lane_line_pred = lane_line_pred[:, pad_h:height-pad_h, pad_w:width-pad_w] |
| lane_line_gt = lane_line_gt[:, pad_h:height-pad_h, pad_w:width-pad_w] |
| metric.reset() |
| metric.addBatch(lane_line_pred.cpu(), lane_line_gt.cpu()) |
| IoU = metric.IntersectionOverUnion() |
| liou_ll = 1 - IoU |
|
|
| s = 3 / no |
| lcls *= cfg.LOSS.CLS_GAIN * s * self.lambdas[0] |
| lobj *= cfg.LOSS.OBJ_GAIN * s * (1.4 if no == 4 else 1.) * self.lambdas[1] |
| lbox *= cfg.LOSS.BOX_GAIN * s * self.lambdas[2] |
|
|
| lseg_da *= cfg.LOSS.DA_SEG_GAIN * self.lambdas[3] |
| lseg_ll *= cfg.LOSS.LL_SEG_GAIN * self.lambdas[4] |
| liou_ll *= cfg.LOSS.LL_IOU_GAIN * self.lambdas[5] |
|
|
| |
| if cfg.TRAIN.DET_ONLY or cfg.TRAIN.ENC_DET_ONLY or cfg.TRAIN.DET_ONLY: |
| lseg_da = 0 * lseg_da |
| lseg_ll = 0 * lseg_ll |
| liou_ll = 0 * liou_ll |
| |
| if cfg.TRAIN.SEG_ONLY or cfg.TRAIN.ENC_SEG_ONLY: |
| lcls = 0 * lcls |
| lobj = 0 * lobj |
| lbox = 0 * lbox |
|
|
| if cfg.TRAIN.LANE_ONLY: |
| lcls = 0 * lcls |
| lobj = 0 * lobj |
| lbox = 0 * lbox |
| lseg_da = 0 * lseg_da |
|
|
| if cfg.TRAIN.DRIVABLE_ONLY: |
| lcls = 0 * lcls |
| lobj = 0 * lobj |
| lbox = 0 * lbox |
| lseg_ll = 0 * lseg_ll |
| liou_ll = 0 * liou_ll |
|
|
| loss = lbox + lobj + lcls + lseg_da + lseg_ll + liou_ll |
| |
| |
| return loss, (lbox.item(), lobj.item(), lcls.item(), lseg_da.item(), lseg_ll.item(), liou_ll.item(), loss.item()) |
|
|
|
|
| def get_loss(cfg, device): |
| """ |
| get MultiHeadLoss |
| |
| Inputs: |
| -cfg: configuration use the loss_name part or |
| function part(like regression classification) |
| -device: cpu or gpu device |
| |
| Returns: |
| -loss: (MultiHeadLoss) |
| |
| """ |
| |
| BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.CLS_POS_WEIGHT])).to(device) |
| |
| BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.OBJ_POS_WEIGHT])).to(device) |
| |
| BCEseg = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.SEG_POS_WEIGHT])).to(device) |
| |
| gamma = cfg.LOSS.FL_GAMMA |
| if gamma > 0: |
| BCEcls, BCEobj = FocalLoss(BCEcls, gamma), FocalLoss(BCEobj, gamma) |
|
|
| loss_list = [BCEcls, BCEobj, BCEseg] |
| loss = MultiHeadLoss(loss_list, cfg=cfg, lambdas=cfg.LOSS.MULTI_HEAD_LAMBDA) |
| return loss |
|
|
| |
| |
|
|
|
|
| def smooth_BCE(eps=0.1): |
| |
| return 1.0 - 0.5 * eps, 0.5 * eps |
|
|
|
|
| class FocalLoss(nn.Module): |
| |
| def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): |
| |
| |
| super(FocalLoss, self).__init__() |
| self.loss_fcn = loss_fcn |
| self.gamma = gamma |
| self.alpha = alpha |
| self.reduction = loss_fcn.reduction |
| self.loss_fcn.reduction = 'none' |
|
|
| def forward(self, pred, true): |
| loss = self.loss_fcn(pred, true) |
| |
| |
|
|
| |
| pred_prob = torch.sigmoid(pred) |
| p_t = true * pred_prob + (1 - true) * (1 - pred_prob) |
| alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) |
| modulating_factor = (1.0 - p_t) ** self.gamma |
| loss *= alpha_factor * modulating_factor |
|
|
| if self.reduction == 'mean': |
| return loss.mean() |
| elif self.reduction == 'sum': |
| return loss.sum() |
| else: |
| return loss |
|
|