| """Segment Anything Model for Point Clouds. |
| |
| References: |
| - https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/sam.py |
| """ |
|
|
| from typing import Dict, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .common import repeat_interleave, sample_prompts, sample_prompts_adapter |
| from .mask_decoder import AuxInputs, MaskDecoder |
| from .pc_encoder import PointCloudEncoder |
| from .prompt_encoder import MaskEncoder, PointEncoder |
|
|
|
|
| class PointCloudSAM(nn.Module): |
| def __init__( |
| self, |
| pc_encoder: PointCloudEncoder, |
| mask_encoder: MaskEncoder, |
| mask_decoder: MaskDecoder, |
| prompt_iters: int, |
| enable_mask_refinement_iterations=True, |
| ): |
| super().__init__() |
| self.pc_encoder = pc_encoder |
| self.point_encoder = PointEncoder(pc_encoder.embed_dim) |
| self.mask_encoder = mask_encoder |
| self.mask_decoder = mask_decoder |
| self.prompt_iters = prompt_iters |
| self.enable_mask_refinement_iterations = enable_mask_refinement_iterations |
|
|
| def predict_masks( |
| self, |
| coords: torch.Tensor, |
| features: torch.Tensor, |
| prompt_coords: torch.Tensor, |
| prompt_labels: torch.Tensor, |
| prompt_masks: torch.Tensor = None, |
| multimask_output: bool = True, |
| ): |
| """Predict masks given point prompts. |
| |
| Args: |
| coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1]. |
| features: [B, N, F]. Point cloud features. |
| """ |
| |
| pc_embeddings, patches = self.pc_encoder(coords, features) |
| centers = patches["centers"] |
| knn_idx = patches["knn_idx"] |
| aux_inputs = AuxInputs(coords=coords, features=features, centers=centers) |
|
|
| |
| pc_pe = self.point_encoder.pe_layer(centers) |
|
|
| |
| sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels) |
| |
| |
| dense_embeddings = self.mask_encoder( |
| prompt_masks, |
| coords, |
| centers, |
| knn_idx |
| ) |
|
|
| |
| dense_embeddings = repeat_interleave( |
| dense_embeddings, |
| sparse_embeddings.shape[0] // dense_embeddings.shape[0], |
| 0, |
| ) |
| |
| |
| masks, iou_preds = self.mask_decoder( |
| pc_embeddings, |
| pc_pe, |
| sparse_embeddings, |
| dense_embeddings, |
| aux_inputs=aux_inputs, |
| multimask_output=multimask_output, |
| ) |
| return masks, iou_preds |
|
|
| def forward( |
| self, |
| coords: torch.Tensor, |
| features: torch.Tensor, |
| gt_masks: torch.Tensor, |
| is_eval: torch.bool = False, |
| ) -> List[Dict[str, torch.Tensor]]: |
| """Forward pass for training. The prompts are sampled given the ground truth masks. |
| |
| Args: |
| coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1]. |
| features: [B, N, F]. Point cloud features. |
| gt_masks: [B, M, N], bool. Ground truth binary masks. |
| |
| Returns: |
| outputs: List of dictionaries. Each dictionary contains the following keys: |
| - prompt_coords: [B * M, num_queries, 3]. Coordinates of the sampled prompts. |
| - prompt_labels: [B * M, num_queries], bool. Labels of the sampled prompts. |
| - prompt_masks: [B * M, N]. The most confident mask. |
| - masks: [B * M, num_outputs, N]. Predicted masks. |
| - iou_preds: [B * M, num_outputs]. IoU predictions. |
| """ |
| batch_size = coords.shape[0] |
| num_masks = gt_masks.shape[1] |
|
|
| |
| pc_embeddings, patches = self.pc_encoder(coords, features) |
| centers = patches["centers"] |
| knn_idx = patches["knn_idx"] |
|
|
| outputs = [] |
| prompt_coords = coords.new_empty((batch_size * num_masks, 0, 3)) |
| prompt_labels = gt_masks.new_empty((batch_size * num_masks, 0)) |
| prompt_masks = None |
| aux_inputs = AuxInputs(coords=coords, features=features, centers=centers) |
|
|
| |
| |
| if self.enable_mask_refinement_iterations and self.training: |
| mask_refinement_iterations = [self.prompt_iters - 1] |
| if self.prompt_iters > 1: |
| sampled_iter = torch.randint(1, self.prompt_iters, (1,)).item() |
| mask_refinement_iterations.append(sampled_iter) |
| else: |
| mask_refinement_iterations = [] |
|
|
| |
| pc_pe = self.point_encoder.pe_layer(centers) |
|
|
| for i in range(self.prompt_iters): |
| if i == 0 or i not in mask_refinement_iterations: |
| new_prompt_coords, new_prompt_labels = sample_prompts_adapter( |
| coords, gt_masks, prompt_masks, is_eval=is_eval, |
| ) |
| prompt_coords = torch.cat([prompt_coords, new_prompt_coords], dim=1) |
| prompt_labels = torch.cat([prompt_labels, new_prompt_labels], dim=1) |
|
|
| |
| sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels) |
|
|
| |
| dense_embeddings = self.mask_encoder( |
| prompt_masks, |
| coords, |
| centers, |
| knn_idx, |
| center_idx=patches.get("fps_idx"), |
| ) |
| |
| dense_embeddings = repeat_interleave( |
| dense_embeddings, |
| sparse_embeddings.shape[0] // dense_embeddings.shape[0], |
| 0, |
| ) |
|
|
| |
| masks, iou_preds = self.mask_decoder( |
| pc_embeddings, |
| pc_pe, |
| sparse_embeddings, |
| dense_embeddings, |
| aux_inputs=aux_inputs, |
| multimask_output=(i == 0), |
| ) |
|
|
| |
| if i == 0: |
| max_iou_pred_ind = torch.argmax(iou_preds, dim=1) |
| prompt_masks = batch_index_select( |
| masks, max_iou_pred_ind, dim=1 |
| ) |
| else: |
| max_iou_pred_ind = 0 |
| prompt_masks = masks[:, 0] |
|
|
| outputs.append( |
| dict( |
| prompt_coords=prompt_coords, |
| prompt_labels=prompt_labels, |
| masks=masks, |
| iou_preds=iou_preds, |
| max_iou_pred_ind=max_iou_pred_ind, |
| prompt_masks=prompt_masks, |
| ) |
| ) |
|
|
| return outputs |
|
|
|
|
| def batch_index_select(data: torch.Tensor, index: torch.Tensor, dim: int): |
| """Batch index select.""" |
| batch_size = data.shape[0] |
| view_shape = [1] * data.dim() |
| view_shape[0] = batch_size |
| view_shape[dim] = -1 |
| index = index.view(view_shape) |
| shape = list(data.shape) |
| shape[dim] = index.shape[dim] |
| index = index.expand(shape) |
| return torch.gather(data, dim, index) |
|
|