| """ |
| Delaunay meshing + SDF labeling + surface extraction. |
| Pure Python using scipy.spatial.Delaunay instead of CGAL. |
| """ |
| import numpy as np |
| import torch |
|
|
|
|
| def compute_circumsphere_centers(tetrahedra): |
| """ |
| tetrahedra: (n, 4, 3) |
| Returns: (n, 3) circumcenters. |
| """ |
| A = tetrahedra[:, 0, :] |
| B = tetrahedra[:, 1, :] |
| C = tetrahedra[:, 2, :] |
| D = tetrahedra[:, 3, :] |
|
|
| x1, y1, z1 = A[:, 0:1], A[:, 1:2], A[:, 2:3] |
| x2, y2, z2 = B[:, 0:1], B[:, 1:2], B[:, 2:3] |
| x3, y3, z3 = C[:, 0:1], C[:, 1:2], C[:, 2:3] |
| x4, y4, z4 = D[:, 0:1], D[:, 1:2], D[:, 2:3] |
|
|
| A_matrix = np.stack([ |
| np.concatenate([x2 - x1, y2 - y1, z2 - z1], axis=-1), |
| np.concatenate([x3 - x1, y3 - y1, z3 - z1], axis=-1), |
| np.concatenate([x4 - x1, y4 - y1, z4 - z1], axis=-1), |
| ], axis=1) |
|
|
| b_vector = 0.5 * np.concatenate([ |
| x2 ** 2 - x1 ** 2 + y2 ** 2 - y1 ** 2 + z2 ** 2 - z1 ** 2, |
| x3 ** 2 - x1 ** 2 + y3 ** 2 - y1 ** 2 + z3 ** 2 - z1 ** 2, |
| x4 ** 2 - x1 ** 2 + y4 ** 2 - y1 ** 2 + z4 ** 2 - z1 ** 2, |
| ], axis=-1) |
|
|
| center = np.linalg.solve(A_matrix, b_vector) |
| return torch.from_numpy(center).float() |
|
|
|
|
| def random_sampling_tetra(cell_vertex, k_samples): |
| """ |
| Random barycentric samples inside each tetrahedron. |
| cell_vertex: (n, 4, 3) |
| Returns: (n, k, 3) |
| """ |
| n = cell_vertex.shape[0] |
| random = np.random.rand(k_samples, 4).astype(np.float32) |
| random = random / (random.sum(axis=1, keepdims=True) + 1e-8) |
| random = torch.from_numpy(random).float() |
| random_samples = cell_vertex.unsqueeze(1) * random.view(1, k_samples, 4, 1) |
| random_samples = random_samples.sum(dim=2) |
| return random_samples |
|
|
|
|
| def labeling(sdf_network, queries, sdf_threshold=0.0, device='cpu', batch_size=10000): |
| """ |
| Query SDF and label tetrahedra. |
| queries: (n, k, 3) |
| Returns: labels (n,) int, 0=outside, 1=inside. |
| """ |
| n, k, _ = queries.shape |
| queries_flat = queries.view(-1, 3).to(device) |
| sdf_vals = [] |
| with torch.no_grad(): |
| for i in range(0, len(queries_flat), batch_size): |
| batch = queries_flat[i:i + batch_size] |
| s = sdf_network.sdf(batch).cpu() |
| sdf_vals.append(s) |
| sdf = torch.cat(sdf_vals, dim=0).view(n, k, 1) |
|
|
| ref = torch.where(sdf >= sdf_threshold, 1.0, 0.0) |
| ref_sum = ref.mean(dim=1) |
| labels = torch.where(ref_sum >= 0.45, 1, 0).squeeze(-1) |
| return labels |
|
|
|
|
| def relabeling(labels, infinite_cell_id, cell_adj): |
| """ |
| Relabel using adjacency consistency. |
| labels: (n,) int |
| infinite_cell_id: set of infinite cell indices |
| cell_adj: (n, 4) neighbor indices |
| """ |
| labels = labels.clone() |
| adj_labels = labels[cell_adj] |
| adj_labels_sum = adj_labels.sum(dim=-1, keepdim=True) |
| inside = torch.where((adj_labels_sum == 0) | (adj_labels_sum == 1)) |
| outside = torch.where((adj_labels_sum == 3) | (adj_labels_sum == 4)) |
| labels[inside] = 0 |
| labels[outside] = 1 |
| for idx in infinite_cell_id: |
| labels[idx] = 0 |
| return labels |
|
|
|
|
| def create_mesh_from_delaunay(points, labels, delaunay): |
| """ |
| Extract surface mesh from labeled Delaunay tetrahedra. |
| Returns: vertices (m, 3), faces (p, 3) |
| """ |
| |
| |
| simplices = delaunay.simplices |
| neighbors = delaunay.neighbors |
|
|
| faces = [] |
| for i in range(len(simplices)): |
| for j in range(4): |
| ni = neighbors[i, j] |
| if ni == -1: |
| |
| if labels[i] == 1: |
| f = np.delete(simplices[i], j) |
| faces.append(f) |
| elif labels[i] != labels[ni] and i < ni: |
| |
| f = np.delete(simplices[i], j) |
| faces.append(f) |
|
|
| if len(faces) == 0: |
| return points, np.zeros((0, 3), dtype=np.int32) |
|
|
| faces = np.array(faces) |
| return points, faces |
|
|
|
|
| def delaunay_meshing(points, sdf_network, sdf_threshold=0.0, k_samples=21, device='cpu'): |
| """ |
| Full pipeline: Delaunay -> sample -> label -> extract mesh. |
| |
| Args: |
| points: (N, 3) numpy array or tensor of generated vertices |
| sdf_network: trained SDFNetwork |
| sdf_threshold: surface level |
| k_samples: random samples per tetrahedron |
| device: torch device |
| Returns: |
| vertices, faces as numpy arrays |
| """ |
| from scipy.spatial import Delaunay |
|
|
| if torch.is_tensor(points): |
| points_np = points.detach().cpu().numpy() |
| else: |
| points_np = np.asarray(points, dtype=np.float32) |
|
|
| print("Building Delaunay triangulation...") |
| delaunay = Delaunay(points_np) |
| simplices = delaunay.simplices |
|
|
| |
| |
| n_tets = len(simplices) |
| infinite_cell_id = set() |
| for i in range(n_tets): |
| if np.any(delaunay.neighbors[i] == -1): |
| infinite_cell_id.add(i) |
|
|
| |
| cell_vertex = torch.from_numpy(points_np[simplices]).float() |
| ball_centers = cell_vertex.mean(dim=1, keepdim=True) |
|
|
| try: |
| c_centers = compute_circumsphere_centers(cell_vertex.numpy()).unsqueeze(1) |
| use_cc = True |
| except Exception: |
| c_centers = None |
| use_cc = False |
|
|
| samples = random_sampling_tetra(cell_vertex, k_samples) |
| samples = torch.cat([samples, ball_centers], dim=1) |
| if use_cc: |
| samples = torch.cat([samples, c_centers], dim=1) |
|
|
| print(f"Labeling {n_tets} cells with SDF queries...") |
| labels = labeling(sdf_network, samples, sdf_threshold=sdf_threshold, device=device) |
|
|
| |
| neighbors_t = torch.from_numpy(delaunay.neighbors).long() |
| neighbors_t = torch.where(neighbors_t < 0, torch.tensor(0), neighbors_t) |
| labels = relabeling(labels, infinite_cell_id, neighbors_t) |
|
|
| vertices, faces = create_mesh_from_delaunay(points_np, labels, delaunay) |
| print(f"Mesh: {len(vertices)} vertices, {len(faces)} faces") |
| return vertices, faces |
|
|
|
|
| def add_mid_vertices(vertices, faces): |
| """ |
| Fix non-manifold edges by adding midpoint vertices. |
| Simple pure-Python version. |
| """ |
| import collections |
| edges = [] |
| for f in faces: |
| edges.append(tuple(sorted([f[0], f[1]]))) |
| edges.append(tuple(sorted([f[1], f[2]]))) |
| edges.append(tuple(sorted([f[2], f[0]]))) |
|
|
| edge_count = collections.Counter(edges) |
| bad_edges = [e for e, c in edge_count.items() if c != 2] |
|
|
| if len(bad_edges) == 0: |
| return vertices, faces |
|
|
| vert_list = list(vertices) |
| face_list = list(faces) |
|
|
| for e in bad_edges: |
| v0, v1 = e |
| mid = (vert_list[v0] + vert_list[v1]) * 0.5 |
| mid_idx = len(vert_list) |
| vert_list.append(mid) |
|
|
| new_faces = [] |
| for f in face_list: |
| ef = [tuple(sorted([f[0], f[1]])), |
| tuple(sorted([f[1], f[2]])), |
| tuple(sorted([f[2], f[0]]))] |
| if e in ef: |
| |
| other = [fv for fv in f if fv not in e][0] |
| f1 = np.array([v0, mid_idx, other]) |
| f2 = np.array([mid_idx, v1, other]) |
| new_faces.append(f1) |
| new_faces.append(f2) |
| else: |
| new_faces.append(f) |
| face_list = new_faces |
|
|
| return np.array(vert_list), np.array(face_list) |
|
|