Preparation of the files for the public release.
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- color_visu.py +8 -0
- end_to_end.py +24 -0
- end_to_end_deeper.py +0 -946
- fast_pointnet.py +0 -520
- fast_pointnet_class.py +7 -0
- fast_pointnet_class_10d.py +0 -405
- fast_pointnet_class_10d_2048.py +0 -405
- fast_pointnet_class_10d_deeper.py +0 -438
- fast_pointnet_class_deeper.py +0 -527
- fast_pointnet_class_v2.py +0 -508
- fast_pointnet_v2.py +11 -1
- fast_pointnet_v3.py +0 -605
- fast_voxel.py +0 -591
- find_best_results.py +7 -0
- fully_deep.py +0 -1082
- generate_pcloud_dataset.py +12 -0
- hoho_cpu.batch +0 -17
- hoho_cpu_gpu_intel.batch +0 -19
- hoho_gpu.batch +0 -19
- hoho_gpu_class.batch +0 -19
- hoho_gpu_class_10d.batch +0 -19
- hoho_gpu_class_10d_2048.batch +0 -19
- hoho_gpu_class_10d_deeper.batch +0 -19
- hoho_gpu_h200.batch +0 -19
- hoho_gpu_voxel.batch +0 -19
- initial_epoch_100.pth +0 -3
- initial_epoch_100_class_v2.pth +0 -3
- initial_epoch_100_v2.pth +0 -3
- initial_epoch_100_v2_aug.pth +0 -3
- initial_epoch_60.pth +0 -3
- initial_epoch_60_v2.pth +0 -3
- iterate.batch +0 -50
- pnet.pth +2 -2
- predict.py +12 -0
- predict_end.py +0 -73
- script.py +1 -1
- train.py +19 -10
- train_end.py +0 -73
- train_pnet.py +0 -13
- train_pnet_class.py +13 -1
- train_pnet_class_cluster.py +0 -13
- train_pnet_class_cluster_10d.py +0 -13
- train_pnet_class_cluster_10d_2048.py +0 -13
- train_pnet_class_cluster_10d_deeper.py +0 -13
- train_pnet_cluster.py +0 -10
- train_pnet_cluster_class_v2.py +0 -10
- train_pnet_cluster_v3.py +0 -10
- train_pnet_cluster_v2.py → train_pnet_v2.py +2 -2
- train_voxel.py +0 -13
- train_voxel_cluster.py +0 -13
color_visu.py
CHANGED
|
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
import numpy as np
|
| 3 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file generates a color legend image for building component visualization.
|
| 3 |
+
It creates a PNG image showing color swatches and labels for two categories:
|
| 4 |
+
1. Gestalt Colors - for various building components like roof, walls, windows, etc.
|
| 5 |
+
2. Edge Colors - for architectural edges like ridges, eaves, hips, valleys, etc.
|
| 6 |
+
The legend helps visualize the color mappings used in building analysis and annotation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
|
end_to_end.py
CHANGED
|
@@ -1,3 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import torch
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-End Voxel-Based Vertex Detection Pipeline
|
| 3 |
+
|
| 4 |
+
This file implements a complete pipeline for detecting wireframe vertices from 3D point clouds using
|
| 5 |
+
a voxel-based deep learning approach. The pipeline includes:
|
| 6 |
+
|
| 7 |
+
1. Data preprocessing: Converting 14D point clouds into 3D voxel grids with averaged features
|
| 8 |
+
2. Ground truth generation: Creating binary vertex labels and refinement targets from wireframe vertices
|
| 9 |
+
3. Model architecture: VoxelUNet with encoder-decoder structure and 1x1x1 bottleneck for vertex detection
|
| 10 |
+
4. Training: Combined loss function with BCE, Dice loss, and MSE for offset regression
|
| 11 |
+
5. Inference: Predicting vertex locations from new point clouds with visualization
|
| 12 |
+
|
| 13 |
+
Key components:
|
| 14 |
+
- Voxelization with configurable grid size and metric voxel size
|
| 15 |
+
- Per-voxel MLP before convolutional processing
|
| 16 |
+
- Gaussian smoothing of ground truth labels
|
| 17 |
+
- Refinement prediction for sub-voxel accuracy
|
| 18 |
+
- PyVista-based visualization for results analysis
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
- Set inference=False to train a new model
|
| 22 |
+
- Set inference=True to run predictions on existing data
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
import os
|
| 26 |
import pickle
|
| 27 |
import torch
|
end_to_end_deeper.py
DELETED
|
@@ -1,946 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import pickle
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.optim as optim
|
| 6 |
-
import numpy as np
|
| 7 |
-
from typing import Dict, Any, Tuple, List
|
| 8 |
-
from torch.utils.data import Dataset, DataLoader
|
| 9 |
-
import glob
|
| 10 |
-
import pyvista as pv
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
# [Previous code from the existing document remains unchanged up to CombinedLoss class]
|
| 14 |
-
# ... (save_data, load_data, get_data_files, voxelize_points, create_ground_truth, VoxelUNet, VoxelDataset) ...
|
| 15 |
-
|
| 16 |
-
def save_data(dict_to_save: Dict[str, Any], filename: str, data_folder: str = "data") -> None:
|
| 17 |
-
"""Save dictionary data to pickle file"""
|
| 18 |
-
os.makedirs(data_folder, exist_ok=True)
|
| 19 |
-
filepath = os.path.join(data_folder, f"{filename}.pkl")
|
| 20 |
-
with open(filepath, 'wb') as f:
|
| 21 |
-
pickle.dump(dict_to_save, f)
|
| 22 |
-
#print(f"Data saved to {filepath}")
|
| 23 |
-
|
| 24 |
-
def load_data(filepath: str) -> Dict[str, Any]:
|
| 25 |
-
"""Load dictionary data from pickle file"""
|
| 26 |
-
with open(filepath, 'rb') as f:
|
| 27 |
-
data = pickle.load(f)
|
| 28 |
-
#print(f"Data loaded from {filepath}")
|
| 29 |
-
return data
|
| 30 |
-
|
| 31 |
-
def get_data_files(data_folder: str = "data", pattern: str = "*.pkl") -> List[str]:
|
| 32 |
-
"""Get list of data files from folder"""
|
| 33 |
-
search_pattern = os.path.join(data_folder, pattern)
|
| 34 |
-
files = glob.glob(search_pattern)
|
| 35 |
-
#print(f"Found {len(files)} data files in {data_folder}")
|
| 36 |
-
return files
|
| 37 |
-
|
| 38 |
-
def voxelize_points(points: np.ndarray,
|
| 39 |
-
grid_size_xy: int = 64,
|
| 40 |
-
voxel_size_metric: float = 0.25
|
| 41 |
-
) -> Tuple[torch.Tensor, np.ndarray, Dict[str, Any]]:
|
| 42 |
-
"""
|
| 43 |
-
Voxelize 14D point cloud into a 3D grid with a fixed number of voxels and fixed metric voxel size.
|
| 44 |
-
The Z dimension of the grid will have grid_size_xy / 4 voxels.
|
| 45 |
-
The point cloud is centered within this metric grid. Points outside are discarded.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
points: (N, 14) array where first 3 dims are xyz (original coordinates).
|
| 49 |
-
grid_size_xy: Number of voxels along X and Y dimensions.
|
| 50 |
-
voxel_size_metric: The physical size of each voxel (e.g., 0.5 units).
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
voxel_grid: (NUM_FEATURES, dim_z, dim_y, dim_x) tensor.
|
| 54 |
-
voxel_indices_for_points: (N_points_in_grid, 3) integer voxel indices (z, y, x)
|
| 55 |
-
for each input point that falls within the grid.
|
| 56 |
-
scale_info: Dict with transformation parameters:
|
| 57 |
-
'grid_origin_metric': Real-world metric coordinate of the corner of voxel [0,0,0] (x,y,z).
|
| 58 |
-
'voxel_size_metric': The metric size of a voxel.
|
| 59 |
-
'grid_dims_voxels': Tuple (dim_x, dim_y, dim_z) representing number of voxels.
|
| 60 |
-
'pc_centroid_metric': Centroid of the input point cloud (x,y,z).
|
| 61 |
-
"""
|
| 62 |
-
NUM_FEATURES = 14
|
| 63 |
-
dim_x = grid_size_xy
|
| 64 |
-
dim_y = grid_size_xy
|
| 65 |
-
dim_z = grid_size_xy # Assuming cubic grid based on usage in create_ground_truth and VoxelDataset
|
| 66 |
-
|
| 67 |
-
if dim_z == 0: dim_z = 1 # Ensure at least one voxel in Z
|
| 68 |
-
|
| 69 |
-
# grid_dims_voxels stores (num_voxels_x, num_voxels_y, num_voxels_z)
|
| 70 |
-
grid_dims_voxels = np.array([dim_x, dim_y, dim_z], dtype=int)
|
| 71 |
-
|
| 72 |
-
def _get_empty_return(reason: str = ""):
|
| 73 |
-
# Voxel grid shape is fixed: (NUM_FEATURES, dim_z, dim_y, dim_x)
|
| 74 |
-
voxel_grid_empty = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
|
| 75 |
-
voxel_indices_empty = np.empty((0, 3), dtype=int)
|
| 76 |
-
scale_info_empty = {
|
| 77 |
-
'grid_origin_metric': np.zeros(3, dtype=float),
|
| 78 |
-
'voxel_size_metric': voxel_size_metric,
|
| 79 |
-
'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
|
| 80 |
-
'pc_centroid_metric': np.zeros(3, dtype=float),
|
| 81 |
-
}
|
| 82 |
-
return voxel_grid_empty, voxel_indices_empty, scale_info_empty
|
| 83 |
-
|
| 84 |
-
if points.shape[0] == 0:
|
| 85 |
-
return _get_empty_return("Initial empty point cloud")
|
| 86 |
-
|
| 87 |
-
xyz = points[:, :3] # Metric coordinates of points
|
| 88 |
-
features_other = points[:, 3:] # Other features
|
| 89 |
-
|
| 90 |
-
pc_centroid_metric = xyz.mean(axis=0) # (cx, cy, cz)
|
| 91 |
-
|
| 92 |
-
# Calculate the metric origin of the grid such that the point cloud centroid
|
| 93 |
-
# aligns with the center of the metric grid.
|
| 94 |
-
# grid_metric_span is (total_metric_width_x, total_metric_height_y, total_metric_depth_z)
|
| 95 |
-
grid_metric_span = grid_dims_voxels * voxel_size_metric
|
| 96 |
-
# grid_origin_metric is the real-world (x,y,z) coordinate of the corner of voxel (0,0,0)
|
| 97 |
-
grid_origin_metric = pc_centroid_metric - (grid_metric_span / 2.0)
|
| 98 |
-
|
| 99 |
-
# Initialize voxel_grid: PyTorch expects (C, D, H, W)
|
| 100 |
-
# Here, D=dim_z, H=dim_y, W=dim_x
|
| 101 |
-
voxel_grid = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
|
| 102 |
-
|
| 103 |
-
# Convert point metric coordinates to continuous voxel coordinates (potentially fractional and outside [0, dim-1])
|
| 104 |
-
# continuous_voxel_coords[i] = (px_i, py_i, pz_i) in continuous voxel grid space
|
| 105 |
-
continuous_voxel_coords = (xyz - grid_origin_metric) / voxel_size_metric
|
| 106 |
-
|
| 107 |
-
# To store (z_idx, y_idx, x_idx) for each point, for easier indexing into torch tensors
|
| 108 |
-
voxel_indices_for_points_zyx_order = []
|
| 109 |
-
|
| 110 |
-
for i in range(points.shape[0]):
|
| 111 |
-
# current_point_continuous_coord_xyz is (x_coord, y_coord, z_coord) in continuous voxel space
|
| 112 |
-
current_point_continuous_coord_xyz = continuous_voxel_coords[i]
|
| 113 |
-
|
| 114 |
-
# Voxel integer indices (ix, iy, iz) by flooring. This is the voxel cell the point falls into.
|
| 115 |
-
voxel_idx_int_xyz = np.round(current_point_continuous_coord_xyz).astype(int)
|
| 116 |
-
|
| 117 |
-
# Check if the point falls outside the grid boundaries
|
| 118 |
-
# grid_dims_voxels is (dim_x, dim_y, dim_z)
|
| 119 |
-
idx_x, idx_y, idx_z = voxel_idx_int_xyz[0], voxel_idx_int_xyz[1], voxel_idx_int_xyz[2]
|
| 120 |
-
|
| 121 |
-
if not (0 <= idx_x < grid_dims_voxels[0] and \
|
| 122 |
-
0 <= idx_y < grid_dims_voxels[1] and \
|
| 123 |
-
0 <= idx_z < grid_dims_voxels[2]):
|
| 124 |
-
# Point is outside the grid, skip it
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
# At this point, idx_x, idx_y, idx_z are guaranteed to be within grid bounds.
|
| 128 |
-
# No explicit clipping is needed here, but using them directly.
|
| 129 |
-
|
| 130 |
-
voxel_indices_for_points_zyx_order.append([idx_z, idx_y, idx_x])
|
| 131 |
-
|
| 132 |
-
# Calculate offset for the first 3 features:
|
| 133 |
-
# Center of the assigned voxel in continuous grid index space (e.g., [0.5,0.5,0.5] for voxel [0,0,0])
|
| 134 |
-
assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
|
| 135 |
-
|
| 136 |
-
# Offset of the point from its assigned voxel center, in grid units.
|
| 137 |
-
offset_xyz_in_grid_units = current_point_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
|
| 138 |
-
|
| 139 |
-
# Accumulate features in voxel_grid (C, Z, Y, X)
|
| 140 |
-
# Store dx, dy, dz (offsets in X, Y, Z dimensions)
|
| 141 |
-
voxel_grid[0, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[0] # dx
|
| 142 |
-
voxel_grid[1, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[1] # dy
|
| 143 |
-
voxel_grid[2, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[2] # dz
|
| 144 |
-
|
| 145 |
-
# Accumulate other original features (from index 3 onwards)
|
| 146 |
-
if NUM_FEATURES > 3:
|
| 147 |
-
current_point_other_features = features_other[i]
|
| 148 |
-
voxel_grid[3:, idx_z, idx_y, idx_x] += torch.tensor(current_point_other_features, dtype=torch.float32)
|
| 149 |
-
|
| 150 |
-
final_voxel_indices_for_points_zyx = np.array(voxel_indices_for_points_zyx_order, dtype=int) if voxel_indices_for_points_zyx_order else np.empty((0,3), dtype=int)
|
| 151 |
-
|
| 152 |
-
scale_info = {
|
| 153 |
-
'grid_origin_metric': grid_origin_metric,
|
| 154 |
-
'voxel_size_metric': voxel_size_metric,
|
| 155 |
-
'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
|
| 156 |
-
'pc_centroid_metric': pc_centroid_metric,
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
return voxel_grid, final_voxel_indices_for_points_zyx, scale_info
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def create_ground_truth(vertices: np.ndarray,
|
| 163 |
-
scale_info: Dict[str, Any]
|
| 164 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 165 |
-
"""
|
| 166 |
-
Create ground truth voxel labels and refinement targets using metric voxelization info.
|
| 167 |
-
The grid dimensions are taken from scale_info.
|
| 168 |
-
|
| 169 |
-
Args:
|
| 170 |
-
vertices: (M, 3) vertex coordinates in original metric space.
|
| 171 |
-
scale_info: Dict from voxelize_points. Requires:
|
| 172 |
-
'grid_origin_metric', 'voxel_size_metric', 'grid_dims_voxels'.
|
| 173 |
-
Returns:
|
| 174 |
-
vertex_labels: (dim_z, dim_y, dim_x) binary labels (1.0 for voxel containing a vertex).
|
| 175 |
-
refinement_targets: (3, dim_z, dim_y, dim_x) offset (dx,dy,dz) from voxel cell center
|
| 176 |
-
in grid units. Range approx [-0.5, 0.5).
|
| 177 |
-
"""
|
| 178 |
-
grid_origin_metric = scale_info['grid_origin_metric'] # (ox, oy, oz)
|
| 179 |
-
voxel_size_metric = scale_info['voxel_size_metric']
|
| 180 |
-
# grid_dims_voxels is (num_voxels_x, num_voxels_y, num_voxels_z)
|
| 181 |
-
grid_dims_voxels = np.array(scale_info['grid_dims_voxels'])
|
| 182 |
-
|
| 183 |
-
dim_x, dim_y, dim_z = grid_dims_voxels[0], grid_dims_voxels[1], grid_dims_voxels[2]
|
| 184 |
-
|
| 185 |
-
# Labels tensor: (dim_z, dim_y, dim_x)
|
| 186 |
-
vertex_labels = torch.zeros(dim_z, dim_y, dim_x, dtype=torch.float32)
|
| 187 |
-
# Refinement targets tensor: (3, dim_z, dim_y, dim_x) for (dx, dy, dz) offsets
|
| 188 |
-
refinement_targets = torch.zeros(3, dim_z, dim_y, dim_x, dtype=torch.float32)
|
| 189 |
-
|
| 190 |
-
if vertices.shape[0] == 0:
|
| 191 |
-
return vertex_labels, refinement_targets
|
| 192 |
-
|
| 193 |
-
# Convert vertex metric coordinates to continuous voxel coordinates
|
| 194 |
-
# (potentially fractional and outside [0, dim-1])
|
| 195 |
-
continuous_voxel_coords_vertices = (vertices - grid_origin_metric) / voxel_size_metric
|
| 196 |
-
|
| 197 |
-
for i in range(vertices.shape[0]):
|
| 198 |
-
# v_continuous_coord_xyz is (vx, vy, vz) for the current vertex in continuous voxel space
|
| 199 |
-
v_continuous_coord_xyz = continuous_voxel_coords_vertices[i]
|
| 200 |
-
|
| 201 |
-
# Integer voxel index (ix, iy, iz) by flooring
|
| 202 |
-
v_idx_int_xyz = np.floor(v_continuous_coord_xyz).astype(int)
|
| 203 |
-
|
| 204 |
-
# Clip to be within grid boundaries [0, dim-1]
|
| 205 |
-
idx_x = np.clip(v_idx_int_xyz[0], 0, dim_x - 1)
|
| 206 |
-
idx_y = np.clip(v_idx_int_xyz[1], 0, dim_y - 1)
|
| 207 |
-
idx_z = np.clip(v_idx_int_xyz[2], 0, dim_z - 1)
|
| 208 |
-
|
| 209 |
-
# Set label for this voxel (using z, y, x order for tensor access)
|
| 210 |
-
vertex_labels[idx_z, idx_y, idx_x] = 1.0
|
| 211 |
-
|
| 212 |
-
# Calculate refinement offset:
|
| 213 |
-
# Center of the *assigned* (clipped) voxel in continuous grid index space
|
| 214 |
-
assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
|
| 215 |
-
|
| 216 |
-
# Offset of the vertex from its *assigned* voxel center, in grid units.
|
| 217 |
-
offset_xyz_grid_units = v_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
|
| 218 |
-
|
| 219 |
-
# Store dx, dy, dz in channels 0, 1, 2 respectively
|
| 220 |
-
# refinement_targets is (3, Z, Y, X)
|
| 221 |
-
refinement_targets[0, idx_z, idx_y, idx_x] = offset_xyz_grid_units[0] # dx
|
| 222 |
-
refinement_targets[1, idx_z, idx_y, idx_x] = offset_xyz_grid_units[1] # dy
|
| 223 |
-
refinement_targets[2, idx_z, idx_y, idx_x] = offset_xyz_grid_units[2] # dz
|
| 224 |
-
|
| 225 |
-
return vertex_labels, refinement_targets
|
| 226 |
-
|
| 227 |
-
class VoxelUNet(nn.Module):
|
| 228 |
-
"""Enhanced U-Net for voxel-based vertex detection with increased capacity and advanced features."""
|
| 229 |
-
|
| 230 |
-
def __init__(self, in_channels: int = 14, base_channels: int = 64, bottleneck_expansion: int = 4,
|
| 231 |
-
use_attention: bool = True, use_residual: bool = True, dropout_rate: float = 0.1):
|
| 232 |
-
super(VoxelUNet, self).__init__()
|
| 233 |
-
|
| 234 |
-
bc = base_channels
|
| 235 |
-
self.use_attention = use_attention
|
| 236 |
-
self.use_residual = use_residual
|
| 237 |
-
|
| 238 |
-
# Encoder with increased depth and capacity
|
| 239 |
-
self.enc1 = self._conv_block(in_channels, bc, use_residual=False) # bc
|
| 240 |
-
self.enc2 = self._conv_block(bc, bc * 2, dropout_rate) # bc*2
|
| 241 |
-
self.enc3 = self._conv_block(bc * 2, bc * 4, dropout_rate) # bc*4
|
| 242 |
-
self.enc4 = self._conv_block(bc * 4, bc * 8, dropout_rate) # bc*8
|
| 243 |
-
self.enc5 = self._conv_block(bc * 8, bc * 16, dropout_rate) # bc*16
|
| 244 |
-
self.enc6 = self._conv_block(bc * 16, bc * 32, dropout_rate) # bc*32 (new layer)
|
| 245 |
-
|
| 246 |
-
self.pool = nn.MaxPool3d(2)
|
| 247 |
-
|
| 248 |
-
# Enhanced bottleneck with multiple processing paths
|
| 249 |
-
self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 250 |
-
bottleneck_in_channels = bc * 32
|
| 251 |
-
bottleneck_width = bottleneck_in_channels * bottleneck_expansion
|
| 252 |
-
|
| 253 |
-
self.bottleneck = nn.Sequential(
|
| 254 |
-
nn.Conv3d(bottleneck_in_channels, bottleneck_width, kernel_size=1, padding=0, bias=True),
|
| 255 |
-
nn.BatchNorm3d(bottleneck_width),
|
| 256 |
-
nn.ReLU(inplace=True),
|
| 257 |
-
nn.Dropout3d(dropout_rate),
|
| 258 |
-
nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True),
|
| 259 |
-
nn.BatchNorm3d(bottleneck_width),
|
| 260 |
-
nn.ReLU(inplace=True),
|
| 261 |
-
nn.Dropout3d(dropout_rate),
|
| 262 |
-
nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True),
|
| 263 |
-
nn.BatchNorm3d(bottleneck_width),
|
| 264 |
-
nn.ReLU(inplace=True)
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# Attention modules for skip connections (if enabled)
|
| 268 |
-
if self.use_attention:
|
| 269 |
-
self.att6 = self._attention_block(bottleneck_width, bc * 32, bc * 16)
|
| 270 |
-
self.att5 = self._attention_block(bc * 32, bc * 16, bc * 8)
|
| 271 |
-
self.att4 = self._attention_block(bc * 16, bc * 8, bc * 4)
|
| 272 |
-
self.att3 = self._attention_block(bc * 8, bc * 4, bc * 2)
|
| 273 |
-
self.att2 = self._attention_block(bc * 4, bc * 2, bc)
|
| 274 |
-
|
| 275 |
-
# Enhanced decoder with more capacity
|
| 276 |
-
self.dec6 = self._conv_block(bottleneck_width + bc * 32, bc * 32, dropout_rate)
|
| 277 |
-
|
| 278 |
-
self.up5 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 279 |
-
self.dec5 = self._conv_block(bc * 32 + bc * 16, bc * 16, dropout_rate)
|
| 280 |
-
|
| 281 |
-
self.up4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 282 |
-
self.dec4 = self._conv_block(bc * 16 + bc * 8, bc * 8, dropout_rate)
|
| 283 |
-
|
| 284 |
-
self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 285 |
-
self.dec3 = self._conv_block(bc * 8 + bc * 4, bc * 4, dropout_rate)
|
| 286 |
-
|
| 287 |
-
self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 288 |
-
self.dec2 = self._conv_block(bc * 4 + bc * 2, bc * 2, dropout_rate)
|
| 289 |
-
|
| 290 |
-
self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 291 |
-
self.dec1 = self._conv_block(bc * 2 + bc, bc, dropout_rate)
|
| 292 |
-
|
| 293 |
-
# Enhanced output heads with intermediate processing
|
| 294 |
-
self.vertex_intermediate = self._conv_block(bc, bc // 2, 0.0, use_residual=False)
|
| 295 |
-
self.vertex_head = nn.Conv3d(bc // 2, 1, kernel_size=1)
|
| 296 |
-
|
| 297 |
-
self.refinement_intermediate = self._conv_block(bc, bc // 2, 0.0, use_residual=False)
|
| 298 |
-
self.refinement_head = nn.Conv3d(bc // 2, 3, kernel_size=1)
|
| 299 |
-
|
| 300 |
-
self.tanh = nn.Tanh()
|
| 301 |
-
|
| 302 |
-
def _conv_block(self, in_channels: int, out_channels: int, dropout_rate: float = 0.0,
|
| 303 |
-
use_residual: bool = None) -> nn.Sequential:
|
| 304 |
-
"""Enhanced convolutional block with optional residual connections and dropout."""
|
| 305 |
-
if use_residual is None:
|
| 306 |
-
use_residual = self.use_residual
|
| 307 |
-
|
| 308 |
-
layers = []
|
| 309 |
-
|
| 310 |
-
# First conv
|
| 311 |
-
layers.extend([
|
| 312 |
-
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 313 |
-
nn.BatchNorm3d(out_channels),
|
| 314 |
-
nn.ReLU(inplace=True)
|
| 315 |
-
])
|
| 316 |
-
|
| 317 |
-
if dropout_rate > 0:
|
| 318 |
-
layers.append(nn.Dropout3d(dropout_rate))
|
| 319 |
-
|
| 320 |
-
# Second conv
|
| 321 |
-
layers.extend([
|
| 322 |
-
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 323 |
-
nn.BatchNorm3d(out_channels)
|
| 324 |
-
])
|
| 325 |
-
|
| 326 |
-
# Third conv for extra capacity
|
| 327 |
-
if out_channels >= 128:
|
| 328 |
-
layers.extend([
|
| 329 |
-
nn.ReLU(inplace=True),
|
| 330 |
-
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 331 |
-
nn.BatchNorm3d(out_channels)
|
| 332 |
-
])
|
| 333 |
-
|
| 334 |
-
block = nn.Sequential(*layers)
|
| 335 |
-
|
| 336 |
-
# Add residual connection if channels match and residual is enabled
|
| 337 |
-
if use_residual and in_channels == out_channels:
|
| 338 |
-
return ResidualBlock(block)
|
| 339 |
-
else:
|
| 340 |
-
return nn.Sequential(block, nn.ReLU(inplace=True))
|
| 341 |
-
|
| 342 |
-
def _attention_block(self, gate_channels: int, skip_channels: int, out_channels: int) -> nn.Module:
|
| 343 |
-
"""Attention gate for focusing on relevant features in skip connections."""
|
| 344 |
-
return AttentionGate(gate_channels, skip_channels, out_channels)
|
| 345 |
-
|
| 346 |
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 347 |
-
# Encoder path with increased depth
|
| 348 |
-
e1 = self.enc1(x) # bc
|
| 349 |
-
p1 = self.pool(e1)
|
| 350 |
-
|
| 351 |
-
e2 = self.enc2(p1) # bc*2
|
| 352 |
-
p2 = self.pool(e2)
|
| 353 |
-
|
| 354 |
-
e3 = self.enc3(p2) # bc*4
|
| 355 |
-
p3 = self.pool(e3)
|
| 356 |
-
|
| 357 |
-
e4 = self.enc4(p3) # bc*8
|
| 358 |
-
p4 = self.pool(e4)
|
| 359 |
-
|
| 360 |
-
e5 = self.enc5(p4) # bc*16
|
| 361 |
-
p5 = self.pool(e5)
|
| 362 |
-
|
| 363 |
-
e6 = self.enc6(p5) # bc*32
|
| 364 |
-
p6 = self.pool(e6)
|
| 365 |
-
|
| 366 |
-
# Enhanced bottleneck
|
| 367 |
-
b_pooled = self.adaptive_pool(p6)
|
| 368 |
-
b = self.bottleneck(b_pooled)
|
| 369 |
-
|
| 370 |
-
# Enhanced decoder path with attention
|
| 371 |
-
u6_from_b = nn.functional.interpolate(b, size=e6.shape[2:], mode='trilinear', align_corners=True)
|
| 372 |
-
if self.use_attention:
|
| 373 |
-
e6_att = self.att6(u6_from_b, e6)
|
| 374 |
-
cat6 = torch.cat([u6_from_b, e6_att], dim=1)
|
| 375 |
-
else:
|
| 376 |
-
cat6 = torch.cat([u6_from_b, e6], dim=1)
|
| 377 |
-
d6 = self.dec6(cat6)
|
| 378 |
-
|
| 379 |
-
u5 = self.up5(d6)
|
| 380 |
-
if self.use_attention:
|
| 381 |
-
e5_att = self.att5(u5, e5)
|
| 382 |
-
cat5 = torch.cat([u5, e5_att], dim=1)
|
| 383 |
-
else:
|
| 384 |
-
cat5 = torch.cat([u5, e5], dim=1)
|
| 385 |
-
d5 = self.dec5(cat5)
|
| 386 |
-
|
| 387 |
-
u4 = self.up4(d5)
|
| 388 |
-
if self.use_attention:
|
| 389 |
-
e4_att = self.att4(u4, e4)
|
| 390 |
-
cat4 = torch.cat([u4, e4_att], dim=1)
|
| 391 |
-
else:
|
| 392 |
-
cat4 = torch.cat([u4, e4], dim=1)
|
| 393 |
-
d4 = self.dec4(cat4)
|
| 394 |
-
|
| 395 |
-
u3 = self.up3(d4)
|
| 396 |
-
if self.use_attention:
|
| 397 |
-
e3_att = self.att3(u3, e3)
|
| 398 |
-
cat3 = torch.cat([u3, e3_att], dim=1)
|
| 399 |
-
else:
|
| 400 |
-
cat3 = torch.cat([u3, e3], dim=1)
|
| 401 |
-
d3 = self.dec3(cat3)
|
| 402 |
-
|
| 403 |
-
u2 = self.up2(d3)
|
| 404 |
-
if self.use_attention:
|
| 405 |
-
e2_att = self.att2(u2, e2)
|
| 406 |
-
cat2 = torch.cat([u2, e2_att], dim=1)
|
| 407 |
-
else:
|
| 408 |
-
cat2 = torch.cat([u2, e2], dim=1)
|
| 409 |
-
d2 = self.dec2(cat2)
|
| 410 |
-
|
| 411 |
-
u1 = self.up1(d2)
|
| 412 |
-
cat1 = torch.cat([u1, e1], dim=1)
|
| 413 |
-
d1 = self.dec1(cat1)
|
| 414 |
-
|
| 415 |
-
# Enhanced output heads
|
| 416 |
-
vertex_features = self.vertex_intermediate(d1)
|
| 417 |
-
vertex_logits = self.vertex_head(vertex_features)
|
| 418 |
-
|
| 419 |
-
refinement_features = self.refinement_intermediate(d1)
|
| 420 |
-
refinement = self.tanh(self.refinement_head(refinement_features)) * 0.5
|
| 421 |
-
|
| 422 |
-
return vertex_logits, refinement
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
class ResidualBlock(nn.Module):
|
| 426 |
-
"""Residual block wrapper for skip connections."""
|
| 427 |
-
def __init__(self, block):
|
| 428 |
-
super().__init__()
|
| 429 |
-
self.block = block
|
| 430 |
-
|
| 431 |
-
def forward(self, x):
|
| 432 |
-
return torch.relu(self.block(x) + x)
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
class AttentionGate(nn.Module):
|
| 436 |
-
"""Attention gate for U-Net skip connections."""
|
| 437 |
-
def __init__(self, gate_channels, skip_channels, out_channels):
|
| 438 |
-
super().__init__()
|
| 439 |
-
self.gate_conv = nn.Conv3d(gate_channels, out_channels, kernel_size=1, bias=True)
|
| 440 |
-
self.skip_conv = nn.Conv3d(skip_channels, out_channels, kernel_size=1, bias=True)
|
| 441 |
-
self.attention_conv = nn.Conv3d(out_channels, 1, kernel_size=1, bias=True)
|
| 442 |
-
self.relu = nn.ReLU(inplace=True)
|
| 443 |
-
self.sigmoid = nn.Sigmoid()
|
| 444 |
-
|
| 445 |
-
def forward(self, gate, skip):
|
| 446 |
-
gate_proj = self.gate_conv(gate)
|
| 447 |
-
skip_proj = self.skip_conv(skip)
|
| 448 |
-
|
| 449 |
-
# Ensure spatial dimensions match
|
| 450 |
-
if gate_proj.shape[2:] != skip_proj.shape[2:]:
|
| 451 |
-
gate_proj = nn.functional.interpolate(
|
| 452 |
-
gate_proj, size=skip_proj.shape[2:],
|
| 453 |
-
mode='trilinear', align_corners=True
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
combined = self.relu(gate_proj + skip_proj)
|
| 457 |
-
attention = self.sigmoid(self.attention_conv(combined))
|
| 458 |
-
|
| 459 |
-
return skip * attention
|
| 460 |
-
|
| 461 |
-
class VoxelDataset(Dataset):
|
| 462 |
-
def __init__(self, data_files: List[str], voxel_size: float = 0.1, grid_size: int = 64):
|
| 463 |
-
self.data_files = data_files
|
| 464 |
-
self.voxel_size = voxel_size
|
| 465 |
-
self.grid_size = grid_size
|
| 466 |
-
|
| 467 |
-
def __len__(self):
|
| 468 |
-
return len(self.data_files)
|
| 469 |
-
|
| 470 |
-
def __getitem__(self, idx):
|
| 471 |
-
data = load_data(self.data_files[idx])
|
| 472 |
-
|
| 473 |
-
voxel_grid, _, scale_info = voxelize_points(
|
| 474 |
-
data['pcloud_14d'], self.grid_size, self.voxel_size
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
wf_vertices_np = np.array(data['wf_vertices'])
|
| 478 |
-
vertex_labels, refinement_targets = create_ground_truth(
|
| 479 |
-
wf_vertices_np, scale_info
|
| 480 |
-
)
|
| 481 |
-
|
| 482 |
-
return voxel_grid, vertex_labels, refinement_targets, scale_info
|
| 483 |
-
|
| 484 |
-
import torch.nn as nn
|
| 485 |
-
import torch.nn.functional as F
|
| 486 |
-
|
| 487 |
-
class CombinedLoss(nn.Module):
|
| 488 |
-
"""
|
| 489 |
-
Combined loss for vertex classification and offset regression.
|
| 490 |
-
Uses:
|
| 491 |
-
- BCEWithLogitsLoss
|
| 492 |
-
- Dice loss
|
| 493 |
-
- MSE loss on refinement offsets (only over positive voxels)
|
| 494 |
-
- Gaussian blur on the GT labels
|
| 495 |
-
"""
|
| 496 |
-
def __init__(self,
|
| 497 |
-
vertex_weight: float = 1.0,
|
| 498 |
-
refinement_weight: float = 0.1,
|
| 499 |
-
dice_weight: float = 0.5,
|
| 500 |
-
blur_kernel_size: int = 5,
|
| 501 |
-
blur_sigma: float = 1.0,
|
| 502 |
-
eps: float = 1e-6):
|
| 503 |
-
super().__init__()
|
| 504 |
-
self.vertex_weight = vertex_weight
|
| 505 |
-
self.refinement_weight = refinement_weight
|
| 506 |
-
self.dice_weight = dice_weight
|
| 507 |
-
self.eps = eps
|
| 508 |
-
|
| 509 |
-
# BCE with logits
|
| 510 |
-
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 511 |
-
# MSE for offset regression
|
| 512 |
-
self.mse_loss = nn.MSELoss()
|
| 513 |
-
|
| 514 |
-
# build 3D gaussian kernel
|
| 515 |
-
k = blur_kernel_size
|
| 516 |
-
coords = torch.arange(k, dtype=torch.float32) - (k - 1) / 2
|
| 517 |
-
xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij')
|
| 518 |
-
kernel = torch.exp(-(xx**2 + yy**2 + zz**2) / (2 * blur_sigma**2))
|
| 519 |
-
# shape (1,1,k,k,k)
|
| 520 |
-
kernel = kernel.view(1, 1, k, k, k)
|
| 521 |
-
self.register_buffer('gaussian_kernel', kernel)
|
| 522 |
-
self.pad = k // 2
|
| 523 |
-
|
| 524 |
-
def forward(self,
|
| 525 |
-
vertex_logits_pred: torch.Tensor, # (B,1,D,H,W)
|
| 526 |
-
refinement_pred: torch.Tensor, # (B,3,D,H,W)
|
| 527 |
-
vertex_gt: torch.Tensor, # (B,D,H,W), 0/1
|
| 528 |
-
refinement_gt: torch.Tensor # (B,3,D,H,W)
|
| 529 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 530 |
-
|
| 531 |
-
# logits & gt
|
| 532 |
-
logits = vertex_logits_pred.squeeze(1) # (B,D,H,W)
|
| 533 |
-
gt = vertex_gt.float() # (B,D,H,W)
|
| 534 |
-
|
| 535 |
-
# apply gaussian blur on gt
|
| 536 |
-
gt_unsq = gt.unsqueeze(1) # (B,1,D,H,W)
|
| 537 |
-
gt_blur = F.conv3d(gt_unsq, self.gaussian_kernel, padding=self.pad) # (B,1,D,H,W)
|
| 538 |
-
gt_blur = gt_blur.clamp(0, 1) # ensure values are in [0, 1]
|
| 539 |
-
gt_smooth = gt_blur.squeeze(1) # (B,D,H,W)
|
| 540 |
-
|
| 541 |
-
# 1) Weighted BCE loss - positive when gt_smooth > 1e-3
|
| 542 |
-
pos_mask = gt_smooth > 1e-3
|
| 543 |
-
neg_mask = ~pos_mask
|
| 544 |
-
|
| 545 |
-
# Compute BCE separately for positive and negative samples
|
| 546 |
-
bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
| 547 |
-
bce_all = bce_loss_fn(logits, gt_smooth)
|
| 548 |
-
|
| 549 |
-
# Calculate weighted BCE
|
| 550 |
-
pos_weight = 1.0
|
| 551 |
-
neg_weight = 1.0
|
| 552 |
-
|
| 553 |
-
if pos_mask.sum() > 0 and neg_mask.sum() > 0:
|
| 554 |
-
pos_loss = bce_all[pos_mask].mean()
|
| 555 |
-
neg_loss = bce_all[neg_mask].mean()
|
| 556 |
-
bce = pos_weight * pos_loss + neg_weight * neg_loss
|
| 557 |
-
elif pos_mask.sum() > 0:
|
| 558 |
-
bce = pos_weight * bce_all[pos_mask].mean()
|
| 559 |
-
elif neg_mask.sum() > 0:
|
| 560 |
-
bce = neg_weight * bce_all[neg_mask].mean()
|
| 561 |
-
else:
|
| 562 |
-
bce = torch.tensor(0.0, device=logits.device)
|
| 563 |
-
|
| 564 |
-
# 2) Dice loss
|
| 565 |
-
prob = torch.sigmoid(logits)
|
| 566 |
-
gt_smooth_round = (gt_smooth > 0.5).float() # binary mask
|
| 567 |
-
intersection = (prob * gt_smooth_round).sum(dim=[1,2,3])
|
| 568 |
-
union = prob.sum(dim=[1,2,3]) + gt_smooth_round.sum(dim=[1,2,3])
|
| 569 |
-
dice_score = (2 * intersection + self.eps) / (union + self.eps)
|
| 570 |
-
dice_loss = 1 - dice_score.mean()
|
| 571 |
-
|
| 572 |
-
vertex_loss = bce + self.dice_weight * dice_loss
|
| 573 |
-
|
| 574 |
-
# 3) Refinement MSE (only where original gt==1)
|
| 575 |
-
mask_pos = (gt > 0.5).unsqueeze(1) # use hard mask for offsets
|
| 576 |
-
if mask_pos.sum() > 0:
|
| 577 |
-
pred_offsets = refinement_pred[mask_pos.expand_as(refinement_pred)] \
|
| 578 |
-
.view(-1, 3)
|
| 579 |
-
gt_offsets = refinement_gt[mask_pos.expand_as(refinement_gt)] \
|
| 580 |
-
.view(-1, 3)
|
| 581 |
-
refinement_loss = self.mse_loss(pred_offsets, gt_offsets)
|
| 582 |
-
else:
|
| 583 |
-
refinement_loss = torch.tensor(0., device=logits.device)
|
| 584 |
-
|
| 585 |
-
# 4) Total loss
|
| 586 |
-
total_loss = (self.vertex_weight * vertex_loss +
|
| 587 |
-
self.refinement_weight * refinement_loss)
|
| 588 |
-
|
| 589 |
-
return total_loss, vertex_loss, refinement_loss
|
| 590 |
-
|
| 591 |
-
def train_epoch(model, dataloader, optimizer, criterion, device, current_epoch: int):
|
| 592 |
-
model.train()
|
| 593 |
-
total_loss_epoch = 0.0
|
| 594 |
-
vertex_loss_epoch = 0.0
|
| 595 |
-
refinement_loss_epoch = 0.0
|
| 596 |
-
|
| 597 |
-
for batch_idx, (voxel_grid_batch, vertex_labels_batch, refinement_targets_batch, _) in enumerate(dataloader):
|
| 598 |
-
voxel_grid_batch = voxel_grid_batch.to(device)
|
| 599 |
-
vertex_labels_batch = vertex_labels_batch.to(device)
|
| 600 |
-
refinement_targets_batch = refinement_targets_batch.to(device)
|
| 601 |
-
|
| 602 |
-
if False:
|
| 603 |
-
print(f'Epoch {current_epoch+1}, Batch {batch_idx+1}/{len(dataloader)}')
|
| 604 |
-
|
| 605 |
-
sample_voxel_features = voxel_grid_batch[0].cpu().numpy()
|
| 606 |
-
sample_gt_labels = vertex_labels_batch[0].cpu().numpy()
|
| 607 |
-
sample_gt_refinement = refinement_targets_batch[0].cpu().numpy()
|
| 608 |
-
|
| 609 |
-
summed_xyz_in_voxels = sample_voxel_features[:3]
|
| 610 |
-
occupied_voxel_mask = np.any(summed_xyz_in_voxels != 0, axis=0)
|
| 611 |
-
|
| 612 |
-
plotter = pv.Plotter(window_size=[800,600])
|
| 613 |
-
plotter.background_color = 'white'
|
| 614 |
-
|
| 615 |
-
if np.any(occupied_voxel_mask):
|
| 616 |
-
occupied_voxel_indices = np.array(np.where(occupied_voxel_mask)).T
|
| 617 |
-
input_points_display = pv.PolyData(occupied_voxel_indices + 0.5)
|
| 618 |
-
plotter.add_mesh(input_points_display, color='cornflowerblue', point_size=5, render_points_as_spheres=True, label='Occupied Voxels (Centers)')
|
| 619 |
-
|
| 620 |
-
gt_vertex_voxel_mask = sample_gt_labels > 0.5
|
| 621 |
-
if np.any(gt_vertex_voxel_mask):
|
| 622 |
-
gt_vertex_indices_int = np.array(np.where(gt_vertex_voxel_mask)).T
|
| 623 |
-
gt_offsets = sample_gt_refinement[:, gt_vertex_voxel_mask].T
|
| 624 |
-
gt_vertex_positions_grid_space = gt_vertex_indices_int.astype(float) + 0.5 + gt_offsets
|
| 625 |
-
|
| 626 |
-
target_vertices_display = pv.PolyData(gt_vertex_positions_grid_space)
|
| 627 |
-
plotter.add_mesh(target_vertices_display, color='crimson', point_size=10, render_points_as_spheres=True, label='Target Vertices (GT)')
|
| 628 |
-
|
| 629 |
-
plotter.show(title=f"Debug Viz E{current_epoch+1} B{batch_idx+1}", auto_close=False)
|
| 630 |
-
else:
|
| 631 |
-
print(f"Epoch {current_epoch+1} Batch {batch_idx+1}: No data to visualize for the first sample.")
|
| 632 |
-
|
| 633 |
-
optimizer.zero_grad()
|
| 634 |
-
vertex_logits_pred, refinement_pred = model(voxel_grid_batch)
|
| 635 |
-
|
| 636 |
-
loss, vertex_loss, refinement_loss = criterion(
|
| 637 |
-
vertex_logits_pred, refinement_pred, vertex_labels_batch, refinement_targets_batch
|
| 638 |
-
)
|
| 639 |
-
|
| 640 |
-
print(f"Batch {batch_idx+1}/{len(dataloader)}: Loss={loss.item():.4f}, Vertex Loss={vertex_loss.item():.4f}, Refinement Loss={refinement_loss.item():.4f}")
|
| 641 |
-
|
| 642 |
-
if loss > 0.000001:
|
| 643 |
-
loss.backward()
|
| 644 |
-
optimizer.step()
|
| 645 |
-
|
| 646 |
-
total_loss_epoch += loss.item()
|
| 647 |
-
vertex_loss_epoch += vertex_loss.item()
|
| 648 |
-
refinement_loss_epoch += refinement_loss.item()
|
| 649 |
-
|
| 650 |
-
if (batch_idx + 1) % 200 == 0:
|
| 651 |
-
checkpoint_path = f"model_epoch_{current_epoch+1}_batch_{batch_idx+1}_grid_128v7.pth" # Consider updating filename if grid size changes
|
| 652 |
-
torch.save(model.state_dict(), checkpoint_path)
|
| 653 |
-
print(f"Saved batch checkpoint: {checkpoint_path}")
|
| 654 |
-
|
| 655 |
-
avg_total_loss = total_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
|
| 656 |
-
avg_vertex_loss = vertex_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
|
| 657 |
-
avg_refinement_loss = refinement_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
|
| 658 |
-
|
| 659 |
-
return avg_total_loss, avg_vertex_loss, avg_refinement_loss
|
| 660 |
-
|
| 661 |
-
def train_model(data_folder: str = "data", num_epochs: int = 100, batch_size: int = 4, neg_pos_ratio_val: float = 1.0):
|
| 662 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 663 |
-
print(f"Using device: {device}")
|
| 664 |
-
|
| 665 |
-
data_files = get_data_files(data_folder)
|
| 666 |
-
if not data_files:
|
| 667 |
-
print(f"No data files found in {data_folder}. Exiting.")
|
| 668 |
-
return
|
| 669 |
-
|
| 670 |
-
GRID_SIZE_CFG = 64
|
| 671 |
-
VOXEL_SIZE_CFG = 0.75
|
| 672 |
-
|
| 673 |
-
dataset = VoxelDataset(data_files, voxel_size=VOXEL_SIZE_CFG, grid_size=GRID_SIZE_CFG)
|
| 674 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
|
| 675 |
-
|
| 676 |
-
model = VoxelUNet().to(device)
|
| 677 |
-
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 678 |
-
|
| 679 |
-
criterion = CombinedLoss(
|
| 680 |
-
vertex_weight=1.0,
|
| 681 |
-
refinement_weight=0.0,
|
| 682 |
-
dice_weight=0.2
|
| 683 |
-
).to(device)
|
| 684 |
-
|
| 685 |
-
print(f"Starting training: {num_epochs} epochs, Batch size: {batch_size}, Grid size: {GRID_SIZE_CFG}, Voxel size: {VOXEL_SIZE_CFG}, Initial LR: {optimizer.param_groups[0]['lr']}")
|
| 686 |
-
|
| 687 |
-
for epoch in range(num_epochs):
|
| 688 |
-
print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
|
| 689 |
-
|
| 690 |
-
avg_loss, avg_vertex_loss, avg_refinement_loss = train_epoch(
|
| 691 |
-
model, dataloader, optimizer, criterion, device, epoch
|
| 692 |
-
)
|
| 693 |
-
|
| 694 |
-
print(f"Epoch {epoch+1} Summary: Avg Loss: {avg_loss:.4f}, "
|
| 695 |
-
f"Avg Vertex Loss: {avg_vertex_loss:.4f}, "
|
| 696 |
-
f"Avg Refinement Loss: {avg_refinement_loss:.4f}, "
|
| 697 |
-
f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
|
| 698 |
-
|
| 699 |
-
checkpoint_path = f"model_epoch_{epoch+1}_grid{GRID_SIZE_CFG}_smooth_bal{neg_pos_ratio_val}_v7.pth"
|
| 700 |
-
torch.save(model.state_dict(), checkpoint_path)
|
| 701 |
-
print(f"Saved checkpoint: {checkpoint_path}")
|
| 702 |
-
|
| 703 |
-
final_model_path = f"final_model_grid{GRID_SIZE_CFG}_epochs{num_epochs}_smooth_bal{neg_pos_ratio_val}_v7.pth"
|
| 704 |
-
torch.save(model.state_dict(), final_model_path)
|
| 705 |
-
print(f"Training completed! Final model saved as {final_model_path}")
|
| 706 |
-
|
| 707 |
-
def load_model_for_inference(model_path: str, device: torch.device,
|
| 708 |
-
in_channels: int = 14, base_channels: int = 32) -> VoxelUNet:
|
| 709 |
-
"""Load a VoxelUNet model for inference."""
|
| 710 |
-
model = VoxelUNet(in_channels=in_channels, base_channels=base_channels)
|
| 711 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 712 |
-
model.to(device)
|
| 713 |
-
model.eval()
|
| 714 |
-
print(f"Model loaded from {model_path} and set to evaluation mode on {device}.")
|
| 715 |
-
return model
|
| 716 |
-
|
| 717 |
-
def predict_vertices(model: VoxelUNet,
|
| 718 |
-
point_cloud_14d: np.ndarray,
|
| 719 |
-
grid_size: int,
|
| 720 |
-
device: torch.device,
|
| 721 |
-
voxel_size_metric: float = 0.35, # Added for consistency, default matches voxelize_points
|
| 722 |
-
vertex_threshold: float = 0.5) -> np.ndarray:
|
| 723 |
-
"""
|
| 724 |
-
Predict vertices from a 14D point cloud.
|
| 725 |
-
|
| 726 |
-
Args:
|
| 727 |
-
model: The trained VoxelUNet model.
|
| 728 |
-
point_cloud_14d: (N, 14) NumPy array of the input point cloud.
|
| 729 |
-
grid_size: The size of the voxel grid along X and Y dimensions (must match training).
|
| 730 |
-
device: PyTorch device ('cuda' or 'cpu').
|
| 731 |
-
voxel_size_metric: The metric size of each voxel (must match training).
|
| 732 |
-
vertex_threshold: Threshold for classifying a voxel as containing a vertex.
|
| 733 |
-
|
| 734 |
-
Returns:
|
| 735 |
-
predicted_vertices_original_space: (M, 3) NumPy array of predicted vertex
|
| 736 |
-
coordinates in the original point cloud space (X, Y, Z order).
|
| 737 |
-
Returns an empty array if no vertices are predicted
|
| 738 |
-
or if the input point cloud results in an empty voxel grid.
|
| 739 |
-
"""
|
| 740 |
-
voxel_grid_tensor, _, scale_info = voxelize_points(
|
| 741 |
-
point_cloud_14d,
|
| 742 |
-
grid_size_xy=grid_size,
|
| 743 |
-
voxel_size_metric=voxel_size_metric
|
| 744 |
-
)
|
| 745 |
-
|
| 746 |
-
# Check if voxelization produced a valid grid (e.g., if input point cloud was empty)
|
| 747 |
-
# voxelize_points returns a zero tensor for grid if input points are empty.
|
| 748 |
-
# If voxel_grid_tensor is all zeros and no points were input, scale_info might be default.
|
| 749 |
-
if voxel_grid_tensor.sum() == 0 and point_cloud_14d.shape[0] == 0:
|
| 750 |
-
# This case implies empty input point cloud, voxelize_points handles this.
|
| 751 |
-
# Predictions will naturally be empty if the grid is empty.
|
| 752 |
-
pass # Continue, model will predict on zero grid.
|
| 753 |
-
|
| 754 |
-
input_tensor = voxel_grid_tensor.unsqueeze(0).to(device)
|
| 755 |
-
|
| 756 |
-
with torch.no_grad():
|
| 757 |
-
vertex_logits_pred_tensor, refinement_pred_tensor = model(input_tensor)
|
| 758 |
-
|
| 759 |
-
vertex_prob_pred_tensor = torch.sigmoid(vertex_logits_pred_tensor)
|
| 760 |
-
|
| 761 |
-
vertex_prob_pred_np = vertex_prob_pred_tensor.squeeze(0).squeeze(0).cpu().numpy()
|
| 762 |
-
refinement_pred_np = refinement_pred_tensor.squeeze(0).cpu().numpy() # Shape (3, D, H, W) -> (dx,dy,dz channels)
|
| 763 |
-
|
| 764 |
-
print(f"Vertex Probabilities Stats: Min={np.min(vertex_prob_pred_np):.4f}, Max={np.max(vertex_prob_pred_np):.4f}, Mean={np.mean(vertex_prob_pred_np):.4f}, Median={np.median(vertex_prob_pred_np):.4f}")
|
| 765 |
-
if refinement_pred_np.size > 0:
|
| 766 |
-
print(f"Refinement Predictions Stats: Min={np.min(refinement_pred_np):.4f}, Max={np.max(refinement_pred_np):.4f}, Mean={np.mean(refinement_pred_np):.4f}, Median={np.median(refinement_pred_np):.4f}")
|
| 767 |
-
for i in range(refinement_pred_np.shape[0]): # Iterate over dx, dy, dz components
|
| 768 |
-
print(f" Refinement Dim {i} (dx,dy,dz order) Stats: Min={np.min(refinement_pred_np[i]):.4f}, Max={np.max(refinement_pred_np[i]):.4f}, Mean={np.mean(refinement_pred_np[i]):.4f}, Median={np.median(refinement_pred_np[i]):.4f}")
|
| 769 |
-
else:
|
| 770 |
-
print("Refinement Predictions Stats: Array is empty.")
|
| 771 |
-
|
| 772 |
-
predicted_mask = vertex_prob_pred_np > vertex_threshold
|
| 773 |
-
# predicted_voxel_indices are (N_preds, 3) with columns (idx_z, idx_y, idx_x)
|
| 774 |
-
predicted_voxel_indices_zyx = np.argwhere(predicted_mask)
|
| 775 |
-
|
| 776 |
-
if not predicted_voxel_indices_zyx.size:
|
| 777 |
-
return np.empty((0, 3), dtype=np.float32)
|
| 778 |
-
|
| 779 |
-
# Extract refinement offsets for the predicted voxels
|
| 780 |
-
# offsets_channels_first will be (3, N_preds) where channels are (dx, dy, dz)
|
| 781 |
-
offsets_channels_first = refinement_pred_np[:,
|
| 782 |
-
predicted_voxel_indices_zyx[:, 0], # z_indices
|
| 783 |
-
predicted_voxel_indices_zyx[:, 1], # y_indices
|
| 784 |
-
predicted_voxel_indices_zyx[:, 2]] # x_indices
|
| 785 |
-
|
| 786 |
-
# Transpose to (N_preds, 3) where columns are (dx, dy, dz)
|
| 787 |
-
offsets_xyz_order = offsets_channels_first.T
|
| 788 |
-
|
| 789 |
-
# Calculate refined coordinates in continuous voxel grid space (X, Y, Z order)
|
| 790 |
-
# Voxel center is at index + 0.5
|
| 791 |
-
# Refinement is added to this center.
|
| 792 |
-
# predicted_voxel_indices_zyx[:, 2] is x_idx
|
| 793 |
-
# predicted_voxel_indices_zyx[:, 1] is y_idx
|
| 794 |
-
# predicted_voxel_indices_zyx[:, 0] is z_idx
|
| 795 |
-
|
| 796 |
-
# offsets_xyz_order[:, 0] is dx
|
| 797 |
-
# offsets_xyz_order[:, 1] is dy
|
| 798 |
-
# offsets_xyz_order[:, 2] is dz
|
| 799 |
-
|
| 800 |
-
refined_x_grid = predicted_voxel_indices_zyx[:, 2].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 0]
|
| 801 |
-
refined_y_grid = predicted_voxel_indices_zyx[:, 1].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 1]
|
| 802 |
-
refined_z_grid = predicted_voxel_indices_zyx[:, 0].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 2]
|
| 803 |
-
|
| 804 |
-
# Stack to get (N_preds, 3) array in (X, Y, Z) order
|
| 805 |
-
refined_grid_coords_xyz = np.stack((refined_x_grid, refined_y_grid, refined_z_grid), axis=-1)
|
| 806 |
-
|
| 807 |
-
# Convert refined grid coordinates to original metric space
|
| 808 |
-
grid_origin_metric = np.array(scale_info['grid_origin_metric']) # (ox, oy, oz)
|
| 809 |
-
# Voxel_size_metric from scale_info should match the input voxel_size_metric parameter
|
| 810 |
-
current_voxel_size_metric = scale_info['voxel_size_metric']
|
| 811 |
-
|
| 812 |
-
# predicted_vertices_original_space are (N_preds, 3) in (X,Y,Z) order
|
| 813 |
-
predicted_vertices_original_space = refined_grid_coords_xyz * current_voxel_size_metric + grid_origin_metric
|
| 814 |
-
|
| 815 |
-
return predicted_vertices_original_space.astype(np.float32)
|
| 816 |
-
|
| 817 |
-
# Simple inference script
|
| 818 |
-
def run_inference(model_path: str,
|
| 819 |
-
data_file_path: str,
|
| 820 |
-
output_file: str = None,
|
| 821 |
-
grid_size: int = 128,
|
| 822 |
-
voxel_size: float = 0.5,
|
| 823 |
-
vertex_threshold: float = 0.5):
|
| 824 |
-
"""
|
| 825 |
-
Run inference on all data files in a directory, visualize with pyvista, and save results.
|
| 826 |
-
"""
|
| 827 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 828 |
-
print(f"Using device: {device}")
|
| 829 |
-
|
| 830 |
-
# Load model
|
| 831 |
-
model = load_model_for_inference(model_path, device)
|
| 832 |
-
|
| 833 |
-
# Get all data files from the directory
|
| 834 |
-
data_files = get_data_files(data_file_path)
|
| 835 |
-
if not data_files:
|
| 836 |
-
print(f"No data files found in {data_file_path}")
|
| 837 |
-
return
|
| 838 |
-
|
| 839 |
-
print(f"Found {len(data_files)} data files to process")
|
| 840 |
-
|
| 841 |
-
for i, file_path in enumerate(data_files):
|
| 842 |
-
print(f"\n--- Processing file {i+1}/{len(data_files)}: {os.path.basename(file_path)} ---")
|
| 843 |
-
|
| 844 |
-
# Load input data
|
| 845 |
-
try:
|
| 846 |
-
data = load_data(file_path)
|
| 847 |
-
except Exception as e:
|
| 848 |
-
print(f"Error loading {file_path}: {e}")
|
| 849 |
-
continue
|
| 850 |
-
|
| 851 |
-
if 'pcloud_14d' not in data:
|
| 852 |
-
print(f"Error: File {file_path} does not contain 'pcloud_14d' key, skipping")
|
| 853 |
-
continue
|
| 854 |
-
|
| 855 |
-
# Extract original point cloud and ground-truth vertices
|
| 856 |
-
pcloud = data['pcloud_14d'][:, :3] # (N,3)
|
| 857 |
-
gt_vertices = np.array(data.get('wf_vertices', [])) # (M,3) or empty
|
| 858 |
-
|
| 859 |
-
print(f"Input point cloud shape: {pcloud.shape}")
|
| 860 |
-
if gt_vertices.size:
|
| 861 |
-
print(f"GT vertices shape: {gt_vertices.shape}")
|
| 862 |
-
|
| 863 |
-
# Run prediction
|
| 864 |
-
print("Running inference...")
|
| 865 |
-
try:
|
| 866 |
-
predicted_vertices = predict_vertices(
|
| 867 |
-
model=model,
|
| 868 |
-
point_cloud_14d=data['pcloud_14d'],
|
| 869 |
-
grid_size=grid_size,
|
| 870 |
-
device=device,
|
| 871 |
-
voxel_size_metric=voxel_size,
|
| 872 |
-
vertex_threshold=vertex_threshold
|
| 873 |
-
)
|
| 874 |
-
except Exception as e:
|
| 875 |
-
print(f"Error during prediction for {file_path}: {e}")
|
| 876 |
-
continue
|
| 877 |
-
|
| 878 |
-
print(f"Predicted {len(predicted_vertices)} vertices")
|
| 879 |
-
|
| 880 |
-
# --- Visualization ---
|
| 881 |
-
plotter = pv.Plotter(window_size=[800,600])
|
| 882 |
-
plotter.background_color = 'white'
|
| 883 |
-
|
| 884 |
-
# Original point cloud in light gray
|
| 885 |
-
if pcloud.size:
|
| 886 |
-
pc_cloud = pv.PolyData(pcloud)
|
| 887 |
-
plotter.add_mesh(pc_cloud, color='lightgray', point_size=2, render_points_as_spheres=True, label='Input PC')
|
| 888 |
-
|
| 889 |
-
# Ground-truth vertices in red
|
| 890 |
-
if gt_vertices.size:
|
| 891 |
-
gt_pd = pv.PolyData(gt_vertices)
|
| 892 |
-
plotter.add_mesh(gt_pd, color='red', point_size=8, render_points_as_spheres=True, label='GT Vertices')
|
| 893 |
-
|
| 894 |
-
# Predicted vertices in blue
|
| 895 |
-
if predicted_vertices.size:
|
| 896 |
-
pred_pd = pv.PolyData(predicted_vertices)
|
| 897 |
-
plotter.add_mesh(pred_pd, color='blue', point_size=8, render_points_as_spheres=True, label='Predicted Vertices')
|
| 898 |
-
|
| 899 |
-
plotter.add_legend()
|
| 900 |
-
plotter.show(title=os.path.basename(file_path))
|
| 901 |
-
|
| 902 |
-
# Prepare output data
|
| 903 |
-
output_data = {
|
| 904 |
-
'predicted_vertices': predicted_vertices,
|
| 905 |
-
'input_file': file_path,
|
| 906 |
-
'model_used': model_path,
|
| 907 |
-
'grid_size': grid_size,
|
| 908 |
-
'voxel_size': voxel_size,
|
| 909 |
-
'vertex_threshold': vertex_threshold,
|
| 910 |
-
'original_data': data
|
| 911 |
-
}
|
| 912 |
-
|
| 913 |
-
# Save results
|
| 914 |
-
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 915 |
-
output_filename = f"{base_name}_predictions"
|
| 916 |
-
try:
|
| 917 |
-
save_data(output_data, output_filename)
|
| 918 |
-
print(f"Results saved to: {output_filename}.pkl")
|
| 919 |
-
except Exception as e:
|
| 920 |
-
print(f"Error saving results for {file_path}: {e}")
|
| 921 |
-
|
| 922 |
-
print(f"\nCompleted processing {len(data_files)} files")
|
| 923 |
-
|
| 924 |
-
if __name__ == "__main__":
|
| 925 |
-
inference = False
|
| 926 |
-
|
| 927 |
-
data_folder_train = '/mnt/personal/skvrnjan/hoho_end/'
|
| 928 |
-
#data_folder_train = '/home/skvrnjan/personal/hoho_end'
|
| 929 |
-
num_epochs_train = 100
|
| 930 |
-
batch_size_train = 4
|
| 931 |
-
# This parameter now controls the ratio of negative to positive samples for BCE loss
|
| 932 |
-
negative_to_positive_bce_ratio = 1
|
| 933 |
-
|
| 934 |
-
if inference:
|
| 935 |
-
run_inference(model_path='/home/skvrnjan/personal/hoho/model_epoch_100_grid128_smooth_bal1.pth',
|
| 936 |
-
data_file_path=data_folder_train,
|
| 937 |
-
output_file=None,
|
| 938 |
-
grid_size=128,
|
| 939 |
-
voxel_size=0.5,
|
| 940 |
-
vertex_threshold=0.5
|
| 941 |
-
)
|
| 942 |
-
else:
|
| 943 |
-
train_model(data_folder=data_folder_train,
|
| 944 |
-
num_epochs=num_epochs_train,
|
| 945 |
-
batch_size=batch_size_train,
|
| 946 |
-
neg_pos_ratio_val=negative_to_positive_bce_ratio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet.py
DELETED
|
@@ -1,520 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class FastPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
|
| 14 |
-
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
|
| 15 |
-
Enhanced with deeper architecture and more parameters for better generalization.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1):
|
| 18 |
-
super(FastPointNet, self).__init__()
|
| 19 |
-
self.max_points = max_points
|
| 20 |
-
self.predict_score = predict_score
|
| 21 |
-
self.predict_class = predict_class
|
| 22 |
-
self.num_classes = num_classes
|
| 23 |
-
|
| 24 |
-
# Enhanced point-wise MLPs with deeper architecture
|
| 25 |
-
self.conv1 = nn.Conv1d(input_dim, 128, 1)
|
| 26 |
-
self.conv2 = nn.Conv1d(128, 256, 1)
|
| 27 |
-
self.conv3 = nn.Conv1d(256, 512, 1)
|
| 28 |
-
self.conv4 = nn.Conv1d(512, 1024, 1)
|
| 29 |
-
|
| 30 |
-
# Additional layers for better feature extraction
|
| 31 |
-
self.conv5 = nn.Conv1d(1024, 1024, 1)
|
| 32 |
-
self.conv6 = nn.Conv1d(1024, 2048, 1)
|
| 33 |
-
|
| 34 |
-
# Larger shared features
|
| 35 |
-
self.shared_fc1 = nn.Linear(2048, 1024)
|
| 36 |
-
self.shared_fc2 = nn.Linear(1024, 512)
|
| 37 |
-
|
| 38 |
-
# Enhanced position prediction head
|
| 39 |
-
self.pos_fc1 = nn.Linear(512, 512)
|
| 40 |
-
self.pos_fc2 = nn.Linear(512, 256)
|
| 41 |
-
self.pos_fc3 = nn.Linear(256, 128)
|
| 42 |
-
self.pos_fc4 = nn.Linear(128, output_dim)
|
| 43 |
-
|
| 44 |
-
# Enhanced score prediction head
|
| 45 |
-
if self.predict_score:
|
| 46 |
-
self.score_fc1 = nn.Linear(512, 512)
|
| 47 |
-
self.score_fc2 = nn.Linear(512, 256)
|
| 48 |
-
self.score_fc3 = nn.Linear(256, 128)
|
| 49 |
-
self.score_fc4 = nn.Linear(128, 64)
|
| 50 |
-
self.score_fc5 = nn.Linear(64, 1)
|
| 51 |
-
|
| 52 |
-
# Classification head
|
| 53 |
-
if self.predict_class:
|
| 54 |
-
self.class_fc1 = nn.Linear(512, 512)
|
| 55 |
-
self.class_fc2 = nn.Linear(512, 256)
|
| 56 |
-
self.class_fc3 = nn.Linear(256, 128)
|
| 57 |
-
self.class_fc4 = nn.Linear(128, 64)
|
| 58 |
-
self.class_fc5 = nn.Linear(64, num_classes)
|
| 59 |
-
|
| 60 |
-
# Batch normalization layers
|
| 61 |
-
self.bn1 = nn.BatchNorm1d(128)
|
| 62 |
-
self.bn2 = nn.BatchNorm1d(256)
|
| 63 |
-
self.bn3 = nn.BatchNorm1d(512)
|
| 64 |
-
self.bn4 = nn.BatchNorm1d(1024)
|
| 65 |
-
self.bn5 = nn.BatchNorm1d(1024)
|
| 66 |
-
self.bn6 = nn.BatchNorm1d(2048)
|
| 67 |
-
|
| 68 |
-
# Dropout with different rates
|
| 69 |
-
self.dropout_light = nn.Dropout(0.2)
|
| 70 |
-
self.dropout_medium = nn.Dropout(0.3)
|
| 71 |
-
self.dropout_heavy = nn.Dropout(0.4)
|
| 72 |
-
|
| 73 |
-
def forward(self, x):
|
| 74 |
-
"""
|
| 75 |
-
Forward pass
|
| 76 |
-
Args:
|
| 77 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 78 |
-
Returns:
|
| 79 |
-
Tuple containing predictions based on configuration:
|
| 80 |
-
- position: (batch_size, output_dim) tensor of predicted 3D coordinates
|
| 81 |
-
- score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True)
|
| 82 |
-
- classification: (batch_size, num_classes) tensor of class logits (if predict_class=True)
|
| 83 |
-
"""
|
| 84 |
-
batch_size = x.size(0)
|
| 85 |
-
|
| 86 |
-
# Enhanced point-wise feature extraction with residual-like connections
|
| 87 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 88 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 89 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 90 |
-
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 91 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 92 |
-
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 93 |
-
|
| 94 |
-
# Global max pooling with additional global average pooling
|
| 95 |
-
max_pool = torch.max(x6, 2)[0] # (batch_size, 2048)
|
| 96 |
-
avg_pool = torch.mean(x6, 2) # (batch_size, 2048)
|
| 97 |
-
|
| 98 |
-
# Combine max and average pooling for richer global features
|
| 99 |
-
global_features = max_pool + avg_pool # (batch_size, 2048)
|
| 100 |
-
|
| 101 |
-
# Enhanced shared features with residual connection
|
| 102 |
-
shared1 = F.relu(self.shared_fc1(global_features))
|
| 103 |
-
shared1 = self.dropout_light(shared1)
|
| 104 |
-
shared2 = F.relu(self.shared_fc2(shared1))
|
| 105 |
-
shared_features = self.dropout_medium(shared2)
|
| 106 |
-
|
| 107 |
-
# Enhanced position prediction with skip connections
|
| 108 |
-
pos1 = F.relu(self.pos_fc1(shared_features))
|
| 109 |
-
pos1 = self.dropout_light(pos1)
|
| 110 |
-
pos2 = F.relu(self.pos_fc2(pos1))
|
| 111 |
-
pos2 = self.dropout_medium(pos2)
|
| 112 |
-
pos3 = F.relu(self.pos_fc3(pos2))
|
| 113 |
-
pos3 = self.dropout_light(pos3)
|
| 114 |
-
position = self.pos_fc4(pos3)
|
| 115 |
-
|
| 116 |
-
outputs = [position]
|
| 117 |
-
|
| 118 |
-
if self.predict_score:
|
| 119 |
-
# Enhanced score prediction
|
| 120 |
-
score1 = F.relu(self.score_fc1(shared_features))
|
| 121 |
-
score1 = self.dropout_light(score1)
|
| 122 |
-
score2 = F.relu(self.score_fc2(score1))
|
| 123 |
-
score2 = self.dropout_medium(score2)
|
| 124 |
-
score3 = F.relu(self.score_fc3(score2))
|
| 125 |
-
score3 = self.dropout_light(score3)
|
| 126 |
-
score4 = F.relu(self.score_fc4(score3))
|
| 127 |
-
score4 = self.dropout_light(score4)
|
| 128 |
-
score = F.relu(self.score_fc5(score4)) # Ensure positive distance
|
| 129 |
-
outputs.append(score)
|
| 130 |
-
|
| 131 |
-
if self.predict_class:
|
| 132 |
-
# Classification prediction
|
| 133 |
-
class1 = F.relu(self.class_fc1(shared_features))
|
| 134 |
-
class1 = self.dropout_light(class1)
|
| 135 |
-
class2 = F.relu(self.class_fc2(class1))
|
| 136 |
-
class2 = self.dropout_medium(class2)
|
| 137 |
-
class3 = F.relu(self.class_fc3(class2))
|
| 138 |
-
class3 = self.dropout_light(class3)
|
| 139 |
-
class4 = F.relu(self.class_fc4(class3))
|
| 140 |
-
class4 = self.dropout_light(class4)
|
| 141 |
-
classification = self.class_fc5(class4) # Raw logits
|
| 142 |
-
outputs.append(classification)
|
| 143 |
-
|
| 144 |
-
# Return outputs based on configuration
|
| 145 |
-
if len(outputs) == 1:
|
| 146 |
-
return outputs[0] # Only position
|
| 147 |
-
elif len(outputs) == 2:
|
| 148 |
-
if self.predict_score:
|
| 149 |
-
return outputs[0], outputs[1] # position, score
|
| 150 |
-
else:
|
| 151 |
-
return outputs[0], outputs[1] # position, classification
|
| 152 |
-
else:
|
| 153 |
-
return outputs[0], outputs[1], outputs[2] # position, score, classification
|
| 154 |
-
|
| 155 |
-
class PatchDataset(Dataset):
|
| 156 |
-
"""
|
| 157 |
-
Dataset class for loading saved patches for PointNet training.
|
| 158 |
-
"""
|
| 159 |
-
|
| 160 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 161 |
-
self.dataset_dir = dataset_dir
|
| 162 |
-
self.max_points = max_points
|
| 163 |
-
self.augment = augment
|
| 164 |
-
|
| 165 |
-
# Load patch files
|
| 166 |
-
self.patch_files = []
|
| 167 |
-
for file in os.listdir(dataset_dir):
|
| 168 |
-
if file.endswith('.pkl'):
|
| 169 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 170 |
-
|
| 171 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 172 |
-
|
| 173 |
-
def __len__(self):
|
| 174 |
-
return len(self.patch_files)
|
| 175 |
-
|
| 176 |
-
def __getitem__(self, idx):
|
| 177 |
-
"""
|
| 178 |
-
Load and process a patch for training.
|
| 179 |
-
Returns:
|
| 180 |
-
patch_data: (7, max_points) tensor of point cloud data
|
| 181 |
-
target: (3,) tensor of target 3D coordinates
|
| 182 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 183 |
-
distance_to_gt: scalar tensor of distance from initial prediction to GT
|
| 184 |
-
classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
|
| 185 |
-
"""
|
| 186 |
-
patch_file = self.patch_files[idx]
|
| 187 |
-
|
| 188 |
-
with open(patch_file, 'rb') as f:
|
| 189 |
-
patch_info = pickle.load(f)
|
| 190 |
-
|
| 191 |
-
patch_7d = patch_info['patch_7d'] # (N, 7)
|
| 192 |
-
target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
|
| 193 |
-
initial_pred = patch_info.get('cluster_center', None) # (3,) or None
|
| 194 |
-
|
| 195 |
-
# Determine classification label based on GT vertex presence
|
| 196 |
-
has_gt_vertex = 1.0 if target is not None else 0.0
|
| 197 |
-
|
| 198 |
-
# Handle patches without ground truth
|
| 199 |
-
if target is None:
|
| 200 |
-
# Use a dummy target for consistency, but mark as invalid with classification
|
| 201 |
-
target = np.zeros(3)
|
| 202 |
-
else:
|
| 203 |
-
target = np.array(target)
|
| 204 |
-
|
| 205 |
-
# Pad or sample points to max_points
|
| 206 |
-
num_points = patch_7d.shape[0]
|
| 207 |
-
|
| 208 |
-
if num_points >= self.max_points:
|
| 209 |
-
# Randomly sample max_points
|
| 210 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 211 |
-
patch_sampled = patch_7d[indices]
|
| 212 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 213 |
-
else:
|
| 214 |
-
# Pad with zeros
|
| 215 |
-
patch_sampled = np.zeros((self.max_points, 7))
|
| 216 |
-
patch_sampled[:num_points] = patch_7d
|
| 217 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 218 |
-
valid_mask[:num_points] = True
|
| 219 |
-
|
| 220 |
-
# Data augmentation (only if GT vertex is present)
|
| 221 |
-
if self.augment and has_gt_vertex > 0:
|
| 222 |
-
patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target)
|
| 223 |
-
|
| 224 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 225 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (7, max_points)
|
| 226 |
-
target_tensor = torch.from_numpy(target).float() # (3,)
|
| 227 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 228 |
-
|
| 229 |
-
# Handle initial_pred
|
| 230 |
-
if initial_pred is not None:
|
| 231 |
-
initial_pred_tensor = torch.from_numpy(initial_pred).float()
|
| 232 |
-
else:
|
| 233 |
-
initial_pred_tensor = torch.zeros(3).float()
|
| 234 |
-
|
| 235 |
-
# Classification tensor
|
| 236 |
-
classification_tensor = torch.tensor(has_gt_vertex).float()
|
| 237 |
-
|
| 238 |
-
return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor
|
| 239 |
-
|
| 240 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 241 |
-
"""
|
| 242 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 243 |
-
|
| 244 |
-
Args:
|
| 245 |
-
patches: List of patch dictionaries from generate_patches()
|
| 246 |
-
dataset_dir: Directory to save the dataset
|
| 247 |
-
entry_id: Unique identifier for this entry/image
|
| 248 |
-
"""
|
| 249 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 250 |
-
|
| 251 |
-
for i, patch in enumerate(patches):
|
| 252 |
-
# Create unique filename
|
| 253 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 254 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 255 |
-
|
| 256 |
-
# Skip if file already exists
|
| 257 |
-
if os.path.exists(filepath):
|
| 258 |
-
continue
|
| 259 |
-
|
| 260 |
-
# Save patch data
|
| 261 |
-
with open(filepath, 'wb') as f:
|
| 262 |
-
pickle.dump(patch, f)
|
| 263 |
-
|
| 264 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 265 |
-
|
| 266 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 267 |
-
def collate_fn(batch):
|
| 268 |
-
valid_batch = []
|
| 269 |
-
for patch_data, target, valid_mask, initial_pred, classification in batch:
|
| 270 |
-
# Filter out invalid samples (no valid points)
|
| 271 |
-
if valid_mask.sum() > 0:
|
| 272 |
-
valid_batch.append((patch_data, target, valid_mask, initial_pred, classification))
|
| 273 |
-
|
| 274 |
-
if len(valid_batch) == 0:
|
| 275 |
-
return None
|
| 276 |
-
|
| 277 |
-
# Stack valid samples
|
| 278 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 279 |
-
targets = torch.stack([item[1] for item in valid_batch])
|
| 280 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 281 |
-
initial_preds = torch.stack([item[3] for item in valid_batch])
|
| 282 |
-
classifications = torch.stack([item[4] for item in valid_batch])
|
| 283 |
-
|
| 284 |
-
return patch_data, targets, valid_masks, initial_preds, classifications
|
| 285 |
-
|
| 286 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 287 |
-
def init_weights(m):
|
| 288 |
-
if isinstance(m, nn.Conv1d):
|
| 289 |
-
nn.init.xavier_uniform_(m.weight)
|
| 290 |
-
if m.bias is not None:
|
| 291 |
-
nn.init.zeros_(m.bias)
|
| 292 |
-
elif isinstance(m, nn.Linear):
|
| 293 |
-
nn.init.xavier_uniform_(m.weight)
|
| 294 |
-
if m.bias is not None:
|
| 295 |
-
nn.init.zeros_(m.bias)
|
| 296 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 297 |
-
nn.init.ones_(m.weight)
|
| 298 |
-
nn.init.zeros_(m.bias)
|
| 299 |
-
|
| 300 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
|
| 301 |
-
score_weight: float = 0.1, class_weight: float = 0.5):
|
| 302 |
-
"""
|
| 303 |
-
Train the FastPointNet model on saved patches.
|
| 304 |
-
|
| 305 |
-
Args:
|
| 306 |
-
dataset_dir: Directory containing saved patch files
|
| 307 |
-
model_save_path: Path to save the trained model
|
| 308 |
-
epochs: Number of training epochs
|
| 309 |
-
batch_size: Training batch size
|
| 310 |
-
lr: Learning rate
|
| 311 |
-
score_weight: Weight for the distance prediction loss
|
| 312 |
-
class_weight: Weight for the classification loss
|
| 313 |
-
"""
|
| 314 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 315 |
-
print(f"Training on device: {device}")
|
| 316 |
-
|
| 317 |
-
# Create dataset and dataloader
|
| 318 |
-
dataset = PatchDataset(dataset_dir, max_points=1024, augment=False)
|
| 319 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 320 |
-
|
| 321 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 322 |
-
collate_fn=collate_fn, drop_last=True)
|
| 323 |
-
|
| 324 |
-
# Initialize model with score and classification prediction
|
| 325 |
-
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1)
|
| 326 |
-
|
| 327 |
-
model.apply(init_weights)
|
| 328 |
-
model.to(device)
|
| 329 |
-
|
| 330 |
-
# Loss functions
|
| 331 |
-
position_criterion = nn.MSELoss()
|
| 332 |
-
score_criterion = nn.MSELoss()
|
| 333 |
-
classification_criterion = nn.BCEWithLogitsLoss()
|
| 334 |
-
|
| 335 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 336 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 337 |
-
|
| 338 |
-
# Training loop
|
| 339 |
-
model.train()
|
| 340 |
-
for epoch in range(epochs):
|
| 341 |
-
total_loss = 0.0
|
| 342 |
-
total_pos_loss = 0.0
|
| 343 |
-
total_score_loss = 0.0
|
| 344 |
-
total_class_loss = 0.0
|
| 345 |
-
num_batches = 0
|
| 346 |
-
|
| 347 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 348 |
-
if batch_data is None: # Skip invalid batches
|
| 349 |
-
continue
|
| 350 |
-
|
| 351 |
-
patch_data, targets, valid_masks, initial_preds, classifications = batch_data
|
| 352 |
-
patch_data = patch_data.to(device) # (batch_size, 7, max_points)
|
| 353 |
-
targets = targets.to(device) # (batch_size, 3)
|
| 354 |
-
classifications = classifications.to(device) # (batch_size,)
|
| 355 |
-
|
| 356 |
-
# Forward pass
|
| 357 |
-
optimizer.zero_grad()
|
| 358 |
-
predictions, predicted_scores, predicted_classes = model(patch_data)
|
| 359 |
-
|
| 360 |
-
# Compute actual distance from predictions to targets
|
| 361 |
-
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
|
| 362 |
-
|
| 363 |
-
# Only compute position and score losses for samples with GT vertices
|
| 364 |
-
has_gt_mask = classifications > 0.5
|
| 365 |
-
|
| 366 |
-
if has_gt_mask.sum() > 0:
|
| 367 |
-
# Position loss only for samples with GT vertices
|
| 368 |
-
pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
|
| 369 |
-
score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
|
| 370 |
-
else:
|
| 371 |
-
pos_loss = torch.tensor(0.0, device=device)
|
| 372 |
-
score_loss = torch.tensor(0.0, device=device)
|
| 373 |
-
|
| 374 |
-
# Classification loss for all samples
|
| 375 |
-
class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
|
| 376 |
-
|
| 377 |
-
# Combined loss
|
| 378 |
-
total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
|
| 379 |
-
|
| 380 |
-
# Backward pass
|
| 381 |
-
total_batch_loss.backward()
|
| 382 |
-
optimizer.step()
|
| 383 |
-
|
| 384 |
-
total_loss += total_batch_loss.item()
|
| 385 |
-
total_pos_loss += pos_loss.item()
|
| 386 |
-
total_score_loss += score_loss.item()
|
| 387 |
-
total_class_loss += class_loss.item()
|
| 388 |
-
num_batches += 1
|
| 389 |
-
|
| 390 |
-
if batch_idx % 50 == 0:
|
| 391 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 392 |
-
f"Total Loss: {total_batch_loss.item():.6f}, "
|
| 393 |
-
f"Pos Loss: {pos_loss.item():.6f}, "
|
| 394 |
-
f"Score Loss: {score_loss.item():.6f}, "
|
| 395 |
-
f"Class Loss: {class_loss.item():.6f}")
|
| 396 |
-
|
| 397 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 398 |
-
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
|
| 399 |
-
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
|
| 400 |
-
avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
|
| 401 |
-
|
| 402 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 403 |
-
f"Avg Total Loss: {avg_loss:.6f}, "
|
| 404 |
-
f"Avg Pos Loss: {avg_pos_loss:.6f}, "
|
| 405 |
-
f"Avg Score Loss: {avg_score_loss:.6f}, "
|
| 406 |
-
f"Avg Class Loss: {avg_class_loss:.6f}")
|
| 407 |
-
|
| 408 |
-
scheduler.step()
|
| 409 |
-
|
| 410 |
-
# Save model checkpoint every epoch
|
| 411 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 412 |
-
torch.save({
|
| 413 |
-
'model_state_dict': model.state_dict(),
|
| 414 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 415 |
-
'epoch': epoch + 1,
|
| 416 |
-
'loss': avg_loss,
|
| 417 |
-
}, checkpoint_path)
|
| 418 |
-
|
| 419 |
-
# Save the trained model
|
| 420 |
-
torch.save({
|
| 421 |
-
'model_state_dict': model.state_dict(),
|
| 422 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 423 |
-
'epoch': epochs,
|
| 424 |
-
}, model_save_path)
|
| 425 |
-
|
| 426 |
-
print(f"Model saved to {model_save_path}")
|
| 427 |
-
return model
|
| 428 |
-
|
| 429 |
-
def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
|
| 430 |
-
"""
|
| 431 |
-
Load a trained FastPointNet model.
|
| 432 |
-
|
| 433 |
-
Args:
|
| 434 |
-
model_path: Path to the saved model
|
| 435 |
-
device: Device to load the model on
|
| 436 |
-
predict_score: Whether the model predicts scores
|
| 437 |
-
|
| 438 |
-
Returns:
|
| 439 |
-
Loaded FastPointNet model
|
| 440 |
-
"""
|
| 441 |
-
if device is None:
|
| 442 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 443 |
-
|
| 444 |
-
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score)
|
| 445 |
-
|
| 446 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 447 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 448 |
-
|
| 449 |
-
model.to(device)
|
| 450 |
-
model.eval()
|
| 451 |
-
|
| 452 |
-
return model
|
| 453 |
-
|
| 454 |
-
def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]:
|
| 455 |
-
"""
|
| 456 |
-
Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet.
|
| 457 |
-
|
| 458 |
-
Args:
|
| 459 |
-
model: Trained FastPointNet model
|
| 460 |
-
patch: Dictionary containing patch data with 'patch_7d' and 'offset' keys
|
| 461 |
-
device: Device to run prediction on
|
| 462 |
-
|
| 463 |
-
Returns:
|
| 464 |
-
tuple of (predicted_coordinates, confidence_score, classification_score)
|
| 465 |
-
predicted_coordinates: (3,) numpy array of predicted 3D coordinates
|
| 466 |
-
confidence_score: float representing predicted distance to GT (lower is better)
|
| 467 |
-
classification_score: float representing probability of GT vertex presence (0-1)
|
| 468 |
-
"""
|
| 469 |
-
if device is None:
|
| 470 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 471 |
-
|
| 472 |
-
patch_7d = patch['patch_7d'] # (N, 7)
|
| 473 |
-
|
| 474 |
-
# Prepare input
|
| 475 |
-
max_points = 1024
|
| 476 |
-
num_points = patch_7d.shape[0]
|
| 477 |
-
|
| 478 |
-
if num_points >= max_points:
|
| 479 |
-
# Sample points
|
| 480 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 481 |
-
patch_sampled = patch_7d[indices]
|
| 482 |
-
else:
|
| 483 |
-
# Pad with zeros
|
| 484 |
-
patch_sampled = np.zeros((max_points, 7))
|
| 485 |
-
patch_sampled[:num_points] = patch_7d
|
| 486 |
-
|
| 487 |
-
# Convert to tensor
|
| 488 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 7, max_points)
|
| 489 |
-
patch_tensor = patch_tensor.to(device)
|
| 490 |
-
|
| 491 |
-
# Predict
|
| 492 |
-
with torch.no_grad():
|
| 493 |
-
outputs = model(patch_tensor)
|
| 494 |
-
|
| 495 |
-
if model.predict_score and model.predict_class:
|
| 496 |
-
position, score, classification = outputs
|
| 497 |
-
position = position.cpu().numpy().squeeze()
|
| 498 |
-
score = score.cpu().numpy().squeeze()
|
| 499 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
|
| 500 |
-
elif model.predict_score:
|
| 501 |
-
position, score = outputs
|
| 502 |
-
position = position.cpu().numpy().squeeze()
|
| 503 |
-
score = score.cpu().numpy().squeeze()
|
| 504 |
-
classification = None
|
| 505 |
-
elif model.predict_class:
|
| 506 |
-
position, classification = outputs
|
| 507 |
-
position = position.cpu().numpy().squeeze()
|
| 508 |
-
score = None
|
| 509 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
|
| 510 |
-
else:
|
| 511 |
-
position = outputs
|
| 512 |
-
position = position.cpu().numpy().squeeze()
|
| 513 |
-
score = None
|
| 514 |
-
classification = None
|
| 515 |
-
|
| 516 |
-
# Apply offset correction
|
| 517 |
-
offset = patch['cluster_center']
|
| 518 |
-
position += offset
|
| 519 |
-
|
| 520 |
-
return position, score, classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_class.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
@@ -403,3 +409,4 @@ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device:
|
|
| 403 |
predicted_class = int(probability > 0.5)
|
| 404 |
|
| 405 |
return predicted_class, probability
|
|
|
|
|
|
| 1 |
+
# This file defines a PointNet-based model for binary classification of 6D point cloud patches.
|
| 2 |
+
# It includes the model architecture (ClassificationPointNet), a custom dataset class
|
| 3 |
+
# (PatchClassificationDataset) for loading and augmenting patches, functions for saving
|
| 4 |
+
# patches to create a dataset, a training loop (train_pointnet), a function to load
|
| 5 |
+
# a trained model (load_pointnet_model), and a function for predicting class labels
|
| 6 |
+
# from new patches (predict_class_from_patch).
|
| 7 |
import os
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 409 |
predicted_class = int(probability > 0.5)
|
| 410 |
|
| 411 |
return predicted_class, probability
|
| 412 |
+
|
fast_pointnet_class_10d.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class ClassificationPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
PointNet implementation for binary classification from 10D point cloud patches.
|
| 14 |
-
Takes 10D point clouds and predicts binary classification (edge/not edge).
|
| 15 |
-
"""
|
| 16 |
-
def __init__(self, input_dim=10, max_points=1024):
|
| 17 |
-
super(ClassificationPointNet, self).__init__()
|
| 18 |
-
self.max_points = max_points
|
| 19 |
-
|
| 20 |
-
# Point-wise MLPs for feature extraction (deeper network)
|
| 21 |
-
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 22 |
-
self.conv2 = nn.Conv1d(64, 128, 1)
|
| 23 |
-
self.conv3 = nn.Conv1d(128, 256, 1)
|
| 24 |
-
self.conv4 = nn.Conv1d(256, 512, 1)
|
| 25 |
-
self.conv5 = nn.Conv1d(512, 1024, 1)
|
| 26 |
-
self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
|
| 27 |
-
|
| 28 |
-
# Classification head (deeper with more capacity)
|
| 29 |
-
self.fc1 = nn.Linear(2048, 1024)
|
| 30 |
-
self.fc2 = nn.Linear(1024, 512)
|
| 31 |
-
self.fc3 = nn.Linear(512, 256)
|
| 32 |
-
self.fc4 = nn.Linear(256, 128)
|
| 33 |
-
self.fc5 = nn.Linear(128, 64)
|
| 34 |
-
self.fc6 = nn.Linear(64, 1) # Single output for binary classification
|
| 35 |
-
|
| 36 |
-
# Batch normalization layers
|
| 37 |
-
self.bn1 = nn.BatchNorm1d(64)
|
| 38 |
-
self.bn2 = nn.BatchNorm1d(128)
|
| 39 |
-
self.bn3 = nn.BatchNorm1d(256)
|
| 40 |
-
self.bn4 = nn.BatchNorm1d(512)
|
| 41 |
-
self.bn5 = nn.BatchNorm1d(1024)
|
| 42 |
-
self.bn6 = nn.BatchNorm1d(2048)
|
| 43 |
-
|
| 44 |
-
# Dropout layers
|
| 45 |
-
self.dropout1 = nn.Dropout(0.3)
|
| 46 |
-
self.dropout2 = nn.Dropout(0.4)
|
| 47 |
-
self.dropout3 = nn.Dropout(0.5)
|
| 48 |
-
self.dropout4 = nn.Dropout(0.4)
|
| 49 |
-
self.dropout5 = nn.Dropout(0.3)
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
"""
|
| 53 |
-
Forward pass
|
| 54 |
-
Args:
|
| 55 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 56 |
-
Returns:
|
| 57 |
-
classification: (batch_size, 1) tensor of logits (sigmoid for probability)
|
| 58 |
-
"""
|
| 59 |
-
batch_size = x.size(0)
|
| 60 |
-
|
| 61 |
-
# Point-wise feature extraction
|
| 62 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 63 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 64 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 65 |
-
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 66 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 67 |
-
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 68 |
-
|
| 69 |
-
# Global max pooling
|
| 70 |
-
global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
|
| 71 |
-
|
| 72 |
-
# Classification head
|
| 73 |
-
x = F.relu(self.fc1(global_features))
|
| 74 |
-
x = self.dropout1(x)
|
| 75 |
-
x = F.relu(self.fc2(x))
|
| 76 |
-
x = self.dropout2(x)
|
| 77 |
-
x = F.relu(self.fc3(x))
|
| 78 |
-
x = self.dropout3(x)
|
| 79 |
-
x = F.relu(self.fc4(x))
|
| 80 |
-
x = self.dropout4(x)
|
| 81 |
-
x = F.relu(self.fc5(x))
|
| 82 |
-
x = self.dropout5(x)
|
| 83 |
-
classification = self.fc6(x) # (batch_size, 1)
|
| 84 |
-
|
| 85 |
-
return classification
|
| 86 |
-
|
| 87 |
-
class PatchClassificationDataset(Dataset):
|
| 88 |
-
"""
|
| 89 |
-
Dataset class for loading saved patches for PointNet classification training.
|
| 90 |
-
"""
|
| 91 |
-
|
| 92 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 93 |
-
self.dataset_dir = dataset_dir
|
| 94 |
-
self.max_points = max_points
|
| 95 |
-
self.augment = augment
|
| 96 |
-
|
| 97 |
-
# Load patch files
|
| 98 |
-
self.patch_files = []
|
| 99 |
-
for file in os.listdir(dataset_dir):
|
| 100 |
-
if file.endswith('.pkl'):
|
| 101 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 102 |
-
|
| 103 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 104 |
-
|
| 105 |
-
def __len__(self):
|
| 106 |
-
return len(self.patch_files)
|
| 107 |
-
|
| 108 |
-
def __getitem__(self, idx):
|
| 109 |
-
"""
|
| 110 |
-
Load and process a patch for training.
|
| 111 |
-
Returns:
|
| 112 |
-
patch_data: (10, max_points) tensor of point cloud data
|
| 113 |
-
label: scalar tensor for binary classification (0 or 1)
|
| 114 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 115 |
-
"""
|
| 116 |
-
patch_file = self.patch_files[idx]
|
| 117 |
-
|
| 118 |
-
with open(patch_file, 'rb') as f:
|
| 119 |
-
patch_info = pickle.load(f)
|
| 120 |
-
|
| 121 |
-
patch_10d = patch_info['patch_10d'] # (N, 10)
|
| 122 |
-
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
|
| 123 |
-
|
| 124 |
-
# Pad or sample points to max_points
|
| 125 |
-
num_points = patch_10d.shape[0]
|
| 126 |
-
|
| 127 |
-
if num_points >= self.max_points:
|
| 128 |
-
# Randomly sample max_points
|
| 129 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 130 |
-
patch_sampled = patch_10d[indices]
|
| 131 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 132 |
-
else:
|
| 133 |
-
# Pad with zeros
|
| 134 |
-
patch_sampled = np.zeros((self.max_points, 10))
|
| 135 |
-
patch_sampled[:num_points] = patch_10d
|
| 136 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 137 |
-
valid_mask[:num_points] = True
|
| 138 |
-
|
| 139 |
-
# Data augmentation
|
| 140 |
-
if self.augment:
|
| 141 |
-
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 142 |
-
|
| 143 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 144 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
|
| 145 |
-
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
|
| 146 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 147 |
-
|
| 148 |
-
return patch_tensor, label_tensor, valid_mask_tensor
|
| 149 |
-
|
| 150 |
-
def _augment_patch(self, patch, valid_mask):
|
| 151 |
-
"""
|
| 152 |
-
Apply data augmentation to the patch.
|
| 153 |
-
"""
|
| 154 |
-
valid_points = patch[valid_mask]
|
| 155 |
-
|
| 156 |
-
if len(valid_points) == 0:
|
| 157 |
-
return patch
|
| 158 |
-
|
| 159 |
-
# Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
|
| 160 |
-
angle = np.random.uniform(0, 2 * np.pi)
|
| 161 |
-
cos_angle = np.cos(angle)
|
| 162 |
-
sin_angle = np.sin(angle)
|
| 163 |
-
rotation_matrix = np.array([
|
| 164 |
-
[cos_angle, -sin_angle, 0],
|
| 165 |
-
[sin_angle, cos_angle, 0],
|
| 166 |
-
[0, 0, 1]
|
| 167 |
-
])
|
| 168 |
-
|
| 169 |
-
# Apply rotation to xyz coordinates (first 3 dimensions)
|
| 170 |
-
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 171 |
-
|
| 172 |
-
# Random jittering (only for xyz coordinates)
|
| 173 |
-
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
|
| 174 |
-
valid_points[:, :3] += noise
|
| 175 |
-
|
| 176 |
-
# Random scaling (only for xyz coordinates)
|
| 177 |
-
scale = np.random.uniform(0.9, 1.1)
|
| 178 |
-
valid_points[:, :3] *= scale
|
| 179 |
-
|
| 180 |
-
patch[valid_mask] = valid_points
|
| 181 |
-
return patch
|
| 182 |
-
|
| 183 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 184 |
-
"""
|
| 185 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 186 |
-
|
| 187 |
-
Args:
|
| 188 |
-
patches: List of patch dictionaries from generate_patches()
|
| 189 |
-
dataset_dir: Directory to save the dataset
|
| 190 |
-
entry_id: Unique identifier for this entry/image
|
| 191 |
-
"""
|
| 192 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 193 |
-
|
| 194 |
-
for i, patch in enumerate(patches):
|
| 195 |
-
# Create unique filename
|
| 196 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 197 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 198 |
-
|
| 199 |
-
# Skip if file already exists
|
| 200 |
-
if os.path.exists(filepath):
|
| 201 |
-
continue
|
| 202 |
-
|
| 203 |
-
# Save patch data
|
| 204 |
-
with open(filepath, 'wb') as f:
|
| 205 |
-
pickle.dump(patch, f)
|
| 206 |
-
|
| 207 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 208 |
-
|
| 209 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 210 |
-
def collate_fn(batch):
|
| 211 |
-
valid_batch = []
|
| 212 |
-
for patch_data, label, valid_mask in batch:
|
| 213 |
-
# Filter out invalid samples (no valid points)
|
| 214 |
-
if valid_mask.sum() > 0:
|
| 215 |
-
valid_batch.append((patch_data, label, valid_mask))
|
| 216 |
-
|
| 217 |
-
if len(valid_batch) == 0:
|
| 218 |
-
return None
|
| 219 |
-
|
| 220 |
-
# Stack valid samples
|
| 221 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 222 |
-
labels = torch.stack([item[1] for item in valid_batch])
|
| 223 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 224 |
-
|
| 225 |
-
return patch_data, labels, valid_masks
|
| 226 |
-
|
| 227 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 228 |
-
def init_weights(m):
|
| 229 |
-
if isinstance(m, nn.Conv1d):
|
| 230 |
-
nn.init.xavier_uniform_(m.weight)
|
| 231 |
-
if m.bias is not None:
|
| 232 |
-
nn.init.zeros_(m.bias)
|
| 233 |
-
elif isinstance(m, nn.Linear):
|
| 234 |
-
nn.init.xavier_uniform_(m.weight)
|
| 235 |
-
if m.bias is not None:
|
| 236 |
-
nn.init.zeros_(m.bias)
|
| 237 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 238 |
-
nn.init.ones_(m.weight)
|
| 239 |
-
nn.init.zeros_(m.bias)
|
| 240 |
-
|
| 241 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
|
| 242 |
-
lr: float = 0.001):
|
| 243 |
-
"""
|
| 244 |
-
Train the ClassificationPointNet model on saved patches.
|
| 245 |
-
|
| 246 |
-
Args:
|
| 247 |
-
dataset_dir: Directory containing saved patch files
|
| 248 |
-
model_save_path: Path to save the trained model
|
| 249 |
-
epochs: Number of training epochs
|
| 250 |
-
batch_size: Training batch size
|
| 251 |
-
lr: Learning rate
|
| 252 |
-
"""
|
| 253 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 254 |
-
print(f"Training on device: {device}")
|
| 255 |
-
|
| 256 |
-
# Create dataset and dataloader
|
| 257 |
-
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
|
| 258 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 259 |
-
|
| 260 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 261 |
-
collate_fn=collate_fn, drop_last=True)
|
| 262 |
-
|
| 263 |
-
# Initialize model
|
| 264 |
-
model = ClassificationPointNet(input_dim=10, max_points=1024)
|
| 265 |
-
model.apply(init_weights)
|
| 266 |
-
model.to(device)
|
| 267 |
-
|
| 268 |
-
# Loss function and optimizer (BCE for binary classification)
|
| 269 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 270 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 271 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 272 |
-
|
| 273 |
-
# Training loop
|
| 274 |
-
model.train()
|
| 275 |
-
for epoch in range(epochs):
|
| 276 |
-
total_loss = 0.0
|
| 277 |
-
correct = 0
|
| 278 |
-
total = 0
|
| 279 |
-
num_batches = 0
|
| 280 |
-
|
| 281 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 282 |
-
if batch_data is None: # Skip invalid batches
|
| 283 |
-
continue
|
| 284 |
-
|
| 285 |
-
patch_data, labels, valid_masks = batch_data
|
| 286 |
-
patch_data = patch_data.to(device) # (batch_size, 10, max_points)
|
| 287 |
-
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
|
| 288 |
-
|
| 289 |
-
# Forward pass
|
| 290 |
-
optimizer.zero_grad()
|
| 291 |
-
outputs = model(patch_data) # (batch_size, 1)
|
| 292 |
-
loss = criterion(outputs, labels)
|
| 293 |
-
|
| 294 |
-
# Backward pass
|
| 295 |
-
loss.backward()
|
| 296 |
-
optimizer.step()
|
| 297 |
-
|
| 298 |
-
# Statistics
|
| 299 |
-
total_loss += loss.item()
|
| 300 |
-
predicted = (torch.sigmoid(outputs) > 0.5).float()
|
| 301 |
-
total += labels.size(0)
|
| 302 |
-
correct += (predicted == labels).sum().item()
|
| 303 |
-
num_batches += 1
|
| 304 |
-
|
| 305 |
-
if batch_idx % 50 == 0:
|
| 306 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 307 |
-
f"Loss: {loss.item():.6f}, "
|
| 308 |
-
f"Accuracy: {100 * correct / total:.2f}%")
|
| 309 |
-
|
| 310 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 311 |
-
accuracy = 100 * correct / total if total > 0 else 0
|
| 312 |
-
|
| 313 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 314 |
-
f"Avg Loss: {avg_loss:.6f}, "
|
| 315 |
-
f"Accuracy: {accuracy:.2f}%")
|
| 316 |
-
|
| 317 |
-
scheduler.step()
|
| 318 |
-
|
| 319 |
-
# Save model checkpoint every epoch
|
| 320 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 321 |
-
torch.save({
|
| 322 |
-
'model_state_dict': model.state_dict(),
|
| 323 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 324 |
-
'epoch': epoch + 1,
|
| 325 |
-
'loss': avg_loss,
|
| 326 |
-
'accuracy': accuracy,
|
| 327 |
-
}, checkpoint_path)
|
| 328 |
-
|
| 329 |
-
# Save the trained model
|
| 330 |
-
torch.save({
|
| 331 |
-
'model_state_dict': model.state_dict(),
|
| 332 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 333 |
-
'epoch': epochs,
|
| 334 |
-
}, model_save_path)
|
| 335 |
-
|
| 336 |
-
print(f"Model saved to {model_save_path}")
|
| 337 |
-
return model
|
| 338 |
-
|
| 339 |
-
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
|
| 340 |
-
"""
|
| 341 |
-
Load a trained ClassificationPointNet model.
|
| 342 |
-
|
| 343 |
-
Args:
|
| 344 |
-
model_path: Path to the saved model
|
| 345 |
-
device: Device to load the model on
|
| 346 |
-
|
| 347 |
-
Returns:
|
| 348 |
-
Loaded ClassificationPointNet model
|
| 349 |
-
"""
|
| 350 |
-
if device is None:
|
| 351 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 352 |
-
|
| 353 |
-
model = ClassificationPointNet(input_dim=10, max_points=1024)
|
| 354 |
-
|
| 355 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 356 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 357 |
-
|
| 358 |
-
model.to(device)
|
| 359 |
-
model.eval()
|
| 360 |
-
|
| 361 |
-
return model
|
| 362 |
-
|
| 363 |
-
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
|
| 364 |
-
"""
|
| 365 |
-
Predict binary classification from a patch using trained PointNet.
|
| 366 |
-
|
| 367 |
-
Args:
|
| 368 |
-
model: Trained ClassificationPointNet model
|
| 369 |
-
patch: Dictionary containing patch data with 'patch_10d' key
|
| 370 |
-
device: Device to run prediction on
|
| 371 |
-
|
| 372 |
-
Returns:
|
| 373 |
-
tuple of (predicted_class, confidence)
|
| 374 |
-
predicted_class: int (0 for not edge, 1 for edge)
|
| 375 |
-
confidence: float representing confidence score (0-1)
|
| 376 |
-
"""
|
| 377 |
-
if device is None:
|
| 378 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 379 |
-
|
| 380 |
-
patch_10d = patch['patch_10d'] # (N, 10)
|
| 381 |
-
|
| 382 |
-
# Prepare input
|
| 383 |
-
max_points = 1024
|
| 384 |
-
num_points = patch_10d.shape[0]
|
| 385 |
-
|
| 386 |
-
if num_points >= max_points:
|
| 387 |
-
# Sample points
|
| 388 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 389 |
-
patch_sampled = patch_10d[indices]
|
| 390 |
-
else:
|
| 391 |
-
# Pad with zeros
|
| 392 |
-
patch_sampled = np.zeros((max_points, 10))
|
| 393 |
-
patch_sampled[:num_points] = patch_10d
|
| 394 |
-
|
| 395 |
-
# Convert to tensor
|
| 396 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
|
| 397 |
-
patch_tensor = patch_tensor.to(device)
|
| 398 |
-
|
| 399 |
-
# Predict
|
| 400 |
-
with torch.no_grad():
|
| 401 |
-
outputs = model(patch_tensor) # (1, 1)
|
| 402 |
-
probability = torch.sigmoid(outputs).item()
|
| 403 |
-
predicted_class = int(probability > 0.5)
|
| 404 |
-
|
| 405 |
-
return predicted_class, probability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_class_10d_2048.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class ClassificationPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
PointNet implementation for binary classification from 10D point cloud patches.
|
| 14 |
-
Takes 10D point clouds and predicts binary classification (edge/not edge).
|
| 15 |
-
"""
|
| 16 |
-
def __init__(self, input_dim=10, max_points=2048):
|
| 17 |
-
super(ClassificationPointNet, self).__init__()
|
| 18 |
-
self.max_points = max_points
|
| 19 |
-
|
| 20 |
-
# Point-wise MLPs for feature extraction (deeper network)
|
| 21 |
-
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 22 |
-
self.conv2 = nn.Conv1d(64, 128, 1)
|
| 23 |
-
self.conv3 = nn.Conv1d(128, 256, 1)
|
| 24 |
-
self.conv4 = nn.Conv1d(256, 512, 1)
|
| 25 |
-
self.conv5 = nn.Conv1d(512, 1024, 1)
|
| 26 |
-
self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
|
| 27 |
-
|
| 28 |
-
# Classification head (deeper with more capacity)
|
| 29 |
-
self.fc1 = nn.Linear(2048, 1024)
|
| 30 |
-
self.fc2 = nn.Linear(1024, 512)
|
| 31 |
-
self.fc3 = nn.Linear(512, 256)
|
| 32 |
-
self.fc4 = nn.Linear(256, 128)
|
| 33 |
-
self.fc5 = nn.Linear(128, 64)
|
| 34 |
-
self.fc6 = nn.Linear(64, 1) # Single output for binary classification
|
| 35 |
-
|
| 36 |
-
# Batch normalization layers
|
| 37 |
-
self.bn1 = nn.BatchNorm1d(64)
|
| 38 |
-
self.bn2 = nn.BatchNorm1d(128)
|
| 39 |
-
self.bn3 = nn.BatchNorm1d(256)
|
| 40 |
-
self.bn4 = nn.BatchNorm1d(512)
|
| 41 |
-
self.bn5 = nn.BatchNorm1d(1024)
|
| 42 |
-
self.bn6 = nn.BatchNorm1d(2048)
|
| 43 |
-
|
| 44 |
-
# Dropout layers
|
| 45 |
-
self.dropout1 = nn.Dropout(0.3)
|
| 46 |
-
self.dropout2 = nn.Dropout(0.4)
|
| 47 |
-
self.dropout3 = nn.Dropout(0.5)
|
| 48 |
-
self.dropout4 = nn.Dropout(0.4)
|
| 49 |
-
self.dropout5 = nn.Dropout(0.3)
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
"""
|
| 53 |
-
Forward pass
|
| 54 |
-
Args:
|
| 55 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 56 |
-
Returns:
|
| 57 |
-
classification: (batch_size, 1) tensor of logits (sigmoid for probability)
|
| 58 |
-
"""
|
| 59 |
-
batch_size = x.size(0)
|
| 60 |
-
|
| 61 |
-
# Point-wise feature extraction
|
| 62 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 63 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 64 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 65 |
-
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 66 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 67 |
-
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 68 |
-
|
| 69 |
-
# Global max pooling
|
| 70 |
-
global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
|
| 71 |
-
|
| 72 |
-
# Classification head
|
| 73 |
-
x = F.relu(self.fc1(global_features))
|
| 74 |
-
x = self.dropout1(x)
|
| 75 |
-
x = F.relu(self.fc2(x))
|
| 76 |
-
x = self.dropout2(x)
|
| 77 |
-
x = F.relu(self.fc3(x))
|
| 78 |
-
x = self.dropout3(x)
|
| 79 |
-
x = F.relu(self.fc4(x))
|
| 80 |
-
x = self.dropout4(x)
|
| 81 |
-
x = F.relu(self.fc5(x))
|
| 82 |
-
x = self.dropout5(x)
|
| 83 |
-
classification = self.fc6(x) # (batch_size, 1)
|
| 84 |
-
|
| 85 |
-
return classification
|
| 86 |
-
|
| 87 |
-
class PatchClassificationDataset(Dataset):
|
| 88 |
-
"""
|
| 89 |
-
Dataset class for loading saved patches for PointNet classification training.
|
| 90 |
-
"""
|
| 91 |
-
|
| 92 |
-
def __init__(self, dataset_dir: str, max_points: int = 2048, augment: bool = True):
|
| 93 |
-
self.dataset_dir = dataset_dir
|
| 94 |
-
self.max_points = max_points
|
| 95 |
-
self.augment = augment
|
| 96 |
-
|
| 97 |
-
# Load patch files
|
| 98 |
-
self.patch_files = []
|
| 99 |
-
for file in os.listdir(dataset_dir):
|
| 100 |
-
if file.endswith('.pkl'):
|
| 101 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 102 |
-
|
| 103 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 104 |
-
|
| 105 |
-
def __len__(self):
|
| 106 |
-
return len(self.patch_files)
|
| 107 |
-
|
| 108 |
-
def __getitem__(self, idx):
|
| 109 |
-
"""
|
| 110 |
-
Load and process a patch for training.
|
| 111 |
-
Returns:
|
| 112 |
-
patch_data: (10, max_points) tensor of point cloud data
|
| 113 |
-
label: scalar tensor for binary classification (0 or 1)
|
| 114 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 115 |
-
"""
|
| 116 |
-
patch_file = self.patch_files[idx]
|
| 117 |
-
|
| 118 |
-
with open(patch_file, 'rb') as f:
|
| 119 |
-
patch_info = pickle.load(f)
|
| 120 |
-
|
| 121 |
-
patch_10d = patch_info['patch_10d'] # (N, 10)
|
| 122 |
-
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
|
| 123 |
-
|
| 124 |
-
# Pad or sample points to max_points
|
| 125 |
-
num_points = patch_10d.shape[0]
|
| 126 |
-
|
| 127 |
-
if num_points >= self.max_points:
|
| 128 |
-
# Randomly sample max_points
|
| 129 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 130 |
-
patch_sampled = patch_10d[indices]
|
| 131 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 132 |
-
else:
|
| 133 |
-
# Pad with zeros
|
| 134 |
-
patch_sampled = np.zeros((self.max_points, 10))
|
| 135 |
-
patch_sampled[:num_points] = patch_10d
|
| 136 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 137 |
-
valid_mask[:num_points] = True
|
| 138 |
-
|
| 139 |
-
# Data augmentation
|
| 140 |
-
if self.augment:
|
| 141 |
-
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 142 |
-
|
| 143 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 144 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
|
| 145 |
-
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
|
| 146 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 147 |
-
|
| 148 |
-
return patch_tensor, label_tensor, valid_mask_tensor
|
| 149 |
-
|
| 150 |
-
def _augment_patch(self, patch, valid_mask):
|
| 151 |
-
"""
|
| 152 |
-
Apply data augmentation to the patch.
|
| 153 |
-
"""
|
| 154 |
-
valid_points = patch[valid_mask]
|
| 155 |
-
|
| 156 |
-
if len(valid_points) == 0:
|
| 157 |
-
return patch
|
| 158 |
-
|
| 159 |
-
# Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
|
| 160 |
-
angle = np.random.uniform(0, 2 * np.pi)
|
| 161 |
-
cos_angle = np.cos(angle)
|
| 162 |
-
sin_angle = np.sin(angle)
|
| 163 |
-
rotation_matrix = np.array([
|
| 164 |
-
[cos_angle, -sin_angle, 0],
|
| 165 |
-
[sin_angle, cos_angle, 0],
|
| 166 |
-
[0, 0, 1]
|
| 167 |
-
])
|
| 168 |
-
|
| 169 |
-
# Apply rotation to xyz coordinates (first 3 dimensions)
|
| 170 |
-
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 171 |
-
|
| 172 |
-
# Random jittering (only for xyz coordinates)
|
| 173 |
-
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
|
| 174 |
-
valid_points[:, :3] += noise
|
| 175 |
-
|
| 176 |
-
# Random scaling (only for xyz coordinates)
|
| 177 |
-
scale = np.random.uniform(0.9, 1.1)
|
| 178 |
-
valid_points[:, :3] *= scale
|
| 179 |
-
|
| 180 |
-
patch[valid_mask] = valid_points
|
| 181 |
-
return patch
|
| 182 |
-
|
| 183 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 184 |
-
"""
|
| 185 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 186 |
-
|
| 187 |
-
Args:
|
| 188 |
-
patches: List of patch dictionaries from generate_patches()
|
| 189 |
-
dataset_dir: Directory to save the dataset
|
| 190 |
-
entry_id: Unique identifier for this entry/image
|
| 191 |
-
"""
|
| 192 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 193 |
-
|
| 194 |
-
for i, patch in enumerate(patches):
|
| 195 |
-
# Create unique filename
|
| 196 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 197 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 198 |
-
|
| 199 |
-
# Skip if file already exists
|
| 200 |
-
if os.path.exists(filepath):
|
| 201 |
-
continue
|
| 202 |
-
|
| 203 |
-
# Save patch data
|
| 204 |
-
with open(filepath, 'wb') as f:
|
| 205 |
-
pickle.dump(patch, f)
|
| 206 |
-
|
| 207 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 208 |
-
|
| 209 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 210 |
-
def collate_fn(batch):
|
| 211 |
-
valid_batch = []
|
| 212 |
-
for patch_data, label, valid_mask in batch:
|
| 213 |
-
# Filter out invalid samples (no valid points)
|
| 214 |
-
if valid_mask.sum() > 0:
|
| 215 |
-
valid_batch.append((patch_data, label, valid_mask))
|
| 216 |
-
|
| 217 |
-
if len(valid_batch) == 0:
|
| 218 |
-
return None
|
| 219 |
-
|
| 220 |
-
# Stack valid samples
|
| 221 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 222 |
-
labels = torch.stack([item[1] for item in valid_batch])
|
| 223 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 224 |
-
|
| 225 |
-
return patch_data, labels, valid_masks
|
| 226 |
-
|
| 227 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 228 |
-
def init_weights(m):
|
| 229 |
-
if isinstance(m, nn.Conv1d):
|
| 230 |
-
nn.init.xavier_uniform_(m.weight)
|
| 231 |
-
if m.bias is not None:
|
| 232 |
-
nn.init.zeros_(m.bias)
|
| 233 |
-
elif isinstance(m, nn.Linear):
|
| 234 |
-
nn.init.xavier_uniform_(m.weight)
|
| 235 |
-
if m.bias is not None:
|
| 236 |
-
nn.init.zeros_(m.bias)
|
| 237 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 238 |
-
nn.init.ones_(m.weight)
|
| 239 |
-
nn.init.zeros_(m.bias)
|
| 240 |
-
|
| 241 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
|
| 242 |
-
lr: float = 0.001):
|
| 243 |
-
"""
|
| 244 |
-
Train the ClassificationPointNet model on saved patches.
|
| 245 |
-
|
| 246 |
-
Args:
|
| 247 |
-
dataset_dir: Directory containing saved patch files
|
| 248 |
-
model_save_path: Path to save the trained model
|
| 249 |
-
epochs: Number of training epochs
|
| 250 |
-
batch_size: Training batch size
|
| 251 |
-
lr: Learning rate
|
| 252 |
-
"""
|
| 253 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 254 |
-
print(f"Training on device: {device}")
|
| 255 |
-
|
| 256 |
-
# Create dataset and dataloader
|
| 257 |
-
dataset = PatchClassificationDataset(dataset_dir, max_points=2048, augment=True)
|
| 258 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 259 |
-
|
| 260 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 261 |
-
collate_fn=collate_fn, drop_last=True)
|
| 262 |
-
|
| 263 |
-
# Initialize model
|
| 264 |
-
model = ClassificationPointNet(input_dim=10, max_points=2048)
|
| 265 |
-
model.apply(init_weights)
|
| 266 |
-
model.to(device)
|
| 267 |
-
|
| 268 |
-
# Loss function and optimizer (BCE for binary classification)
|
| 269 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 270 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 271 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 272 |
-
|
| 273 |
-
# Training loop
|
| 274 |
-
model.train()
|
| 275 |
-
for epoch in range(epochs):
|
| 276 |
-
total_loss = 0.0
|
| 277 |
-
correct = 0
|
| 278 |
-
total = 0
|
| 279 |
-
num_batches = 0
|
| 280 |
-
|
| 281 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 282 |
-
if batch_data is None: # Skip invalid batches
|
| 283 |
-
continue
|
| 284 |
-
|
| 285 |
-
patch_data, labels, valid_masks = batch_data
|
| 286 |
-
patch_data = patch_data.to(device) # (batch_size, 10, max_points)
|
| 287 |
-
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
|
| 288 |
-
|
| 289 |
-
# Forward pass
|
| 290 |
-
optimizer.zero_grad()
|
| 291 |
-
outputs = model(patch_data) # (batch_size, 1)
|
| 292 |
-
loss = criterion(outputs, labels)
|
| 293 |
-
|
| 294 |
-
# Backward pass
|
| 295 |
-
loss.backward()
|
| 296 |
-
optimizer.step()
|
| 297 |
-
|
| 298 |
-
# Statistics
|
| 299 |
-
total_loss += loss.item()
|
| 300 |
-
predicted = (torch.sigmoid(outputs) > 0.5).float()
|
| 301 |
-
total += labels.size(0)
|
| 302 |
-
correct += (predicted == labels).sum().item()
|
| 303 |
-
num_batches += 1
|
| 304 |
-
|
| 305 |
-
if batch_idx % 50 == 0:
|
| 306 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 307 |
-
f"Loss: {loss.item():.6f}, "
|
| 308 |
-
f"Accuracy: {100 * correct / total:.2f}%")
|
| 309 |
-
|
| 310 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 311 |
-
accuracy = 100 * correct / total if total > 0 else 0
|
| 312 |
-
|
| 313 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 314 |
-
f"Avg Loss: {avg_loss:.6f}, "
|
| 315 |
-
f"Accuracy: {accuracy:.2f}%")
|
| 316 |
-
|
| 317 |
-
scheduler.step()
|
| 318 |
-
|
| 319 |
-
# Save model checkpoint every epoch
|
| 320 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 321 |
-
torch.save({
|
| 322 |
-
'model_state_dict': model.state_dict(),
|
| 323 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 324 |
-
'epoch': epoch + 1,
|
| 325 |
-
'loss': avg_loss,
|
| 326 |
-
'accuracy': accuracy,
|
| 327 |
-
}, checkpoint_path)
|
| 328 |
-
|
| 329 |
-
# Save the trained model
|
| 330 |
-
torch.save({
|
| 331 |
-
'model_state_dict': model.state_dict(),
|
| 332 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 333 |
-
'epoch': epochs,
|
| 334 |
-
}, model_save_path)
|
| 335 |
-
|
| 336 |
-
print(f"Model saved to {model_save_path}")
|
| 337 |
-
return model
|
| 338 |
-
|
| 339 |
-
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
|
| 340 |
-
"""
|
| 341 |
-
Load a trained ClassificationPointNet model.
|
| 342 |
-
|
| 343 |
-
Args:
|
| 344 |
-
model_path: Path to the saved model
|
| 345 |
-
device: Device to load the model on
|
| 346 |
-
|
| 347 |
-
Returns:
|
| 348 |
-
Loaded ClassificationPointNet model
|
| 349 |
-
"""
|
| 350 |
-
if device is None:
|
| 351 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 352 |
-
|
| 353 |
-
model = ClassificationPointNet(input_dim=10, max_points=2048)
|
| 354 |
-
|
| 355 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 356 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 357 |
-
|
| 358 |
-
model.to(device)
|
| 359 |
-
model.eval()
|
| 360 |
-
|
| 361 |
-
return model
|
| 362 |
-
|
| 363 |
-
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
|
| 364 |
-
"""
|
| 365 |
-
Predict binary classification from a patch using trained PointNet.
|
| 366 |
-
|
| 367 |
-
Args:
|
| 368 |
-
model: Trained ClassificationPointNet model
|
| 369 |
-
patch: Dictionary containing patch data with 'patch_10d' key
|
| 370 |
-
device: Device to run prediction on
|
| 371 |
-
|
| 372 |
-
Returns:
|
| 373 |
-
tuple of (predicted_class, confidence)
|
| 374 |
-
predicted_class: int (0 for not edge, 1 for edge)
|
| 375 |
-
confidence: float representing confidence score (0-1)
|
| 376 |
-
"""
|
| 377 |
-
if device is None:
|
| 378 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 379 |
-
|
| 380 |
-
patch_10d = patch['patch_10d'] # (N, 10)
|
| 381 |
-
|
| 382 |
-
# Prepare input
|
| 383 |
-
max_points = 2048
|
| 384 |
-
num_points = patch_10d.shape[0]
|
| 385 |
-
|
| 386 |
-
if num_points >= max_points:
|
| 387 |
-
# Sample points
|
| 388 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 389 |
-
patch_sampled = patch_10d[indices]
|
| 390 |
-
else:
|
| 391 |
-
# Pad with zeros
|
| 392 |
-
patch_sampled = np.zeros((max_points, 10))
|
| 393 |
-
patch_sampled[:num_points] = patch_10d
|
| 394 |
-
|
| 395 |
-
# Convert to tensor
|
| 396 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
|
| 397 |
-
patch_tensor = patch_tensor.to(device)
|
| 398 |
-
|
| 399 |
-
# Predict
|
| 400 |
-
with torch.no_grad():
|
| 401 |
-
outputs = model(patch_tensor) # (1, 1)
|
| 402 |
-
probability = torch.sigmoid(outputs).item()
|
| 403 |
-
predicted_class = int(probability > 0.5)
|
| 404 |
-
|
| 405 |
-
return predicted_class, probability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_class_10d_deeper.py
DELETED
|
@@ -1,438 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class ClassificationPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
PointNet implementation for binary classification from 10D point cloud patches.
|
| 14 |
-
Takes 10D point clouds and predicts binary classification (edge/not edge).
|
| 15 |
-
Enhanced with residual connections and attention mechanism.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_dim=10, max_points=1024):
|
| 18 |
-
super(ClassificationPointNet, self).__init__()
|
| 19 |
-
self.max_points = max_points
|
| 20 |
-
|
| 21 |
-
# Point-wise MLPs for feature extraction (deeper with residual connections)
|
| 22 |
-
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 23 |
-
self.conv2 = nn.Conv1d(64, 128, 1)
|
| 24 |
-
self.conv3 = nn.Conv1d(128, 256, 1)
|
| 25 |
-
self.conv4 = nn.Conv1d(256, 512, 1)
|
| 26 |
-
self.conv5 = nn.Conv1d(512, 1024, 1)
|
| 27 |
-
self.conv6 = nn.Conv1d(1024, 2048, 1)
|
| 28 |
-
self.conv7 = nn.Conv1d(2048, 2048, 1) # Additional layer
|
| 29 |
-
|
| 30 |
-
# Residual connection layers
|
| 31 |
-
self.residual1 = nn.Conv1d(128, 256, 1)
|
| 32 |
-
self.residual2 = nn.Conv1d(512, 1024, 1)
|
| 33 |
-
|
| 34 |
-
# Attention mechanism
|
| 35 |
-
self.attention = nn.Conv1d(2048, 1, 1)
|
| 36 |
-
|
| 37 |
-
# Classification head (deeper with more capacity)
|
| 38 |
-
self.fc1 = nn.Linear(2048, 1536)
|
| 39 |
-
self.fc2 = nn.Linear(1536, 1024)
|
| 40 |
-
self.fc3 = nn.Linear(1024, 512)
|
| 41 |
-
self.fc4 = nn.Linear(512, 256)
|
| 42 |
-
self.fc5 = nn.Linear(256, 128)
|
| 43 |
-
self.fc6 = nn.Linear(128, 64)
|
| 44 |
-
self.fc7 = nn.Linear(64, 32)
|
| 45 |
-
self.fc8 = nn.Linear(32, 1) # Single output for binary classification
|
| 46 |
-
|
| 47 |
-
# Batch normalization layers
|
| 48 |
-
self.bn1 = nn.BatchNorm1d(64)
|
| 49 |
-
self.bn2 = nn.BatchNorm1d(128)
|
| 50 |
-
self.bn3 = nn.BatchNorm1d(256)
|
| 51 |
-
self.bn4 = nn.BatchNorm1d(512)
|
| 52 |
-
self.bn5 = nn.BatchNorm1d(1024)
|
| 53 |
-
self.bn6 = nn.BatchNorm1d(2048)
|
| 54 |
-
self.bn7 = nn.BatchNorm1d(2048)
|
| 55 |
-
|
| 56 |
-
# Dropout layers with varying rates
|
| 57 |
-
self.dropout1 = nn.Dropout(0.2)
|
| 58 |
-
self.dropout2 = nn.Dropout(0.3)
|
| 59 |
-
self.dropout3 = nn.Dropout(0.4)
|
| 60 |
-
self.dropout4 = nn.Dropout(0.5)
|
| 61 |
-
self.dropout5 = nn.Dropout(0.4)
|
| 62 |
-
self.dropout6 = nn.Dropout(0.3)
|
| 63 |
-
self.dropout7 = nn.Dropout(0.2)
|
| 64 |
-
|
| 65 |
-
def forward(self, x):
|
| 66 |
-
"""
|
| 67 |
-
Forward pass with residual connections and attention
|
| 68 |
-
Args:
|
| 69 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 70 |
-
Returns:
|
| 71 |
-
classification: (batch_size, 1) tensor of logits (sigmoid for probability)
|
| 72 |
-
"""
|
| 73 |
-
batch_size = x.size(0)
|
| 74 |
-
|
| 75 |
-
# Point-wise feature extraction with residual connections
|
| 76 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 77 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 78 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 79 |
-
|
| 80 |
-
# First residual connection
|
| 81 |
-
x3_res = x3 + self.residual1(x2)
|
| 82 |
-
|
| 83 |
-
x4 = F.relu(self.bn4(self.conv4(x3_res)))
|
| 84 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 85 |
-
|
| 86 |
-
# Second residual connection
|
| 87 |
-
x5_res = x5 + self.residual2(x4)
|
| 88 |
-
|
| 89 |
-
x6 = F.relu(self.bn6(self.conv6(x5_res)))
|
| 90 |
-
x7 = F.relu(self.bn7(self.conv7(x6)))
|
| 91 |
-
|
| 92 |
-
# Attention mechanism
|
| 93 |
-
attention_weights = F.softmax(self.attention(x7), dim=2) # (batch_size, 1, max_points)
|
| 94 |
-
x7_weighted = x7 * attention_weights # Apply attention
|
| 95 |
-
|
| 96 |
-
# Global max pooling combined with attention-weighted average pooling
|
| 97 |
-
global_max = torch.max(x7, 2)[0] # (batch_size, 2048)
|
| 98 |
-
global_avg = torch.sum(x7_weighted, 2) # (batch_size, 2048)
|
| 99 |
-
global_features = global_max + global_avg # Combine features
|
| 100 |
-
|
| 101 |
-
# Classification head with residual connections
|
| 102 |
-
x = F.relu(self.fc1(global_features))
|
| 103 |
-
x = self.dropout1(x)
|
| 104 |
-
x = F.relu(self.fc2(x))
|
| 105 |
-
x = self.dropout2(x)
|
| 106 |
-
x_mid = F.relu(self.fc3(x))
|
| 107 |
-
x = self.dropout3(x_mid)
|
| 108 |
-
x = F.relu(self.fc4(x))
|
| 109 |
-
x = self.dropout4(x)
|
| 110 |
-
x = F.relu(self.fc5(x))
|
| 111 |
-
x = self.dropout5(x)
|
| 112 |
-
x = F.relu(self.fc6(x))
|
| 113 |
-
x = self.dropout6(x)
|
| 114 |
-
x = F.relu(self.fc7(x))
|
| 115 |
-
x = self.dropout7(x)
|
| 116 |
-
classification = self.fc8(x) # (batch_size, 1)
|
| 117 |
-
|
| 118 |
-
return classification
|
| 119 |
-
|
| 120 |
-
class PatchClassificationDataset(Dataset):
|
| 121 |
-
"""
|
| 122 |
-
Dataset class for loading saved patches for PointNet classification training.
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 126 |
-
self.dataset_dir = dataset_dir
|
| 127 |
-
self.max_points = max_points
|
| 128 |
-
self.augment = augment
|
| 129 |
-
|
| 130 |
-
# Load patch files
|
| 131 |
-
self.patch_files = []
|
| 132 |
-
for file in os.listdir(dataset_dir):
|
| 133 |
-
if file.endswith('.pkl'):
|
| 134 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 135 |
-
|
| 136 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 137 |
-
|
| 138 |
-
def __len__(self):
|
| 139 |
-
return len(self.patch_files)
|
| 140 |
-
|
| 141 |
-
def __getitem__(self, idx):
|
| 142 |
-
"""
|
| 143 |
-
Load and process a patch for training.
|
| 144 |
-
Returns:
|
| 145 |
-
patch_data: (10, max_points) tensor of point cloud data
|
| 146 |
-
label: scalar tensor for binary classification (0 or 1)
|
| 147 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 148 |
-
"""
|
| 149 |
-
patch_file = self.patch_files[idx]
|
| 150 |
-
|
| 151 |
-
with open(patch_file, 'rb') as f:
|
| 152 |
-
patch_info = pickle.load(f)
|
| 153 |
-
|
| 154 |
-
patch_10d = patch_info['patch_10d'] # (N, 10)
|
| 155 |
-
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
|
| 156 |
-
|
| 157 |
-
# Pad or sample points to max_points
|
| 158 |
-
num_points = patch_10d.shape[0]
|
| 159 |
-
|
| 160 |
-
if num_points >= self.max_points:
|
| 161 |
-
# Randomly sample max_points
|
| 162 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 163 |
-
patch_sampled = patch_10d[indices]
|
| 164 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 165 |
-
else:
|
| 166 |
-
# Pad with zeros
|
| 167 |
-
patch_sampled = np.zeros((self.max_points, 10))
|
| 168 |
-
patch_sampled[:num_points] = patch_10d
|
| 169 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 170 |
-
valid_mask[:num_points] = True
|
| 171 |
-
|
| 172 |
-
# Data augmentation
|
| 173 |
-
if self.augment:
|
| 174 |
-
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 175 |
-
|
| 176 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 177 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (10, max_points)
|
| 178 |
-
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
|
| 179 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 180 |
-
|
| 181 |
-
return patch_tensor, label_tensor, valid_mask_tensor
|
| 182 |
-
|
| 183 |
-
def _augment_patch(self, patch, valid_mask):
|
| 184 |
-
"""
|
| 185 |
-
Apply data augmentation to the patch.
|
| 186 |
-
"""
|
| 187 |
-
valid_points = patch[valid_mask]
|
| 188 |
-
|
| 189 |
-
if len(valid_points) == 0:
|
| 190 |
-
return patch
|
| 191 |
-
|
| 192 |
-
# Random rotation around z-axis (only for xyz coordinates, first 3 dimensions)
|
| 193 |
-
angle = np.random.uniform(0, 2 * np.pi)
|
| 194 |
-
cos_angle = np.cos(angle)
|
| 195 |
-
sin_angle = np.sin(angle)
|
| 196 |
-
rotation_matrix = np.array([
|
| 197 |
-
[cos_angle, -sin_angle, 0],
|
| 198 |
-
[sin_angle, cos_angle, 0],
|
| 199 |
-
[0, 0, 1]
|
| 200 |
-
])
|
| 201 |
-
|
| 202 |
-
# Apply rotation to xyz coordinates (first 3 dimensions)
|
| 203 |
-
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 204 |
-
|
| 205 |
-
# Random jittering (only for xyz coordinates)
|
| 206 |
-
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
|
| 207 |
-
valid_points[:, :3] += noise
|
| 208 |
-
|
| 209 |
-
# Random scaling (only for xyz coordinates)
|
| 210 |
-
scale = np.random.uniform(0.9, 1.1)
|
| 211 |
-
valid_points[:, :3] *= scale
|
| 212 |
-
|
| 213 |
-
patch[valid_mask] = valid_points
|
| 214 |
-
return patch
|
| 215 |
-
|
| 216 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 217 |
-
"""
|
| 218 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 219 |
-
|
| 220 |
-
Args:
|
| 221 |
-
patches: List of patch dictionaries from generate_patches()
|
| 222 |
-
dataset_dir: Directory to save the dataset
|
| 223 |
-
entry_id: Unique identifier for this entry/image
|
| 224 |
-
"""
|
| 225 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 226 |
-
|
| 227 |
-
for i, patch in enumerate(patches):
|
| 228 |
-
# Create unique filename
|
| 229 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 230 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 231 |
-
|
| 232 |
-
# Skip if file already exists
|
| 233 |
-
if os.path.exists(filepath):
|
| 234 |
-
continue
|
| 235 |
-
|
| 236 |
-
# Save patch data
|
| 237 |
-
with open(filepath, 'wb') as f:
|
| 238 |
-
pickle.dump(patch, f)
|
| 239 |
-
|
| 240 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 241 |
-
|
| 242 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 243 |
-
def collate_fn(batch):
|
| 244 |
-
valid_batch = []
|
| 245 |
-
for patch_data, label, valid_mask in batch:
|
| 246 |
-
# Filter out invalid samples (no valid points)
|
| 247 |
-
if valid_mask.sum() > 0:
|
| 248 |
-
valid_batch.append((patch_data, label, valid_mask))
|
| 249 |
-
|
| 250 |
-
if len(valid_batch) == 0:
|
| 251 |
-
return None
|
| 252 |
-
|
| 253 |
-
# Stack valid samples
|
| 254 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 255 |
-
labels = torch.stack([item[1] for item in valid_batch])
|
| 256 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 257 |
-
|
| 258 |
-
return patch_data, labels, valid_masks
|
| 259 |
-
|
| 260 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 261 |
-
def init_weights(m):
|
| 262 |
-
if isinstance(m, nn.Conv1d):
|
| 263 |
-
nn.init.xavier_uniform_(m.weight)
|
| 264 |
-
if m.bias is not None:
|
| 265 |
-
nn.init.zeros_(m.bias)
|
| 266 |
-
elif isinstance(m, nn.Linear):
|
| 267 |
-
nn.init.xavier_uniform_(m.weight)
|
| 268 |
-
if m.bias is not None:
|
| 269 |
-
nn.init.zeros_(m.bias)
|
| 270 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 271 |
-
nn.init.ones_(m.weight)
|
| 272 |
-
nn.init.zeros_(m.bias)
|
| 273 |
-
|
| 274 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
|
| 275 |
-
lr: float = 0.001):
|
| 276 |
-
"""
|
| 277 |
-
Train the ClassificationPointNet model on saved patches.
|
| 278 |
-
|
| 279 |
-
Args:
|
| 280 |
-
dataset_dir: Directory containing saved patch files
|
| 281 |
-
model_save_path: Path to save the trained model
|
| 282 |
-
epochs: Number of training epochs
|
| 283 |
-
batch_size: Training batch size
|
| 284 |
-
lr: Learning rate
|
| 285 |
-
"""
|
| 286 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 287 |
-
print(f"Training on device: {device}")
|
| 288 |
-
|
| 289 |
-
# Create dataset and dataloader
|
| 290 |
-
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
|
| 291 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 292 |
-
|
| 293 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 294 |
-
collate_fn=collate_fn, drop_last=True)
|
| 295 |
-
|
| 296 |
-
# Initialize model
|
| 297 |
-
model = ClassificationPointNet(input_dim=10, max_points=1024)
|
| 298 |
-
model.apply(init_weights)
|
| 299 |
-
model.to(device)
|
| 300 |
-
|
| 301 |
-
# Loss function and optimizer (BCE for binary classification)
|
| 302 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 303 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 304 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 305 |
-
|
| 306 |
-
# Training loop
|
| 307 |
-
model.train()
|
| 308 |
-
for epoch in range(epochs):
|
| 309 |
-
total_loss = 0.0
|
| 310 |
-
correct = 0
|
| 311 |
-
total = 0
|
| 312 |
-
num_batches = 0
|
| 313 |
-
|
| 314 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 315 |
-
if batch_data is None: # Skip invalid batches
|
| 316 |
-
continue
|
| 317 |
-
|
| 318 |
-
patch_data, labels, valid_masks = batch_data
|
| 319 |
-
patch_data = patch_data.to(device) # (batch_size, 10, max_points)
|
| 320 |
-
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
|
| 321 |
-
|
| 322 |
-
# Forward pass
|
| 323 |
-
optimizer.zero_grad()
|
| 324 |
-
outputs = model(patch_data) # (batch_size, 1)
|
| 325 |
-
loss = criterion(outputs, labels)
|
| 326 |
-
|
| 327 |
-
# Backward pass
|
| 328 |
-
loss.backward()
|
| 329 |
-
optimizer.step()
|
| 330 |
-
|
| 331 |
-
# Statistics
|
| 332 |
-
total_loss += loss.item()
|
| 333 |
-
predicted = (torch.sigmoid(outputs) > 0.5).float()
|
| 334 |
-
total += labels.size(0)
|
| 335 |
-
correct += (predicted == labels).sum().item()
|
| 336 |
-
num_batches += 1
|
| 337 |
-
|
| 338 |
-
if batch_idx % 50 == 0:
|
| 339 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 340 |
-
f"Loss: {loss.item():.6f}, "
|
| 341 |
-
f"Accuracy: {100 * correct / total:.2f}%")
|
| 342 |
-
|
| 343 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 344 |
-
accuracy = 100 * correct / total if total > 0 else 0
|
| 345 |
-
|
| 346 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 347 |
-
f"Avg Loss: {avg_loss:.6f}, "
|
| 348 |
-
f"Accuracy: {accuracy:.2f}%")
|
| 349 |
-
|
| 350 |
-
scheduler.step()
|
| 351 |
-
|
| 352 |
-
# Save model checkpoint every epoch
|
| 353 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 354 |
-
torch.save({
|
| 355 |
-
'model_state_dict': model.state_dict(),
|
| 356 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 357 |
-
'epoch': epoch + 1,
|
| 358 |
-
'loss': avg_loss,
|
| 359 |
-
'accuracy': accuracy,
|
| 360 |
-
}, checkpoint_path)
|
| 361 |
-
|
| 362 |
-
# Save the trained model
|
| 363 |
-
torch.save({
|
| 364 |
-
'model_state_dict': model.state_dict(),
|
| 365 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 366 |
-
'epoch': epochs,
|
| 367 |
-
}, model_save_path)
|
| 368 |
-
|
| 369 |
-
print(f"Model saved to {model_save_path}")
|
| 370 |
-
return model
|
| 371 |
-
|
| 372 |
-
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
|
| 373 |
-
"""
|
| 374 |
-
Load a trained ClassificationPointNet model.
|
| 375 |
-
|
| 376 |
-
Args:
|
| 377 |
-
model_path: Path to the saved model
|
| 378 |
-
device: Device to load the model on
|
| 379 |
-
|
| 380 |
-
Returns:
|
| 381 |
-
Loaded ClassificationPointNet model
|
| 382 |
-
"""
|
| 383 |
-
if device is None:
|
| 384 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 385 |
-
|
| 386 |
-
model = ClassificationPointNet(input_dim=10, max_points=1024)
|
| 387 |
-
|
| 388 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 389 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 390 |
-
|
| 391 |
-
model.to(device)
|
| 392 |
-
model.eval()
|
| 393 |
-
|
| 394 |
-
return model
|
| 395 |
-
|
| 396 |
-
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
|
| 397 |
-
"""
|
| 398 |
-
Predict binary classification from a patch using trained PointNet.
|
| 399 |
-
|
| 400 |
-
Args:
|
| 401 |
-
model: Trained ClassificationPointNet model
|
| 402 |
-
patch: Dictionary containing patch data with 'patch_10d' key
|
| 403 |
-
device: Device to run prediction on
|
| 404 |
-
|
| 405 |
-
Returns:
|
| 406 |
-
tuple of (predicted_class, confidence)
|
| 407 |
-
predicted_class: int (0 for not edge, 1 for edge)
|
| 408 |
-
confidence: float representing confidence score (0-1)
|
| 409 |
-
"""
|
| 410 |
-
if device is None:
|
| 411 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 412 |
-
|
| 413 |
-
patch_10d = patch['patch_10d'] # (N, 10)
|
| 414 |
-
|
| 415 |
-
# Prepare input
|
| 416 |
-
max_points = 1024
|
| 417 |
-
num_points = patch_10d.shape[0]
|
| 418 |
-
|
| 419 |
-
if num_points >= max_points:
|
| 420 |
-
# Sample points
|
| 421 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 422 |
-
patch_sampled = patch_10d[indices]
|
| 423 |
-
else:
|
| 424 |
-
# Pad with zeros
|
| 425 |
-
patch_sampled = np.zeros((max_points, 10))
|
| 426 |
-
patch_sampled[:num_points] = patch_10d
|
| 427 |
-
|
| 428 |
-
# Convert to tensor
|
| 429 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 10, max_points)
|
| 430 |
-
patch_tensor = patch_tensor.to(device)
|
| 431 |
-
|
| 432 |
-
# Predict
|
| 433 |
-
with torch.no_grad():
|
| 434 |
-
outputs = model(patch_tensor) # (1, 1)
|
| 435 |
-
probability = torch.sigmoid(outputs).item()
|
| 436 |
-
predicted_class = int(probability > 0.5)
|
| 437 |
-
|
| 438 |
-
return predicted_class, probability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_class_deeper.py
DELETED
|
@@ -1,527 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class ClassificationPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Enhanced PointNet implementation for binary classification from 6D point cloud patches.
|
| 14 |
-
Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
|
| 15 |
-
Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_dim=6, max_points=1024):
|
| 18 |
-
super(ClassificationPointNet, self).__init__()
|
| 19 |
-
self.max_points = max_points
|
| 20 |
-
|
| 21 |
-
# Point-wise MLPs with residual connections (much deeper)
|
| 22 |
-
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 23 |
-
self.conv2 = nn.Conv1d(64, 64, 1)
|
| 24 |
-
self.conv3 = nn.Conv1d(64, 128, 1)
|
| 25 |
-
self.conv4 = nn.Conv1d(128, 128, 1)
|
| 26 |
-
self.conv5 = nn.Conv1d(128, 256, 1)
|
| 27 |
-
self.conv6 = nn.Conv1d(256, 256, 1)
|
| 28 |
-
self.conv7 = nn.Conv1d(256, 512, 1)
|
| 29 |
-
self.conv8 = nn.Conv1d(512, 512, 1)
|
| 30 |
-
self.conv9 = nn.Conv1d(512, 1024, 1)
|
| 31 |
-
self.conv10 = nn.Conv1d(1024, 1024, 1)
|
| 32 |
-
self.conv11 = nn.Conv1d(1024, 2048, 1)
|
| 33 |
-
|
| 34 |
-
# Residual connection layers
|
| 35 |
-
self.res_conv1 = nn.Conv1d(64, 128, 1)
|
| 36 |
-
self.res_conv2 = nn.Conv1d(128, 256, 1)
|
| 37 |
-
self.res_conv3 = nn.Conv1d(256, 512, 1)
|
| 38 |
-
self.res_conv4 = nn.Conv1d(512, 1024, 1)
|
| 39 |
-
|
| 40 |
-
# Self-attention mechanism
|
| 41 |
-
self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
|
| 42 |
-
self.attention_norm = nn.LayerNorm(2048)
|
| 43 |
-
|
| 44 |
-
# Multi-scale feature aggregation
|
| 45 |
-
self.scale_conv1 = nn.Conv1d(2048, 512, 1)
|
| 46 |
-
self.scale_conv2 = nn.Conv1d(2048, 512, 1)
|
| 47 |
-
self.scale_conv3 = nn.Conv1d(2048, 512, 1)
|
| 48 |
-
|
| 49 |
-
# Enhanced classification head with residual connections
|
| 50 |
-
self.fc1 = nn.Linear(7680, 2048) # Updated input size: 2048*3 + 512*3 = 7680
|
| 51 |
-
self.fc2 = nn.Linear(2048, 2048)
|
| 52 |
-
self.fc3 = nn.Linear(2048, 1024)
|
| 53 |
-
self.fc4 = nn.Linear(1024, 1024)
|
| 54 |
-
self.fc5 = nn.Linear(1024, 512)
|
| 55 |
-
self.fc6 = nn.Linear(512, 512)
|
| 56 |
-
self.fc7 = nn.Linear(512, 256)
|
| 57 |
-
self.fc8 = nn.Linear(256, 128)
|
| 58 |
-
self.fc9 = nn.Linear(128, 64)
|
| 59 |
-
self.fc10 = nn.Linear(64, 1)
|
| 60 |
-
|
| 61 |
-
# Residual connections for FC layers
|
| 62 |
-
self.fc_res1 = nn.Linear(2048, 1024)
|
| 63 |
-
self.fc_res2 = nn.Linear(1024, 512)
|
| 64 |
-
self.fc_res3 = nn.Linear(512, 128)
|
| 65 |
-
|
| 66 |
-
# Batch normalization layers
|
| 67 |
-
self.bn1 = nn.BatchNorm1d(64)
|
| 68 |
-
self.bn2 = nn.BatchNorm1d(64)
|
| 69 |
-
self.bn3 = nn.BatchNorm1d(128)
|
| 70 |
-
self.bn4 = nn.BatchNorm1d(128)
|
| 71 |
-
self.bn5 = nn.BatchNorm1d(256)
|
| 72 |
-
self.bn6 = nn.BatchNorm1d(256)
|
| 73 |
-
self.bn7 = nn.BatchNorm1d(512)
|
| 74 |
-
self.bn8 = nn.BatchNorm1d(512)
|
| 75 |
-
self.bn9 = nn.BatchNorm1d(1024)
|
| 76 |
-
self.bn10 = nn.BatchNorm1d(1024)
|
| 77 |
-
self.bn11 = nn.BatchNorm1d(2048)
|
| 78 |
-
|
| 79 |
-
# Scale batch norms
|
| 80 |
-
self.scale_bn1 = nn.BatchNorm1d(512)
|
| 81 |
-
self.scale_bn2 = nn.BatchNorm1d(512)
|
| 82 |
-
self.scale_bn3 = nn.BatchNorm1d(512)
|
| 83 |
-
|
| 84 |
-
# FC batch norms
|
| 85 |
-
self.fc_bn1 = nn.BatchNorm1d(2048)
|
| 86 |
-
self.fc_bn2 = nn.BatchNorm1d(2048)
|
| 87 |
-
self.fc_bn3 = nn.BatchNorm1d(1024)
|
| 88 |
-
self.fc_bn4 = nn.BatchNorm1d(1024)
|
| 89 |
-
self.fc_bn5 = nn.BatchNorm1d(512)
|
| 90 |
-
self.fc_bn6 = nn.BatchNorm1d(512)
|
| 91 |
-
self.fc_bn7 = nn.BatchNorm1d(256)
|
| 92 |
-
self.fc_bn8 = nn.BatchNorm1d(128)
|
| 93 |
-
|
| 94 |
-
# Dropout layers with varying rates
|
| 95 |
-
self.dropout1 = nn.Dropout(0.1)
|
| 96 |
-
self.dropout2 = nn.Dropout(0.2)
|
| 97 |
-
self.dropout3 = nn.Dropout(0.3)
|
| 98 |
-
self.dropout4 = nn.Dropout(0.4)
|
| 99 |
-
self.dropout5 = nn.Dropout(0.5)
|
| 100 |
-
self.dropout6 = nn.Dropout(0.4)
|
| 101 |
-
self.dropout7 = nn.Dropout(0.3)
|
| 102 |
-
self.dropout8 = nn.Dropout(0.2)
|
| 103 |
-
|
| 104 |
-
def forward(self, x):
|
| 105 |
-
"""
|
| 106 |
-
Forward pass with residual connections and attention
|
| 107 |
-
Args:
|
| 108 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 109 |
-
Returns:
|
| 110 |
-
classification: (batch_size, 1) tensor of logits
|
| 111 |
-
"""
|
| 112 |
-
batch_size = x.size(0)
|
| 113 |
-
|
| 114 |
-
# Deep point-wise feature extraction with residual connections
|
| 115 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 116 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 117 |
-
x2 = x2 + x1 # Residual connection
|
| 118 |
-
|
| 119 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 120 |
-
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 121 |
-
res1 = self.res_conv1(x2)
|
| 122 |
-
x4 = x4 + res1 # Residual connection
|
| 123 |
-
|
| 124 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 125 |
-
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 126 |
-
res2 = self.res_conv2(x4)
|
| 127 |
-
x6 = x6 + res2 # Residual connection
|
| 128 |
-
|
| 129 |
-
x7 = F.relu(self.bn7(self.conv7(x6)))
|
| 130 |
-
x8 = F.relu(self.bn8(self.conv8(x7)))
|
| 131 |
-
res3 = self.res_conv3(x6)
|
| 132 |
-
x8 = x8 + res3 # Residual connection
|
| 133 |
-
|
| 134 |
-
x9 = F.relu(self.bn9(self.conv9(x8)))
|
| 135 |
-
x10 = F.relu(self.bn10(self.conv10(x9)))
|
| 136 |
-
res4 = self.res_conv4(x8)
|
| 137 |
-
x10 = x10 + res4 # Residual connection
|
| 138 |
-
|
| 139 |
-
x11 = F.relu(self.bn11(self.conv11(x10)))
|
| 140 |
-
|
| 141 |
-
# Multi-scale global pooling
|
| 142 |
-
# Max pooling
|
| 143 |
-
global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
|
| 144 |
-
|
| 145 |
-
# Average pooling
|
| 146 |
-
global_avg = torch.mean(x11, 2) # (batch_size, 2048)
|
| 147 |
-
|
| 148 |
-
# Attention-based pooling
|
| 149 |
-
x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
|
| 150 |
-
attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
|
| 151 |
-
attended = self.attention_norm(attended + x11_transposed)
|
| 152 |
-
global_att = torch.mean(attended, 1) # (batch_size, 2048)
|
| 153 |
-
|
| 154 |
-
# Multi-scale feature extraction
|
| 155 |
-
scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
|
| 156 |
-
scale1_pool = torch.max(scale1, 2)[0]
|
| 157 |
-
|
| 158 |
-
scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
|
| 159 |
-
scale2_pool = torch.mean(scale2, 2)
|
| 160 |
-
|
| 161 |
-
scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
|
| 162 |
-
scale3_pool = torch.std(scale3, 2)
|
| 163 |
-
|
| 164 |
-
# Concatenate all global features
|
| 165 |
-
global_features = torch.cat([
|
| 166 |
-
global_max, global_avg, global_att,
|
| 167 |
-
scale1_pool, scale2_pool, scale3_pool
|
| 168 |
-
], dim=1) # (batch_size, 4096)
|
| 169 |
-
|
| 170 |
-
# Enhanced classification head with residual connections
|
| 171 |
-
x = F.relu(self.fc_bn1(self.fc1(global_features)))
|
| 172 |
-
x = self.dropout1(x)
|
| 173 |
-
|
| 174 |
-
x = F.relu(self.fc_bn2(self.fc2(x)))
|
| 175 |
-
identity1 = x
|
| 176 |
-
x = self.dropout2(x)
|
| 177 |
-
|
| 178 |
-
x = F.relu(self.fc_bn3(self.fc3(x)))
|
| 179 |
-
x = self.dropout3(x)
|
| 180 |
-
|
| 181 |
-
x = F.relu(self.fc_bn4(self.fc4(x)))
|
| 182 |
-
res_fc1 = self.fc_res1(identity1)
|
| 183 |
-
x = x + res_fc1 # Residual connection
|
| 184 |
-
identity2 = x
|
| 185 |
-
x = self.dropout4(x)
|
| 186 |
-
|
| 187 |
-
x = F.relu(self.fc_bn5(self.fc5(x)))
|
| 188 |
-
x = self.dropout5(x)
|
| 189 |
-
|
| 190 |
-
x = F.relu(self.fc_bn6(self.fc6(x)))
|
| 191 |
-
res_fc2 = self.fc_res2(identity2)
|
| 192 |
-
x = x + res_fc2 # Residual connection
|
| 193 |
-
identity3 = x
|
| 194 |
-
x = self.dropout6(x)
|
| 195 |
-
|
| 196 |
-
x = F.relu(self.fc_bn7(self.fc7(x)))
|
| 197 |
-
x = self.dropout7(x)
|
| 198 |
-
|
| 199 |
-
x = F.relu(self.fc_bn8(self.fc8(x)))
|
| 200 |
-
res_fc3 = self.fc_res3(identity3)
|
| 201 |
-
x = x + res_fc3 # Residual connection
|
| 202 |
-
x = self.dropout8(x)
|
| 203 |
-
|
| 204 |
-
x = F.relu(self.fc9(x))
|
| 205 |
-
classification = self.fc10(x) # (batch_size, 1)
|
| 206 |
-
|
| 207 |
-
return classification
|
| 208 |
-
|
| 209 |
-
class PatchClassificationDataset(Dataset):
|
| 210 |
-
"""
|
| 211 |
-
Dataset class for loading saved patches for PointNet classification training.
|
| 212 |
-
"""
|
| 213 |
-
|
| 214 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 215 |
-
self.dataset_dir = dataset_dir
|
| 216 |
-
self.max_points = max_points
|
| 217 |
-
self.augment = augment
|
| 218 |
-
|
| 219 |
-
# Load patch files
|
| 220 |
-
self.patch_files = []
|
| 221 |
-
for file in os.listdir(dataset_dir):
|
| 222 |
-
if file.endswith('.pkl'):
|
| 223 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 224 |
-
|
| 225 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 226 |
-
|
| 227 |
-
def __len__(self):
|
| 228 |
-
return len(self.patch_files)
|
| 229 |
-
|
| 230 |
-
def __getitem__(self, idx):
|
| 231 |
-
"""
|
| 232 |
-
Load and process a patch for training.
|
| 233 |
-
Returns:
|
| 234 |
-
patch_data: (6, max_points) tensor of point cloud data
|
| 235 |
-
label: scalar tensor for binary classification (0 or 1)
|
| 236 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 237 |
-
"""
|
| 238 |
-
patch_file = self.patch_files[idx]
|
| 239 |
-
|
| 240 |
-
with open(patch_file, 'rb') as f:
|
| 241 |
-
patch_info = pickle.load(f)
|
| 242 |
-
|
| 243 |
-
patch_6d = patch_info['patch_6d'] # (N, 6)
|
| 244 |
-
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
|
| 245 |
-
|
| 246 |
-
# Pad or sample points to max_points
|
| 247 |
-
num_points = patch_6d.shape[0]
|
| 248 |
-
|
| 249 |
-
if num_points >= self.max_points:
|
| 250 |
-
# Randomly sample max_points
|
| 251 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 252 |
-
patch_sampled = patch_6d[indices]
|
| 253 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 254 |
-
else:
|
| 255 |
-
# Pad with zeros
|
| 256 |
-
patch_sampled = np.zeros((self.max_points, 6))
|
| 257 |
-
patch_sampled[:num_points] = patch_6d
|
| 258 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 259 |
-
valid_mask[:num_points] = True
|
| 260 |
-
|
| 261 |
-
# Data augmentation
|
| 262 |
-
if self.augment:
|
| 263 |
-
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 264 |
-
|
| 265 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 266 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points)
|
| 267 |
-
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
|
| 268 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 269 |
-
|
| 270 |
-
return patch_tensor, label_tensor, valid_mask_tensor
|
| 271 |
-
|
| 272 |
-
def _augment_patch(self, patch, valid_mask):
|
| 273 |
-
"""
|
| 274 |
-
Apply data augmentation to the patch.
|
| 275 |
-
"""
|
| 276 |
-
valid_points = patch[valid_mask]
|
| 277 |
-
|
| 278 |
-
if len(valid_points) == 0:
|
| 279 |
-
return patch
|
| 280 |
-
|
| 281 |
-
# Random rotation around z-axis
|
| 282 |
-
angle = np.random.uniform(0, 2 * np.pi)
|
| 283 |
-
cos_angle = np.cos(angle)
|
| 284 |
-
sin_angle = np.sin(angle)
|
| 285 |
-
rotation_matrix = np.array([
|
| 286 |
-
[cos_angle, -sin_angle, 0],
|
| 287 |
-
[sin_angle, cos_angle, 0],
|
| 288 |
-
[0, 0, 1]
|
| 289 |
-
])
|
| 290 |
-
|
| 291 |
-
# Apply rotation to xyz coordinates
|
| 292 |
-
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 293 |
-
|
| 294 |
-
# Random jittering
|
| 295 |
-
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
|
| 296 |
-
valid_points[:, :3] += noise
|
| 297 |
-
|
| 298 |
-
# Random scaling
|
| 299 |
-
scale = np.random.uniform(0.9, 1.1)
|
| 300 |
-
valid_points[:, :3] *= scale
|
| 301 |
-
|
| 302 |
-
patch[valid_mask] = valid_points
|
| 303 |
-
return patch
|
| 304 |
-
|
| 305 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 306 |
-
"""
|
| 307 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 308 |
-
|
| 309 |
-
Args:
|
| 310 |
-
patches: List of patch dictionaries from generate_patches()
|
| 311 |
-
dataset_dir: Directory to save the dataset
|
| 312 |
-
entry_id: Unique identifier for this entry/image
|
| 313 |
-
"""
|
| 314 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 315 |
-
|
| 316 |
-
for i, patch in enumerate(patches):
|
| 317 |
-
# Create unique filename
|
| 318 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 319 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 320 |
-
|
| 321 |
-
# Skip if file already exists
|
| 322 |
-
if os.path.exists(filepath):
|
| 323 |
-
continue
|
| 324 |
-
|
| 325 |
-
# Save patch data
|
| 326 |
-
with open(filepath, 'wb') as f:
|
| 327 |
-
pickle.dump(patch, f)
|
| 328 |
-
|
| 329 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 330 |
-
|
| 331 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 332 |
-
def collate_fn(batch):
|
| 333 |
-
valid_batch = []
|
| 334 |
-
for patch_data, label, valid_mask in batch:
|
| 335 |
-
# Filter out invalid samples (no valid points)
|
| 336 |
-
if valid_mask.sum() > 0:
|
| 337 |
-
valid_batch.append((patch_data, label, valid_mask))
|
| 338 |
-
|
| 339 |
-
if len(valid_batch) == 0:
|
| 340 |
-
return None
|
| 341 |
-
|
| 342 |
-
# Stack valid samples
|
| 343 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 344 |
-
labels = torch.stack([item[1] for item in valid_batch])
|
| 345 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 346 |
-
|
| 347 |
-
return patch_data, labels, valid_masks
|
| 348 |
-
|
| 349 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 350 |
-
def init_weights(m):
|
| 351 |
-
if isinstance(m, nn.Conv1d):
|
| 352 |
-
nn.init.xavier_uniform_(m.weight)
|
| 353 |
-
if m.bias is not None:
|
| 354 |
-
nn.init.zeros_(m.bias)
|
| 355 |
-
elif isinstance(m, nn.Linear):
|
| 356 |
-
nn.init.xavier_uniform_(m.weight)
|
| 357 |
-
if m.bias is not None:
|
| 358 |
-
nn.init.zeros_(m.bias)
|
| 359 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 360 |
-
nn.init.ones_(m.weight)
|
| 361 |
-
nn.init.zeros_(m.bias)
|
| 362 |
-
|
| 363 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
|
| 364 |
-
lr: float = 0.001):
|
| 365 |
-
"""
|
| 366 |
-
Train the ClassificationPointNet model on saved patches.
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
dataset_dir: Directory containing saved patch files
|
| 370 |
-
model_save_path: Path to save the trained model
|
| 371 |
-
epochs: Number of training epochs
|
| 372 |
-
batch_size: Training batch size
|
| 373 |
-
lr: Learning rate
|
| 374 |
-
"""
|
| 375 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 376 |
-
print(f"Training on device: {device}")
|
| 377 |
-
|
| 378 |
-
# Create dataset and dataloader
|
| 379 |
-
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
|
| 380 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 381 |
-
|
| 382 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 383 |
-
collate_fn=collate_fn, drop_last=True)
|
| 384 |
-
|
| 385 |
-
# Initialize model
|
| 386 |
-
model = ClassificationPointNet(input_dim=6, max_points=1024)
|
| 387 |
-
model.apply(init_weights)
|
| 388 |
-
model.to(device)
|
| 389 |
-
|
| 390 |
-
# Loss function and optimizer (BCE for binary classification)
|
| 391 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 392 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 393 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 394 |
-
|
| 395 |
-
# Training loop
|
| 396 |
-
model.train()
|
| 397 |
-
for epoch in range(epochs):
|
| 398 |
-
total_loss = 0.0
|
| 399 |
-
correct = 0
|
| 400 |
-
total = 0
|
| 401 |
-
num_batches = 0
|
| 402 |
-
|
| 403 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 404 |
-
if batch_data is None: # Skip invalid batches
|
| 405 |
-
continue
|
| 406 |
-
|
| 407 |
-
patch_data, labels, valid_masks = batch_data
|
| 408 |
-
patch_data = patch_data.to(device) # (batch_size, 6, max_points)
|
| 409 |
-
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
|
| 410 |
-
|
| 411 |
-
# Forward pass
|
| 412 |
-
optimizer.zero_grad()
|
| 413 |
-
outputs = model(patch_data) # (batch_size, 1)
|
| 414 |
-
loss = criterion(outputs, labels)
|
| 415 |
-
|
| 416 |
-
# Backward pass
|
| 417 |
-
loss.backward()
|
| 418 |
-
optimizer.step()
|
| 419 |
-
|
| 420 |
-
# Statistics
|
| 421 |
-
total_loss += loss.item()
|
| 422 |
-
predicted = (torch.sigmoid(outputs) > 0.5).float()
|
| 423 |
-
total += labels.size(0)
|
| 424 |
-
correct += (predicted == labels).sum().item()
|
| 425 |
-
num_batches += 1
|
| 426 |
-
|
| 427 |
-
if batch_idx % 50 == 0:
|
| 428 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 429 |
-
f"Loss: {loss.item():.6f}, "
|
| 430 |
-
f"Accuracy: {100 * correct / total:.2f}%")
|
| 431 |
-
|
| 432 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 433 |
-
accuracy = 100 * correct / total if total > 0 else 0
|
| 434 |
-
|
| 435 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 436 |
-
f"Avg Loss: {avg_loss:.6f}, "
|
| 437 |
-
f"Accuracy: {accuracy:.2f}%")
|
| 438 |
-
|
| 439 |
-
scheduler.step()
|
| 440 |
-
|
| 441 |
-
# Save model checkpoint every epoch
|
| 442 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 443 |
-
torch.save({
|
| 444 |
-
'model_state_dict': model.state_dict(),
|
| 445 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 446 |
-
'epoch': epoch + 1,
|
| 447 |
-
'loss': avg_loss,
|
| 448 |
-
'accuracy': accuracy,
|
| 449 |
-
}, checkpoint_path)
|
| 450 |
-
|
| 451 |
-
# Save the trained model
|
| 452 |
-
torch.save({
|
| 453 |
-
'model_state_dict': model.state_dict(),
|
| 454 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 455 |
-
'epoch': epochs,
|
| 456 |
-
}, model_save_path)
|
| 457 |
-
|
| 458 |
-
print(f"Model saved to {model_save_path}")
|
| 459 |
-
return model
|
| 460 |
-
|
| 461 |
-
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
|
| 462 |
-
"""
|
| 463 |
-
Load a trained ClassificationPointNet model.
|
| 464 |
-
|
| 465 |
-
Args:
|
| 466 |
-
model_path: Path to the saved model
|
| 467 |
-
device: Device to load the model on
|
| 468 |
-
|
| 469 |
-
Returns:
|
| 470 |
-
Loaded ClassificationPointNet model
|
| 471 |
-
"""
|
| 472 |
-
if device is None:
|
| 473 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 474 |
-
|
| 475 |
-
model = ClassificationPointNet(input_dim=6, max_points=1024)
|
| 476 |
-
|
| 477 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 478 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 479 |
-
|
| 480 |
-
model.to(device)
|
| 481 |
-
model.eval()
|
| 482 |
-
|
| 483 |
-
return model
|
| 484 |
-
|
| 485 |
-
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
|
| 486 |
-
"""
|
| 487 |
-
Predict binary classification from a patch using trained PointNet.
|
| 488 |
-
|
| 489 |
-
Args:
|
| 490 |
-
model: Trained ClassificationPointNet model
|
| 491 |
-
patch: Dictionary containing patch data with 'patch_6d' key
|
| 492 |
-
device: Device to run prediction on
|
| 493 |
-
|
| 494 |
-
Returns:
|
| 495 |
-
tuple of (predicted_class, confidence)
|
| 496 |
-
predicted_class: int (0 for not edge, 1 for edge)
|
| 497 |
-
confidence: float representing confidence score (0-1)
|
| 498 |
-
"""
|
| 499 |
-
if device is None:
|
| 500 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 501 |
-
|
| 502 |
-
patch_6d = patch['patch_6d'] # (N, 6)
|
| 503 |
-
|
| 504 |
-
# Prepare input
|
| 505 |
-
max_points = 1024
|
| 506 |
-
num_points = patch_6d.shape[0]
|
| 507 |
-
|
| 508 |
-
if num_points >= max_points:
|
| 509 |
-
# Sample points
|
| 510 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 511 |
-
patch_sampled = patch_6d[indices]
|
| 512 |
-
else:
|
| 513 |
-
# Pad with zeros
|
| 514 |
-
patch_sampled = np.zeros((max_points, 6))
|
| 515 |
-
patch_sampled[:num_points] = patch_6d
|
| 516 |
-
|
| 517 |
-
# Convert to tensor
|
| 518 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points)
|
| 519 |
-
patch_tensor = patch_tensor.to(device)
|
| 520 |
-
|
| 521 |
-
# Predict
|
| 522 |
-
with torch.no_grad():
|
| 523 |
-
outputs = model(patch_tensor) # (1, 1)
|
| 524 |
-
probability = torch.sigmoid(outputs).item()
|
| 525 |
-
predicted_class = int(probability > 0.5)
|
| 526 |
-
|
| 527 |
-
return predicted_class, probability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_class_v2.py
DELETED
|
@@ -1,508 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class ClassificationPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Fast PointNet-like implementation for binary classification from point cloud patches.
|
| 14 |
-
Adapted from FastPointNet, focusing only on classification.
|
| 15 |
-
Takes N-dimensional point clouds and predicts a binary class.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_dim: int = 10, max_points: int = 1024, num_classes: int = 1):
|
| 18 |
-
super(ClassificationPointNet, self).__init__()
|
| 19 |
-
self.max_points = max_points
|
| 20 |
-
self.num_classes = num_classes
|
| 21 |
-
|
| 22 |
-
# Enhanced point-wise MLPs with residual connections
|
| 23 |
-
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 24 |
-
self.conv2 = nn.Conv1d(64, 128, 1)
|
| 25 |
-
self.conv3 = nn.Conv1d(128, 256, 1)
|
| 26 |
-
self.conv4 = nn.Conv1d(256, 512, 1)
|
| 27 |
-
self.conv5 = nn.Conv1d(512, 1024, 1)
|
| 28 |
-
self.conv6 = nn.Conv1d(1024, 1024, 1) # Matches FastPointNet structure
|
| 29 |
-
self.conv7 = nn.Conv1d(1024, 2048, 1) # Matches FastPointNet structure
|
| 30 |
-
|
| 31 |
-
# Lightweight channel attention mechanism
|
| 32 |
-
self.channel_attention = nn.Sequential(
|
| 33 |
-
nn.AdaptiveAvgPool1d(1),
|
| 34 |
-
nn.Conv1d(2048, 128, 1),
|
| 35 |
-
nn.ReLU(inplace=True),
|
| 36 |
-
nn.Conv1d(128, 2048, 1),
|
| 37 |
-
nn.Sigmoid()
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
# Enhanced shared features with residual connections
|
| 41 |
-
self.shared_fc1 = nn.Linear(2048, 1024)
|
| 42 |
-
self.shared_fc2 = nn.Linear(1024, 512)
|
| 43 |
-
self.shared_fc3 = nn.Linear(512, 512)
|
| 44 |
-
|
| 45 |
-
# Classification head
|
| 46 |
-
self.class_fc1 = nn.Linear(512, 512)
|
| 47 |
-
self.class_fc2 = nn.Linear(512, 256)
|
| 48 |
-
self.class_fc3 = nn.Linear(256, 128)
|
| 49 |
-
self.class_fc4 = nn.Linear(128, 64)
|
| 50 |
-
self.class_fc5 = nn.Linear(64, self.num_classes) # Output for classification
|
| 51 |
-
|
| 52 |
-
# Batch normalization layers with momentum
|
| 53 |
-
self.bn1 = nn.BatchNorm1d(64, momentum=0.1)
|
| 54 |
-
self.bn2 = nn.BatchNorm1d(128, momentum=0.1)
|
| 55 |
-
self.bn3 = nn.BatchNorm1d(256, momentum=0.1)
|
| 56 |
-
self.bn4 = nn.BatchNorm1d(512, momentum=0.1)
|
| 57 |
-
self.bn5 = nn.BatchNorm1d(1024, momentum=0.1)
|
| 58 |
-
self.bn6 = nn.BatchNorm1d(1024, momentum=0.1)
|
| 59 |
-
self.bn7 = nn.BatchNorm1d(2048, momentum=0.1)
|
| 60 |
-
|
| 61 |
-
# Group normalization for shared layers
|
| 62 |
-
self.gn1 = nn.GroupNorm(32, 1024) # Assuming 1024 channels, 32 groups
|
| 63 |
-
self.gn2 = nn.GroupNorm(16, 512) # Assuming 512 channels, 16 groups
|
| 64 |
-
|
| 65 |
-
# Dropout layers
|
| 66 |
-
self.dropout_light = nn.Dropout(0.1)
|
| 67 |
-
self.dropout_medium = nn.Dropout(0.2)
|
| 68 |
-
# self.dropout_heavy = nn.Dropout(0.3) # Not used in the direct path to classification in this adaptation
|
| 69 |
-
|
| 70 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 71 |
-
"""
|
| 72 |
-
Forward pass with residual connections and attention for classification.
|
| 73 |
-
Args:
|
| 74 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 75 |
-
Returns:
|
| 76 |
-
classification: (batch_size, num_classes) tensor of logits
|
| 77 |
-
"""
|
| 78 |
-
# Enhanced point-wise feature extraction
|
| 79 |
-
x1 = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01, inplace=True)
|
| 80 |
-
x2 = F.leaky_relu(self.bn2(self.conv2(x1)), negative_slope=0.01, inplace=True)
|
| 81 |
-
x3 = F.leaky_relu(self.bn3(self.conv3(x2)), negative_slope=0.01, inplace=True)
|
| 82 |
-
x4 = F.leaky_relu(self.bn4(self.conv4(x3)), negative_slope=0.01, inplace=True)
|
| 83 |
-
x5 = F.leaky_relu(self.bn5(self.conv5(x4)), negative_slope=0.01, inplace=True)
|
| 84 |
-
|
| 85 |
-
# Residual connection for conv6
|
| 86 |
-
x6_conv = self.bn6(self.conv6(x5))
|
| 87 |
-
x6 = F.leaky_relu(x6_conv + x5, negative_slope=0.01, inplace=True) # Add residual before ReLU
|
| 88 |
-
|
| 89 |
-
x7 = F.leaky_relu(self.bn7(self.conv7(x6)), negative_slope=0.01, inplace=True)
|
| 90 |
-
|
| 91 |
-
# Apply channel attention
|
| 92 |
-
attention_weights = self.channel_attention(x7)
|
| 93 |
-
x7_attended = x7 * attention_weights
|
| 94 |
-
|
| 95 |
-
# Multi-scale global pooling
|
| 96 |
-
max_pool = torch.max(x7_attended, 2)[0]
|
| 97 |
-
avg_pool = torch.mean(x7_attended, 2)
|
| 98 |
-
global_features = 0.7 * max_pool + 0.3 * avg_pool
|
| 99 |
-
|
| 100 |
-
# Enhanced shared features
|
| 101 |
-
shared1_fc = self.shared_fc1(global_features)
|
| 102 |
-
shared1 = F.leaky_relu(self.gn1(shared1_fc.unsqueeze(-1)).squeeze(-1), negative_slope=0.01, inplace=True)
|
| 103 |
-
shared1 = self.dropout_light(shared1)
|
| 104 |
-
|
| 105 |
-
shared2_fc = self.shared_fc2(shared1)
|
| 106 |
-
shared2 = F.leaky_relu(self.gn2(shared2_fc.unsqueeze(-1)).squeeze(-1), negative_slope=0.01, inplace=True)
|
| 107 |
-
shared2 = self.dropout_medium(shared2)
|
| 108 |
-
|
| 109 |
-
shared3_fc = self.shared_fc3(shared2)
|
| 110 |
-
# Residual connection for shared_fc3
|
| 111 |
-
shared_features = F.leaky_relu(shared3_fc + shared2, negative_slope=0.01, inplace=True) # Add residual before ReLU
|
| 112 |
-
shared_features = self.dropout_light(shared_features) # Apply dropout after residual and ReLU
|
| 113 |
-
|
| 114 |
-
# Classification head
|
| 115 |
-
class1 = F.leaky_relu(self.class_fc1(shared_features), negative_slope=0.01, inplace=True)
|
| 116 |
-
class1 = self.dropout_light(class1)
|
| 117 |
-
|
| 118 |
-
class2 = F.leaky_relu(self.class_fc2(class1), negative_slope=0.01, inplace=True)
|
| 119 |
-
class2 = self.dropout_medium(class2)
|
| 120 |
-
|
| 121 |
-
class3 = F.leaky_relu(self.class_fc3(class2), negative_slope=0.01, inplace=True)
|
| 122 |
-
class3 = self.dropout_light(class3)
|
| 123 |
-
|
| 124 |
-
class4 = F.leaky_relu(self.class_fc4(class3), negative_slope=0.01, inplace=True)
|
| 125 |
-
# No dropout before the final layer typically
|
| 126 |
-
|
| 127 |
-
classification = self.class_fc5(class4) # Raw logits
|
| 128 |
-
|
| 129 |
-
return classification
|
| 130 |
-
|
| 131 |
-
class PatchClassificationDataset(Dataset):
|
| 132 |
-
"""
|
| 133 |
-
Dataset class for loading saved patches for PointNet classification training.
|
| 134 |
-
"""
|
| 135 |
-
|
| 136 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = False, input_dim: int = 10): # Added input_dim
|
| 137 |
-
self.dataset_dir = dataset_dir
|
| 138 |
-
self.max_points = max_points
|
| 139 |
-
self.augment = augment
|
| 140 |
-
self.input_dim = input_dim # Store input_dim
|
| 141 |
-
|
| 142 |
-
# Load patch files
|
| 143 |
-
self.patch_files = []
|
| 144 |
-
for file in os.listdir(dataset_dir):
|
| 145 |
-
if file.endswith('.pkl'):
|
| 146 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 147 |
-
|
| 148 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 149 |
-
|
| 150 |
-
def __len__(self):
|
| 151 |
-
return len(self.patch_files)
|
| 152 |
-
|
| 153 |
-
def __getitem__(self, idx):
|
| 154 |
-
"""
|
| 155 |
-
Load and process a patch for training.
|
| 156 |
-
Returns:
|
| 157 |
-
patch_data: (input_dim, max_points) tensor of point cloud data
|
| 158 |
-
label: scalar tensor for binary classification (0 or 1)
|
| 159 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 160 |
-
"""
|
| 161 |
-
patch_file = self.patch_files[idx]
|
| 162 |
-
|
| 163 |
-
with open(patch_file, 'rb') as f:
|
| 164 |
-
patch_info = pickle.load(f)
|
| 165 |
-
|
| 166 |
-
# Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
|
| 167 |
-
# For this example, let's assume the key is 'patch_data' and it holds the 10D data.
|
| 168 |
-
# If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
|
| 169 |
-
patch_data_nd = patch_info.get('patch_data', patch_info.get('patch_10d', patch_info.get('patch_6d'))) # Try to get 10d, fallback to 6d for now
|
| 170 |
-
if patch_data_nd.shape[1] != self.input_dim:
|
| 171 |
-
# This is a fallback or error handling if the loaded data isn't 10D.
|
| 172 |
-
# You might want to raise an error or handle this case specifically.
|
| 173 |
-
# For now, if it's 6D, we'll pad it to 10D with zeros as a placeholder.
|
| 174 |
-
# This part needs to be adjusted based on how your 10D data is actually stored.
|
| 175 |
-
print(f"Warning: Patch {patch_file} has {patch_data_nd.shape[1]} dimensions, expected {self.input_dim}. Padding with zeros if necessary.")
|
| 176 |
-
if patch_data_nd.shape[1] < self.input_dim:
|
| 177 |
-
padding = np.zeros((patch_data_nd.shape[0], self.input_dim - patch_data_nd.shape[1]))
|
| 178 |
-
patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
|
| 179 |
-
elif patch_data_nd.shape[1] > self.input_dim:
|
| 180 |
-
patch_data_nd = patch_data_nd[:, :self.input_dim]
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
|
| 184 |
-
|
| 185 |
-
# Pad or sample points to max_points
|
| 186 |
-
num_points = patch_data_nd.shape[0]
|
| 187 |
-
|
| 188 |
-
if num_points >= self.max_points:
|
| 189 |
-
# Randomly sample max_points
|
| 190 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 191 |
-
patch_sampled = patch_data_nd[indices]
|
| 192 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 193 |
-
else:
|
| 194 |
-
# Pad with zeros
|
| 195 |
-
patch_sampled = np.zeros((self.max_points, self.input_dim)) # Changed to self.input_dim
|
| 196 |
-
patch_sampled[:num_points] = patch_data_nd
|
| 197 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 198 |
-
valid_mask[:num_points] = True
|
| 199 |
-
|
| 200 |
-
# Data augmentation
|
| 201 |
-
if self.augment:
|
| 202 |
-
# Note: _augment_patch currently only augments xyz (first 3 dims).
|
| 203 |
-
# If other dimensions are geometric and need augmentation, this function needs an update.
|
| 204 |
-
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 205 |
-
|
| 206 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 207 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (input_dim, max_points)
|
| 208 |
-
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
|
| 209 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 210 |
-
|
| 211 |
-
return patch_tensor, label_tensor, valid_mask_tensor
|
| 212 |
-
|
| 213 |
-
def _augment_patch(self, patch, valid_mask):
|
| 214 |
-
"""
|
| 215 |
-
Apply data augmentation to the patch.
|
| 216 |
-
Note: This implementation only augments the first 3 dimensions (assumed to be XYZ).
|
| 217 |
-
If your 10D representation has other geometric features that need augmentation,
|
| 218 |
-
this function should be updated accordingly.
|
| 219 |
-
"""
|
| 220 |
-
valid_points_data = patch[valid_mask]
|
| 221 |
-
|
| 222 |
-
if len(valid_points_data) == 0:
|
| 223 |
-
return patch
|
| 224 |
-
|
| 225 |
-
# Extract XYZ for augmentation (first 3 columns)
|
| 226 |
-
valid_points_xyz = valid_points_data[:, :3].copy() # Operate on a copy
|
| 227 |
-
|
| 228 |
-
# Random rotation around z-axis
|
| 229 |
-
angle = np.random.uniform(0, 2 * np.pi)
|
| 230 |
-
cos_angle = np.cos(angle)
|
| 231 |
-
sin_angle = np.sin(angle)
|
| 232 |
-
rotation_matrix = np.array([
|
| 233 |
-
[cos_angle, -sin_angle, 0],
|
| 234 |
-
[sin_angle, cos_angle, 0],
|
| 235 |
-
[0, 0, 1]
|
| 236 |
-
])
|
| 237 |
-
|
| 238 |
-
# Apply rotation to xyz coordinates
|
| 239 |
-
valid_points_xyz = valid_points_xyz @ rotation_matrix.T
|
| 240 |
-
|
| 241 |
-
# Random jittering
|
| 242 |
-
noise = np.random.normal(0, 0.01, valid_points_xyz.shape)
|
| 243 |
-
valid_points_xyz += noise
|
| 244 |
-
|
| 245 |
-
# Random scaling
|
| 246 |
-
scale = np.random.uniform(0.9, 1.1)
|
| 247 |
-
valid_points_xyz *= scale
|
| 248 |
-
|
| 249 |
-
# Update the original patch data
|
| 250 |
-
augmented_patch = patch.copy()
|
| 251 |
-
augmented_patch[valid_mask, :3] = valid_points_xyz
|
| 252 |
-
|
| 253 |
-
return augmented_patch
|
| 254 |
-
|
| 255 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 256 |
-
"""
|
| 257 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 258 |
-
Ensure 'patch_data' (or 'patch_10d') in the patch dictionary contains the 10D data.
|
| 259 |
-
|
| 260 |
-
Args:
|
| 261 |
-
patches: List of patch dictionaries from generate_patches()
|
| 262 |
-
dataset_dir: Directory to save the dataset
|
| 263 |
-
entry_id: Unique identifier for this entry/image
|
| 264 |
-
"""
|
| 265 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 266 |
-
|
| 267 |
-
for i, patch in enumerate(patches):
|
| 268 |
-
# Create unique filename
|
| 269 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 270 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 271 |
-
|
| 272 |
-
# Skip if file already exists
|
| 273 |
-
if os.path.exists(filepath):
|
| 274 |
-
continue
|
| 275 |
-
|
| 276 |
-
# Ensure the patch data being saved is 10D.
|
| 277 |
-
# Example: patch_data_key = 'patch_10d' or 'patch_data'
|
| 278 |
-
# if 'patch_data' not in patch or patch['patch_data'].shape[1] != 10:
|
| 279 |
-
# print(f"Warning: Patch {i} for entry {entry_id} does not seem to be 10D. Skipping or error handling needed.")
|
| 280 |
-
# continue
|
| 281 |
-
|
| 282 |
-
with open(filepath, 'wb') as f:
|
| 283 |
-
pickle.dump(patch, f)
|
| 284 |
-
|
| 285 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 286 |
-
|
| 287 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 288 |
-
def collate_fn(batch):
|
| 289 |
-
valid_batch = []
|
| 290 |
-
for patch_data, label, valid_mask in batch:
|
| 291 |
-
# Filter out invalid samples (no valid points)
|
| 292 |
-
if valid_mask.sum() > 0:
|
| 293 |
-
valid_batch.append((patch_data, label, valid_mask))
|
| 294 |
-
|
| 295 |
-
if len(valid_batch) == 0:
|
| 296 |
-
return None
|
| 297 |
-
|
| 298 |
-
# Stack valid samples
|
| 299 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 300 |
-
labels = torch.stack([item[1] for item in valid_batch])
|
| 301 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 302 |
-
|
| 303 |
-
return patch_data, labels, valid_masks
|
| 304 |
-
|
| 305 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 306 |
-
def init_weights(m):
|
| 307 |
-
if isinstance(m, nn.Conv1d):
|
| 308 |
-
nn.init.xavier_uniform_(m.weight)
|
| 309 |
-
if m.bias is not None:
|
| 310 |
-
nn.init.zeros_(m.bias)
|
| 311 |
-
elif isinstance(m, nn.Linear):
|
| 312 |
-
nn.init.xavier_uniform_(m.weight)
|
| 313 |
-
if m.bias is not None:
|
| 314 |
-
nn.init.zeros_(m.bias)
|
| 315 |
-
elif isinstance(m, nn.BatchNorm1d):
|
| 316 |
-
nn.init.ones_(m.weight)
|
| 317 |
-
nn.init.zeros_(m.bias)
|
| 318 |
-
|
| 319 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
|
| 320 |
-
lr: float = 0.001, input_dim: int = 10): # Added input_dim
|
| 321 |
-
"""
|
| 322 |
-
Train the ClassificationPointNet model on saved patches.
|
| 323 |
-
|
| 324 |
-
Args:
|
| 325 |
-
dataset_dir: Directory containing saved patch files
|
| 326 |
-
model_save_path: Path to save the trained model
|
| 327 |
-
epochs: Number of training epochs
|
| 328 |
-
batch_size: Training batch size
|
| 329 |
-
lr: Learning rate
|
| 330 |
-
input_dim: Dimensionality of the input points (e.g., 10 for 10D)
|
| 331 |
-
"""
|
| 332 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 333 |
-
print(f"Training on device: {device}")
|
| 334 |
-
|
| 335 |
-
# Create dataset and dataloader
|
| 336 |
-
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=False, input_dim=input_dim) # Pass input_dim
|
| 337 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 338 |
-
|
| 339 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20,
|
| 340 |
-
collate_fn=collate_fn, drop_last=True)
|
| 341 |
-
|
| 342 |
-
# Initialize model
|
| 343 |
-
model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Pass input_dim
|
| 344 |
-
model.apply(init_weights)
|
| 345 |
-
model.to(device)
|
| 346 |
-
|
| 347 |
-
# Loss function and optimizer (BCE for binary classification)
|
| 348 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 349 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 350 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 351 |
-
|
| 352 |
-
# Training loop
|
| 353 |
-
model.train()
|
| 354 |
-
for epoch in range(epochs):
|
| 355 |
-
total_loss = 0.0
|
| 356 |
-
correct = 0
|
| 357 |
-
total = 0
|
| 358 |
-
num_batches = 0
|
| 359 |
-
|
| 360 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 361 |
-
if batch_data is None: # Skip invalid batches
|
| 362 |
-
continue
|
| 363 |
-
|
| 364 |
-
patch_data, labels, valid_masks = batch_data
|
| 365 |
-
patch_data = patch_data.to(device) # (batch_size, input_dim, max_points)
|
| 366 |
-
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
|
| 367 |
-
|
| 368 |
-
# Forward pass
|
| 369 |
-
optimizer.zero_grad()
|
| 370 |
-
outputs = model(patch_data) # (batch_size, 1)
|
| 371 |
-
loss = criterion(outputs, labels)
|
| 372 |
-
|
| 373 |
-
# Backward pass
|
| 374 |
-
loss.backward()
|
| 375 |
-
optimizer.step()
|
| 376 |
-
|
| 377 |
-
# Statistics
|
| 378 |
-
total_loss += loss.item()
|
| 379 |
-
predicted = (torch.sigmoid(outputs) > 0.5).float()
|
| 380 |
-
total += labels.size(0)
|
| 381 |
-
correct += (predicted == labels).sum().item()
|
| 382 |
-
num_batches += 1
|
| 383 |
-
|
| 384 |
-
if batch_idx % 50 == 0:
|
| 385 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 386 |
-
f"Loss: {loss.item():.6f}, "
|
| 387 |
-
f"Accuracy: {100 * correct / total:.2f}%")
|
| 388 |
-
|
| 389 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 390 |
-
accuracy = 100 * correct / total if total > 0 else 0
|
| 391 |
-
|
| 392 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 393 |
-
f"Avg Loss: {avg_loss:.6f}, "
|
| 394 |
-
f"Accuracy: {accuracy:.2f}%")
|
| 395 |
-
|
| 396 |
-
scheduler.step()
|
| 397 |
-
|
| 398 |
-
# Save model checkpoint every epoch
|
| 399 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 400 |
-
torch.save({
|
| 401 |
-
'model_state_dict': model.state_dict(),
|
| 402 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 403 |
-
'epoch': epoch + 1,
|
| 404 |
-
'loss': avg_loss,
|
| 405 |
-
'accuracy': accuracy,
|
| 406 |
-
'input_dim': input_dim, # Save input_dim with checkpoint
|
| 407 |
-
}, checkpoint_path)
|
| 408 |
-
|
| 409 |
-
# Save the trained model
|
| 410 |
-
torch.save({
|
| 411 |
-
'model_state_dict': model.state_dict(),
|
| 412 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 413 |
-
'epoch': epochs,
|
| 414 |
-
'input_dim': input_dim, # Save input_dim with final model
|
| 415 |
-
}, model_save_path)
|
| 416 |
-
|
| 417 |
-
print(f"Model saved to {model_save_path}")
|
| 418 |
-
return model
|
| 419 |
-
|
| 420 |
-
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
|
| 421 |
-
"""
|
| 422 |
-
Load a trained ClassificationPointNet model.
|
| 423 |
-
|
| 424 |
-
Args:
|
| 425 |
-
model_path: Path to the saved model
|
| 426 |
-
device: Device to load the model on
|
| 427 |
-
|
| 428 |
-
Returns:
|
| 429 |
-
Loaded ClassificationPointNet model
|
| 430 |
-
"""
|
| 431 |
-
if device is None:
|
| 432 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 433 |
-
|
| 434 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 435 |
-
|
| 436 |
-
# Load input_dim from checkpoint if available, otherwise default to 10
|
| 437 |
-
# For older models saved without input_dim, you might need to specify it or assume a default.
|
| 438 |
-
input_dim = checkpoint.get('input_dim', 10)
|
| 439 |
-
|
| 440 |
-
model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Use loaded or default input_dim
|
| 441 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 442 |
-
|
| 443 |
-
model.to(device)
|
| 444 |
-
model.eval()
|
| 445 |
-
|
| 446 |
-
return model
|
| 447 |
-
|
| 448 |
-
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
|
| 449 |
-
"""
|
| 450 |
-
Predict binary classification from a patch using trained PointNet.
|
| 451 |
-
Assumes the model's input_dim matches the data.
|
| 452 |
-
|
| 453 |
-
Args:
|
| 454 |
-
model: Trained ClassificationPointNet model
|
| 455 |
-
patch: Dictionary containing patch data. Expects a key like 'patch_data' or 'patch_10d' with (N, 10) shape.
|
| 456 |
-
device: Device to run prediction on
|
| 457 |
-
|
| 458 |
-
Returns:
|
| 459 |
-
tuple of (predicted_class, confidence)
|
| 460 |
-
predicted_class: int (0 for not edge, 1 for edge)
|
| 461 |
-
confidence: float representing confidence score (0-1)
|
| 462 |
-
"""
|
| 463 |
-
if device is None:
|
| 464 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 465 |
-
|
| 466 |
-
# Determine input_dim from the model
|
| 467 |
-
input_dim = model.conv1.in_channels
|
| 468 |
-
|
| 469 |
-
# Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
|
| 470 |
-
# For this example, let's assume the key is 'patch_data' and it holds the 10D data.
|
| 471 |
-
# If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
|
| 472 |
-
patch_data_nd = patch.get('patch_data', patch.get('patch_10d', patch.get('patch_6d'))) # Try to get 10d, fallback to 6d
|
| 473 |
-
|
| 474 |
-
if patch_data_nd.shape[1] != input_dim:
|
| 475 |
-
# Handle dimension mismatch, e.g., by padding or raising an error
|
| 476 |
-
print(f"Warning: Input patch has {patch_data_nd.shape[1]} dimensions, but model expects {input_dim}. Adjusting...")
|
| 477 |
-
if patch_data_nd.shape[1] < input_dim:
|
| 478 |
-
padding = np.zeros((patch_data_nd.shape[0], input_dim - patch_data_nd.shape[1]))
|
| 479 |
-
patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
|
| 480 |
-
elif patch_data_nd.shape[1] > input_dim:
|
| 481 |
-
patch_data_nd = patch_data_nd[:, :input_dim]
|
| 482 |
-
|
| 483 |
-
# Prepare input
|
| 484 |
-
max_points = model.max_points # Use max_points from the model instance
|
| 485 |
-
num_points = patch_data_nd.shape[0]
|
| 486 |
-
|
| 487 |
-
if num_points >= max_points:
|
| 488 |
-
# Sample points
|
| 489 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 490 |
-
patch_sampled = patch_data_nd[indices]
|
| 491 |
-
else:
|
| 492 |
-
# Pad with zeros
|
| 493 |
-
patch_sampled = np.zeros((max_points, input_dim)) # Use model's input_dim
|
| 494 |
-
patch_sampled[:num_points] = patch_data_nd
|
| 495 |
-
|
| 496 |
-
# Convert to tensor
|
| 497 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, input_dim, max_points)
|
| 498 |
-
patch_tensor = patch_tensor.to(device)
|
| 499 |
-
|
| 500 |
-
# Predict
|
| 501 |
-
model.eval() # Ensure model is in eval mode
|
| 502 |
-
with torch.no_grad():
|
| 503 |
-
outputs = model(patch_tensor) # (1, 1)
|
| 504 |
-
probability = torch.sigmoid(outputs).item()
|
| 505 |
-
predicted_class = int(probability > 0.5)
|
| 506 |
-
|
| 507 |
-
return predicted_class, probability
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_pointnet_v2.py
CHANGED
|
@@ -1,3 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
@@ -568,4 +578,4 @@ def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: to
|
|
| 568 |
offset = patch['cluster_center']
|
| 569 |
position += offset
|
| 570 |
|
| 571 |
-
return position, score, classification
|
|
|
|
| 1 |
+
# This file defines a FastPointNet model for 3D vertex prediction from point clouds.
|
| 2 |
+
# It includes:
|
| 3 |
+
# 1. `FastPointNet`: A deep neural network with enhancements like residual connections,
|
| 4 |
+
# channel attention, and multi-scale pooling. It predicts 3D coordinates,
|
| 5 |
+
# and optionally, confidence scores and classification labels.
|
| 6 |
+
# 2. `PatchDataset`: A PyTorch Dataset for loading, preprocessing, and augmenting
|
| 7 |
+
# 11-dimensional point cloud patches.
|
| 8 |
+
# 3. Utility functions for:
|
| 9 |
+
# - Training the model (`train_pointnet`) with custom loss and optimization.
|
| 10 |
+
# - Loading/saving models, and performing inference (`predict_vertex_from_patch`).
|
| 11 |
import os
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
|
|
|
| 578 |
offset = patch['cluster_center']
|
| 579 |
position += offset
|
| 580 |
|
| 581 |
+
return position, score, classification
|
fast_pointnet_v3.py
DELETED
|
@@ -1,605 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class FastPointNet(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
|
| 14 |
-
Takes 11D point clouds and predicts 3D vertex coordinates.
|
| 15 |
-
Enhanced with transformer attention, deeper architecture, and moderate capacity increase.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1):
|
| 18 |
-
super(FastPointNet, self).__init__()
|
| 19 |
-
self.max_points = max_points
|
| 20 |
-
self.predict_score = predict_score
|
| 21 |
-
self.predict_class = predict_class
|
| 22 |
-
self.num_classes = num_classes
|
| 23 |
-
|
| 24 |
-
# Enhanced point-wise MLPs with moderate capacity increase
|
| 25 |
-
self.conv1 = nn.Conv1d(input_dim, 96, 1) # 64 -> 96
|
| 26 |
-
self.conv2 = nn.Conv1d(96, 192, 1) # 128 -> 192
|
| 27 |
-
self.conv3 = nn.Conv1d(192, 384, 1) # 256 -> 384
|
| 28 |
-
self.conv4 = nn.Conv1d(384, 768, 1) # 512 -> 768
|
| 29 |
-
self.conv5 = nn.Conv1d(768, 1536, 1) # 1024 -> 1536
|
| 30 |
-
self.conv6 = nn.Conv1d(1536, 1536, 1) # Keep same
|
| 31 |
-
self.conv7 = nn.Conv1d(1536, 2048, 1) # Reduce from 1536 to 2048 for transformer
|
| 32 |
-
|
| 33 |
-
# Lightweight Self-Attention Transformer Block
|
| 34 |
-
self.transformer_dim = 2048
|
| 35 |
-
self.num_heads = 8
|
| 36 |
-
self.transformer_block = nn.MultiheadAttention(
|
| 37 |
-
embed_dim=self.transformer_dim,
|
| 38 |
-
num_heads=self.num_heads,
|
| 39 |
-
dropout=0.1,
|
| 40 |
-
batch_first=False
|
| 41 |
-
)
|
| 42 |
-
self.transformer_norm1 = nn.LayerNorm(self.transformer_dim)
|
| 43 |
-
self.transformer_norm2 = nn.LayerNorm(self.transformer_dim)
|
| 44 |
-
|
| 45 |
-
# Transformer FFN
|
| 46 |
-
self.transformer_ffn = nn.Sequential(
|
| 47 |
-
nn.Linear(self.transformer_dim, self.transformer_dim * 2),
|
| 48 |
-
nn.GELU(),
|
| 49 |
-
nn.Dropout(0.1),
|
| 50 |
-
nn.Linear(self.transformer_dim * 2, self.transformer_dim),
|
| 51 |
-
nn.Dropout(0.1)
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
# Enhanced channel attention mechanism
|
| 55 |
-
self.channel_attention = nn.Sequential(
|
| 56 |
-
nn.AdaptiveAvgPool1d(1),
|
| 57 |
-
nn.Conv1d(2048, 192, 1), # 128 -> 192
|
| 58 |
-
nn.GELU(),
|
| 59 |
-
nn.Conv1d(192, 2048, 1),
|
| 60 |
-
nn.Sigmoid()
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
# Enhanced shared features with moderate increase
|
| 64 |
-
self.shared_fc1 = nn.Linear(2048, 1536) # 1024 -> 1536
|
| 65 |
-
self.shared_fc2 = nn.Linear(1536, 768) # 512 -> 768
|
| 66 |
-
self.shared_fc3 = nn.Linear(768, 768) # Additional layer
|
| 67 |
-
|
| 68 |
-
# Enhanced position prediction head
|
| 69 |
-
self.pos_fc1 = nn.Linear(768, 768) # 512 -> 768
|
| 70 |
-
self.pos_fc2 = nn.Linear(768, 384) # 256 -> 384
|
| 71 |
-
self.pos_fc3 = nn.Linear(384, 192) # 128 -> 192
|
| 72 |
-
self.pos_fc4 = nn.Linear(192, 96) # 64 -> 96
|
| 73 |
-
self.pos_fc5 = nn.Linear(96, output_dim)
|
| 74 |
-
|
| 75 |
-
# Enhanced score prediction head
|
| 76 |
-
if self.predict_score:
|
| 77 |
-
self.score_fc1 = nn.Linear(768, 768)
|
| 78 |
-
self.score_fc2 = nn.Linear(768, 384)
|
| 79 |
-
self.score_fc3 = nn.Linear(384, 192)
|
| 80 |
-
self.score_fc4 = nn.Linear(192, 96)
|
| 81 |
-
self.score_fc5 = nn.Linear(96, 1)
|
| 82 |
-
|
| 83 |
-
# Enhanced classification head
|
| 84 |
-
if self.predict_class:
|
| 85 |
-
self.class_fc1 = nn.Linear(768, 768)
|
| 86 |
-
self.class_fc2 = nn.Linear(768, 384)
|
| 87 |
-
self.class_fc3 = nn.Linear(384, 192)
|
| 88 |
-
self.class_fc4 = nn.Linear(192, 96)
|
| 89 |
-
self.class_fc5 = nn.Linear(96, num_classes)
|
| 90 |
-
|
| 91 |
-
# Batch normalization layers
|
| 92 |
-
self.bn1 = nn.BatchNorm1d(96, momentum=0.1)
|
| 93 |
-
self.bn2 = nn.BatchNorm1d(192, momentum=0.1)
|
| 94 |
-
self.bn3 = nn.BatchNorm1d(384, momentum=0.1)
|
| 95 |
-
self.bn4 = nn.BatchNorm1d(768, momentum=0.1)
|
| 96 |
-
self.bn5 = nn.BatchNorm1d(1536, momentum=0.1)
|
| 97 |
-
self.bn6 = nn.BatchNorm1d(1536, momentum=0.1)
|
| 98 |
-
self.bn7 = nn.BatchNorm1d(2048, momentum=0.1)
|
| 99 |
-
|
| 100 |
-
# Group normalization for shared layers
|
| 101 |
-
self.gn1 = nn.GroupNorm(48, 1536) # 32 -> 48 groups
|
| 102 |
-
self.gn2 = nn.GroupNorm(24, 768) # 16 -> 24 groups
|
| 103 |
-
|
| 104 |
-
# Dropout with different rates
|
| 105 |
-
self.dropout_light = nn.Dropout(0.1)
|
| 106 |
-
self.dropout_medium = nn.Dropout(0.2)
|
| 107 |
-
self.dropout_heavy = nn.Dropout(0.3)
|
| 108 |
-
|
| 109 |
-
def forward(self, x):
|
| 110 |
-
"""
|
| 111 |
-
Forward pass with transformer attention and residual connections
|
| 112 |
-
Args:
|
| 113 |
-
x: (batch_size, input_dim, max_points) tensor
|
| 114 |
-
Returns:
|
| 115 |
-
Tuple containing predictions based on configuration
|
| 116 |
-
"""
|
| 117 |
-
batch_size = x.size(0)
|
| 118 |
-
|
| 119 |
-
# Enhanced point-wise feature extraction
|
| 120 |
-
x1 = F.gelu(self.bn1(self.conv1(x)))
|
| 121 |
-
x2 = F.gelu(self.bn2(self.conv2(x1)))
|
| 122 |
-
x3 = F.gelu(self.bn3(self.conv3(x2)))
|
| 123 |
-
x4 = F.gelu(self.bn4(self.conv4(x3)))
|
| 124 |
-
x5 = F.gelu(self.bn5(self.conv5(x4)))
|
| 125 |
-
|
| 126 |
-
# Residual connection
|
| 127 |
-
x6 = F.gelu(self.bn6(self.conv6(x5)) + x5)
|
| 128 |
-
x7 = F.gelu(self.bn7(self.conv7(x6)))
|
| 129 |
-
|
| 130 |
-
# Apply transformer attention
|
| 131 |
-
# Reshape for transformer: (seq_len, batch_size, embed_dim)
|
| 132 |
-
x7_reshaped = x7.permute(2, 0, 1) # (max_points, batch_size, 2048)
|
| 133 |
-
|
| 134 |
-
# Self-attention with residual connection
|
| 135 |
-
attn_out, _ = self.transformer_block(x7_reshaped, x7_reshaped, x7_reshaped)
|
| 136 |
-
x7_attn = self.transformer_norm1(x7_reshaped + attn_out)
|
| 137 |
-
|
| 138 |
-
# Transformer FFN with residual connection
|
| 139 |
-
ffn_out = self.transformer_ffn(x7_attn)
|
| 140 |
-
x7_transformer = self.transformer_norm2(x7_attn + ffn_out)
|
| 141 |
-
|
| 142 |
-
# Reshape back: (batch_size, embed_dim, seq_len)
|
| 143 |
-
x7_transformer = x7_transformer.permute(1, 2, 0)
|
| 144 |
-
|
| 145 |
-
# Apply channel attention
|
| 146 |
-
attention_weights = self.channel_attention(x7_transformer)
|
| 147 |
-
x7_attended = x7_transformer * attention_weights
|
| 148 |
-
|
| 149 |
-
# Multi-scale global pooling
|
| 150 |
-
max_pool = torch.max(x7_attended, 2)[0] # (batch_size, 2048)
|
| 151 |
-
avg_pool = torch.mean(x7_attended, 2) # (batch_size, 2048)
|
| 152 |
-
|
| 153 |
-
# Weighted combination of pooling operations
|
| 154 |
-
global_features = 0.7 * max_pool + 0.3 * avg_pool
|
| 155 |
-
|
| 156 |
-
# Enhanced shared features with residual connections
|
| 157 |
-
shared1 = F.gelu(self.gn1(self.shared_fc1(global_features).unsqueeze(-1)).squeeze(-1))
|
| 158 |
-
shared1 = self.dropout_light(shared1)
|
| 159 |
-
|
| 160 |
-
shared2 = F.gelu(self.gn2(self.shared_fc2(shared1).unsqueeze(-1)).squeeze(-1))
|
| 161 |
-
shared2 = self.dropout_medium(shared2)
|
| 162 |
-
|
| 163 |
-
# Additional shared layer with residual connection
|
| 164 |
-
shared3 = F.gelu(self.shared_fc3(shared2))
|
| 165 |
-
shared_features = self.dropout_light(shared3) + shared2
|
| 166 |
-
|
| 167 |
-
# Enhanced position prediction
|
| 168 |
-
pos1 = F.gelu(self.pos_fc1(shared_features))
|
| 169 |
-
pos1 = self.dropout_light(pos1)
|
| 170 |
-
|
| 171 |
-
pos2 = F.gelu(self.pos_fc2(pos1))
|
| 172 |
-
pos2 = self.dropout_medium(pos2)
|
| 173 |
-
|
| 174 |
-
pos3 = F.gelu(self.pos_fc3(pos2))
|
| 175 |
-
pos3 = self.dropout_light(pos3)
|
| 176 |
-
|
| 177 |
-
pos4 = F.gelu(self.pos_fc4(pos3))
|
| 178 |
-
position = self.pos_fc5(pos4)
|
| 179 |
-
|
| 180 |
-
outputs = [position]
|
| 181 |
-
|
| 182 |
-
if self.predict_score:
|
| 183 |
-
# Enhanced score prediction
|
| 184 |
-
score1 = F.gelu(self.score_fc1(shared_features))
|
| 185 |
-
score1 = self.dropout_light(score1)
|
| 186 |
-
score2 = F.gelu(self.score_fc2(score1))
|
| 187 |
-
score2 = self.dropout_medium(score2)
|
| 188 |
-
score3 = F.gelu(self.score_fc3(score2))
|
| 189 |
-
score3 = self.dropout_light(score3)
|
| 190 |
-
score4 = F.gelu(self.score_fc4(score3))
|
| 191 |
-
score = F.softplus(self.score_fc5(score4))
|
| 192 |
-
outputs.append(score)
|
| 193 |
-
|
| 194 |
-
if self.predict_class:
|
| 195 |
-
# Classification prediction
|
| 196 |
-
class1 = F.gelu(self.class_fc1(shared_features))
|
| 197 |
-
class1 = self.dropout_light(class1)
|
| 198 |
-
class2 = F.gelu(self.class_fc2(class1))
|
| 199 |
-
class2 = self.dropout_medium(class2)
|
| 200 |
-
class3 = F.gelu(self.class_fc3(class2))
|
| 201 |
-
class3 = self.dropout_light(class3)
|
| 202 |
-
class4 = F.gelu(self.class_fc4(class3))
|
| 203 |
-
classification = self.class_fc5(class4)
|
| 204 |
-
outputs.append(classification)
|
| 205 |
-
|
| 206 |
-
# Return outputs based on configuration
|
| 207 |
-
if len(outputs) == 1:
|
| 208 |
-
return outputs[0]
|
| 209 |
-
elif len(outputs) == 2:
|
| 210 |
-
if self.predict_score:
|
| 211 |
-
return outputs[0], outputs[1]
|
| 212 |
-
else:
|
| 213 |
-
return outputs[0], outputs[1]
|
| 214 |
-
else:
|
| 215 |
-
return outputs[0], outputs[1], outputs[2]
|
| 216 |
-
|
| 217 |
-
class PatchDataset(Dataset):
|
| 218 |
-
"""
|
| 219 |
-
Dataset class for loading saved patches for PointNet training.
|
| 220 |
-
Updated for 11D patches.
|
| 221 |
-
"""
|
| 222 |
-
|
| 223 |
-
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 224 |
-
self.dataset_dir = dataset_dir
|
| 225 |
-
self.max_points = max_points
|
| 226 |
-
self.augment = augment
|
| 227 |
-
|
| 228 |
-
# Load patch files
|
| 229 |
-
self.patch_files = []
|
| 230 |
-
for file in os.listdir(dataset_dir):
|
| 231 |
-
if file.endswith('.pkl'):
|
| 232 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 233 |
-
|
| 234 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 235 |
-
|
| 236 |
-
def __len__(self):
|
| 237 |
-
return len(self.patch_files)
|
| 238 |
-
|
| 239 |
-
def __getitem__(self, idx):
|
| 240 |
-
"""
|
| 241 |
-
Load and process a patch for training.
|
| 242 |
-
Returns:
|
| 243 |
-
patch_data: (11, max_points) tensor of point cloud data
|
| 244 |
-
target: (3,) tensor of target 3D coordinates
|
| 245 |
-
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 246 |
-
distance_to_gt: scalar tensor of distance from initial prediction to GT
|
| 247 |
-
classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
|
| 248 |
-
"""
|
| 249 |
-
patch_file = self.patch_files[idx]
|
| 250 |
-
|
| 251 |
-
with open(patch_file, 'rb') as f:
|
| 252 |
-
patch_info = pickle.load(f)
|
| 253 |
-
|
| 254 |
-
patch_11d = patch_info['patch_11d'] # (N, 11) - Updated for 11D
|
| 255 |
-
target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
|
| 256 |
-
initial_pred = patch_info.get('cluster_center', None) # (3,) or None
|
| 257 |
-
|
| 258 |
-
# Determine classification label based on GT vertex presence
|
| 259 |
-
has_gt_vertex = 1.0 if target is not None else 0.0
|
| 260 |
-
|
| 261 |
-
# Handle patches without ground truth
|
| 262 |
-
if target is None:
|
| 263 |
-
# Use a dummy target for consistency, but mark as invalid with classification
|
| 264 |
-
target = np.zeros(3)
|
| 265 |
-
else:
|
| 266 |
-
target = np.array(target)
|
| 267 |
-
|
| 268 |
-
# Pad or sample points to max_points
|
| 269 |
-
num_points = patch_11d.shape[0]
|
| 270 |
-
|
| 271 |
-
if num_points >= self.max_points:
|
| 272 |
-
# Randomly sample max_points
|
| 273 |
-
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 274 |
-
patch_sampled = patch_11d[indices]
|
| 275 |
-
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 276 |
-
else:
|
| 277 |
-
# Pad with zeros
|
| 278 |
-
patch_sampled = np.zeros((self.max_points, 11)) # Updated for 11D
|
| 279 |
-
patch_sampled[:num_points] = patch_11d
|
| 280 |
-
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 281 |
-
valid_mask[:num_points] = True
|
| 282 |
-
|
| 283 |
-
# Data augmentation (only if GT vertex is present)
|
| 284 |
-
if self.augment and has_gt_vertex > 0:
|
| 285 |
-
patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target)
|
| 286 |
-
|
| 287 |
-
# Convert to tensors and transpose for conv1d (channels first)
|
| 288 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (11, max_points)
|
| 289 |
-
target_tensor = torch.from_numpy(target).float() # (3,)
|
| 290 |
-
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 291 |
-
|
| 292 |
-
# Handle initial_pred
|
| 293 |
-
if initial_pred is not None:
|
| 294 |
-
initial_pred_tensor = torch.from_numpy(initial_pred).float()
|
| 295 |
-
else:
|
| 296 |
-
initial_pred_tensor = torch.zeros(3).float()
|
| 297 |
-
|
| 298 |
-
# Classification tensor
|
| 299 |
-
classification_tensor = torch.tensor(has_gt_vertex).float()
|
| 300 |
-
|
| 301 |
-
return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor
|
| 302 |
-
|
| 303 |
-
def _augment_patch(self, patch_sampled, valid_mask, target):
|
| 304 |
-
"""
|
| 305 |
-
Apply data augmentation to patch and target.
|
| 306 |
-
Only augment valid points and update target accordingly.
|
| 307 |
-
"""
|
| 308 |
-
valid_points = patch_sampled[valid_mask]
|
| 309 |
-
|
| 310 |
-
if len(valid_points) > 0:
|
| 311 |
-
# Random rotation around Z-axis (small angle)
|
| 312 |
-
angle = np.random.uniform(-np.pi/12, np.pi/12) # ±15 degrees
|
| 313 |
-
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
| 314 |
-
rotation_matrix = np.array([[cos_a, -sin_a, 0],
|
| 315 |
-
[sin_a, cos_a, 0],
|
| 316 |
-
[0, 0, 1]])
|
| 317 |
-
|
| 318 |
-
# Apply rotation to xyz coordinates
|
| 319 |
-
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 320 |
-
target = target @ rotation_matrix.T
|
| 321 |
-
|
| 322 |
-
# Small random translation
|
| 323 |
-
translation = np.random.uniform(-0.05, 0.05, 3)
|
| 324 |
-
valid_points[:, :3] += translation
|
| 325 |
-
target += translation
|
| 326 |
-
|
| 327 |
-
# Random scaling (small)
|
| 328 |
-
scale = np.random.uniform(0.95, 1.05)
|
| 329 |
-
valid_points[:, :3] *= scale
|
| 330 |
-
target *= scale
|
| 331 |
-
|
| 332 |
-
# Add small noise to features (not coordinates)
|
| 333 |
-
if valid_points.shape[1] > 3:
|
| 334 |
-
noise = np.random.normal(0, 0.01, valid_points[:, 3:].shape)
|
| 335 |
-
valid_points[:, 3:] += noise
|
| 336 |
-
|
| 337 |
-
# Update patch with augmented valid points
|
| 338 |
-
patch_sampled[valid_mask] = valid_points
|
| 339 |
-
|
| 340 |
-
return patch_sampled, target
|
| 341 |
-
|
| 342 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 343 |
-
"""
|
| 344 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 345 |
-
|
| 346 |
-
Args:
|
| 347 |
-
patches: List of patch dictionaries from generate_patches()
|
| 348 |
-
dataset_dir: Directory to save the dataset
|
| 349 |
-
entry_id: Unique identifier for this entry/image
|
| 350 |
-
"""
|
| 351 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 352 |
-
|
| 353 |
-
for i, patch in enumerate(patches):
|
| 354 |
-
# Create unique filename
|
| 355 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 356 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 357 |
-
|
| 358 |
-
# Skip if file already exists
|
| 359 |
-
if os.path.exists(filepath):
|
| 360 |
-
continue
|
| 361 |
-
|
| 362 |
-
# Save patch data
|
| 363 |
-
with open(filepath, 'wb') as f:
|
| 364 |
-
pickle.dump(patch, f)
|
| 365 |
-
|
| 366 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 367 |
-
|
| 368 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 369 |
-
def collate_fn(batch):
|
| 370 |
-
valid_batch = []
|
| 371 |
-
for patch_data, target, valid_mask, initial_pred, classification in batch:
|
| 372 |
-
# Filter out invalid samples (no valid points)
|
| 373 |
-
if valid_mask.sum() > 0:
|
| 374 |
-
valid_batch.append((patch_data, target, valid_mask, initial_pred, classification))
|
| 375 |
-
|
| 376 |
-
if len(valid_batch) == 0:
|
| 377 |
-
return None
|
| 378 |
-
|
| 379 |
-
# Stack valid samples
|
| 380 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 381 |
-
targets = torch.stack([item[1] for item in valid_batch])
|
| 382 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 383 |
-
initial_preds = torch.stack([item[3] for item in valid_batch])
|
| 384 |
-
classifications = torch.stack([item[4] for item in valid_batch])
|
| 385 |
-
|
| 386 |
-
return patch_data, targets, valid_masks, initial_preds, classifications
|
| 387 |
-
|
| 388 |
-
# Initialize weights using Kaiming initialization for LeakyReLU
|
| 389 |
-
def init_weights(m):
|
| 390 |
-
if isinstance(m, nn.Conv1d):
|
| 391 |
-
nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
|
| 392 |
-
if m.bias is not None:
|
| 393 |
-
nn.init.zeros_(m.bias)
|
| 394 |
-
elif isinstance(m, nn.Linear):
|
| 395 |
-
nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
|
| 396 |
-
if m.bias is not None:
|
| 397 |
-
nn.init.zeros_(m.bias)
|
| 398 |
-
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
|
| 399 |
-
nn.init.ones_(m.weight)
|
| 400 |
-
nn.init.zeros_(m.bias)
|
| 401 |
-
|
| 402 |
-
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
|
| 403 |
-
score_weight: float = 0.1, class_weight: float = 0.5):
|
| 404 |
-
"""
|
| 405 |
-
Train the FastPointNet model on saved patches.
|
| 406 |
-
Updated for 11D input.
|
| 407 |
-
"""
|
| 408 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 409 |
-
print(f"Training on device: {device}")
|
| 410 |
-
|
| 411 |
-
# Create dataset and dataloader
|
| 412 |
-
dataset = PatchDataset(dataset_dir, max_points=1024, augment=True) # Enable augmentation
|
| 413 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 414 |
-
|
| 415 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20,
|
| 416 |
-
collate_fn=collate_fn, drop_last=True)
|
| 417 |
-
|
| 418 |
-
# Initialize model with 11D input
|
| 419 |
-
model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1)
|
| 420 |
-
|
| 421 |
-
model.apply(init_weights)
|
| 422 |
-
model.to(device)
|
| 423 |
-
|
| 424 |
-
# Loss functions with label smoothing for classification
|
| 425 |
-
position_criterion = nn.SmoothL1Loss() # More robust than MSE
|
| 426 |
-
score_criterion = nn.SmoothL1Loss()
|
| 427 |
-
classification_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.0)) # Weight positive class more
|
| 428 |
-
|
| 429 |
-
# AdamW optimizer with weight decay
|
| 430 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, betas=(0.9, 0.999))
|
| 431 |
-
|
| 432 |
-
# Cosine annealing scheduler for better convergence
|
| 433 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
|
| 434 |
-
|
| 435 |
-
# Training loop
|
| 436 |
-
model.train()
|
| 437 |
-
for epoch in range(epochs):
|
| 438 |
-
total_loss = 0.0
|
| 439 |
-
total_pos_loss = 0.0
|
| 440 |
-
total_score_loss = 0.0
|
| 441 |
-
total_class_loss = 0.0
|
| 442 |
-
num_batches = 0
|
| 443 |
-
|
| 444 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 445 |
-
if batch_data is None: # Skip invalid batches
|
| 446 |
-
continue
|
| 447 |
-
|
| 448 |
-
patch_data, targets, valid_masks, initial_preds, classifications = batch_data
|
| 449 |
-
patch_data = patch_data.to(device) # (batch_size, 11, max_points)
|
| 450 |
-
targets = targets.to(device) # (batch_size, 3)
|
| 451 |
-
classifications = classifications.to(device) # (batch_size,)
|
| 452 |
-
|
| 453 |
-
# Forward pass
|
| 454 |
-
optimizer.zero_grad()
|
| 455 |
-
predictions, predicted_scores, predicted_classes = model(patch_data)
|
| 456 |
-
|
| 457 |
-
# Compute actual distance from predictions to targets
|
| 458 |
-
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
|
| 459 |
-
|
| 460 |
-
# Only compute position and score losses for samples with GT vertices
|
| 461 |
-
has_gt_mask = classifications > 0.5
|
| 462 |
-
|
| 463 |
-
if has_gt_mask.sum() > 0:
|
| 464 |
-
# Position loss only for samples with GT vertices
|
| 465 |
-
pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
|
| 466 |
-
score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
|
| 467 |
-
else:
|
| 468 |
-
pos_loss = torch.tensor(0.0, device=device)
|
| 469 |
-
score_loss = torch.tensor(0.0, device=device)
|
| 470 |
-
|
| 471 |
-
# Classification loss for all samples
|
| 472 |
-
class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
|
| 473 |
-
|
| 474 |
-
# Combined loss
|
| 475 |
-
total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
|
| 476 |
-
|
| 477 |
-
# Backward pass
|
| 478 |
-
total_batch_loss.backward()
|
| 479 |
-
|
| 480 |
-
# Gradient clipping for stability
|
| 481 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 482 |
-
|
| 483 |
-
optimizer.step()
|
| 484 |
-
|
| 485 |
-
total_loss += total_batch_loss.item()
|
| 486 |
-
total_pos_loss += pos_loss.item()
|
| 487 |
-
total_score_loss += score_loss.item()
|
| 488 |
-
total_class_loss += class_loss.item()
|
| 489 |
-
num_batches += 1
|
| 490 |
-
|
| 491 |
-
if batch_idx % 50 == 0:
|
| 492 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 493 |
-
f"Total Loss: {total_batch_loss.item():.6f}, "
|
| 494 |
-
f"Pos Loss: {pos_loss.item():.6f}, "
|
| 495 |
-
f"Score Loss: {score_loss.item():.6f}, "
|
| 496 |
-
f"Class Loss: {class_loss.item():.6f}")
|
| 497 |
-
|
| 498 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 499 |
-
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
|
| 500 |
-
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
|
| 501 |
-
avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
|
| 502 |
-
|
| 503 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 504 |
-
f"Avg Total Loss: {avg_loss:.6f}, "
|
| 505 |
-
f"Avg Pos Loss: {avg_pos_loss:.6f}, "
|
| 506 |
-
f"Avg Score Loss: {avg_score_loss:.6f}, "
|
| 507 |
-
f"Avg Class Loss: {avg_class_loss:.6f}")
|
| 508 |
-
|
| 509 |
-
scheduler.step()
|
| 510 |
-
|
| 511 |
-
# Save model checkpoint every 10 epochs
|
| 512 |
-
if (epoch + 1) % 10 == 0:
|
| 513 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 514 |
-
torch.save({
|
| 515 |
-
'model_state_dict': model.state_dict(),
|
| 516 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 517 |
-
'epoch': epoch + 1,
|
| 518 |
-
'loss': avg_loss,
|
| 519 |
-
}, checkpoint_path)
|
| 520 |
-
|
| 521 |
-
# Save the trained model
|
| 522 |
-
torch.save({
|
| 523 |
-
'model_state_dict': model.state_dict(),
|
| 524 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 525 |
-
'epoch': epochs,
|
| 526 |
-
}, model_save_path)
|
| 527 |
-
|
| 528 |
-
print(f"Model saved to {model_save_path}")
|
| 529 |
-
return model
|
| 530 |
-
|
| 531 |
-
def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
|
| 532 |
-
"""
|
| 533 |
-
Load a trained FastPointNet model.
|
| 534 |
-
Updated for 11D input.
|
| 535 |
-
"""
|
| 536 |
-
if device is None:
|
| 537 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 538 |
-
|
| 539 |
-
model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=predict_score)
|
| 540 |
-
|
| 541 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 542 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 543 |
-
|
| 544 |
-
model.to(device)
|
| 545 |
-
model.eval()
|
| 546 |
-
|
| 547 |
-
return model
|
| 548 |
-
|
| 549 |
-
def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]:
|
| 550 |
-
"""
|
| 551 |
-
Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet.
|
| 552 |
-
Updated for 11D patches.
|
| 553 |
-
"""
|
| 554 |
-
if device is None:
|
| 555 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 556 |
-
|
| 557 |
-
patch_11d = patch['patch_11d'] # (N, 11) - Updated for 11D
|
| 558 |
-
|
| 559 |
-
# Prepare input
|
| 560 |
-
max_points = 1024
|
| 561 |
-
num_points = patch_11d.shape[0]
|
| 562 |
-
|
| 563 |
-
if num_points >= max_points:
|
| 564 |
-
# Sample points
|
| 565 |
-
indices = np.random.choice(num_points, max_points, replace=False)
|
| 566 |
-
patch_sampled = patch_11d[indices]
|
| 567 |
-
else:
|
| 568 |
-
# Pad with zeros
|
| 569 |
-
patch_sampled = np.zeros((max_points, 11)) # Updated for 11D
|
| 570 |
-
patch_sampled[:num_points] = patch_11d
|
| 571 |
-
|
| 572 |
-
# Convert to tensor
|
| 573 |
-
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 11, max_points)
|
| 574 |
-
patch_tensor = patch_tensor.to(device)
|
| 575 |
-
|
| 576 |
-
# Predict
|
| 577 |
-
with torch.no_grad():
|
| 578 |
-
outputs = model(patch_tensor)
|
| 579 |
-
|
| 580 |
-
if model.predict_score and model.predict_class:
|
| 581 |
-
position, score, classification = outputs
|
| 582 |
-
position = position.cpu().numpy().squeeze()
|
| 583 |
-
score = score.cpu().numpy().squeeze()
|
| 584 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
|
| 585 |
-
elif model.predict_score:
|
| 586 |
-
position, score = outputs
|
| 587 |
-
position = position.cpu().numpy().squeeze()
|
| 588 |
-
score = score.cpu().numpy().squeeze()
|
| 589 |
-
classification = None
|
| 590 |
-
elif model.predict_class:
|
| 591 |
-
position, classification = outputs
|
| 592 |
-
position = position.cpu().numpy().squeeze()
|
| 593 |
-
score = None
|
| 594 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
|
| 595 |
-
else:
|
| 596 |
-
position = outputs
|
| 597 |
-
position = position.cpu().numpy().squeeze()
|
| 598 |
-
score = None
|
| 599 |
-
classification = None
|
| 600 |
-
|
| 601 |
-
# Apply offset correction
|
| 602 |
-
offset = patch['cluster_center']
|
| 603 |
-
position += offset
|
| 604 |
-
|
| 605 |
-
return position, score, classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_voxel.py
DELETED
|
@@ -1,591 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pickle
|
| 7 |
-
from torch.utils.data import Dataset, DataLoader
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import json
|
| 10 |
-
|
| 11 |
-
class Fast3DCNN(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Fast 3D CNN implementation for 3D vertex prediction from voxelized point cloud patches.
|
| 14 |
-
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
|
| 15 |
-
Uses voxelization and 3D convolutions instead of PointNet architecture.
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, input_channels=7, output_dim=3, voxel_size=32, predict_score=True, predict_class=True, num_classes=1):
|
| 18 |
-
super(Fast3DCNN, self).__init__()
|
| 19 |
-
self.voxel_size = voxel_size
|
| 20 |
-
self.predict_score = predict_score
|
| 21 |
-
self.predict_class = predict_class
|
| 22 |
-
self.num_classes = num_classes
|
| 23 |
-
|
| 24 |
-
# 3D Convolutional layers for feature extraction
|
| 25 |
-
self.conv1 = nn.Conv3d(input_channels, 64, kernel_size=3, padding=1)
|
| 26 |
-
self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
|
| 27 |
-
self.conv3 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
|
| 28 |
-
self.conv4 = nn.Conv3d(256, 512, kernel_size=3, padding=1)
|
| 29 |
-
self.conv5 = nn.Conv3d(512, 512, kernel_size=3, padding=1)
|
| 30 |
-
|
| 31 |
-
# Additional convolutional layers for deeper feature extraction
|
| 32 |
-
self.conv6 = nn.Conv3d(512, 1024, kernel_size=3, padding=1)
|
| 33 |
-
|
| 34 |
-
# Batch normalization layers
|
| 35 |
-
self.bn1 = nn.BatchNorm3d(64)
|
| 36 |
-
self.bn2 = nn.BatchNorm3d(128)
|
| 37 |
-
self.bn3 = nn.BatchNorm3d(256)
|
| 38 |
-
self.bn4 = nn.BatchNorm3d(512)
|
| 39 |
-
self.bn5 = nn.BatchNorm3d(512)
|
| 40 |
-
self.bn6 = nn.BatchNorm3d(1024)
|
| 41 |
-
|
| 42 |
-
# Max pooling layers
|
| 43 |
-
self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
|
| 44 |
-
|
| 45 |
-
# Calculate the size after convolutions and pooling
|
| 46 |
-
# Starting with voxel_size^3, after 3 pooling operations: voxel_size / 8
|
| 47 |
-
final_size = voxel_size // 8
|
| 48 |
-
flattened_size = 1024 * (final_size ** 3)
|
| 49 |
-
|
| 50 |
-
# Adaptive pooling to handle variable sizes
|
| 51 |
-
self.adaptive_pool = nn.AdaptiveAvgPool3d((4, 4, 4))
|
| 52 |
-
flattened_size = 1024 * 4 * 4 * 4
|
| 53 |
-
|
| 54 |
-
# Shared fully connected layers
|
| 55 |
-
self.shared_fc1 = nn.Linear(flattened_size, 1024)
|
| 56 |
-
self.shared_fc2 = nn.Linear(1024, 512)
|
| 57 |
-
|
| 58 |
-
# Position prediction head
|
| 59 |
-
self.pos_fc1 = nn.Linear(512, 512)
|
| 60 |
-
self.pos_fc2 = nn.Linear(512, 256)
|
| 61 |
-
self.pos_fc3 = nn.Linear(256, 128)
|
| 62 |
-
self.pos_fc4 = nn.Linear(128, output_dim)
|
| 63 |
-
|
| 64 |
-
# Score prediction head
|
| 65 |
-
if self.predict_score:
|
| 66 |
-
self.score_fc1 = nn.Linear(512, 512)
|
| 67 |
-
self.score_fc2 = nn.Linear(512, 256)
|
| 68 |
-
self.score_fc3 = nn.Linear(256, 128)
|
| 69 |
-
self.score_fc4 = nn.Linear(128, 64)
|
| 70 |
-
self.score_fc5 = nn.Linear(64, 1)
|
| 71 |
-
|
| 72 |
-
# Classification head
|
| 73 |
-
if self.predict_class:
|
| 74 |
-
self.class_fc1 = nn.Linear(512, 512)
|
| 75 |
-
self.class_fc2 = nn.Linear(512, 256)
|
| 76 |
-
self.class_fc3 = nn.Linear(256, 128)
|
| 77 |
-
self.class_fc4 = nn.Linear(128, 64)
|
| 78 |
-
self.class_fc5 = nn.Linear(64, num_classes)
|
| 79 |
-
|
| 80 |
-
# Dropout layers
|
| 81 |
-
self.dropout_light = nn.Dropout(0.2)
|
| 82 |
-
self.dropout_medium = nn.Dropout(0.3)
|
| 83 |
-
self.dropout_heavy = nn.Dropout(0.4)
|
| 84 |
-
|
| 85 |
-
def forward(self, x):
|
| 86 |
-
"""
|
| 87 |
-
Forward pass
|
| 88 |
-
Args:
|
| 89 |
-
x: (batch_size, input_channels, voxel_size, voxel_size, voxel_size) tensor
|
| 90 |
-
Returns:
|
| 91 |
-
Tuple containing predictions based on configuration:
|
| 92 |
-
- position: (batch_size, output_dim) tensor of predicted 3D coordinates
|
| 93 |
-
- score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True)
|
| 94 |
-
- classification: (batch_size, num_classes) tensor of class logits (if predict_class=True)
|
| 95 |
-
"""
|
| 96 |
-
batch_size = x.size(0)
|
| 97 |
-
|
| 98 |
-
# 3D Convolutional feature extraction
|
| 99 |
-
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 100 |
-
x1 = self.pool(x1)
|
| 101 |
-
|
| 102 |
-
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 103 |
-
x2 = self.pool(x2)
|
| 104 |
-
|
| 105 |
-
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 106 |
-
x3 = self.pool(x3)
|
| 107 |
-
|
| 108 |
-
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 109 |
-
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 110 |
-
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 111 |
-
|
| 112 |
-
# Adaptive pooling to ensure consistent size
|
| 113 |
-
x6 = self.adaptive_pool(x6)
|
| 114 |
-
|
| 115 |
-
# Flatten for fully connected layers
|
| 116 |
-
global_features = x6.view(batch_size, -1)
|
| 117 |
-
|
| 118 |
-
# Shared features
|
| 119 |
-
shared1 = F.relu(self.shared_fc1(global_features))
|
| 120 |
-
shared1 = self.dropout_light(shared1)
|
| 121 |
-
shared2 = F.relu(self.shared_fc2(shared1))
|
| 122 |
-
shared_features = self.dropout_medium(shared2)
|
| 123 |
-
|
| 124 |
-
# Position prediction
|
| 125 |
-
pos1 = F.relu(self.pos_fc1(shared_features))
|
| 126 |
-
pos1 = self.dropout_light(pos1)
|
| 127 |
-
pos2 = F.relu(self.pos_fc2(pos1))
|
| 128 |
-
pos2 = self.dropout_medium(pos2)
|
| 129 |
-
pos3 = F.relu(self.pos_fc3(pos2))
|
| 130 |
-
pos3 = self.dropout_light(pos3)
|
| 131 |
-
position = self.pos_fc4(pos3)
|
| 132 |
-
|
| 133 |
-
outputs = [position]
|
| 134 |
-
|
| 135 |
-
if self.predict_score:
|
| 136 |
-
# Score prediction
|
| 137 |
-
score1 = F.relu(self.score_fc1(shared_features))
|
| 138 |
-
score1 = self.dropout_light(score1)
|
| 139 |
-
score2 = F.relu(self.score_fc2(score1))
|
| 140 |
-
score2 = self.dropout_medium(score2)
|
| 141 |
-
score3 = F.relu(self.score_fc3(score2))
|
| 142 |
-
score3 = self.dropout_light(score3)
|
| 143 |
-
score4 = F.relu(self.score_fc4(score3))
|
| 144 |
-
score4 = self.dropout_light(score4)
|
| 145 |
-
score = F.relu(self.score_fc5(score4))
|
| 146 |
-
outputs.append(score)
|
| 147 |
-
|
| 148 |
-
if self.predict_class:
|
| 149 |
-
# Classification prediction
|
| 150 |
-
class1 = F.relu(self.class_fc1(shared_features))
|
| 151 |
-
class1 = self.dropout_light(class1)
|
| 152 |
-
class2 = F.relu(self.class_fc2(class1))
|
| 153 |
-
class2 = self.dropout_medium(class2)
|
| 154 |
-
class3 = F.relu(self.class_fc3(class2))
|
| 155 |
-
class3 = self.dropout_light(class3)
|
| 156 |
-
class4 = F.relu(self.class_fc4(class3))
|
| 157 |
-
class4 = self.dropout_light(class4)
|
| 158 |
-
classification = self.class_fc5(class4)
|
| 159 |
-
outputs.append(classification)
|
| 160 |
-
|
| 161 |
-
# Return outputs based on configuration
|
| 162 |
-
if len(outputs) == 1:
|
| 163 |
-
return outputs[0]
|
| 164 |
-
elif len(outputs) == 2:
|
| 165 |
-
if self.predict_score:
|
| 166 |
-
return outputs[0], outputs[1]
|
| 167 |
-
else:
|
| 168 |
-
return outputs[0], outputs[1]
|
| 169 |
-
else:
|
| 170 |
-
return outputs[0], outputs[1], outputs[2]
|
| 171 |
-
|
| 172 |
-
def voxelize_patch(patch_7d: np.ndarray, voxel_size: int = 32, patch_size: float = 1.0) -> np.ndarray:
|
| 173 |
-
"""
|
| 174 |
-
Convert point cloud patch to voxel grid.
|
| 175 |
-
|
| 176 |
-
Args:
|
| 177 |
-
patch_7d: (N, 7) array of points with [x, y, z, r, g, b, filtered_flag]
|
| 178 |
-
voxel_size: Size of the voxel grid (voxel_size^3)
|
| 179 |
-
patch_size: Physical size of the patch in world coordinates
|
| 180 |
-
|
| 181 |
-
Returns:
|
| 182 |
-
voxels: (7, voxel_size, voxel_size, voxel_size) array of voxelized features
|
| 183 |
-
"""
|
| 184 |
-
if len(patch_7d) == 0:
|
| 185 |
-
return np.zeros((7, voxel_size, voxel_size, voxel_size))
|
| 186 |
-
|
| 187 |
-
# Extract coordinates and features
|
| 188 |
-
coords = patch_7d[:, :3] # x, y, z
|
| 189 |
-
features = patch_7d[:, 3:] # r, g, b, filtered_flag
|
| 190 |
-
|
| 191 |
-
# Normalize coordinates to [0, voxel_size-1]
|
| 192 |
-
coords_min = coords.min(axis=0)
|
| 193 |
-
coords_max = coords.max(axis=0)
|
| 194 |
-
coords_range = coords_max - coords_min
|
| 195 |
-
coords_range[coords_range == 0] = 1 # Avoid division by zero
|
| 196 |
-
|
| 197 |
-
normalized_coords = (coords - coords_min) / coords_range * (voxel_size - 1)
|
| 198 |
-
voxel_indices = normalized_coords.astype(int)
|
| 199 |
-
|
| 200 |
-
# Clip to valid range
|
| 201 |
-
voxel_indices = np.clip(voxel_indices, 0, voxel_size - 1)
|
| 202 |
-
|
| 203 |
-
# Initialize voxel grid
|
| 204 |
-
voxels = np.zeros((7, voxel_size, voxel_size, voxel_size))
|
| 205 |
-
|
| 206 |
-
# Fill voxels with features (average if multiple points fall in same voxel)
|
| 207 |
-
counts = np.zeros((voxel_size, voxel_size, voxel_size))
|
| 208 |
-
|
| 209 |
-
for i in range(len(patch_7d)):
|
| 210 |
-
x, y, z = voxel_indices[i]
|
| 211 |
-
# Store normalized coordinates in first 3 channels
|
| 212 |
-
voxels[0, x, y, z] += normalized_coords[i, 0] / (voxel_size - 1) # normalized x
|
| 213 |
-
voxels[1, x, y, z] += normalized_coords[i, 1] / (voxel_size - 1) # normalized y
|
| 214 |
-
voxels[2, x, y, z] += normalized_coords[i, 2] / (voxel_size - 1) # normalized z
|
| 215 |
-
# Store RGB and filtered_flag in remaining channels
|
| 216 |
-
voxels[3:, x, y, z] += features[i]
|
| 217 |
-
counts[x, y, z] += 1
|
| 218 |
-
|
| 219 |
-
# Average features where multiple points exist
|
| 220 |
-
mask = counts > 0
|
| 221 |
-
for c in range(7):
|
| 222 |
-
voxels[c][mask] /= counts[mask]
|
| 223 |
-
|
| 224 |
-
return voxels
|
| 225 |
-
|
| 226 |
-
class VoxelPatchDataset(Dataset):
|
| 227 |
-
"""
|
| 228 |
-
Dataset class for loading saved patches and converting them to voxel grids for 3D CNN training.
|
| 229 |
-
"""
|
| 230 |
-
|
| 231 |
-
def __init__(self, dataset_dir: str, voxel_size: int = 32, augment: bool = False):
|
| 232 |
-
self.dataset_dir = dataset_dir
|
| 233 |
-
self.voxel_size = voxel_size
|
| 234 |
-
self.augment = augment
|
| 235 |
-
|
| 236 |
-
# Load patch files
|
| 237 |
-
self.patch_files = []
|
| 238 |
-
for file in os.listdir(dataset_dir):
|
| 239 |
-
if file.endswith('.pkl'):
|
| 240 |
-
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 241 |
-
|
| 242 |
-
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 243 |
-
|
| 244 |
-
def __len__(self):
|
| 245 |
-
return len(self.patch_files)
|
| 246 |
-
|
| 247 |
-
def __getitem__(self, idx):
|
| 248 |
-
"""
|
| 249 |
-
Load and process a patch for training.
|
| 250 |
-
Returns:
|
| 251 |
-
voxel_data: (7, voxel_size, voxel_size, voxel_size) tensor of voxelized data
|
| 252 |
-
target: (3,) tensor of target 3D coordinates
|
| 253 |
-
valid_mask: scalar tensor indicating if this is a valid sample
|
| 254 |
-
distance_to_gt: scalar tensor of distance from initial prediction to GT
|
| 255 |
-
classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
|
| 256 |
-
"""
|
| 257 |
-
patch_file = self.patch_files[idx]
|
| 258 |
-
|
| 259 |
-
with open(patch_file, 'rb') as f:
|
| 260 |
-
patch_info = pickle.load(f)
|
| 261 |
-
|
| 262 |
-
patch_7d = patch_info['patch_7d'] # (N, 7)
|
| 263 |
-
target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
|
| 264 |
-
initial_pred = patch_info.get('cluster_center', None) # (3,) or None
|
| 265 |
-
|
| 266 |
-
# Determine classification label based on GT vertex presence
|
| 267 |
-
has_gt_vertex = 1.0 if target is not None else 0.0
|
| 268 |
-
|
| 269 |
-
# Handle patches without ground truth
|
| 270 |
-
if target is None:
|
| 271 |
-
target = np.zeros(3)
|
| 272 |
-
else:
|
| 273 |
-
target = np.array(target)
|
| 274 |
-
|
| 275 |
-
# Voxelize the patch
|
| 276 |
-
voxel_data = voxelize_patch(patch_7d, self.voxel_size)
|
| 277 |
-
|
| 278 |
-
# Data augmentation (only if GT vertex is present)
|
| 279 |
-
if self.augment and has_gt_vertex > 0:
|
| 280 |
-
voxel_data, target = self._augment_voxels(voxel_data, target)
|
| 281 |
-
|
| 282 |
-
# Convert to tensors (copy arrays to handle negative strides from augmentation)
|
| 283 |
-
voxel_tensor = torch.from_numpy(voxel_data.copy()).float() # (7, voxel_size, voxel_size, voxel_size)
|
| 284 |
-
target_tensor = torch.from_numpy(target.copy()).float() # (3,)
|
| 285 |
-
|
| 286 |
-
# Valid mask (check if voxel grid has any non-zero values)
|
| 287 |
-
valid_mask = torch.tensor(1.0 if voxel_data.sum() > 0 else 0.0)
|
| 288 |
-
|
| 289 |
-
# Handle initial_pred
|
| 290 |
-
if initial_pred is not None:
|
| 291 |
-
initial_pred_tensor = torch.from_numpy(initial_pred).float()
|
| 292 |
-
else:
|
| 293 |
-
initial_pred_tensor = torch.zeros(3).float()
|
| 294 |
-
|
| 295 |
-
# Classification tensor
|
| 296 |
-
classification_tensor = torch.tensor(has_gt_vertex).float()
|
| 297 |
-
|
| 298 |
-
return voxel_tensor, target_tensor, valid_mask, initial_pred_tensor, classification_tensor
|
| 299 |
-
|
| 300 |
-
def _augment_voxels(self, voxel_data: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 301 |
-
"""
|
| 302 |
-
Apply data augmentation to voxel data.
|
| 303 |
-
"""
|
| 304 |
-
# Random rotation around Z-axis
|
| 305 |
-
if np.random.random() > 0.5:
|
| 306 |
-
k = np.random.randint(1, 4) # 90, 180, or 270 degrees
|
| 307 |
-
voxel_data = np.rot90(voxel_data, k, axes=(1, 2)) # Rotate around z-axis
|
| 308 |
-
|
| 309 |
-
# Random flip
|
| 310 |
-
if np.random.random() > 0.5:
|
| 311 |
-
voxel_data = np.flip(voxel_data, axis=1) # Flip along x-axis
|
| 312 |
-
if np.random.random() > 0.5:
|
| 313 |
-
voxel_data = np.flip(voxel_data, axis=2) # Flip along y-axis
|
| 314 |
-
|
| 315 |
-
return voxel_data, target
|
| 316 |
-
|
| 317 |
-
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 318 |
-
"""
|
| 319 |
-
Save patches from prediction pipeline to create a training dataset.
|
| 320 |
-
|
| 321 |
-
Args:
|
| 322 |
-
patches: List of patch dictionaries from generate_patches()
|
| 323 |
-
dataset_dir: Directory to save the dataset
|
| 324 |
-
entry_id: Unique identifier for this entry/image
|
| 325 |
-
"""
|
| 326 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
| 327 |
-
|
| 328 |
-
for i, patch in enumerate(patches):
|
| 329 |
-
# Create unique filename
|
| 330 |
-
filename = f"{entry_id}_patch_{i}.pkl"
|
| 331 |
-
filepath = os.path.join(dataset_dir, filename)
|
| 332 |
-
|
| 333 |
-
# Skip if file already exists
|
| 334 |
-
if os.path.exists(filepath):
|
| 335 |
-
continue
|
| 336 |
-
|
| 337 |
-
# Save patch data
|
| 338 |
-
with open(filepath, 'wb') as f:
|
| 339 |
-
pickle.dump(patch, f)
|
| 340 |
-
|
| 341 |
-
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 342 |
-
|
| 343 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 344 |
-
def collate_fn(batch):
|
| 345 |
-
valid_batch = []
|
| 346 |
-
for voxel_data, target, valid_mask, initial_pred, classification in batch:
|
| 347 |
-
# Filter out invalid samples
|
| 348 |
-
if valid_mask > 0:
|
| 349 |
-
valid_batch.append((voxel_data, target, valid_mask, initial_pred, classification))
|
| 350 |
-
|
| 351 |
-
if len(valid_batch) == 0:
|
| 352 |
-
return None
|
| 353 |
-
|
| 354 |
-
# Stack valid samples
|
| 355 |
-
voxel_data = torch.stack([item[0] for item in valid_batch])
|
| 356 |
-
targets = torch.stack([item[1] for item in valid_batch])
|
| 357 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 358 |
-
initial_preds = torch.stack([item[3] for item in valid_batch])
|
| 359 |
-
classifications = torch.stack([item[4] for item in valid_batch])
|
| 360 |
-
|
| 361 |
-
return voxel_data, targets, valid_masks, initial_preds, classifications
|
| 362 |
-
|
| 363 |
-
# Initialize weights using Xavier/Glorot initialization
|
| 364 |
-
def init_weights(m):
|
| 365 |
-
if isinstance(m, (nn.Conv3d, nn.Conv1d)):
|
| 366 |
-
nn.init.xavier_uniform_(m.weight)
|
| 367 |
-
if m.bias is not None:
|
| 368 |
-
nn.init.zeros_(m.bias)
|
| 369 |
-
elif isinstance(m, nn.Linear):
|
| 370 |
-
nn.init.xavier_uniform_(m.weight)
|
| 371 |
-
if m.bias is not None:
|
| 372 |
-
nn.init.zeros_(m.bias)
|
| 373 |
-
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm1d)):
|
| 374 |
-
nn.init.ones_(m.weight)
|
| 375 |
-
nn.init.zeros_(m.bias)
|
| 376 |
-
|
| 377 |
-
def train_3dcnn(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 16, lr: float = 0.001,
|
| 378 |
-
voxel_size: int = 32, score_weight: float = 0.1, class_weight: float = 0.5):
|
| 379 |
-
"""
|
| 380 |
-
Train the Fast3DCNN model on saved patches.
|
| 381 |
-
|
| 382 |
-
Args:
|
| 383 |
-
dataset_dir: Directory containing saved patch files
|
| 384 |
-
model_save_path: Path to save the trained model
|
| 385 |
-
epochs: Number of training epochs
|
| 386 |
-
batch_size: Training batch size (reduced due to memory requirements of 3D conv)
|
| 387 |
-
lr: Learning rate
|
| 388 |
-
voxel_size: Size of voxel grid
|
| 389 |
-
score_weight: Weight for the distance prediction loss
|
| 390 |
-
class_weight: Weight for the classification loss
|
| 391 |
-
"""
|
| 392 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 393 |
-
print(f"Training on device: {device}")
|
| 394 |
-
|
| 395 |
-
# Create dataset and dataloader
|
| 396 |
-
dataset = VoxelPatchDataset(dataset_dir, voxel_size=voxel_size, augment=True)
|
| 397 |
-
print(f"Dataset loaded with {len(dataset)} samples")
|
| 398 |
-
|
| 399 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4,
|
| 400 |
-
collate_fn=collate_fn, drop_last=True)
|
| 401 |
-
|
| 402 |
-
# Initialize model with score and classification prediction
|
| 403 |
-
model = Fast3DCNN(input_channels=7, output_dim=3, voxel_size=voxel_size,
|
| 404 |
-
predict_score=True, predict_class=True, num_classes=1)
|
| 405 |
-
|
| 406 |
-
model.apply(init_weights)
|
| 407 |
-
model.to(device)
|
| 408 |
-
|
| 409 |
-
# Loss functions
|
| 410 |
-
position_criterion = nn.MSELoss()
|
| 411 |
-
score_criterion = nn.MSELoss()
|
| 412 |
-
classification_criterion = nn.BCEWithLogitsLoss()
|
| 413 |
-
|
| 414 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 415 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 416 |
-
|
| 417 |
-
# Training loop
|
| 418 |
-
model.train()
|
| 419 |
-
for epoch in range(epochs):
|
| 420 |
-
total_loss = 0.0
|
| 421 |
-
total_pos_loss = 0.0
|
| 422 |
-
total_score_loss = 0.0
|
| 423 |
-
total_class_loss = 0.0
|
| 424 |
-
num_batches = 0
|
| 425 |
-
|
| 426 |
-
for batch_idx, batch_data in enumerate(dataloader):
|
| 427 |
-
if batch_data is None: # Skip invalid batches
|
| 428 |
-
continue
|
| 429 |
-
|
| 430 |
-
voxel_data, targets, valid_masks, initial_preds, classifications = batch_data
|
| 431 |
-
voxel_data = voxel_data.to(device) # (batch_size, 7, voxel_size, voxel_size, voxel_size)
|
| 432 |
-
targets = targets.to(device) # (batch_size, 3)
|
| 433 |
-
classifications = classifications.to(device) # (batch_size,)
|
| 434 |
-
|
| 435 |
-
# Forward pass
|
| 436 |
-
optimizer.zero_grad()
|
| 437 |
-
predictions, predicted_scores, predicted_classes = model(voxel_data)
|
| 438 |
-
|
| 439 |
-
# Compute actual distance from predictions to targets
|
| 440 |
-
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
|
| 441 |
-
|
| 442 |
-
# Only compute position and score losses for samples with GT vertices
|
| 443 |
-
has_gt_mask = classifications > 0.5
|
| 444 |
-
|
| 445 |
-
if has_gt_mask.sum() > 0:
|
| 446 |
-
# Position loss only for samples with GT vertices
|
| 447 |
-
pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
|
| 448 |
-
score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
|
| 449 |
-
else:
|
| 450 |
-
pos_loss = torch.tensor(0.0, device=device)
|
| 451 |
-
score_loss = torch.tensor(0.0, device=device)
|
| 452 |
-
|
| 453 |
-
# Classification loss for all samples
|
| 454 |
-
class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
|
| 455 |
-
|
| 456 |
-
# Combined loss
|
| 457 |
-
total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
|
| 458 |
-
|
| 459 |
-
# Backward pass
|
| 460 |
-
total_batch_loss.backward()
|
| 461 |
-
optimizer.step()
|
| 462 |
-
|
| 463 |
-
total_loss += total_batch_loss.item()
|
| 464 |
-
total_pos_loss += pos_loss.item()
|
| 465 |
-
total_score_loss += score_loss.item()
|
| 466 |
-
total_class_loss += class_loss.item()
|
| 467 |
-
num_batches += 1
|
| 468 |
-
|
| 469 |
-
if batch_idx % 50 == 0:
|
| 470 |
-
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 471 |
-
f"Total Loss: {total_batch_loss.item():.6f}, "
|
| 472 |
-
f"Pos Loss: {pos_loss.item():.6f}, "
|
| 473 |
-
f"Score Loss: {score_loss.item():.6f}, "
|
| 474 |
-
f"Class Loss: {class_loss.item():.6f}")
|
| 475 |
-
|
| 476 |
-
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 477 |
-
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
|
| 478 |
-
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
|
| 479 |
-
avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
|
| 480 |
-
|
| 481 |
-
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 482 |
-
f"Avg Total Loss: {avg_loss:.6f}, "
|
| 483 |
-
f"Avg Pos Loss: {avg_pos_loss:.6f}, "
|
| 484 |
-
f"Avg Score Loss: {avg_score_loss:.6f}, "
|
| 485 |
-
f"Avg Class Loss: {avg_class_loss:.6f}")
|
| 486 |
-
|
| 487 |
-
scheduler.step()
|
| 488 |
-
|
| 489 |
-
# Save model checkpoint every epoch
|
| 490 |
-
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 491 |
-
torch.save({
|
| 492 |
-
'model_state_dict': model.state_dict(),
|
| 493 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 494 |
-
'epoch': epoch + 1,
|
| 495 |
-
'loss': avg_loss,
|
| 496 |
-
}, checkpoint_path)
|
| 497 |
-
|
| 498 |
-
# Save the trained model
|
| 499 |
-
torch.save({
|
| 500 |
-
'model_state_dict': model.state_dict(),
|
| 501 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 502 |
-
'epoch': epochs,
|
| 503 |
-
}, model_save_path)
|
| 504 |
-
|
| 505 |
-
print(f"Model saved to {model_save_path}")
|
| 506 |
-
return model
|
| 507 |
-
|
| 508 |
-
def load_3dcnn_model(model_path: str, device: torch.device = None, voxel_size: int = 32, predict_score: bool = True) -> Fast3DCNN:
|
| 509 |
-
"""
|
| 510 |
-
Load a trained Fast3DCNN model.
|
| 511 |
-
|
| 512 |
-
Args:
|
| 513 |
-
model_path: Path to the saved model
|
| 514 |
-
device: Device to load the model on
|
| 515 |
-
voxel_size: Size of voxel grid
|
| 516 |
-
predict_score: Whether the model predicts scores
|
| 517 |
-
|
| 518 |
-
Returns:
|
| 519 |
-
Loaded Fast3DCNN model
|
| 520 |
-
"""
|
| 521 |
-
if device is None:
|
| 522 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 523 |
-
|
| 524 |
-
model = Fast3DCNN(input_channels=7, output_dim=3, voxel_size=voxel_size, predict_score=predict_score)
|
| 525 |
-
|
| 526 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 527 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 528 |
-
|
| 529 |
-
model.to(device)
|
| 530 |
-
model.eval()
|
| 531 |
-
|
| 532 |
-
return model
|
| 533 |
-
|
| 534 |
-
def predict_vertex_from_patch_voxel(model: Fast3DCNN, patch: np.ndarray, device: torch.device = None, voxel_size: int = 32) -> Tuple[np.ndarray, float, float]:
|
| 535 |
-
"""
|
| 536 |
-
Predict 3D vertex coordinates, confidence score, and classification from a patch using trained 3D CNN.
|
| 537 |
-
|
| 538 |
-
Args:
|
| 539 |
-
model: Trained Fast3DCNN model
|
| 540 |
-
patch: Dictionary containing patch data with 'patch_7d' and 'cluster_center' keys
|
| 541 |
-
device: Device to run prediction on
|
| 542 |
-
voxel_size: Size of voxel grid
|
| 543 |
-
|
| 544 |
-
Returns:
|
| 545 |
-
tuple of (predicted_coordinates, confidence_score, classification_score)
|
| 546 |
-
predicted_coordinates: (3,) numpy array of predicted 3D coordinates
|
| 547 |
-
confidence_score: float representing predicted distance to GT (lower is better)
|
| 548 |
-
classification_score: float representing probability of GT vertex presence (0-1)
|
| 549 |
-
"""
|
| 550 |
-
if device is None:
|
| 551 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 552 |
-
|
| 553 |
-
patch_7d = patch['patch_7d'] # (N, 7)
|
| 554 |
-
|
| 555 |
-
# Voxelize the patch
|
| 556 |
-
voxel_data = voxelize_patch(patch_7d, voxel_size)
|
| 557 |
-
|
| 558 |
-
# Convert to tensor
|
| 559 |
-
voxel_tensor = torch.from_numpy(voxel_data).float().unsqueeze(0) # (1, 7, voxel_size, voxel_size, voxel_size)
|
| 560 |
-
voxel_tensor = voxel_tensor.to(device)
|
| 561 |
-
|
| 562 |
-
# Predict
|
| 563 |
-
with torch.no_grad():
|
| 564 |
-
outputs = model(voxel_tensor)
|
| 565 |
-
|
| 566 |
-
if model.predict_score and model.predict_class:
|
| 567 |
-
position, score, classification = outputs
|
| 568 |
-
position = position.cpu().numpy().squeeze()
|
| 569 |
-
score = score.cpu().numpy().squeeze()
|
| 570 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze()
|
| 571 |
-
elif model.predict_score:
|
| 572 |
-
position, score = outputs
|
| 573 |
-
position = position.cpu().numpy().squeeze()
|
| 574 |
-
score = score.cpu().numpy().squeeze()
|
| 575 |
-
classification = None
|
| 576 |
-
elif model.predict_class:
|
| 577 |
-
position, classification = outputs
|
| 578 |
-
position = position.cpu().numpy().squeeze()
|
| 579 |
-
score = None
|
| 580 |
-
classification = torch.sigmoid(classification).cpu().numpy().squeeze()
|
| 581 |
-
else:
|
| 582 |
-
position = outputs
|
| 583 |
-
position = position.cpu().numpy().squeeze()
|
| 584 |
-
score = None
|
| 585 |
-
classification = None
|
| 586 |
-
|
| 587 |
-
# Apply offset correction
|
| 588 |
-
offset = patch['cluster_center']
|
| 589 |
-
position += offset
|
| 590 |
-
|
| 591 |
-
return position, score, classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
find_best_results.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# filepath: /home/skvrnjan/hoho/find_best_results.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# filepath: /home/skvrnjan/hoho/find_best_results.py
|
| 3 |
+
# This script scans a directory for result files (text files typically starting
|
| 4 |
+
# with "results_vt" within subdirectories matching a given prefix).
|
| 5 |
+
# It parses these files to extract metrics like Mean HSS, Mean F1, Mean IoU,
|
| 6 |
+
# Vertex Threshold, Edge Threshold, and Only Predicted Connections.
|
| 7 |
+
# The script then identifies and prints the top N results (default N=10)
|
| 8 |
+
# for Mean HSS, Mean F1, and Mean IoU, along with their associated configuration
|
| 9 |
+
# parameters.
|
| 10 |
import os
|
| 11 |
import re
|
| 12 |
|
fully_deep.py
DELETED
|
@@ -1,1082 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import os
|
| 3 |
-
import pickle
|
| 4 |
-
from torch.utils.data import Dataset, DataLoader
|
| 5 |
-
import numpy as np
|
| 6 |
-
from scipy.optimize import linear_sum_assignment
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
# =============================================================================
|
| 11 |
-
# CONFIGURATION PARAMETERS
|
| 12 |
-
# =============================================================================
|
| 13 |
-
|
| 14 |
-
# Dataset Configuration
|
| 15 |
-
DATA_DIR = '/mnt/personal/skvrnjan/hoho_fully'
|
| 16 |
-
SPLIT = 'train'
|
| 17 |
-
MAX_POINTS = 8096
|
| 18 |
-
BATCH_SIZE = 32
|
| 19 |
-
NUM_WORKERS = 8
|
| 20 |
-
|
| 21 |
-
# Model Architecture Parameters
|
| 22 |
-
PC_INPUT_FEATURES = 3
|
| 23 |
-
PC_ENCODER_OUTPUT_FEATURES = 128
|
| 24 |
-
MAX_VERTICES = 50
|
| 25 |
-
VERTEX_COORD_DIM = 3
|
| 26 |
-
GNN_HIDDEN_DIM = 64
|
| 27 |
-
NUM_GNN_LAYERS = 2
|
| 28 |
-
HIDDEN_DIM = 256
|
| 29 |
-
NUM_DECODER_LAYERS = 3
|
| 30 |
-
NUM_HEADS = 8
|
| 31 |
-
|
| 32 |
-
# PointNet2 Encoder Parameters
|
| 33 |
-
SA1_NPOINT = 1024
|
| 34 |
-
SA1_RADIUS = 0.2
|
| 35 |
-
SA1_NSAMPLE = 32
|
| 36 |
-
SA1_MLP = [64, 64, 128]
|
| 37 |
-
|
| 38 |
-
SA2_NPOINT = 256
|
| 39 |
-
SA2_RADIUS = 0.4
|
| 40 |
-
SA2_NSAMPLE = 64
|
| 41 |
-
SA2_MLP = [128, 128, 256]
|
| 42 |
-
|
| 43 |
-
SA3_MLP = [256, 512, 1024] # Global pooling layer
|
| 44 |
-
|
| 45 |
-
FP3_MLP = [256, 256]
|
| 46 |
-
FP2_MLP = [256, 128]
|
| 47 |
-
FP1_MLP = [128, 128] # Will add PC_ENCODER_OUTPUT_FEATURES at the end
|
| 48 |
-
|
| 49 |
-
# Vertex Prediction Head Parameters
|
| 50 |
-
VERTEX_TRANSFORMER_DROPOUT = 0.1
|
| 51 |
-
VERTEX_TRANSFORMER_FFN_RATIO = 4
|
| 52 |
-
|
| 53 |
-
# Edge Prediction Head Parameters
|
| 54 |
-
EDGE_GNN_NUM_HEADS = 4
|
| 55 |
-
EDGE_GNN_DROPOUT = 0.1
|
| 56 |
-
EDGE_K_NEIGHBORS = 8
|
| 57 |
-
|
| 58 |
-
# Training Configuration
|
| 59 |
-
NUM_EPOCHS = 100
|
| 60 |
-
LEARNING_RATE = 1e-4
|
| 61 |
-
WEIGHT_DECAY = 1e-5
|
| 62 |
-
GRADIENT_CLIP_MAX_NORM = 1.0
|
| 63 |
-
|
| 64 |
-
# Loss Weights
|
| 65 |
-
VERTEX_LOSS_WEIGHT = 1.0
|
| 66 |
-
EDGE_LOSS_WEIGHT = 0.5
|
| 67 |
-
CONFIDENCE_LOSS_WEIGHT = 0.3
|
| 68 |
-
|
| 69 |
-
# Learning Rate Scheduler Parameters
|
| 70 |
-
LR_SCHEDULER_FACTOR = 0.5
|
| 71 |
-
LR_SCHEDULER_PATIENCE = 10
|
| 72 |
-
|
| 73 |
-
# Checkpoint and Logging
|
| 74 |
-
CHECKPOINT_SAVE_FREQUENCY = 1 # Save every N epochs
|
| 75 |
-
LOG_FREQUENCY = 10 # Print progress every N batches
|
| 76 |
-
|
| 77 |
-
# Device Configuration
|
| 78 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 79 |
-
|
| 80 |
-
# =============================================================================
|
| 81 |
-
# MODEL IMPLEMENTATION
|
| 82 |
-
# =============================================================================
|
| 83 |
-
|
| 84 |
-
# You would likely need a library like torch_geometric for GNNs
|
| 85 |
-
# from torch_geometric.nn import GATConv, EdgeConv # Example GNN layers
|
| 86 |
-
|
| 87 |
-
# --- 1. Point Cloud Encoder Backbone (Placeholder) ---
|
| 88 |
-
class PointNet2Encoder(nn.Module):
|
| 89 |
-
def __init__(self, input_features, output_features):
|
| 90 |
-
super().__init__()
|
| 91 |
-
self.input_features = input_features
|
| 92 |
-
self.output_features = output_features
|
| 93 |
-
|
| 94 |
-
# Set Abstraction layers - adjusted for 8096 input points
|
| 95 |
-
self.sa1 = SetAbstractionLayer(
|
| 96 |
-
npoint=SA1_NPOINT, radius=SA1_RADIUS, nsample=SA1_NSAMPLE,
|
| 97 |
-
in_channel=input_features + 3, mlp=SA1_MLP
|
| 98 |
-
)
|
| 99 |
-
self.sa2 = SetAbstractionLayer(
|
| 100 |
-
npoint=SA2_NPOINT, radius=SA2_RADIUS, nsample=SA2_NSAMPLE,
|
| 101 |
-
in_channel=SA1_MLP[-1] + 3, mlp=SA2_MLP
|
| 102 |
-
)
|
| 103 |
-
self.sa3 = SetAbstractionLayer(
|
| 104 |
-
npoint=None, radius=None, nsample=None, # Global pooling
|
| 105 |
-
in_channel=SA2_MLP[-1] + 3, mlp=SA3_MLP
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
# Feature Propagation layers for point-wise features
|
| 109 |
-
self.fp3 = FeaturePropagationLayer(in_channel=SA3_MLP[-1] + SA2_MLP[-1], mlp=FP3_MLP)
|
| 110 |
-
self.fp2 = FeaturePropagationLayer(in_channel=FP3_MLP[-1] + SA1_MLP[-1], mlp=FP2_MLP)
|
| 111 |
-
self.fp1 = FeaturePropagationLayer(in_channel=FP2_MLP[-1] + input_features, mlp=FP1_MLP + [output_features])
|
| 112 |
-
|
| 113 |
-
def forward(self, xyz):
|
| 114 |
-
# xyz: (B, N, 3) where N = 8096
|
| 115 |
-
B, N, _ = xyz.shape
|
| 116 |
-
|
| 117 |
-
# Initial features (can be empty or coordinates)
|
| 118 |
-
points = xyz if self.input_features == 3 else None
|
| 119 |
-
|
| 120 |
-
# Set Abstraction
|
| 121 |
-
l1_xyz, l1_points = self.sa1(xyz, points) # 8096 -> 1024 points
|
| 122 |
-
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # 1024 -> 256 points
|
| 123 |
-
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) # 256 -> 1 point (global)
|
| 124 |
-
|
| 125 |
-
# Feature Propagation
|
| 126 |
-
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
|
| 127 |
-
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
|
| 128 |
-
l0_points = self.fp1(xyz, l1_xyz, points, l1_points)
|
| 129 |
-
|
| 130 |
-
# Global feature from the most abstract level
|
| 131 |
-
global_feature = l3_points.squeeze(-1) # (B, 1024)
|
| 132 |
-
|
| 133 |
-
return l0_points, global_feature # (B, 8096, output_features), (B, 1024)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
class SetAbstractionLayer(nn.Module):
|
| 137 |
-
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
|
| 138 |
-
super().__init__()
|
| 139 |
-
self.npoint = npoint
|
| 140 |
-
self.radius = radius
|
| 141 |
-
self.nsample = nsample
|
| 142 |
-
self.group_all = group_all
|
| 143 |
-
|
| 144 |
-
self.mlp_convs = nn.ModuleList()
|
| 145 |
-
self.mlp_bns = nn.ModuleList()
|
| 146 |
-
last_channel = in_channel
|
| 147 |
-
for out_channel in mlp:
|
| 148 |
-
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
| 149 |
-
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
| 150 |
-
last_channel = out_channel
|
| 151 |
-
|
| 152 |
-
def forward(self, xyz, points):
|
| 153 |
-
# xyz: (B, N, 3)
|
| 154 |
-
# points: (B, N, C) or None
|
| 155 |
-
B, N, C = xyz.shape
|
| 156 |
-
|
| 157 |
-
if self.group_all or self.npoint is None:
|
| 158 |
-
# Global pooling
|
| 159 |
-
new_xyz = xyz.mean(dim=1, keepdim=True) # (B, 1, 3)
|
| 160 |
-
if points is not None:
|
| 161 |
-
new_points = torch.cat([xyz, points], dim=-1) # (B, N, 3+C)
|
| 162 |
-
new_points = new_points.transpose(1, 2).unsqueeze(-1) # (B, 3+C, N, 1)
|
| 163 |
-
else:
|
| 164 |
-
new_points = xyz.transpose(1, 2).unsqueeze(-1) # (B, 3, N, 1)
|
| 165 |
-
else:
|
| 166 |
-
# Farthest Point Sampling
|
| 167 |
-
fps_idx = farthest_point_sample(xyz, self.npoint) # (B, npoint)
|
| 168 |
-
new_xyz = index_points(xyz, fps_idx) # (B, npoint, 3)
|
| 169 |
-
|
| 170 |
-
# Ball Query
|
| 171 |
-
idx = ball_query(self.radius, self.nsample, xyz, new_xyz) # (B, npoint, nsample)
|
| 172 |
-
grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, 3)
|
| 173 |
-
grouped_xyz_norm = grouped_xyz - new_xyz.unsqueeze(2) # Relative positions
|
| 174 |
-
|
| 175 |
-
if points is not None:
|
| 176 |
-
grouped_points = index_points(points, idx) # (B, npoint, nsample, C)
|
| 177 |
-
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # (B, npoint, nsample, 3+C)
|
| 178 |
-
else:
|
| 179 |
-
new_points = grouped_xyz_norm # (B, npoint, nsample, 3)
|
| 180 |
-
|
| 181 |
-
new_points = new_points.permute(0, 3, 1, 2) # (B, 3+C, npoint, nsample)
|
| 182 |
-
|
| 183 |
-
# MLP
|
| 184 |
-
for i, conv in enumerate(self.mlp_convs):
|
| 185 |
-
bn = self.mlp_bns[i]
|
| 186 |
-
new_points = F.relu(bn(conv(new_points)))
|
| 187 |
-
|
| 188 |
-
# Max pooling
|
| 189 |
-
new_points = torch.max(new_points, dim=-1)[0] # (B, mlp[-1], npoint)
|
| 190 |
-
new_points = new_points.transpose(1, 2) # (B, npoint, mlp[-1])
|
| 191 |
-
|
| 192 |
-
return new_xyz, new_points
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
class FeaturePropagationLayer(nn.Module):
|
| 196 |
-
def __init__(self, in_channel, mlp):
|
| 197 |
-
super().__init__()
|
| 198 |
-
self.mlp_convs = nn.ModuleList()
|
| 199 |
-
self.mlp_bns = nn.ModuleList()
|
| 200 |
-
last_channel = in_channel
|
| 201 |
-
for out_channel in mlp:
|
| 202 |
-
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
| 203 |
-
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
| 204 |
-
last_channel = out_channel
|
| 205 |
-
|
| 206 |
-
def forward(self, xyz1, xyz2, points1, points2):
|
| 207 |
-
# xyz1: (B, N1, 3) - target points
|
| 208 |
-
# xyz2: (B, N2, 3) - source points
|
| 209 |
-
# points1: (B, N1, C1) - target features
|
| 210 |
-
# points2: (B, N2, C2) - source features
|
| 211 |
-
|
| 212 |
-
# Interpolate features from xyz2 to xyz1
|
| 213 |
-
if points2 is not None:
|
| 214 |
-
interpolated_points = interpolate_features(xyz1, xyz2, points2) # (B, N1, C2)
|
| 215 |
-
if points1 is not None:
|
| 216 |
-
# Ensure both tensors have the same number of points (N1)
|
| 217 |
-
assert points1.shape[1] == interpolated_points.shape[1], f"Point count mismatch: {points1.shape[1]} vs {interpolated_points.shape[1]}"
|
| 218 |
-
new_points = torch.cat([points1, interpolated_points], dim=-1) # (B, N1, C1+C2)
|
| 219 |
-
else:
|
| 220 |
-
new_points = interpolated_points
|
| 221 |
-
else:
|
| 222 |
-
new_points = points1
|
| 223 |
-
|
| 224 |
-
# Handle None case
|
| 225 |
-
if new_points is None:
|
| 226 |
-
return None
|
| 227 |
-
|
| 228 |
-
# MLP
|
| 229 |
-
new_points = new_points.transpose(1, 2) # (B, C, N1)
|
| 230 |
-
for i, conv in enumerate(self.mlp_convs):
|
| 231 |
-
bn = self.mlp_bns[i]
|
| 232 |
-
new_points = F.relu(bn(conv(new_points)))
|
| 233 |
-
|
| 234 |
-
return new_points.transpose(1, 2) # (B, N1, mlp[-1])
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def farthest_point_sample(xyz, npoint):
|
| 238 |
-
"""Farthest Point Sampling"""
|
| 239 |
-
device = xyz.device
|
| 240 |
-
B, N, C = xyz.shape
|
| 241 |
-
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
| 242 |
-
distance = torch.ones(B, N).to(device) * 1e10
|
| 243 |
-
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
| 244 |
-
|
| 245 |
-
for i in range(npoint):
|
| 246 |
-
centroids[:, i] = farthest
|
| 247 |
-
centroid = xyz[torch.arange(B), farthest, :].view(B, 1, 3)
|
| 248 |
-
dist = torch.sum((xyz - centroid) ** 2, -1)
|
| 249 |
-
mask = dist < distance
|
| 250 |
-
distance[mask] = dist[mask]
|
| 251 |
-
farthest = torch.max(distance, -1)[1]
|
| 252 |
-
|
| 253 |
-
return centroids
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def ball_query(radius, nsample, xyz, new_xyz):
|
| 257 |
-
"""Ball Query"""
|
| 258 |
-
device = xyz.device
|
| 259 |
-
B, N, C = xyz.shape
|
| 260 |
-
_, S, _ = new_xyz.shape
|
| 261 |
-
|
| 262 |
-
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
| 263 |
-
sqrdists = square_distance(new_xyz, xyz)
|
| 264 |
-
group_idx[sqrdists > radius ** 2] = N
|
| 265 |
-
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
| 266 |
-
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
| 267 |
-
mask = group_idx == N
|
| 268 |
-
group_idx[mask] = group_first[mask]
|
| 269 |
-
|
| 270 |
-
# If group_first[mask] was N (i.e., no points in the ball for a centroid),
|
| 271 |
-
# group_idx can still contain N. Clamp N to 0 to ensure valid indices.
|
| 272 |
-
# N corresponds to xyz.shape[1], which is guaranteed to be > 0 by the dataloader logic.
|
| 273 |
-
group_idx[group_idx == N] = 0
|
| 274 |
-
|
| 275 |
-
return group_idx
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
def square_distance(src, dst):
|
| 279 |
-
"""Calculate squared distance between each two points"""
|
| 280 |
-
B, N, _ = src.shape
|
| 281 |
-
_, M, _ = dst.shape
|
| 282 |
-
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
| 283 |
-
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
| 284 |
-
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
| 285 |
-
return dist
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def index_points(points, idx):
|
| 289 |
-
"""Index points using given indices"""
|
| 290 |
-
device = points.device
|
| 291 |
-
B = points.shape[0]
|
| 292 |
-
view_shape = list(idx.shape)
|
| 293 |
-
view_shape[1:] = [1] * (len(view_shape) - 1)
|
| 294 |
-
repeat_shape = list(idx.shape)
|
| 295 |
-
repeat_shape[0] = 1
|
| 296 |
-
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
| 297 |
-
new_points = points[batch_indices, idx, :]
|
| 298 |
-
return new_points
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
def interpolate_features(xyz1, xyz2, points2):
|
| 302 |
-
"""Interpolate features using inverse distance weighting"""
|
| 303 |
-
B, N1, C = xyz1.shape
|
| 304 |
-
_, N2, _ = xyz2.shape
|
| 305 |
-
|
| 306 |
-
if N2 == 1:
|
| 307 |
-
# If only one point, broadcast to all target points
|
| 308 |
-
interpolated_points = points2.expand(B, N1, -1)
|
| 309 |
-
else:
|
| 310 |
-
# Find 3 nearest neighbors and interpolate
|
| 311 |
-
dists = square_distance(xyz1, xyz2) # (B, N1, N2)
|
| 312 |
-
dists, idx = dists.sort(dim=-1)
|
| 313 |
-
|
| 314 |
-
# Use min(3, N2) neighbors to handle cases with fewer source points
|
| 315 |
-
k = min(3, N2)
|
| 316 |
-
dists, idx = dists[:, :, :k], idx[:, :, :k]
|
| 317 |
-
|
| 318 |
-
# Inverse distance weighting
|
| 319 |
-
dists[dists < 1e-10] = 1e-10
|
| 320 |
-
weight = 1.0 / dists # (B, N1, k)
|
| 321 |
-
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # Normalize
|
| 322 |
-
|
| 323 |
-
# Interpolate
|
| 324 |
-
interpolated_points = torch.sum(
|
| 325 |
-
index_points(points2, idx) * weight.view(B, N1, k, 1), dim=2
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
-
return interpolated_points
|
| 329 |
-
|
| 330 |
-
# --- 2. Vertex Prediction Head (Transformer-based) ---
|
| 331 |
-
class VertexPredictionHead(nn.Module):
|
| 332 |
-
def __init__(self, point_feature_dim, global_feature_dim, max_vertices, vertex_coord_dim=3,
|
| 333 |
-
hidden_dim=256, num_decoder_layers=3, num_heads=8):
|
| 334 |
-
super().__init__()
|
| 335 |
-
self.max_vertices = max_vertices
|
| 336 |
-
self.vertex_coord_dim = vertex_coord_dim
|
| 337 |
-
self.hidden_dim = hidden_dim
|
| 338 |
-
|
| 339 |
-
# Learnable vertex queries (similar to DETR object queries)
|
| 340 |
-
self.vertex_queries = nn.Parameter(torch.randn(max_vertices, hidden_dim))
|
| 341 |
-
|
| 342 |
-
# Project global feature to hidden dimension
|
| 343 |
-
self.global_proj = nn.Linear(global_feature_dim, 1)
|
| 344 |
-
|
| 345 |
-
# Project point features to hidden dimension for cross-attention
|
| 346 |
-
self.point_proj = nn.Linear(point_feature_dim, hidden_dim)
|
| 347 |
-
|
| 348 |
-
# Transformer decoder layers
|
| 349 |
-
decoder_layer = nn.TransformerDecoderLayer(
|
| 350 |
-
d_model=hidden_dim,
|
| 351 |
-
nhead=num_heads,
|
| 352 |
-
dim_feedforward=hidden_dim * VERTEX_TRANSFORMER_FFN_RATIO,
|
| 353 |
-
dropout=VERTEX_TRANSFORMER_DROPOUT,
|
| 354 |
-
batch_first=True
|
| 355 |
-
)
|
| 356 |
-
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
|
| 357 |
-
|
| 358 |
-
# Output heads
|
| 359 |
-
self.vertex_coord_head = nn.Sequential(
|
| 360 |
-
nn.Linear(hidden_dim, hidden_dim),
|
| 361 |
-
nn.ReLU(),
|
| 362 |
-
nn.Linear(hidden_dim, vertex_coord_dim)
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
# Confidence/existence head (predicts if vertex exists)
|
| 366 |
-
self.vertex_conf_head = nn.Sequential(
|
| 367 |
-
nn.Linear(hidden_dim, hidden_dim),
|
| 368 |
-
nn.ReLU(),
|
| 369 |
-
nn.Linear(hidden_dim, 1)
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
# Position encoding for point features
|
| 373 |
-
self.pos_encoding = nn.Sequential(
|
| 374 |
-
nn.Linear(3, hidden_dim // 2),
|
| 375 |
-
nn.ReLU(),
|
| 376 |
-
nn.Linear(hidden_dim // 2, hidden_dim)
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
def forward(self, point_features, global_feature, point_coords=None):
|
| 380 |
-
# point_features: (B, N, point_feature_dim)
|
| 381 |
-
# global_feature: (B, global_feature_dim)
|
| 382 |
-
# point_coords: (B, N, 3) - optional point coordinates for positional encoding
|
| 383 |
-
|
| 384 |
-
batch_size = point_features.shape[0]
|
| 385 |
-
|
| 386 |
-
# Project features to hidden dimension
|
| 387 |
-
point_features_proj = self.point_proj(point_features) # (B, N, hidden_dim)
|
| 388 |
-
|
| 389 |
-
# Add positional encoding if coordinates are provided
|
| 390 |
-
if point_coords is not None:
|
| 391 |
-
pos_enc = self.pos_encoding(point_coords) # (B, N, hidden_dim)
|
| 392 |
-
point_features_proj = point_features_proj + pos_enc
|
| 393 |
-
|
| 394 |
-
# Prepare vertex queries
|
| 395 |
-
vertex_queries = self.vertex_queries.unsqueeze(0).repeat(batch_size, 1, 1) # (B, max_vertices, hidden_dim)
|
| 396 |
-
|
| 397 |
-
# Add global context to vertex queries
|
| 398 |
-
global_proj = self.global_proj(global_feature).squeeze(-1).unsqueeze(1) # (B, 1, hidden_dim)
|
| 399 |
-
vertex_queries = vertex_queries + global_proj # Broadcasting will handle (B, 1, hidden_dim) + (B, max_vertices, hidden_dim)
|
| 400 |
-
|
| 401 |
-
# Transformer decoder: vertex queries attend to point features
|
| 402 |
-
vertex_features = self.transformer_decoder(
|
| 403 |
-
tgt=vertex_queries, # (B, max_vertices, hidden_dim)
|
| 404 |
-
memory=point_features_proj # (B, N, hidden_dim)
|
| 405 |
-
) # (B, max_vertices, hidden_dim)
|
| 406 |
-
|
| 407 |
-
# Predict vertex coordinates
|
| 408 |
-
predicted_vertices = self.vertex_coord_head(vertex_features) # (B, max_vertices, 3)
|
| 409 |
-
|
| 410 |
-
# Predict vertex confidence/existence
|
| 411 |
-
vertex_confidence = self.vertex_conf_head(vertex_features).squeeze(-1) # (B, max_vertices)
|
| 412 |
-
|
| 413 |
-
return predicted_vertices, vertex_confidence
|
| 414 |
-
|
| 415 |
-
# --- 3. Edge Prediction Head (GNN-based) ---
|
| 416 |
-
class EdgePredictionHeadGNN(nn.Module):
|
| 417 |
-
def __init__(self, vertex_feature_dim, gnn_hidden_dim, num_gnn_layers):
|
| 418 |
-
super().__init__()
|
| 419 |
-
self.vertex_feature_dim = vertex_feature_dim
|
| 420 |
-
self.gnn_hidden_dim = gnn_hidden_dim
|
| 421 |
-
self.num_gnn_layers = num_gnn_layers
|
| 422 |
-
|
| 423 |
-
# Initial vertex feature projection
|
| 424 |
-
self.vertex_proj = nn.Linear(vertex_feature_dim, gnn_hidden_dim)
|
| 425 |
-
|
| 426 |
-
# GNN layers using message passing
|
| 427 |
-
self.gnn_layers = nn.ModuleList()
|
| 428 |
-
for i in range(num_gnn_layers):
|
| 429 |
-
self.gnn_layers.append(
|
| 430 |
-
GraphAttentionLayer(
|
| 431 |
-
in_features=gnn_hidden_dim,
|
| 432 |
-
out_features=gnn_hidden_dim,
|
| 433 |
-
num_heads=EDGE_GNN_NUM_HEADS,
|
| 434 |
-
dropout=EDGE_GNN_DROPOUT
|
| 435 |
-
)
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
# Edge classifier MLP
|
| 439 |
-
self.edge_mlp = nn.Sequential(
|
| 440 |
-
nn.Linear(gnn_hidden_dim * 2, gnn_hidden_dim),
|
| 441 |
-
nn.ReLU(),
|
| 442 |
-
nn.Dropout(EDGE_GNN_DROPOUT),
|
| 443 |
-
nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 2),
|
| 444 |
-
nn.ReLU(),
|
| 445 |
-
nn.Linear(gnn_hidden_dim // 2, 1)
|
| 446 |
-
)
|
| 447 |
-
|
| 448 |
-
# Learnable threshold for k-NN graph construction
|
| 449 |
-
self.k_neighbors = EDGE_K_NEIGHBORS # Number of nearest neighbors for initial graph
|
| 450 |
-
|
| 451 |
-
def forward(self, vertices):
|
| 452 |
-
# vertices: (B, num_vertices, vertex_coord_dim)
|
| 453 |
-
batch_size, num_vertices, _ = vertices.shape
|
| 454 |
-
|
| 455 |
-
# Project vertex coordinates to hidden features
|
| 456 |
-
vertex_features = self.vertex_proj(vertices) # (B, num_vertices, gnn_hidden_dim)
|
| 457 |
-
|
| 458 |
-
# Construct initial graph based on spatial proximity (k-NN)
|
| 459 |
-
adjacency_matrix = self.construct_knn_graph(vertices, k=self.k_neighbors) # (B, num_vertices, num_vertices)
|
| 460 |
-
|
| 461 |
-
# Apply GNN layers
|
| 462 |
-
for gnn_layer in self.gnn_layers:
|
| 463 |
-
vertex_features = gnn_layer(vertex_features, adjacency_matrix) # (B, num_vertices, gnn_hidden_dim)
|
| 464 |
-
|
| 465 |
-
# Generate all possible vertex pairs
|
| 466 |
-
idx_pairs = torch.combinations(torch.arange(num_vertices), r=2).to(vertices.device) # (num_pairs, 2)
|
| 467 |
-
|
| 468 |
-
# Gather features for all vertex pairs
|
| 469 |
-
v1_features = vertex_features[:, idx_pairs[:, 0], :] # (B, num_pairs, gnn_hidden_dim)
|
| 470 |
-
v2_features = vertex_features[:, idx_pairs[:, 1], :] # (B, num_pairs, gnn_hidden_dim)
|
| 471 |
-
|
| 472 |
-
# Concatenate paired vertex features
|
| 473 |
-
edge_features = torch.cat([v1_features, v2_features], dim=2) # (B, num_pairs, gnn_hidden_dim * 2)
|
| 474 |
-
|
| 475 |
-
# Predict edge probabilities
|
| 476 |
-
edge_logits = self.edge_mlp(edge_features).squeeze(-1) # (B, num_pairs)
|
| 477 |
-
|
| 478 |
-
return edge_logits, idx_pairs
|
| 479 |
-
|
| 480 |
-
def construct_knn_graph(self, vertices, k):
|
| 481 |
-
# vertices: (B, num_vertices, 3)
|
| 482 |
-
batch_size, num_vertices, _ = vertices.shape
|
| 483 |
-
|
| 484 |
-
# Compute pairwise distances
|
| 485 |
-
distances = torch.cdist(vertices, vertices, p=2) # (B, num_vertices, num_vertices)
|
| 486 |
-
|
| 487 |
-
# Find k nearest neighbors for each vertex
|
| 488 |
-
_, knn_indices = torch.topk(distances, k + 1, dim=-1, largest=False) # +1 to include self
|
| 489 |
-
knn_indices = knn_indices[:, :, 1:] # Remove self-connection
|
| 490 |
-
|
| 491 |
-
# Create adjacency matrix
|
| 492 |
-
adjacency = torch.zeros(batch_size, num_vertices, num_vertices, device=vertices.device)
|
| 493 |
-
|
| 494 |
-
# Fill adjacency matrix
|
| 495 |
-
batch_idx = torch.arange(batch_size).view(-1, 1, 1).expand(-1, num_vertices, k)
|
| 496 |
-
vertex_idx = torch.arange(num_vertices).view(1, -1, 1).expand(batch_size, -1, k)
|
| 497 |
-
|
| 498 |
-
adjacency[batch_idx, vertex_idx, knn_indices] = 1.0
|
| 499 |
-
|
| 500 |
-
# Make adjacency symmetric
|
| 501 |
-
adjacency = torch.max(adjacency, adjacency.transpose(-1, -2))
|
| 502 |
-
|
| 503 |
-
return adjacency
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
class GraphAttentionLayer(nn.Module):
|
| 507 |
-
def __init__(self, in_features, out_features, num_heads=1, dropout=0.1):
|
| 508 |
-
super().__init__()
|
| 509 |
-
self.in_features = in_features
|
| 510 |
-
self.out_features = out_features
|
| 511 |
-
self.num_heads = num_heads
|
| 512 |
-
self.dropout = dropout
|
| 513 |
-
|
| 514 |
-
assert out_features % num_heads == 0
|
| 515 |
-
self.head_dim = out_features // num_heads
|
| 516 |
-
|
| 517 |
-
# Linear transformations for queries, keys, values
|
| 518 |
-
self.W_q = nn.Linear(in_features, out_features)
|
| 519 |
-
self.W_k = nn.Linear(in_features, out_features)
|
| 520 |
-
self.W_v = nn.Linear(in_features, out_features)
|
| 521 |
-
|
| 522 |
-
# Output projection
|
| 523 |
-
self.W_o = nn.Linear(out_features, out_features)
|
| 524 |
-
|
| 525 |
-
# Attention mechanism
|
| 526 |
-
self.attention = nn.MultiheadAttention(
|
| 527 |
-
embed_dim=out_features,
|
| 528 |
-
num_heads=num_heads,
|
| 529 |
-
dropout=dropout,
|
| 530 |
-
batch_first=True
|
| 531 |
-
)
|
| 532 |
-
|
| 533 |
-
# Layer normalization and residual connection
|
| 534 |
-
self.layer_norm = nn.LayerNorm(out_features)
|
| 535 |
-
self.ffn = nn.Sequential(
|
| 536 |
-
nn.Linear(out_features, out_features * 2),
|
| 537 |
-
nn.ReLU(),
|
| 538 |
-
nn.Dropout(dropout),
|
| 539 |
-
nn.Linear(out_features * 2, out_features)
|
| 540 |
-
)
|
| 541 |
-
self.layer_norm2 = nn.LayerNorm(out_features)
|
| 542 |
-
|
| 543 |
-
def forward(self, x, adjacency_matrix):
|
| 544 |
-
# x: (B, num_vertices, in_features)
|
| 545 |
-
# adjacency_matrix: (B, num_vertices, num_vertices)
|
| 546 |
-
batch_size, num_vertices, _ = x.shape
|
| 547 |
-
|
| 548 |
-
# Project to query, key, value
|
| 549 |
-
Q = self.W_q(x) # (B, num_vertices, out_features)
|
| 550 |
-
K = self.W_k(x) # (B, num_vertices, out_features)
|
| 551 |
-
V = self.W_v(x) # (B, num_vertices, out_features)
|
| 552 |
-
|
| 553 |
-
# Create attention mask from adjacency matrix
|
| 554 |
-
# Convert adjacency to attention mask (0 for allowed, -inf for masked)
|
| 555 |
-
attention_mask = (1 - adjacency_matrix) * (-1e9) # (B, num_vertices, num_vertices)
|
| 556 |
-
|
| 557 |
-
# Apply multi-head attention with adjacency-based masking
|
| 558 |
-
attended_features = []
|
| 559 |
-
for b in range(batch_size):
|
| 560 |
-
q_b = Q[b:b+1] # (1, num_vertices, out_features)
|
| 561 |
-
k_b = K[b:b+1] # (1, num_vertices, out_features)
|
| 562 |
-
v_b = V[b:b+1] # (1, num_vertices, out_features)
|
| 563 |
-
mask_b = attention_mask[b] # (num_vertices, num_vertices)
|
| 564 |
-
|
| 565 |
-
# Apply attention
|
| 566 |
-
attn_output, _ = self.attention(q_b, k_b, v_b, attn_mask=mask_b)
|
| 567 |
-
attended_features.append(attn_output)
|
| 568 |
-
|
| 569 |
-
attended_features = torch.cat(attended_features, dim=0) # (B, num_vertices, out_features)
|
| 570 |
-
|
| 571 |
-
# Residual connection and layer norm
|
| 572 |
-
x_residual = self.layer_norm(attended_features + Q)
|
| 573 |
-
|
| 574 |
-
# Feed-forward network
|
| 575 |
-
ffn_output = self.ffn(x_residual)
|
| 576 |
-
output = self.layer_norm2(ffn_output + x_residual)
|
| 577 |
-
|
| 578 |
-
return output
|
| 579 |
-
|
| 580 |
-
# --- Main Model ---
|
| 581 |
-
class PointCloudToWireframe(nn.Module):
|
| 582 |
-
def __init__(self,
|
| 583 |
-
pc_input_features=PC_INPUT_FEATURES,
|
| 584 |
-
pc_encoder_output_features=PC_ENCODER_OUTPUT_FEATURES,
|
| 585 |
-
max_vertices=MAX_VERTICES,
|
| 586 |
-
vertex_coord_dim=VERTEX_COORD_DIM,
|
| 587 |
-
gnn_hidden_dim=GNN_HIDDEN_DIM,
|
| 588 |
-
num_gnn_layers=NUM_GNN_LAYERS,
|
| 589 |
-
hidden_dim=HIDDEN_DIM,
|
| 590 |
-
num_decoder_layers=NUM_DECODER_LAYERS,
|
| 591 |
-
num_heads=NUM_HEADS):
|
| 592 |
-
super().__init__()
|
| 593 |
-
|
| 594 |
-
# Point cloud encoder using PointNet2-style architecture
|
| 595 |
-
self.encoder = PointNet2Encoder(pc_input_features, pc_encoder_output_features)
|
| 596 |
-
|
| 597 |
-
# Vertex prediction head using transformer decoder
|
| 598 |
-
self.vertex_head = VertexPredictionHead(
|
| 599 |
-
point_feature_dim=pc_encoder_output_features,
|
| 600 |
-
global_feature_dim=SA3_MLP[-1], # From PointNet2Encoder global feature
|
| 601 |
-
max_vertices=max_vertices,
|
| 602 |
-
vertex_coord_dim=vertex_coord_dim,
|
| 603 |
-
hidden_dim=hidden_dim,
|
| 604 |
-
num_decoder_layers=num_decoder_layers,
|
| 605 |
-
num_heads=num_heads
|
| 606 |
-
)
|
| 607 |
-
|
| 608 |
-
# Edge prediction head using GNN
|
| 609 |
-
self.edge_head = EdgePredictionHeadGNN(
|
| 610 |
-
vertex_feature_dim=vertex_coord_dim,
|
| 611 |
-
gnn_hidden_dim=gnn_hidden_dim,
|
| 612 |
-
num_gnn_layers=num_gnn_layers
|
| 613 |
-
)
|
| 614 |
-
|
| 615 |
-
def forward(self, point_cloud):
|
| 616 |
-
# point_cloud: (B, N, 3)
|
| 617 |
-
batch_size, num_points, _ = point_cloud.shape
|
| 618 |
-
|
| 619 |
-
# Encode point cloud
|
| 620 |
-
point_features, global_feature = self.encoder(point_cloud)
|
| 621 |
-
# point_features: (B, N, pc_encoder_output_features)
|
| 622 |
-
# global_feature: (B, 1024)
|
| 623 |
-
|
| 624 |
-
# Predict vertices
|
| 625 |
-
predicted_vertices, vertex_confidence = self.vertex_head(
|
| 626 |
-
point_features, global_feature, point_coords=point_cloud
|
| 627 |
-
)
|
| 628 |
-
# predicted_vertices: (B, max_vertices, 3)
|
| 629 |
-
# vertex_confidence: (B, max_vertices)
|
| 630 |
-
|
| 631 |
-
# Predict edges using GNN (using vertex coordinates directly)
|
| 632 |
-
edge_logits, edge_indices = self.edge_head(predicted_vertices)
|
| 633 |
-
# edge_logits: (B, num_potential_edges)
|
| 634 |
-
# edge_indices: (num_potential_edges, 2)
|
| 635 |
-
|
| 636 |
-
return {
|
| 637 |
-
'vertices': predicted_vertices,
|
| 638 |
-
'vertex_confidence': vertex_confidence,
|
| 639 |
-
'edge_logits': edge_logits,
|
| 640 |
-
'edge_indices': edge_indices
|
| 641 |
-
}
|
| 642 |
-
|
| 643 |
-
class WireframeDataset(Dataset):
|
| 644 |
-
def __init__(self, data_dir=DATA_DIR, split=SPLIT, transform=None, max_points=MAX_POINTS):
|
| 645 |
-
"""
|
| 646 |
-
Dataset for point cloud to wireframe conversion.
|
| 647 |
-
|
| 648 |
-
Args:
|
| 649 |
-
data_dir: Directory containing the pickle files
|
| 650 |
-
split: 'train', 'val', or 'test'
|
| 651 |
-
transform: Optional transforms to apply to point clouds
|
| 652 |
-
max_points: Maximum number of points in the point cloud (default: 8096)
|
| 653 |
-
"""
|
| 654 |
-
self.data_dir = data_dir
|
| 655 |
-
self.split = split
|
| 656 |
-
self.transform = transform
|
| 657 |
-
self.max_points = max_points
|
| 658 |
-
|
| 659 |
-
# Get all pickle files in the directory
|
| 660 |
-
self.data_files = []
|
| 661 |
-
for file in os.listdir(data_dir):
|
| 662 |
-
if file.endswith('.pkl'):
|
| 663 |
-
self.data_files.append(os.path.join(data_dir, file))
|
| 664 |
-
|
| 665 |
-
self.data_files.sort() # Ensure consistent ordering
|
| 666 |
-
|
| 667 |
-
def __len__(self):
|
| 668 |
-
return len(self.data_files)
|
| 669 |
-
|
| 670 |
-
def __getitem__(self, idx):
|
| 671 |
-
# Load the pickle file
|
| 672 |
-
with open(self.data_files[idx], 'rb') as f:
|
| 673 |
-
sample_data = pickle.load(f)
|
| 674 |
-
|
| 675 |
-
# Extract data
|
| 676 |
-
point_cloud = torch.tensor(sample_data['point_cloud'], dtype=torch.float32)
|
| 677 |
-
point_colors = torch.tensor(sample_data['point_colors'], dtype=torch.float32)
|
| 678 |
-
gt_vertices = torch.tensor(sample_data['gt_vertices'], dtype=torch.float32)
|
| 679 |
-
gt_connections = sample_data['gt_connections'] # List of tuples
|
| 680 |
-
sample_id = sample_data['sample_id']
|
| 681 |
-
|
| 682 |
-
# Handle point cloud size to match max_points
|
| 683 |
-
current_points = point_cloud.shape[0]
|
| 684 |
-
|
| 685 |
-
if current_points > self.max_points:
|
| 686 |
-
# Downsample using random sampling
|
| 687 |
-
indices = torch.randperm(current_points)[:self.max_points]
|
| 688 |
-
point_cloud = point_cloud[indices]
|
| 689 |
-
point_colors = point_colors[indices]
|
| 690 |
-
elif current_points < self.max_points:
|
| 691 |
-
# Pad by repeating last point or duplicating random points
|
| 692 |
-
pad_size = self.max_points - current_points
|
| 693 |
-
if current_points > 0:
|
| 694 |
-
# Randomly sample existing points to pad
|
| 695 |
-
pad_indices = torch.randint(0, current_points, (pad_size,))
|
| 696 |
-
pad_points = point_cloud[pad_indices]
|
| 697 |
-
pad_colors = point_colors[pad_indices]
|
| 698 |
-
point_cloud = torch.cat([point_cloud, pad_points], dim=0)
|
| 699 |
-
point_colors = torch.cat([point_colors, pad_colors], dim=0)
|
| 700 |
-
else:
|
| 701 |
-
# Edge case: no points, pad with zeros
|
| 702 |
-
point_cloud = torch.zeros(self.max_points, 3)
|
| 703 |
-
point_colors = torch.zeros(self.max_points, 3)
|
| 704 |
-
|
| 705 |
-
# Convert connections to edge format
|
| 706 |
-
if len(gt_connections) > 0:
|
| 707 |
-
edge_indices = torch.tensor(gt_connections, dtype=torch.long).t() # (2, num_edges)
|
| 708 |
-
else:
|
| 709 |
-
edge_indices = torch.zeros((2, 0), dtype=torch.long) # Empty edges
|
| 710 |
-
|
| 711 |
-
# Apply transforms if any
|
| 712 |
-
if self.transform:
|
| 713 |
-
point_cloud = self.transform(point_cloud)
|
| 714 |
-
|
| 715 |
-
return {
|
| 716 |
-
'point_cloud': point_cloud,
|
| 717 |
-
'point_colors': point_colors,
|
| 718 |
-
'gt_vertices': gt_vertices,
|
| 719 |
-
'edge_indices': edge_indices,
|
| 720 |
-
'sample_id': sample_id
|
| 721 |
-
}
|
| 722 |
-
|
| 723 |
-
def collate_fn(batch):
|
| 724 |
-
"""
|
| 725 |
-
Custom collate function to handle variable number of vertices and edges.
|
| 726 |
-
"""
|
| 727 |
-
point_clouds = []
|
| 728 |
-
point_colors = []
|
| 729 |
-
gt_vertices_list = []
|
| 730 |
-
edge_indices_list = []
|
| 731 |
-
sample_ids = []
|
| 732 |
-
|
| 733 |
-
max_vertices = 0
|
| 734 |
-
|
| 735 |
-
for sample in batch:
|
| 736 |
-
point_clouds.append(sample['point_cloud'])
|
| 737 |
-
point_colors.append(sample['point_colors'])
|
| 738 |
-
gt_vertices_list.append(sample['gt_vertices'])
|
| 739 |
-
edge_indices_list.append(sample['edge_indices'])
|
| 740 |
-
sample_ids.append(sample['sample_id'])
|
| 741 |
-
|
| 742 |
-
max_vertices = max(max_vertices, sample['gt_vertices'].shape[0])
|
| 743 |
-
|
| 744 |
-
# Pad point clouds to same size if needed
|
| 745 |
-
max_points = max(pc.shape[0] for pc in point_clouds)
|
| 746 |
-
padded_point_clouds = []
|
| 747 |
-
padded_point_colors = []
|
| 748 |
-
|
| 749 |
-
for pc, colors in zip(point_clouds, point_colors):
|
| 750 |
-
if pc.shape[0] < max_points:
|
| 751 |
-
# Pad with zeros or repeat last point
|
| 752 |
-
pad_size = max_points - pc.shape[0]
|
| 753 |
-
pc_padded = torch.cat([pc, torch.zeros(pad_size, 3)], dim=0)
|
| 754 |
-
colors_padded = torch.cat([colors, torch.zeros(pad_size, 3)], dim=0)
|
| 755 |
-
else:
|
| 756 |
-
pc_padded = pc
|
| 757 |
-
colors_padded = colors
|
| 758 |
-
|
| 759 |
-
padded_point_clouds.append(pc_padded)
|
| 760 |
-
padded_point_colors.append(colors_padded)
|
| 761 |
-
|
| 762 |
-
# Stack point clouds
|
| 763 |
-
point_clouds_batch = torch.stack(padded_point_clouds)
|
| 764 |
-
point_colors_batch = torch.stack(padded_point_colors)
|
| 765 |
-
|
| 766 |
-
# Pad vertices to max_vertices
|
| 767 |
-
padded_vertices = []
|
| 768 |
-
vertex_masks = [] # To indicate which vertices are real vs padded
|
| 769 |
-
|
| 770 |
-
for vertices in gt_vertices_list:
|
| 771 |
-
num_vertices = vertices.shape[0]
|
| 772 |
-
if num_vertices < max_vertices:
|
| 773 |
-
# Pad with zeros
|
| 774 |
-
pad_size = max_vertices - num_vertices
|
| 775 |
-
vertices_padded = torch.cat([vertices, torch.zeros(pad_size, 3)], dim=0)
|
| 776 |
-
mask = torch.cat([torch.ones(num_vertices), torch.zeros(pad_size)], dim=0).bool()
|
| 777 |
-
else:
|
| 778 |
-
vertices_padded = vertices
|
| 779 |
-
mask = torch.ones(num_vertices).bool()
|
| 780 |
-
|
| 781 |
-
padded_vertices.append(vertices_padded)
|
| 782 |
-
vertex_masks.append(mask)
|
| 783 |
-
|
| 784 |
-
gt_vertices_batch = torch.stack(padded_vertices)
|
| 785 |
-
vertex_masks_batch = torch.stack(vertex_masks)
|
| 786 |
-
|
| 787 |
-
# Create adjacency matrices for edges
|
| 788 |
-
batch_size = len(batch)
|
| 789 |
-
adjacency_matrices = torch.zeros(batch_size, max_vertices, max_vertices)
|
| 790 |
-
|
| 791 |
-
for i, edge_indices in enumerate(edge_indices_list):
|
| 792 |
-
if edge_indices.shape[1] > 0: # If there are edges
|
| 793 |
-
src, dst = edge_indices[0], edge_indices[1]
|
| 794 |
-
# Only add edges for valid vertices (within the actual vertex count)
|
| 795 |
-
valid_edges = (src < gt_vertices_list[i].shape[0]) & (dst < gt_vertices_list[i].shape[0])
|
| 796 |
-
src_valid = src[valid_edges]
|
| 797 |
-
dst_valid = dst[valid_edges]
|
| 798 |
-
adjacency_matrices[i, src_valid, dst_valid] = 1
|
| 799 |
-
adjacency_matrices[i, dst_valid, src_valid] = 1 # Undirected graph
|
| 800 |
-
|
| 801 |
-
return {
|
| 802 |
-
'point_cloud': point_clouds_batch,
|
| 803 |
-
'point_colors': point_colors_batch,
|
| 804 |
-
'gt_vertices': gt_vertices_batch,
|
| 805 |
-
'vertex_masks': vertex_masks_batch,
|
| 806 |
-
'adjacency_matrices': adjacency_matrices,
|
| 807 |
-
'edge_indices_list': edge_indices_list, # Keep original for loss computation
|
| 808 |
-
'sample_ids': sample_ids
|
| 809 |
-
}
|
| 810 |
-
|
| 811 |
-
# Loss functions
|
| 812 |
-
def compute_vertex_loss(pred_vertices, gt_vertices, vertex_masks, vertex_confidence):
|
| 813 |
-
"""
|
| 814 |
-
Compute vertex position loss using Hungarian matching
|
| 815 |
-
"""
|
| 816 |
-
batch_size = pred_vertices.shape[0]
|
| 817 |
-
total_loss = 0.0
|
| 818 |
-
total_confidence_loss = 0.0
|
| 819 |
-
|
| 820 |
-
for b in range(batch_size):
|
| 821 |
-
# Get valid GT vertices for this sample
|
| 822 |
-
valid_mask = vertex_masks[b]
|
| 823 |
-
gt_verts = gt_vertices[b][valid_mask] # (num_valid_gt, 3)
|
| 824 |
-
num_gt = gt_verts.shape[0]
|
| 825 |
-
|
| 826 |
-
if num_gt == 0:
|
| 827 |
-
# No GT vertices, penalize high confidence predictions
|
| 828 |
-
confidence_target = torch.zeros_like(vertex_confidence[b])
|
| 829 |
-
conf_loss = F.binary_cross_entropy_with_logits(vertex_confidence[b], confidence_target)
|
| 830 |
-
total_confidence_loss += conf_loss
|
| 831 |
-
continue
|
| 832 |
-
|
| 833 |
-
pred_verts = pred_vertices[b] # (max_vertices, 3)
|
| 834 |
-
pred_conf = vertex_confidence[b] # (max_vertices,)
|
| 835 |
-
|
| 836 |
-
# Compute pairwise distances between predicted and GT vertices
|
| 837 |
-
distances = torch.cdist(pred_verts, gt_verts) # (max_vertices, num_gt)
|
| 838 |
-
|
| 839 |
-
# Hungarian matching to find optimal assignment
|
| 840 |
-
|
| 841 |
-
# Convert to numpy for scipy
|
| 842 |
-
cost_matrix = distances.detach().cpu().numpy()
|
| 843 |
-
|
| 844 |
-
# Pad cost matrix if needed
|
| 845 |
-
if distances.shape[0] < distances.shape[1]:
|
| 846 |
-
# More GT vertices than predicted - pad with high cost
|
| 847 |
-
padding = np.full((distances.shape[1] - distances.shape[0], distances.shape[1]), 1e6)
|
| 848 |
-
cost_matrix = np.vstack([cost_matrix, padding])
|
| 849 |
-
elif distances.shape[0] > distances.shape[1]:
|
| 850 |
-
# More predicted vertices than GT - pad with high cost
|
| 851 |
-
padding = np.full((distances.shape[0], distances.shape[0] - distances.shape[1]), 1e6)
|
| 852 |
-
cost_matrix = np.hstack([cost_matrix, padding])
|
| 853 |
-
|
| 854 |
-
# Solve assignment problem
|
| 855 |
-
pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
|
| 856 |
-
|
| 857 |
-
# Filter out dummy assignments (high cost padding)
|
| 858 |
-
# Ensure pred_indices are valid for pred_verts and gt_indices for gt_verts
|
| 859 |
-
valid_assignments = (pred_indices < pred_verts.shape[0]) & (gt_indices < num_gt)
|
| 860 |
-
pred_indices = pred_indices[valid_assignments]
|
| 861 |
-
gt_indices = gt_indices[valid_assignments]
|
| 862 |
-
|
| 863 |
-
if len(pred_indices) > 0:
|
| 864 |
-
# Compute position loss for matched vertices
|
| 865 |
-
matched_pred = pred_verts[pred_indices]
|
| 866 |
-
matched_gt = gt_verts[gt_indices]
|
| 867 |
-
position_loss = F.mse_loss(matched_pred, matched_gt)
|
| 868 |
-
total_loss += position_loss
|
| 869 |
-
|
| 870 |
-
# Confidence targets: 1 for matched vertices, 0 for unmatched
|
| 871 |
-
confidence_target = torch.zeros_like(pred_conf)
|
| 872 |
-
confidence_target[pred_indices] = 1.0
|
| 873 |
-
conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
|
| 874 |
-
total_confidence_loss += conf_loss
|
| 875 |
-
else:
|
| 876 |
-
# No valid matches - penalize all predictions
|
| 877 |
-
confidence_target = torch.zeros_like(pred_conf)
|
| 878 |
-
conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
|
| 879 |
-
total_confidence_loss += conf_loss
|
| 880 |
-
|
| 881 |
-
return total_loss / batch_size, total_confidence_loss / batch_size
|
| 882 |
-
|
| 883 |
-
def compute_edge_loss(edge_logits, edge_indices, gt_adjacency_matrices):
|
| 884 |
-
"""
|
| 885 |
-
Compute edge prediction loss
|
| 886 |
-
"""
|
| 887 |
-
batch_size = gt_adjacency_matrices.shape[0]
|
| 888 |
-
|
| 889 |
-
# Create edge targets from adjacency matrices
|
| 890 |
-
edge_targets = []
|
| 891 |
-
for b in range(batch_size):
|
| 892 |
-
gt_adj_for_sample = gt_adjacency_matrices[b] # Shape: (batch_max_gt_verts, batch_max_gt_verts)
|
| 893 |
-
|
| 894 |
-
# Create a target adjacency matrix of size (MAX_VERTICES, MAX_VERTICES)
|
| 895 |
-
# as edge_indices are generated based on the global MAX_VERTICES.
|
| 896 |
-
target_adj_full_size = torch.zeros(
|
| 897 |
-
MAX_VERTICES,
|
| 898 |
-
MAX_VERTICES,
|
| 899 |
-
device=gt_adj_for_sample.device,
|
| 900 |
-
dtype=gt_adj_for_sample.dtype
|
| 901 |
-
)
|
| 902 |
-
|
| 903 |
-
# Determine the actual dimension of the current sample's GT adjacency matrix (padded to batch max)
|
| 904 |
-
current_gt_dim = gt_adj_for_sample.shape[0]
|
| 905 |
-
|
| 906 |
-
# Copy the relevant part of gt_adj_for_sample into the full-sized target matrix.
|
| 907 |
-
# The copy_dim is the minimum of MAX_VERTICES and the current GT dimension,
|
| 908 |
-
# ensuring we don't read out of bounds from gt_adj_for_sample or write out of bounds to target_adj_full_size.
|
| 909 |
-
copy_dim = min(MAX_VERTICES, current_gt_dim)
|
| 910 |
-
|
| 911 |
-
target_adj_full_size[:copy_dim, :copy_dim] = gt_adj_for_sample[:copy_dim, :copy_dim]
|
| 912 |
-
|
| 913 |
-
# Extract targets using edge_indices, which refer to pairs in a MAX_VERTICES graph.
|
| 914 |
-
targets = target_adj_full_size[edge_indices[:, 0], edge_indices[:, 1]]
|
| 915 |
-
edge_targets.append(targets)
|
| 916 |
-
|
| 917 |
-
edge_targets = torch.stack(edge_targets) # Shape: (batch_size, num_potential_edges_in_MAX_VERTICES_graph)
|
| 918 |
-
edge_targets = edge_targets.to(edge_logits.device)
|
| 919 |
-
|
| 920 |
-
# Binary cross entropy loss
|
| 921 |
-
edge_loss = F.binary_cross_entropy_with_logits(edge_logits, edge_targets)
|
| 922 |
-
|
| 923 |
-
return edge_loss
|
| 924 |
-
|
| 925 |
-
def compute_total_loss(model_output, batch):
|
| 926 |
-
"""
|
| 927 |
-
Compute total loss combining vertex and edge losses
|
| 928 |
-
"""
|
| 929 |
-
# Extract model outputs
|
| 930 |
-
pred_vertices = model_output['vertices']
|
| 931 |
-
vertex_confidence = model_output['vertex_confidence']
|
| 932 |
-
edge_logits = model_output['edge_logits']
|
| 933 |
-
edge_indices = model_output['edge_indices']
|
| 934 |
-
|
| 935 |
-
# Extract ground truth
|
| 936 |
-
gt_vertices = batch['gt_vertices'].to(DEVICE)
|
| 937 |
-
vertex_masks = batch['vertex_masks'].to(DEVICE)
|
| 938 |
-
gt_adjacency = batch['adjacency_matrices'].to(DEVICE)
|
| 939 |
-
|
| 940 |
-
# Compute individual losses
|
| 941 |
-
vertex_pos_loss, vertex_conf_loss = compute_vertex_loss(
|
| 942 |
-
pred_vertices, gt_vertices, vertex_masks, vertex_confidence
|
| 943 |
-
)
|
| 944 |
-
edge_loss = compute_edge_loss(edge_logits, edge_indices, gt_adjacency)
|
| 945 |
-
|
| 946 |
-
# Combine losses
|
| 947 |
-
total_loss = (VERTEX_LOSS_WEIGHT * vertex_pos_loss +
|
| 948 |
-
CONFIDENCE_LOSS_WEIGHT * vertex_conf_loss +
|
| 949 |
-
EDGE_LOSS_WEIGHT * edge_loss)
|
| 950 |
-
|
| 951 |
-
return {
|
| 952 |
-
'total_loss': total_loss,
|
| 953 |
-
'vertex_pos_loss': vertex_pos_loss,
|
| 954 |
-
'vertex_conf_loss': vertex_conf_loss,
|
| 955 |
-
'edge_loss': edge_loss
|
| 956 |
-
}
|
| 957 |
-
|
| 958 |
-
# =============================================================================
|
| 959 |
-
# MAIN TRAINING SCRIPT
|
| 960 |
-
# =============================================================================
|
| 961 |
-
|
| 962 |
-
if __name__ == '__main__':
|
| 963 |
-
# Create dataset and dataloader
|
| 964 |
-
dataset = WireframeDataset(data_dir=DATA_DIR, split=SPLIT)
|
| 965 |
-
dataloader = DataLoader(
|
| 966 |
-
dataset,
|
| 967 |
-
batch_size=BATCH_SIZE,
|
| 968 |
-
shuffle=True,
|
| 969 |
-
collate_fn=collate_fn,
|
| 970 |
-
num_workers=NUM_WORKERS
|
| 971 |
-
)
|
| 972 |
-
|
| 973 |
-
# Initialize model
|
| 974 |
-
model = PointCloudToWireframe()
|
| 975 |
-
|
| 976 |
-
# Move model to device
|
| 977 |
-
model = model.to(DEVICE)
|
| 978 |
-
print(f"Model loaded on device: {DEVICE}")
|
| 979 |
-
|
| 980 |
-
# Initialize optimizer and scheduler
|
| 981 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 982 |
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 983 |
-
optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE
|
| 984 |
-
)
|
| 985 |
-
|
| 986 |
-
# Training loop
|
| 987 |
-
model.train()
|
| 988 |
-
print("Starting training...")
|
| 989 |
-
|
| 990 |
-
for epoch in range(NUM_EPOCHS):
|
| 991 |
-
epoch_losses = {
|
| 992 |
-
'total_loss': 0.0,
|
| 993 |
-
'vertex_pos_loss': 0.0,
|
| 994 |
-
'vertex_conf_loss': 0.0,
|
| 995 |
-
'edge_loss': 0.0
|
| 996 |
-
}
|
| 997 |
-
num_batches = 0
|
| 998 |
-
|
| 999 |
-
for batch_idx, batch in enumerate(dataloader):
|
| 1000 |
-
# Move data to device
|
| 1001 |
-
point_cloud = batch['point_cloud'].to(DEVICE)
|
| 1002 |
-
|
| 1003 |
-
# Zero gradients
|
| 1004 |
-
optimizer.zero_grad()
|
| 1005 |
-
|
| 1006 |
-
# Forward pass
|
| 1007 |
-
output = model(point_cloud)
|
| 1008 |
-
|
| 1009 |
-
# Compute losses
|
| 1010 |
-
losses = compute_total_loss(output, batch)
|
| 1011 |
-
|
| 1012 |
-
# Backward pass
|
| 1013 |
-
losses['total_loss'].backward()
|
| 1014 |
-
|
| 1015 |
-
# Gradient clipping
|
| 1016 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM)
|
| 1017 |
-
|
| 1018 |
-
# Update weights
|
| 1019 |
-
optimizer.step()
|
| 1020 |
-
|
| 1021 |
-
# Accumulate losses
|
| 1022 |
-
for key in epoch_losses:
|
| 1023 |
-
epoch_losses[key] += losses[key].item()
|
| 1024 |
-
num_batches += 1
|
| 1025 |
-
|
| 1026 |
-
# Print progress
|
| 1027 |
-
if batch_idx % LOG_FREQUENCY == 0:
|
| 1028 |
-
print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx}/{len(dataloader)}")
|
| 1029 |
-
print(f" Total Loss: {losses['total_loss'].item():.4f}")
|
| 1030 |
-
print(f" Vertex Pos Loss: {losses['vertex_pos_loss'].item():.4f}")
|
| 1031 |
-
print(f" Vertex Conf Loss: {losses['vertex_conf_loss'].item():.4f}")
|
| 1032 |
-
print(f" Edge Loss: {losses['edge_loss'].item():.4f}")
|
| 1033 |
-
|
| 1034 |
-
# Average losses for the epoch
|
| 1035 |
-
for key in epoch_losses:
|
| 1036 |
-
epoch_losses[key] /= num_batches
|
| 1037 |
-
|
| 1038 |
-
# Update learning rate scheduler
|
| 1039 |
-
scheduler.step(epoch_losses['total_loss'])
|
| 1040 |
-
|
| 1041 |
-
# Print epoch summary
|
| 1042 |
-
print(f"\nEpoch {epoch+1} Summary:")
|
| 1043 |
-
print(f" Avg Total Loss: {epoch_losses['total_loss']:.4f}")
|
| 1044 |
-
print(f" Avg Vertex Pos Loss: {epoch_losses['vertex_pos_loss']:.4f}")
|
| 1045 |
-
print(f" Avg Vertex Conf Loss: {epoch_losses['vertex_conf_loss']:.4f}")
|
| 1046 |
-
print(f" Avg Edge Loss: {epoch_losses['edge_loss']:.4f}")
|
| 1047 |
-
print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
|
| 1048 |
-
print("-" * 50)
|
| 1049 |
-
|
| 1050 |
-
# Save checkpoint every epoch
|
| 1051 |
-
if (epoch + 1) % CHECKPOINT_SAVE_FREQUENCY == 0:
|
| 1052 |
-
checkpoint = {
|
| 1053 |
-
'epoch': epoch + 1,
|
| 1054 |
-
'model_state_dict': model.state_dict(),
|
| 1055 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 1056 |
-
'scheduler_state_dict': scheduler.state_dict(),
|
| 1057 |
-
'losses': epoch_losses,
|
| 1058 |
-
'config': {
|
| 1059 |
-
'pc_input_features': PC_INPUT_FEATURES,
|
| 1060 |
-
'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
|
| 1061 |
-
'max_vertices': MAX_VERTICES,
|
| 1062 |
-
'gnn_hidden_dim': GNN_HIDDEN_DIM,
|
| 1063 |
-
'num_gnn_layers': NUM_GNN_LAYERS
|
| 1064 |
-
}
|
| 1065 |
-
}
|
| 1066 |
-
torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
|
| 1067 |
-
print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
|
| 1068 |
-
|
| 1069 |
-
# Save final model
|
| 1070 |
-
torch.save({
|
| 1071 |
-
'model_state_dict': model.state_dict(),
|
| 1072 |
-
'model_config': {
|
| 1073 |
-
'pc_input_features': PC_INPUT_FEATURES,
|
| 1074 |
-
'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
|
| 1075 |
-
'max_vertices': MAX_VERTICES,
|
| 1076 |
-
'gnn_hidden_dim': GNN_HIDDEN_DIM,
|
| 1077 |
-
'num_gnn_layers': NUM_GNN_LAYERS
|
| 1078 |
-
}
|
| 1079 |
-
}, 'final_model.pth')
|
| 1080 |
-
|
| 1081 |
-
print("Training completed!")
|
| 1082 |
-
print(f"Dataset size: {len(dataset)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_pcloud_dataset.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
from hoho2025.viz3d import *
|
| 3 |
import os
|
|
@@ -55,3 +66,4 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 55 |
|
| 56 |
print(f"Generated {counter} samples in {output_dir}")
|
| 57 |
|
|
|
|
|
|
| 1 |
+
# This script processes the 'usm3d/hoho25k' dataset.
|
| 2 |
+
# For each sample in the dataset, it performs the following steps:
|
| 3 |
+
# 1. Reads COLMAP reconstruction data.
|
| 4 |
+
# 2. Extracts 3D point coordinates and their corresponding colors.
|
| 5 |
+
# 3. Retrieves ground truth wireframe vertices and edges.
|
| 6 |
+
# 4. Skips processing if the output file already exists or if no 3D points are found.
|
| 7 |
+
# 5. Saves the extracted point cloud, colors, ground truth data, and sample ID
|
| 8 |
+
# into a pickle file in a specified output directory.
|
| 9 |
+
# The script shuffles the dataset before processing and keeps track of
|
| 10 |
+
# the number of samples successfully processed and saved.
|
| 11 |
+
#
|
| 12 |
from datasets import load_dataset
|
| 13 |
from hoho2025.viz3d import *
|
| 14 |
import os
|
|
|
|
| 66 |
|
| 67 |
print(f"Generated {counter} samples in {output_dir}")
|
| 68 |
|
| 69 |
+
|
hoho_cpu.batch
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=8 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_cpu.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_cpu.out # standard output file
|
| 9 |
-
#SBATCH --partition=amd # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
|
| 13 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 14 |
-
module purge
|
| 15 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 16 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 17 |
-
python train.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_cpu_gpu_intel.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=8 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_cpu.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_cpu.out # standard output file
|
| 9 |
-
#SBATCH --partition=gpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_cluster.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_class.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu_class2_v4.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu_class2_v4.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_class_cluster.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_class_10d.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu_class_10d_v2.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu_class_10d_v2.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_class_cluster_10d.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_class_10d_2048.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu_class_10d_2048.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu_class_10d_2048.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_class_cluster_10d_2048.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_class_10d_deeper.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu_class_10d_v2_deeper_v2.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu_class_10d_v2_deeper_v2.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_class_cluster_10d_deeper.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_h200.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=20 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu_h200_v2_class.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu_h200_v2_class.out # standard output file
|
| 9 |
-
#SBATCH --partition=h200 # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.12.3-GCCcore-13.3.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_pnet_cluster_class_v2.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hoho_gpu_voxel.batch
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --nodes=1 # 1 node
|
| 3 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 4 |
-
#SBATCH --cpus-per-task=16 # 6 CPUS per task = 12 CPUS per node
|
| 5 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 6 |
-
#SBATCH --time=24:00:00 # time limits: 1 hour
|
| 7 |
-
#SBATCH --error=hoho_gpu.err # standard error file
|
| 8 |
-
#SBATCH --output=hoho_gpu.out # standard output file
|
| 9 |
-
#SBATCH --partition=amdgpu # partition name
|
| 10 |
-
#SBATCH --mail-user=skvrnjan@fel.cvut.cz # where send info about job
|
| 11 |
-
#SBATCH --mail-type=ALL # what to send, valid type values are NONE, BEGIN, END, FAIL, REQUEUE, ALL
|
| 12 |
-
#SBATCH --gres=gpu:1
|
| 13 |
-
|
| 14 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 15 |
-
module purge
|
| 16 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 17 |
-
module load CUDA/12.6.0
|
| 18 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 19 |
-
python train_voxel_cluster.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_100.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8d9de42aafd8ca9f0e831c920a41d35f61f05ecf6f96c0227a46d16c34cd861c
|
| 3 |
-
size 93364299
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_100_class_v2.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:af653e5fe08dc57b2bb84996896c60d06a71e1c1a0197ad0e07ee2d03dc080e8
|
| 3 |
-
size 92609251
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_100_v2.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:be55556839f8d4fedad7a5cf7520b48859d7bd9a3fbb6b2efbae627ca8ca3ffc
|
| 3 |
-
size 103080355
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_100_v2_aug.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:66ba39359176a2d4d9f444912c58f24d5039b49c802dd8c2be45bcba10694054
|
| 3 |
-
size 103080355
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_60.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ef3f7c9462b297447ee24b7994919d625b958d1941626215d1d41f27faf7dac1
|
| 3 |
-
size 93364051
|
|
|
|
|
|
|
|
|
|
|
|
initial_epoch_60_v2.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:759a5e61e5934aa2417d587621b10a0b390c821c414380185c75fc1379e07f90
|
| 3 |
-
size 103080040
|
|
|
|
|
|
|
|
|
|
|
|
iterate.batch
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
# Define parameter ranges
|
| 4 |
-
vertex_min=0.4
|
| 5 |
-
vertex_max=0.9
|
| 6 |
-
vertex_step=0.02
|
| 7 |
-
|
| 8 |
-
edge_min=0.4
|
| 9 |
-
edge_max=0.9
|
| 10 |
-
edge_step=0.02
|
| 11 |
-
|
| 12 |
-
# Define results directory path or use first argument
|
| 13 |
-
results_dir=${1:-"/mnt/personal/skvrnjan/hoho/results"}
|
| 14 |
-
|
| 15 |
-
# Create results directory if it doesn't exist
|
| 16 |
-
mkdir -p $results_dir
|
| 17 |
-
|
| 18 |
-
# Iterate over all combinations
|
| 19 |
-
for vertex_thresh in $(seq $vertex_min $vertex_step $vertex_max); do
|
| 20 |
-
for edge_thresh in $(seq $edge_min $edge_step $edge_max); do
|
| 21 |
-
# Create job name
|
| 22 |
-
job_name="v10_train_v${vertex_thresh}_e${edge_thresh}"
|
| 23 |
-
|
| 24 |
-
# Create SLURM script
|
| 25 |
-
cat > "${job_name}.slurm" << EOF
|
| 26 |
-
#!/bin/bash
|
| 27 |
-
#SBATCH --job-name=${job_name}
|
| 28 |
-
#SBATCH --output=${job_name}_%j.out
|
| 29 |
-
#SBATCH --error=${job_name}_%j.err
|
| 30 |
-
#SBATCH --time=4:00:00
|
| 31 |
-
#SBATCH --partition=amdfast # partition name
|
| 32 |
-
#SBATCH --cpus-per-task=4 # 6 CPUS per task = 12 CPUS per node
|
| 33 |
-
#SBATCH --mem-per-cpu=10G # 8GB per CPU = 96GB per node
|
| 34 |
-
#SBATCH --nodes=1 # 1 node
|
| 35 |
-
#SBATCH --ntasks-per-node=1 # 1 tasks per node
|
| 36 |
-
|
| 37 |
-
# Run training with specific parameters
|
| 38 |
-
cd /mnt/personal/skvrnjan/hoho/
|
| 39 |
-
module purge
|
| 40 |
-
module load Python/3.10.8-GCCcore-12.2.0
|
| 41 |
-
source /mnt/personal/skvrnjan/venvs/hoho/bin/activate
|
| 42 |
-
python train.py --vertex_threshold ${vertex_thresh} --edge_threshold ${edge_thresh} --results_dir $results_dir/${job_name} --max_samples 100
|
| 43 |
-
EOF
|
| 44 |
-
|
| 45 |
-
# Submit job
|
| 46 |
-
sbatch "${job_name}.slurm"
|
| 47 |
-
|
| 48 |
-
echo "Submitted job: ${job_name}"
|
| 49 |
-
done
|
| 50 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pnet.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be55556839f8d4fedad7a5cf7520b48859d7bd9a3fbb6b2efbae627ca8ca3ffc
|
| 3 |
+
size 103080355
|
predict.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from typing import Tuple, List
|
| 3 |
from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist
|
|
|
|
| 1 |
+
# This script is designed for 3D wireframe reconstruction, primarily focusing on
|
| 2 |
+
# buildings, using multi-view imagery and associated 3D data.
|
| 3 |
+
# It leverages COLMAP reconstructions, depth maps, and semantic segmentations
|
| 4 |
+
# (ADE20k and Gestalt) to identify and predict structural elements.
|
| 5 |
+
# Core tasks include:
|
| 6 |
+
# - Processing and aligning 2D image data (segmentations, depth) with 3D COLMAP point clouds.
|
| 7 |
+
# - Extracting initial 2D/3D vertex candidates from segmentation maps.
|
| 8 |
+
# - Generating local point cloud patches around these candidates.
|
| 9 |
+
# - Employing machine learning models (e.g., PointNet variants) to refine vertex locations
|
| 10 |
+
# and classify potential edges between them.
|
| 11 |
+
# - Optionally, generating datasets of these patches for training ML models.
|
| 12 |
+
# - Merging information from multiple views to produce a final 3D wireframe.
|
| 13 |
import numpy as np
|
| 14 |
from typing import Tuple, List
|
| 15 |
from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist
|
predict_end.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from typing import Tuple, List
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
|
| 6 |
-
from predict import create_pcloud, convert_entry_to_human_readable, empty_solution
|
| 7 |
-
from end_to_end import save_data
|
| 8 |
-
|
| 9 |
-
data_folder = '/mnt/personal/skvrnjan/hoho_end/'
|
| 10 |
-
|
| 11 |
-
def predict_wireframe(entry, config) -> Tuple[np.ndarray, List[int]]:
|
| 12 |
-
"""
|
| 13 |
-
Predict 3D wireframe from a dataset entry.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
-
|
| 18 |
-
good_entry = convert_entry_to_human_readable(entry)
|
| 19 |
-
colmap_rec = good_entry['colmap_binary']
|
| 20 |
-
|
| 21 |
-
colmap_pcloud = create_pcloud(colmap_rec, good_entry)
|
| 22 |
-
|
| 23 |
-
pcloud_14d = pcloud_7d_to_14d(colmap_pcloud)
|
| 24 |
-
|
| 25 |
-
dict_to_save = {'pcloud_14d': pcloud_14d,
|
| 26 |
-
'wf_vertices': good_entry['wf_vertices'],
|
| 27 |
-
'wf_edges': good_entry['wf_edges']}
|
| 28 |
-
|
| 29 |
-
save_data(dict_to_save, good_entry['order_id'], data_folder=data_folder)
|
| 30 |
-
|
| 31 |
-
return empty_solution()
|
| 32 |
-
|
| 33 |
-
def pcloud_7d_to_14d(pcloud_7d: np.ndarray) -> np.ndarray:
|
| 34 |
-
"""
|
| 35 |
-
Convert 7D point cloud to higher dimensional by removing ID, then adding ADE and Gestalt segmentation
|
| 36 |
-
with bin counting for edge classes.
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
pcloud_7d: Array of shape (N, 7) containing [x, y, z, r, g, b, confidence]
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
Array of shape (N, 15) containing [x, y, z, r, g, b, ade_class, apex, eave_end_point,
|
| 43 |
-
flashing_end_points, eave, ridge, rake, valley, gestalt_rgb]
|
| 44 |
-
"""
|
| 45 |
-
edge_classes = ['apex', 'eave_end_point', 'flashing_end_points', 'eave', 'ridge', 'rake', 'valley']
|
| 46 |
-
|
| 47 |
-
# Extract ADE and Gestalt data from colmap_pcloud
|
| 48 |
-
ade_values = pcloud_7d['ade']
|
| 49 |
-
gestalt_values = pcloud_7d['gestalt']
|
| 50 |
-
point_cloud = pcloud_7d['points_7d']
|
| 51 |
-
|
| 52 |
-
# Initialize output array (6D base + 1 ADE + 7 edge classes + 1 gestalt)
|
| 53 |
-
pcloud_14d = np.zeros((point_cloud.shape[0], 14))
|
| 54 |
-
pcloud_14d[:, :6] = point_cloud[:, :6] # Remove confidence/ID column
|
| 55 |
-
|
| 56 |
-
# Process ADE segmentation
|
| 57 |
-
pcloud_14d[:, 6] = ade_values
|
| 58 |
-
pcloud_14d[:, 3:6] = pcloud_14d[:, 3:6] * 2 - 1
|
| 59 |
-
|
| 60 |
-
# Process Gestalt segmentation with edge class bin counting
|
| 61 |
-
for i, gestalt_list in enumerate(gestalt_values):
|
| 62 |
-
if len(gestalt_list) > 0:
|
| 63 |
-
gestalt_array = np.array(gestalt_list, dtype=np.uint32)
|
| 64 |
-
if gestalt_array.ndim == 2:
|
| 65 |
-
# Bin counting for edge classes (columns 7-13)
|
| 66 |
-
for j, edge_class in enumerate(edge_classes):
|
| 67 |
-
if edge_class in gestalt_color_mapping:
|
| 68 |
-
target_color = np.array(gestalt_color_mapping[edge_class])
|
| 69 |
-
# Count matches for this edge class
|
| 70 |
-
matches = np.sum(np.all(gestalt_array == target_color, axis=1))
|
| 71 |
-
pcloud_14d[i, 7 + j] = matches
|
| 72 |
-
|
| 73 |
-
return pcloud_14d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
script.py
CHANGED
|
@@ -74,7 +74,7 @@ if __name__ == "__main__":
|
|
| 74 |
|
| 75 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 76 |
|
| 77 |
-
pnet_model = load_pointnet_model(model_path="
|
| 78 |
|
| 79 |
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
|
| 80 |
|
|
|
|
| 74 |
|
| 75 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 76 |
|
| 77 |
+
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
|
| 78 |
|
| 79 |
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
|
| 80 |
|
train.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
from hoho2025.vis import plot_all_modalities
|
| 3 |
from hoho2025.viz3d import *
|
|
@@ -16,11 +28,8 @@ from utils import read_colmap_rec, empty_solution
|
|
| 16 |
from hoho2025.metric_helper import hss
|
| 17 |
from predict import predict_wireframe, predict_wireframe_old
|
| 18 |
from tqdm import tqdm
|
| 19 |
-
#from fast_pointnet import load_pointnet_model
|
| 20 |
from fast_pointnet_v2 import load_pointnet_model
|
| 21 |
-
from
|
| 22 |
-
from fast_pointnet_class_v2 import load_pointnet_model as load_pointnet_class_model
|
| 23 |
-
from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_model_10d
|
| 24 |
import torch
|
| 25 |
import time
|
| 26 |
|
|
@@ -47,15 +56,15 @@ print(f"Running with configuration: {config}")
|
|
| 47 |
os.makedirs(args.results_dir, exist_ok=True)
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 52 |
-
ds = ds.shuffle()
|
| 53 |
|
| 54 |
scores_hss = []
|
| 55 |
scores_f1 = []
|
| 56 |
scores_iou = []
|
| 57 |
|
| 58 |
-
show_visu =
|
| 59 |
|
| 60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
|
|
@@ -105,10 +114,10 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 105 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 106 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 107 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
| 108 |
-
wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
|
| 109 |
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
|
| 110 |
|
| 111 |
-
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 112 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 113 |
|
| 114 |
idx += 1
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training and evaluation script for HoHo wireframe prediction model.
|
| 3 |
+
This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline
|
| 4 |
+
using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports
|
| 5 |
+
configurable thresholds, visualization of results, and saves detailed performance metrics to files.
|
| 6 |
+
Key features:
|
| 7 |
+
- Command-line argument support for model configuration
|
| 8 |
+
- PointNet-based vertex and edge prediction
|
| 9 |
+
- Real-time performance monitoring and visualization
|
| 10 |
+
- Comprehensive metric evaluation and result logging
|
| 11 |
+
- Support for CUDA acceleration when available
|
| 12 |
+
"""
|
| 13 |
from datasets import load_dataset
|
| 14 |
from hoho2025.vis import plot_all_modalities
|
| 15 |
from hoho2025.viz3d import *
|
|
|
|
| 28 |
from hoho2025.metric_helper import hss
|
| 29 |
from predict import predict_wireframe, predict_wireframe_old
|
| 30 |
from tqdm import tqdm
|
|
|
|
| 31 |
from fast_pointnet_v2 import load_pointnet_model
|
| 32 |
+
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
|
|
|
|
|
|
| 33 |
import torch
|
| 34 |
import time
|
| 35 |
|
|
|
|
| 56 |
os.makedirs(args.results_dir, exist_ok=True)
|
| 57 |
|
| 58 |
|
| 59 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 60 |
+
#ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 61 |
+
#ds = ds.shuffle()
|
| 62 |
|
| 63 |
scores_hss = []
|
| 64 |
scores_f1 = []
|
| 65 |
scores_iou = []
|
| 66 |
|
| 67 |
+
show_visu = True
|
| 68 |
|
| 69 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
|
|
|
|
| 114 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 115 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 116 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
| 117 |
+
#wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
|
| 118 |
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
|
| 119 |
|
| 120 |
+
visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2
|
| 121 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 122 |
|
| 123 |
idx += 1
|
train_end.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
from datasets import load_dataset
|
| 2 |
-
from hoho2025.vis import plot_all_modalities
|
| 3 |
-
from hoho2025.viz3d import *
|
| 4 |
-
import open3d as o3d
|
| 5 |
-
|
| 6 |
-
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
|
| 7 |
-
from utils import read_colmap_rec, empty_solution
|
| 8 |
-
|
| 9 |
-
#from hoho2025.example_solutions import predict_wireframe
|
| 10 |
-
from hoho2025.metric_helper import hss
|
| 11 |
-
from predict import predict_wireframe_old
|
| 12 |
-
from predict_end import predict_wireframe
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
import torch
|
| 15 |
-
import time
|
| 16 |
-
|
| 17 |
-
#ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 18 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 19 |
-
ds = ds.shuffle()
|
| 20 |
-
|
| 21 |
-
scores_hss = []
|
| 22 |
-
scores_f1 = []
|
| 23 |
-
scores_iou = []
|
| 24 |
-
|
| 25 |
-
show_visu = False
|
| 26 |
-
|
| 27 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
-
|
| 29 |
-
config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
|
| 30 |
-
|
| 31 |
-
idx = 0
|
| 32 |
-
prediction_times = []
|
| 33 |
-
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 34 |
-
#plot_all_modalities(a)
|
| 35 |
-
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 36 |
-
#pred_vertices, pred_edges = predict_wireframe(a.copy(), config)
|
| 37 |
-
try:
|
| 38 |
-
start_time = time.time()
|
| 39 |
-
pred_vertices, pred_edges = predict_wireframe(a.copy(), config)
|
| 40 |
-
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 41 |
-
end_time = time.time()
|
| 42 |
-
prediction_time = end_time - start_time
|
| 43 |
-
prediction_times.append(prediction_time)
|
| 44 |
-
mean_time = np.mean(prediction_times)
|
| 45 |
-
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
|
| 46 |
-
except:
|
| 47 |
-
pred_vertices, pred_edges = empty_solution()
|
| 48 |
-
|
| 49 |
-
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
|
| 50 |
-
print(f"Score: {score}")
|
| 51 |
-
scores_hss.append(score.hss)
|
| 52 |
-
scores_f1.append(score.f1)
|
| 53 |
-
scores_iou.append(score.iou)
|
| 54 |
-
|
| 55 |
-
if show_visu:
|
| 56 |
-
colmap = read_colmap_rec(a['colmap_binary'])
|
| 57 |
-
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 58 |
-
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
| 59 |
-
wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
|
| 60 |
-
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
|
| 61 |
-
|
| 62 |
-
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 63 |
-
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 64 |
-
|
| 65 |
-
idx += 1
|
| 66 |
-
|
| 67 |
-
for i in range(10):
|
| 68 |
-
print("END OF DATASET")
|
| 69 |
-
print(f"Mean HSS: {np.mean(scores_hss):.4f}")
|
| 70 |
-
print(f"Mean F1: {np.mean(scores_f1):.4f}")
|
| 71 |
-
print(f"Mean IoU: {np.mean(scores_iou):.4f}")
|
| 72 |
-
print(config)
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_pointnet import train_pointnet
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/home/skvrnjan/personal/hohocustom/"
|
| 8 |
-
model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_class.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fast_pointnet_class import train_pointnet
|
| 2 |
import os
|
| 3 |
|
|
@@ -10,4 +22,4 @@ if __name__ == "__main__":
|
|
| 10 |
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
|
| 12 |
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path)
|
|
|
|
| 1 |
+
# This script serves as the main entry point for training a PointNet-based
|
| 2 |
+
# classification model.
|
| 3 |
+
#
|
| 4 |
+
# It imports the necessary training function `train_pointnet` from the
|
| 5 |
+
# `fast_pointnet_class` module.
|
| 6 |
+
#
|
| 7 |
+
# The script defines file paths for the input dataset and the directory
|
| 8 |
+
# where the trained model will be saved. It ensures that the model saving
|
| 9 |
+
# directory exists before starting the training.
|
| 10 |
+
#
|
| 11 |
+
# Finally, it initiates the training process by calling the `train_pointnet`
|
| 12 |
+
# function with the specified dataset path, model save path, and a batch size.
|
| 13 |
from fast_pointnet_class import train_pointnet
|
| 14 |
import os
|
| 15 |
|
|
|
|
| 22 |
os.makedirs(model_save_path, exist_ok=True)
|
| 23 |
|
| 24 |
# Train the model
|
| 25 |
+
train_pointnet(dataset_path, model_save_path, batch_size=4)
|
train_pnet_class_cluster.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_class import train_pointnet
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_v4/initial.pth"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_class_cluster_10d.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_class_10d import train_pointnet
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_v2/initial.pth"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_class_cluster_10d_2048.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_class_10d_2048 import train_pointnet
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d_1m/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_2048/initial.pth"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_class_cluster_10d_deeper.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_class_10d_deeper import train_pointnet
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_10d_deeper_v2/initial.pth"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_cluster.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from fast_pointnet import train_pointnet
|
| 2 |
-
|
| 3 |
-
if __name__ == "__main__":
|
| 4 |
-
|
| 5 |
-
# Load the dataset
|
| 6 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom/"
|
| 7 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet/initial.pth"
|
| 8 |
-
|
| 9 |
-
# Train the model
|
| 10 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, score_weight=1.0, class_weight=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_cluster_class_v2.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_class_v2 import train_pointnet
|
| 2 |
-
|
| 3 |
-
if __name__ == "__main__":
|
| 4 |
-
|
| 5 |
-
# Load the dataset
|
| 6 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges_10d_v5/"
|
| 7 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_class_v2/initial.pth"
|
| 8 |
-
|
| 9 |
-
# Train the model
|
| 10 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_cluster_v3.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from fast_pointnet_v3 import train_pointnet
|
| 2 |
-
|
| 3 |
-
if __name__ == "__main__":
|
| 4 |
-
|
| 5 |
-
# Load the dataset
|
| 6 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom_v4/"
|
| 7 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_v11/initial.pth"
|
| 8 |
-
|
| 9 |
-
# Train the model
|
| 10 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=256, lr=0.001, score_weight=0.25, class_weight=1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_pnet_cluster_v2.py → train_pnet_v2.py
RENAMED
|
@@ -3,8 +3,8 @@ from fast_pointnet_v2 import train_pointnet
|
|
| 3 |
if __name__ == "__main__":
|
| 4 |
|
| 5 |
# Load the dataset
|
| 6 |
-
dataset_path = "
|
| 7 |
-
model_save_path = "
|
| 8 |
|
| 9 |
# Train the model
|
| 10 |
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001, score_weight=0.25, class_weight=1.0)
|
|
|
|
| 3 |
if __name__ == "__main__":
|
| 4 |
|
| 5 |
# Load the dataset
|
| 6 |
+
dataset_path = "xx"
|
| 7 |
+
model_save_path = "xx.pth"
|
| 8 |
|
| 9 |
# Train the model
|
| 10 |
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001, score_weight=0.25, class_weight=1.0)
|
train_voxel.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_voxel import train_3dcnn
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/home/skvrnjan/personal/hohocustom/"
|
| 8 |
-
model_save_path = "/home/skvrnjan/personal/hoho_voxel/"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_3dcnn(dataset_path, model_save_path, epochs=100, batch_size=16, lr=0.001, voxel_size=32, score_weight=0.5, class_weight=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_voxel_cluster.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from fast_voxel import train_3dcnn
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
# Load the dataset
|
| 7 |
-
dataset_path = "/mnt/personal/skvrnjan/hohocustom/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/hoho_voxel/initial.pth"
|
| 9 |
-
|
| 10 |
-
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# Train the model
|
| 13 |
-
train_3dcnn(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, voxel_size=32, score_weight=0.5, class_weight=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|