bdck commited on
Commit
e3423d1
·
verified ·
1 Parent(s): 6f54288

Upload learn_region_grow/stage_data.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/stage_data.py +175 -0
learn_region_grow/stage_data.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training data generator.
3
+
4
+ Simulates the region growing process on ground-truth labeled point clouds to create
5
+ supervised training examples: (inlier_points, neighbor_points, add_labels, remove_labels).
6
+
7
+ The key trick in the paper is **controlled noise injection**:
8
+ - add_mistake_prob : probability of including an outlier in the inlier set
9
+ - remove_mistake_prob : probability of removing a true inlier
10
+ This forces the network to learn to *recover from errors*, making the growing
11
+ process robust to early mistakes (e.g. a bad seed or an initial over-growth).
12
+ """
13
+
14
+ import numpy as np
15
+ from typing import Tuple, Optional
16
+ from pathlib import Path
17
+ import h5py
18
+ from .preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector
19
+ from .utils import sample_or_pad, center_features
20
+
21
+
22
+ def stage_labeled_cloud(xyz: np.ndarray, rgb: Optional[np.ndarray],
23
+ labels: np.ndarray,
24
+ add_mistake_prob: float = 0.2,
25
+ remove_mistake_prob: float = 0.2,
26
+ resolution: float = 0.1,
27
+ num_inlier: int = 512,
28
+ num_neighbor: int = 512,
29
+ seeds_per_instance: int = 5,
30
+ max_steps_per_seed: int = 20) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
31
+ """
32
+ Generate supervised training tuples from a labeled point cloud.
33
+
34
+ For each unique instance label, pick `seeds_per_instance` random seeds
35
+ inside the instance and simulate noisy region growing.
36
+
37
+ Parameters
38
+ ----------
39
+ xyz : np.ndarray, shape (N, 3)
40
+ rgb : np.ndarray, shape (N, 3), uint8 or None
41
+ labels : np.ndarray, shape (N,), int
42
+ Instance IDs. Background / wall should have label < 0 or a special value.
43
+ add_mistake_prob : float
44
+ Probability of wrongly keeping an outlier in the growing set.
45
+ remove_mistake_prob : float
46
+ Probability of wrongly discarding a true inlier.
47
+ resolution : float
48
+ Voxel grid resolution.
49
+ num_inlier / num_neighbor : int
50
+ Network input sizes.
51
+ seeds_per_instance : int
52
+ max_steps_per_seed : int
53
+
54
+ Returns
55
+ -------
56
+ inlier_batches : np.ndarray, shape (M, num_inlier, 13)
57
+ neighbor_batches: np.ndarray, shape (M, num_neighbor, 13)
58
+ add_labels : np.ndarray, shape (M, num_neighbor)
59
+ remove_labels : np.ndarray, shape (M, num_inlier)
60
+ """
61
+ # Preprocess
62
+ eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, resolution)
63
+ eq_labels = labels[eq_idx]
64
+
65
+ normals, curvature = compute_normals_and_curvature(eq_xyz, resolution)
66
+ features = build_feature_vector(eq_xyz, rgb[eq_idx] if rgb is not None else None,
67
+ normals, curvature)
68
+
69
+ unique_instances = np.unique(eq_labels[eq_labels >= 0])
70
+
71
+ all_inliers = []
72
+ all_neighbors = []
73
+ all_add = []
74
+ all_remove = []
75
+
76
+ for inst in unique_instances:
77
+ inst_mask = eq_labels == inst
78
+ inst_indices = np.where(inst_mask)[0]
79
+ if len(inst_indices) < 5:
80
+ continue
81
+
82
+ rng = np.random.default_rng()
83
+ seeds = rng.choice(inst_indices, min(seeds_per_instance, len(inst_indices)), replace=False)
84
+
85
+ for seed in seeds:
86
+ region = {int(seed)}
87
+ for step in range(max_steps_per_seed):
88
+ # Find boundary neighbors using voxel adjacency
89
+ neighbors = _boundary_neighbors(region, eq_xyz, voxel_map, resolution)
90
+ # GT labels
91
+ gt_neighbors = eq_labels[np.array(neighbors)] == inst
92
+ gt_region = eq_labels[np.array(list(region))] == inst
93
+
94
+ # Inject noise into labels (the *target* for supervision)
95
+ noisy_add = gt_neighbors.astype(bool).copy()
96
+ noisy_remove = (~gt_region).copy()
97
+ # Flip some correct labels to incorrect ones
98
+ noisy_add = _flip_mask(noisy_add, add_mistake_prob, rng)
99
+ noisy_remove = _flip_mask(noisy_remove, remove_mistake_prob, rng)
100
+
101
+ # Build input tensors
102
+ inlier_pts = features[np.array(list(region), dtype=np.int64)]
103
+ neighbor_pts = features[np.array(neighbors, dtype=np.int64)] if len(neighbors) else np.zeros((0, 13), dtype=np.float32)
104
+
105
+ inlier_s = sample_or_pad(inlier_pts, num_inlier)
106
+ neighbor_s = sample_or_pad(neighbor_pts, num_neighbor)
107
+
108
+ inlier_c, neighbor_c = center_features(inlier_s, neighbor_s)
109
+
110
+ # Pad labels to match padded lengths
111
+ add_label = np.zeros(num_neighbor, dtype=np.int64)
112
+ remove_label = np.zeros(num_inlier, dtype=np.int64)
113
+ n_real = min(len(neighbors), num_neighbor)
114
+ i_real = min(len(region), num_inlier)
115
+ add_label[:n_real] = noisy_add[:n_real].astype(np.int64)
116
+ remove_label[:i_real] = noisy_remove[:i_real].astype(np.int64)
117
+
118
+ all_inliers.append(inlier_c)
119
+ all_neighbors.append(neighbor_c)
120
+ all_add.append(add_label)
121
+ all_remove.append(remove_label)
122
+
123
+ # Update region for next step (use noisy labels as "simulated" current state)
124
+ for idx, flag in zip(neighbors[:n_real], noisy_add[:n_real]):
125
+ if flag:
126
+ region.add(int(idx))
127
+ for idx, flag in zip(list(region)[:i_real], ~noisy_remove[:i_real]):
128
+ if not flag:
129
+ region.discard(int(idx))
130
+
131
+ if len(all_inliers) == 0:
132
+ # Return empty arrays with correct shape
133
+ return (np.zeros((0, num_inlier, 13), dtype=np.float32),
134
+ np.zeros((0, num_neighbor, 13), dtype=np.float32),
135
+ np.zeros((0, num_neighbor), dtype=np.int64),
136
+ np.zeros((0, num_inlier), dtype=np.int64))
137
+
138
+ return (np.stack(all_inliers), np.stack(all_neighbors),
139
+ np.stack(all_add), np.stack(all_remove))
140
+
141
+
142
+ def _flip_mask(mask: np.ndarray, prob: float, rng: np.random.Generator) -> np.ndarray:
143
+ """Randomly flip `prob` fraction of True entries to False and vice-versa."""
144
+ out = mask.copy()
145
+ flip = rng.random(len(mask)) < prob
146
+ out[flip] = ~out[flip]
147
+ return out
148
+
149
+
150
+ def _boundary_neighbors(region: set, xyz: np.ndarray, voxel_map: dict, resolution: float):
151
+ """Find adjacent voxel points not in region (6-connected)."""
152
+ region_pts = np.array(list(region), dtype=np.int64)
153
+ voxels = set()
154
+ for idx in region_pts:
155
+ v = tuple(np.round(xyz[idx] / resolution).astype(int))
156
+ voxels.add(v)
157
+
158
+ adjacent = set()
159
+ for v in voxels:
160
+ for dx, dy, dz in [(-1,0,0),(1,0,0),(0,-1,0),(0,1,0),(0,0,-1),(0,0,1)]:
161
+ nv = (v[0]+dx, v[1]+dy, v[2]+dz)
162
+ if nv in voxel_map and voxel_map[nv] not in region:
163
+ adjacent.add(voxel_map[nv])
164
+ return list(adjacent)
165
+
166
+
167
+ def save_staged_h5(path: str, inliers, neighbors, add_labels, remove_labels):
168
+ """Save staged training data to an HDF5 file."""
169
+ path = Path(path)
170
+ path.parent.mkdir(parents=True, exist_ok=True)
171
+ with h5py.File(path, 'w') as f:
172
+ f.create_dataset('inliers', data=inliers, compression='gzip')
173
+ f.create_dataset('neighbors', data=neighbors, compression='gzip')
174
+ f.create_dataset('add_labels', data=add_labels, compression='gzip')
175
+ f.create_dataset('remove_labels', data=remove_labels, compression='gzip')