bdck commited on
Commit
4fafdbf
·
verified ·
1 Parent(s): ac6e542

Upload point_sam/model/pc_sam.py

Browse files
Files changed (1) hide show
  1. point_sam/model/pc_sam.py +208 -0
point_sam/model/pc_sam.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segment Anything Model for Point Clouds.
2
+
3
+ References:
4
+ - https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/sam.py
5
+ """
6
+
7
+ from typing import Dict, List
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .common import repeat_interleave, sample_prompts, sample_prompts_adapter
14
+ from .mask_decoder import AuxInputs, MaskDecoder
15
+ from .pc_encoder import PointCloudEncoder
16
+ from .prompt_encoder import MaskEncoder, PointEncoder
17
+
18
+
19
+ class PointCloudSAM(nn.Module):
20
+ def __init__(
21
+ self,
22
+ pc_encoder: PointCloudEncoder,
23
+ mask_encoder: MaskEncoder,
24
+ mask_decoder: MaskDecoder,
25
+ prompt_iters: int,
26
+ enable_mask_refinement_iterations=True,
27
+ ):
28
+ super().__init__()
29
+ self.pc_encoder = pc_encoder
30
+ self.point_encoder = PointEncoder(pc_encoder.embed_dim)
31
+ self.mask_encoder = mask_encoder
32
+ self.mask_decoder = mask_decoder
33
+ self.prompt_iters = prompt_iters
34
+ self.enable_mask_refinement_iterations = enable_mask_refinement_iterations
35
+
36
+ def predict_masks(
37
+ self,
38
+ coords: torch.Tensor,
39
+ features: torch.Tensor,
40
+ prompt_coords: torch.Tensor,
41
+ prompt_labels: torch.Tensor,
42
+ prompt_masks: torch.Tensor = None,
43
+ multimask_output: bool = True,
44
+ ):
45
+ """Predict masks given point prompts.
46
+
47
+ Args:
48
+ coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1].
49
+ features: [B, N, F]. Point cloud features.
50
+ """
51
+ # pc_embeddings: [B, num_patches, D]
52
+ pc_embeddings, patches = self.pc_encoder(coords, features)
53
+ centers = patches["centers"] # [B, num_patches, 3]
54
+ knn_idx = patches["knn_idx"] # [B, N, K]
55
+ aux_inputs = AuxInputs(coords=coords, features=features, centers=centers)
56
+
57
+ # [B, num_patches, D]
58
+ pc_pe = self.point_encoder.pe_layer(centers)
59
+
60
+ # [B * M, num_queries, D]
61
+ sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels)
62
+
63
+ # [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None)
64
+ dense_embeddings = self.mask_encoder(
65
+ prompt_masks,
66
+ coords,
67
+ centers,
68
+ knn_idx
69
+ )
70
+
71
+ # [B * M, num_patches, D]
72
+ dense_embeddings = repeat_interleave(
73
+ dense_embeddings,
74
+ sparse_embeddings.shape[0] // dense_embeddings.shape[0],
75
+ 0,
76
+ )
77
+
78
+ # [B * M, num_outputs, N], [B * M, num_outputs]
79
+ masks, iou_preds = self.mask_decoder(
80
+ pc_embeddings,
81
+ pc_pe,
82
+ sparse_embeddings,
83
+ dense_embeddings,
84
+ aux_inputs=aux_inputs,
85
+ multimask_output=multimask_output,
86
+ )
87
+ return masks, iou_preds
88
+
89
+ def forward(
90
+ self,
91
+ coords: torch.Tensor,
92
+ features: torch.Tensor,
93
+ gt_masks: torch.Tensor,
94
+ is_eval: torch.bool = False,
95
+ ) -> List[Dict[str, torch.Tensor]]:
96
+ """Forward pass for training. The prompts are sampled given the ground truth masks.
97
+
98
+ Args:
99
+ coords: [B, N, 3]. Point cloud coordinates, normalized to [-1, 1].
100
+ features: [B, N, F]. Point cloud features.
101
+ gt_masks: [B, M, N], bool. Ground truth binary masks.
102
+
103
+ Returns:
104
+ outputs: List of dictionaries. Each dictionary contains the following keys:
105
+ - prompt_coords: [B * M, num_queries, 3]. Coordinates of the sampled prompts.
106
+ - prompt_labels: [B * M, num_queries], bool. Labels of the sampled prompts.
107
+ - prompt_masks: [B * M, N]. The most confident mask.
108
+ - masks: [B * M, num_outputs, N]. Predicted masks.
109
+ - iou_preds: [B * M, num_outputs]. IoU predictions.
110
+ """
111
+ batch_size = coords.shape[0]
112
+ num_masks = gt_masks.shape[1]
113
+
114
+ # pc_embeddings: [B, num_patches, D]
115
+ pc_embeddings, patches = self.pc_encoder(coords, features)
116
+ centers = patches["centers"] # [B, num_patches, 3]
117
+ knn_idx = patches["knn_idx"] # [B, N, K]
118
+
119
+ outputs = [] # Store the output at each iteration
120
+ prompt_coords = coords.new_empty((batch_size * num_masks, 0, 3))
121
+ prompt_labels = gt_masks.new_empty((batch_size * num_masks, 0))
122
+ prompt_masks = None # [B * M, N]
123
+ aux_inputs = AuxInputs(coords=coords, features=features, centers=centers)
124
+
125
+ # According to Appendix A (training algorithm) of SAM paper,
126
+ # there are two iterations where no additional prompts are sampled.
127
+ if self.enable_mask_refinement_iterations and self.training:
128
+ mask_refinement_iterations = [self.prompt_iters - 1]
129
+ if self.prompt_iters > 1:
130
+ sampled_iter = torch.randint(1, self.prompt_iters, (1,)).item()
131
+ mask_refinement_iterations.append(sampled_iter)
132
+ else:
133
+ mask_refinement_iterations = []
134
+
135
+ # [B, num_patches, D]
136
+ pc_pe = self.point_encoder.pe_layer(centers)
137
+
138
+ for i in range(self.prompt_iters):
139
+ if i == 0 or i not in mask_refinement_iterations:
140
+ new_prompt_coords, new_prompt_labels = sample_prompts_adapter(
141
+ coords, gt_masks, prompt_masks, is_eval=is_eval,
142
+ )
143
+ prompt_coords = torch.cat([prompt_coords, new_prompt_coords], dim=1)
144
+ prompt_labels = torch.cat([prompt_labels, new_prompt_labels], dim=1)
145
+
146
+ # [B * M, num_queries, D]
147
+ sparse_embeddings = self.point_encoder(prompt_coords, prompt_labels)
148
+
149
+ # [B * M, num_patches, D] or [B, num_patches, D] (if prompt_masks=None)
150
+ dense_embeddings = self.mask_encoder(
151
+ prompt_masks,
152
+ coords,
153
+ centers,
154
+ knn_idx,
155
+ center_idx=patches.get("fps_idx"),
156
+ )
157
+ # [B * M, num_patches, D]
158
+ dense_embeddings = repeat_interleave(
159
+ dense_embeddings,
160
+ sparse_embeddings.shape[0] // dense_embeddings.shape[0],
161
+ 0,
162
+ )
163
+
164
+ # [B * M, num_outputs, N], [B * M, num_outputs]
165
+ masks, iou_preds = self.mask_decoder(
166
+ pc_embeddings,
167
+ pc_pe,
168
+ sparse_embeddings,
169
+ dense_embeddings,
170
+ aux_inputs=aux_inputs,
171
+ multimask_output=(i == 0),
172
+ )
173
+
174
+ # Select the most confident mask for the next iteration
175
+ if i == 0:
176
+ max_iou_pred_ind = torch.argmax(iou_preds, dim=1) # [B * M]
177
+ prompt_masks = batch_index_select(
178
+ masks, max_iou_pred_ind, dim=1
179
+ ) # [B * M, N]
180
+ else:
181
+ max_iou_pred_ind = 0
182
+ prompt_masks = masks[:, 0]
183
+
184
+ outputs.append(
185
+ dict(
186
+ prompt_coords=prompt_coords,
187
+ prompt_labels=prompt_labels,
188
+ masks=masks,
189
+ iou_preds=iou_preds,
190
+ max_iou_pred_ind=max_iou_pred_ind,
191
+ prompt_masks=prompt_masks,
192
+ )
193
+ )
194
+
195
+ return outputs
196
+
197
+
198
+ def batch_index_select(data: torch.Tensor, index: torch.Tensor, dim: int):
199
+ """Batch index select."""
200
+ batch_size = data.shape[0]
201
+ view_shape = [1] * data.dim()
202
+ view_shape[0] = batch_size
203
+ view_shape[dim] = -1
204
+ index = index.view(view_shape)
205
+ shape = list(data.shape)
206
+ shape[dim] = index.shape[dim]
207
+ index = index.expand(shape)
208
+ return torch.gather(data, dim, index)