Diffusers
Safetensors
EvalMDE / DepthMaster /src /util /loss.py
zeyuren2002's picture
Add files using upload-large-folder tool
4b7b610 verified
# Last modified: 2025-01-14
#
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
#
# This file has been modified from the original version.
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
# --------------------------------------------------------------------------
import torch
import torch.nn as nn
def get_loss(loss_name, **kwargs):
if "silog_mse" == loss_name:
criterion = SILogMSELoss(**kwargs)
elif "silog_rmse" == loss_name:
criterion = SILogRMSELoss(**kwargs)
elif "mse_loss" == loss_name:
criterion = torch.nn.MSELoss(**kwargs)
elif "l1_loss" == loss_name:
criterion = torch.nn.L1Loss(**kwargs)
elif "l1_loss_with_mask" == loss_name:
criterion = L1LossWithMask(**kwargs)
elif "mean_abs_rel" == loss_name:
criterion = MeanAbsRelLoss()
elif "huber_loss" == loss_name:
criterion = HuberLoss(**kwargs)
else:
raise NotImplementedError
return criterion
class L1LossWithMask:
def __init__(self, batch_reduction=False):
self.batch_reduction = batch_reduction
def __call__(self, depth_pred, depth_gt, valid_mask=None):
diff = depth_pred - depth_gt
if valid_mask is not None:
diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = depth_gt.shape[-2] * depth_gt.shape[-1]
loss = torch.sum(torch.abs(diff)) / n
if self.batch_reduction:
loss = loss.mean()
return loss
class MeanAbsRelLoss:
def __init__(self) -> None:
# super().__init__()
pass
def __call__(self, pred, gt):
diff = pred - gt
rel_abs = torch.abs(diff / gt)
loss = torch.mean(rel_abs, dim=0)
return loss
class SILogMSELoss:
def __init__(self, lamb, log_pred=True, batch_reduction=True):
"""Scale Invariant Log MSE Loss
Args:
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred
"""
super(SILogMSELoss, self).__init__()
self.lamb = lamb
self.pred_in_log = log_pred
self.batch_reduction = batch_reduction
def __call__(self, depth_pred, depth_gt, valid_mask=None):
log_depth_pred = (
depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8))
)
log_depth_gt = torch.log(depth_gt)
diff = log_depth_pred - log_depth_gt
if valid_mask is not None:
diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = depth_gt.shape[-2] * depth_gt.shape[-1]
diff2 = torch.pow(diff, 2)
first_term = torch.sum(diff2, (-1, -2)) / n
second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
loss = first_term - second_term
if self.batch_reduction:
loss = loss.mean()
return loss
class HuberLoss:
def __init__(self, delta=0.5):
self.delta = delta
def __call__(self, depth_pred, depth_gt, valid_mask=None):
# huber 损失
# 计算预测值与真实值的差值
diff = depth_gt - depth_pred
# 计算绝对值和差值的平方
abs_diff = torch.abs(diff)
squared_diff = diff ** 2
# 使用条件语句选择L2损失或L1损失
loss = torch.where(abs_diff > self.delta, 0.5 * squared_diff, self.delta * abs_diff + 0.5 * self.delta ** 2)
# 返回所有样本损失的总和
if valid_mask is not None:
return torch.mean(loss[valid_mask])
else:
return torch.mean(loss)
class SILogRMSELoss:
def __init__(self, lamb, log_pred=True):
"""Scale Invariant Log RMSE Loss
Args:
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss
alpha:
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred
"""
super(SILogRMSELoss, self).__init__()
self.lamb = lamb
# self.alpha = alpha
self.pred_in_log = log_pred
# def __call__(self, depth_pred, depth_gt, valid_mask):
# log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred)
# log_depth_gt = torch.log(depth_gt)
# # borrowed from https://github.com/aliyun/NeWCRFs
# # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask]
# # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha
# diff = log_depth_pred - log_depth_gt
# if valid_mask is not None:
# diff[~valid_mask] = 0
# n = valid_mask.sum((-1, -2))
# else:
# n = depth_gt.shape[-2] * depth_gt.shape[-1]
# diff2 = torch.pow(diff, 2)
# first_term = torch.sum(diff2, (-1, -2)) / n
# second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
# loss = torch.sqrt(first_term - second_term).mean()
# return loss
def __call__(self, depth_pred, depth_gt, valid_mask):
valid_mask = valid_mask.detach()
log_depth_pred = torch.log(depth_pred[valid_mask])
log_depth_gt = torch.log(depth_gt[valid_mask])
diff = log_depth_gt - log_depth_pred
first_term = torch.pow(diff, 2).mean()
second_term = self.lamb * torch.pow(diff.mean(), 2)
loss = torch.sqrt(first_term - second_term)
return loss
def get_smooth_loss(disp, img):
"""Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
"""
mean_disp = disp.mean(2, True).mean(3, True)
disp = disp / (mean_disp + 1e-7)
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
grad_disp_x *= torch.exp(-grad_img_x)
grad_disp_y *= torch.exp(-grad_img_y)
return grad_disp_x.mean() + grad_disp_y.mean()
class SSIM(nn.Module):
"""Layer to compute the SSIM loss between a pair of images
"""
def __init__(self):
super(SSIM, self).__init__()
self.mu_x_pool = nn.AvgPool2d(3, 1)
self.mu_y_pool = nn.AvgPool2d(3, 1)
self.sig_x_pool = nn.AvgPool2d(3, 1)
self.sig_y_pool = nn.AvgPool2d(3, 1)
self.sig_xy_pool = nn.AvgPool2d(3, 1)
self.refl = nn.ReflectionPad2d(1)
self.C1 = 0.01 ** 2
self.C2 = 0.03 ** 2
def forward(self, x, y):
x = self.refl(x)
y = self.refl(y)
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)