bdck commited on
Commit
f8ecd30
·
verified ·
1 Parent(s): 4d5a18c

Upload point_sam/model/mask_decoder.py

Browse files
Files changed (1) hide show
  1. point_sam/model/mask_decoder.py +211 -0
point_sam/model/mask_decoder.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py
2
+ import dataclasses
3
+ from typing import Dict, List, Tuple, Type
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from .common import compute_interp_weights, interpolate_features, repeat_interleave
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class AuxInputs:
14
+ coords: torch.Tensor
15
+ features: torch.Tensor
16
+ centers: torch.Tensor
17
+ interp_index: torch.Tensor = None
18
+ interp_weight: torch.Tensor = None
19
+
20
+
21
+ class MaskDecoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ transformer_dim: int,
25
+ transformer: nn.Module,
26
+ num_multimask_outputs: int = 3,
27
+ iou_head_depth: int = 3,
28
+ iou_head_hidden_dim: int = 256,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.transformer_dim = transformer_dim
32
+ self.transformer = transformer
33
+
34
+ self.num_multimask_outputs = num_multimask_outputs
35
+
36
+ self.iou_token = nn.Embedding(1, transformer_dim)
37
+ self.num_mask_tokens = num_multimask_outputs + 1
38
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
39
+
40
+ self.output_hypernetworks_mlps = nn.ModuleList(
41
+ [
42
+ MLP(transformer_dim, transformer_dim, transformer_dim, 3)
43
+ for i in range(self.num_mask_tokens)
44
+ ]
45
+ )
46
+ # self.output_upscaling = nn.Sequential(
47
+ # nn.Linear(transformer_dim, transformer_dim),
48
+ # nn.LayerNorm(transformer_dim),
49
+ # nn.GELU(),
50
+ # nn.Linear(transformer_dim, transformer_dim),
51
+ # nn.GELU(),
52
+ # )
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.Linear(transformer_dim, transformer_dim),
55
+ nn.LayerNorm(transformer_dim),
56
+ nn.GELU(),
57
+ nn.Linear(transformer_dim, transformer_dim),
58
+ nn.GELU(),
59
+ )
60
+
61
+ self.iou_prediction_head = MLP(
62
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ pc_embeddings: torch.Tensor,
68
+ pc_pe: torch.Tensor,
69
+ sparse_prompt_embeddings: torch.Tensor,
70
+ dense_prompt_embeddings: torch.Tensor,
71
+ aux_inputs: AuxInputs,
72
+ multimask_output: bool,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Predict masks given pointcloud and prompt embeddings.
76
+
77
+ Arguments:
78
+ pc_embeddings (torch.Tensor): the embeddings from the point cloud encoder
79
+ pc_pe (torch.Tensor): positional encoding with the shape of pc_embeddings
80
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
81
+ [B, N_prompts, D]
82
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
83
+ [B, N_patches, D]
84
+ multimask_output (bool): Whether to return multiple masks or a single
85
+ mask.
86
+
87
+ Returns:
88
+ torch.Tensor: batched predicted masks
89
+ torch.Tensor: batched predictions of mask quality
90
+ """
91
+ # Select the correct mask or masks for output
92
+ if multimask_output:
93
+ mask_slice = slice(1, None)
94
+ else:
95
+ mask_slice = slice(0, 1)
96
+
97
+ masks, iou_pred = self.predict_masks(
98
+ pc_embeddings=pc_embeddings,
99
+ pc_pe=pc_pe,
100
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
101
+ dense_prompt_embeddings=dense_prompt_embeddings,
102
+ aux_inputs=aux_inputs,
103
+ mask_slice=mask_slice,
104
+ )
105
+
106
+ # # Select the correct mask or masks for output
107
+ # if multimask_output:
108
+ # mask_slice = slice(1, None)
109
+ # else:
110
+ # mask_slice = slice(0, 1)
111
+ # masks = masks[:, mask_slice, :]
112
+ # iou_pred = iou_pred[:, mask_slice]
113
+
114
+ return masks, iou_pred
115
+
116
+ def predict_masks(
117
+ self,
118
+ pc_embeddings: torch.Tensor,
119
+ pc_pe: torch.Tensor,
120
+ sparse_prompt_embeddings: torch.Tensor,
121
+ dense_prompt_embeddings: torch.Tensor,
122
+ aux_inputs: AuxInputs,
123
+ mask_slice: slice = None,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
125
+ # Concatenate output tokens
126
+ output_tokens = torch.cat(
127
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
128
+ )
129
+ output_tokens = output_tokens.unsqueeze(0).expand(
130
+ sparse_prompt_embeddings.size(0), -1, -1
131
+ )
132
+ # [B*M, N_tokens, D]
133
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
134
+
135
+ # Expand per-image data in batch direction to be per-mask
136
+ repeats = tokens.shape[0] // pc_embeddings.shape[0]
137
+ src = repeat_interleave(pc_embeddings, repeats, dim=0)
138
+ pos_src = repeat_interleave(pc_pe, repeats, dim=0)
139
+ src = src + dense_prompt_embeddings
140
+
141
+ # Run the transformer
142
+ hs, src = self.transformer(src, pos_src, tokens)
143
+ iou_token_out = hs[:, 0, :]
144
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
145
+
146
+ # Upscale mask embeddings
147
+ coords = aux_inputs.coords # [B, N, 3]
148
+ centers = aux_inputs.centers # [B, L, 3]
149
+ interp_index = aux_inputs.interp_index # [B, N, 3]
150
+ interp_weight = aux_inputs.interp_weight # [B, N, 3]
151
+ if interp_index is None or interp_weight is None:
152
+ with torch.no_grad():
153
+ interp_index, interp_weight = compute_interp_weights(coords, centers)
154
+ # Update auxilary inputs for the next iteration
155
+ aux_inputs.interp_index = interp_index
156
+ aux_inputs.interp_weight = interp_weight
157
+
158
+ _repeats = tokens.shape[0] // interp_index.shape[0]
159
+ interp_index = repeat_interleave(interp_index, _repeats, dim=0)
160
+ interp_weight = repeat_interleave(interp_weight, _repeats, dim=0)
161
+
162
+ # [B*M, N, D]
163
+ interp_embedding = interpolate_features(src, interp_index, interp_weight)
164
+ upscaled_embedding = self.output_upscaling(interp_embedding)
165
+
166
+ # Predict masks using the mask tokens
167
+ hyper_in_list: List[torch.Tensor] = []
168
+ mask_indices = list(range(self.num_mask_tokens))
169
+ if mask_slice is not None:
170
+ mask_indices = mask_indices[mask_slice]
171
+ for i in mask_indices:
172
+ hyper_in_list.append(
173
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
174
+ )
175
+ hyper_in = torch.stack(hyper_in_list, dim=1) # [B*M, num_mask_tokens, D]
176
+ masks = hyper_in @ upscaled_embedding.transpose(-1, -2)
177
+ # masks = upscaled_embedding.transpose(-1, -2)
178
+
179
+ # Generate mask quality predictions
180
+ iou_pred = self.iou_prediction_head(iou_token_out)
181
+ if mask_slice is not None:
182
+ iou_pred = iou_pred[:, mask_slice]
183
+
184
+ return masks, iou_pred
185
+
186
+
187
+ # Adapted from https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
188
+ # Used in MaskDecoder for SAM
189
+ class MLP(nn.Module):
190
+ def __init__(
191
+ self,
192
+ input_dim: int,
193
+ hidden_dim: int,
194
+ output_dim: int,
195
+ num_layers: int,
196
+ sigmoid_output: bool = False,
197
+ ) -> None:
198
+ super().__init__()
199
+ self.num_layers = num_layers
200
+ h = [hidden_dim] * (num_layers - 1)
201
+ self.layers = nn.ModuleList(
202
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
203
+ )
204
+ self.sigmoid_output = sigmoid_output
205
+
206
+ def forward(self, x):
207
+ for i, layer in enumerate(self.layers):
208
+ x = F.relu(layer(x), inplace=True) if i < self.num_layers - 1 else layer(x)
209
+ if self.sigmoid_output:
210
+ x = F.sigmoid(x)
211
+ return x