hoho / end_to_end.py
jskvrna's picture
Final submission code
9518589
"""
End-to-End Voxel-Based Vertex Detection Pipeline
This file implements a complete pipeline for detecting wireframe vertices from 3D point clouds using
a voxel-based deep learning approach. The pipeline includes:
1. Data preprocessing: Converting 14D point clouds into 3D voxel grids with averaged features
2. Ground truth generation: Creating binary vertex labels and refinement targets from wireframe vertices
3. Model architecture: VoxelUNet with encoder-decoder structure and 1x1x1 bottleneck for vertex detection
4. Training: Combined loss function with BCE, Dice loss, and MSE for offset regression
5. Inference: Predicting vertex locations from new point clouds with visualization
Key components:
- Voxelization with configurable grid size and metric voxel size
- Per-voxel MLP before convolutional processing
- Gaussian smoothing of ground truth labels
- Refinement prediction for sub-voxel accuracy
- PyVista-based visualization for results analysis
Usage:
- Set inference=False to train a new model
- Set inference=True to run predictions on existing data
"""
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Any, Tuple, List
from torch.utils.data import Dataset, DataLoader
import glob
import pyvista as pv
import torch
# [Previous code from the existing document remains unchanged up to CombinedLoss class]
# ... (save_data, load_data, get_data_files, voxelize_points, create_ground_truth, VoxelUNet, VoxelDataset) ...
def save_data(dict_to_save: Dict[str, Any], filename: str, data_folder: str = "data") -> None:
"""Save dictionary data to pickle file"""
os.makedirs(data_folder, exist_ok=True)
filepath = os.path.join(data_folder, f"{filename}.pkl")
with open(filepath, 'wb') as f:
pickle.dump(dict_to_save, f)
#print(f"Data saved to {filepath}")
def load_data(filepath: str) -> Dict[str, Any]:
"""Load dictionary data from pickle file"""
with open(filepath, 'rb') as f:
data = pickle.load(f)
#print(f"Data loaded from {filepath}")
return data
def get_data_files(data_folder: str = "data", pattern: str = "*.pkl") -> List[str]:
"""Get list of data files from folder"""
search_pattern = os.path.join(data_folder, pattern)
files = glob.glob(search_pattern)
#print(f"Found {len(files)} data files in {data_folder}")
return files
def voxelize_points(points: np.ndarray,
grid_size_xy: int = 64,
voxel_size_metric: float = 0.25
) -> Tuple[torch.Tensor, np.ndarray, Dict[str, Any]]:
"""
Voxelize 14D point cloud into a 3D grid with a fixed number of voxels and fixed metric voxel size.
The Z dimension of the grid will also have `grid_size_xy` voxels, forming a cubic grid.
The point cloud is centered within this metric grid. Points outside are discarded.
Features from points falling into the same voxel are averaged.
Args:
points: (N, 14) array where first 3 dims are xyz (original coordinates).
grid_size_xy: Number of voxels along X and Y dimensions (and Z).
voxel_size_metric: The physical size of each voxel (e.g., 0.5 units).
Returns:
voxel_grid: (NUM_FEATURES, dim_z, dim_y, dim_x) tensor with averaged features.
voxel_indices_for_points: (N_points_in_grid, 3) integer voxel indices (z, y, x)
for each input point that falls within the grid.
scale_info: Dict with transformation parameters:
'grid_origin_metric': Real-world metric coordinate of the corner of voxel [0,0,0] (x,y,z).
'voxel_size_metric': The metric size of a voxel.
'grid_dims_voxels': Tuple (dim_x, dim_y, dim_z) representing number of voxels.
'pc_centroid_metric': Centroid of the input point cloud (x,y,z).
"""
NUM_FEATURES = 14
dim_x = grid_size_xy
dim_y = grid_size_xy
dim_z = grid_size_xy # Assuming cubic grid
if dim_z == 0: dim_z = 1 # Ensure at least one voxel in Z
grid_dims_voxels = np.array([dim_x, dim_y, dim_z], dtype=int)
def _get_empty_return(reason: str = ""):
voxel_grid_empty = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
voxel_indices_empty = np.empty((0, 3), dtype=int)
scale_info_empty = {
'grid_origin_metric': np.zeros(3, dtype=float),
'voxel_size_metric': voxel_size_metric,
'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
'pc_centroid_metric': np.zeros(3, dtype=float),
}
return voxel_grid_empty, voxel_indices_empty, scale_info_empty
if points.shape[0] == 0:
return _get_empty_return("Initial empty point cloud")
xyz = points[:, :3]
features_other = points[:, 3:]
pc_centroid_metric = xyz.mean(axis=0)
grid_metric_span = grid_dims_voxels * voxel_size_metric
grid_origin_metric = pc_centroid_metric - (grid_metric_span / 2.0)
# Voxel grid to store summed features
voxel_grid_sum = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32)
# Counter for points per voxel
point_counts_in_voxel = torch.zeros(grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.int32)
continuous_voxel_coords = (xyz - grid_origin_metric) / voxel_size_metric
voxel_indices_for_points_zyx_order = []
for i in range(points.shape[0]):
current_point_continuous_coord_xyz = continuous_voxel_coords[i]
# Using np.round for voxel assignment (assigns to nearest voxel center)
voxel_idx_int_xyz = np.round(current_point_continuous_coord_xyz).astype(int)
idx_x, idx_y, idx_z = voxel_idx_int_xyz[0], voxel_idx_int_xyz[1], voxel_idx_int_xyz[2]
if not (0 <= idx_x < grid_dims_voxels[0] and \
0 <= idx_y < grid_dims_voxels[1] and \
0 <= idx_z < grid_dims_voxels[2]):
continue # Point is outside the grid
voxel_indices_for_points_zyx_order.append([idx_z, idx_y, idx_x])
assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
offset_xyz_in_grid_units = current_point_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
# Accumulate features in voxel_grid_sum
voxel_grid_sum[0, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[0] # dx
voxel_grid_sum[1, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[1] # dy
voxel_grid_sum[2, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[2] # dz
if NUM_FEATURES > 3:
current_point_other_features = features_other[i]
voxel_grid_sum[3:, idx_z, idx_y, idx_x] += torch.tensor(current_point_other_features, dtype=torch.float32)
point_counts_in_voxel[idx_z, idx_y, idx_x] += 1
# Averaging step
# Initialize the final voxel_grid which will store averaged features
voxel_grid = torch.zeros_like(voxel_grid_sum)
# Prepare counts for division, ensuring no division by zero.
# Convert counts to float for division.
counts_for_division = point_counts_in_voxel.float()
# For voxels with 0 points, counts_for_division is 0.0.
# To avoid 0/0 = NaN, set these counts to 1.0. Since voxel_grid_sum is 0 there,
# the result of 0.0 / 1.0 will be 0.0, which is correct.
counts_for_division[counts_for_division == 0] = 1.0
# Perform averaging:
# voxel_grid_sum is (C, D, H, W)
# counts_for_division.unsqueeze(0) is (1, D, H, W), broadcasting correctly.
voxel_grid = voxel_grid_sum / counts_for_division.unsqueeze(0)
final_voxel_indices_for_points_zyx = np.array(voxel_indices_for_points_zyx_order, dtype=int) if voxel_indices_for_points_zyx_order else np.empty((0,3), dtype=int)
scale_info = {
'grid_origin_metric': grid_origin_metric,
'voxel_size_metric': voxel_size_metric,
'grid_dims_voxels': tuple(grid_dims_voxels.tolist()),
'pc_centroid_metric': pc_centroid_metric,
}
return voxel_grid, final_voxel_indices_for_points_zyx, scale_info
def create_ground_truth(vertices: np.ndarray,
scale_info: Dict[str, Any]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create ground truth voxel labels and refinement targets using metric voxelization info.
The grid dimensions are taken from scale_info.
Args:
vertices: (M, 3) vertex coordinates in original metric space.
scale_info: Dict from voxelize_points. Requires:
'grid_origin_metric', 'voxel_size_metric', 'grid_dims_voxels'.
Returns:
vertex_labels: (dim_z, dim_y, dim_x) binary labels (1.0 for voxel containing a vertex).
refinement_targets: (3, dim_z, dim_y, dim_x) offset (dx,dy,dz) from voxel cell center
in grid units. Range approx [-0.5, 0.5).
"""
grid_origin_metric = scale_info['grid_origin_metric'] # (ox, oy, oz)
voxel_size_metric = scale_info['voxel_size_metric']
# grid_dims_voxels is (num_voxels_x, num_voxels_y, num_voxels_z)
grid_dims_voxels = np.array(scale_info['grid_dims_voxels'])
dim_x, dim_y, dim_z = grid_dims_voxels[0], grid_dims_voxels[1], grid_dims_voxels[2]
# Labels tensor: (dim_z, dim_y, dim_x)
vertex_labels = torch.zeros(dim_z, dim_y, dim_x, dtype=torch.float32)
# Refinement targets tensor: (3, dim_z, dim_y, dim_x) for (dx, dy, dz) offsets
refinement_targets = torch.zeros(3, dim_z, dim_y, dim_x, dtype=torch.float32)
if vertices.shape[0] == 0:
return vertex_labels, refinement_targets
# Convert vertex metric coordinates to continuous voxel coordinates
# (potentially fractional and outside [0, dim-1])
continuous_voxel_coords_vertices = (vertices - grid_origin_metric) / voxel_size_metric
for i in range(vertices.shape[0]):
# v_continuous_coord_xyz is (vx, vy, vz) for the current vertex in continuous voxel space
v_continuous_coord_xyz = continuous_voxel_coords_vertices[i]
# Integer voxel index (ix, iy, iz) by flooring
v_idx_int_xyz = np.floor(v_continuous_coord_xyz).astype(int)
# Clip to be within grid boundaries [0, dim-1]
idx_x = np.clip(v_idx_int_xyz[0], 0, dim_x - 1)
idx_y = np.clip(v_idx_int_xyz[1], 0, dim_y - 1)
idx_z = np.clip(v_idx_int_xyz[2], 0, dim_z - 1)
# Set label for this voxel (using z, y, x order for tensor access)
vertex_labels[idx_z, idx_y, idx_x] = 1.0
# Calculate refinement offset:
# Center of the *assigned* (clipped) voxel in continuous grid index space
assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5
# Offset of the vertex from its *assigned* voxel center, in grid units.
offset_xyz_grid_units = v_continuous_coord_xyz - assigned_voxel_center_grid_idx_space
# Store dx, dy, dz in channels 0, 1, 2 respectively
# refinement_targets is (3, Z, Y, X)
refinement_targets[0, idx_z, idx_y, idx_x] = offset_xyz_grid_units[0] # dx
refinement_targets[1, idx_z, idx_y, idx_x] = offset_xyz_grid_units[1] # dy
refinement_targets[2, idx_z, idx_y, idx_x] = offset_xyz_grid_units[2] # dz
return vertex_labels, refinement_targets
class VoxelUNet(nn.Module):
"""Encoder-decoder network with a 1x1x1 bottleneck for voxel-based vertex detection.
Includes a per-voxel MLP before the first convolutional block."""
def __init__(self, in_channels: int = 14, base_channels: int = 32, bottleneck_expansion: int = 2, mlp_hidden_factor: int = 2):
super(VoxelUNet, self).__init__()
bc = base_channels
# Per-voxel MLP
# The MLP transforms input features per voxel before the convolutional encoder.
# Input to MLP: in_channels
# Output of MLP: base_channels (bc)
mlp_hidden_dim = in_channels * mlp_hidden_factor # Intermediate dimension for the MLP
self.voxel_mlp = nn.Sequential(
nn.Linear(in_channels, mlp_hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(mlp_hidden_dim, bc) # Output of MLP has 'base_channels' features
)
# Encoder
# self.enc1 now takes 'base_channels' as input from the MLP.
self.enc1 = self._conv_block(bc, bc) # bc
self.enc2 = self._conv_block(bc, bc * 2) # bc*2
self.enc3 = self._conv_block(bc * 2, bc * 4) # bc*4
self.enc4 = self._conv_block(bc * 4, bc * 8) # bc*8
self.enc5 = self._conv_block(bc * 8, bc * 16) # bc*16
self.pool = nn.MaxPool3d(2)
# Bottleneck
self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
bottleneck_in_channels = bc * 16
# Width of the bottleneck vector (number of channels after 1x1x1 pooling)
bottleneck_width = bottleneck_in_channels * bottleneck_expansion
self.bottleneck = nn.Sequential(
nn.Conv3d(bottleneck_in_channels, bottleneck_width, kernel_size=1, padding=0, bias=True),
nn.ReLU(inplace=True),
# Second 1x1 conv to add more capacity/non-linearity in the bottleneck
nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True),
nn.ReLU(inplace=True)
)
# Decoder
# Input channels for decoder blocks are adjusted as skip connections are removed.
self.dec5 = self._conv_block(bottleneck_width, bc * 16) # Input from upsampled bottleneck
self.up4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.dec4 = self._conv_block(bc * 16, bc * 8) # Input from dec5 output
self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.dec3 = self._conv_block(bc * 8, bc * 4) # Input from dec4 output
self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.dec2 = self._conv_block(bc * 4, bc * 2) # Input from dec3 output
self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.dec1 = self._conv_block(bc * 2, bc) # Input from dec2 output
# Output heads
#self.vertex_head = nn.Conv3d(bc, 1, kernel_size=1)
self.vertex_head = nn.Sequential(
nn.Conv3d(bc, bc // 2, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv3d(bc // 2, bc // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv3d(bc // 4, 1, kernel_size=1)
)
self.refinement_head = nn.Conv3d(bc, 3, kernel_size=1)
self.tanh = nn.Tanh() # For refinement head
def _conv_block(self, in_channels: int, out_channels: int) -> nn.Sequential:
# Standard convolutional block with two 3x3 convolutions
# Using bias=False because BatchNorm3d is used after each convolution
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x shape: (B, C_in_raw, D, H, W)
# Per-voxel MLP
B, C_in_raw, D, H, W = x.shape
# Permute to (B, D, H, W, C_in_raw) for nn.Linear
x_permuted = x.permute(0, 2, 3, 4, 1).contiguous()
# Flatten spatial dimensions: (B*D*H*W, C_in_raw)
x_flattened = x_permuted.view(-1, C_in_raw)
# Apply MLP: (B*D*H*W, C_mlp_out) where C_mlp_out is base_channels (bc)
mlp_out_flattened = self.voxel_mlp(x_flattened)
C_mlp_out = mlp_out_flattened.shape[-1] # Should be self.base_channels
# Reshape back to (B, D, H, W, C_mlp_out)
x_mlp_reshaped = mlp_out_flattened.view(B, D, H, W, C_mlp_out)
# Permute back to (B, C_mlp_out, D, H, W) for 3D convolutions
x_processed = x_mlp_reshaped.permute(0, 4, 1, 2, 3).contiguous()
# Encoder path
e1 = self.enc1(x_processed) # Output spatial: S, Output channels: bc
p1 = self.pool(e1) # Output spatial: S/2
e2 = self.enc2(p1) # Output spatial: S/2, Output channels: bc*2
p2 = self.pool(e2) # Output spatial: S/4
e3 = self.enc3(p2) # Output spatial: S/4, Output channels: bc*4
p3 = self.pool(e3) # Output spatial: S/8
e4 = self.enc4(p3) # Output spatial: S/8, Output channels: bc*8
p4 = self.pool(e4) # Output spatial: S/16
e5 = self.enc5(p4) # Output spatial: S/16, Output channels: bc*16
p5 = self.pool(e5) # Output spatial: S/32, Channels: bc*16 (input to bottleneck path)
# Bottleneck
b_pooled = self.adaptive_pool(p5) # Output spatial: 1x1x1, Output channels: bc*16
b = self.bottleneck(b_pooled) # Output spatial: 1x1x1, Output channels: bottleneck_width
# Decoder path
# Upsample bottleneck output to match spatial dimensions of e5 (S/16)
u5_from_b = nn.functional.interpolate(b, size=e5.shape[2:], mode='trilinear', align_corners=True)
d5 = self.dec5(u5_from_b) # Output spatial: S/16, Output channels: bc*16
u4 = self.up4(d5) # Output spatial: S/8
d4 = self.dec4(u4) # Output spatial: S/8, Output channels: bc*8
u3 = self.up3(d4) # Output spatial: S/4
d3 = self.dec3(u3) # Output spatial: S/4, Output channels: bc*4
u2 = self.up2(d3) # Output spatial: S/2
d2 = self.dec2(u2) # Output spatial: S/2, Output channels: bc*2
u1 = self.up1(d2) # Output spatial: S
d1 = self.dec1(u1) # Output spatial: S, Output channels: bc
# Output heads
vertex_logits = self.vertex_head(d1)
refinement = self.tanh(self.refinement_head(d1)) * 0.5 # Output range [-0.5, 0.5]
return vertex_logits, refinement
class VoxelDataset(Dataset):
def __init__(self, data_files: List[str], voxel_size: float = 0.1, grid_size: int = 64):
self.data_files = data_files
self.voxel_size = voxel_size
self.grid_size = grid_size
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
data = load_data(self.data_files[idx])
voxel_grid, _, scale_info = voxelize_points(
data['pcloud_14d'], self.grid_size, self.voxel_size
)
wf_vertices_np = np.array(data['wf_vertices'])
vertex_labels, refinement_targets = create_ground_truth(
wf_vertices_np, scale_info
)
return voxel_grid, vertex_labels, refinement_targets, scale_info
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple # Added for Tuple type hint
class CombinedLoss(nn.Module):
"""
Combined loss for vertex classification and offset regression.
Uses:
- BCEWithLogitsLoss (with configurable negative/positive sample weighting)
- Dice loss
- MSE loss on refinement offsets (only over positive voxels)
- Gaussian blur on the GT labels
"""
def __init__(self,
vertex_weight: float = 1.0,
refinement_weight: float = 0.0,
dice_weight: float = 0.5,
bce_neg_pos_ratio: float = 1.0, # Ratio of negative to positive sample weight in BCE
blur_kernel_size: int = 5,
blur_sigma: float = 1.0,
eps: float = 1e-6):
super().__init__()
self.vertex_weight = vertex_weight
self.refinement_weight = refinement_weight
self.dice_weight = dice_weight
self.bce_neg_pos_ratio = bce_neg_pos_ratio # Store the ratio
self.eps = eps
# BCE with logits (reduction='none' to apply custom weighting)
self.bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none')
# MSE for offset regression
self.mse_loss = nn.MSELoss()
# build 3D gaussian kernel
k = blur_kernel_size
coords = torch.arange(k, dtype=torch.float32) - (k - 1) / 2
xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij')
kernel = torch.exp(-(xx**2 + yy**2 + zz**2) / (2 * blur_sigma**2))
# shape (1,1,k,k,k)
kernel = kernel.view(1, 1, k, k, k)
self.register_buffer('gaussian_kernel', kernel)
self.pad = k // 2
def forward(self,
vertex_logits_pred: torch.Tensor, # (B,1,D,H,W)
refinement_pred: torch.Tensor, # (B,3,D,H,W)
vertex_gt: torch.Tensor, # (B,D,H,W), 0/1
refinement_gt: torch.Tensor # (B,3,D,H,W)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# logits & gt
logits = vertex_logits_pred.squeeze(1) # (B,D,H,W)
gt = vertex_gt.float() # (B,D,H,W)
# apply gaussian blur on gt
gt_unsq = gt.unsqueeze(1) # (B,1,D,H,W)
gt_blur = F.conv3d(gt_unsq, self.gaussian_kernel, padding=self.pad) # (B,1,D,H,W)
gt_blur = gt_blur.clamp(0, 1) # ensure values are in [0, 1]
gt_smooth = gt_blur.squeeze(1) # (B,D,H,W)
# 1) Weighted BCE loss - positive when gt_smooth > 1e-3 (using smoothed GT for mask definition)
pos_mask = gt_smooth > 1e-3 # Mask for "positive" regions based on smoothed GT
neg_mask = ~pos_mask # Mask for "negative" regions
bce_all = self.bce_loss_fn(logits, gt_smooth) # Calculate BCE loss for all elements
# Calculate weighted BCE
pos_weight_factor = 1.0 # Weight for positive samples' contribution
neg_weight_factor = self.bce_neg_pos_ratio # Weight for negative samples' contribution
bce = torch.tensor(0.0, device=logits.device)
num_pos = pos_mask.sum()
num_neg = neg_mask.sum()
if num_pos > 0 and num_neg > 0:
mean_pos_loss = bce_all[pos_mask].mean()
mean_neg_loss = bce_all[neg_mask].mean()
bce = pos_weight_factor * mean_pos_loss + neg_weight_factor * mean_neg_loss
elif num_pos > 0: # Only positive samples contribute
mean_pos_loss = bce_all[pos_mask].mean()
bce = pos_weight_factor * mean_pos_loss
elif num_neg > 0: # Only negative samples contribute
mean_neg_loss = bce_all[neg_mask].mean()
bce = neg_weight_factor * mean_neg_loss
# If no samples (num_pos=0 and num_neg=0), bce remains 0.0
# 2) Dice loss
prob = torch.sigmoid(logits)
# Use binarized smoothed GT for Dice target, consistent with original
gt_smooth_round_for_dice = gt_smooth
intersection = (prob * gt_smooth_round_for_dice).sum(dim=[1,2,3])
union = prob.sum(dim=[1,2,3]) + gt_smooth_round_for_dice.sum(dim=[1,2,3])
dice_score = (2 * intersection + self.eps) / (union + self.eps)
dice_loss = 1 - dice_score.mean()
vertex_loss = bce + self.dice_weight * dice_loss
# 3) Refinement MSE (only where original gt==1, i.e., true vertex locations)
# Use the original hard GT for selecting voxels for refinement loss
mask_pos_refinement = (gt > 0.5).unsqueeze(1)
refinement_loss = torch.tensor(0., device=logits.device)
if mask_pos_refinement.sum() > 0:
# Ensure pred and gt have the same shape for masked selection
expanded_mask = mask_pos_refinement.expand_as(refinement_pred)
pred_offsets = refinement_pred[expanded_mask].view(-1, 3)
gt_offsets = refinement_gt[expanded_mask].view(-1, 3)
if pred_offsets.numel() > 0: # Ensure there are elements to compute loss on
refinement_loss = self.mse_loss(pred_offsets, gt_offsets)
# 4) Total loss
total_loss = (self.vertex_weight * vertex_loss +
self.refinement_weight * refinement_loss)
return total_loss, vertex_loss, refinement_loss
def train_epoch(model, dataloader, optimizer, criterion, device, current_epoch: int):
model.train()
total_loss_epoch = 0.0
vertex_loss_epoch = 0.0
refinement_loss_epoch = 0.0
for batch_idx, (voxel_grid_batch, vertex_labels_batch, refinement_targets_batch, _) in enumerate(dataloader):
voxel_grid_batch = voxel_grid_batch.to(device)
vertex_labels_batch = vertex_labels_batch.to(device)
refinement_targets_batch = refinement_targets_batch.to(device)
if False:
print(f'Epoch {current_epoch+1}, Batch {batch_idx+1}/{len(dataloader)}')
sample_voxel_features = voxel_grid_batch[0].cpu().numpy()
sample_gt_labels = vertex_labels_batch[0].cpu().numpy()
sample_gt_refinement = refinement_targets_batch[0].cpu().numpy()
summed_xyz_in_voxels = sample_voxel_features[:3]
occupied_voxel_mask = np.any(summed_xyz_in_voxels != 0, axis=0)
plotter = pv.Plotter(window_size=[800,600])
plotter.background_color = 'white'
if np.any(occupied_voxel_mask):
occupied_voxel_indices = np.array(np.where(occupied_voxel_mask)).T
input_points_display = pv.PolyData(occupied_voxel_indices + 0.5)
plotter.add_mesh(input_points_display, color='cornflowerblue', point_size=5, render_points_as_spheres=True, label='Occupied Voxels (Centers)')
gt_vertex_voxel_mask = sample_gt_labels > 0.5
if np.any(gt_vertex_voxel_mask):
gt_vertex_indices_int = np.array(np.where(gt_vertex_voxel_mask)).T
gt_offsets = sample_gt_refinement[:, gt_vertex_voxel_mask].T
gt_vertex_positions_grid_space = gt_vertex_indices_int.astype(float) + 0.5 + gt_offsets
target_vertices_display = pv.PolyData(gt_vertex_positions_grid_space)
plotter.add_mesh(target_vertices_display, color='crimson', point_size=10, render_points_as_spheres=True, label='Target Vertices (GT)')
plotter.show(title=f"Debug Viz E{current_epoch+1} B{batch_idx+1}", auto_close=False)
else:
print(f"Epoch {current_epoch+1} Batch {batch_idx+1}: No data to visualize for the first sample.")
optimizer.zero_grad()
vertex_logits_pred, refinement_pred = model(voxel_grid_batch)
loss, vertex_loss, refinement_loss = criterion(
vertex_logits_pred, refinement_pred, vertex_labels_batch, refinement_targets_batch
)
print(f"Batch {batch_idx+1}/{len(dataloader)}: Loss={loss.item():.4f}, Vertex Loss={vertex_loss.item():.4f}, Refinement Loss={refinement_loss.item():.4f}")
if loss > 0.000001:
loss.backward()
optimizer.step()
total_loss_epoch += loss.item()
vertex_loss_epoch += vertex_loss.item()
refinement_loss_epoch += refinement_loss.item()
if (batch_idx + 1) % 200 == 0:
checkpoint_path = f"model_epoch_{current_epoch+1}_batch_{batch_idx+1}_grid_128v9.pth" # Consider updating filename if grid size changes
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved batch checkpoint: {checkpoint_path}")
avg_total_loss = total_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
avg_vertex_loss = vertex_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
avg_refinement_loss = refinement_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0
return avg_total_loss, avg_vertex_loss, avg_refinement_loss
def train_model(data_folder: str = "data", num_epochs: int = 100, batch_size: int = 4, neg_pos_ratio_val: float = 1.0):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
data_files = get_data_files(data_folder)
if not data_files:
print(f"No data files found in {data_folder}. Exiting.")
return
GRID_SIZE_CFG = 128
VOXEL_SIZE_CFG = 0.5
dataset = VoxelDataset(data_files, voxel_size=VOXEL_SIZE_CFG, grid_size=GRID_SIZE_CFG)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = CombinedLoss(
vertex_weight=10.0,
refinement_weight=0.0,
dice_weight=-0.0
).to(device)
print(f"Starting training: {num_epochs} epochs, Batch size: {batch_size}, Grid size: {GRID_SIZE_CFG}, Voxel size: {VOXEL_SIZE_CFG}, Initial LR: {optimizer.param_groups[0]['lr']}")
for epoch in range(num_epochs):
print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
avg_loss, avg_vertex_loss, avg_refinement_loss = train_epoch(
model, dataloader, optimizer, criterion, device, epoch
)
print(f"Epoch {epoch+1} Summary: Avg Loss: {avg_loss:.4f}, "
f"Avg Vertex Loss: {avg_vertex_loss:.4f}, "
f"Avg Refinement Loss: {avg_refinement_loss:.4f}, "
f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
checkpoint_path = f"model_epoch_{epoch+1}_grid{GRID_SIZE_CFG}_smooth_bal{neg_pos_ratio_val}_v9.pth"
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved checkpoint: {checkpoint_path}")
final_model_path = f"final_model_grid{GRID_SIZE_CFG}_epochs{num_epochs}_smooth_bal{neg_pos_ratio_val}_v9.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Training completed! Final model saved as {final_model_path}")
def load_model_for_inference(model_path: str, device: torch.device,
in_channels: int = 14, base_channels: int = 32) -> VoxelUNet:
"""Load a VoxelUNet model for inference."""
model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print(f"Model loaded from {model_path} and set to evaluation mode on {device}.")
return model
def predict_vertices(model: VoxelUNet,
point_cloud_14d: np.ndarray,
grid_size: int,
device: torch.device,
voxel_size_metric: float = 0.5, # Added for consistency, default matches voxelize_points
vertex_threshold: float = 0.5) -> np.ndarray:
"""
Predict vertices from a 14D point cloud.
Args:
model: The trained VoxelUNet model.
point_cloud_14d: (N, 14) NumPy array of the input point cloud.
grid_size: The size of the voxel grid along X and Y dimensions (must match training).
device: PyTorch device ('cuda' or 'cpu').
voxel_size_metric: The metric size of each voxel (must match training).
vertex_threshold: Threshold for classifying a voxel as containing a vertex.
Returns:
predicted_vertices_original_space: (M, 3) NumPy array of predicted vertex
coordinates in the original point cloud space (X, Y, Z order).
Returns an empty array if no vertices are predicted
or if the input point cloud results in an empty voxel grid.
"""
voxel_grid_tensor, _, scale_info = voxelize_points(
point_cloud_14d,
grid_size_xy=grid_size,
voxel_size_metric=voxel_size_metric
)
# Check if voxelization produced a valid grid (e.g., if input point cloud was empty)
# voxelize_points returns a zero tensor for grid if input points are empty.
# If voxel_grid_tensor is all zeros and no points were input, scale_info might be default.
if voxel_grid_tensor.sum() == 0 and point_cloud_14d.shape[0] == 0:
# This case implies empty input point cloud, voxelize_points handles this.
# Predictions will naturally be empty if the grid is empty.
pass # Continue, model will predict on zero grid.
input_tensor = voxel_grid_tensor.unsqueeze(0).to(device)
with torch.no_grad():
vertex_logits_pred_tensor, refinement_pred_tensor = model(input_tensor)
vertex_prob_pred_tensor = torch.sigmoid(vertex_logits_pred_tensor)
vertex_prob_pred_np = vertex_prob_pred_tensor.squeeze(0).squeeze(0).cpu().numpy()
refinement_pred_np = refinement_pred_tensor.squeeze(0).cpu().numpy() # Shape (3, D, H, W) -> (dx,dy,dz channels)
print(f"Vertex Probabilities Stats: Min={np.min(vertex_prob_pred_np):.4f}, Max={np.max(vertex_prob_pred_np):.4f}, Mean={np.mean(vertex_prob_pred_np):.4f}, Median={np.median(vertex_prob_pred_np):.4f}")
if refinement_pred_np.size > 0:
print(f"Refinement Predictions Stats: Min={np.min(refinement_pred_np):.4f}, Max={np.max(refinement_pred_np):.4f}, Mean={np.mean(refinement_pred_np):.4f}, Median={np.median(refinement_pred_np):.4f}")
for i in range(refinement_pred_np.shape[0]): # Iterate over dx, dy, dz components
print(f" Refinement Dim {i} (dx,dy,dz order) Stats: Min={np.min(refinement_pred_np[i]):.4f}, Max={np.max(refinement_pred_np[i]):.4f}, Mean={np.mean(refinement_pred_np[i]):.4f}, Median={np.median(refinement_pred_np[i]):.4f}")
else:
print("Refinement Predictions Stats: Array is empty.")
predicted_mask = vertex_prob_pred_np > vertex_threshold
# predicted_voxel_indices are (N_preds, 3) with columns (idx_z, idx_y, idx_x)
predicted_voxel_indices_zyx = np.argwhere(predicted_mask)
if not predicted_voxel_indices_zyx.size:
return np.empty((0, 3), dtype=np.float32)
# Extract refinement offsets for the predicted voxels
# offsets_channels_first will be (3, N_preds) where channels are (dx, dy, dz)
offsets_channels_first = refinement_pred_np[:,
predicted_voxel_indices_zyx[:, 0], # z_indices
predicted_voxel_indices_zyx[:, 1], # y_indices
predicted_voxel_indices_zyx[:, 2]] # x_indices
# Transpose to (N_preds, 3) where columns are (dx, dy, dz)
offsets_xyz_order = offsets_channels_first.T
# Calculate refined coordinates in continuous voxel grid space (X, Y, Z order)
# Voxel center is at index + 0.5
# Refinement is added to this center.
# predicted_voxel_indices_zyx[:, 2] is x_idx
# predicted_voxel_indices_zyx[:, 1] is y_idx
# predicted_voxel_indices_zyx[:, 0] is z_idx
# offsets_xyz_order[:, 0] is dx
# offsets_xyz_order[:, 1] is dy
# offsets_xyz_order[:, 2] is dz
refined_x_grid = predicted_voxel_indices_zyx[:, 2].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 0]
refined_y_grid = predicted_voxel_indices_zyx[:, 1].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 1]
refined_z_grid = predicted_voxel_indices_zyx[:, 0].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 2]
# Stack to get (N_preds, 3) array in (X, Y, Z) order
refined_grid_coords_xyz = np.stack((refined_x_grid, refined_y_grid, refined_z_grid), axis=-1)
# Convert refined grid coordinates to original metric space
grid_origin_metric = np.array(scale_info['grid_origin_metric']) # (ox, oy, oz)
# Voxel_size_metric from scale_info should match the input voxel_size_metric parameter
current_voxel_size_metric = scale_info['voxel_size_metric']
# predicted_vertices_original_space are (N_preds, 3) in (X,Y,Z) order
predicted_vertices_original_space = refined_grid_coords_xyz * current_voxel_size_metric + grid_origin_metric
return predicted_vertices_original_space.astype(np.float32)
# Simple inference script
def run_inference(model_path: str,
data_file_path: str,
output_file: str = None,
grid_size: int = 128,
voxel_size: float = 0.5,
vertex_threshold: float = 0.5):
"""
Run inference on all data files in a directory, visualize with pyvista, and save results.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load model
model = load_model_for_inference(model_path, device)
# Get all data files from the directory
data_files = get_data_files(data_file_path)
if not data_files:
print(f"No data files found in {data_file_path}")
return
print(f"Found {len(data_files)} data files to process")
for i, file_path in enumerate(data_files):
print(f"\n--- Processing file {i+1}/{len(data_files)}: {os.path.basename(file_path)} ---")
# Load input data
try:
data = load_data(file_path)
except Exception as e:
print(f"Error loading {file_path}: {e}")
continue
if 'pcloud_14d' not in data:
print(f"Error: File {file_path} does not contain 'pcloud_14d' key, skipping")
continue
# Extract original point cloud and ground-truth vertices
pcloud = data['pcloud_14d'][:, :3] # (N,3)
gt_vertices = np.array(data.get('wf_vertices', [])) # (M,3) or empty
print(f"Input point cloud shape: {pcloud.shape}")
if gt_vertices.size:
print(f"GT vertices shape: {gt_vertices.shape}")
# Run prediction
print("Running inference...")
try:
predicted_vertices = predict_vertices(
model=model,
point_cloud_14d=data['pcloud_14d'],
grid_size=grid_size,
device=device,
voxel_size_metric=voxel_size,
vertex_threshold=vertex_threshold
)
except Exception as e:
print(f"Error during prediction for {file_path}: {e}")
continue
print(f"Predicted {len(predicted_vertices)} vertices")
# --- Visualization ---
plotter = pv.Plotter(window_size=[800,600])
plotter.background_color = 'white'
# Original point cloud in light gray
if pcloud.size:
pc_cloud = pv.PolyData(pcloud)
plotter.add_mesh(pc_cloud, color='lightgray', point_size=2, render_points_as_spheres=True, label='Input PC')
# Ground-truth vertices in red
if gt_vertices.size:
gt_pd = pv.PolyData(gt_vertices)
plotter.add_mesh(gt_pd, color='red', point_size=8, render_points_as_spheres=True, label='GT Vertices')
# Predicted vertices in blue
if predicted_vertices.size:
pred_pd = pv.PolyData(predicted_vertices)
plotter.add_mesh(pred_pd, color='blue', point_size=8, render_points_as_spheres=True, label='Predicted Vertices')
plotter.add_legend()
plotter.show(title=os.path.basename(file_path))
# Prepare output data
output_data = {
'predicted_vertices': predicted_vertices,
'input_file': file_path,
'model_used': model_path,
'grid_size': grid_size,
'voxel_size': voxel_size,
'vertex_threshold': vertex_threshold,
'original_data': data
}
# Save results
base_name = os.path.splitext(os.path.basename(file_path))[0]
output_filename = f"{base_name}_predictions"
try:
save_data(output_data, output_filename) # Saves to 'data' subfolder by default
print(f"Results saved to: data/{output_filename}.pkl")
except Exception as e:
print(f"Error saving results for {file_path}: {e}")
print(f"\nCompleted processing {len(data_files)} files")
if __name__ == "__main__":
inference = False
# Replace with your actual data folder path
data_folder_train = 'YOUR_LOCAL_DATA_FOLDER_PATH'
# Example: data_folder_train = '/path/to/your/training_data'
num_epochs_train = 100
batch_size_train = 16
# This parameter now controls the ratio of negative to positive samples for BCE loss
negative_to_positive_bce_ratio = 1
if inference:
# Replace with your actual model path and data path for inference
run_inference(model_path='YOUR_MODEL_PATH.pth', # Example: '/path/to/your/model.pth'
data_file_path='YOUR_INFERENCE_DATA_FOLDER_PATH', # Example: '/path/to/your/inference_data'
output_file=None, # Output will be saved in a 'data' subfolder relative to script
grid_size=128,
voxel_size=0.5,
vertex_threshold=0.5
)
else:
train_model(data_folder=data_folder_train,
num_epochs=num_epochs_train,
batch_size=batch_size_train,
neg_pos_ratio_val=negative_to_positive_bce_ratio)