jskvrna commited on
Commit
33113fd
·
1 Parent(s): 2affd35

Preparation of the files for the public release.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. color_visu.py +8 -0
  2. end_to_end.py +24 -0
  3. end_to_end_deeper.py +0 -946
  4. fast_pointnet.py +0 -520
  5. fast_pointnet_class.py +7 -0
  6. fast_pointnet_class_10d.py +0 -405
  7. fast_pointnet_class_10d_2048.py +0 -405
  8. fast_pointnet_class_10d_deeper.py +0 -438
  9. fast_pointnet_class_deeper.py +0 -527
  10. fast_pointnet_class_v2.py +0 -508
  11. fast_pointnet_v2.py +11 -1
  12. fast_pointnet_v3.py +0 -605
  13. fast_voxel.py +0 -591
  14. find_best_results.py +7 -0
  15. fully_deep.py +0 -1082
  16. generate_pcloud_dataset.py +12 -0
  17. hoho_cpu.batch +0 -17
  18. hoho_cpu_gpu_intel.batch +0 -19
  19. hoho_gpu.batch +0 -19
  20. hoho_gpu_class.batch +0 -19
  21. hoho_gpu_class_10d.batch +0 -19
  22. hoho_gpu_class_10d_2048.batch +0 -19
  23. hoho_gpu_class_10d_deeper.batch +0 -19
  24. hoho_gpu_h200.batch +0 -19
  25. hoho_gpu_voxel.batch +0 -19
  26. initial_epoch_100.pth +0 -3
  27. initial_epoch_100_class_v2.pth +0 -3
  28. initial_epoch_100_v2.pth +0 -3
  29. initial_epoch_100_v2_aug.pth +0 -3
  30. initial_epoch_60.pth +0 -3
  31. initial_epoch_60_v2.pth +0 -3
  32. iterate.batch +0 -50
  33. pnet.pth +2 -2
  34. predict.py +12 -0
  35. predict_end.py +0 -73
  36. script.py +1 -1
  37. train.py +19 -10
  38. train_end.py +0 -73
  39. train_pnet.py +0 -13
  40. train_pnet_class.py +13 -1
  41. train_pnet_class_cluster.py +0 -13
  42. train_pnet_class_cluster_10d.py +0 -13
  43. train_pnet_class_cluster_10d_2048.py +0 -13
  44. train_pnet_class_cluster_10d_deeper.py +0 -13
  45. train_pnet_cluster.py +0 -10
  46. train_pnet_cluster_class_v2.py +0 -10
  47. train_pnet_cluster_v3.py +0 -10
  48. train_pnet_cluster_v2.py → train_pnet_v2.py +2 -2
  49. train_voxel.py +0 -13
  50. train_voxel_cluster.py +0 -13
color_visu.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
 
 
1
+ """
2
+ This file generates a color legend image for building component visualization.
3
+ It creates a PNG image showing color swatches and labels for two categories:
4
+ 1. Gestalt Colors - for various building components like roof, walls, windows, etc.
5
+ 2. Edge Colors - for architectural edges like ridges, eaves, hips, valleys, etc.
6
+ The legend helps visualize the color mappings used in building analysis and annotation.
7
+ """
8
+
9
  import cv2
10
  import numpy as np
11
 
end_to_end.py CHANGED
@@ -1,3 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pickle
3
  import torch
 
1
+ """
2
+ End-to-End Voxel-Based Vertex Detection Pipeline
3
+
4
+ This file implements a complete pipeline for detecting wireframe vertices from 3D point clouds using
5
+ a voxel-based deep learning approach. The pipeline includes:
6
+
7
+ 1. Data preprocessing: Converting 14D point clouds into 3D voxel grids with averaged features
8
+ 2. Ground truth generation: Creating binary vertex labels and refinement targets from wireframe vertices
9
+ 3. Model architecture: VoxelUNet with encoder-decoder structure and 1x1x1 bottleneck for vertex detection
10
+ 4. Training: Combined loss function with BCE, Dice loss, and MSE for offset regression
11
+ 5. Inference: Predicting vertex locations from new point clouds with visualization
12
+
13
+ Key components:
14
+ - Voxelization with configurable grid size and metric voxel size
15
+ - Per-voxel MLP before convolutional processing
16
+ - Gaussian smoothing of ground truth labels
17
+ - Refinement prediction for sub-voxel accuracy
18
+ - PyVista-based visualization for results analysis
19
+
20
+ Usage:
21
+ - Set inference=False to train a new model
22
+ - Set inference=True to run predictions on existing data
23
+ """
24
+
25
  import os
26
  import pickle
27
  import torch
