""" Loss functions for SDF and VG stages. All pure PyTorch, no compiled extensions needed. """ import torch import torch.nn.functional as F import numpy as np def gaussian_kernel(points, queries): """ points: (n, k, 3) or (n, 3) queries: (n, 3) Returns: (n, k) or (n,) Gaussian weights. """ if points.dim() == 2: points = points.unsqueeze(1) # (n, 1, 3) pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) # (n, k) h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8 dist_exp = torch.exp(-pts_dist ** 2 / h ** 2) gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8) return gaussian_weight def pull_knn_loss(points, samples_moved, samples): """KNN-weighted pull loss.""" # points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3) if points.dim() == 2: points = points.unsqueeze(1) loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1) g_weight = gaussian_kernel(points, samples).detach() loss_pull = (loss_pull * g_weight).sum(dim=-1).mean() return loss_pull def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples): if points.dim() == 2: points = points.unsqueeze(1) loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1)) g_weight = gaussian_kernel(points, samples).detach() loss_consis = (loss_consis * g_weight).sum(dim=-1).mean() return loss_consis def eikonal_loss(gradients_samples): loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean() return loss_eikonal def div_loss(samples, gradients_samples): """Divergence of gradient field.""" def gradient(inputs, outputs): inputs.requires_grad_(True) d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) points_grad = torch.autograd.grad( outputs=outputs, inputs=inputs, grad_outputs=d_points, create_graph=True, retain_graph=True, only_inputs=True)[0] return points_grad div_dx = gradient(samples, gradients_samples[:, 0]) div_dy = gradient(samples, gradients_samples[:, 1]) div_dz = gradient(samples, gradients_samples[:, 2]) divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2] loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean() return loss_div def sdf_loss(sample_sdf): return torch.abs(sample_sdf).mean() def cal_curvature_with_normal(pts, normals, knn=16): """ Estimate curvature from normal variation over k-NN. pts: (n, 3) normals: (n, 3) Returns: (n, 1) curvature in [0, 1]. """ from scipy.spatial import KDTree pts_np = pts.detach().cpu().numpy() normals_np = normals.detach().cpu().numpy() tree = KDTree(pts_np) _, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np))) idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) # (n, k) neigh_pts = pts[idx] # (n, k, 3) neigh_normals = normals[idx] # (n, k, 3) neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) # (n, k) g_weight = gaussian_kernel(neigh_pts, pts).detach() pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) # (n, 1) pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave)) pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / ( torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6) return pts_curvature_ave def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx): """Normal consistency loss.""" neigh_normals = sample_normals[sur_neigh_idx] nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean() return nc_loss def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp): """ Bidirectional Chamfer with curvature weighting + repulsion. loss_w: [w_curv, w_plain, w_reverse, w_repulsion] """ from scipy.spatial import KDTree # ---- nearest from surface to samples ---- sample_np = sample_pts.detach().cpu().numpy() sur_np = sur_pts.detach().cpu().numpy() tree = KDTree(sample_np) _, sur_neigh_idx = tree.query(sur_np, k=1) sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device) sur_neigh_pts = sample_pts[sur_neigh_idx] dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2 weight_sur_curvature = curvature_surface loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean() # ---- nearest from samples to surface ---- tree2 = KDTree(sur_np) _, sample_neigh_idx = tree2.query(sample_np, k=1) sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device) sample_neigh_pts = sur_pts[sample_neigh_idx] dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2 loss_part_2 = loss_w[2] * dist_2.mean() # ---- repulsion (self-nearest-neighbor distance) ---- tree3 = KDTree(sample_np) _, sample_neigh_idx_self = tree3.query(sample_np, k=2) sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device) sample_neigh_pts_self = sample_pts[sample_neigh_idx_self] relative_dist = sample_neigh_pts_self - sample_pts norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2 norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean()) loss_part_3 = -norm_dist_.mean() * loss_w[3] chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3 return chamfer_loss, sur_neigh_idx, sample_neigh_idx def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp): """Total VG stage loss.""" sample_normal = F.normalize(sample_grad.detach(), dim=-1) chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss( sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp) normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx) loss = chamfer_loss + loss_w[4] * normal_loss return loss