Diffusers
Safetensors
EvalMDE / Edit2Perceive /utils /cycle_loss.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
import torch
import numpy as np
import torch.nn.functional as F
import math
import torch.nn as nn
def _compute_sobel_gradients_chunked(image, chunk_size=4):
"""
Computes Sobel gradients for a batch of images in chunks to save memory.
Args:
image: Tensor of shape [B, H, W]
chunk_size: The size of each chunk to process.
Returns:
tuple: (grad_x, grad_y), each of shape [B, H, W]
"""
# Sobel kernels
sobel_x_kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=image.dtype, device=image.device).view(1, 1, 3, 3)
sobel_y_kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=image.dtype, device=image.device).view(1, 1, 3, 3)
batch_size = image.shape[0]
# 如果 batch_size 小于等于 chunk_size,就不需要分块
if batch_size <= chunk_size:
image_unsqueezed = image.unsqueeze(1)
grad_x = F.conv2d(image_unsqueezed, sobel_x_kernel, padding=1).squeeze(1)
grad_y = F.conv2d(image_unsqueezed, sobel_y_kernel, padding=1).squeeze(1)
return grad_x, grad_y
# --- 分块计算 ---
grads_x_list = []
grads_y_list = []
# 使用 torch.no_grad() 上下文可以进一步节省显存,因为我们不需要为卷积操作本身计算梯度
# Sobel核是固定的,不需要梯度。输入image的梯度会通过torch自动追踪。
with torch.no_grad():
for i in range(0, batch_size, chunk_size):
chunk = image[i:i + chunk_size]
chunk_unsqueezed = chunk.unsqueeze(1)
# 对每个块进行卷积
chunk_grad_x = F.conv2d(chunk_unsqueezed, sobel_x_kernel, padding=1)
chunk_grad_y = F.conv2d(chunk_unsqueezed, sobel_y_kernel, padding=1)
grads_x_list.append(chunk_grad_x)
grads_y_list.append(chunk_grad_y)
# 合并结果
grad_x = torch.cat(grads_x_list, dim=0).squeeze(1)
grad_y = torch.cat(grads_y_list, dim=0).squeeze(1)
return grad_x, grad_y
# def get_cycle_consistency_depth_loss(pred_depth, gt_depth, gt_mask=None, depth_normalization="log",eps=1e-6):
# """
# Cycle consistency depth loss based on AbsRel metric (Optimized batch processing version)
# Args:
# pred_depth: [B,C,H,W] or [B,1,H,W] - predicted depth maps
# gt_depth: [B,C,H,W] or [B,1,H,W] - ground truth depth maps
# gt_mask: [B,H,W] or [B,1,H,W] - mask for valid GT positions (optional)
# eps: float - small value to avoid division by zero
# depth_normalization: str - depth normalization method
# Returns:
# torch.Tensor - scalar loss value for backpropagation
# """
# # Ensure consistent shape -> [B, H, W]
# if pred_depth.dim() == 4 and pred_depth.shape[1] != 1:
# pred_depth = torch.mean(pred_depth, dim=1, keepdim=False)
# elif pred_depth.dim() == 4:
# pred_depth = pred_depth.squeeze(1)
# if gt_depth.dim() == 4 and gt_depth.shape[1] != 1:
# gt_depth = torch.mean(gt_depth, dim=1, keepdim=False)
# elif gt_depth.dim() == 4:
# gt_depth = gt_depth.squeeze(1)
# # Handle shape mismatch by resizing prediction to match GT
# if pred_depth.shape != gt_depth.shape:
# pred_depth = F.interpolate(pred_depth.unsqueeze(1), size=gt_depth.shape[-2:],
# mode='bilinear', align_corners=False).squeeze(1)
# # Handle invalid values in GT and prediction (batch processing)
# gt_depth = torch.where(torch.isinf(gt_depth), torch.zeros_like(gt_depth), gt_depth)
# gt_depth = torch.where(torch.isnan(gt_depth), torch.zeros_like(gt_depth), gt_depth)
# pred_depth = torch.where(torch.isinf(pred_depth), torch.zeros_like(pred_depth), pred_depth)
# pred_depth = torch.where(torch.isnan(pred_depth), torch.zeros_like(pred_depth), pred_depth)
# # Create valid mask
# if gt_mask is not None:
# if gt_mask.dim() == 4:
# gt_mask = gt_mask.squeeze(1) # [B, H, W]
# valid_mask = gt_mask.bool()
# else:
# valid_mask = torch.ones_like(gt_depth, dtype=torch.bool, device=gt_depth.device)
# # Clamp depths to valid range
# pred_depth = torch.clamp(pred_depth, -1.0, 1.0)
# gt_depth = torch.clamp(gt_depth, -1.0, 1.0)
# # Apply depth normalization (batch processing)
# if depth_normalization == "log":
# pred_depth = torch.exp(pred_depth)
# gt_depth = torch.exp(gt_depth)
# elif depth_normalization == "sqrt":
# pred_depth = pred_depth**2
# gt_depth = gt_depth**2
# elif depth_normalization == "disp":
# pred_depth = 1.0/(pred_depth + eps)
# gt_depth = 1.0/(gt_depth + eps)
# elif depth_normalization == "sqrt_disp":
# pred_depth = 1.0/(pred_depth**2 + eps)
# gt_depth = 1.0/(gt_depth**2 + eps)
# elif depth_normalization == "uniform":
# pass # already in the same scale as real relative depth
# else:
# raise ValueError(f"Unknown depth normalization: {depth_normalization}")
# # Compute AbsRel metric for each sample in the batch
# pred_depth = torch.log(pred_depth + eps) # [B, H, W]
# gt_depth = torch.log(gt_depth + eps) # [B, H, W
# abs_diff = torch.abs(pred_depth - gt_depth) # [B, H, W]
# rel_diff = abs_diff
# # rel_diff = abs_diff / (gt_depth + eps) # [B, H, W], add eps to avoid division by zero
# # Apply mask and compute mean for each sample
# masked_rel_diff = rel_diff * valid_mask.float() # [B, H, W]
# # Sum over spatial dimensions and divide by number of valid pixels per sample
# valid_pixel_count = valid_mask.sum(dim=[1, 2]).float() # [B]
# sample_losses = masked_rel_diff.sum(dim=[1, 2]) # [B]
# # Avoid division by zero for samples with no valid pixels
# valid_samples_mask = valid_pixel_count > 0
# sample_losses = sample_losses / torch.clamp(valid_pixel_count, min=1.0)
# # Only consider samples with valid pixels
# if valid_samples_mask.any():
# final_loss = sample_losses[valid_samples_mask].mean()
# else:
# # If no valid samples, return zero loss
# final_loss = torch.tensor(0.0, device=pred_depth.device, requires_grad=True)
# return final_loss
class ScaleAndShiftInvariantLoss(nn.Module):
def __init__(self):
super().__init__()
self.name = "SSILoss"
def forward(self, prediction, target, mask,depth_normalization="log"):
# if mask.ndim ==3:
# mask = mask.unsqueeze(1) # [B,1,H,W]
if prediction.ndim == 4:
prediction = torch.mean(prediction, dim=1, keepdim=False) # [B,H,W]
if target.ndim ==4:
target = torch.mean(target, dim=1, keepdim=False) # [B,H,W]
mask = mask.bool()
with torch.autocast(device_type='cuda', enabled=False):
prediction = prediction.float()
target = target.float()
if depth_normalization == "log":
target_cp = torch.log(target).requires_grad_(False)
elif depth_normalization == "sqrt":
target_cp = torch.sqrt(target).requires_grad_(False)
scale, shift = compute_scale_and_shift_masked(prediction, target_cp, mask)
del target_cp
scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
return loss
def compute_scale_and_shift_masked(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))
# right hand side: b = [b_0, b_1]
b_0 = torch.sum(mask * prediction * target, (1, 2))
b_1 = torch.sum(mask * target, (1, 2))
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
x_0 = torch.zeros_like(b_0)
x_1 = torch.zeros_like(b_1)
det = a_00 * a_11 - a_01 * a_01
# A needs to be a positive definite matrix.
valid = det > 0 #1e-3
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
return x_0, x_1
def get_cycle_consistency_depth_loss(pred, gt, mask=None, depth_normalization="log",eps=1e-8,
alpha=0.85, # SILog 损失的权重
beta=1.0, # 梯度损失的权重
gamma=0.0,
per_sample_weights=None): # 可选的法线一致性损失的权重
"""
一个专注于几何结构的、稳定版的深度损失函数。
它结合了尺度不变损失(SILog)和局部结构损失(梯度匹配)。
Args:
pred: [B,3,H,W] 或 [B,H,W,3], 范围 [-1, 1]
gt: [B,3,H,W] 或 [B,H,W,3], 范围 [-1, 1]
mask: None 或 [B,H,W] / [B,1,H,W] / [B,H,W,1]
eps: 数值稳定项
alpha: SILog 损失中的方差项权重 (通常 0.5 或 0.85)
beta: 梯度损失的权重
gamma: 法线一致性损失的权重 (如果>0,会计算法线损失)
Returns:
单标量 tensor(可反传)
"""
# --- 形状、device、dtype 规范化 ---
if pred.dim() == 4 and pred.shape[1] != 3 and pred.shape[-1] == 3:
pred = pred.permute(0, 3, 1, 2)
if gt.dim() == 4 and gt.shape[1] != 3 and gt.shape[-1] == 3:
gt = gt.permute(0, 3, 1, 2)
B, C, H, W = pred.shape
device = pred.device
dtype = pred.dtype
# if per_sample_weights is None:
# per_sample_weights = torch.ones((B,), device=device, dtype=dtype)
# --- 1. 预处理 ---
# 取3通道均值得到单通道深度图
pred = pred.mean(dim=1,keepdim=True) # [B,1,H,W]
gt = gt.mean(dim=1,keepdim=True) # [B,1,H,W]
# 将 [-1, 1] 范围映射到正数范围 (e.g., [eps, 1+eps]),为 log 做准备
# 这一步非常重要,因为它将归一化的值转回到了一个类似深度的正数空间
pred = (pred.clamp(min=-1, max=1) + 1.0) / 2.0 + eps
gt = (gt.clamp(min=-1, max=1) + 1.0) / 2.0 + eps
# if depth_normalization == "log":
# pred = torch.exp(pred)
# gt = torch.exp(gt)
# elif depth_normalization == "sqrt":
# pred = pred**2
# gt = gt**2
# elif depth_normalization == "disp":
# pred = 1.0/(pred + eps)
# gt = 1.0/(gt + eps)
# elif depth_normalization == "sqrt_disp":
# pred = 1.0/(pred**2 + eps)
# gt = 1.0/(gt**2 + eps)
# else:
# # uniform
# pass # already in the same scale as real relative depth
# --- mask 处理 ---
if mask is None:
# 如果没有mask,创建一个全为True的mask
mask = torch.ones_like(pred, dtype=torch.bool, device=device)
else:
# 确保mask是 [B,H,W] 的 bool tensor
if mask.ndim == 4:
mask = torch.mean(mask, dim=1,keepdim=True).bool() # [B,1,H,W]
elif mask.ndim ==3:
mask = mask.unsqueeze(1).bool() # [B,1,H,W]
mask = mask.bool().to(device)
# 检查有效像素数量
total_valid = mask.sum()
if total_valid.item() == 0:
return torch.zeros((), device=device, dtype=dtype, requires_grad=True)
# --- 2. 计算复合损失 ---
# --- Component A: Scale-Invariant Logarithmic (SILog) Loss ---
# 这是点对点损失的鲁棒版本,对尺度不敏感
log_diff = torch.log(1+pred[mask]) - torch.log(1+gt[mask]) # B,N
# # SILog 公式: E[d^2] - alpha * (E[d])^2
silog_term1 = torch.mean(log_diff ** 2)
silog_term2 = (torch.mean(log_diff)) ** 2
loss_silog = silog_term1 - alpha * silog_term2
loss_silog = (loss_silog).mean()
# --- Component B: Gradient Matching Loss ---
# 这是结构损失,确保表面的局部坡度一致
# 使用简单的 sobel 算子来计算梯度
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=device, dtype=dtype).view(1, 1, 3, 3)
pred_grad_x = F.conv2d(pred, sobel_x, padding=1)
pred_grad_y = F.conv2d(pred, sobel_y, padding=1)
gt_grad_x = F.conv2d(gt, sobel_x, padding=1)
gt_grad_y = F.conv2d(gt, sobel_y, padding=1)
# 计算梯度差异的 L1 损失
grad_loss_x = torch.abs(pred_grad_x - gt_grad_x)
grad_loss_y = torch.abs(pred_grad_y - gt_grad_y)
# 只在 mask 区域计算损失
# 稍微腐蚀一下mask,因为边界处的梯度可能不准确
mask_eroded = F.max_pool2d(mask.float(), kernel_size=3, stride=1, padding=1).bool()
loss_grad = (grad_loss_x[mask_eroded].mean() + grad_loss_y[mask_eroded].mean())
# loss_grad = ((grad_loss_x[mask_eroded].view(B, -1).mean(dim=1) + grad_loss_y[mask_eroded].view(B, -1).mean(dim=1)) * per_sample_weights).mean()
# --- Component C (Optional): Normal Consistency Loss ---
# 如果 gamma > 0,则计算从深度图重建的法线之间的一致性损失
loss_normal = torch.tensor(0.0, device=device, dtype=dtype)
if gamma > 0:
# 从深度图计算法线 (使用梯度)
# 法线 n = [-gx, -gy, 1],然后归一化
# 注意:这里假设相机内参 f_x, f_y = 1
ones = torch.ones_like(pred)
pred_normal = torch.cat([-pred_grad_x, -pred_grad_y, ones], dim=1)
gt_normal = torch.cat([-gt_grad_x, -gt_grad_y, ones], dim=1)
# 使用你成功的法线损失函数
loss_normal = get_cycle_consistency_normal_loss(
pred_normal, gt_normal, mask=mask_eroded.squeeze(1),
)
# --- 最终损失 ---
# 加权求和
# 权重可以根据你的任务进行调整,beta=1.0, gamma=0.0 是一个很好的起点
total_loss = loss_silog + beta * loss_grad + gamma * loss_normal
return total_loss
# ===================================================================
# 核心辅助函数:从深度图计算法线
# ===================================================================
# def depth_to_normals(depth, camera_intrinsics=None, eps=1e-8):
# """
# 从深度图计算表面法线。
# Args:
# depth: [B, 1, H, W] 深度图
# camera_intrinsics: [B, 3, 3] 相机内参矩阵 K。如果为 None,则在像素空间计算。
# Returns:
# normals: [B, 3, H, W] 法线图
# """
# B, _, H, W = depth.shape
# device = depth.device
# dtype = depth.dtype
# # 创建像素坐标网格
# y_coords, x_coords = torch.meshgrid(torch.arange(H, device=device, dtype=dtype),
# torch.arange(W, device=device, dtype=dtype),
# indexing='ij')
# coords = torch.stack([x_coords, y_coords, torch.ones_like(x_coords)], dim=0).unsqueeze(0).repeat(B, 1, 1, 1)
# if camera_intrinsics is not None:
# # 反投影到相机坐标系
# K_inv = torch.inverse(camera_intrinsics).view(B, 3, 3)
# cam_coords = K_inv @ coords.to(torch.float32).view(B, 3, -1)
# cam_coords = cam_coords.view(B, 3, H, W)
# point_cloud = cam_coords * depth
# else:
# # 在像素空间计算(一个合理的近似)
# point_cloud = torch.cat([coords[:, :2, ...], depth], dim=1)
# # 使用 padding + 卷积计算梯度,更稳定
# kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) / 8.0
# kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) / 8.0
# # 对点云的每个分量求梯度
# point_cloud = point_cloud.to(torch.bfloat16)
# p_x, p_y, p_z = point_cloud[:, 0:1], point_cloud[:, 1:2], point_cloud[:, 2:3]
# tg_x = torch.cat([F.conv2d(p, kernel_x, padding='same') for p in (p_x, p_y, p_z)], dim=1)
# tg_y = torch.cat([F.conv2d(p, kernel_y, padding='same') for p in (p_x, p_y, p_z)], dim=1)
# # 计算法线 (cross product)
# # tg_y x tg_x 的方向通常指向相机
# normals = torch.cross(tg_y, tg_x, dim=1)
# # 单位化
# normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(eps)
# return normals
# # ===================================================================
# # 你的全新损失函数!
# # ===================================================================
# def get_cycle_consistency_depth_loss(pred_depth, gt_depth, mask=None, camera_intrinsics=None, depth_normalization="log",eps=1e-8,per_sample_weights=None):
# """
# 将深度估计问题转化为法线几何一致性问题的损失函数。
# 1. 使用最小二乘法对齐 pred 和 gt_depth。
# 2. 将对齐后的 pred 和 gt_depth 转换为法线图。
# 3. 使用 get_cycle_consistency_normal_loss 计算法线间的损失。
# Args:
# pred: [B,3,H,W] or [B,H,W,3], 模型预测的深度图,值域 [-1, 1]
# input_depth: [B,H,W] or [B,1,H,W], 原始的、未归一化的真实深度图
# mask: None or [B,H,W], 有效像素区域
# camera_intrinsics: 可选, [B, 3, 3] 相机内参, 用于更精确的法线计算
# eps: 数值稳定项
# Returns:
# 单标量 tensor (可反传)
# """
# # --- 形状、device、dtype 规范化 ---
# device = pred_depth.device
# dtype = pred_depth.dtype
# if pred_depth.ndim ==4:
# pred_depth = torch.mean(pred_depth, dim=1) # [B, H, W]
# if gt_depth.ndim == 4:
# gt_depth = torch.mean(gt_depth, dim=1) # [B, H, W]
# B, H, W = pred_depth.shape
# if depth_normalization == "log":
# pred_depth = (pred_depth + 1.0) / 2.0 + eps
# pred_depth = torch.exp(pred_depth)
# elif depth_normalization == "sqrt":
# pred_depth = (pred_depth + 1.0) / 2.0 + eps
# pred_depth = pred_depth**2
# elif depth_normalization == "disp":
# pred_depth = (pred_depth + 1.0) / 2.0 + eps
# pred_depth = 1.0/(pred_depth + eps)
# elif depth_normalization == "sqrt_disp":
# pred_depth = (pred_depth + 1.0) / 2.0 + eps
# pred_depth = 1.0/(pred_depth**2 + eps)
# # --- Mask 处理 ---
# if mask is None:
# mask = (gt_depth > eps).detach()
# else:
# # if mask.dim() == 3:
# # mask = mask.unsqueeze(1)
# mask = mask.bool().to(device)
# # --- 核心步骤 1: 最小二乘法对齐尺度和偏移 ---
# aligned_pred_depths = []
# for i in range(B): # 逐个样本处理
# p = pred_depth[i][mask[i]]
# g = gt_depth[i][mask[i]]
# if p.numel() < 2: # 有效点太少,无法拟合
# aligned_pred_depths.append(pred_depth[i])
# continue
# # 构造最小二乘问题 y = s*x + t, A*[s,t]^T = y
# A = torch.stack([p, torch.ones_like(p)], dim=1)
# y = g
# # 使用 torch.linalg.lstsq 求解,稳定且高效
# try:
# solution = torch.linalg.lstsq(A.to(torch.float32), y.to(torch.float32)).solution
# s, t = solution[0], solution[1]
# except torch.linalg.LinAlgError:
# # 如果矩阵奇异,退化为只匹配均值
# s = torch.tensor(1.0, device=device, dtype=dtype)
# t = g.mean() - p.mean()
# # 对齐预测值,不让梯度流过 s 和 t
# aligned_pred_depths.append(pred_depth[i] * s.detach() + t.detach())
# aligned_pred_depth = torch.stack(aligned_pred_depths)
# # --- 核心步骤 2: 从对齐后的深度图计算法线 ---
# # pred_normals = depth_to_normals(aligned_pred_depth, camera_intrinsics, k=5, d=1, gamma=0.05, min_nghbr=4, eps=eps)
# # gt_normals = depth_to_normals(gt_depth, camera_intrinsics, k=5, d=1, gamma=0.05, min_nghbr=4, eps=eps)
# # pred_normals = depth_to_normals(aligned_pred_depth, camera_intrinsics,)
# # gt_normals = depth_to_normals(gt_depth, camera_intrinsics)
# # # --- 核心步骤 3: 复用成功的法线损失 ---
# # # 注意,法线计算在边缘处可能无效,mask需要结合深度mask
# # final_mask = mask & (gt_depth > eps) & (pred_normals.norm(dim=1, keepdim=True) > eps) & (gt_normals.norm(dim=1, keepdim=True) > eps)
# # loss = get_cycle_consistency_normal_loss(pred_normals, gt_normals, final_mask)
# if per_sample_weights is None:
# per_sample_weights = torch.ones((B,), device=device, dtype=dtype)
# loss = per_sample_weights.view(B,1,1)*(aligned_pred_depth-gt_depth)**2*mask # [B,H,W] -> [B]
# return loss.mean().clip(1e-6,1.0)
import torch
import math
def get_cycle_consistency_normal_loss(pred, gt, mask=None, eps=1e-8, per_sample_weights=None):
"""
稳定版的法线角度损失(可直接反传)
Args:
pred: [B,3,H,W] 或 [B,H,W,3]
gt: [B,3,H,W] 或 [B,H,W,3]
mask: None 或 [B,H,W] / [B,1,H,W] / [B,H,W,1]
eps: 数值稳定项
per_sample_weights: None 或 [B],每个样本的权重
Returns:
单标量 tensor(可反传)
"""
# --- 形状、device、dtype 规范化 ---
if pred.dim() == 4 and pred.shape[1] != 3 and pred.shape[-1] == 3:
pred = pred.permute(0, 3, 1, 2)
if gt.dim() == 4 and gt.shape[1] != 3 and gt.shape[-1] == 3:
gt = gt.permute(0, 3, 1, 2)
B, C, H, W = pred.shape
assert C == 3, "输入必须是3通道法线向量"
device = pred.device
dtype = pred.dtype
# --- 单位化(防止除零) ---
pred = pred / pred.norm(dim=1, keepdim=True).clamp_min(eps)
gt = gt / gt.norm(dim=1, keepdim=True).clamp_min(eps)
# --- mask 处理 ---
if mask is None:
mask = torch.ones((B, H, W), dtype=torch.bool, device=device)
if mask.dim() == 4:
mask = mask.squeeze(1)
if mask.dim() == 3 and mask.shape[1] == 1: # 兼容 [B,1,H,W]
mask = mask.squeeze(1)
mask = mask.bool().to(device)
# flatten为 [B,3,N] / [B,N]
pred = pred.view(B, 3, -1)
gt = gt.view(B, 3, -1)
mask = mask.view(B, -1).bool() # [B, N]
# --- 使用 dot 和 cross -> atan2 更稳定 ---
dot = (pred * gt).sum(dim=1) # [B, N], 取代 acos 的 cos 值
cross = torch.cross(pred, gt, dim=1) # [B,3,N]
sin_val = cross.norm(dim=1) # [B, N]
# clamp 防止极端数值(但不把 dot 强行压到 ±1)
dot = dot.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
sin_val = sin_val.clamp_min(1e-12)
# angle in degrees
angle = torch.atan2(sin_val, dot) * (180.0 / math.pi) # [B, N]
# --- 按 mask 计算 mean / accuracy(避免空mask导致 NaN) ---
mask_f = mask.float() # [B, N]
counts_per_batch = mask_f.sum(dim=1) # [B]
# total = counts_per_batch.sum() # scalar
# if total.item() == 0:
# # 没有有效像素:返回 0 且保留梯度(避免图断裂)
# return torch.zeros((), device=device, dtype=dtype, requires_grad=True)
# # mean angular error over masked elements
# loss = (angle * mask_f ).sum() / total
if per_sample_weights is None:
per_sample_weights = torch.ones((B,), device=device, dtype=dtype)
loss = (angle * mask_f).sum(dim=1) / counts_per_batch.clamp_min(1) # [B]
loss = (loss * per_sample_weights).mean()
return loss / 100.0 # scale down for稳定(按你原代码保留)
def get_cycle_consistency_matting_loss(pred_alpha, gt_alpha, trimap):
"""
一个集成的、可微的Matting损失函数 (PyTorch, Batch version)。
此函数将SAD、MSE和Gradient Loss的计算合并在一起,并对未知区域进行操作。
它被设计为简单、直接,易于集成到训练循环中。
Args:
pred_alpha (torch.Tensor): 预测的alpha matte, shape [B, 1, H, W], 值在 [0, 1] 之间。
gt_alpha (torch.Tensor): 真实的alpha matte, shape [B, 1, H, W], 值在 [0, 1] 之间。
trimap (torch.Tensor): Trimap, shape [B, 1, H, W], 值为 0, 128, 255。
Returns:
torch.Tensor: 一个用于反向传播的标量损失值。
"""
# --- 1. 预处理: 创建未知区域的掩码 ---
# 未知区域是trimap中值为128的像素
if pred_alpha.ndim == 4:
pred_alpha = torch.mean(pred_alpha, dim=1, keepdim=True) # [B,1,H,W]
if gt_alpha.ndim == 4:
gt_alpha = torch.mean(gt_alpha, dim=1, keepdim=True) # [B,1,H,W]
if trimap.ndim == 3:
trimap = trimap.unsqueeze(1) # [B,1,H,W]
unknown_mask = ((trimap >= 127.5) & (trimap <= 128.5)).float()
# --- 2. SAD Loss (Sum of Absolute Differences) ---
error_map_sad = torch.abs(pred_alpha - gt_alpha)
# 对每个样本的未知区域求和,并按原始实现进行缩放
loss_sad_per_sample = (error_map_sad * unknown_mask).sum(dim=(1, 2, 3)) / 1000.0
# --- 3. MSE Loss (Mean Squared Error) ---
error_map_mse = (pred_alpha - gt_alpha) ** 2
# 对每个样本的未知区域求和,然后除以该区域的像素数
loss_mse_per_sample = (error_map_mse * unknown_mask).sum(dim=(1, 2, 3)) / (unknown_mask.sum(dim=(1, 2, 3)) + 1e-8)
# --- 4. Gradient Loss ---
# 在函数内部动态创建高斯梯度核,以匹配输入张量的设备
device = pred_alpha.device
sigma = 1.4
# 创建高斯导数核
epsilon = 1e-2
halfsize = math.ceil(sigma * math.sqrt(-2 * math.log(math.sqrt(2 * math.pi) * sigma * epsilon)))
size = 2 * halfsize + 1
coords = torch.arange(-halfsize, halfsize + 1, dtype=torch.float32, device=device)
gauss_vals = torch.exp(-coords**2 / (2 * sigma**2)) / (sigma * math.sqrt(2 * math.pi))
dgauss_vals = -coords * gauss_vals / (sigma**2)
# x和y方向的核
hx = (gauss_vals.unsqueeze(1) * dgauss_vals.unsqueeze(0))
hy = hx.t()
# 归一化
hx = hx / torch.sqrt(torch.sum(torch.abs(hx) * torch.abs(hx)))
hy = hy / torch.sqrt(torch.sum(torch.abs(hy) * torch.abs(hy)))
# 为conv2d调整形状: [out_channels, in_channels, H, W]
kernel_x = hx.to(pred_alpha.dtype).unsqueeze(0).unsqueeze(0)
kernel_y = hy.to(pred_alpha.dtype).unsqueeze(0).unsqueeze(0)
# 使用卷积计算梯度
pred_gx = F.conv2d(pred_alpha, kernel_x, padding='same')
pred_gy = F.conv2d(pred_alpha, kernel_y, padding='same')
gt_gx = F.conv2d(gt_alpha, kernel_x, padding='same')
gt_gy = F.conv2d(gt_alpha, kernel_y, padding='same')
# 计算梯度幅值
pred_amp = torch.sqrt(pred_gx**2 + pred_gy**2)
gt_amp = torch.sqrt(gt_gx**2 + gt_gy**2)
# 计算梯度损失
error_map_grad = (pred_amp - gt_amp) ** 2
loss_grad_per_sample = (error_map_grad * unknown_mask).sum(dim=(1, 2, 3)) / 1000.0
# --- 5. 组合损失并在Batch维度上取平均 ---
# 你可以根据需要为各项损失添加权重, e.g., 1.0 * loss_sad + 0.5 * loss_mse + ...
total_loss_per_sample = loss_sad_per_sample + loss_mse_per_sample + loss_grad_per_sample
final_loss = total_loss_per_sample.mean()
return final_loss
def get_disperse_loss(eta, tau=1.0):
"""
计算 Dispersive Loss.
公式: L_disp = log E_{i,j} [exp(-|| η_i - η_j ||_2^2 / τ)]
参数:
eta (torch.Tensor): 一个批次的中间特征,形状为 (batch_size, seq_len, dim) 或 (batch_size, features)。
tau (float): 温度超参数,根据论文设置为 1.0。
返回:
torch.Tensor: 计算出的 dispersive loss (标量).
"""
batch_size = eta.shape[0]
# 如果批次大小小于等于1,无法计算配对距离,损失为0
if batch_size <= 1:
return torch.tensor(0.0, device=eta.device, dtype=eta.dtype)
# 如果特征是序列化的 (B, N, D),则将其展平为 (B, N*D)
if eta.dim() > 2:
eta = eta.view(batch_size, -1)
# 高效计算批次内样本间的成对L2距离的平方
# (x-y)^2 = x^2 - 2xy + y^2
eta_norm_sq = torch.sum(eta**2, dim=1, keepdim=True)
# dist_sq 是一个 (batch_size, batch_size) 的矩阵,存储了距离的平方
dist_sq = eta_norm_sq - 2 * torch.mm(eta, eta.t()) + eta_norm_sq.t()
# 确保数值稳定性,距离的平方应为非负数
dist_sq = torch.clamp(dist_sq, min=0.0)
# 计算 e^(-d^2/τ)
loss_matrix = torch.exp(-dist_sq / tau)
# 计算期望 E_{i,j},即对所有 B*B 配对取均值
expectation = torch.mean(loss_matrix)
# 最终损失是期望的对数
loss = torch.log(expectation)
return loss