File size: 6,300 Bytes
8e3ab85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Loss functions for SDF and VG stages.
All pure PyTorch, no compiled extensions needed.
"""
import torch
import torch.nn.functional as F
import numpy as np
def gaussian_kernel(points, queries):
"""
points: (n, k, 3) or (n, 3)
queries: (n, 3)
Returns: (n, k) or (n,) Gaussian weights.
"""
if points.dim() == 2:
points = points.unsqueeze(1) # (n, 1, 3)
pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) # (n, k)
h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8
dist_exp = torch.exp(-pts_dist ** 2 / h ** 2)
gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8)
return gaussian_weight
def pull_knn_loss(points, samples_moved, samples):
"""KNN-weighted pull loss."""
# points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3)
if points.dim() == 2:
points = points.unsqueeze(1)
loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1)
g_weight = gaussian_kernel(points, samples).detach()
loss_pull = (loss_pull * g_weight).sum(dim=-1).mean()
return loss_pull
def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples):
if points.dim() == 2:
points = points.unsqueeze(1)
loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1))
g_weight = gaussian_kernel(points, samples).detach()
loss_consis = (loss_consis * g_weight).sum(dim=-1).mean()
return loss_consis
def eikonal_loss(gradients_samples):
loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean()
return loss_eikonal
def div_loss(samples, gradients_samples):
"""Divergence of gradient field."""
def gradient(inputs, outputs):
inputs.requires_grad_(True)
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = torch.autograd.grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return points_grad
div_dx = gradient(samples, gradients_samples[:, 0])
div_dy = gradient(samples, gradients_samples[:, 1])
div_dz = gradient(samples, gradients_samples[:, 2])
divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2]
loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean()
return loss_div
def sdf_loss(sample_sdf):
return torch.abs(sample_sdf).mean()
def cal_curvature_with_normal(pts, normals, knn=16):
"""
Estimate curvature from normal variation over k-NN.
pts: (n, 3) normals: (n, 3)
Returns: (n, 1) curvature in [0, 1].
"""
from scipy.spatial import KDTree
pts_np = pts.detach().cpu().numpy()
normals_np = normals.detach().cpu().numpy()
tree = KDTree(pts_np)
_, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np)))
idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) # (n, k)
neigh_pts = pts[idx] # (n, k, 3)
neigh_normals = normals[idx] # (n, k, 3)
neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) # (n, k)
g_weight = gaussian_kernel(neigh_pts, pts).detach()
pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) # (n, 1)
pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave))
pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / (
torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6)
return pts_curvature_ave
def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx):
"""Normal consistency loss."""
neigh_normals = sample_normals[sur_neigh_idx]
nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean()
return nc_loss
def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp):
"""
Bidirectional Chamfer with curvature weighting + repulsion.
loss_w: [w_curv, w_plain, w_reverse, w_repulsion]
"""
from scipy.spatial import KDTree
# ---- nearest from surface to samples ----
sample_np = sample_pts.detach().cpu().numpy()
sur_np = sur_pts.detach().cpu().numpy()
tree = KDTree(sample_np)
_, sur_neigh_idx = tree.query(sur_np, k=1)
sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device)
sur_neigh_pts = sample_pts[sur_neigh_idx]
dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2
weight_sur_curvature = curvature_surface
loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean()
# ---- nearest from samples to surface ----
tree2 = KDTree(sur_np)
_, sample_neigh_idx = tree2.query(sample_np, k=1)
sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device)
sample_neigh_pts = sur_pts[sample_neigh_idx]
dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2
loss_part_2 = loss_w[2] * dist_2.mean()
# ---- repulsion (self-nearest-neighbor distance) ----
tree3 = KDTree(sample_np)
_, sample_neigh_idx_self = tree3.query(sample_np, k=2)
sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device)
sample_neigh_pts_self = sample_pts[sample_neigh_idx_self]
relative_dist = sample_neigh_pts_self - sample_pts
norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2
norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean())
loss_part_3 = -norm_dist_.mean() * loss_w[3]
chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3
return chamfer_loss, sur_neigh_idx, sample_neigh_idx
def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp):
"""Total VG stage loss."""
sample_normal = F.normalize(sample_grad.detach(), dim=-1)
chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss(
sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp)
normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx)
loss = chamfer_loss + loss_w[4] * normal_loss
return loss
|