| |
| |
| |
| |
| |
| |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import pickle |
| from torch.utils.data import Dataset, DataLoader |
| from typing import List, Dict, Tuple, Optional |
| import json |
|
|
| class ClassificationPointNet(nn.Module): |
| """ |
| PointNet implementation for binary classification from 6D point cloud patches. |
| Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge). |
| """ |
| def __init__(self, input_dim=6, max_points=1024): |
| super(ClassificationPointNet, self).__init__() |
| self.max_points = max_points |
| |
| |
| self.conv1 = nn.Conv1d(input_dim, 64, 1) |
| self.conv2 = nn.Conv1d(64, 128, 1) |
| self.conv3 = nn.Conv1d(128, 256, 1) |
| self.conv4 = nn.Conv1d(256, 512, 1) |
| self.conv5 = nn.Conv1d(512, 1024, 1) |
| self.conv6 = nn.Conv1d(1024, 2048, 1) |
| |
| |
| self.fc1 = nn.Linear(2048, 1024) |
| self.fc2 = nn.Linear(1024, 512) |
| self.fc3 = nn.Linear(512, 256) |
| self.fc4 = nn.Linear(256, 128) |
| self.fc5 = nn.Linear(128, 64) |
| self.fc6 = nn.Linear(64, 1) |
| |
| |
| self.bn1 = nn.BatchNorm1d(64) |
| self.bn2 = nn.BatchNorm1d(128) |
| self.bn3 = nn.BatchNorm1d(256) |
| self.bn4 = nn.BatchNorm1d(512) |
| self.bn5 = nn.BatchNorm1d(1024) |
| self.bn6 = nn.BatchNorm1d(2048) |
| |
| |
| self.dropout1 = nn.Dropout(0.3) |
| self.dropout2 = nn.Dropout(0.4) |
| self.dropout3 = nn.Dropout(0.5) |
| self.dropout4 = nn.Dropout(0.4) |
| self.dropout5 = nn.Dropout(0.3) |
|
|
| def forward(self, x): |
| """ |
| Forward pass |
| Args: |
| x: (batch_size, input_dim, max_points) tensor |
| Returns: |
| classification: (batch_size, 1) tensor of logits (sigmoid for probability) |
| """ |
| batch_size = x.size(0) |
| |
| |
| x1 = F.relu(self.bn1(self.conv1(x))) |
| x2 = F.relu(self.bn2(self.conv2(x1))) |
| x3 = F.relu(self.bn3(self.conv3(x2))) |
| x4 = F.relu(self.bn4(self.conv4(x3))) |
| x5 = F.relu(self.bn5(self.conv5(x4))) |
| x6 = F.relu(self.bn6(self.conv6(x5))) |
| |
| |
| global_features = torch.max(x6, 2)[0] |
| |
| |
| x = F.relu(self.fc1(global_features)) |
| x = self.dropout1(x) |
| x = F.relu(self.fc2(x)) |
| x = self.dropout2(x) |
| x = F.relu(self.fc3(x)) |
| x = self.dropout3(x) |
| x = F.relu(self.fc4(x)) |
| x = self.dropout4(x) |
| x = F.relu(self.fc5(x)) |
| x = self.dropout5(x) |
| classification = self.fc6(x) |
| |
| return classification |
|
|
| class PatchClassificationDataset(Dataset): |
| """ |
| Dataset class for loading saved patches for PointNet classification training. |
| """ |
| |
| def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True): |
| self.dataset_dir = dataset_dir |
| self.max_points = max_points |
| self.augment = augment |
| |
| |
| self.patch_files = [] |
| for file in os.listdir(dataset_dir): |
| if file.endswith('.pkl'): |
| self.patch_files.append(os.path.join(dataset_dir, file)) |
| |
| print(f"Found {len(self.patch_files)} patch files in {dataset_dir}") |
|
|
| def __len__(self): |
| return len(self.patch_files) |
|
|
| def __getitem__(self, idx): |
| """ |
| Load and process a patch for training. |
| Returns: |
| patch_data: (6, max_points) tensor of point cloud data |
| label: scalar tensor for binary classification (0 or 1) |
| valid_mask: (max_points,) boolean tensor indicating valid points |
| """ |
| patch_file = self.patch_files[idx] |
| |
| with open(patch_file, 'rb') as f: |
| patch_info = pickle.load(f) |
| |
| patch_6d = patch_info['patch_6d'] |
| label = patch_info.get('label', 0) |
| |
| |
| num_points = patch_6d.shape[0] |
| |
| if num_points >= self.max_points: |
| |
| indices = np.random.choice(num_points, self.max_points, replace=False) |
| patch_sampled = patch_6d[indices] |
| valid_mask = np.ones(self.max_points, dtype=bool) |
| else: |
| |
| patch_sampled = np.zeros((self.max_points, 6)) |
| patch_sampled[:num_points] = patch_6d |
| valid_mask = np.zeros(self.max_points, dtype=bool) |
| valid_mask[:num_points] = True |
| |
| |
| if self.augment: |
| patch_sampled = self._augment_patch(patch_sampled, valid_mask) |
| |
| |
| patch_tensor = torch.from_numpy(patch_sampled.T).float() |
| label_tensor = torch.tensor(label, dtype=torch.float32) |
| valid_mask_tensor = torch.from_numpy(valid_mask) |
| |
| return patch_tensor, label_tensor, valid_mask_tensor |
|
|
| def _augment_patch(self, patch, valid_mask): |
| """ |
| Apply data augmentation to the patch. |
| """ |
| valid_points = patch[valid_mask] |
| |
| if len(valid_points) == 0: |
| return patch |
| |
| |
| angle = np.random.uniform(0, 2 * np.pi) |
| cos_angle = np.cos(angle) |
| sin_angle = np.sin(angle) |
| rotation_matrix = np.array([ |
| [cos_angle, -sin_angle, 0], |
| [sin_angle, cos_angle, 0], |
| [0, 0, 1] |
| ]) |
| |
| |
| valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T |
| |
| |
| noise = np.random.normal(0, 0.01, valid_points[:, :3].shape) |
| valid_points[:, :3] += noise |
| |
| |
| scale = np.random.uniform(0.9, 1.1) |
| valid_points[:, :3] *= scale |
| |
| patch[valid_mask] = valid_points |
| return patch |
|
|
| def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str): |
| """ |
| Save patches from prediction pipeline to create a training dataset. |
| |
| Args: |
| patches: List of patch dictionaries from generate_patches() |
| dataset_dir: Directory to save the dataset |
| entry_id: Unique identifier for this entry/image |
| """ |
| os.makedirs(dataset_dir, exist_ok=True) |
| |
| for i, patch in enumerate(patches): |
| |
| filename = f"{entry_id}_patch_{i}.pkl" |
| filepath = os.path.join(dataset_dir, filename) |
| |
| |
| if os.path.exists(filepath): |
| continue |
| |
| |
| with open(filepath, 'wb') as f: |
| pickle.dump(patch, f) |
| |
| print(f"Saved {len(patches)} patches for entry {entry_id}") |
|
|
| |
| def collate_fn(batch): |
| valid_batch = [] |
| for patch_data, label, valid_mask in batch: |
| |
| if valid_mask.sum() > 0: |
| valid_batch.append((patch_data, label, valid_mask)) |
| |
| if len(valid_batch) == 0: |
| return None |
| |
| |
| patch_data = torch.stack([item[0] for item in valid_batch]) |
| labels = torch.stack([item[1] for item in valid_batch]) |
| valid_masks = torch.stack([item[2] for item in valid_batch]) |
| |
| return patch_data, labels, valid_masks |
|
|
| |
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.BatchNorm1d): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, |
| lr: float = 0.001): |
| """ |
| Train the ClassificationPointNet model on saved patches. |
| |
| Args: |
| dataset_dir: Directory containing saved patch files |
| model_save_path: Path to save the trained model |
| epochs: Number of training epochs |
| batch_size: Training batch size |
| lr: Learning rate |
| """ |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Training on device: {device}") |
| |
| |
| dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True) |
| print(f"Dataset loaded with {len(dataset)} samples") |
| |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, |
| collate_fn=collate_fn, drop_last=True) |
| |
| |
| model = ClassificationPointNet(input_dim=6, max_points=1024) |
| model.apply(init_weights) |
| model.to(device) |
| |
| |
| criterion = nn.BCEWithLogitsLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) |
| |
| |
| model.train() |
| for epoch in range(epochs): |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| num_batches = 0 |
| |
| for batch_idx, batch_data in enumerate(dataloader): |
| if batch_data is None: |
| continue |
| |
| patch_data, labels, valid_masks = batch_data |
| patch_data = patch_data.to(device) |
| labels = labels.to(device).unsqueeze(1) |
| |
| |
| optimizer.zero_grad() |
| outputs = model(patch_data) |
| loss = criterion(outputs, labels) |
| |
| |
| loss.backward() |
| optimizer.step() |
| |
| |
| total_loss += loss.item() |
| predicted = (torch.sigmoid(outputs) > 0.5).float() |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| num_batches += 1 |
| |
| if batch_idx % 50 == 0: |
| print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, " |
| f"Loss: {loss.item():.6f}, " |
| f"Accuracy: {100 * correct / total:.2f}%") |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
| accuracy = 100 * correct / total if total > 0 else 0 |
| |
| print(f"Epoch {epoch+1}/{epochs} completed, " |
| f"Avg Loss: {avg_loss:.6f}, " |
| f"Accuracy: {accuracy:.2f}%") |
| |
| scheduler.step() |
|
|
| |
| checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth') |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch + 1, |
| 'loss': avg_loss, |
| 'accuracy': accuracy, |
| }, checkpoint_path) |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epochs, |
| }, model_save_path) |
| |
| print(f"Model saved to {model_save_path}") |
| return model |
|
|
| def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet: |
| """ |
| Load a trained ClassificationPointNet model. |
| |
| Args: |
| model_path: Path to the saved model |
| device: Device to load the model on |
| |
| Returns: |
| Loaded ClassificationPointNet model |
| """ |
| if device is None: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| model = ClassificationPointNet(input_dim=6, max_points=1024) |
| |
| checkpoint = torch.load(model_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| model.to(device) |
| model.eval() |
| |
| return model |
|
|
| def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]: |
| """ |
| Predict binary classification from a patch using trained PointNet. |
| |
| Args: |
| model: Trained ClassificationPointNet model |
| patch: Dictionary containing patch data with 'patch_6d' key |
| device: Device to run prediction on |
| |
| Returns: |
| tuple of (predicted_class, confidence) |
| predicted_class: int (0 for not edge, 1 for edge) |
| confidence: float representing confidence score (0-1) |
| """ |
| if device is None: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| patch_6d = patch['patch_6d'] |
| |
| |
| max_points = 1024 |
| num_points = patch_6d.shape[0] |
| |
| if num_points >= max_points: |
| |
| indices = np.random.choice(num_points, max_points, replace=False) |
| patch_sampled = patch_6d[indices] |
| else: |
| |
| patch_sampled = np.zeros((max_points, 6)) |
| patch_sampled[:num_points] = patch_6d |
| |
| |
| patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) |
| patch_tensor = patch_tensor.to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(patch_tensor) |
| probability = torch.sigmoid(outputs).item() |
| predicted_class = int(probability > 0.5) |
| |
| return predicted_class, probability |
|
|
|
|