"""Segment Anything Model for Point Clouds. References: - https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/sam.py """ from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F from .common import repeat_interleave, sample_prompts, sample_prompts_adapter from .mask_decoder import AuxInputs, MaskDecoder from .pc_encoder import PointCloudEncoder from .prompt_encoder import MaskEncoder, PointEncoder class PointCloudSAM(nn.Module): def __init__( self, pc_encoder: PointCloudEncoder, mask_encoder: MaskEncoder, mask_decoder: MaskDecoder, prompt_iters: int, enable_mask_refinement_iterations=True, ): super().__init__() self.pc_encoder = pc_encoder self.point_encoder = PointEncoder(pc_encoder.embed_dim) self.mask_encoder = mask_encoder self.mask_decoder = mask_decoder self.prompt_iters = prompt_iters self.enable_mask_refinement_iterations = enable_mask_refinement_iterations def predict_masks( self, coords: torch.Tensor, features: torch.Tensor, prompt_coords: torch.Tensor, prompt_labels: torch.Tensor, prompt_masks: torch.Tensor = None, multimask_output: bool = True, ): """Predict masks given point prompts. Args: coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1]. features: [B, N, F]. Point cloud features. """ # pc_embeddings: [B, num_patches, D] pc_embeddings, patches = self.pc_encoder(coords, features) centers = patches["centers"] # [B, num_patches, 3] knn_idx = patches["knn_idx"] # [B, N, K] aux_inputs = AuxInputs(coords=coords, features=features, centers=centers) # [B, num_patches, D] pc_pe = self.point_encoder.pe_layer(centers) # [B * M, num_queries, D] sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels) # [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None) dense_embeddings = self.mask_encoder( prompt_masks, coords, centers, knn_idx ) # [B * M, num_patches, D] dense_embeddings = repeat_interleave( dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0, ) # [B * M, num_outputs, N], [B * M, num_outputs] masks, iou_preds = self.mask_decoder( pc_embeddings, pc_pe, sparse_embeddings, dense_embeddings, aux_inputs=aux_inputs, multimask_output=multimask_output, ) return masks, iou_preds def forward( self, coords: torch.Tensor, features: torch.Tensor, gt_masks: torch.Tensor, is_eval: torch.bool = False, ) -> List[Dict[str, torch.Tensor]]: """Forward pass for training. The prompts are sampled given the ground truth masks. Args: coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1]. features: [B, N, F]. Point cloud features. gt_masks: [B, M, N], bool. Ground truth binary masks. Returns: outputs: List of dictionaries. Each dictionary contains the following keys: - prompt_coords: [B * M, num_queries, 3]. Coordinates of the sampled prompts. - prompt_labels: [B * M, num_queries], bool. Labels of the sampled prompts. - prompt_masks: [B * M, N]. The most confident mask. - masks: [B * M, num_outputs, N]. Predicted masks. - iou_preds: [B * M, num_outputs]. IoU predictions. """ batch_size = coords.shape[0] num_masks = gt_masks.shape[1] # pc_embeddings: [B, num_patches, D] pc_embeddings, patches = self.pc_encoder(coords, features) centers = patches["centers"] # [B, num_patches, 3] knn_idx = patches["knn_idx"] # [B, N, K] outputs = [] # Store the output at each iteration prompt_coords = coords.new_empty((batch_size * num_masks, 0, 3)) prompt_labels = gt_masks.new_empty((batch_size * num_masks, 0)) prompt_masks = None # [B * M, N] aux_inputs = AuxInputs(coords=coords, features=features, centers=centers) # According to Appendix A (training algorithm) of SAM paper, # there are two iterations where no additional prompts are sampled. if self.enable_mask_refinement_iterations and self.training: mask_refinement_iterations = [self.prompt_iters - 1] if self.prompt_iters > 1: sampled_iter = torch.randint(1, self.prompt_iters, (1,)).item() mask_refinement_iterations.append(sampled_iter) else: mask_refinement_iterations = [] # [B, num_patches, D] pc_pe = self.point_encoder.pe_layer(centers) for i in range(self.prompt_iters): if i == 0 or i not in mask_refinement_iterations: new_prompt_coords, new_prompt_labels = sample_prompts_adapter( coords, gt_masks, prompt_masks, is_eval=is_eval, ) prompt_coords = torch.cat([prompt_coords, new_prompt_coords], dim=1) prompt_labels = torch.cat([prompt_labels, new_prompt_labels], dim=1) # [B * M, num_queries, D] sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels) # [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None) dense_embeddings = self.mask_encoder( prompt_masks, coords, centers, knn_idx, center_idx=patches.get("fps_idx"), ) # [B * M, num_patches, D] dense_embeddings = repeat_interleave( dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0, ) # [B * M, num_outputs, N], [B * M, num_outputs] masks, iou_preds = self.mask_decoder( pc_embeddings, pc_pe, sparse_embeddings, dense_embeddings, aux_inputs=aux_inputs, multimask_output=(i == 0), ) # Select the most confident mask for the next iteration if i == 0: max_iou_pred_ind = torch.argmax(iou_preds, dim=1) # [B * M] prompt_masks = batch_index_select( masks, max_iou_pred_ind, dim=1 ) # [B * M, N] else: max_iou_pred_ind = 0 prompt_masks = masks[:, 0] outputs.append( dict( prompt_coords=prompt_coords, prompt_labels=prompt_labels, masks=masks, iou_preds=iou_preds, max_iou_pred_ind=max_iou_pred_ind, prompt_masks=prompt_masks, ) ) return outputs def batch_index_select(data: torch.Tensor, index: torch.Tensor, dim: int): """Batch index select.""" batch_size = data.shape[0] view_shape = [1] * data.dim() view_shape[0] = batch_size view_shape[dim] = -1 index = index.view(view_shape) shape = list(data.shape) shape[dim] = index.shape[dim] index = index.expand(shape) return torch.gather(data, dim, index)