| """ |
| End-to-End Voxel-Based Vertex Detection Pipeline |
| |
| This file implements a complete pipeline for detecting wireframe vertices from 3D point clouds using |
| a voxel-based deep learning approach. The pipeline includes: |
| |
| 1. Data preprocessing: Converting 14D point clouds into 3D voxel grids with averaged features |
| 2. Ground truth generation: Creating binary vertex labels and refinement targets from wireframe vertices |
| 3. Model architecture: VoxelUNet with encoder-decoder structure and 1x1x1 bottleneck for vertex detection |
| 4. Training: Combined loss function with BCE, Dice loss, and MSE for offset regression |
| 5. Inference: Predicting vertex locations from new point clouds with visualization |
| |
| Key components: |
| - Voxelization with configurable grid size and metric voxel size |
| - Per-voxel MLP before convolutional processing |
| - Gaussian smoothing of ground truth labels |
| - Refinement prediction for sub-voxel accuracy |
| - PyVista-based visualization for results analysis |
| |
| Usage: |
| - Set inference=False to train a new model |
| - Set inference=True to run predictions on existing data |
| """ |
|
|
| import os |
| import pickle |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| from typing import Dict, Any, Tuple, List |
| from torch.utils.data import Dataset, DataLoader |
| import glob |
| import pyvista as pv |
| import torch |
|
|
| |
| |
|
|
| def save_data(dict_to_save: Dict[str, Any], filename: str, data_folder: str = "data") -> None: |
| """Save dictionary data to pickle file""" |
| os.makedirs(data_folder, exist_ok=True) |
| filepath = os.path.join(data_folder, f"{filename}.pkl") |
| with open(filepath, 'wb') as f: |
| pickle.dump(dict_to_save, f) |
| |
|
|
| def load_data(filepath: str) -> Dict[str, Any]: |
| """Load dictionary data from pickle file""" |
| with open(filepath, 'rb') as f: |
| data = pickle.load(f) |
| |
| return data |
|
|
| def get_data_files(data_folder: str = "data", pattern: str = "*.pkl") -> List[str]: |
| """Get list of data files from folder""" |
| search_pattern = os.path.join(data_folder, pattern) |
| files = glob.glob(search_pattern) |
| |
| return files |
|
|
| def voxelize_points(points: np.ndarray, |
| grid_size_xy: int = 64, |
| voxel_size_metric: float = 0.25 |
| ) -> Tuple[torch.Tensor, np.ndarray, Dict[str, Any]]: |
| """ |
| Voxelize 14D point cloud into a 3D grid with a fixed number of voxels and fixed metric voxel size. |
| The Z dimension of the grid will also have `grid_size_xy` voxels, forming a cubic grid. |
| The point cloud is centered within this metric grid. Points outside are discarded. |
| Features from points falling into the same voxel are averaged. |
| |
| Args: |
| points: (N, 14) array where first 3 dims are xyz (original coordinates). |
| grid_size_xy: Number of voxels along X and Y dimensions (and Z). |
| voxel_size_metric: The physical size of each voxel (e.g., 0.5 units). |
| |
| Returns: |
| voxel_grid: (NUM_FEATURES, dim_z, dim_y, dim_x) tensor with averaged features. |
| voxel_indices_for_points: (N_points_in_grid, 3) integer voxel indices (z, y, x) |
| for each input point that falls within the grid. |
| scale_info: Dict with transformation parameters: |
| 'grid_origin_metric': Real-world metric coordinate of the corner of voxel [0,0,0] (x,y,z). |
| 'voxel_size_metric': The metric size of a voxel. |
| 'grid_dims_voxels': Tuple (dim_x, dim_y, dim_z) representing number of voxels. |
| 'pc_centroid_metric': Centroid of the input point cloud (x,y,z). |
| """ |
| NUM_FEATURES = 14 |
| dim_x = grid_size_xy |
| dim_y = grid_size_xy |
| dim_z = grid_size_xy |
|
|
| if dim_z == 0: dim_z = 1 |
|
|
| grid_dims_voxels = np.array([dim_x, dim_y, dim_z], dtype=int) |
|
|
| def _get_empty_return(reason: str = ""): |
| voxel_grid_empty = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32) |
| voxel_indices_empty = np.empty((0, 3), dtype=int) |
| scale_info_empty = { |
| 'grid_origin_metric': np.zeros(3, dtype=float), |
| 'voxel_size_metric': voxel_size_metric, |
| 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()), |
| 'pc_centroid_metric': np.zeros(3, dtype=float), |
| } |
| return voxel_grid_empty, voxel_indices_empty, scale_info_empty |
|
|
| if points.shape[0] == 0: |
| return _get_empty_return("Initial empty point cloud") |
|
|
| xyz = points[:, :3] |
| features_other = points[:, 3:] |
|
|
| pc_centroid_metric = xyz.mean(axis=0) |
| |
| grid_metric_span = grid_dims_voxels * voxel_size_metric |
| grid_origin_metric = pc_centroid_metric - (grid_metric_span / 2.0) |
| |
| |
| voxel_grid_sum = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32) |
| |
| point_counts_in_voxel = torch.zeros(grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.int32) |
| |
| continuous_voxel_coords = (xyz - grid_origin_metric) / voxel_size_metric |
| |
| voxel_indices_for_points_zyx_order = [] |
|
|
| for i in range(points.shape[0]): |
| current_point_continuous_coord_xyz = continuous_voxel_coords[i] |
| |
| voxel_idx_int_xyz = np.round(current_point_continuous_coord_xyz).astype(int) |
| |
| idx_x, idx_y, idx_z = voxel_idx_int_xyz[0], voxel_idx_int_xyz[1], voxel_idx_int_xyz[2] |
|
|
| if not (0 <= idx_x < grid_dims_voxels[0] and \ |
| 0 <= idx_y < grid_dims_voxels[1] and \ |
| 0 <= idx_z < grid_dims_voxels[2]): |
| continue |
| |
| voxel_indices_for_points_zyx_order.append([idx_z, idx_y, idx_x]) |
|
|
| assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5 |
| offset_xyz_in_grid_units = current_point_continuous_coord_xyz - assigned_voxel_center_grid_idx_space |
| |
| |
| voxel_grid_sum[0, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[0] |
| voxel_grid_sum[1, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[1] |
| voxel_grid_sum[2, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[2] |
| |
| if NUM_FEATURES > 3: |
| current_point_other_features = features_other[i] |
| voxel_grid_sum[3:, idx_z, idx_y, idx_x] += torch.tensor(current_point_other_features, dtype=torch.float32) |
| |
| point_counts_in_voxel[idx_z, idx_y, idx_x] += 1 |
|
|
| |
| |
| voxel_grid = torch.zeros_like(voxel_grid_sum) |
|
|
| |
| |
| counts_for_division = point_counts_in_voxel.float() |
| |
| |
| |
| counts_for_division[counts_for_division == 0] = 1.0 |
|
|
| |
| |
| |
| voxel_grid = voxel_grid_sum / counts_for_division.unsqueeze(0) |
|
|
| final_voxel_indices_for_points_zyx = np.array(voxel_indices_for_points_zyx_order, dtype=int) if voxel_indices_for_points_zyx_order else np.empty((0,3), dtype=int) |
|
|
| scale_info = { |
| 'grid_origin_metric': grid_origin_metric, |
| 'voxel_size_metric': voxel_size_metric, |
| 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()), |
| 'pc_centroid_metric': pc_centroid_metric, |
| } |
| |
| return voxel_grid, final_voxel_indices_for_points_zyx, scale_info |
|
|
|
|
| def create_ground_truth(vertices: np.ndarray, |
| scale_info: Dict[str, Any] |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Create ground truth voxel labels and refinement targets using metric voxelization info. |
| The grid dimensions are taken from scale_info. |
| |
| Args: |
| vertices: (M, 3) vertex coordinates in original metric space. |
| scale_info: Dict from voxelize_points. Requires: |
| 'grid_origin_metric', 'voxel_size_metric', 'grid_dims_voxels'. |
| Returns: |
| vertex_labels: (dim_z, dim_y, dim_x) binary labels (1.0 for voxel containing a vertex). |
| refinement_targets: (3, dim_z, dim_y, dim_x) offset (dx,dy,dz) from voxel cell center |
| in grid units. Range approx [-0.5, 0.5). |
| """ |
| grid_origin_metric = scale_info['grid_origin_metric'] |
| voxel_size_metric = scale_info['voxel_size_metric'] |
| |
| grid_dims_voxels = np.array(scale_info['grid_dims_voxels']) |
|
|
| dim_x, dim_y, dim_z = grid_dims_voxels[0], grid_dims_voxels[1], grid_dims_voxels[2] |
|
|
| |
| vertex_labels = torch.zeros(dim_z, dim_y, dim_x, dtype=torch.float32) |
| |
| refinement_targets = torch.zeros(3, dim_z, dim_y, dim_x, dtype=torch.float32) |
|
|
| if vertices.shape[0] == 0: |
| return vertex_labels, refinement_targets |
|
|
| |
| |
| continuous_voxel_coords_vertices = (vertices - grid_origin_metric) / voxel_size_metric |
| |
| for i in range(vertices.shape[0]): |
| |
| v_continuous_coord_xyz = continuous_voxel_coords_vertices[i] |
| |
| |
| v_idx_int_xyz = np.floor(v_continuous_coord_xyz).astype(int) |
| |
| |
| idx_x = np.clip(v_idx_int_xyz[0], 0, dim_x - 1) |
| idx_y = np.clip(v_idx_int_xyz[1], 0, dim_y - 1) |
| idx_z = np.clip(v_idx_int_xyz[2], 0, dim_z - 1) |
| |
| |
| vertex_labels[idx_z, idx_y, idx_x] = 1.0 |
| |
| |
| |
| assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5 |
| |
| |
| offset_xyz_grid_units = v_continuous_coord_xyz - assigned_voxel_center_grid_idx_space |
| |
| |
| |
| refinement_targets[0, idx_z, idx_y, idx_x] = offset_xyz_grid_units[0] |
| refinement_targets[1, idx_z, idx_y, idx_x] = offset_xyz_grid_units[1] |
| refinement_targets[2, idx_z, idx_y, idx_x] = offset_xyz_grid_units[2] |
| |
| return vertex_labels, refinement_targets |
|
|
| class VoxelUNet(nn.Module): |
| """Encoder-decoder network with a 1x1x1 bottleneck for voxel-based vertex detection. |
| Includes a per-voxel MLP before the first convolutional block.""" |
| |
| def __init__(self, in_channels: int = 14, base_channels: int = 32, bottleneck_expansion: int = 2, mlp_hidden_factor: int = 2): |
| super(VoxelUNet, self).__init__() |
| |
| bc = base_channels |
| |
| |
| |
| |
| |
| mlp_hidden_dim = in_channels * mlp_hidden_factor |
| self.voxel_mlp = nn.Sequential( |
| nn.Linear(in_channels, mlp_hidden_dim), |
| nn.ReLU(inplace=True), |
| nn.Linear(mlp_hidden_dim, bc) |
| ) |
|
|
| |
| |
| self.enc1 = self._conv_block(bc, bc) |
| self.enc2 = self._conv_block(bc, bc * 2) |
| self.enc3 = self._conv_block(bc * 2, bc * 4) |
| self.enc4 = self._conv_block(bc * 4, bc * 8) |
| self.enc5 = self._conv_block(bc * 8, bc * 16) |
| |
| self.pool = nn.MaxPool3d(2) |
| |
| |
| self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) |
| bottleneck_in_channels = bc * 16 |
| |
| bottleneck_width = bottleneck_in_channels * bottleneck_expansion |
| |
| self.bottleneck = nn.Sequential( |
| nn.Conv3d(bottleneck_in_channels, bottleneck_width, kernel_size=1, padding=0, bias=True), |
| nn.ReLU(inplace=True), |
| |
| nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True), |
| nn.ReLU(inplace=True) |
| ) |
| |
| |
| |
| self.dec5 = self._conv_block(bottleneck_width, bc * 16) |
| |
| self.up4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) |
| self.dec4 = self._conv_block(bc * 16, bc * 8) |
| |
| self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) |
| self.dec3 = self._conv_block(bc * 8, bc * 4) |
| |
| self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) |
| self.dec2 = self._conv_block(bc * 4, bc * 2) |
| |
| self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) |
| self.dec1 = self._conv_block(bc * 2, bc) |
| |
| |
| |
| |
| self.vertex_head = nn.Sequential( |
| nn.Conv3d(bc, bc // 2, kernel_size=1), |
| nn.ReLU(inplace=True), |
| nn.Conv3d(bc // 2, bc // 4, kernel_size=1), |
| nn.ReLU(inplace=True), |
| nn.Conv3d(bc // 4, 1, kernel_size=1) |
| ) |
| self.refinement_head = nn.Conv3d(bc, 3, kernel_size=1) |
| |
| self.tanh = nn.Tanh() |
|
|
| def _conv_block(self, in_channels: int, out_channels: int) -> nn.Sequential: |
| |
| |
| return nn.Sequential( |
| nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| nn.BatchNorm3d(out_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| nn.BatchNorm3d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| |
| |
| B, C_in_raw, D, H, W = x.shape |
| |
| x_permuted = x.permute(0, 2, 3, 4, 1).contiguous() |
| |
| x_flattened = x_permuted.view(-1, C_in_raw) |
| |
| |
| mlp_out_flattened = self.voxel_mlp(x_flattened) |
| |
| C_mlp_out = mlp_out_flattened.shape[-1] |
| |
| x_mlp_reshaped = mlp_out_flattened.view(B, D, H, W, C_mlp_out) |
| |
| x_processed = x_mlp_reshaped.permute(0, 4, 1, 2, 3).contiguous() |
| |
| |
| e1 = self.enc1(x_processed) |
| p1 = self.pool(e1) |
| |
| e2 = self.enc2(p1) |
| p2 = self.pool(e2) |
| |
| e3 = self.enc3(p2) |
| p3 = self.pool(e3) |
| |
| e4 = self.enc4(p3) |
| p4 = self.pool(e4) |
| |
| e5 = self.enc5(p4) |
| p5 = self.pool(e5) |
| |
| |
| b_pooled = self.adaptive_pool(p5) |
| b = self.bottleneck(b_pooled) |
| |
| |
| |
| u5_from_b = nn.functional.interpolate(b, size=e5.shape[2:], mode='trilinear', align_corners=True) |
| d5 = self.dec5(u5_from_b) |
| |
| u4 = self.up4(d5) |
| d4 = self.dec4(u4) |
| |
| u3 = self.up3(d4) |
| d3 = self.dec3(u3) |
| |
| u2 = self.up2(d3) |
| d2 = self.dec2(u2) |
| |
| u1 = self.up1(d2) |
| d1 = self.dec1(u1) |
| |
| |
| vertex_logits = self.vertex_head(d1) |
| refinement = self.tanh(self.refinement_head(d1)) * 0.5 |
| |
| return vertex_logits, refinement |
|
|
| class VoxelDataset(Dataset): |
| def __init__(self, data_files: List[str], voxel_size: float = 0.1, grid_size: int = 64): |
| self.data_files = data_files |
| self.voxel_size = voxel_size |
| self.grid_size = grid_size |
| |
| def __len__(self): |
| return len(self.data_files) |
| |
| def __getitem__(self, idx): |
| data = load_data(self.data_files[idx]) |
| |
| voxel_grid, _, scale_info = voxelize_points( |
| data['pcloud_14d'], self.grid_size, self.voxel_size |
| ) |
| |
| wf_vertices_np = np.array(data['wf_vertices']) |
| vertex_labels, refinement_targets = create_ground_truth( |
| wf_vertices_np, scale_info |
| ) |
| |
| return voxel_grid, vertex_labels, refinement_targets, scale_info |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Tuple |
|
|
| class CombinedLoss(nn.Module): |
| """ |
| Combined loss for vertex classification and offset regression. |
| Uses: |
| - BCEWithLogitsLoss (with configurable negative/positive sample weighting) |
| - Dice loss |
| - MSE loss on refinement offsets (only over positive voxels) |
| - Gaussian blur on the GT labels |
| """ |
| def __init__(self, |
| vertex_weight: float = 1.0, |
| refinement_weight: float = 0.0, |
| dice_weight: float = 0.5, |
| bce_neg_pos_ratio: float = 1.0, |
| blur_kernel_size: int = 5, |
| blur_sigma: float = 1.0, |
| eps: float = 1e-6): |
| super().__init__() |
| self.vertex_weight = vertex_weight |
| self.refinement_weight = refinement_weight |
| self.dice_weight = dice_weight |
| self.bce_neg_pos_ratio = bce_neg_pos_ratio |
| self.eps = eps |
|
|
| |
| self.bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none') |
| |
| self.mse_loss = nn.MSELoss() |
|
|
| |
| k = blur_kernel_size |
| coords = torch.arange(k, dtype=torch.float32) - (k - 1) / 2 |
| xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij') |
| kernel = torch.exp(-(xx**2 + yy**2 + zz**2) / (2 * blur_sigma**2)) |
| |
| kernel = kernel.view(1, 1, k, k, k) |
| self.register_buffer('gaussian_kernel', kernel) |
| self.pad = k // 2 |
|
|
| def forward(self, |
| vertex_logits_pred: torch.Tensor, |
| refinement_pred: torch.Tensor, |
| vertex_gt: torch.Tensor, |
| refinement_gt: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
| |
| logits = vertex_logits_pred.squeeze(1) |
| gt = vertex_gt.float() |
|
|
| |
| gt_unsq = gt.unsqueeze(1) |
| gt_blur = F.conv3d(gt_unsq, self.gaussian_kernel, padding=self.pad) |
| gt_blur = gt_blur.clamp(0, 1) |
| gt_smooth = gt_blur.squeeze(1) |
|
|
| |
| pos_mask = gt_smooth > 1e-3 |
| neg_mask = ~pos_mask |
|
|
| bce_all = self.bce_loss_fn(logits, gt_smooth) |
|
|
| |
| pos_weight_factor = 1.0 |
| neg_weight_factor = self.bce_neg_pos_ratio |
| |
| bce = torch.tensor(0.0, device=logits.device) |
| num_pos = pos_mask.sum() |
| num_neg = neg_mask.sum() |
|
|
| if num_pos > 0 and num_neg > 0: |
| mean_pos_loss = bce_all[pos_mask].mean() |
| mean_neg_loss = bce_all[neg_mask].mean() |
| bce = pos_weight_factor * mean_pos_loss + neg_weight_factor * mean_neg_loss |
| elif num_pos > 0: |
| mean_pos_loss = bce_all[pos_mask].mean() |
| bce = pos_weight_factor * mean_pos_loss |
| elif num_neg > 0: |
| mean_neg_loss = bce_all[neg_mask].mean() |
| bce = neg_weight_factor * mean_neg_loss |
| |
|
|
| |
| prob = torch.sigmoid(logits) |
| |
| gt_smooth_round_for_dice = gt_smooth |
| intersection = (prob * gt_smooth_round_for_dice).sum(dim=[1,2,3]) |
| union = prob.sum(dim=[1,2,3]) + gt_smooth_round_for_dice.sum(dim=[1,2,3]) |
| dice_score = (2 * intersection + self.eps) / (union + self.eps) |
| dice_loss = 1 - dice_score.mean() |
|
|
| vertex_loss = bce + self.dice_weight * dice_loss |
|
|
| |
| |
| mask_pos_refinement = (gt > 0.5).unsqueeze(1) |
| |
| refinement_loss = torch.tensor(0., device=logits.device) |
| if mask_pos_refinement.sum() > 0: |
| |
| expanded_mask = mask_pos_refinement.expand_as(refinement_pred) |
| |
| pred_offsets = refinement_pred[expanded_mask].view(-1, 3) |
| gt_offsets = refinement_gt[expanded_mask].view(-1, 3) |
| |
| if pred_offsets.numel() > 0: |
| refinement_loss = self.mse_loss(pred_offsets, gt_offsets) |
| |
| |
| total_loss = (self.vertex_weight * vertex_loss + |
| self.refinement_weight * refinement_loss) |
|
|
| return total_loss, vertex_loss, refinement_loss |
|
|
| def train_epoch(model, dataloader, optimizer, criterion, device, current_epoch: int): |
| model.train() |
| total_loss_epoch = 0.0 |
| vertex_loss_epoch = 0.0 |
| refinement_loss_epoch = 0.0 |
| |
| for batch_idx, (voxel_grid_batch, vertex_labels_batch, refinement_targets_batch, _) in enumerate(dataloader): |
| voxel_grid_batch = voxel_grid_batch.to(device) |
| vertex_labels_batch = vertex_labels_batch.to(device) |
| refinement_targets_batch = refinement_targets_batch.to(device) |
|
|
| if False: |
| print(f'Epoch {current_epoch+1}, Batch {batch_idx+1}/{len(dataloader)}') |
| |
| sample_voxel_features = voxel_grid_batch[0].cpu().numpy() |
| sample_gt_labels = vertex_labels_batch[0].cpu().numpy() |
| sample_gt_refinement = refinement_targets_batch[0].cpu().numpy() |
| |
| summed_xyz_in_voxels = sample_voxel_features[:3] |
| occupied_voxel_mask = np.any(summed_xyz_in_voxels != 0, axis=0) |
| |
| plotter = pv.Plotter(window_size=[800,600]) |
| plotter.background_color = 'white' |
| |
| if np.any(occupied_voxel_mask): |
| occupied_voxel_indices = np.array(np.where(occupied_voxel_mask)).T |
| input_points_display = pv.PolyData(occupied_voxel_indices + 0.5) |
| plotter.add_mesh(input_points_display, color='cornflowerblue', point_size=5, render_points_as_spheres=True, label='Occupied Voxels (Centers)') |
|
|
| gt_vertex_voxel_mask = sample_gt_labels > 0.5 |
| if np.any(gt_vertex_voxel_mask): |
| gt_vertex_indices_int = np.array(np.where(gt_vertex_voxel_mask)).T |
| gt_offsets = sample_gt_refinement[:, gt_vertex_voxel_mask].T |
| gt_vertex_positions_grid_space = gt_vertex_indices_int.astype(float) + 0.5 + gt_offsets |
| |
| target_vertices_display = pv.PolyData(gt_vertex_positions_grid_space) |
| plotter.add_mesh(target_vertices_display, color='crimson', point_size=10, render_points_as_spheres=True, label='Target Vertices (GT)') |
| |
| plotter.show(title=f"Debug Viz E{current_epoch+1} B{batch_idx+1}", auto_close=False) |
| else: |
| print(f"Epoch {current_epoch+1} Batch {batch_idx+1}: No data to visualize for the first sample.") |
| |
| optimizer.zero_grad() |
| vertex_logits_pred, refinement_pred = model(voxel_grid_batch) |
| |
| loss, vertex_loss, refinement_loss = criterion( |
| vertex_logits_pred, refinement_pred, vertex_labels_batch, refinement_targets_batch |
| ) |
|
|
| print(f"Batch {batch_idx+1}/{len(dataloader)}: Loss={loss.item():.4f}, Vertex Loss={vertex_loss.item():.4f}, Refinement Loss={refinement_loss.item():.4f}") |
| |
| if loss > 0.000001: |
| loss.backward() |
| optimizer.step() |
| |
| total_loss_epoch += loss.item() |
| vertex_loss_epoch += vertex_loss.item() |
| refinement_loss_epoch += refinement_loss.item() |
|
|
| if (batch_idx + 1) % 200 == 0: |
| checkpoint_path = f"model_epoch_{current_epoch+1}_batch_{batch_idx+1}_grid_128v9.pth" |
| torch.save(model.state_dict(), checkpoint_path) |
| print(f"Saved batch checkpoint: {checkpoint_path}") |
|
|
| avg_total_loss = total_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 |
| avg_vertex_loss = vertex_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 |
| avg_refinement_loss = refinement_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 |
| |
| return avg_total_loss, avg_vertex_loss, avg_refinement_loss |
|
|
| def train_model(data_folder: str = "data", num_epochs: int = 100, batch_size: int = 4, neg_pos_ratio_val: float = 1.0): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| data_files = get_data_files(data_folder) |
| if not data_files: |
| print(f"No data files found in {data_folder}. Exiting.") |
| return |
| |
| GRID_SIZE_CFG = 128 |
| VOXEL_SIZE_CFG = 0.5 |
|
|
| dataset = VoxelDataset(data_files, voxel_size=VOXEL_SIZE_CFG, grid_size=GRID_SIZE_CFG) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8) |
| |
| model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10).to(device) |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) |
| |
| criterion = CombinedLoss( |
| vertex_weight=10.0, |
| refinement_weight=0.0, |
| dice_weight=-0.0 |
| ).to(device) |
| |
| print(f"Starting training: {num_epochs} epochs, Batch size: {batch_size}, Grid size: {GRID_SIZE_CFG}, Voxel size: {VOXEL_SIZE_CFG}, Initial LR: {optimizer.param_groups[0]['lr']}") |
|
|
| for epoch in range(num_epochs): |
| print(f"\n--- Epoch {epoch+1}/{num_epochs} ---") |
| |
| avg_loss, avg_vertex_loss, avg_refinement_loss = train_epoch( |
| model, dataloader, optimizer, criterion, device, epoch |
| ) |
| |
| print(f"Epoch {epoch+1} Summary: Avg Loss: {avg_loss:.4f}, " |
| f"Avg Vertex Loss: {avg_vertex_loss:.4f}, " |
| f"Avg Refinement Loss: {avg_refinement_loss:.4f}, " |
| f"Current LR: {optimizer.param_groups[0]['lr']:.6f}") |
| |
| checkpoint_path = f"model_epoch_{epoch+1}_grid{GRID_SIZE_CFG}_smooth_bal{neg_pos_ratio_val}_v9.pth" |
| torch.save(model.state_dict(), checkpoint_path) |
| print(f"Saved checkpoint: {checkpoint_path}") |
| |
| final_model_path = f"final_model_grid{GRID_SIZE_CFG}_epochs{num_epochs}_smooth_bal{neg_pos_ratio_val}_v9.pth" |
| torch.save(model.state_dict(), final_model_path) |
| print(f"Training completed! Final model saved as {final_model_path}") |
|
|
| def load_model_for_inference(model_path: str, device: torch.device, |
| in_channels: int = 14, base_channels: int = 32) -> VoxelUNet: |
| """Load a VoxelUNet model for inference.""" |
| model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.to(device) |
| model.eval() |
| print(f"Model loaded from {model_path} and set to evaluation mode on {device}.") |
| return model |
|
|
| def predict_vertices(model: VoxelUNet, |
| point_cloud_14d: np.ndarray, |
| grid_size: int, |
| device: torch.device, |
| voxel_size_metric: float = 0.5, |
| vertex_threshold: float = 0.5) -> np.ndarray: |
| """ |
| Predict vertices from a 14D point cloud. |
| |
| Args: |
| model: The trained VoxelUNet model. |
| point_cloud_14d: (N, 14) NumPy array of the input point cloud. |
| grid_size: The size of the voxel grid along X and Y dimensions (must match training). |
| device: PyTorch device ('cuda' or 'cpu'). |
| voxel_size_metric: The metric size of each voxel (must match training). |
| vertex_threshold: Threshold for classifying a voxel as containing a vertex. |
| |
| Returns: |
| predicted_vertices_original_space: (M, 3) NumPy array of predicted vertex |
| coordinates in the original point cloud space (X, Y, Z order). |
| Returns an empty array if no vertices are predicted |
| or if the input point cloud results in an empty voxel grid. |
| """ |
| voxel_grid_tensor, _, scale_info = voxelize_points( |
| point_cloud_14d, |
| grid_size_xy=grid_size, |
| voxel_size_metric=voxel_size_metric |
| ) |
|
|
| |
| |
| |
| if voxel_grid_tensor.sum() == 0 and point_cloud_14d.shape[0] == 0: |
| |
| |
| pass |
|
|
| input_tensor = voxel_grid_tensor.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| vertex_logits_pred_tensor, refinement_pred_tensor = model(input_tensor) |
|
|
| vertex_prob_pred_tensor = torch.sigmoid(vertex_logits_pred_tensor) |
|
|
| vertex_prob_pred_np = vertex_prob_pred_tensor.squeeze(0).squeeze(0).cpu().numpy() |
| refinement_pred_np = refinement_pred_tensor.squeeze(0).cpu().numpy() |
|
|
| print(f"Vertex Probabilities Stats: Min={np.min(vertex_prob_pred_np):.4f}, Max={np.max(vertex_prob_pred_np):.4f}, Mean={np.mean(vertex_prob_pred_np):.4f}, Median={np.median(vertex_prob_pred_np):.4f}") |
| if refinement_pred_np.size > 0: |
| print(f"Refinement Predictions Stats: Min={np.min(refinement_pred_np):.4f}, Max={np.max(refinement_pred_np):.4f}, Mean={np.mean(refinement_pred_np):.4f}, Median={np.median(refinement_pred_np):.4f}") |
| for i in range(refinement_pred_np.shape[0]): |
| print(f" Refinement Dim {i} (dx,dy,dz order) Stats: Min={np.min(refinement_pred_np[i]):.4f}, Max={np.max(refinement_pred_np[i]):.4f}, Mean={np.mean(refinement_pred_np[i]):.4f}, Median={np.median(refinement_pred_np[i]):.4f}") |
| else: |
| print("Refinement Predictions Stats: Array is empty.") |
|
|
| predicted_mask = vertex_prob_pred_np > vertex_threshold |
| |
| predicted_voxel_indices_zyx = np.argwhere(predicted_mask) |
|
|
| if not predicted_voxel_indices_zyx.size: |
| return np.empty((0, 3), dtype=np.float32) |
|
|
| |
| |
| offsets_channels_first = refinement_pred_np[:, |
| predicted_voxel_indices_zyx[:, 0], |
| predicted_voxel_indices_zyx[:, 1], |
| predicted_voxel_indices_zyx[:, 2]] |
| |
| |
| offsets_xyz_order = offsets_channels_first.T |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| refined_x_grid = predicted_voxel_indices_zyx[:, 2].astype(np.float32) + 0.5 |
| refined_y_grid = predicted_voxel_indices_zyx[:, 1].astype(np.float32) + 0.5 |
| refined_z_grid = predicted_voxel_indices_zyx[:, 0].astype(np.float32) + 0.5 |
|
|
| |
| refined_grid_coords_xyz = np.stack((refined_x_grid, refined_y_grid, refined_z_grid), axis=-1) |
|
|
| |
| grid_origin_metric = np.array(scale_info['grid_origin_metric']) |
| |
| current_voxel_size_metric = scale_info['voxel_size_metric'] |
|
|
| |
| predicted_vertices_original_space = refined_grid_coords_xyz * current_voxel_size_metric + grid_origin_metric |
| |
| return predicted_vertices_original_space.astype(np.float32) |
|
|
| |
| def run_inference(model_path: str, |
| data_file_path: str, |
| output_file: str = None, |
| grid_size: int = 128, |
| voxel_size: float = 0.5, |
| vertex_threshold: float = 0.5): |
| """ |
| Run inference on all data files in a directory, visualize with pyvista, and save results. |
| """ |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| |
| model = load_model_for_inference(model_path, device) |
| |
| |
| data_files = get_data_files(data_file_path) |
| if not data_files: |
| print(f"No data files found in {data_file_path}") |
| return |
| |
| print(f"Found {len(data_files)} data files to process") |
| |
| for i, file_path in enumerate(data_files): |
| print(f"\n--- Processing file {i+1}/{len(data_files)}: {os.path.basename(file_path)} ---") |
| |
| |
| try: |
| data = load_data(file_path) |
| except Exception as e: |
| print(f"Error loading {file_path}: {e}") |
| continue |
| |
| if 'pcloud_14d' not in data: |
| print(f"Error: File {file_path} does not contain 'pcloud_14d' key, skipping") |
| continue |
| |
| |
| pcloud = data['pcloud_14d'][:, :3] |
| gt_vertices = np.array(data.get('wf_vertices', [])) |
| |
| print(f"Input point cloud shape: {pcloud.shape}") |
| if gt_vertices.size: |
| print(f"GT vertices shape: {gt_vertices.shape}") |
| |
| |
| print("Running inference...") |
| try: |
| predicted_vertices = predict_vertices( |
| model=model, |
| point_cloud_14d=data['pcloud_14d'], |
| grid_size=grid_size, |
| device=device, |
| voxel_size_metric=voxel_size, |
| vertex_threshold=vertex_threshold |
| ) |
| except Exception as e: |
| print(f"Error during prediction for {file_path}: {e}") |
| continue |
| |
| print(f"Predicted {len(predicted_vertices)} vertices") |
| |
| |
| plotter = pv.Plotter(window_size=[800,600]) |
| plotter.background_color = 'white' |
| |
| |
| if pcloud.size: |
| pc_cloud = pv.PolyData(pcloud) |
| plotter.add_mesh(pc_cloud, color='lightgray', point_size=2, render_points_as_spheres=True, label='Input PC') |
| |
| |
| if gt_vertices.size: |
| gt_pd = pv.PolyData(gt_vertices) |
| plotter.add_mesh(gt_pd, color='red', point_size=8, render_points_as_spheres=True, label='GT Vertices') |
| |
| |
| if predicted_vertices.size: |
| pred_pd = pv.PolyData(predicted_vertices) |
| plotter.add_mesh(pred_pd, color='blue', point_size=8, render_points_as_spheres=True, label='Predicted Vertices') |
| |
| plotter.add_legend() |
| plotter.show(title=os.path.basename(file_path)) |
| |
| |
| output_data = { |
| 'predicted_vertices': predicted_vertices, |
| 'input_file': file_path, |
| 'model_used': model_path, |
| 'grid_size': grid_size, |
| 'voxel_size': voxel_size, |
| 'vertex_threshold': vertex_threshold, |
| 'original_data': data |
| } |
| |
| |
| base_name = os.path.splitext(os.path.basename(file_path))[0] |
| output_filename = f"{base_name}_predictions" |
| try: |
| save_data(output_data, output_filename) |
| print(f"Results saved to: data/{output_filename}.pkl") |
| except Exception as e: |
| print(f"Error saving results for {file_path}: {e}") |
| |
| print(f"\nCompleted processing {len(data_files)} files") |
|
|
| if __name__ == "__main__": |
| inference = False |
|
|
| |
| data_folder_train = 'YOUR_LOCAL_DATA_FOLDER_PATH' |
| |
| |
| num_epochs_train = 100 |
| batch_size_train = 16 |
| |
| negative_to_positive_bce_ratio = 1 |
| |
| if inference: |
| |
| run_inference(model_path='YOUR_MODEL_PATH.pth', |
| data_file_path='YOUR_INFERENCE_DATA_FOLDER_PATH', |
| output_file=None, |
| grid_size=128, |
| voxel_size=0.5, |
| vertex_threshold=0.5 |
| ) |
| else: |
| train_model(data_folder=data_folder_train, |
| num_epochs=num_epochs_train, |
| batch_size=batch_size_train, |
| neg_pos_ratio_val=negative_to_positive_bce_ratio) |
|
|
|
|