end_to_end_deeper.py DELETED
@@ -1,946 +0,0 @@
1
- import os
2
- import pickle
3
- import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- import numpy as np
7
- from typing import Dict, Any, Tuple, List
8
- from torch.utils.data import Dataset, DataLoader
9
- import glob
10
- import pyvista as pv
11
- import torch
12
-
13
- # [Previous code from the existing document remains unchanged up to CombinedLoss class]
14
- # ... (save_data, load_data, get_data_files, voxelize_points, create_ground_truth, VoxelUNet, VoxelDataset) ...
15
-
16
- def save_data(dict_to_save: Dict[str, Any], filename: str, data_folder: str = "data") -> None:
17
- """Save dictionary data to pickle file"""
18
- os.makedirs(data_folder, exist_ok=True)
19
- filepath = os.path.join(data_folder, f"{filename}.pkl")
20
- with open(filepath, 'wb') as f:
21
- pickle.dump(dict_to_save, f)
22
- #print(f"Data saved to {filepath}")
23
-
24
- def load_data(filepath: str) -> Dict[str, Any]:
25
- """Load dictionary data from pickle file"""
26
- with open(filepath, 'rb') as f:
27
- data = pickle.load(f)
28
- #print(f"Data loaded from {filepath}")
29
- return data
30
-
31
- def get_data_files(data_folder: str = "data", pattern: str = "*.pkl") -> List[str]:
32
- """Get list of data files from folder"""
33
- search_pattern = os.path.join(data_folder, pattern)
34
- files = glob.glob(search_pattern)
35
- #print(f"Found {len(files)} data files in {data_folder}")
36
- return files
37
-
38
- def voxelize_points(points: np.ndarray,
39
- grid_size_xy: int = 64,
40
- voxel_size_metric: float = 0.25
41
- ) -> Tuple[torch.Tensor, np.ndarray, Dict[str, Any]]:
42
- """
43
- Voxelize 14D point cloud into a 3D grid with a fixed number of voxels and fixed metric voxel size.
44
- The Z dimension of the grid will have grid_size_xy / 4 voxels.
45
- The point cloud is centered within this metric grid. Points outside are discarded.
46
-
47
- Args:
48
- points: (N, 14) array where first 3 dims are xyz (original coordinates).
49
- grid_size_xy: Number of voxels along X and Y dimensions.
50
- voxel_size_metric: The physical size of each voxel (e.g., 0.5 units).
51
-
52
- Returns:
53
- voxel_grid: (NUM_FEATURES, dim_z, dim_y, dim_x) tensor.
54
- voxel_indices_for_points: (N_points_in_grid, 3) integer voxel indices (z, y, x)
55
- for each input point that falls within the grid.
56
- scale_info: Dict with transformation parameters:
57
- 'grid_origin_metric': Real-world metric coordinate of the corner of voxel [0,0,0] (x,y,z).
58
- 'voxel_size_metric': The metric size of a voxel.
59
- 'grid_dims_voxels': Tuple (dim_x, dim_y, dim_z) representing number of voxels.
60
- 'pc_centroid_metric': Centroid of the input point cloud (x,y,z).
61
- """
62
- NUM_FEATURES = 14
63
- dim_x = grid_size_xy
64
- dim_y = grid_size_xy
65
- dim_z = grid_size_xy # Assuming cubic grid based on usage in create_ground_truth and VoxelDataset
66
-
67
- if dim_z == 0: dim_z = 1 # Ensure at least one voxel in Z
68
-
69
- # grid_dims_voxels stores (num_voxels_x, num_voxels_y, num_voxels_z)
70
- grid_dims_voxels = np.array([dim_x, dim_y, dim_z], dtype=int)
71
-
72
- def _get_empty_return(reason: str = ""):
73
- # Voxel grid shape is fixed: (NUM_FEATURES, dim_z, dim_y, dim_x)
74
- voxel_grid_empty = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
75
- voxel_indices_empty = np.empty((0, 3), dtype=int)
76
- scale_info_empty = {
77
- 'grid_origin_metric': np.zeros(3, dtype=float),
78
- 'voxel_size_metric': voxel_size_metric,
79
- 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
80
- 'pc_centroid_metric': np.zeros(3, dtype=float),
81
- }
82
- return voxel_grid_empty, voxel_indices_empty, scale_info_empty
83
-
84
- if points.shape[0] == 0:
85
- return _get_empty_return("Initial empty point cloud")
86
-
87
- xyz = points[:, :3] # Metric coordinates of points
88
- features_other = points[:, 3:] # Other features
89
-
90
- pc_centroid_metric = xyz.mean(axis=0) # (cx, cy, cz)
91
-
92
- # Calculate the metric origin of the grid such that the point cloud centroid
93
- # aligns with the center of the metric grid.
94
- # grid_metric_span is (total_metric_width_x, total_metric_height_y, total_metric_depth_z)
95
- grid_metric_span = grid_dims_voxels * voxel_size_metric
96
- # grid_origin_metric is the real-world (x,y,z) coordinate of the corner of voxel (0,0,0)
97
- grid_origin_metric = pc_centroid_metric - (grid_metric_span / 2.0)
98
-
99
- # Initialize voxel_grid: PyTorch expects (C, D, H, W)
100
- # Here, D=dim_z, H=dim_y, W=dim_x
101
- voxel_grid = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
102
-
103
- # Convert point metric coordinates to continuous voxel coordinates (potentially fractional and outside [0, dim-1])
104
- # continuous_voxel_coords[i] = (px_i, py_i, pz_i) in continuous voxel grid space
105
- continuous_voxel_coords = (xyz - grid_origin_metric) / voxel_size_metric
106
-
107
- # To store (z_idx, y_idx, x_idx) for each point, for easier indexing into torch tensors
108
- voxel_indices_for_points_zyx_order = []
109
-
110
- for i in range(points.shape[0]):
111
- # current_point_continuous_coord_xyz is (x_coord, y_coord, z_coord) in continuous voxel space
112
- current_point_continuous_coord_xyz = continuous_voxel_coords[i]
113
-
114
- # Voxel integer indices (ix, iy, iz) by flooring. This is the voxel cell the point falls into.
115
- voxel_idx_int_xyz = np.round(current_point_continuous_coord_xyz).astype(int)
116
-
117
- # Check if the point falls outside the grid boundaries
118
- # grid_dims_voxels is (dim_x, dim_y, dim_z)
119
- idx_x, idx_y, idx_z = voxel_idx_int_xyz[0], voxel_idx_int_xyz[1], voxel_idx_int_xyz[2]
120
-
121
- if not (0 <= idx_x < grid_dims_voxels[0] and \
122
- 0 <= idx_y < grid_dims_voxels[1] and \
123
- 0 <= idx_z < grid_dims_voxels[2]):
124
- # Point is outside the grid, skip it
125
- continue
126
-
127
- # At this point, idx_x, idx_y, idx_z are guaranteed to be within grid bounds.
128
- # No explicit clipping is needed here, but using them directly.
129
-
130
- voxel_indices_for_points_zyx_order.append([idx_z, idx_y, idx_x])
131
-
132
- # Calculate offset for the first 3 features:
133
- # Center of the assigned voxel in continuous grid index space (e.g., [0.5,0.5,0.5] for voxel [0,0,0])
134
- assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
135
-
136
- # Offset of the point from its assigned voxel center, in grid units.
137
- offset_xyz_in_grid_units = current_point_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
138
-
139
- # Accumulate features in voxel_grid (C, Z, Y, X)
140
- # Store dx, dy, dz (offsets in X, Y, Z dimensions)
141
- voxel_grid[0, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[0] # dx
142
- voxel_grid[1, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[1] # dy
143
- voxel_grid[2, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[2] # dz
144
-
145
- # Accumulate other original features (from index 3 onwards)
146
- if NUM_FEATURES > 3:
147
- current_point_other_features = features_other[i]
148
- voxel_grid[3:, idx_z, idx_y, idx_x] += torch.tensor(current_point_other_features, dtype=torch.float32)
149
-
150
- final_voxel_indices_for_points_zyx = np.array(voxel_indices_for_points_zyx_order, dtype=int) if voxel_indices_for_points_zyx_order else np.empty((0,3), dtype=int)
151
-
152
- scale_info = {
153
- 'grid_origin_metric': grid_origin_metric,
154
- 'voxel_size_metric': voxel_size_metric,
155
- 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
156
- 'pc_centroid_metric': pc_centroid_metric,
157
- }
158
-
159
- return voxel_grid, final_voxel_indices_for_points_zyx, scale_info
160
-
161
-
162
- def create_ground_truth(vertices: np.ndarray,
163
- scale_info: Dict[str, Any]
164
- ) -> Tuple[torch.Tensor, torch.Tensor]:
165
- """
166
- Create ground truth voxel labels and refinement targets using metric voxelization info.
167
- The grid dimensions are taken from scale_info.
168
-
169
- Args:
170
- vertices: (M, 3) vertex coordinates in original metric space.
171
- scale_info: Dict from voxelize_points. Requires:
172
- 'grid_origin_metric', 'voxel_size_metric', 'grid_dims_voxels'.
173
- Returns:
174
- vertex_labels: (dim_z, dim_y, dim_x) binary labels (1.0 for voxel containing a vertex).
175
- refinement_targets: (3, dim_z, dim_y, dim_x) offset (dx,dy,dz) from voxel cell center
176
- in grid units. Range approx [-0.5, 0.5).
177
- """
178
- grid_origin_metric = scale_info['grid_origin_metric'] # (ox, oy, oz)
179
- voxel_size_metric = scale_info['voxel_size_metric']
180
- # grid_dims_voxels is (num_voxels_x, num_voxels_y, num_voxels_z)
181
- grid_dims_voxels = np.array(scale_info['grid_dims_voxels'])
182
-
183
- dim_x, dim_y, dim_z = grid_dims_voxels[0], grid_dims_voxels[1], grid_dims_voxels[2]
184
-
185
- # Labels tensor: (dim_z, dim_y, dim_x)
186
- vertex_labels = torch.zeros(dim_z, dim_y, dim_x, dtype=torch.float32)
187
- # Refinement targets tensor: (3, dim_z, dim_y, dim_x) for (dx, dy, dz) offsets
188
- refinement_targets = torch.zeros(3, dim_z, dim_y, dim_x, dtype=torch.float32)
189
-
190
- if vertices.shape[0] == 0:
191
- return vertex_labels, refinement_targets
192
-
193
- # Convert vertex metric coordinates to continuous voxel coordinates
194
- # (potentially fractional and outside [0, dim-1])
195
- continuous_voxel_coords_vertices = (vertices - grid_origin_metric) / voxel_size_metric
196
-
197
- for i in range(vertices.shape[0]):
198
- # v_continuous_coord_xyz is (vx, vy, vz) for the current vertex in continuous voxel space
199
- v_continuous_coord_xyz = continuous_voxel_coords_vertices[i]
200
-
201
- # Integer voxel index (ix, iy, iz) by flooring
202
- v_idx_int_xyz = np.floor(v_continuous_coord_xyz).astype(int)
203
-
204
- # Clip to be within grid boundaries [0, dim-1]
205
- idx_x = np.clip(v_idx_int_xyz[0], 0, dim_x - 1)
206
- idx_y = np.clip(v_idx_int_xyz[1], 0, dim_y - 1)
207
- idx_z = np.clip(v_idx_int_xyz[2], 0, dim_z - 1)
208
-
209
- # Set label for this voxel (using z, y, x order for tensor access)
210
- vertex_labels[idx_z, idx_y, idx_x] = 1.0
211
-
212
- # Calculate refinement offset:
213
- # Center of the *assigned* (clipped) voxel in continuous grid index space
214
- assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
215
-
216
- # Offset of the vertex from its *assigned* voxel center, in grid units.
217
- offset_xyz_grid_units = v_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
218
-
219
- # Store dx, dy, dz in channels 0, 1, 2 respectively
220
- # refinement_targets is (3, Z, Y, X)
221
- refinement_targets[0, idx_z, idx_y, idx_x] = offset_xyz_grid_units[0] # dx
222
- refinement_targets[1, idx_z, idx_y, idx_x] = offset_xyz_grid_units[1] # dy
223
- refinement_targets[2, idx_z, idx_y, idx_x] = offset_xyz_grid_units[2] # dz
224
-
225
- return vertex_labels, refinement_targets
226
-
227
- class VoxelUNet(nn.Module):
228
- """Enhanced U-Net for voxel-based vertex detection with increased capacity and advanced features."""
229
-
230
- def __init__(self, in_channels: int = 14, base_channels: int = 64, bottleneck_expansion: int = 4,
231
- use_attention: bool = True, use_residual: bool = True, dropout_rate: float = 0.1):
232
- super(VoxelUNet, self).__init__()
233
-
234
- bc = base_channels
235
- self.use_attention = use_attention
236
- self.use_residual = use_residual
237
-
238
- # Encoder with increased depth and capacity
239
- self.enc1 = self._conv_block(in_channels, bc, use_residual=False) # bc
240
- self.enc2 = self._conv_block(bc, bc * 2, dropout_rate) # bc*2
241
- self.enc3 = self._conv_block(bc * 2, bc * 4, dropout_rate) # bc*4
242
- self.enc4 = self._conv_block(bc * 4, bc * 8, dropout_rate) # bc*8
243
- self.enc5 = self._conv_block(bc * 8, bc * 16, dropout_rate) # bc*16
244
- self.enc6 = self._conv_block(bc * 16, bc * 32, dropout_rate) # bc*32 (new layer)
245
-
246
- self.pool = nn.MaxPool3d(2)
247
-
248
- # Enhanced bottleneck with multiple processing paths
249
- self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
250
- bottleneck_in_channels = bc * 32
251
- bottleneck_width = bottleneck_in_channels * bottleneck_expansion
252
-
253
- self.bottleneck = nn.Sequential(
254
- nn.Conv3d(bottleneck_in_channels, bottleneck_width, kernel_size=1, padding=0, bias=True),
255
- nn.BatchNorm3d(bottleneck_width),
256
- nn.ReLU(inplace=True),
257
- nn.Dropout3d(dropout_rate),
258
- nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True),
259
- nn.BatchNorm3d(bottleneck_width),
260
- nn.ReLU(inplace=True),
261
- nn.Dropout3d(dropout_rate),
262
- nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True),
263
- nn.BatchNorm3d(bottleneck_width),
264
- nn.ReLU(inplace=True)
265
- )
266
-
267
- # Attention modules for skip connections (if enabled)
268
- if self.use_attention:
269
- self.att6 = self._attention_block(bottleneck_width, bc * 32, bc * 16)
270
- self.att5 = self._attention_block(bc * 32, bc * 16, bc * 8)
271
- self.att4 = self._attention_block(bc * 16, bc * 8, bc * 4)
272
- self.att3 = self._attention_block(bc * 8, bc * 4, bc * 2)
273
- self.att2 = self._attention_block(bc * 4, bc * 2, bc)
274
-
275
- # Enhanced decoder with more capacity
276
- self.dec6 = self._conv_block(bottleneck_width + bc * 32, bc * 32, dropout_rate)
277
-
278
- self.up5 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
279
- self.dec5 = self._conv_block(bc * 32 + bc * 16, bc * 16, dropout_rate)
280
-
281
- self.up4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
282
- self.dec4 = self._conv_block(bc * 16 + bc * 8, bc * 8, dropout_rate)
283
-
284
- self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
285
- self.dec3 = self._conv_block(bc * 8 + bc * 4, bc * 4, dropout_rate)
286
-
287
- self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
288
- self.dec2 = self._conv_block(bc * 4 + bc * 2, bc * 2, dropout_rate)
289
-
290
- self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
291
- self.dec1 = self._conv_block(bc * 2 + bc, bc, dropout_rate)
292
-
293
- # Enhanced output heads with intermediate processing
294
- self.vertex_intermediate = self._conv_block(bc, bc // 2, 0.0, use_residual=False)
295
- self.vertex_head = nn.Conv3d(bc // 2, 1, kernel_size=1)
296
-
297
- self.refinement_intermediate = self._conv_block(bc, bc // 2, 0.0, use_residual=False)
298
- self.refinement_head = nn.Conv3d(bc // 2, 3, kernel_size=1)
299
-
300
- self.tanh = nn.Tanh()
301
-
302
- def _conv_block(self, in_channels: int, out_channels: int, dropout_rate: float = 0.0,
303
- use_residual: bool = None) -> nn.Sequential:
304
- """Enhanced convolutional block with optional residual connections and dropout."""
305
- if use_residual is None:
306
- use_residual = self.use_residual
307
-
308
- layers = []
309
-
310
- # First conv
311
- layers.extend([
312
- nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
313
- nn.BatchNorm3d(out_channels),
314
- nn.ReLU(inplace=True)
315
- ])
316
-
317
- if dropout_rate > 0:
318
- layers.append(nn.Dropout3d(dropout_rate))
319
-
320
- # Second conv
321
- layers.extend([
322
- nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
323
- nn.BatchNorm3d(out_channels)
324
- ])
325
-
326
- # Third conv for extra capacity
327
- if out_channels >= 128:
328
- layers.extend([
329
- nn.ReLU(inplace=True),
330
- nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
331
- nn.BatchNorm3d(out_channels)
332
- ])
333
-
334
- block = nn.Sequential(*layers)
335
-
336
- # Add residual connection if channels match and residual is enabled
337
- if use_residual and in_channels == out_channels:
338
- return ResidualBlock(block)
339
- else:
340
- return nn.Sequential(block, nn.ReLU(inplace=True))
341
-
342
- def _attention_block(self, gate_channels: int, skip_channels: int, out_channels: int) -> nn.Module:
343
- """Attention gate for focusing on relevant features in skip connections."""
344
- return AttentionGate(gate_channels, skip_channels, out_channels)
345
-
346
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
347
- # Encoder path with increased depth
348
- e1 = self.enc1(x) # bc
349
- p1 = self.pool(e1)
350
-
351
- e2 = self.enc2(p1) # bc*2
352
- p2 = self.pool(e2)
353
-
354
- e3 = self.enc3(p2) # bc*4
355
- p3 = self.pool(e3)
356
-
357
- e4 = self.enc4(p3) # bc*8
358
- p4 = self.pool(e4)
359
-
360
- e5 = self.enc5(p4) # bc*16
361
- p5 = self.pool(e5)
362
-
363
- e6 = self.enc6(p5) # bc*32
364
- p6 = self.pool(e6)
365
-
366
- # Enhanced bottleneck
367
- b_pooled = self.adaptive_pool(p6)
368
- b = self.bottleneck(b_pooled)
369
-
370
- # Enhanced decoder path with attention
371
- u6_from_b = nn.functional.interpolate(b, size=e6.shape[2:], mode='trilinear', align_corners=True)
372
- if self.use_attention:
373
- e6_att = self.att6(u6_from_b, e6)
374
- cat6 = torch.cat([u6_from_b, e6_att], dim=1)
375
- else:
376
- cat6 = torch.cat([u6_from_b, e6], dim=1)
377
- d6 = self.dec6(cat6)
378
-
379
- u5 = self.up5(d6)
380
- if self.use_attention:
381
- e5_att = self.att5(u5, e5)
382
- cat5 = torch.cat([u5, e5_att], dim=1)
383
- else:
384
- cat5 = torch.cat([u5, e5], dim=1)
385
- d5 = self.dec5(cat5)
386
-
387
- u4 = self.up4(d5)
388
- if self.use_attention:
389
- e4_att = self.att4(u4, e4)
390
- cat4 = torch.cat([u4, e4_att], dim=1)
391
- else:
392
- cat4 = torch.cat([u4, e4], dim=1)
393
- d4 = self.dec4(cat4)
394
-
395
- u3 = self.up3(d4)
396
- if self.use_attention:
397
- e3_att = self.att3(u3, e3)
398
- cat3 = torch.cat([u3, e3_att], dim=1)
399
- else:
400
- cat3 = torch.cat([u3, e3], dim=1)
401
- d3 = self.dec3(cat3)
402
-
403
- u2 = self.up2(d3)
404
- if self.use_attention:
405
- e2_att = self.att2(u2, e2)
406
- cat2 = torch.cat([u2, e2_att], dim=1)
407
- else:
408
- cat2 = torch.cat([u2, e2], dim=1)
409
- d2 = self.dec2(cat2)
410
-
411
- u1 = self.up1(d2)
412
- cat1 = torch.cat([u1, e1], dim=1)
413
- d1 = self.dec1(cat1)
414
-
415
- # Enhanced output heads
416
- vertex_features = self.vertex_intermediate(d1)
417
- vertex_logits = self.vertex_head(vertex_features)
418
-
419
- refinement_features = self.refinement_intermediate(d1)
420
- refinement = self.tanh(self.refinement_head(refinement_features)) * 0.5
421
-
422
- return vertex_logits, refinement
423
-
424
-
425
- class ResidualBlock(nn.Module):
426
- """Residual block wrapper for skip connections."""
427
- def __init__(self, block):
428
- super().__init__()
429
- self.block = block
430
-
431
- def forward(self, x):
432
- return torch.relu(self.block(x) + x)
433
-
434
-
435
- class AttentionGate(nn.Module):
436
- """Attention gate for U-Net skip connections."""
437
- def __init__(self, gate_channels, skip_channels, out_channels):
438
- super().__init__()
439
- self.gate_conv = nn.Conv3d(gate_channels, out_channels, kernel_size=1, bias=True)
440
- self.skip_conv = nn.Conv3d(skip_channels, out_channels, kernel_size=1, bias=True)
441
- self.attention_conv = nn.Conv3d(out_channels, 1, kernel_size=1, bias=True)
442
- self.relu = nn.ReLU(inplace=True)
443
- self.sigmoid = nn.Sigmoid()
444
-
445
- def forward(self, gate, skip):
446
- gate_proj = self.gate_conv(gate)
447
- skip_proj = self.skip_conv(skip)
448
-
449
- # Ensure spatial dimensions match
450
- if gate_proj.shape[2:] != skip_proj.shape[2:]:
451
- gate_proj = nn.functional.interpolate(
452
- gate_proj, size=skip_proj.shape[2:],
453
- mode='trilinear', align_corners=True
454
- )
455
-
456
- combined = self.relu(gate_proj + skip_proj)
457
- attention = self.sigmoid(self.attention_conv(combined))
458
-
459
- return skip * attention
460
-
461
- class VoxelDataset(Dataset):
462
- def __init__(self, data_files: List[str], voxel_size: float = 0.1, grid_size: int = 64):
463
- self.data_files = data_files
464
- self.voxel_size = voxel_size
465
- self.grid_size = grid_size
466
-
467
- def __len__(self):
468
- return len(self.data_files)
469
-
470
- def __getitem__(self, idx):
471
- data = load_data(self.data_files[idx])
472
-
473
- voxel_grid, _, scale_info = voxelize_points(
474
- data['pcloud_14d'], self.grid_size, self.voxel_size
475
- )
476
-
477
- wf_vertices_np = np.array(data['wf_vertices'])
478
- vertex_labels, refinement_targets = create_ground_truth(
479
- wf_vertices_np, scale_info
480
- )
481
-
482
- return voxel_grid, vertex_labels, refinement_targets, scale_info
483
-
484
- import torch.nn as nn
485
- import torch.nn.functional as F
486
-
487
- class CombinedLoss(nn.Module):
488
- """
489
- Combined loss for vertex classification and offset regression.
490
- Uses:
491
- - BCEWithLogitsLoss
492
- - Dice loss
493
- - MSE loss on refinement offsets (only over positive voxels)
494
- - Gaussian blur on the GT labels
495
- """
496
- def __init__(self,
497
- vertex_weight: float = 1.0,
498
- refinement_weight: float = 0.1,
499
- dice_weight: float = 0.5,
500
- blur_kernel_size: int = 5,
501
- blur_sigma: float = 1.0,
502
- eps: float = 1e-6):
503
- super().__init__()
504
- self.vertex_weight = vertex_weight
505
- self.refinement_weight = refinement_weight
506
- self.dice_weight = dice_weight
507
- self.eps = eps
508
-
509
- # BCE with logits
510
- self.bce_loss = nn.BCEWithLogitsLoss()
511
- # MSE for offset regression
512
- self.mse_loss = nn.MSELoss()
513
-
514
- # build 3D gaussian kernel
515
- k = blur_kernel_size
516
- coords = torch.arange(k, dtype=torch.float32) - (k - 1) / 2
517
- xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij')
518
- kernel = torch.exp(-(xx**2 + yy**2 + zz**2) / (2 * blur_sigma**2))
519
- # shape (1,1,k,k,k)
520
- kernel = kernel.view(1, 1, k, k, k)
521
- self.register_buffer('gaussian_kernel', kernel)
522
- self.pad = k // 2
523
-
524
- def forward(self,
525
- vertex_logits_pred: torch.Tensor, # (B,1,D,H,W)
526
- refinement_pred: torch.Tensor, # (B,3,D,H,W)
527
- vertex_gt: torch.Tensor, # (B,D,H,W), 0/1
528
- refinement_gt: torch.Tensor # (B,3,D,H,W)
529
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
530
-
531
- # logits & gt
532
- logits = vertex_logits_pred.squeeze(1) # (B,D,H,W)
533
- gt = vertex_gt.float() # (B,D,H,W)
534
-
535
- # apply gaussian blur on gt
536
- gt_unsq = gt.unsqueeze(1) # (B,1,D,H,W)
537
- gt_blur = F.conv3d(gt_unsq, self.gaussian_kernel, padding=self.pad) # (B,1,D,H,W)
538
- gt_blur = gt_blur.clamp(0, 1) # ensure values are in [0, 1]
539
- gt_smooth = gt_blur.squeeze(1) # (B,D,H,W)
540
-
541
- # 1) Weighted BCE loss - positive when gt_smooth > 1e-3
542
- pos_mask = gt_smooth > 1e-3
543
- neg_mask = ~pos_mask
544
-
545
- # Compute BCE separately for positive and negative samples
546
- bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none')
547
- bce_all = bce_loss_fn(logits, gt_smooth)
548
-
549
- # Calculate weighted BCE
550
- pos_weight = 1.0
551
- neg_weight = 1.0
552
-
553
- if pos_mask.sum() > 0 and neg_mask.sum() > 0:
554
- pos_loss = bce_all[pos_mask].mean()
555
- neg_loss = bce_all[neg_mask].mean()
556
- bce = pos_weight * pos_loss + neg_weight * neg_loss
557
- elif pos_mask.sum() > 0:
558
- bce = pos_weight * bce_all[pos_mask].mean()
559
- elif neg_mask.sum() > 0:
560
- bce = neg_weight * bce_all[neg_mask].mean()
561
- else:
562
- bce = torch.tensor(0.0, device=logits.device)
563
-
564
- # 2) Dice loss
565
- prob = torch.sigmoid(logits)
566
- gt_smooth_round = (gt_smooth > 0.5).float() # binary mask
567
- intersection = (prob * gt_smooth_round).sum(dim=[1,2,3])
568
- union = prob.sum(dim=[1,2,3]) + gt_smooth_round.sum(dim=[1,2,3])
569
- dice_score = (2 * intersection + self.eps) / (union + self.eps)
570
- dice_loss = 1 - dice_score.mean()
571
-
572
- vertex_loss = bce + self.dice_weight * dice_loss
573
-
574
- # 3) Refinement MSE (only where original gt==1)
575
- mask_pos = (gt > 0.5).unsqueeze(1) # use hard mask for offsets
576
- if mask_pos.sum() > 0:
577
- pred_offsets = refinement_pred[mask_pos.expand_as(refinement_pred)] \
578
- .view(-1, 3)
579
- gt_offsets = refinement_gt[mask_pos.expand_as(refinement_gt)] \
580
- .view(-1, 3)
581
- refinement_loss = self.mse_loss(pred_offsets, gt_offsets)
582
- else:
583
- refinement_loss = torch.tensor(0., device=logits.device)
584
-
585
- # 4) Total loss
586
- total_loss = (self.vertex_weight * vertex_loss +
587
- self.refinement_weight * refinement_loss)
588
-
589
- return total_loss, vertex_loss, refinement_loss
590
-
591
- def train_epoch(model, dataloader, optimizer, criterion, device, current_epoch: int):
592
- model.train()
593
- total_loss_epoch = 0.0
594
- vertex_loss_epoch = 0.0
595
- refinement_loss_epoch = 0.0
596
-
597
- for batch_idx, (voxel_grid_batch, vertex_labels_batch, refinement_targets_batch, _) in enumerate(dataloader):
598
- voxel_grid_batch = voxel_grid_batch.to(device)
599
- vertex_labels_batch = vertex_labels_batch.to(device)
600
- refinement_targets_batch = refinement_targets_batch.to(device)
601
-
602
- if False:
603
- print(f'Epoch {current_epoch+1}, Batch {batch_idx+1}/{len(dataloader)}')
604
-
605
- sample_voxel_features = voxel_grid_batch[0].cpu().numpy()
606
- sample_gt_labels = vertex_labels_batch[0].cpu().numpy()
607
- sample_gt_refinement = refinement_targets_batch[0].cpu().numpy()
608
-
609
- summed_xyz_in_voxels = sample_voxel_features[:3]
610
- occupied_voxel_mask = np.any(summed_xyz_in_voxels != 0, axis=0)
611
-
612
- plotter = pv.Plotter(window_size=[800,600])
613
- plotter.background_color = 'white'
614
-
615
- if np.any(occupied_voxel_mask):
616
- occupied_voxel_indices = np.array(np.where(occupied_voxel_mask)).T
617
- input_points_display = pv.PolyData(occupied_voxel_indices + 0.5)
618
- plotter.add_mesh(input_points_display, color='cornflowerblue', point_size=5, render_points_as_spheres=True, label='Occupied Voxels (Centers)')
619
-
620
- gt_vertex_voxel_mask = sample_gt_labels > 0.5
621
- if np.any(gt_vertex_voxel_mask):
622
- gt_vertex_indices_int = np.array(np.where(gt_vertex_voxel_mask)).T
623
- gt_offsets = sample_gt_refinement[:, gt_vertex_voxel_mask].T
624
- gt_vertex_positions_grid_space = gt_vertex_indices_int.astype(float) + 0.5 + gt_offsets
625
-
626
- target_vertices_display = pv.PolyData(gt_vertex_positions_grid_space)
627
- plotter.add_mesh(target_vertices_display, color='crimson', point_size=10, render_points_as_spheres=True, label='Target Vertices (GT)')
628
-
629
- plotter.show(title=f"Debug Viz E{current_epoch+1} B{batch_idx+1}", auto_close=False)
630
- else:
631
- print(f"Epoch {current_epoch+1} Batch {batch_idx+1}: No data to visualize for the first sample.")
632
-
633
- optimizer.zero_grad()
634
- vertex_logits_pred, refinement_pred = model(voxel_grid_batch)
635
-
636
- loss, vertex_loss, refinement_loss = criterion(
637
- vertex_logits_pred, refinement_pred, vertex_labels_batch, refinement_targets_batch
638
- )
639
-
640
- print(f"Batch {batch_idx+1}/{len(dataloader)}: Loss={loss.item():.4f}, Vertex Loss={vertex_loss.item():.4f}, Refinement Loss={refinement_loss.item():.4f}")
641
-
642
- if loss > 0.000001:
643
- loss.backward()
644
- optimizer.step()
645
-
646
- total_loss_epoch += loss.item()
647
- vertex_loss_epoch += vertex_loss.item()
648
- refinement_loss_epoch += refinement_loss.item()
649
-
650
- if (batch_idx + 1) % 200 == 0:
651
- checkpoint_path = f"model_epoch_{current_epoch+1}_batch_{batch_idx+1}_grid_128v7.pth" # Consider updating filename if grid size changes
652
- torch.save(model.state_dict(), checkpoint_path)
653
- print(f"Saved batch checkpoint: {checkpoint_path}")
654
-
655
- avg_total_loss = total_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
656
- avg_vertex_loss = vertex_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
657
- avg_refinement_loss = refinement_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
658
-
659
- return avg_total_loss, avg_vertex_loss, avg_refinement_loss
660
-
661
- def train_model(data_folder: str = "data", num_epochs: int = 100, batch_size: int = 4, neg_pos_ratio_val: float = 1.0):
662
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
663
- print(f"Using device: {device}")
664
-
665
- data_files = get_data_files(data_folder)
666
- if not data_files:
667
- print(f"No data files found in {data_folder}. Exiting.")
668
- return
669
-
670
- GRID_SIZE_CFG = 64
671
- VOXEL_SIZE_CFG = 0.75
672
-
673
- dataset = VoxelDataset(data_files, voxel_size=VOXEL_SIZE_CFG, grid_size=GRID_SIZE_CFG)
674
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
675
-
676
- model = VoxelUNet().to(device)
677
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
678
-
679
- criterion = CombinedLoss(
680
- vertex_weight=1.0,
681
- refinement_weight=0.0,
682
- dice_weight=0.2
683
- ).to(device)
684
-
685
- print(f"Starting training: {num_epochs} epochs, Batch size: {batch_size}, Grid size: {GRID_SIZE_CFG}, Voxel size: {VOXEL_SIZE_CFG}, Initial LR: {optimizer.param_groups[0]['lr']}")
686
-
687
- for epoch in range(num_epochs):
688
- print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
689
-
690
- avg_loss, avg_vertex_loss, avg_refinement_loss = train_epoch(
691
- model, dataloader, optimizer, criterion, device, epoch
692
- )
693
-
694
- print(f"Epoch {epoch+1} Summary: Avg Loss: {avg_loss:.4f}, "
695
- f"Avg Vertex Loss: {avg_vertex_loss:.4f}, "
696
- f"Avg Refinement Loss: {avg_refinement_loss:.4f}, "
697
- f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
698
-
699
- checkpoint_path = f"model_epoch_{epoch+1}_grid{GRID_SIZE_CFG}_smooth_bal{neg_pos_ratio_val}_v7.pth"
700
- torch.save(model.state_dict(), checkpoint_path)
701
- print(f"Saved checkpoint: {checkpoint_path}")
702
-
703
- final_model_path = f"final_model_grid{GRID_SIZE_CFG}_epochs{num_epochs}_smooth_bal{neg_pos_ratio_val}_v7.pth"
704
- torch.save(model.state_dict(), final_model_path)
705
- print(f"Training completed! Final model saved as {final_model_path}")
706
-
707
- def load_model_for_inference(model_path: str, device: torch.device,
708
- in_channels: int = 14, base_channels: int = 32) -> VoxelUNet:
709
- """Load a VoxelUNet model for inference."""
710
- model = VoxelUNet(in_channels=in_channels, base_channels=base_channels)
711
- model.load_state_dict(torch.load(model_path, map_location=device))
712
- model.to(device)
713
- model.eval()
714
- print(f"Model loaded from {model_path} and set to evaluation mode on {device}.")
715
- return model
716
-
717
- def predict_vertices(model: VoxelUNet,
718
- point_cloud_14d: np.ndarray,
719
- grid_size: int,
720
- device: torch.device,
721
- voxel_size_metric: float = 0.35, # Added for consistency, default matches voxelize_points
722
- vertex_threshold: float = 0.5) -> np.ndarray:
723
- """
724
- Predict vertices from a 14D point cloud.
725
-
726
- Args:
727
- model: The trained VoxelUNet model.
728
- point_cloud_14d: (N, 14) NumPy array of the input point cloud.
729
- grid_size: The size of the voxel grid along X and Y dimensions (must match training).
730
- device: PyTorch device ('cuda' or 'cpu').
731
- voxel_size_metric: The metric size of each voxel (must match training).
732
- vertex_threshold: Threshold for classifying a voxel as containing a vertex.
733
-
734
- Returns:
735
- predicted_vertices_original_space: (M, 3) NumPy array of predicted vertex
736
- coordinates in the original point cloud space (X, Y, Z order).
737
- Returns an empty array if no vertices are predicted
738
- or if the input point cloud results in an empty voxel grid.
739
- """
740
- voxel_grid_tensor, _, scale_info = voxelize_points(
741
- point_cloud_14d,
742
- grid_size_xy=grid_size,
743
- voxel_size_metric=voxel_size_metric
744
- )
745
-
746
- # Check if voxelization produced a valid grid (e.g., if input point cloud was empty)
747
- # voxelize_points returns a zero tensor for grid if input points are empty.
748
- # If voxel_grid_tensor is all zeros and no points were input, scale_info might be default.
749
- if voxel_grid_tensor.sum() == 0 and point_cloud_14d.shape[0] == 0:
750
- # This case implies empty input point cloud, voxelize_points handles this.
751
- # Predictions will naturally be empty if the grid is empty.
752
- pass # Continue, model will predict on zero grid.
753
-
754
- input_tensor = voxel_grid_tensor.unsqueeze(0).to(device)
755
-
756
- with torch.no_grad():
757
- vertex_logits_pred_tensor, refinement_pred_tensor = model(input_tensor)
758
-
759
- vertex_prob_pred_tensor = torch.sigmoid(vertex_logits_pred_tensor)
760
-
761
- vertex_prob_pred_np = vertex_prob_pred_tensor.squeeze(0).squeeze(0).cpu().numpy()
762
- refinement_pred_np = refinement_pred_tensor.squeeze(0).cpu().numpy() # Shape (3, D, H, W) -> (dx,dy,dz channels)
763
-
764
- print(f"Vertex Probabilities Stats: Min={np.min(vertex_prob_pred_np):.4f}, Max={np.max(vertex_prob_pred_np):.4f}, Mean={np.mean(vertex_prob_pred_np):.4f}, Median={np.median(vertex_prob_pred_np):.4f}")
765
- if refinement_pred_np.size > 0:
766
- print(f"Refinement Predictions Stats: Min={np.min(refinement_pred_np):.4f}, Max={np.max(refinement_pred_np):.4f}, Mean={np.mean(refinement_pred_np):.4f}, Median={np.median(refinement_pred_np):.4f}")
767
- for i in range(refinement_pred_np.shape[0]): # Iterate over dx, dy, dz components
768
- print(f" Refinement Dim {i} (dx,dy,dz order) Stats: Min={np.min(refinement_pred_np[i]):.4f}, Max={np.max(refinement_pred_np[i]):.4f}, Mean={np.mean(refinement_pred_np[i]):.4f}, Median={np.median(refinement_pred_np[i]):.4f}")
769
- else:
770
- print("Refinement Predictions Stats: Array is empty.")
771
-
772
- predicted_mask = vertex_prob_pred_np > vertex_threshold
773
- # predicted_voxel_indices are (N_preds, 3) with columns (idx_z, idx_y, idx_x)
774
- predicted_voxel_indices_zyx = np.argwhere(predicted_mask)
775
-
776
- if not predicted_voxel_indices_zyx.size:
777
- return np.empty((0, 3), dtype=np.float32)
778
-
779
- # Extract refinement offsets for the predicted voxels
780
- # offsets_channels_first will be (3, N_preds) where channels are (dx, dy, dz)
781
- offsets_channels_first = refinement_pred_np[:,
782
- predicted_voxel_indices_zyx[:, 0], # z_indices
783
- predicted_voxel_indices_zyx[:, 1], # y_indices
784
- predicted_voxel_indices_zyx[:, 2]] # x_indices
785
-
786
- # Transpose to (N_preds, 3) where columns are (dx, dy, dz)
787
- offsets_xyz_order = offsets_channels_first.T
788
-
789
- # Calculate refined coordinates in continuous voxel grid space (X, Y, Z order)
790
- # Voxel center is at index + 0.5
791
- # Refinement is added to this center.
792
- # predicted_voxel_indices_zyx[:, 2] is x_idx
793
- # predicted_voxel_indices_zyx[:, 1] is y_idx
794
- # predicted_voxel_indices_zyx[:, 0] is z_idx
795
-
796
- # offsets_xyz_order[:, 0] is dx
797
- # offsets_xyz_order[:, 1] is dy
798
- # offsets_xyz_order[:, 2] is dz
799
-
800
- refined_x_grid = predicted_voxel_indices_zyx[:, 2].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 0]
801
- refined_y_grid = predicted_voxel_indices_zyx[:, 1].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 1]
802
- refined_z_grid = predicted_voxel_indices_zyx[:, 0].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 2]
803
-
804
- # Stack to get (N_preds, 3) array in (X, Y, Z) order
805
- refined_grid_coords_xyz = np.stack((refined_x_grid, refined_y_grid, refined_z_grid), axis=-1)
806
-
807
- # Convert refined grid coordinates to original metric space
808
- grid_origin_metric = np.array(scale_info['grid_origin_metric']) # (ox, oy, oz)
809
- # Voxel_size_metric from scale_info should match the input voxel_size_metric parameter
810
- current_voxel_size_metric = scale_info['voxel_size_metric']
811
-
812
- # predicted_vertices_original_space are (N_preds, 3) in (X,Y,Z) order
813
- predicted_vertices_original_space = refined_grid_coords_xyz * current_voxel_size_metric + grid_origin_metric
814
-
815
- return predicted_vertices_original_space.astype(np.float32)
816
-
817
- # Simple inference script
818
- def run_inference(model_path: str,
819
- data_file_path: str,
820
- output_file: str = None,
821
- grid_size: int = 128,
822
- voxel_size: float = 0.5,
823
- vertex_threshold: float = 0.5):
824
- """
825
- Run inference on all data files in a directory, visualize with pyvista, and save results.
826
- """
827
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
828
- print(f"Using device: {device}")
829
-
830
- # Load model
831
- model = load_model_for_inference(model_path, device)
832
-
833
- # Get all data files from the directory
834
- data_files = get_data_files(data_file_path)
835
- if not data_files:
836
- print(f"No data files found in {data_file_path}")
837
- return
838
-
839
- print(f"Found {len(data_files)} data files to process")
840
-
841
- for i, file_path in enumerate(data_files):
842
- print(f"\n--- Processing file {i+1}/{len(data_files)}: {os.path.basename(file_path)} ---")
843
-
844
- # Load input data
845
- try:
846
- data = load_data(file_path)
847
- except Exception as e:
848
- print(f"Error loading {file_path}: {e}")
849
- continue
850
-
851
- if 'pcloud_14d' not in data:
852
- print(f"Error: File {file_path} does not contain 'pcloud_14d' key, skipping")
853
- continue
854
-
855
- # Extract original point cloud and ground-truth vertices
856
- pcloud = data['pcloud_14d'][:, :3] # (N,3)
857
- gt_vertices = np.array(data.get('wf_vertices', [])) # (M,3) or empty
858
-
859
- print(f"Input point cloud shape: {pcloud.shape}")
860
- if gt_vertices.size:
861
- print(f"GT vertices shape: {gt_vertices.shape}")
862
-
863
- # Run prediction
864
- print("Running inference...")
865
- try:
866
- predicted_vertices = predict_vertices(
867
- model=model,
868
- point_cloud_14d=data['pcloud_14d'],
869
- grid_size=grid_size,
870
- device=device,
871
- voxel_size_metric=voxel_size,
872
- vertex_threshold=vertex_threshold
873
- )
874
- except Exception as e:
875
- print(f"Error during prediction for {file_path}: {e}")
876
- continue
877
-
878
- print(f"Predicted {len(predicted_vertices)} vertices")
879
-
880
- # --- Visualization ---
881
- plotter = pv.Plotter(window_size=[800,600])
882
- plotter.background_color = 'white'
883
-
884
- # Original point cloud in light gray
885
- if pcloud.size:
886
- pc_cloud = pv.PolyData(pcloud)
887
- plotter.add_mesh(pc_cloud, color='lightgray', point_size=2, render_points_as_spheres=True, label='Input PC')
888
-
889
- # Ground-truth vertices in red
890
- if gt_vertices.size:
891
- gt_pd = pv.PolyData(gt_vertices)
892
- plotter.add_mesh(gt_pd, color='red', point_size=8, render_points_as_spheres=True, label='GT Vertices')
893
-
894
- # Predicted vertices in blue
895
- if predicted_vertices.size:
896
- pred_pd = pv.PolyData(predicted_vertices)
897
- plotter.add_mesh(pred_pd, color='blue', point_size=8, render_points_as_spheres=True, label='Predicted Vertices')
898
-
899
- plotter.add_legend()
900
- plotter.show(title=os.path.basename(file_path))
901
-
902
- # Prepare output data
903
- output_data = {
904
- 'predicted_vertices': predicted_vertices,
905
- 'input_file': file_path,
906
- 'model_used': model_path,
907
- 'grid_size': grid_size,
908
- 'voxel_size': voxel_size,
909
- 'vertex_threshold': vertex_threshold,
910
- 'original_data': data
911
- }
912
-
913
- # Save results
914
- base_name = os.path.splitext(os.path.basename(file_path))[0]
915
- output_filename = f"{base_name}_predictions"
916
- try:
917
- save_data(output_data, output_filename)
918
- print(f"Results saved to: {output_filename}.pkl")
919
- except Exception as e:
920
- print(f"Error saving results for {file_path}: {e}")
921
-
922
- print(f"\nCompleted processing {len(data_files)} files")
923
-
924
- if __name__ == "__main__":
925
- inference = False
926
-
927
- data_folder_train = '/mnt/personal/skvrnjan/hoho_end/'
928
- #data_folder_train = '/home/skvrnjan/personal/hoho_end'
929
- num_epochs_train = 100
930
- batch_size_train = 4
931
- # This parameter now controls the ratio of negative to positive samples for BCE loss
932
- negative_to_positive_bce_ratio = 1
933
-
934
- if inference:
935
- run_inference(model_path='/home/skvrnjan/personal/hoho/model_epoch_100_grid128_smooth_bal1.pth',
936
- data_file_path=data_folder_train,
937
- output_file=None,
938
- grid_size=128,
939
- voxel_size=0.5,
940
- vertex_threshold=0.5
941
- )
942
- else:
943
- train_model(data_folder=data_folder_train,
944
- num_epochs=num_epochs_train,
945
- batch_size=batch_size_train,
946
- neg_pos_ratio_val=negative_to_positive_bce_ratio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet.py DELETED
@@ -1,520 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class FastPointNet(nn.Module):
12
- """
13
- Fast PointNet implementation for 3D vertex prediction from point cloud patches.
14
- Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
15
- Enhanced with deeper architecture and more parameters for better generalization.
16
- """
17
- def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1):
18
- super(FastPointNet, self).__init__()
19
- self.max_points = max_points
20
- self.predict_score = predict_score
21
- self.predict_class = predict_class
22
- self.num_classes = num_classes
23
-
24
- # Enhanced point-wise MLPs with deeper architecture
25
- self.conv1 = nn.Conv1d(input_dim, 128, 1)
26
- self.conv2 = nn.Conv1d(128, 256, 1)
27
- self.conv3 = nn.Conv1d(256, 512, 1)
28
- self.conv4 = nn.Conv1d(512, 1024, 1)
29
-
30
- # Additional layers for better feature extraction
31
- self.conv5 = nn.Conv1d(1024, 1024, 1)
32
- self.conv6 = nn.Conv1d(1024, 2048, 1)
33
-
34
- # Larger shared features
35
- self.shared_fc1 = nn.Linear(2048, 1024)
36
- self.shared_fc2 = nn.Linear(1024, 512)
37
-
38
- # Enhanced position prediction head
39
- self.pos_fc1 = nn.Linear(512, 512)
40
- self.pos_fc2 = nn.Linear(512, 256)
41
- self.pos_fc3 = nn.Linear(256, 128)
42
- self.pos_fc4 = nn.Linear(128, output_dim)
43
-
44
- # Enhanced score prediction head
45
- if self.predict_score:
46
- self.score_fc1 = nn.Linear(512, 512)
47
- self.score_fc2 = nn.Linear(512, 256)
48
- self.score_fc3 = nn.Linear(256, 128)
49
- self.score_fc4 = nn.Linear(128, 64)
50
- self.score_fc5 = nn.Linear(64, 1)
51
-
52
- # Classification head
53
- if self.predict_class:
54
- self.class_fc1 = nn.Linear(512, 512)
55
- self.class_fc2 = nn.Linear(512, 256)
56
- self.class_fc3 = nn.Linear(256, 128)
57
- self.class_fc4 = nn.Linear(128, 64)
58
- self.class_fc5 = nn.Linear(64, num_classes)
59
-
60
- # Batch normalization layers
61
- self.bn1 = nn.BatchNorm1d(128)
62
- self.bn2 = nn.BatchNorm1d(256)
63
- self.bn3 = nn.BatchNorm1d(512)
64
- self.bn4 = nn.BatchNorm1d(1024)
65
- self.bn5 = nn.BatchNorm1d(1024)
66
- self.bn6 = nn.BatchNorm1d(2048)
67
-
68
- # Dropout with different rates
69
- self.dropout_light = nn.Dropout(0.2)
70
- self.dropout_medium = nn.Dropout(0.3)
71
- self.dropout_heavy = nn.Dropout(0.4)
72
-
73
- def forward(self, x):
74
- """
75
- Forward pass
76
- Args:
77
- x: (batch_size, input_dim, max_points) tensor
78
- Returns:
79
- Tuple containing predictions based on configuration:
80
- - position: (batch_size, output_dim) tensor of predicted 3D coordinates
81
- - score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True)
82
- - classification: (batch_size, num_classes) tensor of class logits (if predict_class=True)
83
- """
84
- batch_size = x.size(0)
85
-
86
- # Enhanced point-wise feature extraction with residual-like connections
87
- x1 = F.relu(self.bn1(self.conv1(x)))
88
- x2 = F.relu(self.bn2(self.conv2(x1)))
89
- x3 = F.relu(self.bn3(self.conv3(x2)))
90
- x4 = F.relu(self.bn4(self.conv4(x3)))
91
- x5 = F.relu(self.bn5(self.conv5(x4)))
92
- x6 = F.relu(self.bn6(self.conv6(x5)))
93
-
94
- # Global max pooling with additional global average pooling
95
- max_pool = torch.max(x6, 2)[0] # (batch_size, 2048)
96
- avg_pool = torch.mean(x6, 2) # (batch_size, 2048)
97
-
98
- # Combine max and average pooling for richer global features
99
- global_features = max_pool + avg_pool # (batch_size, 2048)
100
-
101
- # Enhanced shared features with residual connection
102
- shared1 = F.relu(self.shared_fc1(global_features))
103
- shared1 = self.dropout_light(shared1)
104
- shared2 = F.relu(self.shared_fc2(shared1))
105
- shared_features = self.dropout_medium(shared2)
106
-
107
- # Enhanced position prediction with skip connections
108
- pos1 = F.relu(self.pos_fc1(shared_features))
109
- pos1 = self.dropout_light(pos1)
110
- pos2 = F.relu(self.pos_fc2(pos1))
111
- pos2 = self.dropout_medium(pos2)
112
- pos3 = F.relu(self.pos_fc3(pos2))
113
- pos3 = self.dropout_light(pos3)
114
- position = self.pos_fc4(pos3)
115
-
116
- outputs = [position]
117
-
118
- if self.predict_score:
119
- # Enhanced score prediction
120
- score1 = F.relu(self.score_fc1(shared_features))
121
- score1 = self.dropout_light(score1)
122
- score2 = F.relu(self.score_fc2(score1))
123
- score2 = self.dropout_medium(score2)
124
- score3 = F.relu(self.score_fc3(score2))
125
- score3 = self.dropout_light(score3)
126
- score4 = F.relu(self.score_fc4(score3))
127
- score4 = self.dropout_light(score4)
128
- score = F.relu(self.score_fc5(score4)) # Ensure positive distance
129
- outputs.append(score)
130
-
131
- if self.predict_class:
132
- # Classification prediction
133
- class1 = F.relu(self.class_fc1(shared_features))
134
- class1 = self.dropout_light(class1)
135
- class2 = F.relu(self.class_fc2(class1))
136
- class2 = self.dropout_medium(class2)
137
- class3 = F.relu(self.class_fc3(class2))
138
- class3 = self.dropout_light(class3)
139
- class4 = F.relu(self.class_fc4(class3))
140
- class4 = self.dropout_light(class4)
141
- classification = self.class_fc5(class4) # Raw logits
142
- outputs.append(classification)
143
-
144
- # Return outputs based on configuration
145
- if len(outputs) == 1:
146
- return outputs[0] # Only position
147
- elif len(outputs) == 2:
148
- if self.predict_score:
149
- return outputs[0], outputs[1] # position, score
150
- else:
151
- return outputs[0], outputs[1] # position, classification
152
- else:
153
- return outputs[0], outputs[1], outputs[2] # position, score, classification
154
-
155
- class PatchDataset(Dataset):
156
- """
157
- Dataset class for loading saved patches for PointNet training.
158
- """
159
-
160
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
161
- self.dataset_dir = dataset_dir
162
- self.max_points = max_points
163
- self.augment = augment
164
-
165
- # Load patch files
166
- self.patch_files = []
167
- for file in os.listdir(dataset_dir):
168
- if file.endswith('.pkl'):
169
- self.patch_files.append(os.path.join(dataset_dir, file))
170
-
171
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
172
-
173
- def __len__(self):
174
- return len(self.patch_files)
175
-
176
- def __getitem__(self, idx):
177
- """
178
- Load and process a patch for training.
179
- Returns:
180
- patch_data: (7, max_points) tensor of point cloud data
181
- target: (3,) tensor of target 3D coordinates
182
- valid_mask: (max_points,) boolean tensor indicating valid points
183
- distance_to_gt: scalar tensor of distance from initial prediction to GT
184
- classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
185
- """
186
- patch_file = self.patch_files[idx]
187
-
188
- with open(patch_file, 'rb') as f:
189
- patch_info = pickle.load(f)
190
-
191
- patch_7d = patch_info['patch_7d'] # (N, 7)
192
- target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
193
- initial_pred = patch_info.get('cluster_center', None) # (3,) or None
194
-
195
- # Determine classification label based on GT vertex presence
196
- has_gt_vertex = 1.0 if target is not None else 0.0
197
-
198
- # Handle patches without ground truth
199
- if target is None:
200
- # Use a dummy target for consistency, but mark as invalid with classification
201
- target = np.zeros(3)
202
- else:
203
- target = np.array(target)
204
-
205
- # Pad or sample points to max_points
206
- num_points = patch_7d.shape[0]
207
-
208
- if num_points >= self.max_points:
209
- # Randomly sample max_points
210
- indices = np.random.choice(num_points, self.max_points, replace=False)
211
- patch_sampled = patch_7d[indices]
212
- valid_mask = np.ones(self.max_points, dtype=bool)
213
- else:
214
- # Pad with zeros
215
- patch_sampled = np.zeros((self.max_points, 7))
216
- patch_sampled[:num_points] = patch_7d
217
- valid_mask = np.zeros(self.max_points, dtype=bool)
218
- valid_mask[:num_points] = True
219
-
220
- # Data augmentation (only if GT vertex is present)
221
- if self.augment and has_gt_vertex > 0:
222
- patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target)
223
-
224
- # Convert to tensors and transpose for conv1d (channels first)
225
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (7, max_points)
226
- target_tensor = torch.from_numpy(target).float() # (3,)
227
- valid_mask_tensor = torch.from_numpy(valid_mask)
228
-
229
- # Handle initial_pred
230
- if initial_pred is not None:
231
- initial_pred_tensor = torch.from_numpy(initial_pred).float()
232
- else:
233
- initial_pred_tensor = torch.zeros(3).float()
234
-
235
- # Classification tensor
236
- classification_tensor = torch.tensor(has_gt_vertex).float()
237
-
238
- return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor
239
-
240
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
241
- """
242
- Save patches from prediction pipeline to create a training dataset.
243
-
244
- Args:
245
- patches: List of patch dictionaries from generate_patches()
246
- dataset_dir: Directory to save the dataset
247
- entry_id: Unique identifier for this entry/image
248
- """
249
- os.makedirs(dataset_dir, exist_ok=True)
250
-
251
- for i, patch in enumerate(patches):
252
- # Create unique filename
253
- filename = f"{entry_id}_patch_{i}.pkl"
254
- filepath = os.path.join(dataset_dir, filename)
255
-
256
- # Skip if file already exists
257
- if os.path.exists(filepath):
258
- continue
259
-
260
- # Save patch data
261
- with open(filepath, 'wb') as f:
262
- pickle.dump(patch, f)
263
-
264
- print(f"Saved {len(patches)} patches for entry {entry_id}")
265
-
266
- # Create dataloader with custom collate function to filter invalid samples
267
- def collate_fn(batch):
268
- valid_batch = []
269
- for patch_data, target, valid_mask, initial_pred, classification in batch:
270
- # Filter out invalid samples (no valid points)
271
- if valid_mask.sum() > 0:
272
- valid_batch.append((patch_data, target, valid_mask, initial_pred, classification))
273
-
274
- if len(valid_batch) == 0:
275
- return None
276
-
277
- # Stack valid samples
278
- patch_data = torch.stack([item[0] for item in valid_batch])
279
- targets = torch.stack([item[1] for item in valid_batch])
280
- valid_masks = torch.stack([item[2] for item in valid_batch])
281
- initial_preds = torch.stack([item[3] for item in valid_batch])
282
- classifications = torch.stack([item[4] for item in valid_batch])
283
-
284
- return patch_data, targets, valid_masks, initial_preds, classifications
285
-
286
- # Initialize weights using Xavier/Glorot initialization
287
- def init_weights(m):
288
- if isinstance(m, nn.Conv1d):
289
- nn.init.xavier_uniform_(m.weight)
290
- if m.bias is not None:
291
- nn.init.zeros_(m.bias)
292
- elif isinstance(m, nn.Linear):
293
- nn.init.xavier_uniform_(m.weight)
294
- if m.bias is not None:
295
- nn.init.zeros_(m.bias)
296
- elif isinstance(m, nn.BatchNorm1d):
297
- nn.init.ones_(m.weight)
298
- nn.init.zeros_(m.bias)
299
-
300
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
301
- score_weight: float = 0.1, class_weight: float = 0.5):
302
- """
303
- Train the FastPointNet model on saved patches.
304
-
305
- Args:
306
- dataset_dir: Directory containing saved patch files
307
- model_save_path: Path to save the trained model
308
- epochs: Number of training epochs
309
- batch_size: Training batch size
310
- lr: Learning rate
311
- score_weight: Weight for the distance prediction loss
312
- class_weight: Weight for the classification loss
313
- """
314
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
315
- print(f"Training on device: {device}")
316
-
317
- # Create dataset and dataloader
318
- dataset = PatchDataset(dataset_dir, max_points=1024, augment=False)
319
- print(f"Dataset loaded with {len(dataset)} samples")
320
-
321
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
322
- collate_fn=collate_fn, drop_last=True)
323
-
324
- # Initialize model with score and classification prediction
325
- model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1)
326
-
327
- model.apply(init_weights)
328
- model.to(device)
329
-
330
- # Loss functions
331
- position_criterion = nn.MSELoss()
332
- score_criterion = nn.MSELoss()
333
- classification_criterion = nn.BCEWithLogitsLoss()
334
-
335
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
336
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
337
-
338
- # Training loop
339
- model.train()
340
- for epoch in range(epochs):
341
- total_loss = 0.0
342
- total_pos_loss = 0.0
343
- total_score_loss = 0.0
344
- total_class_loss = 0.0
345
- num_batches = 0
346
-
347
- for batch_idx, batch_data in enumerate(dataloader):
348
- if batch_data is None: # Skip invalid batches
349
- continue
350
-
351
- patch_data, targets, valid_masks, initial_preds, classifications = batch_data
352
- patch_data = patch_data.to(device) # (batch_size, 7, max_points)
353
- targets = targets.to(device) # (batch_size, 3)
354
- classifications = classifications.to(device) # (batch_size,)
355
-
356
- # Forward pass
357
- optimizer.zero_grad()
358
- predictions, predicted_scores, predicted_classes = model(patch_data)
359
-
360
- # Compute actual distance from predictions to targets
361
- actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
362
-
363
- # Only compute position and score losses for samples with GT vertices
364
- has_gt_mask = classifications > 0.5
365
-
366
- if has_gt_mask.sum() > 0:
367
- # Position loss only for samples with GT vertices
368
- pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
369
- score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
370
- else:
371
- pos_loss = torch.tensor(0.0, device=device)
372
- score_loss = torch.tensor(0.0, device=device)
373
-
374
- # Classification loss for all samples
375
- class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
376
-
377
- # Combined loss
378
- total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
379
-
380
- # Backward pass
381
- total_batch_loss.backward()
382
- optimizer.step()
383
-
384
- total_loss += total_batch_loss.item()
385
- total_pos_loss += pos_loss.item()
386
- total_score_loss += score_loss.item()
387
- total_class_loss += class_loss.item()
388
- num_batches += 1
389
-
390
- if batch_idx % 50 == 0:
391
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
392
- f"Total Loss: {total_batch_loss.item():.6f}, "
393
- f"Pos Loss: {pos_loss.item():.6f}, "
394
- f"Score Loss: {score_loss.item():.6f}, "
395
- f"Class Loss: {class_loss.item():.6f}")
396
-
397
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
398
- avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
399
- avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
400
- avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
401
-
402
- print(f"Epoch {epoch+1}/{epochs} completed, "
403
- f"Avg Total Loss: {avg_loss:.6f}, "
404
- f"Avg Pos Loss: {avg_pos_loss:.6f}, "
405
- f"Avg Score Loss: {avg_score_loss:.6f}, "
406
- f"Avg Class Loss: {avg_class_loss:.6f}")
407
-
408
- scheduler.step()
409
-
410
- # Save model checkpoint every epoch
411
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
412
- torch.save({
413
- 'model_state_dict': model.state_dict(),
414
- 'optimizer_state_dict': optimizer.state_dict(),
415
- 'epoch': epoch + 1,
416
- 'loss': avg_loss,
417
- }, checkpoint_path)
418
-
419
- # Save the trained model
420
- torch.save({
421
- 'model_state_dict': model.state_dict(),
422
- 'optimizer_state_dict': optimizer.state_dict(),
423
- 'epoch': epochs,
424
- }, model_save_path)
425
-
426
- print(f"Model saved to {model_save_path}")
427
- return model
428
-
429
- def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
430
- """
431
- Load a trained FastPointNet model.
432
-
433
- Args:
434
- model_path: Path to the saved model
435
- device: Device to load the model on
436
- predict_score: Whether the model predicts scores
437
-
438
- Returns:
439
- Loaded FastPointNet model
440
- """
441
- if device is None:
442
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
443
-
444
- model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score)
445
-
446
- checkpoint = torch.load(model_path, map_location=device)
447
- model.load_state_dict(checkpoint['model_state_dict'])
448
-
449
- model.to(device)
450
- model.eval()
451
-
452
- return model
453
-
454
- def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]:
455
- """
456
- Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet.
457
-
458
- Args:
459
- model: Trained FastPointNet model
460
- patch: Dictionary containing patch data with 'patch_7d' and 'offset' keys
461
- device: Device to run prediction on
462
-
463
- Returns:
464
- tuple of (predicted_coordinates, confidence_score, classification_score)
465
- predicted_coordinates: (3,) numpy array of predicted 3D coordinates
466
- confidence_score: float representing predicted distance to GT (lower is better)
467
- classification_score: float representing probability of GT vertex presence (0-1)
468
- """
469
- if device is None:
470
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
471
-
472
- patch_7d = patch['patch_7d'] # (N, 7)
473
-
474
- # Prepare input
475
- max_points = 1024
476
- num_points = patch_7d.shape[0]
477
-
478
- if num_points >= max_points:
479
- # Sample points
480
- indices = np.random.choice(num_points, max_points, replace=False)
481
- patch_sampled = patch_7d[indices]
482
- else:
483
- # Pad with zeros
484
- patch_sampled = np.zeros((max_points, 7))
485
- patch_sampled[:num_points] = patch_7d
486
-
487
- # Convert to tensor
488
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 7, max_points)
489
- patch_tensor = patch_tensor.to(device)
490
-
491
- # Predict
492
- with torch.no_grad():
493
- outputs = model(patch_tensor)
494
-
495
- if model.predict_score and model.predict_class:
496
- position, score, classification = outputs
497
- position = position.cpu().numpy().squeeze()
498
- score = score.cpu().numpy().squeeze()
499
- classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
500
- elif model.predict_score:
501
- position, score = outputs
502
- position = position.cpu().numpy().squeeze()
503
- score = score.cpu().numpy().squeeze()
504
- classification = None
505
- elif model.predict_class:
506
- position, classification = outputs
507
- position = position.cpu().numpy().squeeze()
508
- score = None
509
- classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
510
- else:
511
- position = outputs
512
- position = position.cpu().numpy().squeeze()
513
- score = None
514
- classification = None
515
-
516
- # Apply offset correction
517
- offset = patch['cluster_center']
518
- position += offset
519
-
520
- return position, score, classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_class.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
@@ -403,3 +409,4 @@ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device:
403
  predicted_class = int(probability > 0.5)
