| """ |
| Mesh data structure with edge-based topology for MeshCNN operations. |
| |
| Stores vertices, faces, edges, and the GEMM (edge-neighbor) adjacency |
| required by MeshCNN convolutions. Also handles PartMesh spatial splitting. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from collections import defaultdict |
| from typing import List, Dict, Tuple, Optional |
|
|
|
|
| |
| |
| |
| class Mesh: |
| """Half-edgeβinspired mesh representation for MeshCNN / Point2Mesh.""" |
|
|
| def __init__( |
| self, |
| vertices: np.ndarray, |
| faces: np.ndarray, |
| device: str = "cpu", |
| ): |
| """ |
| Parameters |
| ---------- |
| vertices : (N_v, 3) float array |
| faces : (N_f, 3) int array β vertex indices per triangle |
| device : torch device string |
| """ |
| self.device = torch.device(device) |
| self.vs = torch.tensor(vertices, dtype=torch.float32, device=self.device) |
| self.faces = torch.tensor(faces, dtype=torch.long, device=self.device) |
| self._build_topology() |
|
|
| |
| def _build_topology(self): |
| F_np = self.faces.cpu().numpy() |
| n_faces = len(F_np) |
|
|
| edge_to_idx: Dict[Tuple[int, int], int] = {} |
| edge_faces: Dict[Tuple[int, int], List[Tuple[int, int]]] = defaultdict(list) |
| edges_list: List[Tuple[int, int]] = [] |
|
|
| for fi, face in enumerate(F_np): |
| for k in range(3): |
| v0, v1 = int(face[k]), int(face[(k + 1) % 3]) |
| key = (min(v0, v1), max(v0, v1)) |
| if key not in edge_to_idx: |
| edge_to_idx[key] = len(edges_list) |
| edges_list.append(key) |
| edge_faces[key].append((fi, k)) |
|
|
| edges_np = np.array(edges_list, dtype=np.int64) |
| n_edges = len(edges_np) |
|
|
| self.edges = torch.tensor(edges_np, dtype=torch.long, device=self.device) |
| self.n_edges = n_edges |
| self.n_faces = n_faces |
|
|
| |
| face_edges = np.zeros((n_faces, 3), dtype=np.int64) |
| for key, flist in edge_faces.items(): |
| eidx = edge_to_idx[key] |
| for fi, k in flist: |
| face_edges[fi][k] = eidx |
|
|
| |
| gemm = np.full((n_edges, 4), -1, dtype=np.int64) |
| for key, flist in edge_faces.items(): |
| eidx = edge_to_idx[key] |
| if len(flist) < 2: |
| fi, k = flist[0] |
| n0 = face_edges[fi][(k + 1) % 3] |
| n1 = face_edges[fi][(k + 2) % 3] |
| gemm[eidx] = [n0, n1, n0, n1] |
| continue |
| (fi0, k0), (fi1, k1) = flist[0], flist[1] |
| a = face_edges[fi0][(k0 + 1) % 3] |
| b = face_edges[fi0][(k0 + 2) % 3] |
| c = face_edges[fi1][(k1 + 1) % 3] |
| d = face_edges[fi1][(k1 + 2) % 3] |
| gemm[eidx] = [a, b, c, d] |
|
|
| self.gemm_edges = torch.tensor(gemm, dtype=torch.long, device=self.device) |
|
|
| |
| ve: Dict[int, List[int]] = defaultdict(list) |
| for ei, (v0, v1) in enumerate(edges_np): |
| ve[int(v0)].append(ei) |
| ve[int(v1)].append(ei) |
| self.vertex_edges = dict(ve) |
|
|
| |
| |
| self._build_edge_vertex_tables(edges_np) |
|
|
| self._update_face_areas() |
|
|
| def _build_edge_vertex_tables(self, edges_np: np.ndarray): |
| """Sparse index tables for scatter-adding edge Ξ β vertex Ξ.""" |
| n_v = self.vs.shape[0] |
| |
| |
| self.edge_v0 = torch.tensor(edges_np[:, 0], dtype=torch.long, device=self.device) |
| self.edge_v1 = torch.tensor(edges_np[:, 1], dtype=torch.long, device=self.device) |
|
|
| def _update_face_areas(self): |
| v0 = self.vs[self.faces[:, 0]] |
| v1 = self.vs[self.faces[:, 1]] |
| v2 = self.vs[self.faces[:, 2]] |
| cross = torch.cross(v1 - v0, v2 - v0, dim=1) |
| self.face_areas = 0.5 * cross.norm(dim=1) |
|
|
| def face_normals(self, verts: Optional[torch.Tensor] = None) -> torch.Tensor: |
| V = verts if verts is not None else self.vs |
| v0 = V[self.faces[:, 0]] |
| v1 = V[self.faces[:, 1]] |
| v2 = V[self.faces[:, 2]] |
| cross = torch.cross(v1 - v0, v2 - v0, dim=1) |
| return torch.nn.functional.normalize(cross, dim=1) |
|
|
| @property |
| def n_vertices(self) -> int: |
| return self.vs.shape[0] |
|
|
| def clone(self) -> "Mesh": |
| m = Mesh.__new__(Mesh) |
| m.device = self.device |
| m.vs = self.vs.clone() |
| m.faces = self.faces.clone() |
| m.edges = self.edges.clone() |
| m.n_edges = self.n_edges |
| m.n_faces = self.n_faces |
| m.gemm_edges = self.gemm_edges.clone() |
| m.vertex_edges = {k: list(v) for k, v in self.vertex_edges.items()} |
| m.edge_v0 = self.edge_v0.clone() |
| m.edge_v1 = self.edge_v1.clone() |
| m.face_areas = self.face_areas.clone() |
| return m |
|
|
|
|
| |
| |
| |
| def edge_to_vertex_displacement( |
| delta_edges: torch.Tensor, |
| mesh: Mesh, |
| ) -> torch.Tensor: |
| """Average per-edge endpoint displacements into per-vertex displacements.""" |
| n_v = mesh.n_vertices |
| delta_v = torch.zeros(n_v, 3, device=delta_edges.device, dtype=delta_edges.dtype) |
| count = torch.zeros(n_v, 1, device=delta_edges.device, dtype=delta_edges.dtype) |
|
|
| delta_v.scatter_add_(0, mesh.edge_v0.unsqueeze(1).expand(-1, 3), delta_edges[:, 0]) |
| delta_v.scatter_add_(0, mesh.edge_v1.unsqueeze(1).expand(-1, 3), delta_edges[:, 1]) |
| count.scatter_add_(0, mesh.edge_v0.unsqueeze(1), torch.ones(mesh.n_edges, 1, device=delta_edges.device)) |
| count.scatter_add_(0, mesh.edge_v1.unsqueeze(1), torch.ones(mesh.n_edges, 1, device=delta_edges.device)) |
|
|
| return delta_v / count.clamp(min=1) |
|
|
|
|
| |
| |
| |
| class PartMesh: |
| """Spatially partition a mesh for memory-efficient processing.""" |
|
|
| def __init__(self, mesh: Mesh, n_parts: int = 2): |
| self.mesh = mesh |
| self.n_parts = n_parts |
| self.parts: List[Mesh] = [] |
| self.vertex_maps: List[np.ndarray] = [] |
| self._split() |
|
|
| def _split(self): |
| vs = self.mesh.vs.cpu().numpy() |
| F_np = self.mesh.faces.cpu().numpy() |
| n = self.n_parts |
| lo = vs.min(axis=0) |
| hi = vs.max(axis=0) |
| span = hi - lo + 1e-8 |
|
|
| cell = np.floor(((vs - lo) / span) * n).astype(int) |
| cell = np.clip(cell, 0, n - 1) |
| cell_id = cell[:, 0] * n * n + cell[:, 1] * n + cell[:, 2] |
|
|
| cell_faces: Dict[int, set] = defaultdict(set) |
| for fi, face in enumerate(F_np): |
| for vi in face: |
| cell_faces[cell_id[vi]].add(fi) |
|
|
| for cid in sorted(cell_faces.keys()): |
| fset = sorted(cell_faces[cid]) |
| if not fset: |
| continue |
| sub_faces = F_np[fset] |
| unique_verts = np.unique(sub_faces.ravel()) |
| g2l = {int(g): l for l, g in enumerate(unique_verts)} |
| local_faces = np.vectorize(g2l.get)(sub_faces) |
| sub_vs = vs[unique_verts] |
| part = Mesh(sub_vs, local_faces, device=str(self.mesh.device)) |
| self.parts.append(part) |
| self.vertex_maps.append(unique_verts) |
|
|
| def aggregate_displacements( |
| self, part_deltas: List[torch.Tensor] |
| ) -> torch.Tensor: |
| delta = torch.zeros_like(self.mesh.vs) |
| count = torch.zeros(self.mesh.n_vertices, 1, device=self.mesh.device) |
| for vmap, dv in zip(self.vertex_maps, part_deltas): |
| idx = torch.tensor(vmap, dtype=torch.long, device=self.mesh.device) |
| delta[idx] += dv.to(self.mesh.device) |
| count[idx] += 1 |
| return delta / count.clamp(min=1) |
|
|