Upload losses.py
Browse files
losses.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loss functions for SDF and VG stages.
|
| 3 |
+
All pure PyTorch, no compiled extensions needed.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def gaussian_kernel(points, queries):
|
| 11 |
+
"""
|
| 12 |
+
points: (n, k, 3) or (n, 3)
|
| 13 |
+
queries: (n, 3)
|
| 14 |
+
Returns: (n, k) or (n,) Gaussian weights.
|
| 15 |
+
"""
|
| 16 |
+
if points.dim() == 2:
|
| 17 |
+
points = points.unsqueeze(1) # (n, 1, 3)
|
| 18 |
+
pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) # (n, k)
|
| 19 |
+
h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8
|
| 20 |
+
dist_exp = torch.exp(-pts_dist ** 2 / h ** 2)
|
| 21 |
+
gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8)
|
| 22 |
+
return gaussian_weight
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pull_knn_loss(points, samples_moved, samples):
|
| 26 |
+
"""KNN-weighted pull loss."""
|
| 27 |
+
# points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3)
|
| 28 |
+
if points.dim() == 2:
|
| 29 |
+
points = points.unsqueeze(1)
|
| 30 |
+
loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1)
|
| 31 |
+
g_weight = gaussian_kernel(points, samples).detach()
|
| 32 |
+
loss_pull = (loss_pull * g_weight).sum(dim=-1).mean()
|
| 33 |
+
return loss_pull
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples):
|
| 37 |
+
if points.dim() == 2:
|
| 38 |
+
points = points.unsqueeze(1)
|
| 39 |
+
loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1))
|
| 40 |
+
g_weight = gaussian_kernel(points, samples).detach()
|
| 41 |
+
loss_consis = (loss_consis * g_weight).sum(dim=-1).mean()
|
| 42 |
+
return loss_consis
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def eikonal_loss(gradients_samples):
|
| 46 |
+
loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean()
|
| 47 |
+
return loss_eikonal
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def div_loss(samples, gradients_samples):
|
| 51 |
+
"""Divergence of gradient field."""
|
| 52 |
+
def gradient(inputs, outputs):
|
| 53 |
+
inputs.requires_grad_(True)
|
| 54 |
+
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
|
| 55 |
+
points_grad = torch.autograd.grad(
|
| 56 |
+
outputs=outputs,
|
| 57 |
+
inputs=inputs,
|
| 58 |
+
grad_outputs=d_points,
|
| 59 |
+
create_graph=True,
|
| 60 |
+
retain_graph=True,
|
| 61 |
+
only_inputs=True)[0]
|
| 62 |
+
return points_grad
|
| 63 |
+
|
| 64 |
+
div_dx = gradient(samples, gradients_samples[:, 0])
|
| 65 |
+
div_dy = gradient(samples, gradients_samples[:, 1])
|
| 66 |
+
div_dz = gradient(samples, gradients_samples[:, 2])
|
| 67 |
+
divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2]
|
| 68 |
+
loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean()
|
| 69 |
+
return loss_div
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def sdf_loss(sample_sdf):
|
| 73 |
+
return torch.abs(sample_sdf).mean()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def cal_curvature_with_normal(pts, normals, knn=16):
|
| 77 |
+
"""
|
| 78 |
+
Estimate curvature from normal variation over k-NN.
|
| 79 |
+
pts: (n, 3) normals: (n, 3)
|
| 80 |
+
Returns: (n, 1) curvature in [0, 1].
|
| 81 |
+
"""
|
| 82 |
+
from scipy.spatial import KDTree
|
| 83 |
+
pts_np = pts.detach().cpu().numpy()
|
| 84 |
+
normals_np = normals.detach().cpu().numpy()
|
| 85 |
+
tree = KDTree(pts_np)
|
| 86 |
+
_, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np)))
|
| 87 |
+
idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) # (n, k)
|
| 88 |
+
neigh_pts = pts[idx] # (n, k, 3)
|
| 89 |
+
neigh_normals = normals[idx] # (n, k, 3)
|
| 90 |
+
neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) # (n, k)
|
| 91 |
+
g_weight = gaussian_kernel(neigh_pts, pts).detach()
|
| 92 |
+
pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) # (n, 1)
|
| 93 |
+
|
| 94 |
+
pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave))
|
| 95 |
+
pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / (
|
| 96 |
+
torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6)
|
| 97 |
+
return pts_curvature_ave
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx):
|
| 101 |
+
"""Normal consistency loss."""
|
| 102 |
+
neigh_normals = sample_normals[sur_neigh_idx]
|
| 103 |
+
nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean()
|
| 104 |
+
return nc_loss
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp):
|
| 108 |
+
"""
|
| 109 |
+
Bidirectional Chamfer with curvature weighting + repulsion.
|
| 110 |
+
loss_w: [w_curv, w_plain, w_reverse, w_repulsion]
|
| 111 |
+
"""
|
| 112 |
+
from scipy.spatial import KDTree
|
| 113 |
+
|
| 114 |
+
# ---- nearest from surface to samples ----
|
| 115 |
+
sample_np = sample_pts.detach().cpu().numpy()
|
| 116 |
+
sur_np = sur_pts.detach().cpu().numpy()
|
| 117 |
+
tree = KDTree(sample_np)
|
| 118 |
+
_, sur_neigh_idx = tree.query(sur_np, k=1)
|
| 119 |
+
sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device)
|
| 120 |
+
sur_neigh_pts = sample_pts[sur_neigh_idx]
|
| 121 |
+
dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2
|
| 122 |
+
weight_sur_curvature = curvature_surface
|
| 123 |
+
loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean()
|
| 124 |
+
|
| 125 |
+
# ---- nearest from samples to surface ----
|
| 126 |
+
tree2 = KDTree(sur_np)
|
| 127 |
+
_, sample_neigh_idx = tree2.query(sample_np, k=1)
|
| 128 |
+
sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device)
|
| 129 |
+
sample_neigh_pts = sur_pts[sample_neigh_idx]
|
| 130 |
+
dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2
|
| 131 |
+
loss_part_2 = loss_w[2] * dist_2.mean()
|
| 132 |
+
|
| 133 |
+
# ---- repulsion (self-nearest-neighbor distance) ----
|
| 134 |
+
tree3 = KDTree(sample_np)
|
| 135 |
+
_, sample_neigh_idx_self = tree3.query(sample_np, k=2)
|
| 136 |
+
sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device)
|
| 137 |
+
sample_neigh_pts_self = sample_pts[sample_neigh_idx_self]
|
| 138 |
+
relative_dist = sample_neigh_pts_self - sample_pts
|
| 139 |
+
norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2
|
| 140 |
+
norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean())
|
| 141 |
+
loss_part_3 = -norm_dist_.mean() * loss_w[3]
|
| 142 |
+
|
| 143 |
+
chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3
|
| 144 |
+
return chamfer_loss, sur_neigh_idx, sample_neigh_idx
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp):
|
| 148 |
+
"""Total VG stage loss."""
|
| 149 |
+
sample_normal = F.normalize(sample_grad.detach(), dim=-1)
|
| 150 |
+
chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss(
|
| 151 |
+
sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp)
|
| 152 |
+
normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx)
|
| 153 |
+
loss = chamfer_loss + loss_w[4] * normal_loss
|
| 154 |
+
return loss
|