bdck's picture
Upload point_sam/model/pc_sam.py
4fafdbf verified
"""Segment Anything Model for Point Clouds.
References:
- https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/sam.py
"""
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import repeat_interleave, sample_prompts, sample_prompts_adapter
from .mask_decoder import AuxInputs, MaskDecoder
from .pc_encoder import PointCloudEncoder
from .prompt_encoder import MaskEncoder, PointEncoder
class PointCloudSAM(nn.Module):
def __init__(
self,
pc_encoder: PointCloudEncoder,
mask_encoder: MaskEncoder,
mask_decoder: MaskDecoder,
prompt_iters: int,
enable_mask_refinement_iterations=True,
):
super().__init__()
self.pc_encoder = pc_encoder
self.point_encoder = PointEncoder(pc_encoder.embed_dim)
self.mask_encoder = mask_encoder
self.mask_decoder = mask_decoder
self.prompt_iters = prompt_iters
self.enable_mask_refinement_iterations = enable_mask_refinement_iterations
def predict_masks(
self,
coords: torch.Tensor,
features: torch.Tensor,
prompt_coords: torch.Tensor,
prompt_labels: torch.Tensor,
prompt_masks: torch.Tensor = None,
multimask_output: bool = True,
):
"""Predict masks given point prompts.
Args:
coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1].
features: [B, N, F]. Point cloud features.
"""
# pc_embeddings: [B, num_patches, D]
pc_embeddings, patches = self.pc_encoder(coords, features)
centers = patches["centers"] # [B, num_patches, 3]
knn_idx = patches["knn_idx"] # [B, N, K]
aux_inputs = AuxInputs(coords=coords, features=features, centers=centers)
# [B, num_patches, D]
pc_pe = self.point_encoder.pe_layer(centers)
# [B * M, num_queries, D]
sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels)
# [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None)
dense_embeddings = self.mask_encoder(
prompt_masks,
coords,
centers,
knn_idx
)
# [B * M, num_patches, D]
dense_embeddings = repeat_interleave(
dense_embeddings,
sparse_embeddings.shape[0] // dense_embeddings.shape[0],
0,
)
# [B * M, num_outputs, N], [B * M, num_outputs]
masks, iou_preds = self.mask_decoder(
pc_embeddings,
pc_pe,
sparse_embeddings,
dense_embeddings,
aux_inputs=aux_inputs,
multimask_output=multimask_output,
)
return masks, iou_preds
def forward(
self,
coords: torch.Tensor,
features: torch.Tensor,
gt_masks: torch.Tensor,
is_eval: torch.bool = False,
) -> List[Dict[str, torch.Tensor]]:
"""Forward pass for training. The prompts are sampled given the ground truth masks.
Args:
coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1].
features: [B, N, F]. Point cloud features.
gt_masks: [B, M, N], bool. Ground truth binary masks.
Returns:
outputs: List of dictionaries. Each dictionary contains the following keys:
- prompt_coords: [B * M, num_queries, 3]. Coordinates of the sampled prompts.
- prompt_labels: [B * M, num_queries], bool. Labels of the sampled prompts.
- prompt_masks: [B * M, N]. The most confident mask.
- masks: [B * M, num_outputs, N]. Predicted masks.
- iou_preds: [B * M, num_outputs]. IoU predictions.
"""
batch_size = coords.shape[0]
num_masks = gt_masks.shape[1]
# pc_embeddings: [B, num_patches, D]
pc_embeddings, patches = self.pc_encoder(coords, features)
centers = patches["centers"] # [B, num_patches, 3]
knn_idx = patches["knn_idx"] # [B, N, K]
outputs = [] # Store the output at each iteration
prompt_coords = coords.new_empty((batch_size * num_masks, 0, 3))
prompt_labels = gt_masks.new_empty((batch_size * num_masks, 0))
prompt_masks = None # [B * M, N]
aux_inputs = AuxInputs(coords=coords, features=features, centers=centers)
# According to Appendix A (training algorithm) of SAM paper,
# there are two iterations where no additional prompts are sampled.
if self.enable_mask_refinement_iterations and self.training:
mask_refinement_iterations = [self.prompt_iters - 1]
if self.prompt_iters > 1:
sampled_iter = torch.randint(1, self.prompt_iters, (1,)).item()
mask_refinement_iterations.append(sampled_iter)
else:
mask_refinement_iterations = []
# [B, num_patches, D]
pc_pe = self.point_encoder.pe_layer(centers)
for i in range(self.prompt_iters):
if i == 0 or i not in mask_refinement_iterations:
new_prompt_coords, new_prompt_labels = sample_prompts_adapter(
coords, gt_masks, prompt_masks, is_eval=is_eval,
)
prompt_coords = torch.cat([prompt_coords, new_prompt_coords], dim=1)
prompt_labels = torch.cat([prompt_labels, new_prompt_labels], dim=1)
# [B * M, num_queries, D]
sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels)
# [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None)
dense_embeddings = self.mask_encoder(
prompt_masks,
coords,
centers,
knn_idx,
center_idx=patches.get("fps_idx"),
)
# [B * M, num_patches, D]
dense_embeddings = repeat_interleave(
dense_embeddings,
sparse_embeddings.shape[0] // dense_embeddings.shape[0],
0,
)
# [B * M, num_outputs, N], [B * M, num_outputs]
masks, iou_preds = self.mask_decoder(
pc_embeddings,
pc_pe,
sparse_embeddings,
dense_embeddings,
aux_inputs=aux_inputs,
multimask_output=(i == 0),
)
# Select the most confident mask for the next iteration
if i == 0:
max_iou_pred_ind = torch.argmax(iou_preds, dim=1) # [B * M]
prompt_masks = batch_index_select(
masks, max_iou_pred_ind, dim=1
) # [B * M, N]
else:
max_iou_pred_ind = 0
prompt_masks = masks[:, 0]
outputs.append(
dict(
prompt_coords=prompt_coords,
prompt_labels=prompt_labels,
masks=masks,
iou_preds=iou_preds,
max_iou_pred_ind=max_iou_pred_ind,
prompt_masks=prompt_masks,
)
)
return outputs
def batch_index_select(data: torch.Tensor, index: torch.Tensor, dim: int):
"""Batch index select."""
batch_size = data.shape[0]
view_shape = [1] * data.dim()
view_shape[0] = batch_size
view_shape[dim] = -1
index = index.view(view_shape)
shape = list(data.shape)
shape[dim] = index.shape[dim]
index = index.expand(shape)
return torch.gather(data, dim, index)