File size: 5,721 Bytes
874cec4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # get from https://github.com/EPFL-VILAB/omnidata/blob/1af855042a05778d029d420b2a4bc1b9a0c09f30/omnidata_tools/torch/losses/midas_loss.py#L10
# only change line 154-155, for our gt is 1/depth, not depth
# Based on https://gist.github.com/dvdhfnr/732c26b61a0e63a0abc8a5d769dbebd0
import torch
import torch.nn as nn
import numpy as np
def masked_l1_loss(preds, target, mask_valid):
element_wise_loss = abs(preds - target)
element_wise_loss[~mask_valid] = 0
return element_wise_loss.sum() / mask_valid.sum()
def compute_scale_and_shift(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))
# right hand side: b = [b_0, b_1]
b_0 = torch.sum(mask * prediction * target, (1, 2))
b_1 = torch.sum(mask * target, (1, 2))
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
x_0 = torch.zeros_like(b_0)
x_1 = torch.zeros_like(b_1)
det = a_00 * a_11 - a_01 * a_01
valid = det.nonzero()
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / (det[valid] + 1e-6)
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / (det[valid] + 1e-6)
return x_0, x_1
def masked_shift_and_scale(depth_preds, depth_gt, mask_valid):
depth_preds_nan = depth_preds.clone()
depth_gt_nan = depth_gt.clone()
depth_preds_nan[~mask_valid] = np.nan
depth_gt_nan[~mask_valid] = np.nan
mask_diff = mask_valid.view(mask_valid.size()[:2] + (-1,)).sum(-1, keepdims=True) + 1
t_gt = depth_gt_nan.view(depth_gt_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
t_gt[torch.isnan(t_gt)] = 0
diff_gt = torch.abs(depth_gt - t_gt)
diff_gt[~mask_valid] = 0
s_gt = (diff_gt.view(diff_gt.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
depth_gt_aligned = (depth_gt - t_gt) / (s_gt + 1e-6)
t_pred = depth_preds_nan.view(depth_preds_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
t_pred[torch.isnan(t_pred)] = 0
diff_pred = torch.abs(depth_preds - t_pred)
diff_pred[~mask_valid] = 0
s_pred = (diff_pred.view(diff_pred.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
depth_pred_aligned = (depth_preds - t_pred) / (s_pred + 1e-6)
return depth_pred_aligned, depth_gt_aligned
def reduction_batch_based(image_loss, M):
# average of all valid pixels of the batch
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
divisor = torch.sum(M)
if divisor == 0:
return 0
else:
return torch.sum(image_loss) / divisor
def reduction_image_based(image_loss, M):
# mean of average of valid pixels of an image
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
valid = M.nonzero()
image_loss[valid] = image_loss[valid] / M[valid]
return torch.mean(image_loss)
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
M = torch.sum(mask, (1, 2))
diff = prediction - target
diff = torch.mul(mask, diff)
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
return reduction(image_loss, M)
class SSIMAE(nn.Module):
def __init__(self):
super().__init__()
def forward(self, depth_preds, depth_gt, mask_valid):
depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(depth_preds, depth_gt, mask_valid)
ssi_mae_loss = masked_l1_loss(depth_pred_aligned, depth_gt_aligned, mask_valid)
return ssi_mae_loss
class GradientMatchingTerm(nn.Module):
def __init__(self, scales=4, reduction='batch-based'):
super().__init__()
if reduction == 'batch-based':
self.__reduction = reduction_batch_based
else:
self.__reduction = reduction_image_based
self.__scales = scales
def forward(self, prediction, target, mask):
total = 0
for scale in range(self.__scales):
step = pow(2, scale)
total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
mask[:, ::step, ::step], reduction=self.__reduction)
return total
class MidasLoss(nn.Module):
def __init__(self, alpha=0.1, scales=4, reduction='image-based'):
super().__init__()
self.__ssi_mae_loss = SSIMAE()
self.__gradient_matching_term = GradientMatchingTerm(scales=scales, reduction=reduction)
self.__alpha = alpha
def forward(self, prediction, target_inverse, mask):
prediction_inverse = 1 / (prediction+1e-6)
target = 1 / (target_inverse+1e-6)
ssi_loss = self.__ssi_mae_loss(prediction, target, mask)
target_inverse = target_inverse.squeeze(1)
prediction_inverse = prediction_inverse.squeeze(1)
mask = mask.squeeze(1)
scale, shift = compute_scale_and_shift(prediction_inverse, target_inverse, mask)
prediction_ssi = scale.view(-1, 1, 1) * prediction_inverse + shift.view(-1, 1, 1)
reg_loss = self.__gradient_matching_term(prediction_ssi, target_inverse, mask)
if self.__alpha > 0:
total = ssi_loss + self.__alpha * reg_loss
return total, ssi_loss, reg_loss |