404
 
405
  return predicted_class, probability
 
 
1
+ # This file defines a PointNet-based model for binary classification of 6D point cloud patches.
2
+ # It includes the model architecture (ClassificationPointNet), a custom dataset class
3
+ # (PatchClassificationDataset) for loading and augmenting patches, functions for saving
4
+ # patches to create a dataset, a training loop (train_pointnet), a function to load
5
+ # a trained model (load_pointnet_model), and a function for predicting class labels
6
+ # from new patches (predict_class_from_patch).
7
  import os
8
  import torch
9
  import torch.nn as nn
 
409
  predicted_class = int(probability > 0.5)
410
 
411
  return predicted_class, probability
412
+
fast_pointnet_class_10d.py DELETED
@@ -1,405 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class ClassificationPointNet(nn.Module):
12
- """
13
- PointNet implementation for binary classification from 10D point cloud patches.
14
- Takes 10D point clouds and predicts binary classification (edge/not edge).
15
- """
16
- def __init__(self, input_dim=10, max_points=1024):
17
- super(ClassificationPointNet, self).__init__()
18
- self.max_points = max_points
19
-
20
- # Point-wise MLPs for feature extraction (deeper network)
21
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
22
- self.conv2 = nn.Conv1d(64, 128, 1)
23
- self.conv3 = nn.Conv1d(128, 256, 1)
24
- self.conv4 = nn.Conv1d(256, 512, 1)
25
- self.conv5 = nn.Conv1d(512, 1024, 1)
26
- self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
-
28
- # Classification head (deeper with more capacity)
29
- self.fc1 = nn.Linear(2048, 1024)
30
- self.fc2 = nn.Linear(1024, 512)
31
- self.fc3 = nn.Linear(512, 256)
32
- self.fc4 = nn.Linear(256, 128)
33
- self.fc5 = nn.Linear(128, 64)
34
- self.fc6 = nn.Linear(64, 1) # Single output for binary classification
35
-
36
- # Batch normalization layers
37
- self.bn1 = nn.BatchNorm1d(64)
38
- self.bn2 = nn.BatchNorm1d(128)
39
- self.bn3 = nn.BatchNorm1d(256)
40
- self.bn4 = nn.BatchNorm1d(512)
41
- self.bn5 = nn.BatchNorm1d(1024)
42
- self.bn6 = nn.BatchNorm1d(2048)
43
-
44
- # Dropout layers
45
- self.dropout1 = nn.Dropout(0.3)
46
- self.dropout2 = nn.Dropout(0.4)
47
- self.dropout3 = nn.Dropout(0.5)
48
- self.dropout4 = nn.Dropout(0.4)
49
- self.dropout5 = nn.Dropout(0.3)
50
-
51
- def forward(self, x):
52
- """
53
- Forward pass
54
- Args:
55
- x: (batch_size, input_dim, max_points) tensor
56
- Returns:
57
- classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
- """
59
- batch_size = x.size(0)
60
-
61
- # Point-wise feature extraction
62
- x1 = F.relu(self.bn1(self.conv1(x)))
63
- x2 = F.relu(self.bn2(self.conv2(x1)))
64
- x3 = F.relu(self.bn3(self.conv3(x2)))
65
- x4 = F.relu(self.bn4(self.conv4(x3)))
66
- x5 = F.relu(self.bn5(self.conv5(x4)))
67
- x6 = F.relu(self.bn6(self.conv6(x5)))
68
-
69
- # Global max pooling
70
- global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
-
72
- # Classification head
73
- x = F.relu(self.fc1(global_features))
74
- x = self.dropout1(x)
75
- x = F.relu(self.fc2(x))
76
- x = self.dropout2(x)
77
- x = F.relu(self.fc3(x))
78
- x = self.dropout3(x)
79
- x = F.relu(self.fc4(x))
80
- x = self.dropout4(x)
81
- x = F.relu(self.fc5(x))
82
- x = self.dropout5(x)
83
- classification = self.fc6(x) # (batch_size, 1)
84
-
85
- return classification
86
-
87
- class PatchClassificationDataset(Dataset):
88
- """
89
- Dataset class for loading saved patches for PointNet classification training.
90
- """
91
-
92
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
93
- self.dataset_dir = dataset_dir
94
- self.max_points = max_points
95
- self.augment = augment
96
-
97
- # Load patch files
98
- self.patch_files = []
99
- for file in os.listdir(dataset_dir):
100
- if file.endswith('.pkl'):
101
- self.patch_files.append(os.path.join(dataset_dir, file))
102
-
103
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
104
-
105
- def __len__(self):
106
- return len(self.patch_files)
107
-
108
- def __getitem__(self, idx):
109
- """
110
- Load and process a patch for training.
111
- Returns:
112
- patch_data: (10, max_points) tensor of point cloud data
113
- label: scalar tensor for binary classification (0 or 1)
114
- valid_mask: (max_points,) boolean tensor indicating valid points
115
- """
116
- patch_file = self.patch_files[idx]
117
-
118
- with open(patch_file, 'rb') as f:
119
- patch_info = pickle.load(f)
120
-
121
- patch_10d = patch_info['patch_10d'] # (N, 10)
122
- label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
123
-
124
- # Pad or sample points to max_points
125
- num_points = patch_10d.shape[0]
126
-
127
- if num_points >= self.max_points:
128
- # Randomly sample max_points
129
- indices = np.random.choice(num_points, self.max_points, replace=False)
130
- patch_sampled = patch_10d[indices]
131
- valid_mask = np.ones(self.max_points, dtype=bool)
132
- else:
133
- # Pad with zeros
134
- patch_sampled = np.zeros((self.max_points, 10))
135
- patch_sampled[:num_points] = patch_10d
136
- valid_mask = np.zeros(self.max_points, dtype=bool)
137
- valid_mask[:num_points] = True
138
-
139
- # Data augmentation
140
- if self.augment:
141
- patch_sampled = self._augment_patch(patch_sampled, valid_mask)
142
-
143
- # Convert to tensors and transpose for conv1d (channels first)
144
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
145
- label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
146
- valid_mask_tensor = torch.from_numpy(valid_mask)
147
-
148
- return patch_tensor, label_tensor, valid_mask_tensor
149
-
150
- def _augment_patch(self, patch, valid_mask):
151
- """
152
- Apply data augmentation to the patch.
153
- """
154
- valid_points = patch[valid_mask]
155
-
156
- if len(valid_points) == 0:
157
- return patch
158
-
159
- # Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
160
- angle = np.random.uniform(0, 2 * np.pi)
161
- cos_angle = np.cos(angle)
162
- sin_angle = np.sin(angle)
163
- rotation_matrix = np.array([
164
- [cos_angle, -sin_angle, 0],
165
- [sin_angle, cos_angle, 0],
166
- [0, 0, 1]
167
- ])
168
-
169
- # Apply rotation to xyz coordinates (first 3 dimensions)
170
- valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
171
-
172
- # Random jittering (only for xyz coordinates)
173
- noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
174
- valid_points[:, :3] += noise
175
-
176
- # Random scaling (only for xyz coordinates)
177
- scale = np.random.uniform(0.9, 1.1)
178
- valid_points[:, :3] *= scale
179
-
180
- patch[valid_mask] = valid_points
181
- return patch
182
-
183
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
184
- """
185
- Save patches from prediction pipeline to create a training dataset.
186
-
187
- Args:
188
- patches: List of patch dictionaries from generate_patches()
189
- dataset_dir: Directory to save the dataset
190
- entry_id: Unique identifier for this entry/image
191
- """
192
- os.makedirs(dataset_dir, exist_ok=True)
193
-
194
- for i, patch in enumerate(patches):
195
- # Create unique filename
196
- filename = f"{entry_id}_patch_{i}.pkl"
197
- filepath = os.path.join(dataset_dir, filename)
198
-
199
- # Skip if file already exists
200
- if os.path.exists(filepath):
201
- continue
202
-
203
- # Save patch data
204
- with open(filepath, 'wb') as f:
205
- pickle.dump(patch, f)
206
-
207
- print(f"Saved {len(patches)} patches for entry {entry_id}")
208
-
209
- # Create dataloader with custom collate function to filter invalid samples
210
- def collate_fn(batch):
211
- valid_batch = []
212
- for patch_data, label, valid_mask in batch:
213
- # Filter out invalid samples (no valid points)
214
- if valid_mask.sum() > 0:
215
- valid_batch.append((patch_data, label, valid_mask))
216
-
217
- if len(valid_batch) == 0:
218
- return None
219
-
220
- # Stack valid samples
221
- patch_data = torch.stack([item[0] for item in valid_batch])
222
- labels = torch.stack([item[1] for item in valid_batch])
223
- valid_masks = torch.stack([item[2] for item in valid_batch])
224
-
225
- return patch_data, labels, valid_masks
226
-
227
- # Initialize weights using Xavier/Glorot initialization
228
- def init_weights(m):
229
- if isinstance(m, nn.Conv1d):
230
- nn.init.xavier_uniform_(m.weight)
231
- if m.bias is not None:
232
- nn.init.zeros_(m.bias)
233
- elif isinstance(m, nn.Linear):
234
- nn.init.xavier_uniform_(m.weight)
235
- if m.bias is not None:
236
- nn.init.zeros_(m.bias)
237
- elif isinstance(m, nn.BatchNorm1d):
238
- nn.init.ones_(m.weight)
239
- nn.init.zeros_(m.bias)
240
-
241
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
242
- lr: float = 0.001):
243
- """
244
- Train the ClassificationPointNet model on saved patches.
245
-
246
- Args:
247
- dataset_dir: Directory containing saved patch files
248
- model_save_path: Path to save the trained model
249
- epochs: Number of training epochs
250
- batch_size: Training batch size
251
- lr: Learning rate
252
- """
253
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
254
- print(f"Training on device: {device}")
255
-
256
- # Create dataset and dataloader
257
- dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
258
- print(f"Dataset loaded with {len(dataset)} samples")
259
-
260
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
261
- collate_fn=collate_fn, drop_last=True)
262
-
263
- # Initialize model
264
- model = ClassificationPointNet(input_dim=10, max_points=1024)
265
- model.apply(init_weights)
266
- model.to(device)
267
-
268
- # Loss function and optimizer (BCE for binary classification)
269
- criterion = nn.BCEWithLogitsLoss()
270
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
271
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
272
-
273
- # Training loop
274
- model.train()
275
- for epoch in range(epochs):
276
- total_loss = 0.0
277
- correct = 0
278
- total = 0
279
- num_batches = 0
280
-
281
- for batch_idx, batch_data in enumerate(dataloader):
282
- if batch_data is None: # Skip invalid batches
283
- continue
284
-
285
- patch_data, labels, valid_masks = batch_data
286
- patch_data = patch_data.to(device) # (batch_size, 10, max_points)
287
- labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
288
-
289
- # Forward pass
290
- optimizer.zero_grad()
291
- outputs = model(patch_data) # (batch_size, 1)
292
- loss = criterion(outputs, labels)
293
-
294
- # Backward pass
295
- loss.backward()
296
- optimizer.step()
297
-
298
- # Statistics
299
- total_loss += loss.item()
300
- predicted = (torch.sigmoid(outputs) > 0.5).float()
301
- total += labels.size(0)
302
- correct += (predicted == labels).sum().item()
303
- num_batches += 1
304
-
305
- if batch_idx % 50 == 0:
306
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
307
- f"Loss: {loss.item():.6f}, "
308
- f"Accuracy: {100 * correct / total:.2f}%")
309
-
310
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
311
- accuracy = 100 * correct / total if total > 0 else 0
312
-
313
- print(f"Epoch {epoch+1}/{epochs} completed, "
314
- f"Avg Loss: {avg_loss:.6f}, "
315
- f"Accuracy: {accuracy:.2f}%")
316
-
317
- scheduler.step()
318
-
319
- # Save model checkpoint every epoch
320
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
321
- torch.save({
322
- 'model_state_dict': model.state_dict(),
323
- 'optimizer_state_dict': optimizer.state_dict(),
324
- 'epoch': epoch + 1,
325
- 'loss': avg_loss,
326
- 'accuracy': accuracy,
327
- }, checkpoint_path)
328
-
329
- # Save the trained model
330
- torch.save({
331
- 'model_state_dict': model.state_dict(),
332
- 'optimizer_state_dict': optimizer.state_dict(),
333
- 'epoch': epochs,
334
- }, model_save_path)
335
-
336
- print(f"Model saved to {model_save_path}")
337
- return model
338
-
339
- def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
340
- """
341
- Load a trained ClassificationPointNet model.
342
-
343
- Args:
344
- model_path: Path to the saved model
345
- device: Device to load the model on
346
-
347
- Returns:
348
- Loaded ClassificationPointNet model
349
- """
350
- if device is None:
351
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
352
-
353
- model = ClassificationPointNet(input_dim=10, max_points=1024)
354
-
355
- checkpoint = torch.load(model_path, map_location=device)
356
- model.load_state_dict(checkpoint['model_state_dict'])
357
-
358
- model.to(device)
359
- model.eval()
360
-
361
- return model
362
-
363
- def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
364
- """
365
- Predict binary classification from a patch using trained PointNet.
366
-
367
- Args:
368
- model: Trained ClassificationPointNet model
369
- patch: Dictionary containing patch data with 'patch_10d' key
370
- device: Device to run prediction on
371
-
372
- Returns:
373
- tuple of (predicted_class, confidence)
374
- predicted_class: int (0 for not edge, 1 for edge)
375
- confidence: float representing confidence score (0-1)
376
- """
377
- if device is None:
378
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
379
-
380
- patch_10d = patch['patch_10d'] # (N, 10)
381
-
382
- # Prepare input
383
- max_points = 1024
384
- num_points = patch_10d.shape[0]
385
-
386
- if num_points >= max_points:
387
- # Sample points
388
- indices = np.random.choice(num_points, max_points, replace=False)
389
- patch_sampled = patch_10d[indices]
390
- else:
391
- # Pad with zeros
392
- patch_sampled = np.zeros((max_points, 10))
393
- patch_sampled[:num_points] = patch_10d
394
-
395
- # Convert to tensor
396
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
397
- patch_tensor = patch_tensor.to(device)
398
-
399
- # Predict
400
- with torch.no_grad():
401
- outputs = model(patch_tensor) # (1, 1)
402
- probability = torch.sigmoid(outputs).item()
403
- predicted_class = int(probability > 0.5)
404
-
405
- return predicted_class, probability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_class_10d_2048.py DELETED
@@ -1,405 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class ClassificationPointNet(nn.Module):
12
- """
13
- PointNet implementation for binary classification from 10D point cloud patches.
14
- Takes 10D point clouds and predicts binary classification (edge/not edge).
15
- """
16
- def __init__(self, input_dim=10, max_points=2048):
17
- super(ClassificationPointNet, self).__init__()
18
- self.max_points = max_points
19
-
20
- # Point-wise MLPs for feature extraction (deeper network)
21
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
22
- self.conv2 = nn.Conv1d(64, 128, 1)
23
- self.conv3 = nn.Conv1d(128, 256, 1)
24
- self.conv4 = nn.Conv1d(256, 512, 1)
25
- self.conv5 = nn.Conv1d(512, 1024, 1)
26
- self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
-
28
- # Classification head (deeper with more capacity)
29
- self.fc1 = nn.Linear(2048, 1024)
30
- self.fc2 = nn.Linear(1024, 512)
31
- self.fc3 = nn.Linear(512, 256)
32
- self.fc4 = nn.Linear(256, 128)
33
- self.fc5 = nn.Linear(128, 64)
34
- self.fc6 = nn.Linear(64, 1) # Single output for binary classification
35
-
36
- # Batch normalization layers
37
- self.bn1 = nn.BatchNorm1d(64)
38
- self.bn2 = nn.BatchNorm1d(128)
39
- self.bn3 = nn.BatchNorm1d(256)
40
- self.bn4 = nn.BatchNorm1d(512)
41
- self.bn5 = nn.BatchNorm1d(1024)
42
- self.bn6 = nn.BatchNorm1d(2048)
43
-
44
- # Dropout layers
45
- self.dropout1 = nn.Dropout(0.3)
46
- self.dropout2 = nn.Dropout(0.4)
47
- self.dropout3 = nn.Dropout(0.5)
48
- self.dropout4 = nn.Dropout(0.4)
49
- self.dropout5 = nn.Dropout(0.3)
50
-
51
- def forward(self, x):
52
- """
53
- Forward pass
54
- Args:
55
- x: (batch_size, input_dim, max_points) tensor
56
- Returns:
57
- classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
- """
59
- batch_size = x.size(0)
60
-
61
- # Point-wise feature extraction
62
- x1 = F.relu(self.bn1(self.conv1(x)))
63
- x2 = F.relu(self.bn2(self.conv2(x1)))
64
- x3 = F.relu(self.bn3(self.conv3(x2)))
65
- x4 = F.relu(self.bn4(self.conv4(x3)))
66
- x5 = F.relu(self.bn5(self.conv5(x4)))
67
- x6 = F.relu(self.bn6(self.conv6(x5)))
68
-
69
- # Global max pooling
70
- global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
-
72
- # Classification head
73
- x = F.relu(self.fc1(global_features))
74
- x = self.dropout1(x)
75
- x = F.relu(self.fc2(x))
76
- x = self.dropout2(x)
77
- x = F.relu(self.fc3(x))
78
- x = self.dropout3(x)
79
- x = F.relu(self.fc4(x))
80
- x = self.dropout4(x)
81
- x = F.relu(self.fc5(x))
82
- x = self.dropout5(x)
83
- classification = self.fc6(x) # (batch_size, 1)
84
-
85
- return classification
86
-
87
- class PatchClassificationDataset(Dataset):
88
- """
89
- Dataset class for loading saved patches for PointNet classification training.
90
- """
91
-
92
- def __init__(self, dataset_dir: str, max_points: int = 2048, augment: bool = True):
93
- self.dataset_dir = dataset_dir
94
- self.max_points = max_points
95
- self.augment = augment
96
-
97
- # Load patch files
98
- self.patch_files = []
99
- for file in os.listdir(dataset_dir):
100
- if file.endswith('.pkl'):
101
- self.patch_files.append(os.path.join(dataset_dir, file))
102
-
103
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
104
-
105
- def __len__(self):
106
- return len(self.patch_files)
107
-
108
- def __getitem__(self, idx):
109
- """
110
- Load and process a patch for training.
111
- Returns:
112
- patch_data: (10, max_points) tensor of point cloud data
113
- label: scalar tensor for binary classification (0 or 1)
114
- valid_mask: (max_points,) boolean tensor indicating valid points
115
- """
116
- patch_file = self.patch_files[idx]
117
-
118
- with open(patch_file, 'rb') as f:
119
- patch_info = pickle.load(f)
120
-
121
- patch_10d = patch_info['patch_10d'] # (N, 10)
122
- label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
123
-
124
- # Pad or sample points to max_points
125
- num_points = patch_10d.shape[0]
126
-
127
- if num_points >= self.max_points:
128
- # Randomly sample max_points
129
- indices = np.random.choice(num_points, self.max_points, replace=False)
130
- patch_sampled = patch_10d[indices]
131
- valid_mask = np.ones(self.max_points, dtype=bool)
132
- else:
133
- # Pad with zeros
134
- patch_sampled = np.zeros((self.max_points, 10))
135
- patch_sampled[:num_points] = patch_10d
136
- valid_mask = np.zeros(self.max_points, dtype=bool)
137
- valid_mask[:num_points] = True
138
-
139
- # Data augmentation
140
- if self.augment:
141
- patch_sampled = self._augment_patch(patch_sampled, valid_mask)
142
-
143
- # Convert to tensors and transpose for conv1d (channels first)
144
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
145
- label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
146
- valid_mask_tensor = torch.from_numpy(valid_mask)
147
-
148
- return patch_tensor, label_tensor, valid_mask_tensor
149
-
150
- def _augment_patch(self, patch, valid_mask):
151
- """
152
- Apply data augmentation to the patch.
153
- """
154
- valid_points = patch[valid_mask]
155
-
156
- if len(valid_points) == 0:
157
- return patch
158
-
159
- # Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
160
- angle = np.random.uniform(0, 2 * np.pi)
161
- cos_angle = np.cos(angle)
162
- sin_angle = np.sin(angle)
163
- rotation_matrix = np.array([
164
- [cos_angle, -sin_angle, 0],
165
- [sin_angle, cos_angle, 0],
166
- [0, 0, 1]
167
- ])
168
-
169
- # Apply rotation to xyz coordinates (first 3 dimensions)
170
- valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
171
-
172
- # Random jittering (only for xyz coordinates)
173
- noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
174
- valid_points[:, :3] += noise
175
-
176
- # Random scaling (only for xyz coordinates)
177
- scale = np.random.uniform(0.9, 1.1)
178
- valid_points[:, :3] *= scale
179
-
180
- patch[valid_mask] = valid_points
181
- return patch
182
-
183
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
184
- """
185
- Save patches from prediction pipeline to create a training dataset.
186
-
187
- Args:
188
- patches: List of patch dictionaries from generate_patches()
189
- dataset_dir: Directory to save the dataset
190
- entry_id: Unique identifier for this entry/image
191
- """
192
- os.makedirs(dataset_dir, exist_ok=True)
193
-
194
- for i, patch in enumerate(patches):
195
- # Create unique filename
196
- filename = f"{entry_id}_patch_{i}.pkl"
197
- filepath = os.path.join(dataset_dir, filename)
198
-
199
- # Skip if file already exists
200
- if os.path.exists(filepath):
201
- continue
202
-
203
- # Save patch data
204
- with open(filepath, 'wb') as f:
205
- pickle.dump(patch, f)
206
-
207
- print(f"Saved {len(patches)} patches for entry {entry_id}")
208
-
209
- # Create dataloader with custom collate function to filter invalid samples
210
- def collate_fn(batch):
211
- valid_batch = []
212
- for patch_data, label, valid_mask in batch:
213
- # Filter out invalid samples (no valid points)
214
- if valid_mask.sum() > 0:
215
- valid_batch.append((patch_data, label, valid_mask))
216
-
217
- if len(valid_batch) == 0:
218
- return None
219
-
220
- # Stack valid samples
221
- patch_data = torch.stack([item[0] for item in valid_batch])
222
- labels = torch.stack([item[1] for item in valid_batch])
223
- valid_masks = torch.stack([item[2] for item in valid_batch])
224
-
225
- return patch_data, labels, valid_masks
226
-
227
- # Initialize weights using Xavier/Glorot initialization
228
- def init_weights(m):
229
- if isinstance(m, nn.Conv1d):
230
- nn.init.xavier_uniform_(m.weight)
231
- if m.bias is not None:
232
- nn.init.zeros_(m.bias)
233
- elif isinstance(m, nn.Linear):
234
- nn.init.xavier_uniform_(m.weight)
235
- if m.bias is not None:
236
- nn.init.zeros_(m.bias)
237
- elif isinstance(m, nn.BatchNorm1d):
238
- nn.init.ones_(m.weight)
239
- nn.init.zeros_(m.bias)
240
-
241
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
242
- lr: float = 0.001):
243
- """
244
- Train the ClassificationPointNet model on saved patches.
245
-
246
- Args:
247
- dataset_dir: Directory containing saved patch files
248
- model_save_path: Path to save the trained model
249
- epochs: Number of training epochs
250
- batch_size: Training batch size
251
- lr: Learning rate
252
- """
253
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
254
- print(f"Training on device: {device}")
255
-
256
- # Create dataset and dataloader
257
- dataset = PatchClassificationDataset(dataset_dir, max_points=2048, augment=True)
258
- print(f"Dataset loaded with {len(dataset)} samples")
259
-
260
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
261
- collate_fn=collate_fn, drop_last=True)
262
-
263
- # Initialize model
264
- model = ClassificationPointNet(input_dim=10, max_points=2048)
265
- model.apply(init_weights)
266
- model.to(device)
267
-
268
- # Loss function and optimizer (BCE for binary classification)
269
- criterion = nn.BCEWithLogitsLoss()
270
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
271
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
272
-
273
- # Training loop
274
- model.train()
275
- for epoch in range(epochs):
276
- total_loss = 0.0
277
- correct = 0
278
- total = 0
279
- num_batches = 0
280
-
281
- for batch_idx, batch_data in enumerate(dataloader):
282
- if batch_data is None: # Skip invalid batches
283
- continue
284
-
285
- patch_data, labels, valid_masks = batch_data
286
- patch_data = patch_data.to(device) # (batch_size, 10, max_points)
287
- labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
288
-
289
- # Forward pass
290
- optimizer.zero_grad()
291
- outputs = model(patch_data) # (batch_size, 1)
292
- loss = criterion(outputs, labels)
293
-
294
- # Backward pass
295
- loss.backward()
296
- optimizer.step()
297
-
298
- # Statistics
299
- total_loss += loss.item()
300
- predicted = (torch.sigmoid(outputs) > 0.5).float()
301
- total += labels.size(0)
302
- correct += (predicted == labels).sum().item()
303
- num_batches += 1
304
-
305
- if batch_idx % 50 == 0:
306
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
307
- f"Loss: {loss.item():.6f}, "
308
- f"Accuracy: {100 * correct / total:.2f}%")
309
-
310
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
311
- accuracy = 100 * correct / total if total > 0 else 0
312
-
313
- print(f"Epoch {epoch+1}/{epochs} completed, "
314
- f"Avg Loss: {avg_loss:.6f}, "
315
- f"Accuracy: {accuracy:.2f}%")
316
-
317
- scheduler.step()
318
-
319
- # Save model checkpoint every epoch
320
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
321
- torch.save({
322
- 'model_state_dict': model.state_dict(),
323
- 'optimizer_state_dict': optimizer.state_dict(),
324
- 'epoch': epoch + 1,
325
- 'loss': avg_loss,
326
- 'accuracy': accuracy,
327
- }, checkpoint_path)
328
-
329
- # Save the trained model
330
- torch.save({
331
- 'model_state_dict': model.state_dict(),
332
- 'optimizer_state_dict': optimizer.state_dict(),
333
- 'epoch': epochs,
334
- }, model_save_path)
335
-
336
- print(f"Model saved to {model_save_path}")
337
- return model
338
-
339
- def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
340
- """
341
- Load a trained ClassificationPointNet model.
342
-
343
- Args:
344
- model_path: Path to the saved model
345
- device: Device to load the model on
346
-
347
- Returns:
348
- Loaded ClassificationPointNet model
349
- """
350
- if device is None:
351
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
352
-
353
- model = ClassificationPointNet(input_dim=10, max_points=2048)
354
-
355
- checkpoint = torch.load(model_path, map_location=device)
356
- model.load_state_dict(checkpoint['model_state_dict'])
357
-
358
- model.to(device)
359
- model.eval()
360
-
361
- return model
362
-
363
- def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
364
- """
365
- Predict binary classification from a patch using trained PointNet.
366
-
367
- Args:
368
- model: Trained ClassificationPointNet model
369
- patch: Dictionary containing patch data with 'patch_10d' key
370
- device: Device to run prediction on
371
-
372
- Returns:
373
- tuple of (predicted_class, confidence)
374
- predicted_class: int (0 for not edge, 1 for edge)
375
- confidence: float representing confidence score (0-1)
376
- """
377
- if device is None:
378
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
379
-
380
- patch_10d = patch['patch_10d'] # (N, 10)
381
-
382
- # Prepare input
383
- max_points = 2048
384
- num_points = patch_10d.shape[0]
385
-
386
- if num_points >= max_points:
387
- # Sample points
388
- indices = np.random.choice(num_points, max_points, replace=False)
389
- patch_sampled = patch_10d[indices]
390
- else:
391
- # Pad with zeros
392
- patch_sampled = np.zeros((max_points, 10))
393
- patch_sampled[:num_points] = patch_10d
394
-
395
- # Convert to tensor
396
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
397
- patch_tensor = patch_tensor.to(device)
398
-
399
- # Predict
400
- with torch.no_grad():
401
- outputs = model(patch_tensor) # (1, 1)
402
- probability = torch.sigmoid(outputs).item()
403
- predicted_class = int(probability > 0.5)
404
-
405
- return predicted_class, probability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_class_10d_deeper.py DELETED
@@ -1,438 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class ClassificationPointNet(nn.Module):
12
- """
13
- PointNet implementation for binary classification from 10D point cloud patches.
14
- Takes 10D point clouds and predicts binary classification (edge/not edge).
15
- Enhanced with residual connections and attention mechanism.
16
- """
17
- def __init__(self, input_dim=10, max_points=1024):
18
- super(ClassificationPointNet, self).__init__()
19
- self.max_points = max_points
20
-
21
- # Point-wise MLPs for feature extraction (deeper with residual connections)
22
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
23
- self.conv2 = nn.Conv1d(64, 128, 1)
24
- self.conv3 = nn.Conv1d(128, 256, 1)
25
- self.conv4 = nn.Conv1d(256, 512, 1)
26
- self.conv5 = nn.Conv1d(512, 1024, 1)
27
- self.conv6 = nn.Conv1d(1024, 2048, 1)
28
- self.conv7 = nn.Conv1d(2048, 2048, 1) # Additional layer
29
-
30
- # Residual connection layers
31
- self.residual1 = nn.Conv1d(128, 256, 1)
32
- self.residual2 = nn.Conv1d(512, 1024, 1)
33
-
34
- # Attention mechanism
35
- self.attention = nn.Conv1d(2048, 1, 1)
36
-
37
- # Classification head (deeper with more capacity)
38
- self.fc1 = nn.Linear(2048, 1536)
39
- self.fc2 = nn.Linear(1536, 1024)
40
- self.fc3 = nn.Linear(1024, 512)
41
- self.fc4 = nn.Linear(512, 256)
42
- self.fc5 = nn.Linear(256, 128)
43
- self.fc6 = nn.Linear(128, 64)
44
- self.fc7 = nn.Linear(64, 32)
45
- self.fc8 = nn.Linear(32, 1) # Single output for binary classification
46
-
47
- # Batch normalization layers
48
- self.bn1 = nn.BatchNorm1d(64)
49
- self.bn2 = nn.BatchNorm1d(128)
50
- self.bn3 = nn.BatchNorm1d(256)
51
- self.bn4 = nn.BatchNorm1d(512)
52
- self.bn5 = nn.BatchNorm1d(1024)
53
- self.bn6 = nn.BatchNorm1d(2048)
54
- self.bn7 = nn.BatchNorm1d(2048)
55
-
56
- # Dropout layers with varying rates
57
- self.dropout1 = nn.Dropout(0.2)
58
- self.dropout2 = nn.Dropout(0.3)
59
- self.dropout3 = nn.Dropout(0.4)
60
- self.dropout4 = nn.Dropout(0.5)
61
- self.dropout5 = nn.Dropout(0.4)
62
- self.dropout6 = nn.Dropout(0.3)
63
- self.dropout7 = nn.Dropout(0.2)
64
-
65
- def forward(self, x):
66
- """
67
- Forward pass with residual connections and attention
68
- Args:
69
- x: (batch_size, input_dim, max_points) tensor
70
- Returns:
71
- classification: (batch_size, 1) tensor of logits (sigmoid for probability)
72
- """
73
- batch_size = x.size(0)
74
-
75
- # Point-wise feature extraction with residual connections
76
- x1 = F.relu(self.bn1(self.conv1(x)))
77
- x2 = F.relu(self.bn2(self.conv2(x1)))
78
- x3 = F.relu(self.bn3(self.conv3(x2)))
79
-
80
- # First residual connection
81
- x3_res = x3 + self.residual1(x2)
82
-
83
- x4 = F.relu(self.bn4(self.conv4(x3_res)))
84
- x5 = F.relu(self.bn5(self.conv5(x4)))
85
-
86
- # Second residual connection
87
- x5_res = x5 + self.residual2(x4)
88
-
89
- x6 = F.relu(self.bn6(self.conv6(x5_res)))
90
- x7 = F.relu(self.bn7(self.conv7(x6)))
91
-
92
- # Attention mechanism
93
- attention_weights = F.softmax(self.attention(x7), dim=2) # (batch_size, 1, max_points)
94
- x7_weighted = x7 * attention_weights # Apply attention
95
-
96
- # Global max pooling combined with attention-weighted average pooling
97
- global_max = torch.max(x7, 2)[0] # (batch_size, 2048)
98
- global_avg = torch.sum(x7_weighted, 2) # (batch_size, 2048)
99
- global_features = global_max + global_avg # Combine features
100
-
101
- # Classification head with residual connections
102
- x = F.relu(self.fc1(global_features))
103
- x = self.dropout1(x)
104
- x = F.relu(self.fc2(x))
105
- x = self.dropout2(x)
106
- x_mid = F.relu(self.fc3(x))
107
- x = self.dropout3(x_mid)
108
- x = F.relu(self.fc4(x))
109
- x = self.dropout4(x)
110
- x = F.relu(self.fc5(x))
111
- x = self.dropout5(x)
112
- x = F.relu(self.fc6(x))
113
- x = self.dropout6(x)
114
- x = F.relu(self.fc7(x))
115
- x = self.dropout7(x)
116
- classification = self.fc8(x) # (batch_size, 1)
117
-
118
- return classification
119
-
120
- class PatchClassificationDataset(Dataset):
121
- """
122
- Dataset class for loading saved patches for PointNet classification training.
123
- """
124
-
125
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
126
- self.dataset_dir = dataset_dir
127
- self.max_points = max_points
128
- self.augment = augment
129
-
130
- # Load patch files
131
- self.patch_files = []
132
- for file in os.listdir(dataset_dir):
133
- if file.endswith('.pkl'):
134
- self.patch_files.append(os.path.join(dataset_dir, file))
135
-
136
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
137
-
138
- def __len__(self):
139
- return len(self.patch_files)
140
-
141
- def __getitem__(self, idx):
142
- """
143
- Load and process a patch for training.
144
- Returns:
145
- patch_data: (10, max_points) tensor of point cloud data
146
- label: scalar tensor for binary classification (0 or 1)
147
- valid_mask: (max_points,) boolean tensor indicating valid points
148
- """
149
- patch_file = self.patch_files[idx]
150
-
151
- with open(patch_file, 'rb') as f:
152
- patch_info = pickle.load(f)
153
-
154
- patch_10d = patch_info['patch_10d'] # (N, 10)
155
- label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
156
-
157
- # Pad or sample points to max_points
158
- num_points = patch_10d.shape[0]
159
-
160
- if num_points >= self.max_points:
161
- # Randomly sample max_points
162
- indices = np.random.choice(num_points, self.max_points, replace=False)
163
- patch_sampled = patch_10d[indices]
164
- valid_mask = np.ones(self.max_points, dtype=bool)
165
- else:
166
- # Pad with zeros
167
- patch_sampled = np.zeros((self.max_points, 10))
168
- patch_sampled[:num_points] = patch_10d
169
- valid_mask = np.zeros(self.max_points, dtype=bool)
170
- valid_mask[:num_points] = True
171
-
172
- # Data augmentation
173
- if self.augment:
174
- patch_sampled = self._augment_patch(patch_sampled, valid_mask)
175
-
176
- # Convert to tensors and transpose for conv1d (channels first)
177
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
178
- label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
179
- valid_mask_tensor = torch.from_numpy(valid_mask)
180
-
181
- return patch_tensor, label_tensor, valid_mask_tensor
182
-
183
- def _augment_patch(self, patch, valid_mask):
184
- """
185
- Apply data augmentation to the patch.
186
- """
187
- valid_points = patch[valid_mask]
188
-
189
- if len(valid_points) == 0:
190
- return patch
191
-
192
- # Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
193
- angle = np.random.uniform(0, 2 * np.pi)
194
- cos_angle = np.cos(angle)
195
- sin_angle = np.sin(angle)
196
- rotation_matrix = np.array([
197
- [cos_angle, -sin_angle, 0],
198
- [sin_angle, cos_angle, 0],
199
- [0, 0, 1]
200
- ])
201
-
202
- # Apply rotation to xyz coordinates (first 3 dimensions)
203
- valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
204
-
205
- # Random jittering (only for xyz coordinates)
206
- noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
207
- valid_points[:, :3] += noise
208
-
209
- # Random scaling (only for xyz coordinates)
210
- scale = np.random.uniform(0.9, 1.1)
211
- valid_points[:, :3] *= scale
212
-
213
- patch[valid_mask] = valid_points
214
- return patch
215
-
216
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
217
- """
218
- Save patches from prediction pipeline to create a training dataset.
219
-
220
- Args:
221
- patches: List of patch dictionaries from generate_patches()
222
- dataset_dir: Directory to save the dataset
223
- entry_id: Unique identifier for this entry/image
224
- """
225
- os.makedirs(dataset_dir, exist_ok=True)
226
-
227
- for i, patch in enumerate(patches):
228
- # Create unique filename
229
- filename = f"{entry_id}_patch_{i}.pkl"
230
- filepath = os.path.join(dataset_dir, filename)
231
-
232
- # Skip if file already exists
233
- if os.path.exists(filepath):
234
- continue
235
-
236
- # Save patch data
237
- with open(filepath, 'wb') as f:
238
- pickle.dump(patch, f)
239
-
240
- print(f"Saved {len(patches)} patches for entry {entry_id}")
241
-
242
- # Create dataloader with custom collate function to filter invalid samples
243
- def collate_fn(batch):
244
- valid_batch = []
245
- for patch_data, label, valid_mask in batch:
246
- # Filter out invalid samples (no valid points)
247
- if valid_mask.sum() > 0:
248
- valid_batch.append((patch_data, label, valid_mask))
249
-
250
- if len(valid_batch) == 0:
251
- return None
252
-
253
- # Stack valid samples
254
- patch_data = torch.stack([item[0] for item in valid_batch])
255
- labels = torch.stack([item[1] for item in valid_batch])
256
- valid_masks = torch.stack([item[2] for item in valid_batch])
257
-
258
- return patch_data, labels, valid_masks
259
-
260
- # Initialize weights using Xavier/Glorot initialization
261
- def init_weights(m):
262
- if isinstance(m, nn.Conv1d):
263
- nn.init.xavier_uniform_(m.weight)
264
- if m.bias is not None:
265
- nn.init.zeros_(m.bias)
266
- elif isinstance(m, nn.Linear):
267
- nn.init.xavier_uniform_(m.weight)
268
- if m.bias is not None:
269
- nn.init.zeros_(m.bias)
270
- elif isinstance(m, nn.BatchNorm1d):
271
- nn.init.ones_(m.weight)
272
- nn.init.zeros_(m.bias)
273
-
274
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
275
- lr: float = 0.001):
276
- """
277
- Train the ClassificationPointNet model on saved patches.
278
-
279
- Args:
280
- dataset_dir: Directory containing saved patch files
281
- model_save_path: Path to save the trained model
282
- epochs: Number of training epochs
283
- batch_size: Training batch size
284
- lr: Learning rate
285
- """
286
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
287
- print(f"Training on device: {device}")
288
-
289
- # Create dataset and dataloader
290
- dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
291
- print(f"Dataset loaded with {len(dataset)} samples")
292
-
293
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
294
- collate_fn=collate_fn, drop_last=True)
295
-
296
- # Initialize model
297
- model = ClassificationPointNet(input_dim=10, max_points=1024)
298
- model.apply(init_weights)
299
- model.to(device)
300
-
301
- # Loss function and optimizer (BCE for binary classification)
302
- criterion = nn.BCEWithLogitsLoss()
303
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
304
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
305
-
306
- # Training loop
307
- model.train()
308
- for epoch in range(epochs):
309
- total_loss = 0.0
310
- correct = 0
311
- total = 0
312
- num_batches = 0
313
-
314
- for batch_idx, batch_data in enumerate(dataloader):
315
- if batch_data is None: # Skip invalid batches
316
- continue
317
-
318
- patch_data, labels, valid_masks = batch_data
319
- patch_data = patch_data.to(device) # (batch_size, 10, max_points)
320
- labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
321
-
322
- # Forward pass
323
- optimizer.zero_grad()
324
- outputs = model(patch_data) # (batch_size, 1)
325
- loss = criterion(outputs, labels)
326
-
327
- # Backward pass
328
- loss.backward()
329
- optimizer.step()
330
-
331
- # Statistics
332
- total_loss += loss.item()
333
- predicted = (torch.sigmoid(outputs) > 0.5).float()
334
- total += labels.size(0)
335
- correct += (predicted == labels).sum().item()
336
- num_batches += 1
337
-
338
- if batch_idx % 50 == 0:
339
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
340
- f"Loss: {loss.item():.6f}, "
341
- f"Accuracy: {100 * correct / total:.2f}%")
342
-
343
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
344
- accuracy = 100 * correct / total if total > 0 else 0
345
-
346
- print(f"Epoch {epoch+1}/{epochs} completed, "
347
- f"Avg Loss: {avg_loss:.6f}, "
348
- f"Accuracy: {accuracy:.2f}%")
349
-
350
- scheduler.step()
351
-
352
- # Save model checkpoint every epoch
353
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
354
- torch.save({
355
- 'model_state_dict': model.state_dict(),
356
- 'optimizer_state_dict': optimizer.state_dict(),
357
- 'epoch': epoch + 1,
358
- 'loss': avg_loss,
359
- 'accuracy': accuracy,
360
- }, checkpoint_path)
361
-
362
- # Save the trained model
363
- torch.save({
364
- 'model_state_dict': model.state_dict(),
365
- 'optimizer_state_dict': optimizer.state_dict(),
366
- 'epoch': epochs,
367
- }, model_save_path)
368
-
369
- print(f"Model saved to {model_save_path}")
370
- return model
371
-
372
- def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
373
- """
374
- Load a trained ClassificationPointNet model.
375
-
376
- Args:
377
- model_path: Path to the saved model
378
- device: Device to load the model on
379
-
380
- Returns:
381
- Loaded ClassificationPointNet model
382
- """
383
- if device is None:
384
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
385
-
386
- model = ClassificationPointNet(input_dim=10, max_points=1024)
387
-
388
- checkpoint = torch.load(model_path, map_location=device)
389
- model.load_state_dict(checkpoint['model_state_dict'])
390
-
391
- model.to(device)
392
- model.eval()
393
-
394
- return model
395
-
396
- def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
397
- """
398
- Predict binary classification from a patch using trained PointNet.
399
-
400
- Args:
401
- model: Trained ClassificationPointNet model
402
- patch: Dictionary containing patch data with 'patch_10d' key
403
- device: Device to run prediction on
404
-
405
- Returns:
406
- tuple of (predicted_class, confidence)
407
- predicted_class: int (0 for not edge, 1 for edge)
408
- confidence: float representing confidence score (0-1)
409
- """
410
- if device is None:
411
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
412
-
413
- patch_10d = patch['patch_10d'] # (N, 10)
414
-
415
- # Prepare input
416
- max_points = 1024
417
- num_points = patch_10d.shape[0]
418
-
419
- if num_points >= max_points:
420
- # Sample points
421
- indices = np.random.choice(num_points, max_points, replace=False)
422
- patch_sampled = patch_10d[indices]
423
- else:
424
- # Pad with zeros
425
- patch_sampled = np.zeros((max_points, 10))
426
- patch_sampled[:num_points] = patch_10d
427
-
428
- # Convert to tensor
429
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
430
- patch_tensor = patch_tensor.to(device)
431
-
432
- # Predict
433
- with torch.no_grad():
434
- outputs = model(patch_tensor) # (1, 1)
435
- probability = torch.sigmoid(outputs).item()
436
- predicted_class = int(probability > 0.5)
437
-
438
- return predicted_class, probability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_class_deeper.py DELETED
@@ -1,527 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class ClassificationPointNet(nn.Module):
12
- """
13
- Enhanced PointNet implementation for binary classification from 6D point cloud patches.
14
- Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
15
- Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
16
- """
17
- def __init__(self, input_dim=6, max_points=1024):
18
- super(ClassificationPointNet, self).__init__()
19
- self.max_points = max_points
20
-
21
- # Point-wise MLPs with residual connections (much deeper)
22
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
23
- self.conv2 = nn.Conv1d(64, 64, 1)
24
- self.conv3 = nn.Conv1d(64, 128, 1)
25
- self.conv4 = nn.Conv1d(128, 128, 1)
26
- self.conv5 = nn.Conv1d(128, 256, 1)
27
- self.conv6 = nn.Conv1d(256, 256, 1)
28
- self.conv7 = nn.Conv1d(256, 512, 1)
29
- self.conv8 = nn.Conv1d(512, 512, 1)
30
- self.conv9 = nn.Conv1d(512, 1024, 1)
31
- self.conv10 = nn.Conv1d(1024, 1024, 1)
32
- self.conv11 = nn.Conv1d(1024, 2048, 1)
33
-
34
- # Residual connection layers
35
- self.res_conv1 = nn.Conv1d(64, 128, 1)
36
- self.res_conv2 = nn.Conv1d(128, 256, 1)
37
- self.res_conv3 = nn.Conv1d(256, 512, 1)
38
- self.res_conv4 = nn.Conv1d(512, 1024, 1)
39
-
40
- # Self-attention mechanism
41
- self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
42
- self.attention_norm = nn.LayerNorm(2048)
43
-
44
- # Multi-scale feature aggregation
45
- self.scale_conv1 = nn.Conv1d(2048, 512, 1)
46
- self.scale_conv2 = nn.Conv1d(2048, 512, 1)
47
- self.scale_conv3 = nn.Conv1d(2048, 512, 1)
48
-
49
- # Enhanced classification head with residual connections
50
- self.fc1 = nn.Linear(7680, 2048) # Updated input size: 2048*3 + 512*3 = 7680
51
- self.fc2 = nn.Linear(2048, 2048)
52
- self.fc3 = nn.Linear(2048, 1024)
53
- self.fc4 = nn.Linear(1024, 1024)
54
- self.fc5 = nn.Linear(1024, 512)
55
- self.fc6 = nn.Linear(512, 512)
56
- self.fc7 = nn.Linear(512, 256)
57
- self.fc8 = nn.Linear(256, 128)
58
- self.fc9 = nn.Linear(128, 64)
59
- self.fc10 = nn.Linear(64, 1)
60
-
61
- # Residual connections for FC layers
62
- self.fc_res1 = nn.Linear(2048, 1024)
63
- self.fc_res2 = nn.Linear(1024, 512)
64
- self.fc_res3 = nn.Linear(512, 128)
65
-
66
- # Batch normalization layers
67
- self.bn1 = nn.BatchNorm1d(64)
68
- self.bn2 = nn.BatchNorm1d(64)
69
- self.bn3 = nn.BatchNorm1d(128)
70
- self.bn4 = nn.BatchNorm1d(128)
71
- self.bn5 = nn.BatchNorm1d(256)
72
- self.bn6 = nn.BatchNorm1d(256)
73
- self.bn7 = nn.BatchNorm1d(512)
74
- self.bn8 = nn.BatchNorm1d(512)
75
- self.bn9 = nn.BatchNorm1d(1024)
76
- self.bn10 = nn.BatchNorm1d(1024)
77
- self.bn11 = nn.BatchNorm1d(2048)
78
-
79
- # Scale batch norms
80
- self.scale_bn1 = nn.BatchNorm1d(512)
81
- self.scale_bn2 = nn.BatchNorm1d(512)
82
- self.scale_bn3 = nn.BatchNorm1d(512)
83
-
84
- # FC batch norms
85
- self.fc_bn1 = nn.BatchNorm1d(2048)
86
- self.fc_bn2 = nn.BatchNorm1d(2048)
87
- self.fc_bn3 = nn.BatchNorm1d(1024)
88
- self.fc_bn4 = nn.BatchNorm1d(1024)
89
- self.fc_bn5 = nn.BatchNorm1d(512)
90
- self.fc_bn6 = nn.BatchNorm1d(512)
91
- self.fc_bn7 = nn.BatchNorm1d(256)
92
- self.fc_bn8 = nn.BatchNorm1d(128)
93
-
94
- # Dropout layers with varying rates
95
- self.dropout1 = nn.Dropout(0.1)
96
- self.dropout2 = nn.Dropout(0.2)
97
- self.dropout3 = nn.Dropout(0.3)
98
- self.dropout4 = nn.Dropout(0.4)
99
- self.dropout5 = nn.Dropout(0.5)
100
- self.dropout6 = nn.Dropout(0.4)
101
- self.dropout7 = nn.Dropout(0.3)
102
- self.dropout8 = nn.Dropout(0.2)
103
-
104
- def forward(self, x):
105
- """
106
- Forward pass with residual connections and attention
107
- Args:
108
- x: (batch_size, input_dim, max_points) tensor
109
- Returns:
110
- classification: (batch_size, 1) tensor of logits
111
- """
112
- batch_size = x.size(0)
113
-
114
- # Deep point-wise feature extraction with residual connections
115
- x1 = F.relu(self.bn1(self.conv1(x)))
116
- x2 = F.relu(self.bn2(self.conv2(x1)))
117
- x2 = x2 + x1 # Residual connection
118
-
119
- x3 = F.relu(self.bn3(self.conv3(x2)))
120
- x4 = F.relu(self.bn4(self.conv4(x3)))
121
- res1 = self.res_conv1(x2)
122
- x4 = x4 + res1 # Residual connection
123
-
124
- x5 = F.relu(self.bn5(self.conv5(x4)))
125
- x6 = F.relu(self.bn6(self.conv6(x5)))
126
- res2 = self.res_conv2(x4)
127
- x6 = x6 + res2 # Residual connection
128
-
129
- x7 = F.relu(self.bn7(self.conv7(x6)))
130
- x8 = F.relu(self.bn8(self.conv8(x7)))
131
- res3 = self.res_conv3(x6)
132
- x8 = x8 + res3 # Residual connection
133
-
134
- x9 = F.relu(self.bn9(self.conv9(x8)))
135
- x10 = F.relu(self.bn10(self.conv10(x9)))
136
- res4 = self.res_conv4(x8)
137
- x10 = x10 + res4 # Residual connection
138
-
139
- x11 = F.relu(self.bn11(self.conv11(x10)))
140
-
141
- # Multi-scale global pooling
142
- # Max pooling
143
- global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
144
-
145
- # Average pooling
146
- global_avg = torch.mean(x11, 2) # (batch_size, 2048)
147
-
148
- # Attention-based pooling
149
- x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
150
- attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
151
- attended = self.attention_norm(attended + x11_transposed)
152
- global_att = torch.mean(attended, 1) # (batch_size, 2048)
153
-
154
- # Multi-scale feature extraction
155
- scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
156
- scale1_pool = torch.max(scale1, 2)[0]
157
-
158
- scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
159
- scale2_pool = torch.mean(scale2, 2)
160
-
161
- scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
162
- scale3_pool = torch.std(scale3, 2)
163
-
164
- # Concatenate all global features
165
- global_features = torch.cat([
166
- global_max, global_avg, global_att,
167
- scale1_pool, scale2_pool, scale3_pool
168
- ], dim=1) # (batch_size, 4096)
169
-
170
- # Enhanced classification head with residual connections
171
- x = F.relu(self.fc_bn1(self.fc1(global_features)))
172
- x = self.dropout1(x)
173
-
174
- x = F.relu(self.fc_bn2(self.fc2(x)))
175
- identity1 = x
176
- x = self.dropout2(x)
177
-
178
- x = F.relu(self.fc_bn3(self.fc3(x)))
179
- x = self.dropout3(x)
180
-
181
- x = F.relu(self.fc_bn4(self.fc4(x)))
182
- res_fc1 = self.fc_res1(identity1)
183
- x = x + res_fc1 # Residual connection
184
- identity2 = x
185
- x = self.dropout4(x)
186
-
187
- x = F.relu(self.fc_bn5(self.fc5(x)))
188
- x = self.dropout5(x)
189
-
190
- x = F.relu(self.fc_bn6(self.fc6(x)))
191
- res_fc2 = self.fc_res2(identity2)
192
- x = x + res_fc2 # Residual connection
193
- identity3 = x
194
- x = self.dropout6(x)
195
-
196
- x = F.relu(self.fc_bn7(self.fc7(x)))
197
- x = self.dropout7(x)
198
-
199
- x = F.relu(self.fc_bn8(self.fc8(x)))
200
- res_fc3 = self.fc_res3(identity3)
201
- x = x + res_fc3 # Residual connection
202
- x = self.dropout8(x)
203
-
204
- x = F.relu(self.fc9(x))
205
- classification = self.fc10(x) # (batch_size, 1)
206
-
207
- return classification
208
-
209
- class PatchClassificationDataset(Dataset):
210
- """
211
- Dataset class for loading saved patches for PointNet classification training.
212
- """
213
-
214
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
215
- self.dataset_dir = dataset_dir
216
- self.max_points = max_points
217
- self.augment = augment
218
-
219
- # Load patch files
220
- self.patch_files = []
221
- for file in os.listdir(dataset_dir):
222
- if file.endswith('.pkl'):
223
- self.patch_files.append(os.path.join(dataset_dir, file))
224
-
225
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
226
-
227
- def __len__(self):
228
- return len(self.patch_files)
229
-
230
- def __getitem__(self, idx):
231
- """
232
- Load and process a patch for training.
233
- Returns:
234
- patch_data: (6, max_points) tensor of point cloud data
235
- label: scalar tensor for binary classification (0 or 1)
236
- valid_mask: (max_points,) boolean tensor indicating valid points
237
- """
238
- patch_file = self.patch_files[idx]
239
-
240
- with open(patch_file, 'rb') as f:
241
- patch_info = pickle.load(f)
242
-
243
- patch_6d = patch_info['patch_6d'] # (N, 6)
244
- label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
245
-
246
- # Pad or sample points to max_points
247
- num_points = patch_6d.shape[0]
248
-
249
- if num_points >= self.max_points:
250
- # Randomly sample max_points
251
- indices = np.random.choice(num_points, self.max_points, replace=False)
252
- patch_sampled = patch_6d[indices]
253
- valid_mask = np.ones(self.max_points, dtype=bool)
254
- else:
255
- # Pad with zeros
256
- patch_sampled = np.zeros((self.max_points, 6))
257
- patch_sampled[:num_points] = patch_6d
258
- valid_mask = np.zeros(self.max_points, dtype=bool)
259
- valid_mask[:num_points] = True
260
-
261
- # Data augmentation
262
- if self.augment:
263
- patch_sampled = self._augment_patch(patch_sampled, valid_mask)
264
-
265
- # Convert to tensors and transpose for conv1d (channels first)
266
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points)
267
- label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
268
- valid_mask_tensor = torch.from_numpy(valid_mask)
269
-
270
- return patch_tensor, label_tensor, valid_mask_tensor
271
-
272
- def _augment_patch(self, patch, valid_mask):
273
- """
274
- Apply data augmentation to the patch.
275
- """
276
- valid_points = patch[valid_mask]
277
-
278
- if len(valid_points) == 0:
279
- return patch
280
-
281
- # Random rotation around z-axis
282
- angle = np.random.uniform(0, 2 * np.pi)
283
- cos_angle = np.cos(angle)
284
- sin_angle = np.sin(angle)
285
- rotation_matrix = np.array([
286
- [cos_angle, -sin_angle, 0],
287
- [sin_angle, cos_angle, 0],
288
- [0, 0, 1]
289
- ])
290
-
291
- # Apply rotation to xyz coordinates
292
- valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
293
-
294
- # Random jittering
295
- noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
296
- valid_points[:, :3] += noise
297
-
298
- # Random scaling
299
- scale = np.random.uniform(0.9, 1.1)
300
- valid_points[:, :3] *= scale
301
-
302
- patch[valid_mask] = valid_points
303
- return patch
304
-
305
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
306
- """
307
- Save patches from prediction pipeline to create a training dataset.
308
-
309
- Args:
310
- patches: List of patch dictionaries from generate_patches()
311
- dataset_dir: Directory to save the dataset
312
- entry_id: Unique identifier for this entry/image
313
- """
314
- os.makedirs(dataset_dir, exist_ok=True)
315
-
316
- for i, patch in enumerate(patches):
317
- # Create unique filename
318
- filename = f"{entry_id}_patch_{i}.pkl"
319
- filepath = os.path.join(dataset_dir, filename)
320
-
321
- # Skip if file already exists
322
- if os.path.exists(filepath):
323
- continue
324
-
325
- # Save patch data
326
- with open(filepath, 'wb') as f:
327
- pickle.dump(patch, f)
328
-
329
- print(f"Saved {len(patches)} patches for entry {entry_id}")
330
-
331
- # Create dataloader with custom collate function to filter invalid samples
332
- def collate_fn(batch):
333
- valid_batch = []
334
- for patch_data, label, valid_mask in batch:
335
- # Filter out invalid samples (no valid points)
336
- if valid_mask.sum() > 0:
337
- valid_batch.append((patch_data, label, valid_mask))
338
-
339
- if len(valid_batch) == 0:
340
- return None
341
-
342
- # Stack valid samples
343
- patch_data = torch.stack([item[0] for item in valid_batch])
344
- labels = torch.stack([item[1] for item in valid_batch])
345
- valid_masks = torch.stack([item[2] for item in valid_batch])
346
-
347
- return patch_data, labels, valid_masks
348
-
349
- # Initialize weights using Xavier/Glorot initialization
350
- def init_weights(m):
351
- if isinstance(m, nn.Conv1d):
352
- nn.init.xavier_uniform_(m.weight)
353
- if m.bias is not None:
354
- nn.init.zeros_(m.bias)
355
- elif isinstance(m, nn.Linear):
356
- nn.init.xavier_uniform_(m.weight)
357
- if m.bias is not None:
358
- nn.init.zeros_(m.bias)
359
- elif isinstance(m, nn.BatchNorm1d):
360
- nn.init.ones_(m.weight)
361
- nn.init.zeros_(m.bias)
362
-
363
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
364
- lr: float = 0.001):
365
- """
366
- Train the ClassificationPointNet model on saved patches.
367
-
368
- Args:
369
- dataset_dir: Directory containing saved patch files
370
- model_save_path: Path to save the trained model
371
- epochs: Number of training epochs
372
- batch_size: Training batch size
373
- lr: Learning rate
374
- """
375
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
376
- print(f"Training on device: {device}")
377
-
378
- # Create dataset and dataloader
379
- dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
380
- print(f"Dataset loaded with {len(dataset)} samples")
381
-
382
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
383
- collate_fn=collate_fn, drop_last=True)
384
-
385
- # Initialize model
386
- model = ClassificationPointNet(input_dim=6, max_points=1024)
387
- model.apply(init_weights)
388
- model.to(device)
389
-
390
- # Loss function and optimizer (BCE for binary classification)
391
- criterion = nn.BCEWithLogitsLoss()
392
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
393
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
394
-
395
- # Training loop
396
- model.train()
397
- for epoch in range(epochs):
398
- total_loss = 0.0
399
- correct = 0
400
- total = 0
401
- num_batches = 0
402
-
403
- for batch_idx, batch_data in enumerate(dataloader):
404
- if batch_data is None: # Skip invalid batches
405
- continue
406
-
407
- patch_data, labels, valid_masks = batch_data
408
- patch_data = patch_data.to(device) # (batch_size, 6, max_points)
409
- labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
410
-
411
- # Forward pass
412
- optimizer.zero_grad()
413
- outputs = model(patch_data) # (batch_size, 1)
414
- loss = criterion(outputs, labels)
415
-
416
- # Backward pass
417
- loss.backward()
418
- optimizer.step()
419
-
420
- # Statistics
421
- total_loss += loss.item()
422
- predicted = (torch.sigmoid(outputs) > 0.5).float()
423
- total += labels.size(0)
424
- correct += (predicted == labels).sum().item()
425
- num_batches += 1
426
-
427
- if batch_idx % 50 == 0:
428
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
429
- f"Loss: {loss.item():.6f}, "
430
- f"Accuracy: {100 * correct / total:.2f}%")
431
-
432
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
433
- accuracy = 100 * correct / total if total > 0 else 0
434
-
435
- print(f"Epoch {epoch+1}/{epochs} completed, "
436
- f"Avg Loss: {avg_loss:.6f}, "
437
- f"Accuracy: {accuracy:.2f}%")
438
-
439
- scheduler.step()
440
-
441
- # Save model checkpoint every epoch
442
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
443
- torch.save({
444
- 'model_state_dict': model.state_dict(),
445
- 'optimizer_state_dict': optimizer.state_dict(),
446
- 'epoch': epoch + 1,
447
- 'loss': avg_loss,
448
- 'accuracy': accuracy,
449
- }, checkpoint_path)
450
-
451
- # Save the trained model
452
- torch.save({
453
- 'model_state_dict': model.state_dict(),
454
- 'optimizer_state_dict': optimizer.state_dict(),
455
- 'epoch': epochs,
456
- }, model_save_path)
457
-
458
- print(f"Model saved to {model_save_path}")
459
- return model
460
-
461
- def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
462
- """
463
- Load a trained ClassificationPointNet model.
464
-
465
- Args:
466
- model_path: Path to the saved model
467
- device: Device to load the model on
468
-
469
- Returns:
470
- Loaded ClassificationPointNet model
471
- """
472
- if device is None:
473
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
474
-
475
- model = ClassificationPointNet(input_dim=6, max_points=1024)
476
-
477
- checkpoint = torch.load(model_path, map_location=device)
478
- model.load_state_dict(checkpoint['model_state_dict'])
479
-
480
- model.to(device)
481
- model.eval()
482
-
483
- return model
484
-
485
- def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
486
- """
487
- Predict binary classification from a patch using trained PointNet.
488
-
489
- Args:
490
- model: Trained ClassificationPointNet model
491
- patch: Dictionary containing patch data with 'patch_6d' key
492
- device: Device to run prediction on
493
-
494
- Returns:
495
- tuple of (predicted_class, confidence)
496
- predicted_class: int (0 for not edge, 1 for edge)
497
- confidence: float representing confidence score (0-1)
498
- """
499
- if device is None:
500
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
501
-
502
- patch_6d = patch['patch_6d'] # (N, 6)
503
-
504
- # Prepare input
505
- max_points = 1024
506
- num_points = patch_6d.shape[0]
507
-
508
- if num_points >= max_points:
509
- # Sample points
510
- indices = np.random.choice(num_points, max_points, replace=False)
511
- patch_sampled = patch_6d[indices]
512
- else:
513
- # Pad with zeros
514
- patch_sampled = np.zeros((max_points, 6))
515
- patch_sampled[:num_points] = patch_6d
516
-
517
- # Convert to tensor
518
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points)
519
- patch_tensor = patch_tensor.to(device)
520
-
521
- # Predict
522
- with torch.no_grad():
523
- outputs = model(patch_tensor) # (1, 1)
524
- probability = torch.sigmoid(outputs).item()
525
- predicted_class = int(probability > 0.5)
526
-
527
- return predicted_class, probability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_class_v2.py DELETED
@@ -1,508 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class ClassificationPointNet(nn.Module):
12
- """
13
- Fast PointNet-like implementation for binary classification from point cloud patches.
14
- Adapted from FastPointNet, focusing only on classification.
15
- Takes N-dimensional point clouds and predicts a binary class.
16
- """
17
- def __init__(self, input_dim: int = 10, max_points: int = 1024, num_classes: int = 1):
18
- super(ClassificationPointNet, self).__init__()
19
- self.max_points = max_points
20
- self.num_classes = num_classes
21
-
22
- # Enhanced point-wise MLPs with residual connections
23
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
24
- self.conv2 = nn.Conv1d(64, 128, 1)
25
- self.conv3 = nn.Conv1d(128, 256, 1)
26
- self.conv4 = nn.Conv1d(256, 512, 1)
27
- self.conv5 = nn.Conv1d(512, 1024, 1)
28
- self.conv6 = nn.Conv1d(1024, 1024, 1) # Matches FastPointNet structure
29
- self.conv7 = nn.Conv1d(1024, 2048, 1) # Matches FastPointNet structure
30
-
31
- # Lightweight channel attention mechanism
32
- self.channel_attention = nn.Sequential(
33
- nn.AdaptiveAvgPool1d(1),
34
- nn.Conv1d(2048, 128, 1),
35
- nn.ReLU(inplace=True),
36
- nn.Conv1d(128, 2048, 1),
37
- nn.Sigmoid()
38
- )
39
-
40
- # Enhanced shared features with residual connections
41
- self.shared_fc1 = nn.Linear(2048, 1024)
42
- self.shared_fc2 = nn.Linear(1024, 512)
43
- self.shared_fc3 = nn.Linear(512, 512)
44
-
45
- # Classification head
46
- self.class_fc1 = nn.Linear(512, 512)
47
- self.class_fc2 = nn.Linear(512, 256)
48
- self.class_fc3 = nn.Linear(256, 128)
49
- self.class_fc4 = nn.Linear(128, 64)
50
- self.class_fc5 = nn.Linear(64, self.num_classes) # Output for classification
51
-
52
- # Batch normalization layers with momentum
53
- self.bn1 = nn.BatchNorm1d(64, momentum=0.1)
54
- self.bn2 = nn.BatchNorm1d(128, momentum=0.1)
55
- self.bn3 = nn.BatchNorm1d(256, momentum=0.1)
56
- self.bn4 = nn.BatchNorm1d(512, momentum=0.1)
57
- self.bn5 = nn.BatchNorm1d(1024, momentum=0.1)
58
- self.bn6 = nn.BatchNorm1d(1024, momentum=0.1)
59
- self.bn7 = nn.BatchNorm1d(2048, momentum=0.1)
60
-
61
- # Group normalization for shared layers
62
- self.gn1 = nn.GroupNorm(32, 1024) # Assuming 1024 channels, 32 groups
63
- self.gn2 = nn.GroupNorm(16, 512) # Assuming 512 channels, 16 groups
64
-
65
- # Dropout layers
66
- self.dropout_light = nn.Dropout(0.1)
67
- self.dropout_medium = nn.Dropout(0.2)
68
- # self.dropout_heavy = nn.Dropout(0.3) # Not used in the direct path to classification in this adaptation
69
-
70
- def forward(self, x: torch.Tensor) -> torch.Tensor:
71
- """
72
- Forward pass with residual connections and attention for classification.
73
- Args:
74
- x: (batch_size, input_dim, max_points) tensor
75
- Returns:
76
- classification: (batch_size, num_classes) tensor of logits
77
- """
78
- # Enhanced point-wise feature extraction
79
- x1 = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01, inplace=True)
80
- x2 = F.leaky_relu(self.bn2(self.conv2(x1)), negative_slope=0.01, inplace=True)
81
- x3 = F.leaky_relu(self.bn3(self.conv3(x2)), negative_slope=0.01, inplace=True)
82
- x4 = F.leaky_relu(self.bn4(self.conv4(x3)), negative_slope=0.01, inplace=True)
83
- x5 = F.leaky_relu(self.bn5(self.conv5(x4)), negative_slope=0.01, inplace=True)
84
-
85
- # Residual connection for conv6
86
- x6_conv = self.bn6(self.conv6(x5))
87
- x6 = F.leaky_relu(x6_conv + x5, negative_slope=0.01, inplace=True) # Add residual before ReLU
88
-
89
- x7 = F.leaky_relu(self.bn7(self.conv7(x6)), negative_slope=0.01, inplace=True)
90
-
91
- # Apply channel attention
92
- attention_weights = self.channel_attention(x7)
93
- x7_attended = x7 * attention_weights
94
-
95
- # Multi-scale global pooling
96
- max_pool = torch.max(x7_attended, 2)[0]
97
- avg_pool = torch.mean(x7_attended, 2)
98
- global_features = 0.7 * max_pool + 0.3 * avg_pool
99
-
100
- # Enhanced shared features
101
- shared1_fc = self.shared_fc1(global_features)
102
- shared1 = F.leaky_relu(self.gn1(shared1_fc.unsqueeze(-1)).squeeze(-1), negative_slope=0.01, inplace=True)
103
- shared1 = self.dropout_light(shared1)
104
-
105
- shared2_fc = self.shared_fc2(shared1)
106
- shared2 = F.leaky_relu(self.gn2(shared2_fc.unsqueeze(-1)).squeeze(-1), negative_slope=0.01, inplace=True)
107
- shared2 = self.dropout_medium(shared2)
108
-
109
- shared3_fc = self.shared_fc3(shared2)
110
- # Residual connection for shared_fc3
111
- shared_features = F.leaky_relu(shared3_fc + shared2, negative_slope=0.01, inplace=True) # Add residual before ReLU
112
- shared_features = self.dropout_light(shared_features) # Apply dropout after residual and ReLU
113
-
114
- # Classification head
115
- class1 = F.leaky_relu(self.class_fc1(shared_features), negative_slope=0.01, inplace=True)
116
- class1 = self.dropout_light(class1)
117
-
118
- class2 = F.leaky_relu(self.class_fc2(class1), negative_slope=0.01, inplace=True)
119
- class2 = self.dropout_medium(class2)
120
-
121
- class3 = F.leaky_relu(self.class_fc3(class2), negative_slope=0.01, inplace=True)
122
- class3 = self.dropout_light(class3)
123
-
124
- class4 = F.leaky_relu(self.class_fc4(class3), negative_slope=0.01, inplace=True)
125
- # No dropout before the final layer typically
126
-
127
- classification = self.class_fc5(class4) # Raw logits
128
-
129
- return classification
130
-
131
- class PatchClassificationDataset(Dataset):
132
- """
133
- Dataset class for loading saved patches for PointNet classification training.
134
- """
135
-
136
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = False, input_dim: int = 10): # Added input_dim
137
- self.dataset_dir = dataset_dir
138
- self.max_points = max_points
139
- self.augment = augment
140
- self.input_dim = input_dim # Store input_dim
141
-
142
- # Load patch files
143
- self.patch_files = []
144
- for file in os.listdir(dataset_dir):
145
- if file.endswith('.pkl'):
146
- self.patch_files.append(os.path.join(dataset_dir, file))
147
-
148
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
149
-
150
- def __len__(self):
151
- return len(self.patch_files)
152
-
153
- def __getitem__(self, idx):
154
- """
155
- Load and process a patch for training.
156
- Returns:
157
- patch_data: (input_dim, max_points) tensor of point cloud data
158
- label: scalar tensor for binary classification (0 or 1)
159
- valid_mask: (max_points,) boolean tensor indicating valid points
160
- """
161
- patch_file = self.patch_files[idx]
162
-
163
- with open(patch_file, 'rb') as f:
164
- patch_info = pickle.load(f)
165
-
166
- # Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
167
- # For this example, let's assume the key is 'patch_data' and it holds the 10D data.
168
- # If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
169
- patch_data_nd = patch_info.get('patch_data', patch_info.get('patch_10d', patch_info.get('patch_6d'))) # Try to get 10d, fallback to 6d for now
170
- if patch_data_nd.shape[1] != self.input_dim:
171
- # This is a fallback or error handling if the loaded data isn't 10D.
172
- # You might want to raise an error or handle this case specifically.
173
- # For now, if it's 6D, we'll pad it to 10D with zeros as a placeholder.
174
- # This part needs to be adjusted based on how your 10D data is actually stored.
175
- print(f"Warning: Patch {patch_file} has {patch_data_nd.shape[1]} dimensions, expected {self.input_dim}. Padding with zeros if necessary.")
176
- if patch_data_nd.shape[1] < self.input_dim:
177
- padding = np.zeros((patch_data_nd.shape[0], self.input_dim - patch_data_nd.shape[1]))
178
- patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
179
- elif patch_data_nd.shape[1] > self.input_dim:
180
- patch_data_nd = patch_data_nd[:, :self.input_dim]
181
-
182
-
183
- label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
184
-
185
- # Pad or sample points to max_points
186
- num_points = patch_data_nd.shape[0]
187
-
188
- if num_points >= self.max_points:
189
- # Randomly sample max_points
190
- indices = np.random.choice(num_points, self.max_points, replace=False)
191
- patch_sampled = patch_data_nd[indices]
192
- valid_mask = np.ones(self.max_points, dtype=bool)
193
- else:
194
- # Pad with zeros
195
- patch_sampled = np.zeros((self.max_points, self.input_dim)) # Changed to self.input_dim
196
- patch_sampled[:num_points] = patch_data_nd
197
- valid_mask = np.zeros(self.max_points, dtype=bool)
198
- valid_mask[:num_points] = True
199
-
200
- # Data augmentation
201
- if self.augment:
202
- # Note: _augment_patch currently only augments xyz (first 3 dims).
203
- # If other dimensions are geometric and need augmentation, this function needs an update.
204
- patch_sampled = self._augment_patch(patch_sampled, valid_mask)
205
-
206
- # Convert to tensors and transpose for conv1d (channels first)
207
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (input_dim, max_points)
208
- label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
209
- valid_mask_tensor = torch.from_numpy(valid_mask)
210
-
211
- return patch_tensor, label_tensor, valid_mask_tensor
212
-
213
- def _augment_patch(self, patch, valid_mask):
214
- """
215
- Apply data augmentation to the patch.
216
- Note: This implementation only augments the first 3 dimensions (assumed to be XYZ).
217
- If your 10D representation has other geometric features that need augmentation,
218
- this function should be updated accordingly.
219
- """
220
- valid_points_data = patch[valid_mask]
221
-
222
- if len(valid_points_data) == 0:
223
- return patch
224
-
225
- # Extract XYZ for augmentation (first 3 columns)
226
- valid_points_xyz = valid_points_data[:, :3].copy() # Operate on a copy
227
-
228
- # Random rotation around z-axis
229
- angle = np.random.uniform(0, 2 * np.pi)
230
- cos_angle = np.cos(angle)
231
- sin_angle = np.sin(angle)
232
- rotation_matrix = np.array([
233
- [cos_angle, -sin_angle, 0],
234
- [sin_angle, cos_angle, 0],
235
- [0, 0, 1]
236
- ])
237
-
238
- # Apply rotation to xyz coordinates
239
- valid_points_xyz = valid_points_xyz @ rotation_matrix.T
240
-
241
- # Random jittering
242
- noise = np.random.normal(0, 0.01, valid_points_xyz.shape)
243
- valid_points_xyz += noise
244
-
245
- # Random scaling
246
- scale = np.random.uniform(0.9, 1.1)
247
- valid_points_xyz *= scale
248
-
249
- # Update the original patch data
250
- augmented_patch = patch.copy()
251
- augmented_patch[valid_mask, :3] = valid_points_xyz
252
-
253
- return augmented_patch
254
-
255
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
256
- """
257
- Save patches from prediction pipeline to create a training dataset.
258
- Ensure 'patch_data' (or 'patch_10d') in the patch dictionary contains the 10D data.
259
-
260
- Args:
261
- patches: List of patch dictionaries from generate_patches()
262
- dataset_dir: Directory to save the dataset
263
- entry_id: Unique identifier for this entry/image
264
- """
265
- os.makedirs(dataset_dir, exist_ok=True)
266
-
267
- for i, patch in enumerate(patches):
268
- # Create unique filename
269
- filename = f"{entry_id}_patch_{i}.pkl"
270
- filepath = os.path.join(dataset_dir, filename)
271
-
272
- # Skip if file already exists
273
- if os.path.exists(filepath):
274
- continue
275
-
276
- # Ensure the patch data being saved is 10D.
277
- # Example: patch_data_key = 'patch_10d' or 'patch_data'
278
- # if 'patch_data' not in patch or patch['patch_data'].shape[1] != 10:
279
- # print(f"Warning: Patch {i} for entry {entry_id} does not seem to be 10D. Skipping or error handling needed.")
280
- # continue
281
-
282
- with open(filepath, 'wb') as f:
283
- pickle.dump(patch, f)
284
-
285
- print(f"Saved {len(patches)} patches for entry {entry_id}")
286
-
287
- # Create dataloader with custom collate function to filter invalid samples
288
- def collate_fn(batch):
289
- valid_batch = []
290
- for patch_data, label, valid_mask in batch:
291
- # Filter out invalid samples (no valid points)
292
- if valid_mask.sum() > 0:
293
- valid_batch.append((patch_data, label, valid_mask))
294
-
295
- if len(valid_batch) == 0:
296
- return None
297
-
298
- # Stack valid samples
299
- patch_data = torch.stack([item[0] for item in valid_batch])
300
- labels = torch.stack([item[1] for item in valid_batch])
301
- valid_masks = torch.stack([item[2] for item in valid_batch])
302
-
303
- return patch_data, labels, valid_masks
304
-
305
- # Initialize weights using Xavier/Glorot initialization
306
- def init_weights(m):
307
- if isinstance(m, nn.Conv1d):
308
- nn.init.xavier_uniform_(m.weight)
309
- if m.bias is not None:
310
- nn.init.zeros_(m.bias)
311
- elif isinstance(m, nn.Linear):
312
- nn.init.xavier_uniform_(m.weight)
313
- if m.bias is not None:
314
- nn.init.zeros_(m.bias)
315
- elif isinstance(m, nn.BatchNorm1d):
316
- nn.init.ones_(m.weight)
317
- nn.init.zeros_(m.bias)
318
-
319
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
320
- lr: float = 0.001, input_dim: int = 10): # Added input_dim
321
- """
322
- Train the ClassificationPointNet model on saved patches.
323
-
324
- Args:
325
- dataset_dir: Directory containing saved patch files
326
- model_save_path: Path to save the trained model
327
- epochs: Number of training epochs
328
- batch_size: Training batch size
329
- lr: Learning rate
330
- input_dim: Dimensionality of the input points (e.g., 10 for 10D)
331
- """
332
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
333
- print(f"Training on device: {device}")
334
-
335
- # Create dataset and dataloader
336
- dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=False, input_dim=input_dim) # Pass input_dim
337
- print(f"Dataset loaded with {len(dataset)} samples")
338
-
339
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20,
340
- collate_fn=collate_fn, drop_last=True)
341
-
342
- # Initialize model
343
- model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Pass input_dim
344
- model.apply(init_weights)
345
- model.to(device)
346
-
347
- # Loss function and optimizer (BCE for binary classification)
348
- criterion = nn.BCEWithLogitsLoss()
349
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
350
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
351
-
352
- # Training loop
353
- model.train()
354
- for epoch in range(epochs):
355
- total_loss = 0.0
356
- correct = 0
357
- total = 0
358
- num_batches = 0
359
-
360
- for batch_idx, batch_data in enumerate(dataloader):
361
- if batch_data is None: # Skip invalid batches
362
- continue
363
-
364
- patch_data, labels, valid_masks = batch_data
365
- patch_data = patch_data.to(device) # (batch_size, input_dim, max_points)
366
- labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
367
-
368
- # Forward pass
369
- optimizer.zero_grad()
370
- outputs = model(patch_data) # (batch_size, 1)
371
- loss = criterion(outputs, labels)
372
-
373
- # Backward pass
374
- loss.backward()
375
- optimizer.step()
376
-
377
- # Statistics
378
- total_loss += loss.item()
379
- predicted = (torch.sigmoid(outputs) > 0.5).float()
380
- total += labels.size(0)
381
- correct += (predicted == labels).sum().item()
382
- num_batches += 1
383
-
384
- if batch_idx % 50 == 0:
385
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
386
- f"Loss: {loss.item():.6f}, "
387
- f"Accuracy: {100 * correct / total:.2f}%")
388
-
389
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
390
- accuracy = 100 * correct / total if total > 0 else 0
391
-
392
- print(f"Epoch {epoch+1}/{epochs} completed, "
393
- f"Avg Loss: {avg_loss:.6f}, "
394
- f"Accuracy: {accuracy:.2f}%")
395
-
396
- scheduler.step()
397
-
398
- # Save model checkpoint every epoch
399
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
400
- torch.save({
401
- 'model_state_dict': model.state_dict(),
402
- 'optimizer_state_dict': optimizer.state_dict(),
403
- 'epoch': epoch + 1,
404
- 'loss': avg_loss,
405
- 'accuracy': accuracy,
406
- 'input_dim': input_dim, # Save input_dim with checkpoint
407
- }, checkpoint_path)
408
-
409
- # Save the trained model
410
- torch.save({
411
- 'model_state_dict': model.state_dict(),
412
- 'optimizer_state_dict': optimizer.state_dict(),
413
- 'epoch': epochs,
414
- 'input_dim': input_dim, # Save input_dim with final model
415
- }, model_save_path)
416
-
417
- print(f"Model saved to {model_save_path}")
418
- return model
419
-
420
- def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
421
- """
422
- Load a trained ClassificationPointNet model.
423
-
424
- Args:
425
- model_path: Path to the saved model
426
- device: Device to load the model on
427
-
428
- Returns:
429
- Loaded ClassificationPointNet model
430
- """
431
- if device is None:
432
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
433
-
434
- checkpoint = torch.load(model_path, map_location=device)
435
-
436
- # Load input_dim from checkpoint if available, otherwise default to 10
437
- # For older models saved without input_dim, you might need to specify it or assume a default.
438
- input_dim = checkpoint.get('input_dim', 10)
439
-
440
- model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Use loaded or default input_dim
441
- model.load_state_dict(checkpoint['model_state_dict'])
442
-
443
- model.to(device)
444
- model.eval()
445
-
446
- return model
447
-
448
- def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
449
- """
450
- Predict binary classification from a patch using trained PointNet.
451
- Assumes the model's input_dim matches the data.
452
-
453
- Args:
454
- model: Trained ClassificationPointNet model
455
- patch: Dictionary containing patch data. Expects a key like 'patch_data' or 'patch_10d' with (N, 10) shape.
456
- device: Device to run prediction on
457
-
458
- Returns:
459
- tuple of (predicted_class, confidence)
460
- predicted_class: int (0 for not edge, 1 for edge)
461
- confidence: float representing confidence score (0-1)
462
- """
463
- if device is None:
464
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
465
-
466
- # Determine input_dim from the model
467
- input_dim = model.conv1.in_channels
468
-
469
- # Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
470
- # For this example, let's assume the key is 'patch_data' and it holds the 10D data.
471
- # If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
472
- patch_data_nd = patch.get('patch_data', patch.get('patch_10d', patch.get('patch_6d'))) # Try to get 10d, fallback to 6d
473
-
474
- if patch_data_nd.shape[1] != input_dim:
475
- # Handle dimension mismatch, e.g., by padding or raising an error
476
- print(f"Warning: Input patch has {patch_data_nd.shape[1]} dimensions, but model expects {input_dim}. Adjusting...")
477
- if patch_data_nd.shape[1] < input_dim:
478
- padding = np.zeros((patch_data_nd.shape[0], input_dim - patch_data_nd.shape[1]))
479
- patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
480
- elif patch_data_nd.shape[1] > input_dim:
481
- patch_data_nd = patch_data_nd[:, :input_dim]
482
-
483
- # Prepare input
484
- max_points = model.max_points # Use max_points from the model instance
485
- num_points = patch_data_nd.shape[0]
486
-
487
- if num_points >= max_points:
488
- # Sample points
489
- indices = np.random.choice(num_points, max_points, replace=False)
490
- patch_sampled = patch_data_nd[indices]
491
- else:
492
- # Pad with zeros
493
- patch_sampled = np.zeros((max_points, input_dim)) # Use model's input_dim
494
- patch_sampled[:num_points] = patch_data_nd
495
-
496
- # Convert to tensor
497
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, input_dim, max_points)
498
- patch_tensor = patch_tensor.to(device)
499
-
500
- # Predict
501
- model.eval() # Ensure model is in eval mode
502
- with torch.no_grad():
503
- outputs = model(patch_tensor) # (1, 1)
504
- probability = torch.sigmoid(outputs).item()
505
- predicted_class = int(probability > 0.5)
506
-
507
- return predicted_class, probability
508
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_pointnet_v2.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
@@ -568,4 +578,4 @@ def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: to
568
  offset = patch['cluster_center']
