bdck commited on
Commit
2b1915f
Β·
verified Β·
1 Parent(s): 68f2e33

add loss functions (chamfer, beam-gap, normal alignment)

Browse files
Files changed (1) hide show
  1. point2mesh/losses.py +188 -0
point2mesh/losses.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for Point2Mesh optimisation.
3
+
4
+ * Bidirectional Chamfer distance
5
+ * Beam-gap loss (pulls the mesh into narrow cavities)
6
+ * Normal-alignment loss (cosine, handles unoriented normals)
7
+ * Differentiable surface sampling
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from typing import Optional, Tuple
15
+
16
+
17
+ # ──────────────────────────────────────────────────────────────────────
18
+ # Differentiable surface sampling
19
+ # ──────────────────────────────────────────────────────────────────────
20
+ def sample_surface(
21
+ verts: torch.Tensor, # (N_v, 3)
22
+ faces: torch.Tensor, # (N_f, 3) long
23
+ n_samples: int,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Uniformly sample points on the mesh surface (area-weighted).
27
+
28
+ Returns
29
+ -------
30
+ pts : (n_samples, 3) β€” differentiable w.r.t. verts
31
+ face_ids: (n_samples,) β€” which face each point was sampled from
32
+ """
33
+ v0 = verts[faces[:, 0]]
34
+ v1 = verts[faces[:, 1]]
35
+ v2 = verts[faces[:, 2]]
36
+
37
+ # Face areas (differentiable)
38
+ cross = torch.cross(v1 - v0, v2 - v0, dim=1)
39
+ areas = 0.5 * cross.norm(dim=1) # (N_f,)
40
+ probs = areas / areas.sum().clamp(min=1e-12)
41
+
42
+ # Sample faces proportional to area
43
+ face_ids = torch.multinomial(probs, n_samples, replacement=True)
44
+
45
+ # Barycentric sampling: r1, r2 ~ U(0,1) with r1+r2 < 1
46
+ r1 = torch.rand(n_samples, device=verts.device)
47
+ r2 = torch.rand(n_samples, device=verts.device)
48
+ mask = (r1 + r2) >= 1.0
49
+ r1[mask] = 1.0 - r1[mask]
50
+ r2[mask] = 1.0 - r2[mask]
51
+
52
+ fv0 = verts[faces[face_ids, 0]]
53
+ fv1 = verts[faces[face_ids, 1]]
54
+ fv2 = verts[faces[face_ids, 2]]
55
+
56
+ pts = (1 - r1 - r2).unsqueeze(1) * fv0 + r1.unsqueeze(1) * fv1 + r2.unsqueeze(1) * fv2
57
+
58
+ return pts, face_ids
59
+
60
+
61
+ # ──────────────────────────────────────────────────────────────────────
62
+ # Chamfer Distance (bidirectional, L2)
63
+ # ──────────────────────────────────────────────────────────────────────
64
+ def chamfer_loss(
65
+ X: torch.Tensor, # (N, 3) target point cloud
66
+ Y: torch.Tensor, # (M, 3) sampled mesh points
67
+ batch_size: int = 4096,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Bidirectional Chamfer distance (mean of per-point min-dists).
71
+ Batched to avoid OOM on large point sets.
72
+ """
73
+ def _one_way(src, tgt):
74
+ """For each point in src, find min squared distance to tgt."""
75
+ total = torch.tensor(0.0, device=src.device)
76
+ n = src.shape[0]
77
+ for i in range(0, n, batch_size):
78
+ chunk = src[i : i + batch_size]
79
+ dists = torch.cdist(chunk, tgt) # (chunk, M)
80
+ total = total + dists.min(dim=1).values.sum()
81
+ return total / max(n, 1)
82
+
83
+ return _one_way(X, Y) + _one_way(Y, X)
84
+
85
+
86
+ # ──────────────────────────────────────────────────────────────────────
87
+ # Beam-Gap Loss
88
+ # ──────────────────────────────────────────────────────────────────────
89
+ @torch.no_grad()
90
+ def _mutual_knn_mask(
91
+ Y: torch.Tensor, X: torch.Tensor, k: int = 3
92
+ ) -> torch.Tensor:
93
+ """
94
+ Boolean mask (M,): True where point y already has a 'good fit'
95
+ (mutual k-NN with X), so beam-gap should skip it.
96
+ """
97
+ # Y→X nearest k
98
+ d_yx = torch.cdist(Y, X) # (M, N)
99
+ _, idx_yx = d_yx.topk(k, dim=1, largest=False) # (M, k) indices into X
100
+
101
+ # X→Y nearest k
102
+ d_xy = torch.cdist(X, Y) # (N, M)
103
+ _, idx_xy = d_xy.topk(k, dim=1, largest=False) # (N, k) indices into Y
104
+
105
+ # For each y_i check: is y_i in the k-NN of any of its own k-NN x-targets?
106
+ good = torch.zeros(Y.shape[0], dtype=torch.bool, device=Y.device)
107
+ for yi in range(Y.shape[0]):
108
+ for xi in idx_yx[yi]:
109
+ if yi in idx_xy[xi.item()]:
110
+ good[yi] = True
111
+ break
112
+ return good
113
+
114
+
115
+ def beam_gap_loss(
116
+ Y: torch.Tensor, # (M, 3) sampled mesh points
117
+ normals: torch.Tensor, # (M, 3) mesh face normals at each sample
118
+ X: torch.Tensor, # (N, 3) target point cloud
119
+ epsilon: float = 0.5,
120
+ knn_k: int = 3,
121
+ max_samples: int = 2000,
122
+ ) -> torch.Tensor:
123
+ """
124
+ Beam-gap loss (Point2Mesh Β§3.3).
125
+
126
+ For each mesh sample Ε·, cast a beam along its normal, find the closest
127
+ target point inside an Ξ΅-cylinder, and penalise the gap.
128
+ Skips points that already have a mutual k-NN match.
129
+ """
130
+ M = Y.shape[0]
131
+ if M > max_samples:
132
+ # Subsample for efficiency
133
+ idx = torch.randperm(M, device=Y.device)[:max_samples]
134
+ Y = Y[idx]
135
+ normals = normals[idx]
136
+ M = max_samples
137
+
138
+ # Identify good-fit points to skip
139
+ good_mask = _mutual_knn_mask(Y, X, k=knn_k)
140
+
141
+ loss = torch.tensor(0.0, device=Y.device)
142
+ count = 0
143
+
144
+ # Vectorised cylinder test
145
+ # For each y: project (X - y) onto normal n
146
+ # along = dot(X - y, n); perp = ||(X - y) - along * n||
147
+ diffs = X.unsqueeze(0) - Y.unsqueeze(1) # (M, N, 3)
148
+ along = (diffs * normals.unsqueeze(1)).sum(dim=2) # (M, N)
149
+ perp = (diffs - along.unsqueeze(2) * normals.unsqueeze(1)).norm(dim=2) # (M, N)
150
+
151
+ # Inside cylinder: perp < epsilon AND along > 0 (ahead of surface)
152
+ in_cyl = (perp < epsilon) & (along > 0)
153
+
154
+ for i in range(M):
155
+ if good_mask[i]:
156
+ continue
157
+ cand = in_cyl[i]
158
+ if not cand.any():
159
+ continue
160
+ # Closest along the beam direction
161
+ dists_along = along[i].clone()
162
+ dists_along[~cand] = float("inf")
163
+ best_j = dists_along.argmin()
164
+ target = X[best_j]
165
+ loss = loss + (Y[i] - target).pow(2).sum()
166
+ count += 1
167
+
168
+ return loss / max(count, 1)
169
+
170
+
171
+ # ──────────────────────────────────────────────────────────────────────
172
+ # Normal alignment loss
173
+ # ──────────────────────────────────────────────────────────────────────
174
+ def normal_loss(
175
+ Y: torch.Tensor, # (M, 3) sampled mesh points
176
+ normals_mesh: torch.Tensor, # (M, 3) mesh face normals at samples
177
+ X: torch.Tensor, # (N, 3) target point cloud
178
+ normals_pc: torch.Tensor, # (N, 3) target point-cloud normals
179
+ ) -> torch.Tensor:
180
+ """
181
+ Unoriented normal alignment: 1 βˆ’ |n_mesh Β· n_pc|.
182
+ Pairs each mesh sample with its nearest point-cloud point.
183
+ """
184
+ dists = torch.cdist(Y, X) # (M, N)
185
+ nn_idx = dists.argmin(dim=1) # (M,)
186
+ nn_normals = normals_pc[nn_idx]
187
+ dot = (normals_mesh * nn_normals).sum(dim=1)
188
+ return (1 - dot.abs()).mean()