bdck commited on
Commit
8e3ab85
·
verified ·
1 Parent(s): 9ef936d

Upload losses.py

Browse files
Files changed (1) hide show
  1. losses.py +154 -0
losses.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for SDF and VG stages.
3
+ All pure PyTorch, no compiled extensions needed.
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+
10
+ def gaussian_kernel(points, queries):
11
+ """
12
+ points: (n, k, 3) or (n, 3)
13
+ queries: (n, 3)
14
+ Returns: (n, k) or (n,) Gaussian weights.
15
+ """
16
+ if points.dim() == 2:
17
+ points = points.unsqueeze(1) # (n, 1, 3)
18
+ pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) # (n, k)
19
+ h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8
20
+ dist_exp = torch.exp(-pts_dist ** 2 / h ** 2)
21
+ gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8)
22
+ return gaussian_weight
23
+
24
+
25
+ def pull_knn_loss(points, samples_moved, samples):
26
+ """KNN-weighted pull loss."""
27
+ # points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3)
28
+ if points.dim() == 2:
29
+ points = points.unsqueeze(1)
30
+ loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1)
31
+ g_weight = gaussian_kernel(points, samples).detach()
32
+ loss_pull = (loss_pull * g_weight).sum(dim=-1).mean()
33
+ return loss_pull
34
+
35
+
36
+ def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples):
37
+ if points.dim() == 2:
38
+ points = points.unsqueeze(1)
39
+ loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1))
40
+ g_weight = gaussian_kernel(points, samples).detach()
41
+ loss_consis = (loss_consis * g_weight).sum(dim=-1).mean()
42
+ return loss_consis
43
+
44
+
45
+ def eikonal_loss(gradients_samples):
46
+ loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean()
47
+ return loss_eikonal
48
+
49
+
50
+ def div_loss(samples, gradients_samples):
51
+ """Divergence of gradient field."""
52
+ def gradient(inputs, outputs):
53
+ inputs.requires_grad_(True)
54
+ d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
55
+ points_grad = torch.autograd.grad(
56
+ outputs=outputs,
57
+ inputs=inputs,
58
+ grad_outputs=d_points,
59
+ create_graph=True,
60
+ retain_graph=True,
61
+ only_inputs=True)[0]
62
+ return points_grad
63
+
64
+ div_dx = gradient(samples, gradients_samples[:, 0])
65
+ div_dy = gradient(samples, gradients_samples[:, 1])
66
+ div_dz = gradient(samples, gradients_samples[:, 2])
67
+ divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2]
68
+ loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean()
69
+ return loss_div
70
+
71
+
72
+ def sdf_loss(sample_sdf):
73
+ return torch.abs(sample_sdf).mean()
74
+
75
+
76
+ def cal_curvature_with_normal(pts, normals, knn=16):
77
+ """
78
+ Estimate curvature from normal variation over k-NN.
79
+ pts: (n, 3) normals: (n, 3)
80
+ Returns: (n, 1) curvature in [0, 1].
81
+ """
82
+ from scipy.spatial import KDTree
83
+ pts_np = pts.detach().cpu().numpy()
84
+ normals_np = normals.detach().cpu().numpy()
85
+ tree = KDTree(pts_np)
86
+ _, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np)))
87
+ idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) # (n, k)
88
+ neigh_pts = pts[idx] # (n, k, 3)
89
+ neigh_normals = normals[idx] # (n, k, 3)
90
+ neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) # (n, k)
91
+ g_weight = gaussian_kernel(neigh_pts, pts).detach()
92
+ pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) # (n, 1)
93
+
94
+ pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave))
95
+ pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / (
96
+ torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6)
97
+ return pts_curvature_ave
98
+
99
+
100
+ def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx):
101
+ """Normal consistency loss."""
102
+ neigh_normals = sample_normals[sur_neigh_idx]
103
+ nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean()
104
+ return nc_loss
105
+
106
+
107
+ def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp):
108
+ """
109
+ Bidirectional Chamfer with curvature weighting + repulsion.
110
+ loss_w: [w_curv, w_plain, w_reverse, w_repulsion]
111
+ """
112
+ from scipy.spatial import KDTree
113
+
114
+ # ---- nearest from surface to samples ----
115
+ sample_np = sample_pts.detach().cpu().numpy()
116
+ sur_np = sur_pts.detach().cpu().numpy()
117
+ tree = KDTree(sample_np)
118
+ _, sur_neigh_idx = tree.query(sur_np, k=1)
119
+ sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device)
120
+ sur_neigh_pts = sample_pts[sur_neigh_idx]
121
+ dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2
122
+ weight_sur_curvature = curvature_surface
123
+ loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean()
124
+
125
+ # ---- nearest from samples to surface ----
126
+ tree2 = KDTree(sur_np)
127
+ _, sample_neigh_idx = tree2.query(sample_np, k=1)
128
+ sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device)
129
+ sample_neigh_pts = sur_pts[sample_neigh_idx]
130
+ dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2
131
+ loss_part_2 = loss_w[2] * dist_2.mean()
132
+
133
+ # ---- repulsion (self-nearest-neighbor distance) ----
134
+ tree3 = KDTree(sample_np)
135
+ _, sample_neigh_idx_self = tree3.query(sample_np, k=2)
136
+ sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device)
137
+ sample_neigh_pts_self = sample_pts[sample_neigh_idx_self]
138
+ relative_dist = sample_neigh_pts_self - sample_pts
139
+ norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2
140
+ norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean())
141
+ loss_part_3 = -norm_dist_.mean() * loss_w[3]
142
+
143
+ chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3
144
+ return chamfer_loss, sur_neigh_idx, sample_neigh_idx
145
+
146
+
147
+ def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp):
148
+ """Total VG stage loss."""
149
+ sample_normal = F.normalize(sample_grad.detach(), dim=-1)
150
+ chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss(
151
+ sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp)
152
+ normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx)
153
+ loss = chamfer_loss + loss_w[4] * normal_loss
154
+ return loss