569
  position += offset
570
 
571
- return position, score, classification
 
1
+ # This file defines a FastPointNet model for 3D vertex prediction from point clouds.
2
+ # It includes:
3
+ # 1. `FastPointNet`: A deep neural network with enhancements like residual connections,
4
+ # channel attention, and multi-scale pooling. It predicts 3D coordinates,
5
+ # and optionally, confidence scores and classification labels.
6
+ # 2. `PatchDataset`: A PyTorch Dataset for loading, preprocessing, and augmenting
7
+ # 11-dimensional point cloud patches.
8
+ # 3. Utility functions for:
9
+ # - Training the model (`train_pointnet`) with custom loss and optimization.
10
+ # - Loading/saving models, and performing inference (`predict_vertex_from_patch`).
11
  import os
12
  import torch
13
  import torch.nn as nn
 
578
  offset = patch['cluster_center']
579
  position += offset
580
 
581
+ return position, score, classification
fast_pointnet_v3.py DELETED
@@ -1,605 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class FastPointNet(nn.Module):
12
- """
13
- Fast PointNet implementation for 3D vertex prediction from point cloud patches.
14
- Takes 11D point clouds and predicts 3D vertex coordinates.
15
- Enhanced with transformer attention, deeper architecture, and moderate capacity increase.
16
- """
17
- def __init__(self, input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1):
18
- super(FastPointNet, self).__init__()
19
- self.max_points = max_points
20
- self.predict_score = predict_score
21
- self.predict_class = predict_class
22
- self.num_classes = num_classes
23
-
24
- # Enhanced point-wise MLPs with moderate capacity increase
25
- self.conv1 = nn.Conv1d(input_dim, 96, 1) # 64 -> 96
26
- self.conv2 = nn.Conv1d(96, 192, 1) # 128 -> 192
27
- self.conv3 = nn.Conv1d(192, 384, 1) # 256 -> 384
28
- self.conv4 = nn.Conv1d(384, 768, 1) # 512 -> 768
29
- self.conv5 = nn.Conv1d(768, 1536, 1) # 1024 -> 1536
30
- self.conv6 = nn.Conv1d(1536, 1536, 1) # Keep same
31
- self.conv7 = nn.Conv1d(1536, 2048, 1) # Reduce from 1536 to 2048 for transformer
32
-
33
- # Lightweight Self-Attention Transformer Block
34
- self.transformer_dim = 2048
35
- self.num_heads = 8
36
- self.transformer_block = nn.MultiheadAttention(
37
- embed_dim=self.transformer_dim,
38
- num_heads=self.num_heads,
39
- dropout=0.1,
40
- batch_first=False
41
- )
42
- self.transformer_norm1 = nn.LayerNorm(self.transformer_dim)
43
- self.transformer_norm2 = nn.LayerNorm(self.transformer_dim)
44
-
45
- # Transformer FFN
46
- self.transformer_ffn = nn.Sequential(
47
- nn.Linear(self.transformer_dim, self.transformer_dim * 2),
48
- nn.GELU(),
49
- nn.Dropout(0.1),
50
- nn.Linear(self.transformer_dim * 2, self.transformer_dim),
51
- nn.Dropout(0.1)
52
- )
53
-
54
- # Enhanced channel attention mechanism
55
- self.channel_attention = nn.Sequential(
56
- nn.AdaptiveAvgPool1d(1),
57
- nn.Conv1d(2048, 192, 1), # 128 -> 192
58
- nn.GELU(),
59
- nn.Conv1d(192, 2048, 1),
60
- nn.Sigmoid()
61
- )
62
-
63
- # Enhanced shared features with moderate increase
64
- self.shared_fc1 = nn.Linear(2048, 1536) # 1024 -> 1536
65
- self.shared_fc2 = nn.Linear(1536, 768) # 512 -> 768
66
- self.shared_fc3 = nn.Linear(768, 768) # Additional layer
67
-
68
- # Enhanced position prediction head
69
- self.pos_fc1 = nn.Linear(768, 768) # 512 -> 768
70
- self.pos_fc2 = nn.Linear(768, 384) # 256 -> 384
71
- self.pos_fc3 = nn.Linear(384, 192) # 128 -> 192
72
- self.pos_fc4 = nn.Linear(192, 96) # 64 -> 96
73
- self.pos_fc5 = nn.Linear(96, output_dim)
74
-
75
- # Enhanced score prediction head
76
- if self.predict_score:
77
- self.score_fc1 = nn.Linear(768, 768)
78
- self.score_fc2 = nn.Linear(768, 384)
79
- self.score_fc3 = nn.Linear(384, 192)
80
- self.score_fc4 = nn.Linear(192, 96)
81
- self.score_fc5 = nn.Linear(96, 1)
82
-
83
- # Enhanced classification head
84
- if self.predict_class:
85
- self.class_fc1 = nn.Linear(768, 768)
86
- self.class_fc2 = nn.Linear(768, 384)
87
- self.class_fc3 = nn.Linear(384, 192)
88
- self.class_fc4 = nn.Linear(192, 96)
89
- self.class_fc5 = nn.Linear(96, num_classes)
90
-
91
- # Batch normalization layers
92
- self.bn1 = nn.BatchNorm1d(96, momentum=0.1)
93
- self.bn2 = nn.BatchNorm1d(192, momentum=0.1)
94
- self.bn3 = nn.BatchNorm1d(384, momentum=0.1)
95
- self.bn4 = nn.BatchNorm1d(768, momentum=0.1)
96
- self.bn5 = nn.BatchNorm1d(1536, momentum=0.1)
97
- self.bn6 = nn.BatchNorm1d(1536, momentum=0.1)
98
- self.bn7 = nn.BatchNorm1d(2048, momentum=0.1)
99
-
100
- # Group normalization for shared layers
101
- self.gn1 = nn.GroupNorm(48, 1536) # 32 -> 48 groups
102
- self.gn2 = nn.GroupNorm(24, 768) # 16 -> 24 groups
103
-
104
- # Dropout with different rates
105
- self.dropout_light = nn.Dropout(0.1)
106
- self.dropout_medium = nn.Dropout(0.2)
107
- self.dropout_heavy = nn.Dropout(0.3)
108
-
109
- def forward(self, x):
110
- """
111
- Forward pass with transformer attention and residual connections
112
- Args:
113
- x: (batch_size, input_dim, max_points) tensor
114
- Returns:
115
- Tuple containing predictions based on configuration
116
- """
117
- batch_size = x.size(0)
118
-
119
- # Enhanced point-wise feature extraction
120
- x1 = F.gelu(self.bn1(self.conv1(x)))
121
- x2 = F.gelu(self.bn2(self.conv2(x1)))
122
- x3 = F.gelu(self.bn3(self.conv3(x2)))
123
- x4 = F.gelu(self.bn4(self.conv4(x3)))
124
- x5 = F.gelu(self.bn5(self.conv5(x4)))
125
-
126
- # Residual connection
127
- x6 = F.gelu(self.bn6(self.conv6(x5)) + x5)
128
- x7 = F.gelu(self.bn7(self.conv7(x6)))
129
-
130
- # Apply transformer attention
131
- # Reshape for transformer: (seq_len, batch_size, embed_dim)
132
- x7_reshaped = x7.permute(2, 0, 1) # (max_points, batch_size, 2048)
133
-
134
- # Self-attention with residual connection
135
- attn_out, _ = self.transformer_block(x7_reshaped, x7_reshaped, x7_reshaped)
136
- x7_attn = self.transformer_norm1(x7_reshaped + attn_out)
137
-
138
- # Transformer FFN with residual connection
139
- ffn_out = self.transformer_ffn(x7_attn)
140
- x7_transformer = self.transformer_norm2(x7_attn + ffn_out)
141
-
142
- # Reshape back: (batch_size, embed_dim, seq_len)
143
- x7_transformer = x7_transformer.permute(1, 2, 0)
144
-
145
- # Apply channel attention
146
- attention_weights = self.channel_attention(x7_transformer)
147
- x7_attended = x7_transformer * attention_weights
148
-
149
- # Multi-scale global pooling
150
- max_pool = torch.max(x7_attended, 2)[0] # (batch_size, 2048)
151
- avg_pool = torch.mean(x7_attended, 2) # (batch_size, 2048)
152
-
153
- # Weighted combination of pooling operations
154
- global_features = 0.7 * max_pool + 0.3 * avg_pool
155
-
156
- # Enhanced shared features with residual connections
157
- shared1 = F.gelu(self.gn1(self.shared_fc1(global_features).unsqueeze(-1)).squeeze(-1))
158
- shared1 = self.dropout_light(shared1)
159
-
160
- shared2 = F.gelu(self.gn2(self.shared_fc2(shared1).unsqueeze(-1)).squeeze(-1))
161
- shared2 = self.dropout_medium(shared2)
162
-
163
- # Additional shared layer with residual connection
164
- shared3 = F.gelu(self.shared_fc3(shared2))
165
- shared_features = self.dropout_light(shared3) + shared2
166
-
167
- # Enhanced position prediction
168
- pos1 = F.gelu(self.pos_fc1(shared_features))
169
- pos1 = self.dropout_light(pos1)
170
-
171
- pos2 = F.gelu(self.pos_fc2(pos1))
172
- pos2 = self.dropout_medium(pos2)
173
-
174
- pos3 = F.gelu(self.pos_fc3(pos2))
175
- pos3 = self.dropout_light(pos3)
176
-
177
- pos4 = F.gelu(self.pos_fc4(pos3))
178
- position = self.pos_fc5(pos4)
179
-
180
- outputs = [position]
181
-
182
- if self.predict_score:
183
- # Enhanced score prediction
184
- score1 = F.gelu(self.score_fc1(shared_features))
185
- score1 = self.dropout_light(score1)
186
- score2 = F.gelu(self.score_fc2(score1))
187
- score2 = self.dropout_medium(score2)
188
- score3 = F.gelu(self.score_fc3(score2))
189
- score3 = self.dropout_light(score3)
190
- score4 = F.gelu(self.score_fc4(score3))
191
- score = F.softplus(self.score_fc5(score4))
192
- outputs.append(score)
193
-
194
- if self.predict_class:
195
- # Classification prediction
196
- class1 = F.gelu(self.class_fc1(shared_features))
197
- class1 = self.dropout_light(class1)
198
- class2 = F.gelu(self.class_fc2(class1))
199
- class2 = self.dropout_medium(class2)
200
- class3 = F.gelu(self.class_fc3(class2))
201
- class3 = self.dropout_light(class3)
202
- class4 = F.gelu(self.class_fc4(class3))
203
- classification = self.class_fc5(class4)
204
- outputs.append(classification)
205
-
206
- # Return outputs based on configuration
207
- if len(outputs) == 1:
208
- return outputs[0]
209
- elif len(outputs) == 2:
210
- if self.predict_score:
211
- return outputs[0], outputs[1]
212
- else:
213
- return outputs[0], outputs[1]
214
- else:
215
- return outputs[0], outputs[1], outputs[2]
216
-
217
- class PatchDataset(Dataset):
218
- """
219
- Dataset class for loading saved patches for PointNet training.
220
- Updated for 11D patches.
221
- """
222
-
223
- def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
224
- self.dataset_dir = dataset_dir
225
- self.max_points = max_points
226
- self.augment = augment
227
-
228
- # Load patch files
229
- self.patch_files = []
230
- for file in os.listdir(dataset_dir):
231
- if file.endswith('.pkl'):
232
- self.patch_files.append(os.path.join(dataset_dir, file))
233
-
234
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
235
-
236
- def __len__(self):
237
- return len(self.patch_files)
238
-
239
- def __getitem__(self, idx):
240
- """
241
- Load and process a patch for training.
242
- Returns:
243
- patch_data: (11, max_points) tensor of point cloud data
244
- target: (3,) tensor of target 3D coordinates
245
- valid_mask: (max_points,) boolean tensor indicating valid points
246
- distance_to_gt: scalar tensor of distance from initial prediction to GT
247
- classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
248
- """
249
- patch_file = self.patch_files[idx]
250
-
251
- with open(patch_file, 'rb') as f:
252
- patch_info = pickle.load(f)
253
-
254
- patch_11d = patch_info['patch_11d'] # (N, 11) - Updated for 11D
255
- target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
256
- initial_pred = patch_info.get('cluster_center', None) # (3,) or None
257
-
258
- # Determine classification label based on GT vertex presence
259
- has_gt_vertex = 1.0 if target is not None else 0.0
260
-
261
- # Handle patches without ground truth
262
- if target is None:
263
- # Use a dummy target for consistency, but mark as invalid with classification
264
- target = np.zeros(3)
265
- else:
266
- target = np.array(target)
267
-
268
- # Pad or sample points to max_points
269
- num_points = patch_11d.shape[0]
270
-
271
- if num_points >= self.max_points:
272
- # Randomly sample max_points
273
- indices = np.random.choice(num_points, self.max_points, replace=False)
274
- patch_sampled = patch_11d[indices]
275
- valid_mask = np.ones(self.max_points, dtype=bool)
276
- else:
277
- # Pad with zeros
278
- patch_sampled = np.zeros((self.max_points, 11)) # Updated for 11D
279
- patch_sampled[:num_points] = patch_11d
280
- valid_mask = np.zeros(self.max_points, dtype=bool)
281
- valid_mask[:num_points] = True
282
-
283
- # Data augmentation (only if GT vertex is present)
284
- if self.augment and has_gt_vertex > 0:
285
- patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target)
286
-
287
- # Convert to tensors and transpose for conv1d (channels first)
288
- patch_tensor = torch.from_numpy(patch_sampled.T).float() # (11, max_points)
289
- target_tensor = torch.from_numpy(target).float() # (3,)
290
- valid_mask_tensor = torch.from_numpy(valid_mask)
291
-
292
- # Handle initial_pred
293
- if initial_pred is not None:
294
- initial_pred_tensor = torch.from_numpy(initial_pred).float()
295
- else:
296
- initial_pred_tensor = torch.zeros(3).float()
297
-
298
- # Classification tensor
299
- classification_tensor = torch.tensor(has_gt_vertex).float()
300
-
301
- return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor
302
-
303
- def _augment_patch(self, patch_sampled, valid_mask, target):
304
- """
305
- Apply data augmentation to patch and target.
306
- Only augment valid points and update target accordingly.
307
- """
308
- valid_points = patch_sampled[valid_mask]
309
-
310
- if len(valid_points) > 0:
311
- # Random rotation around Z-axis (small angle)
312
- angle = np.random.uniform(-np.pi/12, np.pi/12) # ±15 degrees
313
- cos_a, sin_a = np.cos(angle), np.sin(angle)
314
- rotation_matrix = np.array([[cos_a, -sin_a, 0],
315
- [sin_a, cos_a, 0],
316
- [0, 0, 1]])
317
-
318
- # Apply rotation to xyz coordinates
319
- valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
320
- target = target @ rotation_matrix.T
321
-
322
- # Small random translation
323
- translation = np.random.uniform(-0.05, 0.05, 3)
324
- valid_points[:, :3] += translation
325
- target += translation
326
-
327
- # Random scaling (small)
328
- scale = np.random.uniform(0.95, 1.05)
329
- valid_points[:, :3] *= scale
330
- target *= scale
331
-
332
- # Add small noise to features (not coordinates)
333
- if valid_points.shape[1] > 3:
334
- noise = np.random.normal(0, 0.01, valid_points[:, 3:].shape)
335
- valid_points[:, 3:] += noise
336
-
337
- # Update patch with augmented valid points
338
- patch_sampled[valid_mask] = valid_points
339
-
340
- return patch_sampled, target
341
-
342
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
343
- """
344
- Save patches from prediction pipeline to create a training dataset.
345
-
346
- Args:
347
- patches: List of patch dictionaries from generate_patches()
348
- dataset_dir: Directory to save the dataset
349
- entry_id: Unique identifier for this entry/image
350
- """
351
- os.makedirs(dataset_dir, exist_ok=True)
352
-
353
- for i, patch in enumerate(patches):
354
- # Create unique filename
355
- filename = f"{entry_id}_patch_{i}.pkl"
356
- filepath = os.path.join(dataset_dir, filename)
357
-
358
- # Skip if file already exists
359
- if os.path.exists(filepath):
360
- continue
361
-
362
- # Save patch data
363
- with open(filepath, 'wb') as f:
364
- pickle.dump(patch, f)
365
-
366
- print(f"Saved {len(patches)} patches for entry {entry_id}")
367
-
368
- # Create dataloader with custom collate function to filter invalid samples
369
- def collate_fn(batch):
370
- valid_batch = []
371
- for patch_data, target, valid_mask, initial_pred, classification in batch:
372
- # Filter out invalid samples (no valid points)
373
- if valid_mask.sum() > 0:
374
- valid_batch.append((patch_data, target, valid_mask, initial_pred, classification))
375
-
376
- if len(valid_batch) == 0:
377
- return None
378
-
379
- # Stack valid samples
380
- patch_data = torch.stack([item[0] for item in valid_batch])
381
- targets = torch.stack([item[1] for item in valid_batch])
382
- valid_masks = torch.stack([item[2] for item in valid_batch])
383
- initial_preds = torch.stack([item[3] for item in valid_batch])
384
- classifications = torch.stack([item[4] for item in valid_batch])
385
-
386
- return patch_data, targets, valid_masks, initial_preds, classifications
387
-
388
- # Initialize weights using Kaiming initialization for LeakyReLU
389
- def init_weights(m):
390
- if isinstance(m, nn.Conv1d):
391
- nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
392
- if m.bias is not None:
393
- nn.init.zeros_(m.bias)
394
- elif isinstance(m, nn.Linear):
395
- nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
396
- if m.bias is not None:
397
- nn.init.zeros_(m.bias)
398
- elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
399
- nn.init.ones_(m.weight)
400
- nn.init.zeros_(m.bias)
401
-
402
- def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
403
- score_weight: float = 0.1, class_weight: float = 0.5):
404
- """
405
- Train the FastPointNet model on saved patches.
406
- Updated for 11D input.
407
- """
408
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
409
- print(f"Training on device: {device}")
410
-
411
- # Create dataset and dataloader
412
- dataset = PatchDataset(dataset_dir, max_points=1024, augment=True) # Enable augmentation
413
- print(f"Dataset loaded with {len(dataset)} samples")
414
-
415
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20,
416
- collate_fn=collate_fn, drop_last=True)
417
-
418
- # Initialize model with 11D input
419
- model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1)
420
-
421
- model.apply(init_weights)
422
- model.to(device)
423
-
424
- # Loss functions with label smoothing for classification
425
- position_criterion = nn.SmoothL1Loss() # More robust than MSE
426
- score_criterion = nn.SmoothL1Loss()
427
- classification_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.0)) # Weight positive class more
428
-
429
- # AdamW optimizer with weight decay
430
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, betas=(0.9, 0.999))
431
-
432
- # Cosine annealing scheduler for better convergence
433
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
434
-
435
- # Training loop
436
- model.train()
437
- for epoch in range(epochs):
438
- total_loss = 0.0
439
- total_pos_loss = 0.0
440
- total_score_loss = 0.0
441
- total_class_loss = 0.0
442
- num_batches = 0
443
-
444
- for batch_idx, batch_data in enumerate(dataloader):
445
- if batch_data is None: # Skip invalid batches
446
- continue
447
-
448
- patch_data, targets, valid_masks, initial_preds, classifications = batch_data
449
- patch_data = patch_data.to(device) # (batch_size, 11, max_points)
450
- targets = targets.to(device) # (batch_size, 3)
451
- classifications = classifications.to(device) # (batch_size,)
452
-
453
- # Forward pass
454
- optimizer.zero_grad()
455
- predictions, predicted_scores, predicted_classes = model(patch_data)
456
-
457
- # Compute actual distance from predictions to targets
458
- actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
459
-
460
- # Only compute position and score losses for samples with GT vertices
461
- has_gt_mask = classifications > 0.5
462
-
463
- if has_gt_mask.sum() > 0:
464
- # Position loss only for samples with GT vertices
465
- pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
466
- score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
467
- else:
468
- pos_loss = torch.tensor(0.0, device=device)
469
- score_loss = torch.tensor(0.0, device=device)
470
-
471
- # Classification loss for all samples
472
- class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
473
-
474
- # Combined loss
475
- total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
476
-
477
- # Backward pass
478
- total_batch_loss.backward()
479
-
480
- # Gradient clipping for stability
481
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
482
-
483
- optimizer.step()
484
-
485
- total_loss += total_batch_loss.item()
486
- total_pos_loss += pos_loss.item()
487
- total_score_loss += score_loss.item()
488
- total_class_loss += class_loss.item()
489
- num_batches += 1
490
-
491
- if batch_idx % 50 == 0:
492
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
493
- f"Total Loss: {total_batch_loss.item():.6f}, "
494
- f"Pos Loss: {pos_loss.item():.6f}, "
495
- f"Score Loss: {score_loss.item():.6f}, "
496
- f"Class Loss: {class_loss.item():.6f}")
497
-
498
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
499
- avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
500
- avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
501
- avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
502
-
503
- print(f"Epoch {epoch+1}/{epochs} completed, "
504
- f"Avg Total Loss: {avg_loss:.6f}, "
505
- f"Avg Pos Loss: {avg_pos_loss:.6f}, "
506
- f"Avg Score Loss: {avg_score_loss:.6f}, "
507
- f"Avg Class Loss: {avg_class_loss:.6f}")
508
-
509
- scheduler.step()
510
-
511
- # Save model checkpoint every 10 epochs
512
- if (epoch + 1) % 10 == 0:
513
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
514
- torch.save({
515
- 'model_state_dict': model.state_dict(),
516
- 'optimizer_state_dict': optimizer.state_dict(),
517
- 'epoch': epoch + 1,
518
- 'loss': avg_loss,
519
- }, checkpoint_path)
520
-
521
- # Save the trained model
522
- torch.save({
523
- 'model_state_dict': model.state_dict(),
524
- 'optimizer_state_dict': optimizer.state_dict(),
525
- 'epoch': epochs,
526
- }, model_save_path)
527
-
528
- print(f"Model saved to {model_save_path}")
529
- return model
530
-
531
- def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
532
- """
533
- Load a trained FastPointNet model.
534
- Updated for 11D input.
535
- """
536
- if device is None:
537
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
538
-
539
- model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=predict_score)
540
-
541
- checkpoint = torch.load(model_path, map_location=device)
542
- model.load_state_dict(checkpoint['model_state_dict'])
543
-
544
- model.to(device)
545
- model.eval()
546
-
547
- return model
548
-
549
- def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]:
550
- """
551
- Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet.
552
- Updated for 11D patches.
553
- """
554
- if device is None:
555
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
556
-
557
- patch_11d = patch['patch_11d'] # (N, 11) - Updated for 11D
558
-
559
- # Prepare input
560
- max_points = 1024
561
- num_points = patch_11d.shape[0]
562
-
563
- if num_points >= max_points:
564
- # Sample points
565
- indices = np.random.choice(num_points, max_points, replace=False)
566
- patch_sampled = patch_11d[indices]
567
- else:
568
- # Pad with zeros
569
- patch_sampled = np.zeros((max_points, 11)) # Updated for 11D
570
- patch_sampled[:num_points] = patch_11d
571
-
572
- # Convert to tensor
573
- patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 11, max_points)
574
- patch_tensor = patch_tensor.to(device)
575
-
576
- # Predict
577
- with torch.no_grad():
578
- outputs = model(patch_tensor)
579
-
580
- if model.predict_score and model.predict_class:
581
- position, score, classification = outputs
582
- position = position.cpu().numpy().squeeze()
583
- score = score.cpu().numpy().squeeze()
584
- classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
585
- elif model.predict_score:
586
- position, score = outputs
587
- position = position.cpu().numpy().squeeze()
588
- score = score.cpu().numpy().squeeze()
589
- classification = None
590
- elif model.predict_class:
591
- position, classification = outputs
592
- position = position.cpu().numpy().squeeze()
593
- score = None
594
- classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
595
- else:
596
- position = outputs
597
- position = position.cpu().numpy().squeeze()
598
- score = None
599
- classification = None
600
-
601
- # Apply offset correction
602
- offset = patch['cluster_center']
603
- position += offset
604
-
605
- return position, score, classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fast_voxel.py DELETED
@@ -1,591 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- import pickle
7
- from torch.utils.data import Dataset, DataLoader
8
- from typing import List, Dict, Tuple, Optional
9
- import json
10
-
11
- class Fast3DCNN(nn.Module):
12
- """
13
- Fast 3D CNN implementation for 3D vertex prediction from voxelized point cloud patches.
14
- Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
15
- Uses voxelization and 3D convolutions instead of PointNet architecture.
16
- """
17
- def __init__(self, input_channels=7, output_dim=3, voxel_size=32, predict_score=True, predict_class=True, num_classes=1):
18
- super(Fast3DCNN, self).__init__()
19
- self.voxel_size = voxel_size
20
- self.predict_score = predict_score
21
- self.predict_class = predict_class
22
- self.num_classes = num_classes
23
-
24
- # 3D Convolutional layers for feature extraction
25
- self.conv1 = nn.Conv3d(input_channels, 64, kernel_size=3, padding=1)
26
- self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
27
- self.conv3 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
28
- self.conv4 = nn.Conv3d(256, 512, kernel_size=3, padding=1)
29
- self.conv5 = nn.Conv3d(512, 512, kernel_size=3, padding=1)
30
-
31
- # Additional convolutional layers for deeper feature extraction
32
- self.conv6 = nn.Conv3d(512, 1024, kernel_size=3, padding=1)
33
-
34
- # Batch normalization layers
35
- self.bn1 = nn.BatchNorm3d(64)
36
- self.bn2 = nn.BatchNorm3d(128)
37
- self.bn3 = nn.BatchNorm3d(256)
38
- self.bn4 = nn.BatchNorm3d(512)
39
- self.bn5 = nn.BatchNorm3d(512)
40
- self.bn6 = nn.BatchNorm3d(1024)
41
-
42
- # Max pooling layers
43
- self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
44
-
45
- # Calculate the size after convolutions and pooling
46
- # Starting with voxel_size^3, after 3 pooling operations: voxel_size / 8
47
- final_size = voxel_size // 8
48
- flattened_size = 1024 * (final_size ** 3)
49
-
50
- # Adaptive pooling to handle variable sizes
51
- self.adaptive_pool = nn.AdaptiveAvgPool3d((4, 4, 4))
52
- flattened_size = 1024 * 4 * 4 * 4
53
-
54
- # Shared fully connected layers
55
- self.shared_fc1 = nn.Linear(flattened_size, 1024)
56
- self.shared_fc2 = nn.Linear(1024, 512)
57
-
58
- # Position prediction head
59
- self.pos_fc1 = nn.Linear(512, 512)
60
- self.pos_fc2 = nn.Linear(512, 256)
61
- self.pos_fc3 = nn.Linear(256, 128)
62
- self.pos_fc4 = nn.Linear(128, output_dim)
63
-
64
- # Score prediction head
65
- if self.predict_score:
66
- self.score_fc1 = nn.Linear(512, 512)
67
- self.score_fc2 = nn.Linear(512, 256)
68
- self.score_fc3 = nn.Linear(256, 128)
69
- self.score_fc4 = nn.Linear(128, 64)
70
- self.score_fc5 = nn.Linear(64, 1)
71
-
72
- # Classification head
73
- if self.predict_class:
74
- self.class_fc1 = nn.Linear(512, 512)
75
- self.class_fc2 = nn.Linear(512, 256)
76
- self.class_fc3 = nn.Linear(256, 128)
77
- self.class_fc4 = nn.Linear(128, 64)
78
- self.class_fc5 = nn.Linear(64, num_classes)
79
-
80
- # Dropout layers
81
- self.dropout_light = nn.Dropout(0.2)
82
- self.dropout_medium = nn.Dropout(0.3)
83
- self.dropout_heavy = nn.Dropout(0.4)
84
-
85
- def forward(self, x):
86
- """
87
- Forward pass
88
- Args:
89
- x: (batch_size, input_channels, voxel_size, voxel_size, voxel_size) tensor
90
- Returns:
91
- Tuple containing predictions based on configuration:
92
- - position: (batch_size, output_dim) tensor of predicted 3D coordinates
93
- - score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True)
94
- - classification: (batch_size, num_classes) tensor of class logits (if predict_class=True)
95
- """
96
- batch_size = x.size(0)
97
-
98
- # 3D Convolutional feature extraction
99
- x1 = F.relu(self.bn1(self.conv1(x)))
100
- x1 = self.pool(x1)
101
-
102
- x2 = F.relu(self.bn2(self.conv2(x1)))
103
- x2 = self.pool(x2)
104
-
105
- x3 = F.relu(self.bn3(self.conv3(x2)))
106
- x3 = self.pool(x3)
107
-
108
- x4 = F.relu(self.bn4(self.conv4(x3)))
109
- x5 = F.relu(self.bn5(self.conv5(x4)))
110
- x6 = F.relu(self.bn6(self.conv6(x5)))
111
-
112
- # Adaptive pooling to ensure consistent size
113
- x6 = self.adaptive_pool(x6)
114
-
115
- # Flatten for fully connected layers
116
- global_features = x6.view(batch_size, -1)
117
-
118
- # Shared features
119
- shared1 = F.relu(self.shared_fc1(global_features))
120
- shared1 = self.dropout_light(shared1)
121
- shared2 = F.relu(self.shared_fc2(shared1))
122
- shared_features = self.dropout_medium(shared2)
123
-
124
- # Position prediction
125
- pos1 = F.relu(self.pos_fc1(shared_features))
126
- pos1 = self.dropout_light(pos1)
127
- pos2 = F.relu(self.pos_fc2(pos1))
128
- pos2 = self.dropout_medium(pos2)
129
- pos3 = F.relu(self.pos_fc3(pos2))
130
- pos3 = self.dropout_light(pos3)
131
- position = self.pos_fc4(pos3)
132
-
133
- outputs = [position]
134
-
135
- if self.predict_score:
136
- # Score prediction
137
- score1 = F.relu(self.score_fc1(shared_features))
138
- score1 = self.dropout_light(score1)
139
- score2 = F.relu(self.score_fc2(score1))
140
- score2 = self.dropout_medium(score2)
141
- score3 = F.relu(self.score_fc3(score2))
142
- score3 = self.dropout_light(score3)
143
- score4 = F.relu(self.score_fc4(score3))
144
- score4 = self.dropout_light(score4)
145
- score = F.relu(self.score_fc5(score4))
146
- outputs.append(score)
147
-
148
- if self.predict_class:
149
- # Classification prediction
150
- class1 = F.relu(self.class_fc1(shared_features))
151
- class1 = self.dropout_light(class1)
152
- class2 = F.relu(self.class_fc2(class1))
153
- class2 = self.dropout_medium(class2)
154
- class3 = F.relu(self.class_fc3(class2))
155
- class3 = self.dropout_light(class3)
156
- class4 = F.relu(self.class_fc4(class3))
157
- class4 = self.dropout_light(class4)
158
- classification = self.class_fc5(class4)
159
- outputs.append(classification)
160
-
161
- # Return outputs based on configuration
162
- if len(outputs) == 1:
163
- return outputs[0]
164
- elif len(outputs) == 2:
165
- if self.predict_score:
166
- return outputs[0], outputs[1]
167
- else:
168
- return outputs[0], outputs[1]
169
- else:
170
- return outputs[0], outputs[1], outputs[2]
171
-
172
- def voxelize_patch(patch_7d: np.ndarray, voxel_size: int = 32, patch_size: float = 1.0) -> np.ndarray:
173
- """
174
- Convert point cloud patch to voxel grid.
175
-
176
- Args:
177
- patch_7d: (N, 7) array of points with [x, y, z, r, g, b, filtered_flag]
178
- voxel_size: Size of the voxel grid (voxel_size^3)
179
- patch_size: Physical size of the patch in world coordinates
180
-
181
- Returns:
182
- voxels: (7, voxel_size, voxel_size, voxel_size) array of voxelized features
183
- """
184
- if len(patch_7d) == 0:
185
- return np.zeros((7, voxel_size, voxel_size, voxel_size))
186
-
187
- # Extract coordinates and features
188
- coords = patch_7d[:, :3] # x, y, z
189
- features = patch_7d[:, 3:] # r, g, b, filtered_flag
190
-
191
- # Normalize coordinates to [0, voxel_size-1]
192
- coords_min = coords.min(axis=0)
193
- coords_max = coords.max(axis=0)
194
- coords_range = coords_max - coords_min
195
- coords_range[coords_range == 0] = 1 # Avoid division by zero
196
-
197
- normalized_coords = (coords - coords_min) / coords_range * (voxel_size - 1)
198
- voxel_indices = normalized_coords.astype(int)
199
-
200
- # Clip to valid range
201
- voxel_indices = np.clip(voxel_indices, 0, voxel_size - 1)
202
-
203
- # Initialize voxel grid
204
- voxels = np.zeros((7, voxel_size, voxel_size, voxel_size))
205
-
206
- # Fill voxels with features (average if multiple points fall in same voxel)
207
- counts = np.zeros((voxel_size, voxel_size, voxel_size))
208
-
209
- for i in range(len(patch_7d)):
210
- x, y, z = voxel_indices[i]
211
- # Store normalized coordinates in first 3 channels
212
- voxels[0, x, y, z] += normalized_coords[i, 0] / (voxel_size - 1) # normalized x
213
- voxels[1, x, y, z] += normalized_coords[i, 1] / (voxel_size - 1) # normalized y
214
- voxels[2, x, y, z] += normalized_coords[i, 2] / (voxel_size - 1) # normalized z
215
- # Store RGB and filtered_flag in remaining channels
216
- voxels[3:, x, y, z] += features[i]
217
- counts[x, y, z] += 1
218
-
219
- # Average features where multiple points exist
220
- mask = counts > 0
221
- for c in range(7):
222
- voxels[c][mask] /= counts[mask]
223
-
224
- return voxels
225
-
226
- class VoxelPatchDataset(Dataset):
227
- """
228
- Dataset class for loading saved patches and converting them to voxel grids for 3D CNN training.
229
- """
230
-
231
- def __init__(self, dataset_dir: str, voxel_size: int = 32, augment: bool = False):
232
- self.dataset_dir = dataset_dir
233
- self.voxel_size = voxel_size
234
- self.augment = augment
235
-
236
- # Load patch files
237
- self.patch_files = []
238
- for file in os.listdir(dataset_dir):
239
- if file.endswith('.pkl'):
240
- self.patch_files.append(os.path.join(dataset_dir, file))
241
-
242
- print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
243
-
244
- def __len__(self):
245
- return len(self.patch_files)
246
-
247
- def __getitem__(self, idx):
248
- """
249
- Load and process a patch for training.
250
- Returns:
251
- voxel_data: (7, voxel_size, voxel_size, voxel_size) tensor of voxelized data
252
- target: (3,) tensor of target 3D coordinates
253
- valid_mask: scalar tensor indicating if this is a valid sample
254
- distance_to_gt: scalar tensor of distance from initial prediction to GT
255
- classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
256
- """
257
- patch_file = self.patch_files[idx]
258
-
259
- with open(patch_file, 'rb') as f:
260
- patch_info = pickle.load(f)
261
-
262
- patch_7d = patch_info['patch_7d'] # (N, 7)
263
- target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
264
- initial_pred = patch_info.get('cluster_center', None) # (3,) or None
265
-
266
- # Determine classification label based on GT vertex presence
267
- has_gt_vertex = 1.0 if target is not None else 0.0
268
-
269
- # Handle patches without ground truth
270
- if target is None:
271
- target = np.zeros(3)
272
- else:
273
- target = np.array(target)
274
-
275
- # Voxelize the patch
276
- voxel_data = voxelize_patch(patch_7d, self.voxel_size)
277
-
278
- # Data augmentation (only if GT vertex is present)
279
- if self.augment and has_gt_vertex > 0:
280
- voxel_data, target = self._augment_voxels(voxel_data, target)
281
-
282
- # Convert to tensors (copy arrays to handle negative strides from augmentation)
283
- voxel_tensor = torch.from_numpy(voxel_data.copy()).float() # (7, voxel_size, voxel_size, voxel_size)
284
- target_tensor = torch.from_numpy(target.copy()).float() # (3,)
285
-
286
- # Valid mask (check if voxel grid has any non-zero values)
287
- valid_mask = torch.tensor(1.0 if voxel_data.sum() > 0 else 0.0)
288
-
289
- # Handle initial_pred
290
- if initial_pred is not None:
291
- initial_pred_tensor = torch.from_numpy(initial_pred).float()
292
- else:
293
- initial_pred_tensor = torch.zeros(3).float()
294
-
295
- # Classification tensor
296
- classification_tensor = torch.tensor(has_gt_vertex).float()
297
-
298
- return voxel_tensor, target_tensor, valid_mask, initial_pred_tensor, classification_tensor
299
-
300
- def _augment_voxels(self, voxel_data: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
301
- """
302
- Apply data augmentation to voxel data.
303
- """
304
- # Random rotation around Z-axis
305
- if np.random.random() > 0.5:
306
- k = np.random.randint(1, 4) # 90, 180, or 270 degrees
307
- voxel_data = np.rot90(voxel_data, k, axes=(1, 2)) # Rotate around z-axis
308
-
309
- # Random flip
310
- if np.random.random() > 0.5:
311
- voxel_data = np.flip(voxel_data, axis=1) # Flip along x-axis
312
- if np.random.random() > 0.5:
313
- voxel_data = np.flip(voxel_data, axis=2) # Flip along y-axis
314
-
315
- return voxel_data, target
316
-
317
- def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
318
- """
319
- Save patches from prediction pipeline to create a training dataset.
320
-
321
- Args:
322
- patches: List of patch dictionaries from generate_patches()
323
- dataset_dir: Directory to save the dataset
324
- entry_id: Unique identifier for this entry/image
325
- """
326
- os.makedirs(dataset_dir, exist_ok=True)
327
-
328
- for i, patch in enumerate(patches):
329
- # Create unique filename
330
- filename = f"{entry_id}_patch_{i}.pkl"
331
- filepath = os.path.join(dataset_dir, filename)
332
-
333
- # Skip if file already exists
334
- if os.path.exists(filepath):
335
- continue
336
-
337
- # Save patch data
338
- with open(filepath, 'wb') as f:
339
- pickle.dump(patch, f)
340
-
341
- print(f"Saved {len(patches)} patches for entry {entry_id}")
342
-
343
- # Create dataloader with custom collate function to filter invalid samples
344
- def collate_fn(batch):
345
- valid_batch = []
346
- for voxel_data, target, valid_mask, initial_pred, classification in batch:
347
- # Filter out invalid samples
348
- if valid_mask > 0:
349
- valid_batch.append((voxel_data, target, valid_mask, initial_pred, classification))
350
-
351
- if len(valid_batch) == 0:
352
- return None
353
-
354
- # Stack valid samples
355
- voxel_data = torch.stack([item[0] for item in valid_batch])
356
- targets = torch.stack([item[1] for item in valid_batch])
357
- valid_masks = torch.stack([item[2] for item in valid_batch])
358
- initial_preds = torch.stack([item[3] for item in valid_batch])
359
- classifications = torch.stack([item[4] for item in valid_batch])
360
-
361
- return voxel_data, targets, valid_masks, initial_preds, classifications
362
-
363
- # Initialize weights using Xavier/Glorot initialization
364
- def init_weights(m):
365
- if isinstance(m, (nn.Conv3d, nn.Conv1d)):
366
- nn.init.xavier_uniform_(m.weight)
367
- if m.bias is not None:
368
- nn.init.zeros_(m.bias)
369
- elif isinstance(m, nn.Linear):
370
- nn.init.xavier_uniform_(m.weight)
371
- if m.bias is not None:
372
- nn.init.zeros_(m.bias)
373
- elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm1d)):
374
- nn.init.ones_(m.weight)
375
- nn.init.zeros_(m.bias)
376
-
377
- def train_3dcnn(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 16, lr: float = 0.001,
378
- voxel_size: int = 32, score_weight: float = 0.1, class_weight: float = 0.5):
379
- """
380
- Train the Fast3DCNN model on saved patches.
381
-
382
- Args:
383
- dataset_dir: Directory containing saved patch files
384
- model_save_path: Path to save the trained model
385
- epochs: Number of training epochs
386
- batch_size: Training batch size (reduced due to memory requirements of 3D conv)
387
- lr: Learning rate
388
- voxel_size: Size of voxel grid
389
- score_weight: Weight for the distance prediction loss
390
- class_weight: Weight for the classification loss
391
- """
392
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
393
- print(f"Training on device: {device}")
394
-
395
- # Create dataset and dataloader
396
- dataset = VoxelPatchDataset(dataset_dir, voxel_size=voxel_size, augment=True)
397
- print(f"Dataset loaded with {len(dataset)} samples")
398
-
399
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4,
400
- collate_fn=collate_fn, drop_last=True)
401
-
402
- # Initialize model with score and classification prediction
403
- model = Fast3DCNN(input_channels=7, output_dim=3, voxel_size=voxel_size,
404
- predict_score=True, predict_class=True, num_classes=1)
405
-
406
- model.apply(init_weights)
407
- model.to(device)
408
-
409
- # Loss functions
410
- position_criterion = nn.MSELoss()
411
- score_criterion = nn.MSELoss()
412
- classification_criterion = nn.BCEWithLogitsLoss()
413
-
414
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
415
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
416
-
417
- # Training loop
418
- model.train()
419
- for epoch in range(epochs):
420
- total_loss = 0.0
421
- total_pos_loss = 0.0
422
- total_score_loss = 0.0
423
- total_class_loss = 0.0
424
- num_batches = 0
425
-
426
- for batch_idx, batch_data in enumerate(dataloader):
427
- if batch_data is None: # Skip invalid batches
428
- continue
429
-
430
- voxel_data, targets, valid_masks, initial_preds, classifications = batch_data
431
- voxel_data = voxel_data.to(device) # (batch_size, 7, voxel_size, voxel_size, voxel_size)
432
- targets = targets.to(device) # (batch_size, 3)
433
- classifications = classifications.to(device) # (batch_size,)
434
-
435
- # Forward pass
436
- optimizer.zero_grad()
437
- predictions, predicted_scores, predicted_classes = model(voxel_data)
438
-
439
- # Compute actual distance from predictions to targets
440
- actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
441
-
442
- # Only compute position and score losses for samples with GT vertices
443
- has_gt_mask = classifications > 0.5
444
-
445
- if has_gt_mask.sum() > 0:
446
- # Position loss only for samples with GT vertices
447
- pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
448
- score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
449
- else:
450
- pos_loss = torch.tensor(0.0, device=device)
451
- score_loss = torch.tensor(0.0, device=device)
452
-
453
- # Classification loss for all samples
454
- class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
455
-
456
- # Combined loss
457
- total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
458
-
459
- # Backward pass
460
- total_batch_loss.backward()
461
- optimizer.step()
462
-
463
- total_loss += total_batch_loss.item()
464
- total_pos_loss += pos_loss.item()
465
- total_score_loss += score_loss.item()
466
- total_class_loss += class_loss.item()
467
- num_batches += 1
468
-
469
- if batch_idx % 50 == 0:
470
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
471
- f"Total Loss: {total_batch_loss.item():.6f}, "
472
- f"Pos Loss: {pos_loss.item():.6f}, "
473
- f"Score Loss: {score_loss.item():.6f}, "
474
- f"Class Loss: {class_loss.item():.6f}")
475
-
476
- avg_loss = total_loss / num_batches if num_batches > 0 else 0
477
- avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
478
- avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
479
- avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
480
-
481
- print(f"Epoch {epoch+1}/{epochs} completed, "
482
- f"Avg Total Loss: {avg_loss:.6f}, "
483
- f"Avg Pos Loss: {avg_pos_loss:.6f}, "
484
- f"Avg Score Loss: {avg_score_loss:.6f}, "
485
- f"Avg Class Loss: {avg_class_loss:.6f}")
486
-
487
- scheduler.step()
488
-
489
- # Save model checkpoint every epoch
490
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
491
- torch.save({
492
- 'model_state_dict': model.state_dict(),
493
- 'optimizer_state_dict': optimizer.state_dict(),
494
- 'epoch': epoch + 1,
495
- 'loss': avg_loss,
496
- }, checkpoint_path)
497
-
498
- # Save the trained model
499
- torch.save({
500
- 'model_state_dict': model.state_dict(),
501
- 'optimizer_state_dict': optimizer.state_dict(),
502
- 'epoch': epochs,
503
- }, model_save_path)
504
-
505
- print(f"Model saved to {model_save_path}")
506
- return model
507
-
508
- def load_3dcnn_model(model_path: str, device: torch.device = None, voxel_size: int = 32, predict_score: bool = True) -> Fast3DCNN:
509
- """
510
- Load a trained Fast3DCNN model.
511
-
512
- Args:
513
- model_path: Path to the saved model
514
- device: Device to load the model on
515
- voxel_size: Size of voxel grid
516
- predict_score: Whether the model predicts scores
517
-
518
- Returns:
519
- Loaded Fast3DCNN model
520
- """
521
- if device is None:
522
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
523
-
524
- model = Fast3DCNN(input_channels=7, output_dim=3, voxel_size=voxel_size, predict_score=predict_score)
525
-
526
- checkpoint = torch.load(model_path, map_location=device)
527
- model.load_state_dict(checkpoint['model_state_dict'])
528
-
529
- model.to(device)
530
- model.eval()
531
-
532
- return model
533
-
534
- def predict_vertex_from_patch_voxel(model: Fast3DCNN, patch: np.ndarray, device: torch.device = None, voxel_size: int = 32) -> Tuple[np.ndarray, float, float]:
535
- """
536
- Predict 3D vertex coordinates, confidence score, and classification from a patch using trained 3D CNN.
537
-
538
- Args:
539
- model: Trained Fast3DCNN model
540
- patch: Dictionary containing patch data with 'patch_7d' and 'cluster_center' keys
541
- device: Device to run prediction on
542
- voxel_size: Size of voxel grid
543
-
544
- Returns:
545
- tuple of (predicted_coordinates, confidence_score, classification_score)
546
- predicted_coordinates: (3,) numpy array of predicted 3D coordinates
547
- confidence_score: float representing predicted distance to GT (lower is better)
548
- classification_score: float representing probability of GT vertex presence (0-1)
549
- """
550
- if device is None:
551
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
552
-
553
- patch_7d = patch['patch_7d'] # (N, 7)
554
-
555
- # Voxelize the patch
556
- voxel_data = voxelize_patch(patch_7d, voxel_size)
557
-
558
- # Convert to tensor
559
- voxel_tensor = torch.from_numpy(voxel_data).float().unsqueeze(0) # (1, 7, voxel_size, voxel_size, voxel_size)
560
- voxel_tensor = voxel_tensor.to(device)
561
-
562
- # Predict
563
- with torch.no_grad():
564
- outputs = model(voxel_tensor)
565
-
566
- if model.predict_score and model.predict_class:
567
- position, score, classification = outputs
568
- position = position.cpu().numpy().squeeze()
569
- score = score.cpu().numpy().squeeze()
570
- classification = torch.sigmoid(classification).cpu().numpy().squeeze()
571
- elif model.predict_score:
572
- position, score = outputs
573
- position = position.cpu().numpy().squeeze()
574
- score = score.cpu().numpy().squeeze()
575
- classification = None
576
- elif model.predict_class:
577
- position, classification = outputs
578
- position = position.cpu().numpy().squeeze()
579
- score = None
580
- classification = torch.sigmoid(classification).cpu().numpy().squeeze()
581
- else:
582
- position = outputs
583
- position = position.cpu().numpy().squeeze()
584
- score = None
585
- classification = None
586
-
587
- # Apply offset correction
588
- offset = patch['cluster_center']
589
- position += offset
590
-
591
- return position, score, classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
find_best_results.py CHANGED
@@ -1,5 +1,12 @@
1
  #!/usr/bin/env python3
