Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import math | |
| import torch.nn as nn | |
| def _compute_sobel_gradients_chunked(image, chunk_size=4): | |
| """ | |
| Computes Sobel gradients for a batch of images in chunks to save memory. | |
| Args: | |
| image: Tensor of shape [B, H, W] | |
| chunk_size: The size of each chunk to process. | |
| Returns: | |
| tuple: (grad_x, grad_y), each of shape [B, H, W] | |
| """ | |
| # Sobel kernels | |
| sobel_x_kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], | |
| dtype=image.dtype, device=image.device).view(1, 1, 3, 3) | |
| sobel_y_kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], | |
| dtype=image.dtype, device=image.device).view(1, 1, 3, 3) | |
| batch_size = image.shape[0] | |
| # 如果 batch_size 小于等于 chunk_size,就不需要分块 | |
| if batch_size <= chunk_size: | |
| image_unsqueezed = image.unsqueeze(1) | |
| grad_x = F.conv2d(image_unsqueezed, sobel_x_kernel, padding=1).squeeze(1) | |
| grad_y = F.conv2d(image_unsqueezed, sobel_y_kernel, padding=1).squeeze(1) | |
| return grad_x, grad_y | |
| # --- 分块计算 --- | |
| grads_x_list = [] | |
| grads_y_list = [] | |
| # 使用 torch.no_grad() 上下文可以进一步节省显存,因为我们不需要为卷积操作本身计算梯度 | |
| # Sobel核是固定的,不需要梯度。输入image的梯度会通过torch自动追踪。 | |
| with torch.no_grad(): | |
| for i in range(0, batch_size, chunk_size): | |
| chunk = image[i:i + chunk_size] | |
| chunk_unsqueezed = chunk.unsqueeze(1) | |
| # 对每个块进行卷积 | |
| chunk_grad_x = F.conv2d(chunk_unsqueezed, sobel_x_kernel, padding=1) | |
| chunk_grad_y = F.conv2d(chunk_unsqueezed, sobel_y_kernel, padding=1) | |
| grads_x_list.append(chunk_grad_x) | |
| grads_y_list.append(chunk_grad_y) | |
| # 合并结果 | |
| grad_x = torch.cat(grads_x_list, dim=0).squeeze(1) | |
| grad_y = torch.cat(grads_y_list, dim=0).squeeze(1) | |
| return grad_x, grad_y | |
| # def get_cycle_consistency_depth_loss(pred_depth, gt_depth, gt_mask=None, depth_normalization="log",eps=1e-6): | |
| # """ | |
| # Cycle consistency depth loss based on AbsRel metric (Optimized batch processing version) | |
| # Args: | |
| # pred_depth: [B,C,H,W] or [B,1,H,W] - predicted depth maps | |
| # gt_depth: [B,C,H,W] or [B,1,H,W] - ground truth depth maps | |
| # gt_mask: [B,H,W] or [B,1,H,W] - mask for valid GT positions (optional) | |
| # eps: float - small value to avoid division by zero | |
| # depth_normalization: str - depth normalization method | |
| # Returns: | |
| # torch.Tensor - scalar loss value for backpropagation | |
| # """ | |
| # # Ensure consistent shape -> [B, H, W] | |
| # if pred_depth.dim() == 4 and pred_depth.shape[1] != 1: | |
| # pred_depth = torch.mean(pred_depth, dim=1, keepdim=False) | |
| # elif pred_depth.dim() == 4: | |
| # pred_depth = pred_depth.squeeze(1) | |
| # if gt_depth.dim() == 4 and gt_depth.shape[1] != 1: | |
| # gt_depth = torch.mean(gt_depth, dim=1, keepdim=False) | |
| # elif gt_depth.dim() == 4: | |
| # gt_depth = gt_depth.squeeze(1) | |
| # # Handle shape mismatch by resizing prediction to match GT | |
| # if pred_depth.shape != gt_depth.shape: | |
| # pred_depth = F.interpolate(pred_depth.unsqueeze(1), size=gt_depth.shape[-2:], | |
| # mode='bilinear', align_corners=False).squeeze(1) | |
| # # Handle invalid values in GT and prediction (batch processing) | |
| # gt_depth = torch.where(torch.isinf(gt_depth), torch.zeros_like(gt_depth), gt_depth) | |
| # gt_depth = torch.where(torch.isnan(gt_depth), torch.zeros_like(gt_depth), gt_depth) | |
| # pred_depth = torch.where(torch.isinf(pred_depth), torch.zeros_like(pred_depth), pred_depth) | |
| # pred_depth = torch.where(torch.isnan(pred_depth), torch.zeros_like(pred_depth), pred_depth) | |
| # # Create valid mask | |
| # if gt_mask is not None: | |
| # if gt_mask.dim() == 4: | |
| # gt_mask = gt_mask.squeeze(1) # [B, H, W] | |
| # valid_mask = gt_mask.bool() | |
| # else: | |
| # valid_mask = torch.ones_like(gt_depth, dtype=torch.bool, device=gt_depth.device) | |
| # # Clamp depths to valid range | |
| # pred_depth = torch.clamp(pred_depth, -1.0, 1.0) | |
| # gt_depth = torch.clamp(gt_depth, -1.0, 1.0) | |
| # # Apply depth normalization (batch processing) | |
| # if depth_normalization == "log": | |
| # pred_depth = torch.exp(pred_depth) | |
| # gt_depth = torch.exp(gt_depth) | |
| # elif depth_normalization == "sqrt": | |
| # pred_depth = pred_depth**2 | |
| # gt_depth = gt_depth**2 | |
| # elif depth_normalization == "disp": | |
| # pred_depth = 1.0/(pred_depth + eps) | |
| # gt_depth = 1.0/(gt_depth + eps) | |
| # elif depth_normalization == "sqrt_disp": | |
| # pred_depth = 1.0/(pred_depth**2 + eps) | |
| # gt_depth = 1.0/(gt_depth**2 + eps) | |
| # elif depth_normalization == "uniform": | |
| # pass # already in the same scale as real relative depth | |
| # else: | |
| # raise ValueError(f"Unknown depth normalization: {depth_normalization}") | |
| # # Compute AbsRel metric for each sample in the batch | |
| # pred_depth = torch.log(pred_depth + eps) # [B, H, W] | |
| # gt_depth = torch.log(gt_depth + eps) # [B, H, W | |
| # abs_diff = torch.abs(pred_depth - gt_depth) # [B, H, W] | |
| # rel_diff = abs_diff | |
| # # rel_diff = abs_diff / (gt_depth + eps) # [B, H, W], add eps to avoid division by zero | |
| # # Apply mask and compute mean for each sample | |
| # masked_rel_diff = rel_diff * valid_mask.float() # [B, H, W] | |
| # # Sum over spatial dimensions and divide by number of valid pixels per sample | |
| # valid_pixel_count = valid_mask.sum(dim=[1, 2]).float() # [B] | |
| # sample_losses = masked_rel_diff.sum(dim=[1, 2]) # [B] | |
| # # Avoid division by zero for samples with no valid pixels | |
| # valid_samples_mask = valid_pixel_count > 0 | |
| # sample_losses = sample_losses / torch.clamp(valid_pixel_count, min=1.0) | |
| # # Only consider samples with valid pixels | |
| # if valid_samples_mask.any(): | |
| # final_loss = sample_losses[valid_samples_mask].mean() | |
| # else: | |
| # # If no valid samples, return zero loss | |
| # final_loss = torch.tensor(0.0, device=pred_depth.device, requires_grad=True) | |
| # return final_loss | |
| class ScaleAndShiftInvariantLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = "SSILoss" | |
| def forward(self, prediction, target, mask,depth_normalization="log"): | |
| # if mask.ndim ==3: | |
| # mask = mask.unsqueeze(1) # [B,1,H,W] | |
| if prediction.ndim == 4: | |
| prediction = torch.mean(prediction, dim=1, keepdim=False) # [B,H,W] | |
| if target.ndim ==4: | |
| target = torch.mean(target, dim=1, keepdim=False) # [B,H,W] | |
| mask = mask.bool() | |
| with torch.autocast(device_type='cuda', enabled=False): | |
| prediction = prediction.float() | |
| target = target.float() | |
| if depth_normalization == "log": | |
| target_cp = torch.log(target).requires_grad_(False) | |
| elif depth_normalization == "sqrt": | |
| target_cp = torch.sqrt(target).requires_grad_(False) | |
| scale, shift = compute_scale_and_shift_masked(prediction, target_cp, mask) | |
| del target_cp | |
| scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) | |
| loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) | |
| return loss | |
| def compute_scale_and_shift_masked(prediction, target, mask): | |
| # system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
| a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
| a_01 = torch.sum(mask * prediction, (1, 2)) | |
| a_11 = torch.sum(mask, (1, 2)) | |
| # right hand side: b = [b_0, b_1] | |
| b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
| b_1 = torch.sum(mask * target, (1, 2)) | |
| # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
| x_0 = torch.zeros_like(b_0) | |
| x_1 = torch.zeros_like(b_1) | |
| det = a_00 * a_11 - a_01 * a_01 | |
| # A needs to be a positive definite matrix. | |
| valid = det > 0 #1e-3 | |
| x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
| x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
| return x_0, x_1 | |
| def get_cycle_consistency_depth_loss(pred, gt, mask=None, depth_normalization="log",eps=1e-8, | |
| alpha=0.85, # SILog 损失的权重 | |
| beta=1.0, # 梯度损失的权重 | |
| gamma=0.0, | |
| per_sample_weights=None): # 可选的法线一致性损失的权重 | |
| """ | |
| 一个专注于几何结构的、稳定版的深度损失函数。 | |
| 它结合了尺度不变损失(SILog)和局部结构损失(梯度匹配)。 | |
| Args: | |
| pred: [B,3,H,W] 或 [B,H,W,3], 范围 [-1, 1] | |
| gt: [B,3,H,W] 或 [B,H,W,3], 范围 [-1, 1] | |
| mask: None 或 [B,H,W] / [B,1,H,W] / [B,H,W,1] | |
| eps: 数值稳定项 | |
| alpha: SILog 损失中的方差项权重 (通常 0.5 或 0.85) | |
| beta: 梯度损失的权重 | |
| gamma: 法线一致性损失的权重 (如果>0,会计算法线损失) | |
| Returns: | |
| 单标量 tensor(可反传) | |
| """ | |
| # --- 形状、device、dtype 规范化 --- | |
| if pred.dim() == 4 and pred.shape[1] != 3 and pred.shape[-1] == 3: | |
| pred = pred.permute(0, 3, 1, 2) | |
| if gt.dim() == 4 and gt.shape[1] != 3 and gt.shape[-1] == 3: | |
| gt = gt.permute(0, 3, 1, 2) | |
| B, C, H, W = pred.shape | |
| device = pred.device | |
| dtype = pred.dtype | |
| # if per_sample_weights is None: | |
| # per_sample_weights = torch.ones((B,), device=device, dtype=dtype) | |
| # --- 1. 预处理 --- | |
| # 取3通道均值得到单通道深度图 | |
| pred = pred.mean(dim=1,keepdim=True) # [B,1,H,W] | |
| gt = gt.mean(dim=1,keepdim=True) # [B,1,H,W] | |
| # 将 [-1, 1] 范围映射到正数范围 (e.g., [eps, 1+eps]),为 log 做准备 | |
| # 这一步非常重要,因为它将归一化的值转回到了一个类似深度的正数空间 | |
| pred = (pred.clamp(min=-1, max=1) + 1.0) / 2.0 + eps | |
| gt = (gt.clamp(min=-1, max=1) + 1.0) / 2.0 + eps | |
| # if depth_normalization == "log": | |
| # pred = torch.exp(pred) | |
| # gt = torch.exp(gt) | |
| # elif depth_normalization == "sqrt": | |
| # pred = pred**2 | |
| # gt = gt**2 | |
| # elif depth_normalization == "disp": | |
| # pred = 1.0/(pred + eps) | |
| # gt = 1.0/(gt + eps) | |
| # elif depth_normalization == "sqrt_disp": | |
| # pred = 1.0/(pred**2 + eps) | |
| # gt = 1.0/(gt**2 + eps) | |
| # else: | |
| # # uniform | |
| # pass # already in the same scale as real relative depth | |
| # --- mask 处理 --- | |
| if mask is None: | |
| # 如果没有mask,创建一个全为True的mask | |
| mask = torch.ones_like(pred, dtype=torch.bool, device=device) | |
| else: | |
| # 确保mask是 [B,H,W] 的 bool tensor | |
| if mask.ndim == 4: | |
| mask = torch.mean(mask, dim=1,keepdim=True).bool() # [B,1,H,W] | |
| elif mask.ndim ==3: | |
| mask = mask.unsqueeze(1).bool() # [B,1,H,W] | |
| mask = mask.bool().to(device) | |
| # 检查有效像素数量 | |
| total_valid = mask.sum() | |
| if total_valid.item() == 0: | |
| return torch.zeros((), device=device, dtype=dtype, requires_grad=True) | |
| # --- 2. 计算复合损失 --- | |
| # --- Component A: Scale-Invariant Logarithmic (SILog) Loss --- | |
| # 这是点对点损失的鲁棒版本,对尺度不敏感 | |
| log_diff = torch.log(1+pred[mask]) - torch.log(1+gt[mask]) # B,N | |
| # # SILog 公式: E[d^2] - alpha * (E[d])^2 | |
| silog_term1 = torch.mean(log_diff ** 2) | |
| silog_term2 = (torch.mean(log_diff)) ** 2 | |
| loss_silog = silog_term1 - alpha * silog_term2 | |
| loss_silog = (loss_silog).mean() | |
| # --- Component B: Gradient Matching Loss --- | |
| # 这是结构损失,确保表面的局部坡度一致 | |
| # 使用简单的 sobel 算子来计算梯度 | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) | |
| sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) | |
| pred_grad_x = F.conv2d(pred, sobel_x, padding=1) | |
| pred_grad_y = F.conv2d(pred, sobel_y, padding=1) | |
| gt_grad_x = F.conv2d(gt, sobel_x, padding=1) | |
| gt_grad_y = F.conv2d(gt, sobel_y, padding=1) | |
| # 计算梯度差异的 L1 损失 | |
| grad_loss_x = torch.abs(pred_grad_x - gt_grad_x) | |
| grad_loss_y = torch.abs(pred_grad_y - gt_grad_y) | |
| # 只在 mask 区域计算损失 | |
| # 稍微腐蚀一下mask,因为边界处的梯度可能不准确 | |
| mask_eroded = F.max_pool2d(mask.float(), kernel_size=3, stride=1, padding=1).bool() | |
| loss_grad = (grad_loss_x[mask_eroded].mean() + grad_loss_y[mask_eroded].mean()) | |
| # loss_grad = ((grad_loss_x[mask_eroded].view(B, -1).mean(dim=1) + grad_loss_y[mask_eroded].view(B, -1).mean(dim=1)) * per_sample_weights).mean() | |
| # --- Component C (Optional): Normal Consistency Loss --- | |
| # 如果 gamma > 0,则计算从深度图重建的法线之间的一致性损失 | |
| loss_normal = torch.tensor(0.0, device=device, dtype=dtype) | |
| if gamma > 0: | |
| # 从深度图计算法线 (使用梯度) | |
| # 法线 n = [-gx, -gy, 1],然后归一化 | |
| # 注意:这里假设相机内参 f_x, f_y = 1 | |
| ones = torch.ones_like(pred) | |
| pred_normal = torch.cat([-pred_grad_x, -pred_grad_y, ones], dim=1) | |
| gt_normal = torch.cat([-gt_grad_x, -gt_grad_y, ones], dim=1) | |
| # 使用你成功的法线损失函数 | |
| loss_normal = get_cycle_consistency_normal_loss( | |
| pred_normal, gt_normal, mask=mask_eroded.squeeze(1), | |
| ) | |
| # --- 最终损失 --- | |
| # 加权求和 | |
| # 权重可以根据你的任务进行调整,beta=1.0, gamma=0.0 是一个很好的起点 | |
| total_loss = loss_silog + beta * loss_grad + gamma * loss_normal | |
| return total_loss | |
| # =================================================================== | |
| # 核心辅助函数:从深度图计算法线 | |
| # =================================================================== | |
| # def depth_to_normals(depth, camera_intrinsics=None, eps=1e-8): | |
| # """ | |
| # 从深度图计算表面法线。 | |
| # Args: | |
| # depth: [B, 1, H, W] 深度图 | |
| # camera_intrinsics: [B, 3, 3] 相机内参矩阵 K。如果为 None,则在像素空间计算。 | |
| # Returns: | |
| # normals: [B, 3, H, W] 法线图 | |
| # """ | |
| # B, _, H, W = depth.shape | |
| # device = depth.device | |
| # dtype = depth.dtype | |
| # # 创建像素坐标网格 | |
| # y_coords, x_coords = torch.meshgrid(torch.arange(H, device=device, dtype=dtype), | |
| # torch.arange(W, device=device, dtype=dtype), | |
| # indexing='ij') | |
| # coords = torch.stack([x_coords, y_coords, torch.ones_like(x_coords)], dim=0).unsqueeze(0).repeat(B, 1, 1, 1) | |
| # if camera_intrinsics is not None: | |
| # # 反投影到相机坐标系 | |
| # K_inv = torch.inverse(camera_intrinsics).view(B, 3, 3) | |
| # cam_coords = K_inv @ coords.to(torch.float32).view(B, 3, -1) | |
| # cam_coords = cam_coords.view(B, 3, H, W) | |
| # point_cloud = cam_coords * depth | |
| # else: | |
| # # 在像素空间计算(一个合理的近似) | |
| # point_cloud = torch.cat([coords[:, :2, ...], depth], dim=1) | |
| # # 使用 padding + 卷积计算梯度,更稳定 | |
| # kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) / 8.0 | |
| # kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=device, dtype=dtype).view(1, 1, 3, 3) / 8.0 | |
| # # 对点云的每个分量求梯度 | |
| # point_cloud = point_cloud.to(torch.bfloat16) | |
| # p_x, p_y, p_z = point_cloud[:, 0:1], point_cloud[:, 1:2], point_cloud[:, 2:3] | |
| # tg_x = torch.cat([F.conv2d(p, kernel_x, padding='same') for p in (p_x, p_y, p_z)], dim=1) | |
| # tg_y = torch.cat([F.conv2d(p, kernel_y, padding='same') for p in (p_x, p_y, p_z)], dim=1) | |
| # # 计算法线 (cross product) | |
| # # tg_y x tg_x 的方向通常指向相机 | |
| # normals = torch.cross(tg_y, tg_x, dim=1) | |
| # # 单位化 | |
| # normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(eps) | |
| # return normals | |
| # # =================================================================== | |
| # # 你的全新损失函数! | |
| # # =================================================================== | |
| # def get_cycle_consistency_depth_loss(pred_depth, gt_depth, mask=None, camera_intrinsics=None, depth_normalization="log",eps=1e-8,per_sample_weights=None): | |
| # """ | |
| # 将深度估计问题转化为法线几何一致性问题的损失函数。 | |
| # 1. 使用最小二乘法对齐 pred 和 gt_depth。 | |
| # 2. 将对齐后的 pred 和 gt_depth 转换为法线图。 | |
| # 3. 使用 get_cycle_consistency_normal_loss 计算法线间的损失。 | |
| # Args: | |
| # pred: [B,3,H,W] or [B,H,W,3], 模型预测的深度图,值域 [-1, 1] | |
| # input_depth: [B,H,W] or [B,1,H,W], 原始的、未归一化的真实深度图 | |
| # mask: None or [B,H,W], 有效像素区域 | |
| # camera_intrinsics: 可选, [B, 3, 3] 相机内参, 用于更精确的法线计算 | |
| # eps: 数值稳定项 | |
| # Returns: | |
| # 单标量 tensor (可反传) | |
| # """ | |
| # # --- 形状、device、dtype 规范化 --- | |
| # device = pred_depth.device | |
| # dtype = pred_depth.dtype | |
| # if pred_depth.ndim ==4: | |
| # pred_depth = torch.mean(pred_depth, dim=1) # [B, H, W] | |
| # if gt_depth.ndim == 4: | |
| # gt_depth = torch.mean(gt_depth, dim=1) # [B, H, W] | |
| # B, H, W = pred_depth.shape | |
| # if depth_normalization == "log": | |
| # pred_depth = (pred_depth + 1.0) / 2.0 + eps | |
| # pred_depth = torch.exp(pred_depth) | |
| # elif depth_normalization == "sqrt": | |
| # pred_depth = (pred_depth + 1.0) / 2.0 + eps | |
| # pred_depth = pred_depth**2 | |
| # elif depth_normalization == "disp": | |
| # pred_depth = (pred_depth + 1.0) / 2.0 + eps | |
| # pred_depth = 1.0/(pred_depth + eps) | |
| # elif depth_normalization == "sqrt_disp": | |
| # pred_depth = (pred_depth + 1.0) / 2.0 + eps | |
| # pred_depth = 1.0/(pred_depth**2 + eps) | |
| # # --- Mask 处理 --- | |
| # if mask is None: | |
| # mask = (gt_depth > eps).detach() | |
| # else: | |
| # # if mask.dim() == 3: | |
| # # mask = mask.unsqueeze(1) | |
| # mask = mask.bool().to(device) | |
| # # --- 核心步骤 1: 最小二乘法对齐尺度和偏移 --- | |
| # aligned_pred_depths = [] | |
| # for i in range(B): # 逐个样本处理 | |
| # p = pred_depth[i][mask[i]] | |
| # g = gt_depth[i][mask[i]] | |
| # if p.numel() < 2: # 有效点太少,无法拟合 | |
| # aligned_pred_depths.append(pred_depth[i]) | |
| # continue | |
| # # 构造最小二乘问题 y = s*x + t, A*[s,t]^T = y | |
| # A = torch.stack([p, torch.ones_like(p)], dim=1) | |
| # y = g | |
| # # 使用 torch.linalg.lstsq 求解,稳定且高效 | |
| # try: | |
| # solution = torch.linalg.lstsq(A.to(torch.float32), y.to(torch.float32)).solution | |
| # s, t = solution[0], solution[1] | |
| # except torch.linalg.LinAlgError: | |
| # # 如果矩阵奇异,退化为只匹配均值 | |
| # s = torch.tensor(1.0, device=device, dtype=dtype) | |
| # t = g.mean() - p.mean() | |
| # # 对齐预测值,不让梯度流过 s 和 t | |
| # aligned_pred_depths.append(pred_depth[i] * s.detach() + t.detach()) | |
| # aligned_pred_depth = torch.stack(aligned_pred_depths) | |
| # # --- 核心步骤 2: 从对齐后的深度图计算法线 --- | |
| # # pred_normals = depth_to_normals(aligned_pred_depth, camera_intrinsics, k=5, d=1, gamma=0.05, min_nghbr=4, eps=eps) | |
| # # gt_normals = depth_to_normals(gt_depth, camera_intrinsics, k=5, d=1, gamma=0.05, min_nghbr=4, eps=eps) | |
| # # pred_normals = depth_to_normals(aligned_pred_depth, camera_intrinsics,) | |
| # # gt_normals = depth_to_normals(gt_depth, camera_intrinsics) | |
| # # # --- 核心步骤 3: 复用成功的法线损失 --- | |
| # # # 注意,法线计算在边缘处可能无效,mask需要结合深度mask | |
| # # final_mask = mask & (gt_depth > eps) & (pred_normals.norm(dim=1, keepdim=True) > eps) & (gt_normals.norm(dim=1, keepdim=True) > eps) | |
| # # loss = get_cycle_consistency_normal_loss(pred_normals, gt_normals, final_mask) | |
| # if per_sample_weights is None: | |
| # per_sample_weights = torch.ones((B,), device=device, dtype=dtype) | |
| # loss = per_sample_weights.view(B,1,1)*(aligned_pred_depth-gt_depth)**2*mask # [B,H,W] -> [B] | |
| # return loss.mean().clip(1e-6,1.0) | |
| import torch | |
| import math | |
| def get_cycle_consistency_normal_loss(pred, gt, mask=None, eps=1e-8, per_sample_weights=None): | |
| """ | |
| 稳定版的法线角度损失(可直接反传) | |
| Args: | |
| pred: [B,3,H,W] 或 [B,H,W,3] | |
| gt: [B,3,H,W] 或 [B,H,W,3] | |
| mask: None 或 [B,H,W] / [B,1,H,W] / [B,H,W,1] | |
| eps: 数值稳定项 | |
| per_sample_weights: None 或 [B],每个样本的权重 | |
| Returns: | |
| 单标量 tensor(可反传) | |
| """ | |
| # --- 形状、device、dtype 规范化 --- | |
| if pred.dim() == 4 and pred.shape[1] != 3 and pred.shape[-1] == 3: | |
| pred = pred.permute(0, 3, 1, 2) | |
| if gt.dim() == 4 and gt.shape[1] != 3 and gt.shape[-1] == 3: | |
| gt = gt.permute(0, 3, 1, 2) | |
| B, C, H, W = pred.shape | |
| assert C == 3, "输入必须是3通道法线向量" | |
| device = pred.device | |
| dtype = pred.dtype | |
| # --- 单位化(防止除零) --- | |
| pred = pred / pred.norm(dim=1, keepdim=True).clamp_min(eps) | |
| gt = gt / gt.norm(dim=1, keepdim=True).clamp_min(eps) | |
| # --- mask 处理 --- | |
| if mask is None: | |
| mask = torch.ones((B, H, W), dtype=torch.bool, device=device) | |
| if mask.dim() == 4: | |
| mask = mask.squeeze(1) | |
| if mask.dim() == 3 and mask.shape[1] == 1: # 兼容 [B,1,H,W] | |
| mask = mask.squeeze(1) | |
| mask = mask.bool().to(device) | |
| # flatten为 [B,3,N] / [B,N] | |
| pred = pred.view(B, 3, -1) | |
| gt = gt.view(B, 3, -1) | |
| mask = mask.view(B, -1).bool() # [B, N] | |
| # --- 使用 dot 和 cross -> atan2 更稳定 --- | |
| dot = (pred * gt).sum(dim=1) # [B, N], 取代 acos 的 cos 值 | |
| cross = torch.cross(pred, gt, dim=1) # [B,3,N] | |
| sin_val = cross.norm(dim=1) # [B, N] | |
| # clamp 防止极端数值(但不把 dot 强行压到 ±1) | |
| dot = dot.clamp(-1.0 + 1e-7, 1.0 - 1e-7) | |
| sin_val = sin_val.clamp_min(1e-12) | |
| # angle in degrees | |
| angle = torch.atan2(sin_val, dot) * (180.0 / math.pi) # [B, N] | |
| # --- 按 mask 计算 mean / accuracy(避免空mask导致 NaN) --- | |
| mask_f = mask.float() # [B, N] | |
| counts_per_batch = mask_f.sum(dim=1) # [B] | |
| # total = counts_per_batch.sum() # scalar | |
| # if total.item() == 0: | |
| # # 没有有效像素:返回 0 且保留梯度(避免图断裂) | |
| # return torch.zeros((), device=device, dtype=dtype, requires_grad=True) | |
| # # mean angular error over masked elements | |
| # loss = (angle * mask_f ).sum() / total | |
| if per_sample_weights is None: | |
| per_sample_weights = torch.ones((B,), device=device, dtype=dtype) | |
| loss = (angle * mask_f).sum(dim=1) / counts_per_batch.clamp_min(1) # [B] | |
| loss = (loss * per_sample_weights).mean() | |
| return loss / 100.0 # scale down for稳定(按你原代码保留) | |
| def get_cycle_consistency_matting_loss(pred_alpha, gt_alpha, trimap): | |
| """ | |
| 一个集成的、可微的Matting损失函数 (PyTorch, Batch version)。 | |
| 此函数将SAD、MSE和Gradient Loss的计算合并在一起,并对未知区域进行操作。 | |
| 它被设计为简单、直接,易于集成到训练循环中。 | |
| Args: | |
| pred_alpha (torch.Tensor): 预测的alpha matte, shape [B, 1, H, W], 值在 [0, 1] 之间。 | |
| gt_alpha (torch.Tensor): 真实的alpha matte, shape [B, 1, H, W], 值在 [0, 1] 之间。 | |
| trimap (torch.Tensor): Trimap, shape [B, 1, H, W], 值为 0, 128, 255。 | |
| Returns: | |
| torch.Tensor: 一个用于反向传播的标量损失值。 | |
| """ | |
| # --- 1. 预处理: 创建未知区域的掩码 --- | |
| # 未知区域是trimap中值为128的像素 | |
| if pred_alpha.ndim == 4: | |
| pred_alpha = torch.mean(pred_alpha, dim=1, keepdim=True) # [B,1,H,W] | |
| if gt_alpha.ndim == 4: | |
| gt_alpha = torch.mean(gt_alpha, dim=1, keepdim=True) # [B,1,H,W] | |
| if trimap.ndim == 3: | |
| trimap = trimap.unsqueeze(1) # [B,1,H,W] | |
| unknown_mask = ((trimap >= 127.5) & (trimap <= 128.5)).float() | |
| # --- 2. SAD Loss (Sum of Absolute Differences) --- | |
| error_map_sad = torch.abs(pred_alpha - gt_alpha) | |
| # 对每个样本的未知区域求和,并按原始实现进行缩放 | |
| loss_sad_per_sample = (error_map_sad * unknown_mask).sum(dim=(1, 2, 3)) / 1000.0 | |
| # --- 3. MSE Loss (Mean Squared Error) --- | |
| error_map_mse = (pred_alpha - gt_alpha) ** 2 | |
| # 对每个样本的未知区域求和,然后除以该区域的像素数 | |
| loss_mse_per_sample = (error_map_mse * unknown_mask).sum(dim=(1, 2, 3)) / (unknown_mask.sum(dim=(1, 2, 3)) + 1e-8) | |
| # --- 4. Gradient Loss --- | |
| # 在函数内部动态创建高斯梯度核,以匹配输入张量的设备 | |
| device = pred_alpha.device | |
| sigma = 1.4 | |
| # 创建高斯导数核 | |
| epsilon = 1e-2 | |
| halfsize = math.ceil(sigma * math.sqrt(-2 * math.log(math.sqrt(2 * math.pi) * sigma * epsilon))) | |
| size = 2 * halfsize + 1 | |
| coords = torch.arange(-halfsize, halfsize + 1, dtype=torch.float32, device=device) | |
| gauss_vals = torch.exp(-coords**2 / (2 * sigma**2)) / (sigma * math.sqrt(2 * math.pi)) | |
| dgauss_vals = -coords * gauss_vals / (sigma**2) | |
| # x和y方向的核 | |
| hx = (gauss_vals.unsqueeze(1) * dgauss_vals.unsqueeze(0)) | |
| hy = hx.t() | |
| # 归一化 | |
| hx = hx / torch.sqrt(torch.sum(torch.abs(hx) * torch.abs(hx))) | |
| hy = hy / torch.sqrt(torch.sum(torch.abs(hy) * torch.abs(hy))) | |
| # 为conv2d调整形状: [out_channels, in_channels, H, W] | |
| kernel_x = hx.to(pred_alpha.dtype).unsqueeze(0).unsqueeze(0) | |
| kernel_y = hy.to(pred_alpha.dtype).unsqueeze(0).unsqueeze(0) | |
| # 使用卷积计算梯度 | |
| pred_gx = F.conv2d(pred_alpha, kernel_x, padding='same') | |
| pred_gy = F.conv2d(pred_alpha, kernel_y, padding='same') | |
| gt_gx = F.conv2d(gt_alpha, kernel_x, padding='same') | |
| gt_gy = F.conv2d(gt_alpha, kernel_y, padding='same') | |
| # 计算梯度幅值 | |
| pred_amp = torch.sqrt(pred_gx**2 + pred_gy**2) | |
| gt_amp = torch.sqrt(gt_gx**2 + gt_gy**2) | |
| # 计算梯度损失 | |
| error_map_grad = (pred_amp - gt_amp) ** 2 | |
| loss_grad_per_sample = (error_map_grad * unknown_mask).sum(dim=(1, 2, 3)) / 1000.0 | |
| # --- 5. 组合损失并在Batch维度上取平均 --- | |
| # 你可以根据需要为各项损失添加权重, e.g., 1.0 * loss_sad + 0.5 * loss_mse + ... | |
| total_loss_per_sample = loss_sad_per_sample + loss_mse_per_sample + loss_grad_per_sample | |
| final_loss = total_loss_per_sample.mean() | |
| return final_loss | |
| def get_disperse_loss(eta, tau=1.0): | |
| """ | |
| 计算 Dispersive Loss. | |
| 公式: L_disp = log E_{i,j} [exp(-|| η_i - η_j ||_2^2 / τ)] | |
| 参数: | |
| eta (torch.Tensor): 一个批次的中间特征,形状为 (batch_size, seq_len, dim) 或 (batch_size, features)。 | |
| tau (float): 温度超参数,根据论文设置为 1.0。 | |
| 返回: | |
| torch.Tensor: 计算出的 dispersive loss (标量). | |
| """ | |
| batch_size = eta.shape[0] | |
| # 如果批次大小小于等于1,无法计算配对距离,损失为0 | |
| if batch_size <= 1: | |
| return torch.tensor(0.0, device=eta.device, dtype=eta.dtype) | |
| # 如果特征是序列化的 (B, N, D),则将其展平为 (B, N*D) | |
| if eta.dim() > 2: | |
| eta = eta.view(batch_size, -1) | |
| # 高效计算批次内样本间的成对L2距离的平方 | |
| # (x-y)^2 = x^2 - 2xy + y^2 | |
| eta_norm_sq = torch.sum(eta**2, dim=1, keepdim=True) | |
| # dist_sq 是一个 (batch_size, batch_size) 的矩阵,存储了距离的平方 | |
| dist_sq = eta_norm_sq - 2 * torch.mm(eta, eta.t()) + eta_norm_sq.t() | |
| # 确保数值稳定性,距离的平方应为非负数 | |
| dist_sq = torch.clamp(dist_sq, min=0.0) | |
| # 计算 e^(-d^2/τ) | |
| loss_matrix = torch.exp(-dist_sq / tau) | |
| # 计算期望 E_{i,j},即对所有 B*B 配对取均值 | |
| expectation = torch.mean(loss_matrix) | |
| # 最终损失是期望的对数 | |
| loss = torch.log(expectation) | |
| return loss |