| """ |
| Loss functions for SDF and VG stages. |
| All pure PyTorch, no compiled extensions needed. |
| """ |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| def gaussian_kernel(points, queries): |
| """ |
| points: (n, k, 3) or (n, 3) |
| queries: (n, 3) |
| Returns: (n, k) or (n,) Gaussian weights. |
| """ |
| if points.dim() == 2: |
| points = points.unsqueeze(1) |
| pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) |
| h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8 |
| dist_exp = torch.exp(-pts_dist ** 2 / h ** 2) |
| gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8) |
| return gaussian_weight |
|
|
|
|
| def pull_knn_loss(points, samples_moved, samples): |
| """KNN-weighted pull loss.""" |
| |
| if points.dim() == 2: |
| points = points.unsqueeze(1) |
| loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1) |
| g_weight = gaussian_kernel(points, samples).detach() |
| loss_pull = (loss_pull * g_weight).sum(dim=-1).mean() |
| return loss_pull |
|
|
|
|
| def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples): |
| if points.dim() == 2: |
| points = points.unsqueeze(1) |
| loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1)) |
| g_weight = gaussian_kernel(points, samples).detach() |
| loss_consis = (loss_consis * g_weight).sum(dim=-1).mean() |
| return loss_consis |
|
|
|
|
| def eikonal_loss(gradients_samples): |
| loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean() |
| return loss_eikonal |
|
|
|
|
| def div_loss(samples, gradients_samples): |
| """Divergence of gradient field.""" |
| def gradient(inputs, outputs): |
| inputs.requires_grad_(True) |
| d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) |
| points_grad = torch.autograd.grad( |
| outputs=outputs, |
| inputs=inputs, |
| grad_outputs=d_points, |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True)[0] |
| return points_grad |
|
|
| div_dx = gradient(samples, gradients_samples[:, 0]) |
| div_dy = gradient(samples, gradients_samples[:, 1]) |
| div_dz = gradient(samples, gradients_samples[:, 2]) |
| divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2] |
| loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean() |
| return loss_div |
|
|
|
|
| def sdf_loss(sample_sdf): |
| return torch.abs(sample_sdf).mean() |
|
|
|
|
| def cal_curvature_with_normal(pts, normals, knn=16): |
| """ |
| Estimate curvature from normal variation over k-NN. |
| pts: (n, 3) normals: (n, 3) |
| Returns: (n, 1) curvature in [0, 1]. |
| """ |
| from scipy.spatial import KDTree |
| pts_np = pts.detach().cpu().numpy() |
| normals_np = normals.detach().cpu().numpy() |
| tree = KDTree(pts_np) |
| _, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np))) |
| idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) |
| neigh_pts = pts[idx] |
| neigh_normals = normals[idx] |
| neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) |
| g_weight = gaussian_kernel(neigh_pts, pts).detach() |
| pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) |
|
|
| pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave)) |
| pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / ( |
| torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6) |
| return pts_curvature_ave |
|
|
|
|
| def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx): |
| """Normal consistency loss.""" |
| neigh_normals = sample_normals[sur_neigh_idx] |
| nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean() |
| return nc_loss |
|
|
|
|
| def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp): |
| """ |
| Bidirectional Chamfer with curvature weighting + repulsion. |
| loss_w: [w_curv, w_plain, w_reverse, w_repulsion] |
| """ |
| from scipy.spatial import KDTree |
|
|
| |
| sample_np = sample_pts.detach().cpu().numpy() |
| sur_np = sur_pts.detach().cpu().numpy() |
| tree = KDTree(sample_np) |
| _, sur_neigh_idx = tree.query(sur_np, k=1) |
| sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device) |
| sur_neigh_pts = sample_pts[sur_neigh_idx] |
| dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2 |
| weight_sur_curvature = curvature_surface |
| loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean() |
|
|
| |
| tree2 = KDTree(sur_np) |
| _, sample_neigh_idx = tree2.query(sample_np, k=1) |
| sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device) |
| sample_neigh_pts = sur_pts[sample_neigh_idx] |
| dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2 |
| loss_part_2 = loss_w[2] * dist_2.mean() |
|
|
| |
| tree3 = KDTree(sample_np) |
| _, sample_neigh_idx_self = tree3.query(sample_np, k=2) |
| sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device) |
| sample_neigh_pts_self = sample_pts[sample_neigh_idx_self] |
| relative_dist = sample_neigh_pts_self - sample_pts |
| norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2 |
| norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean()) |
| loss_part_3 = -norm_dist_.mean() * loss_w[3] |
|
|
| chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3 |
| return chamfer_loss, sur_neigh_idx, sample_neigh_idx |
|
|
|
|
| def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp): |
| """Total VG stage loss.""" |
| sample_normal = F.normalize(sample_grad.detach(), dim=-1) |
| chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss( |
| sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp) |
| normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx) |
| loss = chamfer_loss + loss_w[4] * normal_loss |
| return loss |
|
|