bdck commited on
Commit
dba24bf
·
verified ·
1 Parent(s): d4dafad

Upload learn_region_grow/lrg_net.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/lrg_net.py +181 -0
learn_region_grow/lrg_net.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LrgNet — Dual-branch 1D PointNet for learnable region growing.
3
+
4
+ This is a direct PyTorch port of the TensorFlow 1.x model described in:
5
+
6
+ LRGNet: Learnable Region Growing for Class-Agnostic Point Cloud Segmentation
7
+ Jingdao Chen, Zsolt Kira, Yong K. Cho
8
+ IEEE Robotics and Automation Letters (RAL), 2021
9
+ arXiv:2103.09160
10
+
11
+ Architecture overview
12
+ -------------------
13
+ The network takes two point sets as input:
14
+ 1. **Inlier branch**: the current region (points already assigned to the object).
15
+ 2. **Neighbor branch**: candidate points lying on the region boundary.
16
+
17
+ Each branch runs an independent 1D PointNet (shared weights between conv layers
18
+ in the original TensorFlow code, but kept independent here for clarity).
19
+ After local per-point convolutions, a global max-pool extracts a single
20
+ feature vector summarising the whole set. That global vector is tiled back to
21
+ match the point counts and concatenated with the per-point features.
22
+
23
+ Two classification heads then predict:
24
+ - **add_head** : per-neighbor binary logits (should this point join the region?)
25
+ - **remove_head** : per-inlier binary logits (should this point leave the region?)
26
+
27
+ Both heads are trained jointly with cross-entropy.
28
+ """
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from typing import Tuple
34
+
35
+
36
+ class PointNetBranch(nn.Module):
37
+ """
38
+ Single 1D PointNet branch: 1D conv layers + global max pool.
39
+
40
+ In the original TF code this is a sequence of:
41
+ Conv1D(13 -> 64), Conv1D(64 -> 64), Conv1D(64 -> 64),
42
+ Conv1D(64 -> 128), Conv1D(128 -> 512)
43
+ followed by max-pooling over the spatial (point) dimension.
44
+ """
45
+
46
+ def __init__(self, in_channels: int = 13):
47
+ super().__init__()
48
+ self.conv1 = nn.Conv1d(in_channels, 64, 1)
49
+ self.conv2 = nn.Conv1d(64, 64, 1)
50
+ self.conv3 = nn.Conv1d(64, 64, 1)
51
+ self.conv4 = nn.Conv1d(64, 128, 1)
52
+ self.conv5 = nn.Conv1d(128, 512, 1)
53
+
54
+ self.bn1 = nn.BatchNorm1d(64)
55
+ self.bn2 = nn.BatchNorm1d(64)
56
+ self.bn3 = nn.BatchNorm1d(64)
57
+ self.bn4 = nn.BatchNorm1d(128)
58
+ self.bn5 = nn.BatchNorm1d(512)
59
+
60
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ Parameters
63
+ ----------
64
+ x : torch.Tensor, shape (B, C, N)
65
+
66
+ Returns
67
+ -------
68
+ local_feat : torch.Tensor, shape (B, 512, N)
69
+ Per-point features from the deepest conv layer.
70
+ global_feat : torch.Tensor, shape (B, 512, 1)
71
+ Max-pooled global vector.
72
+ """
73
+ # x: (B, C, N)
74
+ x = F.relu(self.bn1(self.conv1(x)))
75
+ x = F.relu(self.bn2(self.conv2(x)))
76
+ x = F.relu(self.bn3(self.conv3(x)))
77
+ x = F.relu(self.bn4(self.conv4(x)))
78
+ x = F.relu(self.bn5(self.conv5(x)))
79
+ local_feat = x # (B, 512, N)
80
+ global_feat = torch.max(x, dim=2, keepdim=True)[0] # (B, 512, 1)
81
+ return local_feat, global_feat
82
+
83
+
84
+ class LrgNet(nn.Module):
85
+ """
86
+ LrgNet — Dual-branch network for learned region growing.
87
+
88
+ Parameters
89
+ ----------
90
+ in_channels : int
91
+ Number of feature channels per point (default 13 from the paper).
92
+ lite : int
93
+ 0 = full channels, 1 = half, 2 = quarter.
94
+ Lite variants run faster on edge devices with negligible accuracy loss.
95
+ """
96
+
97
+ def __init__(self, in_channels: int = 13, lite: int = 0):
98
+ super().__init__()
99
+ factor = 1 / (2 ** lite) # 1, 0.5, 0.25
100
+ c0 = int(64 * factor)
101
+ c1 = int(64 * factor)
102
+ c2 = int(64 * factor)
103
+ c3 = int(128 * factor)
104
+ c4 = int(512 * factor)
105
+
106
+ # Independent branches (original TF code shares conv weights conceptually,
107
+ # but we keep them separate to avoid accidental information leakage).
108
+ self.inlier_branch = self._make_branch(in_channels, c0, c1, c2, c3, c4)
109
+ self.neighbor_branch = self._make_branch(in_channels, c0, c1, c2, c3, c4)
110
+
111
+ # Classification heads
112
+ # Input: 512 (local) + 512 (global inlier) + 512 (global neighbor) = 1536
113
+ self.add_head = self._make_head(c4 * 3, 256, 128, 1)
114
+ self.remove_head = self._make_head(c4 * 3, 256, 128, 1)
115
+
116
+ def _make_branch(self, cin, c0, c1, c2, c3, c4):
117
+ layers = [
118
+ nn.Conv1d(cin, c0, 1), nn.BatchNorm1d(c0), nn.ReLU(),
119
+ nn.Conv1d(c0, c1, 1), nn.BatchNorm1d(c1), nn.ReLU(),
120
+ nn.Conv1d(c1, c2, 1), nn.BatchNorm1d(c2), nn.ReLU(),
121
+ nn.Conv1d(c2, c3, 1), nn.BatchNorm1d(c3), nn.ReLU(),
122
+ nn.Conv1d(c3, c4, 1), nn.BatchNorm1d(c4), nn.ReLU(),
123
+ ]
124
+ return nn.Sequential(*layers)
125
+
126
+ def _make_head(self, cin, h1, h2, out):
127
+ return nn.Sequential(
128
+ nn.Conv1d(cin, h1, 1),
129
+ nn.BatchNorm1d(h1),
130
+ nn.ReLU(),
131
+ nn.Conv1d(h1, h2, 1),
132
+ nn.BatchNorm1d(h2),
133
+ nn.ReLU(),
134
+ nn.Conv1d(h2, out, 1),
135
+ )
136
+
137
+ def forward(self, inliers: torch.Tensor, neighbors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """
139
+ Parameters
140
+ ----------
141
+ inliers : torch.Tensor, shape (B, C, Ni)
142
+ Current region points.
143
+ neighbors : torch.Tensor, shape (B, C, Nn)
144
+ Candidate boundary points.
145
+
146
+ Returns
147
+ -------
148
+ add_logits : torch.Tensor, shape (B, 1, Nn)
149
+ Log-odds for adding each neighbor.
150
+ remove_logits : torch.Tensor, shape (B, 1, Ni)
151
+ Log-odds for removing each inlier.
152
+ """
153
+ # Run branches
154
+ inlier_local = self.inlier_branch(inliers) # (B, c4, Ni)
155
+ neighbor_local = self.neighbor_branch(neighbors) # (B, c4, Nn)
156
+
157
+ # Global max-pool
158
+ inlier_global = torch.max(inlier_local, dim=2, keepdim=True)[0] # (B, c4, 1)
159
+ neighbor_global = torch.max(neighbor_local, dim=2, keepdim=True)[0] # (B, c4, 1)
160
+
161
+ # Tile globals to match point counts
162
+ inlier_global_tiled = inlier_global.expand(-1, -1, inliers.shape[2]) # (B, c4, Ni)
163
+ neighbor_global_tiled = neighbor_global.expand(-1, -1, neighbors.shape[2]) # (B, c4, Nn)
164
+
165
+ # Fuse for add head: neighbor local + neighbor global + inlier global
166
+ add_input = torch.cat([
167
+ neighbor_local,
168
+ neighbor_global_tiled,
169
+ inlier_global.expand(-1, -1, neighbors.shape[2])
170
+ ], dim=1) # (B, c4*3, Nn)
171
+ add_logits = self.add_head(add_input) # (B, 1, Nn)
172
+
173
+ # Fuse for remove head: inlier local + inlier global + neighbor global
174
+ remove_input = torch.cat([
175
+ inlier_local,
176
+ inlier_global_tiled,
177
+ neighbor_global.expand(-1, -1, inliers.shape[2])
178
+ ], dim=1) # (B, c4*3, Ni)
179
+ remove_logits = self.remove_head(remove_input) # (B, 1, Ni)
180
+
181
+ return add_logits, remove_logits