Upload learn_region_grow/growing.py
Browse files- learn_region_grow/growing.py +224 -0
learn_region_grow/growing.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference engine: learned region growing on a preprocessed point cloud.
|
| 3 |
+
|
| 4 |
+
Given a trained LrgNet model and a point cloud (with 13-channel features),
|
| 5 |
+
this module performs class-agnostic instance segmentation by growing regions
|
| 6 |
+
from seed points sorted by curvature.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from typing import Tuple, List, Optional
|
| 12 |
+
from .utils import sample_or_pad, center_features
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RegionGrower:
|
| 16 |
+
"""
|
| 17 |
+
Wraps a trained LrgNet and runs iterative region growing on a point cloud.
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
model : LrgNet
|
| 22 |
+
Trained PyTorch model.
|
| 23 |
+
device : torch.device
|
| 24 |
+
num_inlier : int
|
| 25 |
+
Max points sampled from current region for network input (default 512).
|
| 26 |
+
num_neighbor : int
|
| 27 |
+
Max points sampled from boundary candidates (default 512).
|
| 28 |
+
add_threshold : float
|
| 29 |
+
Confidence threshold for adding a neighbor (default 0.5).
|
| 30 |
+
remove_threshold : float
|
| 31 |
+
Confidence threshold for removing an inlier (default 0.5).
|
| 32 |
+
cluster_threshold : int
|
| 33 |
+
Minimum region size to keep (default 10).
|
| 34 |
+
max_steps : int
|
| 35 |
+
Maximum growing steps per seed (default 100).
|
| 36 |
+
stuck_patience : int
|
| 37 |
+
Stop if no updates for this many consecutive steps (default 2).
|
| 38 |
+
stochastic : bool
|
| 39 |
+
If True, use random sampling weighted by confidence (paper default).
|
| 40 |
+
If False, deterministic top-k thresholding.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, model, device,
|
| 44 |
+
num_inlier: int = 512,
|
| 45 |
+
num_neighbor: int = 512,
|
| 46 |
+
add_threshold: float = 0.5,
|
| 47 |
+
remove_threshold: float = 0.5,
|
| 48 |
+
cluster_threshold: int = 10,
|
| 49 |
+
max_steps: int = 100,
|
| 50 |
+
stuck_patience: int = 2,
|
| 51 |
+
stochastic: bool = True):
|
| 52 |
+
self.model = model.to(device).eval()
|
| 53 |
+
self.device = device
|
| 54 |
+
self.num_inlier = num_inlier
|
| 55 |
+
self.num_neighbor = num_neighbor
|
| 56 |
+
self.add_threshold = add_threshold
|
| 57 |
+
self.remove_threshold = remove_threshold
|
| 58 |
+
self.cluster_threshold = cluster_threshold
|
| 59 |
+
self.max_steps = max_steps
|
| 60 |
+
self.stuck_patience = stuck_patience
|
| 61 |
+
self.stochastic = stochastic
|
| 62 |
+
|
| 63 |
+
@torch.no_grad()
|
| 64 |
+
def grow(self, xyz: np.ndarray, features: np.ndarray,
|
| 65 |
+
voxel_map: dict, resolution: float = 0.1) -> np.ndarray:
|
| 66 |
+
"""
|
| 67 |
+
Run learned region growing on the whole point cloud.
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
xyz : np.ndarray, shape (N, 3)
|
| 72 |
+
Equalized point coordinates.
|
| 73 |
+
features : np.ndarray, shape (N, 13)
|
| 74 |
+
Precomputed feature vectors.
|
| 75 |
+
voxel_map : dict
|
| 76 |
+
Voxel key -> point index mapping (from voxel_equalize).
|
| 77 |
+
resolution : float
|
| 78 |
+
Voxel grid resolution.
|
| 79 |
+
|
| 80 |
+
Returns
|
| 81 |
+
-------
|
| 82 |
+
labels : np.ndarray, shape (N,), int
|
| 83 |
+
Instance label per point. -1 = unlabeled.
|
| 84 |
+
"""
|
| 85 |
+
n = len(xyz)
|
| 86 |
+
labels = np.full(n, -1, dtype=np.int32)
|
| 87 |
+
visited = np.zeros(n, dtype=bool)
|
| 88 |
+
|
| 89 |
+
# Sort seeds by curvature (channel 12), ascending -> flat surfaces first
|
| 90 |
+
seed_order = np.argsort(features[:, 12])
|
| 91 |
+
|
| 92 |
+
instance_id = 0
|
| 93 |
+
for seed_idx in seed_order:
|
| 94 |
+
if visited[seed_idx]:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
region, changed = self._grow_seed(
|
| 98 |
+
seed_idx, xyz, features, voxel_map, resolution, visited
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if len(region) >= self.cluster_threshold:
|
| 102 |
+
labels[list(region)] = instance_id
|
| 103 |
+
instance_id += 1
|
| 104 |
+
visited[list(region)] = True
|
| 105 |
+
|
| 106 |
+
# Fill unlabeled points by nearest-neighbor assignment
|
| 107 |
+
unlabeled = np.where(labels == -1)[0]
|
| 108 |
+
if len(unlabeled) > 0:
|
| 109 |
+
labeled = np.where(labels >= 0)[0]
|
| 110 |
+
if len(labeled) > 0:
|
| 111 |
+
# Brute-force nearest labeled neighbor
|
| 112 |
+
from scipy.spatial import cKDTree
|
| 113 |
+
tree = cKDTree(xyz[labeled])
|
| 114 |
+
_, nn_idx = tree.query(xyz[unlabeled])
|
| 115 |
+
labels[unlabeled] = labels[labeled[nn_idx]]
|
| 116 |
+
|
| 117 |
+
return labels
|
| 118 |
+
|
| 119 |
+
def _grow_seed(self, seed_idx: int, xyz: np.ndarray, features: np.ndarray,
|
| 120 |
+
voxel_map: dict, resolution: float, visited: np.ndarray) -> Tuple[set, bool]:
|
| 121 |
+
"""
|
| 122 |
+
Grow a single region from a seed point.
|
| 123 |
+
|
| 124 |
+
Returns
|
| 125 |
+
-------
|
| 126 |
+
region : set of point indices
|
| 127 |
+
changed : bool
|
| 128 |
+
Whether the region ever grew (used by caller for bookkeeping).
|
| 129 |
+
"""
|
| 130 |
+
region = {seed_idx}
|
| 131 |
+
stuck_counter = 0
|
| 132 |
+
changed = False
|
| 133 |
+
|
| 134 |
+
for step in range(self.max_steps):
|
| 135 |
+
# Find candidate neighbors on boundary
|
| 136 |
+
neighbors = self._find_boundary_neighbors(region, xyz, voxel_map, resolution)
|
| 137 |
+
# Exclude already visited seeds from other instances
|
| 138 |
+
neighbors = [p for p in neighbors if not visited[p] or p in region]
|
| 139 |
+
|
| 140 |
+
if len(neighbors) == 0:
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
# Sample / pad
|
| 144 |
+
inlier_pts = features[np.array(list(region), dtype=np.int64)]
|
| 145 |
+
neighbor_pts = features[np.array(neighbors, dtype=np.int64)]
|
| 146 |
+
|
| 147 |
+
inlier_s = sample_or_pad(inlier_pts, self.num_inlier)
|
| 148 |
+
neighbor_s = sample_or_pad(neighbor_pts, self.num_neighbor)
|
| 149 |
+
|
| 150 |
+
inlier_c, neighbor_c = center_features(inlier_s, neighbor_s)
|
| 151 |
+
|
| 152 |
+
# To tensor: (B, C, N)
|
| 153 |
+
inl_t = torch.from_numpy(inlier_c.T).unsqueeze(0).float().to(self.device)
|
| 154 |
+
nbr_t = torch.from_numpy(neighbor_c.T).unsqueeze(0).float().to(self.device)
|
| 155 |
+
|
| 156 |
+
add_logits, remove_logits = self.model(inl_t, nbr_t)
|
| 157 |
+
add_prob = torch.sigmoid(add_logits).cpu().numpy().squeeze() # (Nn,)
|
| 158 |
+
remove_prob = torch.sigmoid(remove_logits).cpu().numpy().squeeze() # (Ni,)
|
| 159 |
+
|
| 160 |
+
# Map back to actual points (first len(neighbors) entries correspond to real points)
|
| 161 |
+
n_real = len(neighbors)
|
| 162 |
+
i_real = len(region)
|
| 163 |
+
|
| 164 |
+
add_prob = add_prob[:n_real]
|
| 165 |
+
remove_prob = remove_prob[:i_real]
|
| 166 |
+
|
| 167 |
+
if self.stochastic:
|
| 168 |
+
add_mask = np.random.rand(n_real) < add_prob
|
| 169 |
+
remove_mask = np.random.rand(i_real) < remove_prob
|
| 170 |
+
else:
|
| 171 |
+
add_mask = add_prob > self.add_threshold
|
| 172 |
+
remove_mask = remove_prob > self.remove_threshold
|
| 173 |
+
|
| 174 |
+
to_add = set(np.array(neighbors)[add_mask].tolist())
|
| 175 |
+
region_list = list(region)
|
| 176 |
+
to_remove = set(np.array(region_list)[remove_mask].tolist())
|
| 177 |
+
|
| 178 |
+
new_region = (region | to_add) - to_remove
|
| 179 |
+
if new_region == region:
|
| 180 |
+
stuck_counter += 1
|
| 181 |
+
if stuck_counter >= self.stuck_patience:
|
| 182 |
+
break
|
| 183 |
+
else:
|
| 184 |
+
stuck_counter = 0
|
| 185 |
+
changed = True
|
| 186 |
+
region = new_region
|
| 187 |
+
|
| 188 |
+
return region, changed
|
| 189 |
+
|
| 190 |
+
def _find_boundary_neighbors(self, region: set, xyz: np.ndarray,
|
| 191 |
+
voxel_map: dict, resolution: float) -> List[int]:
|
| 192 |
+
"""
|
| 193 |
+
Find points in voxels adjacent to the current region's bounding box
|
| 194 |
+
that are not yet inside the region.
|
| 195 |
+
|
| 196 |
+
This is a 6-connected voxel search: look at all voxels that share a face
|
| 197 |
+
with any voxel occupied by the region, and return their point indices.
|
| 198 |
+
"""
|
| 199 |
+
# Collect occupied voxels of the region
|
| 200 |
+
region_pts = np.array(list(region), dtype=np.int64)
|
| 201 |
+
region_voxels = set()
|
| 202 |
+
for idx in region_pts:
|
| 203 |
+
v = tuple(np.round(xyz[idx] / resolution).astype(int))
|
| 204 |
+
region_voxels.add(v)
|
| 205 |
+
|
| 206 |
+
# Find adjacent voxels
|
| 207 |
+
adjacent_voxels = set()
|
| 208 |
+
for v in region_voxels:
|
| 209 |
+
for dx, dy, dz in [(-1,0,0), (1,0,0), (0,-1,0), (0,1,0), (0,0,-1), (0,0,1)]:
|
| 210 |
+
adjacent_voxels.add((v[0]+dx, v[1]+dy, v[2]+dz))
|
| 211 |
+
|
| 212 |
+
candidates = []
|
| 213 |
+
for v in adjacent_voxels:
|
| 214 |
+
if v in voxel_map and voxel_map[v] not in region:
|
| 215 |
+
candidates.append(voxel_map[v])
|
| 216 |
+
|
| 217 |
+
# Deduplicate
|
| 218 |
+
seen = set()
|
| 219 |
+
unique = []
|
| 220 |
+
for c in candidates:
|
| 221 |
+
if c not in seen:
|
| 222 |
+
seen.add(c)
|
| 223 |
+
unique.append(c)
|
| 224 |
+
return unique
|