bdck commited on
Commit
0946eea
·
verified ·
1 Parent(s): 74ead5e

Upload meshing.py

Browse files
Files changed (1) hide show
  1. meshing.py +231 -0
meshing.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Delaunay meshing + SDF labeling + surface extraction.
3
+ Pure Python using scipy.spatial.Delaunay instead of CGAL.
4
+ """
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def compute_circumsphere_centers(tetrahedra):
10
+ """
11
+ tetrahedra: (n, 4, 3)
12
+ Returns: (n, 3) circumcenters.
13
+ """
14
+ A = tetrahedra[:, 0, :]
15
+ B = tetrahedra[:, 1, :]
16
+ C = tetrahedra[:, 2, :]
17
+ D = tetrahedra[:, 3, :]
18
+
19
+ x1, y1, z1 = A[:, 0:1], A[:, 1:2], A[:, 2:3]
20
+ x2, y2, z2 = B[:, 0:1], B[:, 1:2], B[:, 2:3]
21
+ x3, y3, z3 = C[:, 0:1], C[:, 1:2], C[:, 2:3]
22
+ x4, y4, z4 = D[:, 0:1], D[:, 1:2], D[:, 2:3]
23
+
24
+ A_matrix = np.stack([
25
+ np.concatenate([x2 - x1, y2 - y1, z2 - z1], axis=-1),
26
+ np.concatenate([x3 - x1, y3 - y1, z3 - z1], axis=-1),
27
+ np.concatenate([x4 - x1, y4 - y1, z4 - z1], axis=-1),
28
+ ], axis=1) # (n, 3, 3)
29
+
30
+ b_vector = 0.5 * np.concatenate([
31
+ x2 ** 2 - x1 ** 2 + y2 ** 2 - y1 ** 2 + z2 ** 2 - z1 ** 2,
32
+ x3 ** 2 - x1 ** 2 + y3 ** 2 - y1 ** 2 + z3 ** 2 - z1 ** 2,
33
+ x4 ** 2 - x1 ** 2 + y4 ** 2 - y1 ** 2 + z4 ** 2 - z1 ** 2,
34
+ ], axis=-1) # (n, 3)
35
+
36
+ center = np.linalg.solve(A_matrix, b_vector)
37
+ return torch.from_numpy(center).float()
38
+
39
+
40
+ def random_sampling_tetra(cell_vertex, k_samples):
41
+ """
42
+ Random barycentric samples inside each tetrahedron.
43
+ cell_vertex: (n, 4, 3)
44
+ Returns: (n, k, 3)
45
+ """
46
+ n = cell_vertex.shape[0]
47
+ random = np.random.rand(k_samples, 4).astype(np.float32)
48
+ random = random / (random.sum(axis=1, keepdims=True) + 1e-8)
49
+ random = torch.from_numpy(random).float()
50
+ random_samples = cell_vertex.unsqueeze(1) * random.view(1, k_samples, 4, 1)
51
+ random_samples = random_samples.sum(dim=2)
52
+ return random_samples
53
+
54
+
55
+ def labeling(sdf_network, queries, sdf_threshold=0.0, device='cpu', batch_size=10000):
56
+ """
57
+ Query SDF and label tetrahedra.
58
+ queries: (n, k, 3)
59
+ Returns: labels (n,) int, 0=outside, 1=inside.
60
+ """
61
+ n, k, _ = queries.shape
62
+ queries_flat = queries.view(-1, 3).to(device)
63
+ sdf_vals = []
64
+ with torch.no_grad():
65
+ for i in range(0, len(queries_flat), batch_size):
66
+ batch = queries_flat[i:i + batch_size]
67
+ s = sdf_network.sdf(batch).cpu()
68
+ sdf_vals.append(s)
69
+ sdf = torch.cat(sdf_vals, dim=0).view(n, k, 1)
70
+
71
+ ref = torch.where(sdf >= sdf_threshold, 1.0, 0.0)
72
+ ref_sum = ref.mean(dim=1) # (n, 1)
73
+ labels = torch.where(ref_sum >= 0.45, 1, 0).squeeze(-1)
74
+ return labels
75
+
76
+
77
+ def relabeling(labels, infinite_cell_id, cell_adj):
78
+ """
79
+ Relabel using adjacency consistency.
80
+ labels: (n,) int
81
+ infinite_cell_id: set of infinite cell indices
82
+ cell_adj: (n, 4) neighbor indices
83
+ """
84
+ labels = labels.clone()
85
+ adj_labels = labels[cell_adj] # (n, 4)
86
+ adj_labels_sum = adj_labels.sum(dim=-1, keepdim=True)
87
+ inside = torch.where((adj_labels_sum == 0) | (adj_labels_sum == 1))
88
+ outside = torch.where((adj_labels_sum == 3) | (adj_labels_sum == 4))
89
+ labels[inside] = 0
90
+ labels[outside] = 1
91
+ for idx in infinite_cell_id:
92
+ labels[idx] = 0
93
+ return labels
94
+
95
+
96
+ def create_mesh_from_delaunay(points, labels, delaunay):
97
+ """
98
+ Extract surface mesh from labeled Delaunay tetrahedra.
99
+ Returns: vertices (m, 3), faces (p, 3)
100
+ """
101
+ # facets between cells with different labels form the surface
102
+ # Build adjacency from Delaunay.neighbors
103
+ simplices = delaunay.simplices # (n, 4)
104
+ neighbors = delaunay.neighbors # (n, 4)
105
+
106
+ faces = []
107
+ for i in range(len(simplices)):
108
+ for j in range(4):
109
+ ni = neighbors[i, j]
110
+ if ni == -1:
111
+ # boundary facet
112
+ if labels[i] == 1:
113
+ f = np.delete(simplices[i], j)
114
+ faces.append(f)
115
+ elif labels[i] != labels[ni] and i < ni:
116
+ # shared facet between inside and outside
117
+ f = np.delete(simplices[i], j)
118
+ faces.append(f)
119
+
120
+ if len(faces) == 0:
121
+ return points, np.zeros((0, 3), dtype=np.int32)
122
+
123
+ faces = np.array(faces)
124
+ return points, faces
125
+
126
+
127
+ def delaunay_meshing(points, sdf_network, sdf_threshold=0.0, k_samples=21, device='cpu'):
128
+ """
129
+ Full pipeline: Delaunay -> sample -> label -> extract mesh.
130
+
131
+ Args:
132
+ points: (N, 3) numpy array or tensor of generated vertices
133
+ sdf_network: trained SDFNetwork
134
+ sdf_threshold: surface level
135
+ k_samples: random samples per tetrahedron
136
+ device: torch device
137
+ Returns:
138
+ vertices, faces as numpy arrays
139
+ """
140
+ from scipy.spatial import Delaunay
141
+
142
+ if torch.is_tensor(points):
143
+ points_np = points.detach().cpu().numpy()
144
+ else:
145
+ points_np = np.asarray(points, dtype=np.float32)
146
+
147
+ print("Building Delaunay triangulation...")
148
+ delaunay = Delaunay(points_np)
149
+ simplices = delaunay.simplices # (n_tets, 4)
150
+
151
+ # Identify infinite cells (any vertex == -1 in neighbor means boundary)
152
+ # scipy Delaunay marks boundary neighbors as -1
153
+ n_tets = len(simplices)
154
+ infinite_cell_id = set()
155
+ for i in range(n_tets):
156
+ if np.any(delaunay.neighbors[i] == -1):
157
+ infinite_cell_id.add(i)
158
+
159
+ # Compute circumcenters for constraint sampling
160
+ cell_vertex = torch.from_numpy(points_np[simplices]).float() # (n_tets, 4, 3)
161
+ ball_centers = cell_vertex.mean(dim=1, keepdim=True) # (n_tets, 1, 3)
162
+
163
+ try:
164
+ c_centers = compute_circumsphere_centers(cell_vertex.numpy()).unsqueeze(1) # (n_tets, 1, 3)
165
+ use_cc = True
166
+ except Exception:
167
+ c_centers = None
168
+ use_cc = False
169
+
170
+ samples = random_sampling_tetra(cell_vertex, k_samples)
171
+ samples = torch.cat([samples, ball_centers], dim=1) # (n_tets, k+1, 3)
172
+ if use_cc:
173
+ samples = torch.cat([samples, c_centers], dim=1)
174
+
175
+ print(f"Labeling {n_tets} cells with SDF queries...")
176
+ labels = labeling(sdf_network, samples, sdf_threshold=sdf_threshold, device=device)
177
+
178
+ # Build adjacency for relabeling
179
+ neighbors_t = torch.from_numpy(delaunay.neighbors).long()
180
+ neighbors_t = torch.where(neighbors_t < 0, torch.tensor(0), neighbors_t) # dummy for -1
181
+ labels = relabeling(labels, infinite_cell_id, neighbors_t)
182
+
183
+ vertices, faces = create_mesh_from_delaunay(points_np, labels, delaunay)
184
+ print(f"Mesh: {len(vertices)} vertices, {len(faces)} faces")
185
+ return vertices, faces
186
+
187
+
188
+ def add_mid_vertices(vertices, faces):
189
+ """
190
+ Fix non-manifold edges by adding midpoint vertices.
191
+ Simple pure-Python version.
192
+ """
193
+ import collections
194
+ edges = []
195
+ for f in faces:
196
+ edges.append(tuple(sorted([f[0], f[1]])))
197
+ edges.append(tuple(sorted([f[1], f[2]])))
198
+ edges.append(tuple(sorted([f[2], f[0]])))
199
+
200
+ edge_count = collections.Counter(edges)
201
+ bad_edges = [e for e, c in edge_count.items() if c != 2]
202
+
203
+ if len(bad_edges) == 0:
204
+ return vertices, faces
205
+
206
+ vert_list = list(vertices)
207
+ face_list = list(faces)
208
+
209
+ for e in bad_edges:
210
+ v0, v1 = e
211
+ mid = (vert_list[v0] + vert_list[v1]) * 0.5
212
+ mid_idx = len(vert_list)
213
+ vert_list.append(mid)
214
+
215
+ new_faces = []
216
+ for f in face_list:
217
+ ef = [tuple(sorted([f[0], f[1]])),
218
+ tuple(sorted([f[1], f[2]])),
219
+ tuple(sorted([f[2], f[0]]))]
220
+ if e in ef:
221
+ # split this face along the edge
222
+ other = [fv for fv in f if fv not in e][0]
223
+ f1 = np.array([v0, mid_idx, other])
224
+ f2 = np.array([mid_idx, v1, other])
225
+ new_faces.append(f1)
226
+ new_faces.append(f2)
227
+ else:
228
+ new_faces.append(f)
229
+ face_list = new_faces
230
+
231
+ return np.array(vert_list), np.array(face_list)