lightweightmr / meshing.py
bdck's picture
Upload meshing.py
0946eea verified
"""
Delaunay meshing + SDF labeling + surface extraction.
Pure Python using scipy.spatial.Delaunay instead of CGAL.
"""
import numpy as np
import torch
def compute_circumsphere_centers(tetrahedra):
"""
tetrahedra: (n, 4, 3)
Returns: (n, 3) circumcenters.
"""
A = tetrahedra[:, 0, :]
B = tetrahedra[:, 1, :]
C = tetrahedra[:, 2, :]
D = tetrahedra[:, 3, :]
x1, y1, z1 = A[:, 0:1], A[:, 1:2], A[:, 2:3]
x2, y2, z2 = B[:, 0:1], B[:, 1:2], B[:, 2:3]
x3, y3, z3 = C[:, 0:1], C[:, 1:2], C[:, 2:3]
x4, y4, z4 = D[:, 0:1], D[:, 1:2], D[:, 2:3]
A_matrix = np.stack([
np.concatenate([x2 - x1, y2 - y1, z2 - z1], axis=-1),
np.concatenate([x3 - x1, y3 - y1, z3 - z1], axis=-1),
np.concatenate([x4 - x1, y4 - y1, z4 - z1], axis=-1),
], axis=1) # (n, 3, 3)
b_vector = 0.5 * np.concatenate([
x2 ** 2 - x1 ** 2 + y2 ** 2 - y1 ** 2 + z2 ** 2 - z1 ** 2,
x3 ** 2 - x1 ** 2 + y3 ** 2 - y1 ** 2 + z3 ** 2 - z1 ** 2,
x4 ** 2 - x1 ** 2 + y4 ** 2 - y1 ** 2 + z4 ** 2 - z1 ** 2,
], axis=-1) # (n, 3)
center = np.linalg.solve(A_matrix, b_vector)
return torch.from_numpy(center).float()
def random_sampling_tetra(cell_vertex, k_samples):
"""
Random barycentric samples inside each tetrahedron.
cell_vertex: (n, 4, 3)
Returns: (n, k, 3)
"""
n = cell_vertex.shape[0]
random = np.random.rand(k_samples, 4).astype(np.float32)
random = random / (random.sum(axis=1, keepdims=True) + 1e-8)
random = torch.from_numpy(random).float()
random_samples = cell_vertex.unsqueeze(1) * random.view(1, k_samples, 4, 1)
random_samples = random_samples.sum(dim=2)
return random_samples
def labeling(sdf_network, queries, sdf_threshold=0.0, device='cpu', batch_size=10000):
"""
Query SDF and label tetrahedra.
queries: (n, k, 3)
Returns: labels (n,) int, 0=outside, 1=inside.
"""
n, k, _ = queries.shape
queries_flat = queries.view(-1, 3).to(device)
sdf_vals = []
with torch.no_grad():
for i in range(0, len(queries_flat), batch_size):
batch = queries_flat[i:i + batch_size]
s = sdf_network.sdf(batch).cpu()
sdf_vals.append(s)
sdf = torch.cat(sdf_vals, dim=0).view(n, k, 1)
ref = torch.where(sdf >= sdf_threshold, 1.0, 0.0)
ref_sum = ref.mean(dim=1) # (n, 1)
labels = torch.where(ref_sum >= 0.45, 1, 0).squeeze(-1)
return labels
def relabeling(labels, infinite_cell_id, cell_adj):
"""
Relabel using adjacency consistency.
labels: (n,) int
infinite_cell_id: set of infinite cell indices
cell_adj: (n, 4) neighbor indices
"""
labels = labels.clone()
adj_labels = labels[cell_adj] # (n, 4)
adj_labels_sum = adj_labels.sum(dim=-1, keepdim=True)
inside = torch.where((adj_labels_sum == 0) | (adj_labels_sum == 1))
outside = torch.where((adj_labels_sum == 3) | (adj_labels_sum == 4))
labels[inside] = 0
labels[outside] = 1
for idx in infinite_cell_id:
labels[idx] = 0
return labels
def create_mesh_from_delaunay(points, labels, delaunay):
"""
Extract surface mesh from labeled Delaunay tetrahedra.
Returns: vertices (m, 3), faces (p, 3)
"""
# facets between cells with different labels form the surface
# Build adjacency from Delaunay.neighbors
simplices = delaunay.simplices # (n, 4)
neighbors = delaunay.neighbors # (n, 4)
faces = []
for i in range(len(simplices)):
for j in range(4):
ni = neighbors[i, j]
if ni == -1:
# boundary facet
if labels[i] == 1:
f = np.delete(simplices[i], j)
faces.append(f)
elif labels[i] != labels[ni] and i < ni:
# shared facet between inside and outside
f = np.delete(simplices[i], j)
faces.append(f)
if len(faces) == 0:
return points, np.zeros((0, 3), dtype=np.int32)
faces = np.array(faces)
return points, faces
def delaunay_meshing(points, sdf_network, sdf_threshold=0.0, k_samples=21, device='cpu'):
"""
Full pipeline: Delaunay -> sample -> label -> extract mesh.
Args:
points: (N, 3) numpy array or tensor of generated vertices
sdf_network: trained SDFNetwork
sdf_threshold: surface level
k_samples: random samples per tetrahedron
device: torch device
Returns:
vertices, faces as numpy arrays
"""
from scipy.spatial import Delaunay
if torch.is_tensor(points):
points_np = points.detach().cpu().numpy()
else:
points_np = np.asarray(points, dtype=np.float32)
print("Building Delaunay triangulation...")
delaunay = Delaunay(points_np)
simplices = delaunay.simplices # (n_tets, 4)
# Identify infinite cells (any vertex == -1 in neighbor means boundary)
# scipy Delaunay marks boundary neighbors as -1
n_tets = len(simplices)
infinite_cell_id = set()
for i in range(n_tets):
if np.any(delaunay.neighbors[i] == -1):
infinite_cell_id.add(i)
# Compute circumcenters for constraint sampling
cell_vertex = torch.from_numpy(points_np[simplices]).float() # (n_tets, 4, 3)
ball_centers = cell_vertex.mean(dim=1, keepdim=True) # (n_tets, 1, 3)
try:
c_centers = compute_circumsphere_centers(cell_vertex.numpy()).unsqueeze(1) # (n_tets, 1, 3)
use_cc = True
except Exception:
c_centers = None
use_cc = False
samples = random_sampling_tetra(cell_vertex, k_samples)
samples = torch.cat([samples, ball_centers], dim=1) # (n_tets, k+1, 3)
if use_cc:
samples = torch.cat([samples, c_centers], dim=1)
print(f"Labeling {n_tets} cells with SDF queries...")
labels = labeling(sdf_network, samples, sdf_threshold=sdf_threshold, device=device)
# Build adjacency for relabeling
neighbors_t = torch.from_numpy(delaunay.neighbors).long()
neighbors_t = torch.where(neighbors_t < 0, torch.tensor(0), neighbors_t) # dummy for -1
labels = relabeling(labels, infinite_cell_id, neighbors_t)
vertices, faces = create_mesh_from_delaunay(points_np, labels, delaunay)
print(f"Mesh: {len(vertices)} vertices, {len(faces)} faces")
return vertices, faces
def add_mid_vertices(vertices, faces):
"""
Fix non-manifold edges by adding midpoint vertices.
Simple pure-Python version.
"""
import collections
edges = []
for f in faces:
edges.append(tuple(sorted([f[0], f[1]])))
edges.append(tuple(sorted([f[1], f[2]])))
edges.append(tuple(sorted([f[2], f[0]])))
edge_count = collections.Counter(edges)
bad_edges = [e for e, c in edge_count.items() if c != 2]
if len(bad_edges) == 0:
return vertices, faces
vert_list = list(vertices)
face_list = list(faces)
for e in bad_edges:
v0, v1 = e
mid = (vert_list[v0] + vert_list[v1]) * 0.5
mid_idx = len(vert_list)
vert_list.append(mid)
new_faces = []
for f in face_list:
ef = [tuple(sorted([f[0], f[1]])),
tuple(sorted([f[1], f[2]])),
tuple(sorted([f[2], f[0]]))]
if e in ef:
# split this face along the edge
other = [fv for fv in f if fv not in e][0]
f1 = np.array([v0, mid_idx, other])
f2 = np.array([mid_idx, v1, other])
new_faces.append(f1)
new_faces.append(f2)
else:
new_faces.append(f)
face_list = new_faces
return np.array(vert_list), np.array(face_list)