bdck commited on
Commit
bf02e8b
·
verified ·
1 Parent(s): 4e56ffc

Upload learn_region_grow/growing.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/growing.py +224 -0
learn_region_grow/growing.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference engine: learned region growing on a preprocessed point cloud.
3
+
4
+ Given a trained LrgNet model and a point cloud (with 13-channel features),
5
+ this module performs class-agnostic instance segmentation by growing regions
6
+ from seed points sorted by curvature.
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ from typing import Tuple, List, Optional
12
+ from .utils import sample_or_pad, center_features
13
+
14
+
15
+ class RegionGrower:
16
+ """
17
+ Wraps a trained LrgNet and runs iterative region growing on a point cloud.
18
+
19
+ Parameters
20
+ ----------
21
+ model : LrgNet
22
+ Trained PyTorch model.
23
+ device : torch.device
24
+ num_inlier : int
25
+ Max points sampled from current region for network input (default 512).
26
+ num_neighbor : int
27
+ Max points sampled from boundary candidates (default 512).
28
+ add_threshold : float
29
+ Confidence threshold for adding a neighbor (default 0.5).
30
+ remove_threshold : float
31
+ Confidence threshold for removing an inlier (default 0.5).
32
+ cluster_threshold : int
33
+ Minimum region size to keep (default 10).
34
+ max_steps : int
35
+ Maximum growing steps per seed (default 100).
36
+ stuck_patience : int
37
+ Stop if no updates for this many consecutive steps (default 2).
38
+ stochastic : bool
39
+ If True, use random sampling weighted by confidence (paper default).
40
+ If False, deterministic top-k thresholding.
41
+ """
42
+
43
+ def __init__(self, model, device,
44
+ num_inlier: int = 512,
45
+ num_neighbor: int = 512,
46
+ add_threshold: float = 0.5,
47
+ remove_threshold: float = 0.5,
48
+ cluster_threshold: int = 10,
49
+ max_steps: int = 100,
50
+ stuck_patience: int = 2,
51
+ stochastic: bool = True):
52
+ self.model = model.to(device).eval()
53
+ self.device = device
54
+ self.num_inlier = num_inlier
55
+ self.num_neighbor = num_neighbor
56
+ self.add_threshold = add_threshold
57
+ self.remove_threshold = remove_threshold
58
+ self.cluster_threshold = cluster_threshold
59
+ self.max_steps = max_steps
60
+ self.stuck_patience = stuck_patience
61
+ self.stochastic = stochastic
62
+
63
+ @torch.no_grad()
64
+ def grow(self, xyz: np.ndarray, features: np.ndarray,
65
+ voxel_map: dict, resolution: float = 0.1) -> np.ndarray:
66
+ """
67
+ Run learned region growing on the whole point cloud.
68
+
69
+ Parameters
70
+ ----------
71
+ xyz : np.ndarray, shape (N, 3)
72
+ Equalized point coordinates.
73
+ features : np.ndarray, shape (N, 13)
74
+ Precomputed feature vectors.
75
+ voxel_map : dict
76
+ Voxel key -> point index mapping (from voxel_equalize).
77
+ resolution : float
78
+ Voxel grid resolution.
79
+
80
+ Returns
81
+ -------
82
+ labels : np.ndarray, shape (N,), int
83
+ Instance label per point. -1 = unlabeled.
84
+ """
85
+ n = len(xyz)
86
+ labels = np.full(n, -1, dtype=np.int32)
87
+ visited = np.zeros(n, dtype=bool)
88
+
89
+ # Sort seeds by curvature (channel 12), ascending -> flat surfaces first
90
+ seed_order = np.argsort(features[:, 12])
91
+
92
+ instance_id = 0
93
+ for seed_idx in seed_order:
94
+ if visited[seed_idx]:
95
+ continue
96
+
97
+ region, changed = self._grow_seed(
98
+ seed_idx, xyz, features, voxel_map, resolution, visited
99
+ )
100
+
101
+ if len(region) >= self.cluster_threshold:
102
+ labels[list(region)] = instance_id
103
+ instance_id += 1
104
+ visited[list(region)] = True
105
+
106
+ # Fill unlabeled points by nearest-neighbor assignment
107
+ unlabeled = np.where(labels == -1)[0]
108
+ if len(unlabeled) > 0:
109
+ labeled = np.where(labels >= 0)[0]
110
+ if len(labeled) > 0:
111
+ # Brute-force nearest labeled neighbor
112
+ from scipy.spatial import cKDTree
113
+ tree = cKDTree(xyz[labeled])
114
+ _, nn_idx = tree.query(xyz[unlabeled])
115
+ labels[unlabeled] = labels[labeled[nn_idx]]
116
+
117
+ return labels
118
+
119
+ def _grow_seed(self, seed_idx: int, xyz: np.ndarray, features: np.ndarray,
120
+ voxel_map: dict, resolution: float, visited: np.ndarray) -> Tuple[set, bool]:
121
+ """
122
+ Grow a single region from a seed point.
123
+
124
+ Returns
125
+ -------
126
+ region : set of point indices
127
+ changed : bool
128
+ Whether the region ever grew (used by caller for bookkeeping).
129
+ """
130
+ region = {seed_idx}
131
+ stuck_counter = 0
132
+ changed = False
133
+
134
+ for step in range(self.max_steps):
135
+ # Find candidate neighbors on boundary
136
+ neighbors = self._find_boundary_neighbors(region, xyz, voxel_map, resolution)
137
+ # Exclude already visited seeds from other instances
138
+ neighbors = [p for p in neighbors if not visited[p] or p in region]
139
+
140
+ if len(neighbors) == 0:
141
+ break
142
+
143
+ # Sample / pad
144
+ inlier_pts = features[np.array(list(region), dtype=np.int64)]
145
+ neighbor_pts = features[np.array(neighbors, dtype=np.int64)]
146
+
147
+ inlier_s = sample_or_pad(inlier_pts, self.num_inlier)
148
+ neighbor_s = sample_or_pad(neighbor_pts, self.num_neighbor)
149
+
150
+ inlier_c, neighbor_c = center_features(inlier_s, neighbor_s)
151
+
152
+ # To tensor: (B, C, N)
153
+ inl_t = torch.from_numpy(inlier_c.T).unsqueeze(0).float().to(self.device)
154
+ nbr_t = torch.from_numpy(neighbor_c.T).unsqueeze(0).float().to(self.device)
155
+
156
+ add_logits, remove_logits = self.model(inl_t, nbr_t)
157
+ add_prob = torch.sigmoid(add_logits).cpu().numpy().squeeze() # (Nn,)
158
+ remove_prob = torch.sigmoid(remove_logits).cpu().numpy().squeeze() # (Ni,)
159
+
160
+ # Map back to actual points (first len(neighbors) entries correspond to real points)
161
+ n_real = len(neighbors)
162
+ i_real = len(region)
163
+
164
+ add_prob = add_prob[:n_real]
165
+ remove_prob = remove_prob[:i_real]
166
+
167
+ if self.stochastic:
168
+ add_mask = np.random.rand(n_real) < add_prob
169
+ remove_mask = np.random.rand(i_real) < remove_prob
170
+ else:
171
+ add_mask = add_prob > self.add_threshold
172
+ remove_mask = remove_prob > self.remove_threshold
173
+
174
+ to_add = set(np.array(neighbors)[add_mask].tolist())
175
+ region_list = list(region)
176
+ to_remove = set(np.array(region_list)[remove_mask].tolist())
177
+
178
+ new_region = (region | to_add) - to_remove
179
+ if new_region == region:
180
+ stuck_counter += 1
181
+ if stuck_counter >= self.stuck_patience:
182
+ break
183
+ else:
184
+ stuck_counter = 0
185
+ changed = True
186
+ region = new_region
187
+
188
+ return region, changed
189
+
190
+ def _find_boundary_neighbors(self, region: set, xyz: np.ndarray,
191
+ voxel_map: dict, resolution: float) -> List[int]:
192
+ """
193
+ Find points in voxels adjacent to the current region's bounding box
194
+ that are not yet inside the region.
195
+
196
+ This is a 6-connected voxel search: look at all voxels that share a face
197
+ with any voxel occupied by the region, and return their point indices.
198
+ """
199
+ # Collect occupied voxels of the region
200
+ region_pts = np.array(list(region), dtype=np.int64)
201
+ region_voxels = set()
202
+ for idx in region_pts:
203
+ v = tuple(np.round(xyz[idx] / resolution).astype(int))
204
+ region_voxels.add(v)
205
+
206
+ # Find adjacent voxels
207
+ adjacent_voxels = set()
208
+ for v in region_voxels:
209
+ for dx, dy, dz in [(-1,0,0), (1,0,0), (0,-1,0), (0,1,0), (0,0,-1), (0,0,1)]:
210
+ adjacent_voxels.add((v[0]+dx, v[1]+dy, v[2]+dz))
211
+
212
+ candidates = []
213
+ for v in adjacent_voxels:
214
+ if v in voxel_map and voxel_map[v] not in region:
215
+ candidates.append(voxel_map[v])
216
+
217
+ # Deduplicate
218
+ seen = set()
219
+ unique = []
220
+ for c in candidates:
221
+ if c not in seen:
222
+ seen.add(c)
223
+ unique.append(c)
224
+ return unique