bdck commited on
Commit
21f6803
Β·
verified Β·
1 Parent(s): a944a57

add MeshCNN conv/pool/unpool layers

Browse files
Files changed (1) hide show
  1. point2mesh/layers.py +269 -0
point2mesh/layers.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MeshCNN layers β€” convolution, pooling and unpooling on triangle meshes.
3
+
4
+ Convolution works on edges: each edge has 4 topological neighbors from its
5
+ two incident faces. Symmetric aggregation removes the face-ordering
6
+ ambiguity. Pooling collapses edges by L2-norm priority; unpooling restores
7
+ them from stored history.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from typing import List, Optional
17
+ from .mesh import Mesh
18
+
19
+
20
+ # ──────────────────────────────────────────────────────────────────────
21
+ # MeshConv
22
+ # ──────────────────────────────────────────────────────────────────────
23
+ class MeshConv(nn.Module):
24
+ """
25
+ Edge-based convolution (MeshCNN Β§4.1).
26
+
27
+ For each edge *e* with 4 neighbors (a, b, c, d) we form 5 inputs:
28
+ [e, |aβˆ’c|, a+c, |bβˆ’d|, b+d]
29
+ and apply a learned linear combination (via Conv2d with kernel (1,5)).
30
+ """
31
+
32
+ def __init__(self, in_ch: int, out_ch: int, bias: bool = True):
33
+ super().__init__()
34
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=(1, 5), bias=bias)
35
+
36
+ def forward(self, x: torch.Tensor, mesh: Mesh) -> torch.Tensor:
37
+ """
38
+ x : (1, C_in, N_e)
39
+ mesh : Mesh with .gemm_edges [N_e, 4]
40
+ Returns: (1, C_out, N_e)
41
+ """
42
+ # Gather the 4 neighbor features
43
+ G = self._gather_neighbors(x, mesh) # (1, C, N_e, 4)
44
+
45
+ # Symmetric aggregation
46
+ a, b, c, d = G[..., 0], G[..., 1], G[..., 2], G[..., 3]
47
+ sym = torch.stack([
48
+ torch.abs(a - c),
49
+ a + c,
50
+ torch.abs(b - d),
51
+ b + d,
52
+ ], dim=-1) # (1, C, N_e, 4)
53
+
54
+ # Concatenate center edge + 4 symmetric descriptors β†’ width-5
55
+ x_5 = torch.cat([x.unsqueeze(-1), sym], dim=-1) # (1, C, N_e, 5)
56
+ out = self.conv(x_5) # (1, C_out, N_e, 1)
57
+ return out.squeeze(-1)
58
+
59
+ @staticmethod
60
+ def _gather_neighbors(x: torch.Tensor, mesh: Mesh) -> torch.Tensor:
61
+ """Gather features of the 4 neighbor edges for every edge."""
62
+ # x: (1, C, N_e)
63
+ B, C, N_e = x.shape
64
+ gemm = mesh.gemm_edges # (N_e, 4) on mesh.device
65
+ # Clamp to handle any βˆ’1 (boundary mirror already filled, but be safe)
66
+ gemm = gemm.clamp(min=0)
67
+ flat = gemm.reshape(-1) # (N_e*4,)
68
+ gathered = x[:, :, flat] # (1, C, N_e*4)
69
+ return gathered.view(B, C, N_e, 4)
70
+
71
+
72
+ # ──────────────────────────────────────────────────────────────────────
73
+ # MeshPool (edge collapse)
74
+ # ──────────────────────────────────────────────────────────────────────
75
+ class MeshPool(nn.Module):
76
+ """
77
+ Mesh pooling via edge collapse (MeshCNN Β§4.2).
78
+
79
+ Edges are prioritised by L2-norm of their feature vector; the
80
+ smallest-norm edges are collapsed first. After each collapse the
81
+ features of the two resulting edges are set to the average of the
82
+ three merged edges.
83
+
84
+ The collapse history is stored so that `MeshUnpool` can invert the
85
+ operation.
86
+ """
87
+
88
+ def __init__(self, target: int):
89
+ """target : number of edges to keep after pooling."""
90
+ super().__init__()
91
+ self.target = target
92
+
93
+ def forward(
94
+ self,
95
+ x: torch.Tensor,
96
+ mesh: Mesh,
97
+ ) -> tuple[torch.Tensor, Mesh, dict]:
98
+ """
99
+ x : (1, C, N_e)
100
+ mesh : current Mesh
101
+ Returns
102
+ -------
103
+ x_pooled : (1, C, target)
104
+ mesh_new : Mesh with updated topology (MUTATED)
105
+ history : dict consumed by MeshUnpool
106
+ """
107
+ B, C, N_e = x.shape
108
+ device = x.device
109
+
110
+ # Work on CPU numpy for topology manipulation (small meshes are fast)
111
+ gemm = mesh.gemm_edges.cpu().numpy().copy() # (N_e, 4)
112
+ edges_np = mesh.edges.cpu().numpy().copy()
113
+ feat = x.squeeze(0).detach().cpu().numpy().copy() # (C, N_e)
114
+
115
+ active = np.ones(N_e, dtype=bool)
116
+ n_active = int(active.sum())
117
+
118
+ # Priority: L2 norm of each edge's feature vector
119
+ norms = np.linalg.norm(feat, axis=0) # (N_e,)
120
+
121
+ # Sorted order of edges by ascending norm
122
+ order = np.argsort(norms)
123
+
124
+ # History bookkeeping for unpooling
125
+ # Maps: new_edge_idx β†’ set of old edge indices that contributed
126
+ merge_log: List[tuple] = [] # (surviving_edge, [merged edges], [merge weights])
127
+
128
+ collapse_map = np.arange(N_e) # edge redirect after collapses
129
+
130
+ idx = 0
131
+ while n_active > self.target and idx < len(order):
132
+ e = order[idx]
133
+ idx += 1
134
+ if not active[e]:
135
+ continue
136
+
137
+ a, b, c, d = gemm[e]
138
+ # Validity checks
139
+ if a < 0 or b < 0 or c < 0 or d < 0:
140
+ continue
141
+ if not (active[a] and active[b] and active[c] and active[d]):
142
+ continue
143
+ # Non-manifold guard: skip if collapsing would merge two boundary verts
144
+ # (simplified: skip if any neighbor is already dead or re-targeted)
145
+ if a == b or c == d:
146
+ continue
147
+
148
+ # Collapse edge e:
149
+ # Face-0 edges (a, b) β†’ surviving edge p
150
+ # Face-1 edges (c, d) β†’ surviving edge q
151
+ # Merged features: p = avg(e, a, b), q = avg(e, c, d)
152
+ p, q = b, d # surviving edge labels (keep the "second" edge of each face)
153
+
154
+ # Update features (on numpy)
155
+ feat[:, p] = (feat[:, e] + feat[:, a] + feat[:, b]) / 3.0
156
+ feat[:, q] = (feat[:, e] + feat[:, c] + feat[:, d]) / 3.0
157
+
158
+ merge_log.append((p, [e, a, b]))
159
+ merge_log.append((q, [e, c, d]))
160
+
161
+ # Deactivate collapsed edges
162
+ active[e] = False
163
+ active[a] = False
164
+ active[c] = False
165
+ n_active -= 3
166
+
167
+ # Redirect any neighbor pointers that point to a or c
168
+ gemm[gemm == a] = p
169
+ gemm[gemm == c] = q
170
+ gemm[gemm == e] = p # default redirect to p
171
+
172
+ # Build new compact edge set
173
+ kept = np.where(active)[0]
174
+ old2new = np.full(N_e, -1, dtype=np.int64)
175
+ for new_i, old_i in enumerate(kept):
176
+ old2new[old_i] = new_i
177
+
178
+ # Re-index gemm for surviving edges
179
+ new_gemm = gemm[kept].copy()
180
+ for i in range(new_gemm.shape[0]):
181
+ for j in range(4):
182
+ mapped = old2new[new_gemm[i, j]]
183
+ new_gemm[i, j] = mapped if mapped >= 0 else i # self-loop fallback
184
+
185
+ # Build new feature tensor (differentiable path)
186
+ kept_t = torch.tensor(kept, dtype=torch.long, device=device)
187
+ x_pooled = x[:, :, kept_t] # (1, C, n_kept)
188
+
189
+ # Overwrite collapsed features differentiably
190
+ # We re-run the averaging in torch for grad flow
191
+ x_work = x.squeeze(0) # (C, N_e)
192
+ new_feats = []
193
+ for old_i in kept:
194
+ new_feats.append(x_work[:, old_i])
195
+ # Override with merged averages
196
+ merge_map = {}
197
+ for surv, sources in merge_log:
198
+ if surv in merge_map:
199
+ continue # keep first
200
+ merge_map[surv] = sources
201
+ new_feat_list = []
202
+ for ni, old_i in enumerate(kept):
203
+ if old_i in merge_map:
204
+ srcs = merge_map[old_i]
205
+ avg = sum(x_work[:, s] for s in srcs) / len(srcs)
206
+ new_feat_list.append(avg)
207
+ else:
208
+ new_feat_list.append(x_work[:, old_i])
209
+ x_pooled = torch.stack(new_feat_list, dim=1).unsqueeze(0) # (1, C, n_kept)
210
+
211
+ # Construct new Mesh from surviving edges/faces
212
+ # (for simplicity we update the mesh in-place rather than rebuild faces)
213
+ mesh_new = mesh.clone()
214
+ mesh_new.gemm_edges = torch.tensor(new_gemm, dtype=torch.long, device=mesh.device)
215
+ mesh_new.n_edges = len(kept)
216
+ # Keep edges array updated (vertex indices of surviving edges)
217
+ mesh_new.edges = mesh.edges[kept_t.to(mesh.device)]
218
+ mesh_new.edge_v0 = mesh_new.edges[:, 0]
219
+ mesh_new.edge_v1 = mesh_new.edges[:, 1]
220
+
221
+ history = {
222
+ "kept": kept, # indices of surviving edges in old ordering
223
+ "old2new": old2new, # old_edge β†’ new_edge mapping
224
+ "merge_log": merge_log, # how features were merged
225
+ "n_old": N_e,
226
+ }
227
+ return x_pooled, mesh_new, history
228
+
229
+
230
+ # ──────────────────────────────────────────────────────────────────────
231
+ # MeshUnpool (restore topology from history)
232
+ # ──────────────────────────────────────────────────────────────────────
233
+ class MeshUnpool(nn.Module):
234
+ """
235
+ Restore the pre-pooling edge topology using stored history.
236
+
237
+ Unpooled edge features are set to the feature of the surviving edge
238
+ they were merged into (broadcast).
239
+ """
240
+
241
+ def forward(
242
+ self,
243
+ x: torch.Tensor,
244
+ history: dict,
245
+ ) -> torch.Tensor:
246
+ """
247
+ x : (1, C, N_pooled)
248
+ history : dict from MeshPool.forward
249
+ Returns : (1, C, N_old)
250
+ """
251
+ B, C, N_pooled = x.shape
252
+ N_old = history["n_old"]
253
+ device = x.device
254
+ kept = history["kept"]
255
+
256
+ out = torch.zeros(B, C, N_old, device=device, dtype=x.dtype)
257
+
258
+ # Place surviving edge features at their original indices
259
+ kept_t = torch.tensor(kept, dtype=torch.long, device=device)
260
+ out[:, :, kept_t] = x
261
+
262
+ # For collapsed edges, copy from the surviving edge they merged into
263
+ for surv, sources in history["merge_log"]:
264
+ surv_new = int(history["old2new"][surv])
265
+ for s in sources:
266
+ if s not in kept:
267
+ out[:, :, s] = x[:, :, surv_new]
268
+
269
+ return out