bdck commited on
Commit
68f2e33
Β·
verified Β·
1 Parent(s): 21f6803

add Point2Mesh U-Net network

Browse files
Files changed (1) hide show
  1. point2mesh/network.py +169 -0
point2mesh/network.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point2Mesh network β€” MeshCNN-based U-Net that regresses per-edge
3
+ vertex displacements from a fixed random input signal.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import List, Optional
11
+ from .mesh import Mesh
12
+ from .layers import MeshConv, MeshPool, MeshUnpool
13
+
14
+
15
+ class ConvBlock(nn.Module):
16
+ """MeshConv + GroupNorm + ReLU."""
17
+
18
+ def __init__(self, in_ch: int, out_ch: int, n_groups: int = 16):
19
+ super().__init__()
20
+ self.conv = MeshConv(in_ch, out_ch)
21
+ # Ensure n_groups divides out_ch
22
+ ng = min(n_groups, out_ch)
23
+ while out_ch % ng != 0:
24
+ ng -= 1
25
+ self.norm = nn.GroupNorm(ng, out_ch)
26
+ self.act = nn.ReLU(inplace=True)
27
+
28
+ def forward(self, x: torch.Tensor, mesh: Mesh) -> torch.Tensor:
29
+ x = self.conv(x, mesh)
30
+ x = self.norm(x)
31
+ x = self.act(x)
32
+ return x
33
+
34
+
35
+ class Point2MeshNet(nn.Module):
36
+ """
37
+ U-Net encoder-decoder on mesh edges.
38
+
39
+ Encoder: Conv β†’ Pool β†’ Conv β†’ Pool β†’ Conv β†’ Pool
40
+ Decoder: Unpool + skip β†’ Conv β†’ Unpool + skip β†’ Conv β†’ Unpool + skip β†’ Conv
41
+ Output : Linear β†’ [N_e, 6] β†’ reshaped to [N_e, 2, 3] edge displacements
42
+
43
+ The input is a *fixed* random tensor C_l ∈ [0,1)^{N_e Γ— in_ch} that is
44
+ NOT optimised; only the network weights are.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ in_ch: int = 6,
50
+ enc_channels: List[int] = None,
51
+ pool_targets: Optional[List[int]] = None,
52
+ ):
53
+ """
54
+ Parameters
55
+ ----------
56
+ in_ch : feature dim of the random input (default 6)
57
+ enc_channels : channel widths for encoder stages (default [64,128,256,256])
58
+ pool_targets : #edges after each pool stage β€” **set at runtime** from mesh
59
+ """
60
+ super().__init__()
61
+ if enc_channels is None:
62
+ enc_channels = [64, 128, 256, 256]
63
+
64
+ self.in_ch = in_ch
65
+ self.enc_channels = enc_channels
66
+
67
+ # ── Encoder ──
68
+ self.enc_convs = nn.ModuleList()
69
+ ch = in_ch
70
+ for out in enc_channels:
71
+ self.enc_convs.append(ConvBlock(ch, out))
72
+ ch = out
73
+
74
+ # ── Pools (instantiated lazily when pool_targets is known) ──
75
+ self.pools: Optional[nn.ModuleList] = None
76
+
77
+ # ── Decoder (symmetric) ──
78
+ dec_channels = list(reversed(enc_channels[:-1])) + [enc_channels[0]]
79
+ self.dec_convs = nn.ModuleList()
80
+ for i, out in enumerate(dec_channels):
81
+ # skip connection doubles the channel count
82
+ skip_ch = enc_channels[-(i + 2)] if i < len(enc_channels) - 1 else in_ch
83
+ self.dec_convs.append(ConvBlock(ch + skip_ch, out))
84
+ ch = out
85
+
86
+ self.unpools = nn.ModuleList([MeshUnpool() for _ in dec_channels])
87
+
88
+ # Output head: edge displacement [N_e, 6] β†’ reshaped [N_e, 2, 3]
89
+ self.head = nn.Conv1d(ch, 6, 1)
90
+ # Initialise head to zero so the first forward pass predicts no displacement
91
+ nn.init.zeros_(self.head.weight)
92
+ nn.init.zeros_(self.head.bias)
93
+
94
+ # ------------------------------------------------------------------
95
+ def _ensure_pools(self, n_edges: int):
96
+ """Create pool layers with target counts proportional to the mesh."""
97
+ if self.pools is not None:
98
+ return
99
+ n_pools = len(self.enc_channels) - 1 # one fewer pool than conv
100
+ targets = []
101
+ ne = n_edges
102
+ for _ in range(n_pools):
103
+ ne = max(ne // 2, 64)
104
+ targets.append(ne)
105
+ self.pools = nn.ModuleList([MeshPool(t) for t in targets])
106
+
107
+ # ------------------------------------------------------------------
108
+ def forward(
109
+ self,
110
+ C_l: torch.Tensor, # (1, in_ch, N_e) β€” random input, fixed
111
+ mesh: Mesh,
112
+ ) -> torch.Tensor:
113
+ """
114
+ Returns
115
+ -------
116
+ delta_edges : (N_e, 2, 3) vertex displacements at each edge endpoint
117
+ """
118
+ self._ensure_pools(mesh.n_edges)
119
+
120
+ x = C_l # (1, in_ch, N_e)
121
+
122
+ # ── Encoder ──
123
+ skips = [x]
124
+ histories = []
125
+ mesh_levels = [mesh]
126
+ for i, conv in enumerate(self.enc_convs):
127
+ x = conv(x, mesh_levels[-1])
128
+ if i < len(self.pools):
129
+ skips.append(x)
130
+ x, mesh_pooled, hist = self.pools[i](x, mesh_levels[-1])
131
+ histories.append(hist)
132
+ mesh_levels.append(mesh_pooled)
133
+
134
+ # ── Decoder ──
135
+ for i, (conv, unpool) in enumerate(zip(self.dec_convs, self.unpools)):
136
+ hist = histories[-(i + 1)]
137
+ x = unpool(x, hist)
138
+ skip = skips[-(i + 1)]
139
+ # Match sizes (in case of rounding during pool/unpool)
140
+ if x.shape[-1] != skip.shape[-1]:
141
+ min_e = min(x.shape[-1], skip.shape[-1])
142
+ x = x[:, :, :min_e]
143
+ skip = skip[:, :, :min_e]
144
+ x = torch.cat([x, skip], dim=1)
145
+ # Use the corresponding mesh level for this decoder stage
146
+ mesh_idx = len(mesh_levels) - 2 - i
147
+ x = conv(x, mesh_levels[max(mesh_idx, 0)])
148
+
149
+ # ── Head ──
150
+ x = self.head(x) # (1, 6, N_e)
151
+ x = x.squeeze(0).t() # (N_e, 6)
152
+ delta_edges = x.reshape(-1, 2, 3) # (N_e, 2, 3)
153
+ return delta_edges
154
+
155
+ def reset_weights(self):
156
+ """Re-initialize all weights (called at each coarse-to-fine level)."""
157
+ for m in self.modules():
158
+ if isinstance(m, (nn.Conv2d, nn.Conv1d)):
159
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
160
+ if m.bias is not None:
161
+ nn.init.zeros_(m.bias)
162
+ elif isinstance(m, nn.GroupNorm):
163
+ nn.init.ones_(m.weight)
164
+ nn.init.zeros_(m.bias)
165
+ # Always zero-init the output head
166
+ nn.init.zeros_(self.head.weight)
167
+ nn.init.zeros_(self.head.bias)
168
+ # Reset pool targets so they are re-created for new mesh
169
+ self.pools = None