2
  # filepath: /home/skvrnjan/hoho/find_best_results.py
 
 
 
 
 
 
 
3
  import os
4
  import re
5
 
 
1
  #!/usr/bin/env python3
2
  # filepath: /home/skvrnjan/hoho/find_best_results.py
3
+ # This script scans a directory for result files (text files typically starting
4
+ # with "results_vt" within subdirectories matching a given prefix).
5
+ # It parses these files to extract metrics like Mean HSS, Mean F1, Mean IoU,
6
+ # Vertex Threshold, Edge Threshold, and Only Predicted Connections.
7
+ # The script then identifies and prints the top N results (default N=10)
8
+ # for Mean HSS, Mean F1, and Mean IoU, along with their associated configuration
9
+ # parameters.
10
  import os
11
  import re
12
 
fully_deep.py DELETED
@@ -1,1082 +0,0 @@
1
- import torch
2
- import os
3
- import pickle
4
- from torch.utils.data import Dataset, DataLoader
5
- import numpy as np
6
- from scipy.optimize import linear_sum_assignment
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- # =============================================================================
11
- # CONFIGURATION PARAMETERS
12
- # =============================================================================
13
-
14
- # Dataset Configuration
15
- DATA_DIR = '/mnt/personal/skvrnjan/hoho_fully'
16
- SPLIT = 'train'
17
- MAX_POINTS = 8096
18
- BATCH_SIZE = 32
19
- NUM_WORKERS = 8
20
-
21
- # Model Architecture Parameters
22
- PC_INPUT_FEATURES = 3
23
- PC_ENCODER_OUTPUT_FEATURES = 128
24
- MAX_VERTICES = 50
25
- VERTEX_COORD_DIM = 3
26
- GNN_HIDDEN_DIM = 64
27
- NUM_GNN_LAYERS = 2
28
- HIDDEN_DIM = 256
29
- NUM_DECODER_LAYERS = 3
30
- NUM_HEADS = 8
31
-
32
- # PointNet2 Encoder Parameters
33
- SA1_NPOINT = 1024
34
- SA1_RADIUS = 0.2
35
- SA1_NSAMPLE = 32
36
- SA1_MLP = [64, 64, 128]
37
-
38
- SA2_NPOINT = 256
39
- SA2_RADIUS = 0.4
40
- SA2_NSAMPLE = 64
41
- SA2_MLP = [128, 128, 256]
42
-
43
- SA3_MLP = [256, 512, 1024] # Global pooling layer
44
-
45
- FP3_MLP = [256, 256]
46
- FP2_MLP = [256, 128]
47
- FP1_MLP = [128, 128] # Will add PC_ENCODER_OUTPUT_FEATURES at the end
48
-
49
- # Vertex Prediction Head Parameters
50
- VERTEX_TRANSFORMER_DROPOUT = 0.1
51
- VERTEX_TRANSFORMER_FFN_RATIO = 4
52
-
53
- # Edge Prediction Head Parameters
54
- EDGE_GNN_NUM_HEADS = 4
55
- EDGE_GNN_DROPOUT = 0.1
56
- EDGE_K_NEIGHBORS = 8
57
-
58
- # Training Configuration
59
- NUM_EPOCHS = 100
60
- LEARNING_RATE = 1e-4
61
- WEIGHT_DECAY = 1e-5
62
- GRADIENT_CLIP_MAX_NORM = 1.0
63
-
64
- # Loss Weights
65
- VERTEX_LOSS_WEIGHT = 1.0
66
- EDGE_LOSS_WEIGHT = 0.5
67
- CONFIDENCE_LOSS_WEIGHT = 0.3
68
-
69
- # Learning Rate Scheduler Parameters
70
- LR_SCHEDULER_FACTOR = 0.5
71
- LR_SCHEDULER_PATIENCE = 10
72
-
73
- # Checkpoint and Logging
74
- CHECKPOINT_SAVE_FREQUENCY = 1 # Save every N epochs
75
- LOG_FREQUENCY = 10 # Print progress every N batches
76
-
77
- # Device Configuration
78
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
-
80
- # =============================================================================
81
- # MODEL IMPLEMENTATION
82
- # =============================================================================
83
-
84
- # You would likely need a library like torch_geometric for GNNs
85
- # from torch_geometric.nn import GATConv, EdgeConv # Example GNN layers
86
-
87
- # --- 1. Point Cloud Encoder Backbone (Placeholder) ---
88
- class PointNet2Encoder(nn.Module):
89
- def __init__(self, input_features, output_features):
90
- super().__init__()
91
- self.input_features = input_features
92
- self.output_features = output_features
93
-
94
- # Set Abstraction layers - adjusted for 8096 input points
95
- self.sa1 = SetAbstractionLayer(
96
- npoint=SA1_NPOINT, radius=SA1_RADIUS, nsample=SA1_NSAMPLE,
97
- in_channel=input_features + 3, mlp=SA1_MLP
98
- )
99
- self.sa2 = SetAbstractionLayer(
100
- npoint=SA2_NPOINT, radius=SA2_RADIUS, nsample=SA2_NSAMPLE,
101
- in_channel=SA1_MLP[-1] + 3, mlp=SA2_MLP
102
- )
103
- self.sa3 = SetAbstractionLayer(
104
- npoint=None, radius=None, nsample=None, # Global pooling
105
- in_channel=SA2_MLP[-1] + 3, mlp=SA3_MLP
106
- )
107
-
108
- # Feature Propagation layers for point-wise features
109
- self.fp3 = FeaturePropagationLayer(in_channel=SA3_MLP[-1] + SA2_MLP[-1], mlp=FP3_MLP)
110
- self.fp2 = FeaturePropagationLayer(in_channel=FP3_MLP[-1] + SA1_MLP[-1], mlp=FP2_MLP)
111
- self.fp1 = FeaturePropagationLayer(in_channel=FP2_MLP[-1] + input_features, mlp=FP1_MLP + [output_features])
112
-
113
- def forward(self, xyz):
114
- # xyz: (B, N, 3) where N = 8096
115
- B, N, _ = xyz.shape
116
-
117
- # Initial features (can be empty or coordinates)
118
- points = xyz if self.input_features == 3 else None
119
-
120
- # Set Abstraction
121
- l1_xyz, l1_points = self.sa1(xyz, points) # 8096 -> 1024 points
122
- l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # 1024 -> 256 points
123
- l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) # 256 -> 1 point (global)
124
-
125
- # Feature Propagation
126
- l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
127
- l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
128
- l0_points = self.fp1(xyz, l1_xyz, points, l1_points)
129
-
130
- # Global feature from the most abstract level
131
- global_feature = l3_points.squeeze(-1) # (B, 1024)
132
-
133
- return l0_points, global_feature # (B, 8096, output_features), (B, 1024)
134
-
135
-
136
- class SetAbstractionLayer(nn.Module):
137
- def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
138
- super().__init__()
139
- self.npoint = npoint
140
- self.radius = radius
141
- self.nsample = nsample
142
- self.group_all = group_all
143
-
144
- self.mlp_convs = nn.ModuleList()
145
- self.mlp_bns = nn.ModuleList()
146
- last_channel = in_channel
147
- for out_channel in mlp:
148
- self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
149
- self.mlp_bns.append(nn.BatchNorm2d(out_channel))
150
- last_channel = out_channel
151
-
152
- def forward(self, xyz, points):
153
- # xyz: (B, N, 3)
154
- # points: (B, N, C) or None
155
- B, N, C = xyz.shape
156
-
157
- if self.group_all or self.npoint is None:
158
- # Global pooling
159
- new_xyz = xyz.mean(dim=1, keepdim=True) # (B, 1, 3)
160
- if points is not None:
161
- new_points = torch.cat([xyz, points], dim=-1) # (B, N, 3+C)
162
- new_points = new_points.transpose(1, 2).unsqueeze(-1) # (B, 3+C, N, 1)
163
- else:
164
- new_points = xyz.transpose(1, 2).unsqueeze(-1) # (B, 3, N, 1)
165
- else:
166
- # Farthest Point Sampling
167
- fps_idx = farthest_point_sample(xyz, self.npoint) # (B, npoint)
168
- new_xyz = index_points(xyz, fps_idx) # (B, npoint, 3)
169
-
170
- # Ball Query
171
- idx = ball_query(self.radius, self.nsample, xyz, new_xyz) # (B, npoint, nsample)
172
- grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, 3)
173
- grouped_xyz_norm = grouped_xyz - new_xyz.unsqueeze(2) # Relative positions
174
-
175
- if points is not None:
176
- grouped_points = index_points(points, idx) # (B, npoint, nsample, C)
177
- new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # (B, npoint, nsample, 3+C)
178
- else:
179
- new_points = grouped_xyz_norm # (B, npoint, nsample, 3)
180
-
181
- new_points = new_points.permute(0, 3, 1, 2) # (B, 3+C, npoint, nsample)
182
-
183
- # MLP
184
- for i, conv in enumerate(self.mlp_convs):
185
- bn = self.mlp_bns[i]
186
- new_points = F.relu(bn(conv(new_points)))
187
-
188
- # Max pooling
189
- new_points = torch.max(new_points, dim=-1)[0] # (B, mlp[-1], npoint)
190
- new_points = new_points.transpose(1, 2) # (B, npoint, mlp[-1])
191
-
192
- return new_xyz, new_points
193
-
194
-
195
- class FeaturePropagationLayer(nn.Module):
196
- def __init__(self, in_channel, mlp):
197
- super().__init__()
198
- self.mlp_convs = nn.ModuleList()
199
- self.mlp_bns = nn.ModuleList()
200
- last_channel = in_channel
201
- for out_channel in mlp:
202
- self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
203
- self.mlp_bns.append(nn.BatchNorm1d(out_channel))
204
- last_channel = out_channel
205
-
206
- def forward(self, xyz1, xyz2, points1, points2):
207
- # xyz1: (B, N1, 3) - target points
208
- # xyz2: (B, N2, 3) - source points
209
- # points1: (B, N1, C1) - target features
210
- # points2: (B, N2, C2) - source features
211
-
212
- # Interpolate features from xyz2 to xyz1
213
- if points2 is not None:
214
- interpolated_points = interpolate_features(xyz1, xyz2, points2) # (B, N1, C2)
215
- if points1 is not None:
216
- # Ensure both tensors have the same number of points (N1)
217
- assert points1.shape[1] == interpolated_points.shape[1], f"Point count mismatch: {points1.shape[1]} vs {interpolated_points.shape[1]}"
218
- new_points = torch.cat([points1, interpolated_points], dim=-1) # (B, N1, C1+C2)
219
- else:
220
- new_points = interpolated_points
221
- else:
222
- new_points = points1
223
-
224
- # Handle None case
225
- if new_points is None:
226
- return None
227
-
228
- # MLP
229
- new_points = new_points.transpose(1, 2) # (B, C, N1)
230
- for i, conv in enumerate(self.mlp_convs):
231
- bn = self.mlp_bns[i]
232
- new_points = F.relu(bn(conv(new_points)))
233
-
234
- return new_points.transpose(1, 2) # (B, N1, mlp[-1])
235
-
236
-
237
- def farthest_point_sample(xyz, npoint):
238
- """Farthest Point Sampling"""
239
- device = xyz.device
240
- B, N, C = xyz.shape
241
- centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
242
- distance = torch.ones(B, N).to(device) * 1e10
243
- farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
244
-
245
- for i in range(npoint):
246
- centroids[:, i] = farthest
247
- centroid = xyz[torch.arange(B), farthest, :].view(B, 1, 3)
248
- dist = torch.sum((xyz - centroid) ** 2, -1)
249
- mask = dist < distance
250
- distance[mask] = dist[mask]
251
- farthest = torch.max(distance, -1)[1]
252
-
253
- return centroids
254
-
255
-
256
- def ball_query(radius, nsample, xyz, new_xyz):
257
- """Ball Query"""
258
- device = xyz.device
259
- B, N, C = xyz.shape
260
- _, S, _ = new_xyz.shape
261
-
262
- group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
263
- sqrdists = square_distance(new_xyz, xyz)
264
- group_idx[sqrdists > radius ** 2] = N
265
- group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
266
- group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
267
- mask = group_idx == N
268
- group_idx[mask] = group_first[mask]
269
-
270
- # If group_first[mask] was N (i.e., no points in the ball for a centroid),
271
- # group_idx can still contain N. Clamp N to 0 to ensure valid indices.
272
- # N corresponds to xyz.shape[1], which is guaranteed to be > 0 by the dataloader logic.
273
- group_idx[group_idx == N] = 0
274
-
275
- return group_idx
276
-
277
-
278
- def square_distance(src, dst):
279
- """Calculate squared distance between each two points"""
280
- B, N, _ = src.shape
281
- _, M, _ = dst.shape
282
- dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
283
- dist += torch.sum(src ** 2, -1).view(B, N, 1)
284
- dist += torch.sum(dst ** 2, -1).view(B, 1, M)
285
- return dist
286
-
287
-
288
- def index_points(points, idx):
289
- """Index points using given indices"""
290
- device = points.device
291
- B = points.shape[0]
292
- view_shape = list(idx.shape)
293
- view_shape[1:] = [1] * (len(view_shape) - 1)
294
- repeat_shape = list(idx.shape)
295
- repeat_shape[0] = 1
296
- batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
297
- new_points = points[batch_indices, idx, :]
298
- return new_points
299
-
300
-
301
- def interpolate_features(xyz1, xyz2, points2):
302
- """Interpolate features using inverse distance weighting"""
303
- B, N1, C = xyz1.shape
304
- _, N2, _ = xyz2.shape
305
-
306
- if N2 == 1:
307
- # If only one point, broadcast to all target points
308
- interpolated_points = points2.expand(B, N1, -1)
309
- else:
310
- # Find 3 nearest neighbors and interpolate
311
- dists = square_distance(xyz1, xyz2) # (B, N1, N2)
312
- dists, idx = dists.sort(dim=-1)
313
-
314
- # Use min(3, N2) neighbors to handle cases with fewer source points
315
- k = min(3, N2)
316
- dists, idx = dists[:, :, :k], idx[:, :, :k]
317
-
318
- # Inverse distance weighting
319
- dists[dists < 1e-10] = 1e-10
320
- weight = 1.0 / dists # (B, N1, k)
321
- weight = weight / torch.sum(weight, dim=-1, keepdim=True) # Normalize
322
-
323
- # Interpolate
324
- interpolated_points = torch.sum(
325
- index_points(points2, idx) * weight.view(B, N1, k, 1), dim=2
326
- )
327
-
328
- return interpolated_points
329
-
330
- # --- 2. Vertex Prediction Head (Transformer-based) ---
331
- class VertexPredictionHead(nn.Module):
332
- def __init__(self, point_feature_dim, global_feature_dim, max_vertices, vertex_coord_dim=3,
333
- hidden_dim=256, num_decoder_layers=3, num_heads=8):
334
- super().__init__()
335
- self.max_vertices = max_vertices
336
- self.vertex_coord_dim = vertex_coord_dim
337
- self.hidden_dim = hidden_dim
338
-
339
- # Learnable vertex queries (similar to DETR object queries)
340
- self.vertex_queries = nn.Parameter(torch.randn(max_vertices, hidden_dim))
341
-
342
- # Project global feature to hidden dimension
343
- self.global_proj = nn.Linear(global_feature_dim, 1)
344
-
345
- # Project point features to hidden dimension for cross-attention
346
- self.point_proj = nn.Linear(point_feature_dim, hidden_dim)
347
-
348
- # Transformer decoder layers
349
- decoder_layer = nn.TransformerDecoderLayer(
350
- d_model=hidden_dim,
351
- nhead=num_heads,
352
- dim_feedforward=hidden_dim * VERTEX_TRANSFORMER_FFN_RATIO,
353
- dropout=VERTEX_TRANSFORMER_DROPOUT,
354
- batch_first=True
355
- )
356
- self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
357
-
358
- # Output heads
359
- self.vertex_coord_head = nn.Sequential(
360
- nn.Linear(hidden_dim, hidden_dim),
361
- nn.ReLU(),
362
- nn.Linear(hidden_dim, vertex_coord_dim)
363
- )
364
-
365
- # Confidence/existence head (predicts if vertex exists)
366
- self.vertex_conf_head = nn.Sequential(
367
- nn.Linear(hidden_dim, hidden_dim),
368
- nn.ReLU(),
369
- nn.Linear(hidden_dim, 1)
370
- )
371
-
372
- # Position encoding for point features
373
- self.pos_encoding = nn.Sequential(
374
- nn.Linear(3, hidden_dim // 2),
375
- nn.ReLU(),
376
- nn.Linear(hidden_dim // 2, hidden_dim)
377
- )
378
-
379
- def forward(self, point_features, global_feature, point_coords=None):
380
- # point_features: (B, N, point_feature_dim)
381
- # global_feature: (B, global_feature_dim)
382
- # point_coords: (B, N, 3) - optional point coordinates for positional encoding
383
-
384
- batch_size = point_features.shape[0]
385
-
386
- # Project features to hidden dimension
387
- point_features_proj = self.point_proj(point_features) # (B, N, hidden_dim)
388
-
389
- # Add positional encoding if coordinates are provided
390
- if point_coords is not None:
391
- pos_enc = self.pos_encoding(point_coords) # (B, N, hidden_dim)
392
- point_features_proj = point_features_proj + pos_enc
393
-
394
- # Prepare vertex queries
395
- vertex_queries = self.vertex_queries.unsqueeze(0).repeat(batch_size, 1, 1) # (B, max_vertices, hidden_dim)
396
-
397
- # Add global context to vertex queries
398
- global_proj = self.global_proj(global_feature).squeeze(-1).unsqueeze(1) # (B, 1, hidden_dim)
399
- vertex_queries = vertex_queries + global_proj # Broadcasting will handle (B, 1, hidden_dim) + (B, max_vertices, hidden_dim)
400
-
401
- # Transformer decoder: vertex queries attend to point features
402
- vertex_features = self.transformer_decoder(
403
- tgt=vertex_queries, # (B, max_vertices, hidden_dim)
404
- memory=point_features_proj # (B, N, hidden_dim)
405
- ) # (B, max_vertices, hidden_dim)
406
-
407
- # Predict vertex coordinates
408
- predicted_vertices = self.vertex_coord_head(vertex_features) # (B, max_vertices, 3)
409
-
410
- # Predict vertex confidence/existence
411
- vertex_confidence = self.vertex_conf_head(vertex_features).squeeze(-1) # (B, max_vertices)
412
-
413
- return predicted_vertices, vertex_confidence
414
-
415
- # --- 3. Edge Prediction Head (GNN-based) ---
416
- class EdgePredictionHeadGNN(nn.Module):
417
- def __init__(self, vertex_feature_dim, gnn_hidden_dim, num_gnn_layers):
418
- super().__init__()
419
- self.vertex_feature_dim = vertex_feature_dim
420
- self.gnn_hidden_dim = gnn_hidden_dim
421
- self.num_gnn_layers = num_gnn_layers
422
-
423
- # Initial vertex feature projection
424
- self.vertex_proj = nn.Linear(vertex_feature_dim, gnn_hidden_dim)
425
-
426
- # GNN layers using message passing
427
- self.gnn_layers = nn.ModuleList()
428
- for i in range(num_gnn_layers):
429
- self.gnn_layers.append(
430
- GraphAttentionLayer(
431
- in_features=gnn_hidden_dim,
432
- out_features=gnn_hidden_dim,
433
- num_heads=EDGE_GNN_NUM_HEADS,
434
- dropout=EDGE_GNN_DROPOUT
435
- )
436
- )
437
-
438
- # Edge classifier MLP
439
- self.edge_mlp = nn.Sequential(
440
- nn.Linear(gnn_hidden_dim * 2, gnn_hidden_dim),
441
- nn.ReLU(),
442
- nn.Dropout(EDGE_GNN_DROPOUT),
443
- nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 2),
444
- nn.ReLU(),
445
- nn.Linear(gnn_hidden_dim // 2, 1)
446
- )
447
-
448
- # Learnable threshold for k-NN graph construction
449
- self.k_neighbors = EDGE_K_NEIGHBORS # Number of nearest neighbors for initial graph
450
-
451
- def forward(self, vertices):
452
- # vertices: (B, num_vertices, vertex_coord_dim)
453
- batch_size, num_vertices, _ = vertices.shape
454
-
455
- # Project vertex coordinates to hidden features
456
- vertex_features = self.vertex_proj(vertices) # (B, num_vertices, gnn_hidden_dim)
457
-
458
- # Construct initial graph based on spatial proximity (k-NN)
459
- adjacency_matrix = self.construct_knn_graph(vertices, k=self.k_neighbors) # (B, num_vertices, num_vertices)
460
-
461
- # Apply GNN layers
462
- for gnn_layer in self.gnn_layers:
463
- vertex_features = gnn_layer(vertex_features, adjacency_matrix) # (B, num_vertices, gnn_hidden_dim)
464
-
465
- # Generate all possible vertex pairs
466
- idx_pairs = torch.combinations(torch.arange(num_vertices), r=2).to(vertices.device) # (num_pairs, 2)
467
-
468
- # Gather features for all vertex pairs
469
- v1_features = vertex_features[:, idx_pairs[:, 0], :] # (B, num_pairs, gnn_hidden_dim)
470
- v2_features = vertex_features[:, idx_pairs[:, 1], :] # (B, num_pairs, gnn_hidden_dim)
471
-
472
- # Concatenate paired vertex features
473
- edge_features = torch.cat([v1_features, v2_features], dim=2) # (B, num_pairs, gnn_hidden_dim * 2)
474
-
475
- # Predict edge probabilities
476
- edge_logits = self.edge_mlp(edge_features).squeeze(-1) # (B, num_pairs)
477
-
478
- return edge_logits, idx_pairs
479
-
480
- def construct_knn_graph(self, vertices, k):
481
- # vertices: (B, num_vertices, 3)
482
- batch_size, num_vertices, _ = vertices.shape
483
-
484
- # Compute pairwise distances
485
- distances = torch.cdist(vertices, vertices, p=2) # (B, num_vertices, num_vertices)
486
-
487
- # Find k nearest neighbors for each vertex
488
- _, knn_indices = torch.topk(distances, k + 1, dim=-1, largest=False) # +1 to include self
489
- knn_indices = knn_indices[:, :, 1:] # Remove self-connection
490
-
491
- # Create adjacency matrix
492
- adjacency = torch.zeros(batch_size, num_vertices, num_vertices, device=vertices.device)
493
-
494
- # Fill adjacency matrix
495
- batch_idx = torch.arange(batch_size).view(-1, 1, 1).expand(-1, num_vertices, k)
496
- vertex_idx = torch.arange(num_vertices).view(1, -1, 1).expand(batch_size, -1, k)
497
-
498
- adjacency[batch_idx, vertex_idx, knn_indices] = 1.0
499
-
500
- # Make adjacency symmetric
501
- adjacency = torch.max(adjacency, adjacency.transpose(-1, -2))
502
-
503
- return adjacency
504
-
505
-
506
- class GraphAttentionLayer(nn.Module):
507
- def __init__(self, in_features, out_features, num_heads=1, dropout=0.1):
508
- super().__init__()
509
- self.in_features = in_features
510
- self.out_features = out_features
511
- self.num_heads = num_heads
512
- self.dropout = dropout
513
-
514
- assert out_features % num_heads == 0
515
- self.head_dim = out_features // num_heads
516
-
517
- # Linear transformations for queries, keys, values
518
- self.W_q = nn.Linear(in_features, out_features)
519
- self.W_k = nn.Linear(in_features, out_features)
520
- self.W_v = nn.Linear(in_features, out_features)
521
-
522
- # Output projection
523
- self.W_o = nn.Linear(out_features, out_features)
524
-
525
- # Attention mechanism
526
- self.attention = nn.MultiheadAttention(
527
- embed_dim=out_features,
528
- num_heads=num_heads,
529
- dropout=dropout,
530
- batch_first=True
531
- )
532
-
533
- # Layer normalization and residual connection
534
- self.layer_norm = nn.LayerNorm(out_features)
535
- self.ffn = nn.Sequential(
536
- nn.Linear(out_features, out_features * 2),
537
- nn.ReLU(),
538
- nn.Dropout(dropout),
539
- nn.Linear(out_features * 2, out_features)
540
- )
541
- self.layer_norm2 = nn.LayerNorm(out_features)
542
-
543
- def forward(self, x, adjacency_matrix):
544
- # x: (B, num_vertices, in_features)
545
- # adjacency_matrix: (B, num_vertices, num_vertices)
546
- batch_size, num_vertices, _ = x.shape
547
-
548
- # Project to query, key, value
549
- Q = self.W_q(x) # (B, num_vertices, out_features)
550
- K = self.W_k(x) # (B, num_vertices, out_features)
551
- V = self.W_v(x) # (B, num_vertices, out_features)
552
-
553
- # Create attention mask from adjacency matrix
554
- # Convert adjacency to attention mask (0 for allowed, -inf for masked)
555
- attention_mask = (1 - adjacency_matrix) * (-1e9) # (B, num_vertices, num_vertices)
556
-
557
- # Apply multi-head attention with adjacency-based masking
558
- attended_features = []
559
- for b in range(batch_size):
560
- q_b = Q[b:b+1] # (1, num_vertices, out_features)
561
- k_b = K[b:b+1] # (1, num_vertices, out_features)
562
- v_b = V[b:b+1] # (1, num_vertices, out_features)
563
- mask_b = attention_mask[b] # (num_vertices, num_vertices)
564
-
565
- # Apply attention
566
- attn_output, _ = self.attention(q_b, k_b, v_b, attn_mask=mask_b)
567
- attended_features.append(attn_output)
568
-
569
- attended_features = torch.cat(attended_features, dim=0) # (B, num_vertices, out_features)
570
-
571
- # Residual connection and layer norm
572
- x_residual = self.layer_norm(attended_features + Q)
573
-
574
- # Feed-forward network
575
- ffn_output = self.ffn(x_residual)
576
- output = self.layer_norm2(ffn_output + x_residual)
577
-
578
- return output
579
-
580
- # --- Main Model ---
581
- class PointCloudToWireframe(nn.Module):
582
- def __init__(self,
583
- pc_input_features=PC_INPUT_FEATURES,
584
- pc_encoder_output_features=PC_ENCODER_OUTPUT_FEATURES,
585
- max_vertices=MAX_VERTICES,
586
- vertex_coord_dim=VERTEX_COORD_DIM,
587
- gnn_hidden_dim=GNN_HIDDEN_DIM,
588
- num_gnn_layers=NUM_GNN_LAYERS,
589
- hidden_dim=HIDDEN_DIM,
590
- num_decoder_layers=NUM_DECODER_LAYERS,
591
- num_heads=NUM_HEADS):
592
- super().__init__()
593
-
594
- # Point cloud encoder using PointNet2-style architecture
595
- self.encoder = PointNet2Encoder(pc_input_features, pc_encoder_output_features)
596
-
597
- # Vertex prediction head using transformer decoder
598
- self.vertex_head = VertexPredictionHead(
599
- point_feature_dim=pc_encoder_output_features,
600
- global_feature_dim=SA3_MLP[-1], # From PointNet2Encoder global feature
601
- max_vertices=max_vertices,
602
- vertex_coord_dim=vertex_coord_dim,
603
- hidden_dim=hidden_dim,
604
- num_decoder_layers=num_decoder_layers,
605
- num_heads=num_heads
606
- )
607
-
608
- # Edge prediction head using GNN
609
- self.edge_head = EdgePredictionHeadGNN(
610
- vertex_feature_dim=vertex_coord_dim,
611
- gnn_hidden_dim=gnn_hidden_dim,
612
- num_gnn_layers=num_gnn_layers
613
- )
614
-
615
- def forward(self, point_cloud):
616
- # point_cloud: (B, N, 3)
617
- batch_size, num_points, _ = point_cloud.shape
618
-
619
- # Encode point cloud
620
- point_features, global_feature = self.encoder(point_cloud)
621
- # point_features: (B, N, pc_encoder_output_features)
622
- # global_feature: (B, 1024)
623
-
624
- # Predict vertices
625
- predicted_vertices, vertex_confidence = self.vertex_head(
626
- point_features, global_feature, point_coords=point_cloud
627
- )
628
- # predicted_vertices: (B, max_vertices, 3)
629
- # vertex_confidence: (B, max_vertices)
630
-
631
- # Predict edges using GNN (using vertex coordinates directly)
632
- edge_logits, edge_indices = self.edge_head(predicted_vertices)
633
- # edge_logits: (B, num_potential_edges)
634
- # edge_indices: (num_potential_edges, 2)
635
-
636
- return {
637
- 'vertices': predicted_vertices,
638
- 'vertex_confidence': vertex_confidence,
639
- 'edge_logits': edge_logits,
640
- 'edge_indices': edge_indices
641
- }
642
-
643
- class WireframeDataset(Dataset):
644
- def __init__(self, data_dir=DATA_DIR, split=SPLIT, transform=None, max_points=MAX_POINTS):
645
- """
646
- Dataset for point cloud to wireframe conversion.
647
-
648
- Args:
649
- data_dir: Directory containing the pickle files
650
- split: 'train', 'val', or 'test'
651
- transform: Optional transforms to apply to point clouds
652
- max_points: Maximum number of points in the point cloud (default: 8096)
653
- """
654
- self.data_dir = data_dir
655
- self.split = split
656
- self.transform = transform
657
- self.max_points = max_points
658
-
659
- # Get all pickle files in the directory
660
- self.data_files = []
661
- for file in os.listdir(data_dir):
662
- if file.endswith('.pkl'):
663
- self.data_files.append(os.path.join(data_dir, file))
664
-
665
- self.data_files.sort() # Ensure consistent ordering
666
-
667
- def __len__(self):
668
- return len(self.data_files)
669
-
670
- def __getitem__(self, idx):
671
- # Load the pickle file
672
- with open(self.data_files[idx], 'rb') as f:
673
- sample_data = pickle.load(f)
674
-
675
- # Extract data
676
- point_cloud = torch.tensor(sample_data['point_cloud'], dtype=torch.float32)
677
- point_colors = torch.tensor(sample_data['point_colors'], dtype=torch.float32)
678
- gt_vertices = torch.tensor(sample_data['gt_vertices'], dtype=torch.float32)
679
- gt_connections = sample_data['gt_connections'] # List of tuples
680
- sample_id = sample_data['sample_id']
681
-
682
- # Handle point cloud size to match max_points
683
- current_points = point_cloud.shape[0]
684
-
685
- if current_points > self.max_points:
686
- # Downsample using random sampling
687
- indices = torch.randperm(current_points)[:self.max_points]
688
- point_cloud = point_cloud[indices]
689
- point_colors = point_colors[indices]
690
- elif current_points < self.max_points:
691
- # Pad by repeating last point or duplicating random points
692
- pad_size = self.max_points - current_points
693
- if current_points > 0:
694
- # Randomly sample existing points to pad
695
- pad_indices = torch.randint(0, current_points, (pad_size,))
696
- pad_points = point_cloud[pad_indices]
697
- pad_colors = point_colors[pad_indices]
698
- point_cloud = torch.cat([point_cloud, pad_points], dim=0)
699
- point_colors = torch.cat([point_colors, pad_colors], dim=0)
700
- else:
701
- # Edge case: no points, pad with zeros
702
- point_cloud = torch.zeros(self.max_points, 3)
703
- point_colors = torch.zeros(self.max_points, 3)
704
-
705
- # Convert connections to edge format
706
- if len(gt_connections) > 0:
707
- edge_indices = torch.tensor(gt_connections, dtype=torch.long).t() # (2, num_edges)
708
- else:
709
- edge_indices = torch.zeros((2, 0), dtype=torch.long) # Empty edges
710
-
711
- # Apply transforms if any
712
- if self.transform:
713
- point_cloud = self.transform(point_cloud)
714
-
715
- return {
716
- 'point_cloud': point_cloud,
717
- 'point_colors': point_colors,
718
- 'gt_vertices': gt_vertices,
719
- 'edge_indices': edge_indices,
720
- 'sample_id': sample_id
721
- }
722
-
723
- def collate_fn(batch):
724
- """
725
- Custom collate function to handle variable number of vertices and edges.
726
- """
727
- point_clouds = []
728
- point_colors = []
729
- gt_vertices_list = []
730
- edge_indices_list = []
731
- sample_ids = []
732
-
733
- max_vertices = 0
734
-
735
- for sample in batch:
736
- point_clouds.append(sample['point_cloud'])
737
- point_colors.append(sample['point_colors'])
738
- gt_vertices_list.append(sample['gt_vertices'])
739
- edge_indices_list.append(sample['edge_indices'])
740
- sample_ids.append(sample['sample_id'])
741
-
742
- max_vertices = max(max_vertices, sample['gt_vertices'].shape[0])
743
-
744
- # Pad point clouds to same size if needed
745
- max_points = max(pc.shape[0] for pc in point_clouds)
746
- padded_point_clouds = []
747
- padded_point_colors = []
748
-
749
- for pc, colors in zip(point_clouds, point_colors):
750
- if pc.shape[0] < max_points:
751
- # Pad with zeros or repeat last point
752
- pad_size = max_points - pc.shape[0]
753
- pc_padded = torch.cat([pc, torch.zeros(pad_size, 3)], dim=0)
754
- colors_padded = torch.cat([colors, torch.zeros(pad_size, 3)], dim=0)
755
- else:
756
- pc_padded = pc
757
- colors_padded = colors
758
-
759
- padded_point_clouds.append(pc_padded)
760
- padded_point_colors.append(colors_padded)
761
-
762
- # Stack point clouds
763
- point_clouds_batch = torch.stack(padded_point_clouds)
764
- point_colors_batch = torch.stack(padded_point_colors)
765
-
766
- # Pad vertices to max_vertices
767
- padded_vertices = []
768
- vertex_masks = [] # To indicate which vertices are real vs padded
769
-
770
- for vertices in gt_vertices_list:
771
- num_vertices = vertices.shape[0]
772
- if num_vertices < max_vertices:
773
- # Pad with zeros
774
- pad_size = max_vertices - num_vertices
775
- vertices_padded = torch.cat([vertices, torch.zeros(pad_size, 3)], dim=0)
776
- mask = torch.cat([torch.ones(num_vertices), torch.zeros(pad_size)], dim=0).bool()
777
- else:
778
- vertices_padded = vertices
779
- mask = torch.ones(num_vertices).bool()
780
-
781
- padded_vertices.append(vertices_padded)
782
- vertex_masks.append(mask)
783
-
784
- gt_vertices_batch = torch.stack(padded_vertices)
785
- vertex_masks_batch = torch.stack(vertex_masks)
786
-
787
- # Create adjacency matrices for edges
788
- batch_size = len(batch)
789
- adjacency_matrices = torch.zeros(batch_size, max_vertices, max_vertices)
790
-
791
- for i, edge_indices in enumerate(edge_indices_list):
792
- if edge_indices.shape[1] > 0: # If there are edges
793
- src, dst = edge_indices[0], edge_indices[1]
794
- # Only add edges for valid vertices (within the actual vertex count)
795
- valid_edges = (src < gt_vertices_list[i].shape[0]) & (dst < gt_vertices_list[i].shape[0])
796
- src_valid = src[valid_edges]
797
- dst_valid = dst[valid_edges]
798
- adjacency_matrices[i, src_valid, dst_valid] = 1
799
- adjacency_matrices[i, dst_valid, src_valid] = 1 # Undirected graph
800
-
801
- return {
802
- 'point_cloud': point_clouds_batch,
803
- 'point_colors': point_colors_batch,
804
- 'gt_vertices': gt_vertices_batch,
805
- 'vertex_masks': vertex_masks_batch,
806
- 'adjacency_matrices': adjacency_matrices,
807
- 'edge_indices_list': edge_indices_list, # Keep original for loss computation
808
- 'sample_ids': sample_ids
809
- }
810
-
811
- # Loss functions
812
- def compute_vertex_loss(pred_vertices, gt_vertices, vertex_masks, vertex_confidence):
813
- """
814
- Compute vertex position loss using Hungarian matching
815
- """
816
- batch_size = pred_vertices.shape[0]
817
- total_loss = 0.0
818
- total_confidence_loss = 0.0
819
-
820
- for b in range(batch_size):
821
- # Get valid GT vertices for this sample
822
- valid_mask = vertex_masks[b]
823
- gt_verts = gt_vertices[b][valid_mask] # (num_valid_gt, 3)
824
- num_gt = gt_verts.shape[0]
825
-
826
- if num_gt == 0:
827
- # No GT vertices, penalize high confidence predictions
828
- confidence_target = torch.zeros_like(vertex_confidence[b])
829
- conf_loss = F.binary_cross_entropy_with_logits(vertex_confidence[b], confidence_target)
830
- total_confidence_loss += conf_loss
831
- continue
832
-
833
- pred_verts = pred_vertices[b] # (max_vertices, 3)
834
- pred_conf = vertex_confidence[b] # (max_vertices,)
835
-
836
- # Compute pairwise distances between predicted and GT vertices
837
- distances = torch.cdist(pred_verts, gt_verts) # (max_vertices, num_gt)
838
-
839
- # Hungarian matching to find optimal assignment
840
-
841
- # Convert to numpy for scipy
842
- cost_matrix = distances.detach().cpu().numpy()
843
-
844
- # Pad cost matrix if needed
845
- if distances.shape[0] < distances.shape[1]:
846
- # More GT vertices than predicted - pad with high cost
847
- padding = np.full((distances.shape[1] - distances.shape[0], distances.shape[1]), 1e6)
848
- cost_matrix = np.vstack([cost_matrix, padding])
849
- elif distances.shape[0] > distances.shape[1]:
850
- # More predicted vertices than GT - pad with high cost
851
- padding = np.full((distances.shape[0], distances.shape[0] - distances.shape[1]), 1e6)
852
- cost_matrix = np.hstack([cost_matrix, padding])
853
-
854
- # Solve assignment problem
855
- pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
856
-
857
- # Filter out dummy assignments (high cost padding)
858
- # Ensure pred_indices are valid for pred_verts and gt_indices for gt_verts
859
- valid_assignments = (pred_indices < pred_verts.shape[0]) & (gt_indices < num_gt)
860
- pred_indices = pred_indices[valid_assignments]
861
- gt_indices = gt_indices[valid_assignments]
862
-
863
- if len(pred_indices) > 0:
864
- # Compute position loss for matched vertices
865
- matched_pred = pred_verts[pred_indices]
866
- matched_gt = gt_verts[gt_indices]
867
- position_loss = F.mse_loss(matched_pred, matched_gt)
868
- total_loss += position_loss
869
-
870
- # Confidence targets: 1 for matched vertices, 0 for unmatched
871
- confidence_target = torch.zeros_like(pred_conf)
872
- confidence_target[pred_indices] = 1.0
873
- conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
874
- total_confidence_loss += conf_loss
875
- else:
876
- # No valid matches - penalize all predictions
877
- confidence_target = torch.zeros_like(pred_conf)
878
- conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
879
- total_confidence_loss += conf_loss
880
-
881
- return total_loss / batch_size, total_confidence_loss / batch_size
882
-
883
- def compute_edge_loss(edge_logits, edge_indices, gt_adjacency_matrices):
884
- """
885
- Compute edge prediction loss
886
- """
887
- batch_size = gt_adjacency_matrices.shape[0]
888
-
889
- # Create edge targets from adjacency matrices
890
- edge_targets = []
891
- for b in range(batch_size):
892
- gt_adj_for_sample = gt_adjacency_matrices[b] # Shape: (batch_max_gt_verts, batch_max_gt_verts)
893
-
894
- # Create a target adjacency matrix of size (MAX_VERTICES, MAX_VERTICES)
895
- # as edge_indices are generated based on the global MAX_VERTICES.
896
- target_adj_full_size = torch.zeros(
897
- MAX_VERTICES,
898
- MAX_VERTICES,
899
- device=gt_adj_for_sample.device,
900
- dtype=gt_adj_for_sample.dtype
901
- )
902
-
903
- # Determine the actual dimension of the current sample's GT adjacency matrix (padded to batch max)
904
- current_gt_dim = gt_adj_for_sample.shape[0]
905
-
906
- # Copy the relevant part of gt_adj_for_sample into the full-sized target matrix.
907
- # The copy_dim is the minimum of MAX_VERTICES and the current GT dimension,
908
- # ensuring we don't read out of bounds from gt_adj_for_sample or write out of bounds to target_adj_full_size.
909
- copy_dim = min(MAX_VERTICES, current_gt_dim)
910
-
911
- target_adj_full_size[:copy_dim, :copy_dim] = gt_adj_for_sample[:copy_dim, :copy_dim]
912
-
913
- # Extract targets using edge_indices, which refer to pairs in a MAX_VERTICES graph.
914
- targets = target_adj_full_size[edge_indices[:, 0], edge_indices[:, 1]]
915
- edge_targets.append(targets)
916
-
917
- edge_targets = torch.stack(edge_targets) # Shape: (batch_size, num_potential_edges_in_MAX_VERTICES_graph)
918
- edge_targets = edge_targets.to(edge_logits.device)
919
-
920
- # Binary cross entropy loss
921
- edge_loss = F.binary_cross_entropy_with_logits(edge_logits, edge_targets)
922
-
923
- return edge_loss
924
-
925
- def compute_total_loss(model_output, batch):
926
- """
927
- Compute total loss combining vertex and edge losses
928
- """
929
- # Extract model outputs
930
- pred_vertices = model_output['vertices']
931
- vertex_confidence = model_output['vertex_confidence']
932
- edge_logits = model_output['edge_logits']
933
- edge_indices = model_output['edge_indices']
934
-
935
- # Extract ground truth
936
- gt_vertices = batch['gt_vertices'].to(DEVICE)
937
- vertex_masks = batch['vertex_masks'].to(DEVICE)
938
- gt_adjacency = batch['adjacency_matrices'].to(DEVICE)
939
-
940
- # Compute individual losses
941
- vertex_pos_loss, vertex_conf_loss = compute_vertex_loss(
942
- pred_vertices, gt_vertices, vertex_masks, vertex_confidence
943
- )
944
- edge_loss = compute_edge_loss(edge_logits, edge_indices, gt_adjacency)
945
-
946
- # Combine losses
947
- total_loss = (VERTEX_LOSS_WEIGHT * vertex_pos_loss +
948
- CONFIDENCE_LOSS_WEIGHT * vertex_conf_loss +
949
- EDGE_LOSS_WEIGHT * edge_loss)
950
-
951
- return {
952
- 'total_loss': total_loss,
953
- 'vertex_pos_loss': vertex_pos_loss,
954
- 'vertex_conf_loss': vertex_conf_loss,
955
- 'edge_loss': edge_loss
956
- }
957
-
958
- # =============================================================================
959
- # MAIN TRAINING SCRIPT
960
- # =============================================================================
961
-
962
- if __name__ == '__main__':
963
- # Create dataset and dataloader
964
- dataset = WireframeDataset(data_dir=DATA_DIR, split=SPLIT)
965
- dataloader = DataLoader(
966
- dataset,
967
- batch_size=BATCH_SIZE,
968
- shuffle=True,
969
- collate_fn=collate_fn,
970
- num_workers=NUM_WORKERS
971
- )
972
-
973
- # Initialize model
974
- model = PointCloudToWireframe()
975
-
976
- # Move model to device
977
- model = model.to(DEVICE)
978
- print(f"Model loaded on device: {DEVICE}")
979
-
980
- # Initialize optimizer and scheduler
981
- optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
982
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
983
- optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE
984
- )
985
-
986
- # Training loop
987
- model.train()
988
- print("Starting training...")
989
-
990
- for epoch in range(NUM_EPOCHS):
991
- epoch_losses = {
992
- 'total_loss': 0.0,
993
- 'vertex_pos_loss': 0.0,
994
- 'vertex_conf_loss': 0.0,
995
- 'edge_loss': 0.0
996
- }
997
- num_batches = 0
998
-
999
- for batch_idx, batch in enumerate(dataloader):
1000
- # Move data to device
1001
- point_cloud = batch['point_cloud'].to(DEVICE)
1002
-
1003
- # Zero gradients
1004
- optimizer.zero_grad()
1005
-
1006
- # Forward pass
1007
- output = model(point_cloud)
1008
-
1009
- # Compute losses
1010
- losses = compute_total_loss(output, batch)
1011
-
1012
- # Backward pass
1013
- losses['total_loss'].backward()
1014
-
1015
- # Gradient clipping
1016
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM)
1017
-
1018
- # Update weights
1019
- optimizer.step()
1020
-
1021
- # Accumulate losses
1022
- for key in epoch_losses:
1023
- epoch_losses[key] += losses[key].item()
1024
- num_batches += 1
1025
-
1026
- # Print progress
1027
- if batch_idx % LOG_FREQUENCY == 0:
1028
- print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx}/{len(dataloader)}")
1029
- print(f" Total Loss: {losses['total_loss'].item():.4f}")
1030
- print(f" Vertex Pos Loss: {losses['vertex_pos_loss'].item():.4f}")
1031
- print(f" Vertex Conf Loss: {losses['vertex_conf_loss'].item():.4f}")
1032
- print(f" Edge Loss: {losses['edge_loss'].item():.4f}")
1033
-
1034
- # Average losses for the epoch
1035
- for key in epoch_losses:
1036
- epoch_losses[key] /= num_batches
1037
-
1038
- # Update learning rate scheduler
1039
- scheduler.step(epoch_losses['total_loss'])
1040
-
1041
- # Print epoch summary
1042
- print(f"\nEpoch {epoch+1} Summary:")
1043
- print(f" Avg Total Loss: {epoch_losses['total_loss']:.4f}")
1044
- print(f" Avg Vertex Pos Loss: {epoch_losses['vertex_pos_loss']:.4f}")
1045
- print(f" Avg Vertex Conf Loss: {epoch_losses['vertex_conf_loss']:.4f}")
1046
- print(f" Avg Edge Loss: {epoch_losses['edge_loss']:.4f}")
1047
- print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
1048
- print("-" * 50)
1049
-
1050
- # Save checkpoint every epoch
1051
- if (epoch + 1) % CHECKPOINT_SAVE_FREQUENCY == 0:
1052
- checkpoint = {
1053
- 'epoch': epoch + 1,
1054
- 'model_state_dict': model.state_dict(),
1055
- 'optimizer_state_dict': optimizer.state_dict(),
1056
- 'scheduler_state_dict': scheduler.state_dict(),
1057
- 'losses': epoch_losses,
1058
- 'config': {
1059
- 'pc_input_features': PC_INPUT_FEATURES,
1060
- 'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
1061
- 'max_vertices': MAX_VERTICES,
1062
- 'gnn_hidden_dim': GNN_HIDDEN_DIM,
1063
- 'num_gnn_layers': NUM_GNN_LAYERS
1064
- }
1065
- }
1066
- torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
1067
- print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
1068
-
1069
- # Save final model
1070
- torch.save({
1071
- 'model_state_dict': model.state_dict(),
1072
- 'model_config': {
1073
- 'pc_input_features': PC_INPUT_FEATURES,
1074
- 'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
1075
- 'max_vertices': MAX_VERTICES,
1076
- 'gnn_hidden_dim': GNN_HIDDEN_DIM,
1077
- 'num_gnn_layers': NUM_GNN_LAYERS
1078
- }
1079
- }, 'final_model.pth')
1080
-
1081
- print("Training completed!")
1082
- print(f"Dataset size: {len(dataset)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate_pcloud_dataset.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import load_dataset
2
  from hoho2025.viz3d import *
3
  import os
@@ -55,3 +66,4 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
55
 
56
  print(f"Generated {counter} samples in {output_dir}")
57
 
 
 
1
+ # This script processes the 'usm3d/hoho25k' dataset.
2
+ # For each sample in the dataset, it performs the following steps:
3
+ # 1. Reads COLMAP reconstruction data.
4
+ # 2. Extracts 3D point coordinates and their corresponding colors.
5
+ # 3. Retrieves ground truth wireframe vertices and edges.
6
+ # 4. Skips processing if the output file already exists or if no 3D points are found.
7
+ # 5. Saves the extracted point cloud, colors, ground truth data, and sample ID
8
+ # into a pickle file in a specified output directory.
9
+ # The script shuffles the dataset before processing and keeps track of
10
+ # the number of samples successfully processed and saved.
11
+ #
12
  from datasets import load_dataset
13
  from hoho2025.viz3d import *
14
  import os
 
66
 
67
  print(f"Generated {counter} samples in {output_dir}")
68
 
69
+
hoho_cpu.batch DELETED
@@ -1,17 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=8 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_cpu.err # standard error file
8
- #SBATCH --output=hoho_cpu.out # standard output file
9
- #SBATCH --partition=amd # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
-
13
- cd /mnt/personal/skvrnjan/hoho/
14
- module purge
15
- module load Python/3.10.8-GCCcore-12.2.0
16
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
17
- python train.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_cpu_gpu_intel.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=8 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_cpu.err # standard error file
8
- #SBATCH --output=hoho_cpu.out # standard output file
9
- #SBATCH --partition=gpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu.err # standard error file
8
- #SBATCH --output=hoho_gpu.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_cluster.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_class.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu_class2_v4.err # standard error file
8
- #SBATCH --output=hoho_gpu_class2_v4.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_class_cluster.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_class_10d.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu_class_10d_v2.err # standard error file
8
- #SBATCH --output=hoho_gpu_class_10d_v2.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_class_cluster_10d.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_class_10d_2048.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu_class_10d_2048.err # standard error file
8
- #SBATCH --output=hoho_gpu_class_10d_2048.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_class_cluster_10d_2048.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_class_10d_deeper.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu_class_10d_v2_deeper_v2.err # standard error file
8
- #SBATCH --output=hoho_gpu_class_10d_v2_deeper_v2.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_class_cluster_10d_deeper.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_h200.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=20 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu_h200_v2_class.err # standard error file
8
- #SBATCH --output=hoho_gpu_h200_v2_class.out # standard output file
9
- #SBATCH --partition=h200 # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.12.3-GCCcore-13.3.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_pnet_cluster_class_v2.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hoho_gpu_voxel.batch DELETED
@@ -1,19 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --nodes=1 # 1 node
3
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
4
- #SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
5
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
6
- #SBATCH --time=24:00:00 # time limits: 1 hour
7
- #SBATCH --error=hoho_gpu.err # standard error file
8
- #SBATCH --output=hoho_gpu.out # standard output file
9
- #SBATCH --partition=amdgpu # partition name
10
- #SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
11
- #SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
12
- #SBATCH --gres=gpu:1
13
-
14
- cd /mnt/personal/skvrnjan/hoho/
15
- module purge
16
- module load Python/3.10.8-GCCcore-12.2.0
17
- module load CUDA/12.6.0
18
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
19
- python train_voxel_cluster.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
initial_epoch_100.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d9de42aafd8ca9f0e831c920a41d35f61f05ecf6f96c0227a46d16c34cd861c
3
- size 93364299
 
 
 
 
initial_epoch_100_class_v2.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:af653e5fe08dc57b2bb84996896c60d06a71e1c1a0197ad0e07ee2d03dc080e8
3
- size 92609251
 
 
 
 
initial_epoch_100_v2.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:be55556839f8d4fedad7a5cf7520b48859d7bd9a3fbb6b2efbae627ca8ca3ffc
3
- size 103080355
 
 
 
 
initial_epoch_100_v2_aug.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:66ba39359176a2d4d9f444912c58f24d5039b49c802dd8c2be45bcba10694054
3
- size 103080355
 
 
 
 
initial_epoch_60.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef3f7c9462b297447ee24b7994919d625b958d1941626215d1d41f27faf7dac1
3
- size 93364051
 
 
 
 
initial_epoch_60_v2.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:759a5e61e5934aa2417d587621b10a0b390c821c414380185c75fc1379e07f90
3
- size 103080040
 
 
 
 
iterate.batch DELETED
@@ -1,50 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Define parameter ranges
4
- vertex_min=0.4
5
- vertex_max=0.9
6
- vertex_step=0.02
7
-
8
- edge_min=0.4
9
- edge_max=0.9
10
- edge_step=0.02
11
-
12
- # Define results directory path or use first argument
13
- results_dir=${1:-"/mnt/personal/skvrnjan/hoho/results"}
14
-
15
- # Create results directory if it doesn't exist
16
- mkdir -p $results_dir
17
-
18
- # Iterate over all combinations
19
- for vertex_thresh in $(seq $vertex_min $vertex_step $vertex_max); do
20
- for edge_thresh in $(seq $edge_min $edge_step $edge_max); do
21
- # Create job name
22
- job_name="v10_train_v${vertex_thresh}_e${edge_thresh}"
23
-
24
- # Create SLURM script
25
- cat > "${job_name}.slurm" << EOF
26
- #!/bin/bash
27
- #SBATCH --job-name=${job_name}
28
- #SBATCH --output=${job_name}_%j.out
29
- #SBATCH --error=${job_name}_%j.err
30
- #SBATCH --time=4:00:00
31
- #SBATCH --partition=amdfast # partition name
32
- #SBATCH --cpus-per-task=4 # 6 CPUS per task = 12 CPUS per node
33
- #SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
34
- #SBATCH --nodes=1 # 1 node
35
- #SBATCH --ntasks-per-node=1 # 1 tasks per node
36
-
37
- # Run training with specific parameters
38
- cd /mnt/personal/skvrnjan/hoho/
39
- module purge
40
- module load Python/3.10.8-GCCcore-12.2.0
41
- source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
42
- python train.py --vertex_threshold ${vertex_thresh} --edge_threshold ${edge_thresh} --results_dir $results_dir/${job_name} --max_samples 100
43
- EOF
44
-
45
- # Submit job
46
- sbatch "${job_name}.slurm"
47
-
48
- echo "Submitted job: ${job_name}"
49
- done
50
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pnet.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5eef9dc9903751fd7fb2418d729e7021ba7f90e9133f3c64a9694c78c10b61f7
3
- size 93358155
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be55556839f8d4fedad7a5cf7520b48859d7bd9a3fbb6b2efbae627ca8ca3ffc
3
+ size 103080355
predict.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from typing import Tuple, List
3
  from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist
 
1
+ # This script is designed for 3D wireframe reconstruction, primarily focusing on
2
+ # buildings, using multi-view imagery and associated 3D data.
3
+ # It leverages COLMAP reconstructions, depth maps, and semantic segmentations
4
+ # (ADE20k and Gestalt) to identify and predict structural elements.
5
+ # Core tasks include:
6
+ # - Processing and aligning 2D image data (segmentations, depth) with 3D COLMAP point clouds.
7
+ # - Extracting initial 2D/3D vertex candidates from segmentation maps.
8
+ # - Generating local point cloud patches around these candidates.
9
+ # - Employing machine learning models (e.g., PointNet variants) to refine vertex locations
10
+ # and classify potential edges between them.
11
+ # - Optionally, generating datasets of these patches for training ML models.
12
+ # - Merging information from multiple views to produce a final 3D wireframe.
13
  import numpy as np
14
  from typing import Tuple, List
15
  from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist
predict_end.py DELETED
@@ -1,73 +0,0 @@
1
- import torch
2
- from typing import Tuple, List
3
- import numpy as np
4
-
5
- from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
6
- from predict import create_pcloud, convert_entry_to_human_readable, empty_solution
7
- from end_to_end import save_data
8
-
9
- data_folder = '/mnt/personal/skvrnjan/hoho_end/'
10
-
11
- def predict_wireframe(entry, config) -> Tuple[np.ndarray, List[int]]:
12
- """
13
- Predict 3D wireframe from a dataset entry.
14
- """
15
-
16
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
-
18
- good_entry = convert_entry_to_human_readable(entry)
19
- colmap_rec = good_entry['colmap_binary']
20
-
21
- colmap_pcloud = create_pcloud(colmap_rec, good_entry)
22
-
23
- pcloud_14d = pcloud_7d_to_14d(colmap_pcloud)
24
-
25
- dict_to_save = {'pcloud_14d': pcloud_14d,
26
- 'wf_vertices': good_entry['wf_vertices'],
27
- 'wf_edges': good_entry['wf_edges']}
28
-
29
- save_data(dict_to_save, good_entry['order_id'], data_folder=data_folder)
30
-
31
- return empty_solution()
32
-
33
- def pcloud_7d_to_14d(pcloud_7d: np.ndarray) -> np.ndarray:
34
- """
35
- Convert 7D point cloud to higher dimensional by removing ID, then adding ADE and Gestalt segmentation
36
- with bin counting for edge classes.
37
-
38
- Args:
39
- pcloud_7d: Array of shape (N, 7) containing [x, y, z, r, g, b, confidence]
40
-
41
- Returns:
42
- Array of shape (N, 15) containing [x, y, z, r, g, b, ade_class, apex, eave_end_point,
43
- flashing_end_points, eave, ridge, rake, valley, gestalt_rgb]
44
- """
45
- edge_classes = ['apex', 'eave_end_point', 'flashing_end_points', 'eave', 'ridge', 'rake', 'valley']
46
-
47
- # Extract ADE and Gestalt data from colmap_pcloud
48
- ade_values = pcloud_7d['ade']
49
- gestalt_values = pcloud_7d['gestalt']
50
- point_cloud = pcloud_7d['points_7d']
51
-
52
- # Initialize output array (6D base + 1 ADE + 7 edge classes + 1 gestalt)
53
- pcloud_14d = np.zeros((point_cloud.shape[0], 14))
54
- pcloud_14d[:, :6] = point_cloud[:, :6] # Remove confidence/ID column
55
-
56
- # Process ADE segmentation
57
- pcloud_14d[:, 6] = ade_values
58
- pcloud_14d[:, 3:6] = pcloud_14d[:, 3:6] * 2 - 1
59
-
60
- # Process Gestalt segmentation with edge class bin counting
61
- for i, gestalt_list in enumerate(gestalt_values):
62
- if len(gestalt_list) > 0:
63
- gestalt_array = np.array(gestalt_list, dtype=np.uint32)
64
- if gestalt_array.ndim == 2:
65
- # Bin counting for edge classes (columns 7-13)
66
- for j, edge_class in enumerate(edge_classes):
67
- if edge_class in gestalt_color_mapping:
68
- target_color = np.array(gestalt_color_mapping[edge_class])
69
- # Count matches for this edge class
70
- matches = np.sum(np.all(gestalt_array == target_color, axis=1))
71
- pcloud_14d[i, 7 + j] = matches
72
-
73
- return pcloud_14d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
script.py CHANGED
@@ -74,7 +74,7 @@ if __name__ == "__main__":
74
 
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
 
77
- pnet_model = load_pointnet_model(model_path="initial_epoch_100_v2.pth", device=device, predict_score=True)
78
 
79
  pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
80
 
 
74
 
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
 
77
+ pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
78
 
79
  pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
80
 
train.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import load_dataset
2
  from hoho2025.vis import plot_all_modalities
3
  from hoho2025.viz3d import *
@@ -16,11 +28,8 @@ from utils import read_colmap_rec, empty_solution
16
  from hoho2025.metric_helper import hss
17
  from predict import predict_wireframe, predict_wireframe_old
18
  from tqdm import tqdm
19
- #from fast_pointnet import load_pointnet_model
20
  from fast_pointnet_v2 import load_pointnet_model
21
- from fast_voxel import load_3dcnn_model
22
- from fast_pointnet_class_v2 import load_pointnet_model as load_pointnet_class_model
23
- from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_model_10d
24
  import torch
25
  import time
26
 
@@ -47,15 +56,15 @@ print(f"Running with configuration: {config}")
47
  os.makedirs(args.results_dir, exist_ok=True)
48
 
49
 
50
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
51
- ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
52
- ds = ds.shuffle()
53
 
54
  scores_hss = []
55
  scores_f1 = []
56
  scores_iou = []
57
 
58
- show_visu = False
59
 
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
 
@@ -105,10 +114,10 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
105
  colmap = read_colmap_rec(a['colmap_binary'])
106
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
107
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
108
- wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
109
  bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
110
 
111
- visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
112
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
113
 
114
  idx += 1
 
1
+ """
2
+ Training and evaluation script for HoHo wireframe prediction model.
3
+ This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline
4
+ using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports
5
+ configurable thresholds, visualization of results, and saves detailed performance metrics to files.
6
+ Key features:
7
+ - Command-line argument support for model configuration
8
+ - PointNet-based vertex and edge prediction
9
+ - Real-time performance monitoring and visualization
10
+ - Comprehensive metric evaluation and result logging
11
+ - Support for CUDA acceleration when available
12
+ """
13
  from datasets import load_dataset
14
  from hoho2025.vis import plot_all_modalities
15
  from hoho2025.viz3d import *
 
28
  from hoho2025.metric_helper import hss
29
  from predict import predict_wireframe, predict_wireframe_old
30
  from tqdm import tqdm
 
31
  from fast_pointnet_v2 import load_pointnet_model
32
+ from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
 
 
33
  import torch
34
  import time
35
 
 
56
  os.makedirs(args.results_dir, exist_ok=True)
57
 
58
 
59
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
60
+ #ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
61
+ #ds = ds.shuffle()
62
 
63
  scores_hss = []
64
  scores_f1 = []
65
  scores_iou = []
66
 
67
+ show_visu = True
68
 
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
 
 
114
  colmap = read_colmap_rec(a['colmap_binary'])
115
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
116
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
117
+ #wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
118
  bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
119
 
120
+ visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2
121
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
122
 
123
  idx += 1
train_end.py DELETED
@@ -1,73 +0,0 @@
1
- from datasets import load_dataset
2
- from hoho2025.vis import plot_all_modalities
3
- from hoho2025.viz3d import *
4
- import open3d as o3d
5
-
6
- from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
7
- from utils import read_colmap_rec, empty_solution
8
-
9
- #from hoho2025.example_solutions import predict_wireframe
10
- from hoho2025.metric_helper import hss
11
- from predict import predict_wireframe_old
12
- from predict_end import predict_wireframe
13
- from tqdm import tqdm
14
- import torch
15
- import time
16
-
17
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
18
- ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
19
- ds = ds.shuffle()
20
-
21
- scores_hss = []
22
- scores_f1 = []
23
- scores_iou = []
24
-
25
- show_visu = False
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
- config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
30
-
31
- idx = 0
32
- prediction_times = []
33
- for a in tqdm(ds['train'], desc="Processing dataset"):
34
- #plot_all_modalities(a)
35
- #pred_vertices, pred_edges = predict_wireframe_old(a)
36
- #pred_vertices, pred_edges = predict_wireframe(a.copy(), config)
37
- try:
38
- start_time = time.time()
39
- pred_vertices, pred_edges = predict_wireframe(a.copy(), config)
40
- #pred_vertices, pred_edges = predict_wireframe_old(a)
41
- end_time = time.time()
42
- prediction_time = end_time - start_time
43
- prediction_times.append(prediction_time)
44
- mean_time = np.mean(prediction_times)
45
- print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
46
- except:
47
- pred_vertices, pred_edges = empty_solution()
48
-
49
- score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
50
- print(f"Score: {score}")
51
- scores_hss.append(score.hss)
52
- scores_f1.append(score.f1)
53
- scores_iou.append(score.iou)
54
-
55
- if show_visu:
56
- colmap = read_colmap_rec(a['colmap_binary'])
57
- pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
58
- wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
59
- wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
60
- bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
61
-
62
- visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
63
- o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
64
-
65
- idx += 1
66
-
67
- for i in range(10):
68
- print("END OF DATASET")
69
- print(f"Mean HSS: {np.mean(scores_hss):.4f}")
70
- print(f"Mean F1: {np.mean(scores_f1):.4f}")
71
- print(f"Mean IoU: {np.mean(scores_iou):.4f}")
72
- print(config)
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_pointnet import train_pointnet
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/home/skvrnjan/personal/hohocustom/"
8
- model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_pointnet(dataset_path, model_save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet_class.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fast_pointnet_class import train_pointnet
2
  import os
3
 
@@ -10,4 +22,4 @@ if __name__ == "__main__":
10
  os.makedirs(model_save_path, exist_ok=True)
11
 
12
  # Train the model
13
- train_pointnet(dataset_path, model_save_path)
 
1
+ # This script serves as the main entry point for training a PointNet-based
2
+ # classification model.
3
+ #
4
+ # It imports the necessary training function `train_pointnet` from the
5
+ # `fast_pointnet_class` module.
6
+ #
7
+ # The script defines file paths for the input dataset and the directory
8
+ # where the trained model will be saved. It ensures that the model saving
9
+ # directory exists before starting the training.
10
+ #
11
+ # Finally, it initiates the training process by calling the `train_pointnet`
12
+ # function with the specified dataset path, model save path, and a batch size.
13
  from fast_pointnet_class import train_pointnet
14
  import os
15
 
 
22
  os.makedirs(model_save_path, exist_ok=True)
23
 
24
  # Train the model
25
+ train_pointnet(dataset_path, model_save_path, batch_size=4)
train_pnet_class_cluster.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_pointnet_class import train_pointnet
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_v4/initial.pth"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet_class_cluster_10d.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_pointnet_class_10d import train_pointnet
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_v2/initial.pth"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet_class_cluster_10d_2048.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_pointnet_class_10d_2048 import train_pointnet
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d_1m/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_2048/initial.pth"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet_class_cluster_10d_deeper.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_pointnet_class_10d_deeper import train_pointnet
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_deeper_v2/initial.pth"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pnet_cluster.py DELETED
@@ -1,10 +0,0 @@
1
- from fast_pointnet import train_pointnet
2
-
3
- if __name__ == "__main__":
4
-
5
- # Load the dataset
6
- dataset_path = "/mnt/personal/skvrnjan/hohocustom/"
7
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet/initial.pth"
8
-
9
- # Train the model
10
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, score_weight=1.0, class_weight=0.5)
 
 
 
 
 
 
 
 
 
 
 
train_pnet_cluster_class_v2.py DELETED
@@ -1,10 +0,0 @@
1
- from fast_pointnet_class_v2 import train_pointnet
2
-
3
- if __name__ == "__main__":
4
-
5
- # Load the dataset
6
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d_v5/"
7
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_class_v2/initial.pth"
8
-
9
- # Train the model
10
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001)
 
 
 
 
 
 
 
 
 
 
 
