bdck commited on
Commit
ac6e542
·
verified ·
1 Parent(s): 7f9bc92

Upload point_sam/model/pc_encoder.py

Browse files
Files changed (1) hide show
  1. point_sam/model/pc_encoder.py +198 -0
point_sam/model/pc_encoder.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/baaivision/Uni3D/blob/main/models/point_encoder.py
2
+ from typing import Union
3
+
4
+ import timm
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.eva import Eva
8
+ from timm.models.vision_transformer import VisionTransformer
9
+
10
+ from .common import KNNGrouper, NNGrouper, PatchEncoder
11
+
12
+
13
+ class PatchEmbed(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels,
17
+ out_channels,
18
+ num_patches,
19
+ patch_size,
20
+ radius: float = None,
21
+ centralize_features=False,
22
+ ):
23
+ super().__init__()
24
+ self.in_channels = in_channels
25
+ self.out_channels = out_channels
26
+
27
+ self.grouper = KNNGrouper(
28
+ num_patches,
29
+ patch_size,
30
+ radius=radius,
31
+ centralize_features=centralize_features,
32
+ )
33
+
34
+ self.patch_encoder = PatchEncoder(in_channels, out_channels, [128, 512])
35
+
36
+ def forward(self, coords: torch.Tensor, features: torch.Tensor):
37
+ patches = self.grouper(coords, features)
38
+ patch_features = patches["features"] # [B, L, K, C_in]
39
+ x = self.patch_encoder(patch_features)
40
+ patches["embeddings"] = x
41
+ return patches
42
+
43
+
44
+ class PatchDropout(nn.Module):
45
+ """Randomly drop patches.
46
+
47
+ References:
48
+ - https://arxiv.org/abs/2212.00794
49
+ - `timm.layers.patch_dropout`. It uses `argsort` rather than `topk`, which might be inefficient.
50
+ """
51
+
52
+ def __init__(self, prob, num_prefix_tokens: int = 1):
53
+ super().__init__()
54
+ assert 0.0 <= prob < 1.0, prob
55
+ self.prob = prob
56
+ # exclude CLS token (or other prefix tokens)
57
+ self.num_prefix_tokens = num_prefix_tokens
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ # x: [B, L, ...]
61
+ if not self.training or self.prob == 0.0:
62
+ return x
63
+
64
+ if self.num_prefix_tokens:
65
+ prefix_tokens = x[:, : self.num_prefix_tokens]
66
+ x = x[:, self.num_prefix_tokens :]
67
+ else:
68
+ prefix_tokens = None
69
+
70
+ B, L = x.shape[:2]
71
+ num_keep = max(1, int(L * (1.0 - self.prob)))
72
+ rand = torch.randn(B, L, device=x.device)
73
+ keep_indices = rand.topk(num_keep, dim=1).indices
74
+ _keep_indices = keep_indices.reshape((B, num_keep) + (-1,) * (x.dim() - 2))
75
+ _keep_indices = _keep_indices.expand((-1, -1) + x.shape[2:])
76
+ x = x.gather(1, _keep_indices)
77
+
78
+ if prefix_tokens is not None:
79
+ x = torch.cat((prefix_tokens, x), dim=1)
80
+
81
+ return x
82
+
83
+
84
+ class PointCloudEncoder(nn.Module):
85
+ def __init__(
86
+ self,
87
+ patch_embed: PatchEmbed,
88
+ transformer: Union[VisionTransformer, Eva],
89
+ embed_dim: int,
90
+ patch_drop_rate=0.0,
91
+ ):
92
+ super().__init__()
93
+ self.transformer_dim = transformer.embed_dim
94
+ self.embed_dim = embed_dim
95
+
96
+ # Patch embedding
97
+ self.patch_embed = patch_embed
98
+ # Project patch features to transformer input dim
99
+ self.patch_proj = nn.Linear(self.patch_embed.out_channels, self.transformer_dim)
100
+
101
+ # Positional embedding
102
+ self.pos_embed = nn.Sequential(
103
+ nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.transformer_dim)
104
+ )
105
+
106
+ assert patch_drop_rate == 0, "PatchDropout is not compatible with decoder."
107
+ if patch_drop_rate > 0:
108
+ self.patch_dropout = PatchDropout(patch_drop_rate, num_prefix_tokens=0)
109
+ else:
110
+ self.patch_dropout = nn.Identity()
111
+
112
+ # Transformer encoder
113
+ self.transformer = transformer
114
+
115
+ # Project transformer output to embedding dim
116
+ self.out_proj = nn.Linear(self.transformer_dim, self.embed_dim)
117
+
118
+ def forward(self, coords, features):
119
+ # Group points into patches and get embeddings
120
+ patches = self.patch_embed(coords, features)
121
+ if isinstance(patches, list):
122
+ patch_embed = patches[-1]["embeddings"]
123
+ centers = patches[-1]["centers"]
124
+ else:
125
+ patch_embed = patches["embeddings"] # [B, L, D]
126
+ centers = patches["centers"] # [B, L, 3]
127
+ patch_embed = self.patch_proj(patch_embed)
128
+
129
+ # Positional embedding for patches
130
+ pos_embed = self.pos_embed(centers)
131
+ x = patch_embed + pos_embed
132
+
133
+ # Dropout patch
134
+ x = self.patch_dropout(x)
135
+ # Dropout features
136
+ x = self.transformer.pos_drop(x)
137
+
138
+ for block in self.transformer.blocks:
139
+ x = block(x)
140
+ # In fact, only norm or fc_norm is not identity in those transformers.
141
+ x = self.transformer.norm(x)
142
+ x = self.transformer.fc_norm(x)
143
+ x = self.out_proj(x)
144
+
145
+ return x, patches
146
+
147
+
148
+ class Block(nn.Module):
149
+ def __init__(self, in_channels, hidden_dim, out_channels):
150
+ super().__init__()
151
+ # Follow timm.layers.mlp
152
+ self.mlp = nn.Sequential(
153
+ nn.Linear(in_channels, hidden_dim),
154
+ nn.GELU(),
155
+ nn.LayerNorm(hidden_dim),
156
+ nn.Linear(hidden_dim, out_channels),
157
+ )
158
+ self.norm = nn.LayerNorm(out_channels)
159
+
160
+ def forward(self, x):
161
+ # PreLN. Follow timm.models.vision_transformer
162
+ return x + self.mlp(self.norm(x))
163
+
164
+
165
+ class PatchEmbedNN(nn.Module):
166
+ def __init__(self, in_channels, hidden_dim, out_channels, num_patches) -> None:
167
+ super().__init__()
168
+ self.in_channels = in_channels
169
+ self.out_channels = out_channels
170
+ hidden_dim = hidden_dim or out_channels
171
+
172
+ self.grouper = NNGrouper(num_patches)
173
+ self.in_proj = nn.Linear(in_channels, hidden_dim)
174
+ self.blocks1 = nn.Sequential(
175
+ *[Block(hidden_dim, hidden_dim, hidden_dim) for _ in range(3)]
176
+ )
177
+ self.blocks2 = nn.Sequential(
178
+ *[Block(hidden_dim, hidden_dim, hidden_dim) for _ in range(3)]
179
+ )
180
+ self.norm = nn.LayerNorm(hidden_dim)
181
+ self.out_proj = nn.Linear(hidden_dim, out_channels)
182
+
183
+ def forward(self, coords: torch.tensor, features: torch.tensor):
184
+ patches = self.grouper(coords, features)
185
+ patch_features = patches["features"] # [B, N, D]
186
+ nn_idx = patches["nn_idx"] # [B, N]
187
+
188
+ x = self.in_proj(patch_features)
189
+ x = self.blocks1(x) # [B, N, D]
190
+ y = x.new_zeros(x.shape[0], self.grouper.num_groups, x.shape[-1])
191
+ y.scatter_reduce_(
192
+ 1, nn_idx.unsqueeze(-1).expand_as(x), x, "amax", include_self=False
193
+ )
194
+ x = self.blocks2(y)
195
+ x = self.norm(x)
196
+ x = self.out_proj(x)
197
+ patches["embeddings"] = x
198
+ return patches