train_pnet_cluster_v3.py DELETED
@@ -1,10 +0,0 @@
1
- from fast_pointnet_v3 import train_pointnet
2
-
3
- if __name__ == "__main__":
4
-
5
- # Load the dataset
6
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_v4/"
7
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_v11/initial.pth"
8
-
9
- # Train the model
10
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=256, lr=0.001, score_weight=0.25, class_weight=1.0)
 
 
 
 
 
 
 
 
 
 
 
train_pnet_cluster_v2.py → train_pnet_v2.py RENAMED
@@ -3,8 +3,8 @@ from fast_pointnet_v2 import train_pointnet
3
  if __name__ == "__main__":
4
 
5
  # Load the dataset
6
- dataset_path = "/mnt/personal/skvrnjan/hohocustom_v4/"
7
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_v7/initial.pth"
8
 
9
  # Train the model
10
  train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001, score_weight=0.25, class_weight=1.0)
 
3
  if __name__ == "__main__":
4
 
5
  # Load the dataset
6
+ dataset_path = "xx"
7
+ model_save_path = "xx.pth"
8
 
9
  # Train the model
10
  train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001, score_weight=0.25, class_weight=1.0)
train_voxel.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_voxel import train_3dcnn
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/home/skvrnjan/personal/hohocustom/"
8
- model_save_path = "/home/skvrnjan/personal/hoho_voxel/"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_3dcnn(dataset_path, model_save_path, epochs=100, batch_size=16, lr=0.001, voxel_size=32, score_weight=0.5, class_weight=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_voxel_cluster.py DELETED
@@ -1,13 +0,0 @@
1
- from fast_voxel import train_3dcnn
2
- import os
3
-
4
- if __name__ == "__main__":
5
-
6
- # Load the dataset
7
- dataset_path = "/mnt/personal/skvrnjan/hohocustom/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_voxel/initial.pth"
9
-
10
- os.makedirs(model_save_path, exist_ok=True)
11
-
12
- # Train the model
13
- train_3dcnn(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, voxel_size=32, score_weight=0.5, class_weight=0.5)